diff options
author | Pieter Noordhuis <pcnoordhuis@gmail.com> | 2018-06-08 12:59:51 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-08 12:59:51 -0700 |
commit | 695d40efc28bde1f01e29c9041f99a595072bd67 (patch) | |
tree | 51929f43cd998d9bf95b3aece056a7dbce3e7455 | |
parent | 75563674c4c9133f190726d5e06a2cd59904cb64 (diff) | |
download | pytorch-695d40efc28bde1f01e29c9041f99a595072bd67.tar.gz pytorch-695d40efc28bde1f01e29c9041f99a595072bd67.tar.bz2 pytorch-695d40efc28bde1f01e29c9041f99a595072bd67.zip |
Create initial Python bindings for c10d (#8119)
* Build and install c10d from tools/build_pytorch_libs.sh
* Create initial Python bindings for c10d
* clang-format
* Switch link order to include more symbols
* Add bindings and tests for ProcessGroupGloo
* Add broadcast test
* Separate build flag for c10d
* Explicit PIC property
* Skip c10d tests if not available
* Remove c10d from Windows blacklist
Let it skip by itself because it won't be available anyway.
* Make lint happy
* Comments
* Move c10d module into torch.distributed
* Close tempfile such that it is deleted
-rw-r--r-- | setup.py | 13 | ||||
-rw-r--r-- | test/run_test.py | 1 | ||||
-rw-r--r-- | test/test_c10d.py | 171 | ||||
-rwxr-xr-x | tools/build_pytorch_libs.sh | 4 | ||||
-rw-r--r-- | tools/setup_helpers/dist_check.py | 4 | ||||
-rw-r--r-- | torch/csrc/Module.cpp | 7 | ||||
-rw-r--r-- | torch/csrc/distributed/c10d/c10d.h | 13 | ||||
-rw-r--r-- | torch/csrc/distributed/c10d/init.cpp | 125 | ||||
-rw-r--r-- | torch/distributed/c10d/__init__.py | 8 | ||||
-rw-r--r-- | torch/lib/c10d/CMakeLists.txt | 48 |
10 files changed, 384 insertions, 10 deletions
@@ -115,7 +115,7 @@ from tools.setup_helpers.nvtoolext import NVTOOLEXT_HOME from tools.setup_helpers.generate_code import generate_code from tools.setup_helpers.ninja_builder import NinjaBuilder, ninja_build_ext from tools.setup_helpers.dist_check import WITH_DISTRIBUTED, \ - WITH_DISTRIBUTED_MW, WITH_GLOO_IBVERBS + WITH_DISTRIBUTED_MW, WITH_GLOO_IBVERBS, WITH_C10D ################################################################################ @@ -250,7 +250,7 @@ class create_version_file(PytorchCommand): # All libraries that torch could depend on dep_libs = [ 'nccl', 'caffe2', - 'libshm', 'libshm_windows', 'gloo', 'THD', 'nanopb', + 'libshm', 'libshm_windows', 'gloo', 'THD', 'nanopb', 'c10d', ] missing_pydep = ''' @@ -346,6 +346,8 @@ class build_deps(PytorchCommand): if sys.platform.startswith('linux'): libs += ['gloo'] libs += ['THD'] + if WITH_C10D: + libs += ['c10d'] build_libs(libs) # Use copies instead of symbolic files. @@ -633,6 +635,8 @@ if WITH_CUDA or WITH_ROCM: CAFFE2_LIBS.extend(['-Wl,--no-as-needed', os.path.join(lib_path, 'libcaffe2_gpu.so'), '-Wl,--as-needed']) THD_LIB = os.path.join(lib_path, 'libTHD.a') NCCL_LIB = os.path.join(lib_path, 'libnccl.so.1') +C10D_LIB = os.path.join(lib_path, 'libc10d.a') +C10D_GLOO_LIB = os.path.join(lib_path, 'libc10d_gloo.a') # static library only NANOPB_STATIC_LIB = os.path.join(lib_path, 'libprotobuf-nanopb.a') @@ -787,6 +791,11 @@ if WITH_DISTRIBUTED: include_dirs += [tmp_install_path + "/include/THD"] main_link_args += [THD_LIB] +if WITH_C10D: + extra_compile_args += ['-DWITH_C10D'] + main_sources += ['torch/csrc/distributed/c10d/init.cpp'] + main_link_args += [C10D_GLOO_LIB, C10D_LIB] + if WITH_CUDA: nvtoolext_lib_name = None if IS_WINDOWS: diff --git a/test/run_test.py b/test/run_test.py index 8eec3dfe48..65aa1003e3 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -17,6 +17,7 @@ from torch.utils import cpp_extension TESTS = [ 'autograd', 'cpp_extensions', + 'c10d', 'cuda', 'dataloader', 'distributed', diff --git a/test/test_c10d.py b/test/test_c10d.py new file mode 100644 index 0000000000..0b19e2b334 --- /dev/null +++ b/test/test_c10d.py @@ -0,0 +1,171 @@ +import math +import multiprocessing +import sys +import tempfile +import unittest +from functools import wraps + +import torch +import torch.distributed.c10d as c10d + +from common import TestCase + + +TCP_ADDR = '127.0.0.1' +TCP_PORT = 29500 + +TIMEOUT_DEFAULT = 5 +TIMEOUT_OVERRIDE = {} + + +def get_timeout(test_id): + return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT) + + +if not c10d.is_available(): + print('c10d not available, skipping tests') + sys.exit(0) + + +class StoreTestBase(object): + def _create_store(self, i): + raise RuntimeError("not implemented") + + def _test_set_get(self, fs): + fs.set("key0", "value0") + fs.set("key1", "value1") + fs.set("key2", "value2") + self.assertEqual(b"value0", fs.get("key0")) + self.assertEqual(b"value1", fs.get("key1")) + self.assertEqual(b"value2", fs.get("key2")) + + def test_set_get(self): + self._test_set_get(self._create_store()) + + +class FileStoreTest(TestCase, StoreTestBase): + def setUp(self): + self.file = tempfile.NamedTemporaryFile() + + def tearDown(self): + self.file.close() + + def _create_store(self): + return c10d.FileStore(self.file.name) + + +class TCPStoreTest(TestCase, StoreTestBase): + def _create_store(self): + return c10d.TCPStore(TCP_ADDR, TCP_PORT, True) + + +class ProcessGroupGlooTest(TestCase): + MAIN_PROCESS_RANK = -1 + + @staticmethod + def join_or_run(fn): + @wraps(fn) + def wrapper(self): + if self.rank == self.MAIN_PROCESS_RANK: + self._join_processes(fn) + else: + fn(self) + return wrapper + + # The main process spawns N subprocesses that run the test. + # This function patches overwrites every test function to either + # assume the role of the main process and join its subprocesses, + # or run the underlying test function. + @classmethod + def setUpClass(cls): + for attr in dir(cls): + if attr.startswith('test'): + fn = getattr(cls, attr) + setattr(cls, attr, cls.join_or_run(fn)) + + def setUp(self): + self.rank = self.MAIN_PROCESS_RANK + self.size = 4 + self.file = tempfile.NamedTemporaryFile() + self.processes = [self._spawn_process(rank) for rank in range(int(self.size))] + + def tearDown(self): + for p in self.processes: + p.terminate() + self.file.close() + + def _spawn_process(self, rank): + name = 'process ' + str(rank) + process = multiprocessing.Process(target=self._run, name=name, args=(rank,)) + process.start() + return process + + def _run(self, rank): + self.rank = rank + + # self.id() == e.g. '__main__.TestDistributed.test_get_rank' + # We're retreiving a corresponding test and executing it. + getattr(self, self.id().split(".")[2])() + sys.exit(0) + + def _join_processes(self, fn): + timeout = get_timeout(self.id()) + for p in self.processes: + p.join(timeout) + + def test_broadcast_ops(self): + store = c10d.FileStore(self.file.name) + pg = c10d.ProcessGroupGloo(store, self.rank, self.size) + + def broadcast(xs, rootRank, rootTensor): + opts = c10d.BroadcastOptions() + opts.rootRank = rootRank + opts.rootTensor = rootTensor + work = pg.broadcast(xs, opts) + work.wait() + + # Every rank is root once, every tensor index is root once + for i in range(self.size): + for j in range(2): + xs = [ + torch.Tensor([self.rank * self.size + 0.0]), + torch.Tensor([self.rank * self.size + 1.0]), + ] + + broadcast(xs, i, j) + self.assertEqual(torch.Tensor([i * self.size + j]), xs[0]) + self.assertEqual(torch.Tensor([i * self.size + j]), xs[1]) + + def test_allreduce_ops(self): + store = c10d.FileStore(self.file.name) + pg = c10d.ProcessGroupGloo(store, self.rank, self.size) + + def allreduce(x, op): + opts = c10d.AllreduceOptions() + opts.reduceOp = op + work = pg.allreduce([x], opts) + work.wait() + + # Sum + x = torch.Tensor([self.rank + 1.0]) + allreduce(x, c10d.ReduceOp.SUM) + self.assertEqual(torch.Tensor([float(self.size * (self.size + 1) / 2)]), x) + + # Product + x = torch.Tensor([self.rank + 1.0]) + allreduce(x, c10d.ReduceOp.PRODUCT) + self.assertEqual(torch.Tensor([float(math.factorial(self.size))]), x) + + # Min + x = torch.Tensor([self.rank + 1.0]) + allreduce(x, c10d.ReduceOp.MIN) + self.assertEqual(torch.Tensor([1.0]), x) + + # Max + x = torch.Tensor([self.rank + 1.0]) + allreduce(x, c10d.ReduceOp.MAX) + self.assertEqual(torch.Tensor([self.size]), x) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/build_pytorch_libs.sh b/tools/build_pytorch_libs.sh index 0323b96c83..8320111541 100755 --- a/tools/build_pytorch_libs.sh +++ b/tools/build_pytorch_libs.sh @@ -290,6 +290,10 @@ for arg in "$@"; do pushd "$TORCH_LIB_DIR" build $arg popd + elif [[ "$arg" == "c10d" ]]; then + pushd "$TORCH_LIB_DIR" + build c10d + popd else pushd "$THIRD_PARTY_DIR" build $arg diff --git a/tools/setup_helpers/dist_check.py b/tools/setup_helpers/dist_check.py index afec6ce14c..2d77b3a3e0 100644 --- a/tools/setup_helpers/dist_check.py +++ b/tools/setup_helpers/dist_check.py @@ -2,12 +2,14 @@ import os import subprocess import glob -from .env import IS_CONDA, IS_WINDOWS, CONDA_DIR, check_env_flag, gather_paths +from .env import IS_CONDA, IS_LINUX, IS_WINDOWS, CONDA_DIR, check_env_flag, gather_paths +from .cuda import WITH_CUDA # On ROCm, RCCL development isn't complete. https://github.com/ROCmSoftwarePlatform/rccl WITH_DISTRIBUTED = not check_env_flag("NO_DISTRIBUTED") and not IS_WINDOWS and not check_env_flag("WITH_ROCM") WITH_DISTRIBUTED_MW = WITH_DISTRIBUTED and check_env_flag("WITH_DISTRIBUTED_MW") WITH_GLOO_IBVERBS = False +WITH_C10D = WITH_DISTRIBUTED and WITH_CUDA and IS_LINUX IB_DEVINFO_CMD = "ibv_devinfo" diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 9e945c1218..b4f3e2d0f8 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -42,6 +42,10 @@ #include "cudnn.h" #endif +#ifdef WITH_C10D +#include "torch/csrc/distributed/c10d/c10d.h" +#endif + #define WITH_NUMPY_IMPORT_ARRAY #include "torch/csrc/utils/numpy_stub.h" @@ -487,6 +491,9 @@ static PyObject* initModule() { #ifdef WITH_DISTRIBUTED THPUtils_addPyMethodDefs(methods, THDPModule_methods()); #endif +#ifdef WITH_C10D + THPUtils_addPyMethodDefs(methods, torch::distributed::c10d::python_functions()); +#endif #if PY_MAJOR_VERSION == 2 ASSERT_TRUE(module = Py_InitModule("torch._C", methods.data())); diff --git a/torch/csrc/distributed/c10d/c10d.h b/torch/csrc/distributed/c10d/c10d.h new file mode 100644 index 0000000000..e91d4cb63b --- /dev/null +++ b/torch/csrc/distributed/c10d/c10d.h @@ -0,0 +1,13 @@ +#pragma once + +#include "torch/csrc/python_headers.h" + +namespace torch { +namespace distributed { +namespace c10d { + +PyMethodDef* python_functions(); + +} // namespace c10d +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp new file mode 100644 index 0000000000..62d8b4a5bd --- /dev/null +++ b/torch/csrc/distributed/c10d/init.cpp @@ -0,0 +1,125 @@ +#include "torch/csrc/python_headers.h" + +#include <c10d/FileStore.hpp> +#include <c10d/ProcessGroup.hpp> +#include <c10d/ProcessGroupGloo.hpp> +#include <c10d/TCPStore.hpp> + +#include "torch/csrc/Exceptions.h" +#include "torch/csrc/utils/object_ptr.h" +#include "torch/csrc/utils/pybind.h" + +namespace torch { +namespace distributed { +namespace c10d { + +namespace { + +template <typename T> +using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>; + +PyObject* c10d_init(PyObject* _unused) { + auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed.c10d")); + if (!c10d_module) { + throw python_error(); + } + + auto module = py::handle(c10d_module).cast<py::module>(); + + auto store = + shared_ptr_class_<::c10d::Store>(module, "Store") + // Convert from std::string to std::vector<uint8>. + .def( + "set", + [](::c10d::Store& store, + const std::string& key, + const std::string& value) { + std::vector<uint8_t> value_(value.begin(), value.end()); + store.set(key, value_); + }, + py::call_guard<py::gil_scoped_release>()) + // Convert from std::vector<uint8_t> to py::bytes. + // The returned value is not guaranteed to be valid UTF-8. + .def( + "get", + [](::c10d::Store& store, const std::string& key) -> py::bytes { + auto value = store.get(key); + return py::bytes( + reinterpret_cast<char*>(value.data()), value.size()); + }, + py::call_guard<py::gil_scoped_release>()) + .def( + "add", + &::c10d::Store::add, + py::call_guard<py::gil_scoped_release>()) + .def( + "wait", + &::c10d::Store::wait, + py::call_guard<py::gil_scoped_release>()); + + shared_ptr_class_<::c10d::FileStore>(module, "FileStore", store) + .def(py::init<const std::string&>()); + + shared_ptr_class_<::c10d::TCPStore>(module, "TCPStore", store) + .def(py::init<const std::string&, int, bool>()); + + auto processGroup = + shared_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup") + .def("rank", &::c10d::ProcessGroup::getRank) + .def("size", &::c10d::ProcessGroup::getSize) + .def( + "broadcast", + &::c10d::ProcessGroup::broadcast, + py::call_guard<py::gil_scoped_release>()) + .def( + "allreduce", + &::c10d::ProcessGroup::allreduce, + py::call_guard<py::gil_scoped_release>()); + + shared_ptr_class_<::c10d::ProcessGroupGloo>( + module, "ProcessGroupGloo", processGroup) + .def(py::init<const std::shared_ptr<::c10d::Store>&, int, int>()); + + shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work") + .def("isCompleted", &::c10d::ProcessGroup::Work::isCompleted) + .def("isSuccess", &::c10d::ProcessGroup::Work::isSuccess) + .def("exception", &::c10d::ProcessGroup::Work::exception) + .def("synchronize", &::c10d::ProcessGroup::Work::synchronize) + .def( + "wait", + &::c10d::ProcessGroup::Work::wait, + py::call_guard<py::gil_scoped_release>()); + + // Algorithm specific option structs and enums + py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions") + .def(py::init<>()) + .def_readwrite("rootRank", &::c10d::BroadcastOptions::rootRank) + .def_readwrite("rootTensor", &::c10d::BroadcastOptions::rootTensor); + + py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions") + .def(py::init<>()) + .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp); + + py::enum_<::c10d::ReduceOp>(module, "ReduceOp") + .value("SUM", ::c10d::ReduceOp::SUM) + .value("PRODUCT", ::c10d::ReduceOp::PRODUCT) + .value("MIN", ::c10d::ReduceOp::MIN) + .value("MAX", ::c10d::ReduceOp::MAX); + + Py_RETURN_TRUE; +} + +} // namespace + +// c10d methods on torch._C +static PyMethodDef methods[] = { + {"_c10d_init", (PyCFunction)c10d_init, METH_NOARGS, nullptr}, + {nullptr, nullptr, 0, nullptr}}; + +PyMethodDef* python_functions() { + return methods; +} + +} // namespace c10d +} // namespace distributed +} // namespace torch diff --git a/torch/distributed/c10d/__init__.py b/torch/distributed/c10d/__init__.py new file mode 100644 index 0000000000..1d5c085295 --- /dev/null +++ b/torch/distributed/c10d/__init__.py @@ -0,0 +1,8 @@ +import torch + + +def is_available(): + return hasattr(torch._C, '_c10d_init') + +if is_available() and not torch._C._c10d_init(): + raise RuntimeError("c10d initialization failed") diff --git a/torch/lib/c10d/CMakeLists.txt b/torch/lib/c10d/CMakeLists.txt index 1fb2141dbf..dec330dfc0 100644 --- a/torch/lib/c10d/CMakeLists.txt +++ b/torch/lib/c10d/CMakeLists.txt @@ -59,16 +59,22 @@ if(NOT CUDA_FOUND) message(FATAL_ERROR "CUDA not found") endif() +function(copy_header file) + configure_file(${file} ${CMAKE_BINARY_DIR}/include/c10d/${file} COPYONLY) +endfunction() + set(C10D_SRCS - Utils.cpp - Store.cpp + CUDAUtils.cpp FileStore.cpp - TCPStore.cpp ProcessGroup.cpp - CUDAUtils.cpp + Store.cpp + TCPStore.cpp + Utils.cpp ) add_library(c10d ${C10D_SRCS}) +set_property(TARGET c10d PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET c10d PROPERTY CXX_STANDARD 11) target_compile_options(c10d PUBLIC -Wall -Wextra @@ -78,6 +84,15 @@ target_compile_options(c10d PUBLIC -Wno-unknown-pragmas ) target_link_libraries(c10d PUBLIC caffe2_gpu) +copy_header(CUDAUtils.hpp) +copy_header(FileStore.hpp) +copy_header(ProcessGroup.hpp) +copy_header(Store.hpp) +copy_header(TCPStore.hpp) +copy_header(Types.hpp) +copy_header(Utils.hpp) +target_include_directories(c10d PRIVATE ${CMAKE_BINARY_DIR}/include) +install(TARGETS c10d ARCHIVE DESTINATION lib) # c10d links to Caffe2/ATen, but the targets don't add TH/THC to the include path target_include_directories(c10d PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../tmp_install/include/TH) @@ -91,16 +106,26 @@ set(C10D_GLOO_SRCS ) add_library(c10d_gloo ${C10D_GLOO_SRCS}) +set_property(TARGET c10d_gloo PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET c10d_gloo PROPERTY CXX_STANDARD 11) target_include_directories(c10d_gloo PUBLIC ${GLOO_INCLUDE_DIR}) target_link_libraries(c10d_gloo PUBLIC c10d ${Gloo_NATIVE_LIBRARY} ${Gloo_LIBRARY}) +copy_header(ProcessGroupGloo.hpp) +target_include_directories(c10d PRIVATE ${CMAKE_BINARY_DIR}/include) +install(TARGETS c10d_gloo ARCHIVE DESTINATION lib) if(MPI_FOUND) set(C10D_MPI_SRCS ProcessGroupMPI.cpp ) add_library(c10d_mpi ${C10D_MPI_SRCS}) + set_property(TARGET c10d_mpi PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET c10d_mpi PROPERTY CXX_STANDARD 11) target_include_directories(c10d_mpi PUBLIC ${MPI_INCLUDE_PATH}) target_link_libraries(c10d_mpi PUBLIC c10d ${MPI_LIBRARIES}) + copy_header(ProcessGroupMPI.hpp) + target_include_directories(c10d_mpi PRIVATE ${CMAKE_BINARY_DIR}/include) + install(TARGETS c10d_mpi ARCHIVE DESTINATION lib) endif() if(DISTRIBUTED_NCCL_FOUND) @@ -112,7 +137,16 @@ if(DISTRIBUTED_NCCL_FOUND) target_link_libraries(c10d_nccl PUBLIC c10d ${NCCL_LIBRARIES}) endif() -add_subdirectory(example) +option(BUILD_EXAMPLES "Build examples" OFF) +if(BUILD_EXAMPLES) + add_subdirectory(example) +endif() + +option(BUILD_TEST "Build tests" OFF) +if(BUILD_TEST) + enable_testing() + add_subdirectory(test) +endif() -enable_testing() -add_subdirectory(test) +# Install all header files that were prepared in the build directory +install(DIRECTORY ${CMAKE_BINARY_DIR}/include/ DESTINATION include) |