summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPieter Noordhuis <pcnoordhuis@gmail.com>2018-06-08 12:59:51 -0700
committerGitHub <noreply@github.com>2018-06-08 12:59:51 -0700
commit695d40efc28bde1f01e29c9041f99a595072bd67 (patch)
tree51929f43cd998d9bf95b3aece056a7dbce3e7455
parent75563674c4c9133f190726d5e06a2cd59904cb64 (diff)
downloadpytorch-695d40efc28bde1f01e29c9041f99a595072bd67.tar.gz
pytorch-695d40efc28bde1f01e29c9041f99a595072bd67.tar.bz2
pytorch-695d40efc28bde1f01e29c9041f99a595072bd67.zip
Create initial Python bindings for c10d (#8119)
* Build and install c10d from tools/build_pytorch_libs.sh * Create initial Python bindings for c10d * clang-format * Switch link order to include more symbols * Add bindings and tests for ProcessGroupGloo * Add broadcast test * Separate build flag for c10d * Explicit PIC property * Skip c10d tests if not available * Remove c10d from Windows blacklist Let it skip by itself because it won't be available anyway. * Make lint happy * Comments * Move c10d module into torch.distributed * Close tempfile such that it is deleted
-rw-r--r--setup.py13
-rw-r--r--test/run_test.py1
-rw-r--r--test/test_c10d.py171
-rwxr-xr-xtools/build_pytorch_libs.sh4
-rw-r--r--tools/setup_helpers/dist_check.py4
-rw-r--r--torch/csrc/Module.cpp7
-rw-r--r--torch/csrc/distributed/c10d/c10d.h13
-rw-r--r--torch/csrc/distributed/c10d/init.cpp125
-rw-r--r--torch/distributed/c10d/__init__.py8
-rw-r--r--torch/lib/c10d/CMakeLists.txt48
10 files changed, 384 insertions, 10 deletions
diff --git a/setup.py b/setup.py
index c411ce7d8b..7e3f44a45c 100644
--- a/setup.py
+++ b/setup.py
@@ -115,7 +115,7 @@ from tools.setup_helpers.nvtoolext import NVTOOLEXT_HOME
from tools.setup_helpers.generate_code import generate_code
from tools.setup_helpers.ninja_builder import NinjaBuilder, ninja_build_ext
from tools.setup_helpers.dist_check import WITH_DISTRIBUTED, \
- WITH_DISTRIBUTED_MW, WITH_GLOO_IBVERBS
+ WITH_DISTRIBUTED_MW, WITH_GLOO_IBVERBS, WITH_C10D
################################################################################
@@ -250,7 +250,7 @@ class create_version_file(PytorchCommand):
# All libraries that torch could depend on
dep_libs = [
'nccl', 'caffe2',
- 'libshm', 'libshm_windows', 'gloo', 'THD', 'nanopb',
+ 'libshm', 'libshm_windows', 'gloo', 'THD', 'nanopb', 'c10d',
]
missing_pydep = '''
@@ -346,6 +346,8 @@ class build_deps(PytorchCommand):
if sys.platform.startswith('linux'):
libs += ['gloo']
libs += ['THD']
+ if WITH_C10D:
+ libs += ['c10d']
build_libs(libs)
# Use copies instead of symbolic files.
@@ -633,6 +635,8 @@ if WITH_CUDA or WITH_ROCM:
CAFFE2_LIBS.extend(['-Wl,--no-as-needed', os.path.join(lib_path, 'libcaffe2_gpu.so'), '-Wl,--as-needed'])
THD_LIB = os.path.join(lib_path, 'libTHD.a')
NCCL_LIB = os.path.join(lib_path, 'libnccl.so.1')
+C10D_LIB = os.path.join(lib_path, 'libc10d.a')
+C10D_GLOO_LIB = os.path.join(lib_path, 'libc10d_gloo.a')
# static library only
NANOPB_STATIC_LIB = os.path.join(lib_path, 'libprotobuf-nanopb.a')
@@ -787,6 +791,11 @@ if WITH_DISTRIBUTED:
include_dirs += [tmp_install_path + "/include/THD"]
main_link_args += [THD_LIB]
+if WITH_C10D:
+ extra_compile_args += ['-DWITH_C10D']
+ main_sources += ['torch/csrc/distributed/c10d/init.cpp']
+ main_link_args += [C10D_GLOO_LIB, C10D_LIB]
+
if WITH_CUDA:
nvtoolext_lib_name = None
if IS_WINDOWS:
diff --git a/test/run_test.py b/test/run_test.py
index 8eec3dfe48..65aa1003e3 100644
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -17,6 +17,7 @@ from torch.utils import cpp_extension
TESTS = [
'autograd',
'cpp_extensions',
+ 'c10d',
'cuda',
'dataloader',
'distributed',
diff --git a/test/test_c10d.py b/test/test_c10d.py
new file mode 100644
index 0000000000..0b19e2b334
--- /dev/null
+++ b/test/test_c10d.py
@@ -0,0 +1,171 @@
+import math
+import multiprocessing
+import sys
+import tempfile
+import unittest
+from functools import wraps
+
+import torch
+import torch.distributed.c10d as c10d
+
+from common import TestCase
+
+
+TCP_ADDR = '127.0.0.1'
+TCP_PORT = 29500
+
+TIMEOUT_DEFAULT = 5
+TIMEOUT_OVERRIDE = {}
+
+
+def get_timeout(test_id):
+ return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT)
+
+
+if not c10d.is_available():
+ print('c10d not available, skipping tests')
+ sys.exit(0)
+
+
+class StoreTestBase(object):
+ def _create_store(self, i):
+ raise RuntimeError("not implemented")
+
+ def _test_set_get(self, fs):
+ fs.set("key0", "value0")
+ fs.set("key1", "value1")
+ fs.set("key2", "value2")
+ self.assertEqual(b"value0", fs.get("key0"))
+ self.assertEqual(b"value1", fs.get("key1"))
+ self.assertEqual(b"value2", fs.get("key2"))
+
+ def test_set_get(self):
+ self._test_set_get(self._create_store())
+
+
+class FileStoreTest(TestCase, StoreTestBase):
+ def setUp(self):
+ self.file = tempfile.NamedTemporaryFile()
+
+ def tearDown(self):
+ self.file.close()
+
+ def _create_store(self):
+ return c10d.FileStore(self.file.name)
+
+
+class TCPStoreTest(TestCase, StoreTestBase):
+ def _create_store(self):
+ return c10d.TCPStore(TCP_ADDR, TCP_PORT, True)
+
+
+class ProcessGroupGlooTest(TestCase):
+ MAIN_PROCESS_RANK = -1
+
+ @staticmethod
+ def join_or_run(fn):
+ @wraps(fn)
+ def wrapper(self):
+ if self.rank == self.MAIN_PROCESS_RANK:
+ self._join_processes(fn)
+ else:
+ fn(self)
+ return wrapper
+
+ # The main process spawns N subprocesses that run the test.
+ # This function patches overwrites every test function to either
+ # assume the role of the main process and join its subprocesses,
+ # or run the underlying test function.
+ @classmethod
+ def setUpClass(cls):
+ for attr in dir(cls):
+ if attr.startswith('test'):
+ fn = getattr(cls, attr)
+ setattr(cls, attr, cls.join_or_run(fn))
+
+ def setUp(self):
+ self.rank = self.MAIN_PROCESS_RANK
+ self.size = 4
+ self.file = tempfile.NamedTemporaryFile()
+ self.processes = [self._spawn_process(rank) for rank in range(int(self.size))]
+
+ def tearDown(self):
+ for p in self.processes:
+ p.terminate()
+ self.file.close()
+
+ def _spawn_process(self, rank):
+ name = 'process ' + str(rank)
+ process = multiprocessing.Process(target=self._run, name=name, args=(rank,))
+ process.start()
+ return process
+
+ def _run(self, rank):
+ self.rank = rank
+
+ # self.id() == e.g. '__main__.TestDistributed.test_get_rank'
+ # We're retreiving a corresponding test and executing it.
+ getattr(self, self.id().split(".")[2])()
+ sys.exit(0)
+
+ def _join_processes(self, fn):
+ timeout = get_timeout(self.id())
+ for p in self.processes:
+ p.join(timeout)
+
+ def test_broadcast_ops(self):
+ store = c10d.FileStore(self.file.name)
+ pg = c10d.ProcessGroupGloo(store, self.rank, self.size)
+
+ def broadcast(xs, rootRank, rootTensor):
+ opts = c10d.BroadcastOptions()
+ opts.rootRank = rootRank
+ opts.rootTensor = rootTensor
+ work = pg.broadcast(xs, opts)
+ work.wait()
+
+ # Every rank is root once, every tensor index is root once
+ for i in range(self.size):
+ for j in range(2):
+ xs = [
+ torch.Tensor([self.rank * self.size + 0.0]),
+ torch.Tensor([self.rank * self.size + 1.0]),
+ ]
+
+ broadcast(xs, i, j)
+ self.assertEqual(torch.Tensor([i * self.size + j]), xs[0])
+ self.assertEqual(torch.Tensor([i * self.size + j]), xs[1])
+
+ def test_allreduce_ops(self):
+ store = c10d.FileStore(self.file.name)
+ pg = c10d.ProcessGroupGloo(store, self.rank, self.size)
+
+ def allreduce(x, op):
+ opts = c10d.AllreduceOptions()
+ opts.reduceOp = op
+ work = pg.allreduce([x], opts)
+ work.wait()
+
+ # Sum
+ x = torch.Tensor([self.rank + 1.0])
+ allreduce(x, c10d.ReduceOp.SUM)
+ self.assertEqual(torch.Tensor([float(self.size * (self.size + 1) / 2)]), x)
+
+ # Product
+ x = torch.Tensor([self.rank + 1.0])
+ allreduce(x, c10d.ReduceOp.PRODUCT)
+ self.assertEqual(torch.Tensor([float(math.factorial(self.size))]), x)
+
+ # Min
+ x = torch.Tensor([self.rank + 1.0])
+ allreduce(x, c10d.ReduceOp.MIN)
+ self.assertEqual(torch.Tensor([1.0]), x)
+
+ # Max
+ x = torch.Tensor([self.rank + 1.0])
+ allreduce(x, c10d.ReduceOp.MAX)
+ self.assertEqual(torch.Tensor([self.size]), x)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tools/build_pytorch_libs.sh b/tools/build_pytorch_libs.sh
index 0323b96c83..8320111541 100755
--- a/tools/build_pytorch_libs.sh
+++ b/tools/build_pytorch_libs.sh
@@ -290,6 +290,10 @@ for arg in "$@"; do
pushd "$TORCH_LIB_DIR"
build $arg
popd
+ elif [[ "$arg" == "c10d" ]]; then
+ pushd "$TORCH_LIB_DIR"
+ build c10d
+ popd
else
pushd "$THIRD_PARTY_DIR"
build $arg
diff --git a/tools/setup_helpers/dist_check.py b/tools/setup_helpers/dist_check.py
index afec6ce14c..2d77b3a3e0 100644
--- a/tools/setup_helpers/dist_check.py
+++ b/tools/setup_helpers/dist_check.py
@@ -2,12 +2,14 @@ import os
import subprocess
import glob
-from .env import IS_CONDA, IS_WINDOWS, CONDA_DIR, check_env_flag, gather_paths
+from .env import IS_CONDA, IS_LINUX, IS_WINDOWS, CONDA_DIR, check_env_flag, gather_paths
+from .cuda import WITH_CUDA
# On ROCm, RCCL development isn't complete. https://github.com/ROCmSoftwarePlatform/rccl
WITH_DISTRIBUTED = not check_env_flag("NO_DISTRIBUTED") and not IS_WINDOWS and not check_env_flag("WITH_ROCM")
WITH_DISTRIBUTED_MW = WITH_DISTRIBUTED and check_env_flag("WITH_DISTRIBUTED_MW")
WITH_GLOO_IBVERBS = False
+WITH_C10D = WITH_DISTRIBUTED and WITH_CUDA and IS_LINUX
IB_DEVINFO_CMD = "ibv_devinfo"
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 9e945c1218..b4f3e2d0f8 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -42,6 +42,10 @@
#include "cudnn.h"
#endif
+#ifdef WITH_C10D
+#include "torch/csrc/distributed/c10d/c10d.h"
+#endif
+
#define WITH_NUMPY_IMPORT_ARRAY
#include "torch/csrc/utils/numpy_stub.h"
@@ -487,6 +491,9 @@ static PyObject* initModule() {
#ifdef WITH_DISTRIBUTED
THPUtils_addPyMethodDefs(methods, THDPModule_methods());
#endif
+#ifdef WITH_C10D
+ THPUtils_addPyMethodDefs(methods, torch::distributed::c10d::python_functions());
+#endif
#if PY_MAJOR_VERSION == 2
ASSERT_TRUE(module = Py_InitModule("torch._C", methods.data()));
diff --git a/torch/csrc/distributed/c10d/c10d.h b/torch/csrc/distributed/c10d/c10d.h
new file mode 100644
index 0000000000..e91d4cb63b
--- /dev/null
+++ b/torch/csrc/distributed/c10d/c10d.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include "torch/csrc/python_headers.h"
+
+namespace torch {
+namespace distributed {
+namespace c10d {
+
+PyMethodDef* python_functions();
+
+} // namespace c10d
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
new file mode 100644
index 0000000000..62d8b4a5bd
--- /dev/null
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -0,0 +1,125 @@
+#include "torch/csrc/python_headers.h"
+
+#include <c10d/FileStore.hpp>
+#include <c10d/ProcessGroup.hpp>
+#include <c10d/ProcessGroupGloo.hpp>
+#include <c10d/TCPStore.hpp>
+
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/utils/object_ptr.h"
+#include "torch/csrc/utils/pybind.h"
+
+namespace torch {
+namespace distributed {
+namespace c10d {
+
+namespace {
+
+template <typename T>
+using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
+
+PyObject* c10d_init(PyObject* _unused) {
+ auto c10d_module = THPObjectPtr(PyImport_ImportModule("torch.distributed.c10d"));
+ if (!c10d_module) {
+ throw python_error();
+ }
+
+ auto module = py::handle(c10d_module).cast<py::module>();
+
+ auto store =
+ shared_ptr_class_<::c10d::Store>(module, "Store")
+ // Convert from std::string to std::vector<uint8>.
+ .def(
+ "set",
+ [](::c10d::Store& store,
+ const std::string& key,
+ const std::string& value) {
+ std::vector<uint8_t> value_(value.begin(), value.end());
+ store.set(key, value_);
+ },
+ py::call_guard<py::gil_scoped_release>())
+ // Convert from std::vector<uint8_t> to py::bytes.
+ // The returned value is not guaranteed to be valid UTF-8.
+ .def(
+ "get",
+ [](::c10d::Store& store, const std::string& key) -> py::bytes {
+ auto value = store.get(key);
+ return py::bytes(
+ reinterpret_cast<char*>(value.data()), value.size());
+ },
+ py::call_guard<py::gil_scoped_release>())
+ .def(
+ "add",
+ &::c10d::Store::add,
+ py::call_guard<py::gil_scoped_release>())
+ .def(
+ "wait",
+ &::c10d::Store::wait,
+ py::call_guard<py::gil_scoped_release>());
+
+ shared_ptr_class_<::c10d::FileStore>(module, "FileStore", store)
+ .def(py::init<const std::string&>());
+
+ shared_ptr_class_<::c10d::TCPStore>(module, "TCPStore", store)
+ .def(py::init<const std::string&, int, bool>());
+
+ auto processGroup =
+ shared_ptr_class_<::c10d::ProcessGroup>(module, "ProcessGroup")
+ .def("rank", &::c10d::ProcessGroup::getRank)
+ .def("size", &::c10d::ProcessGroup::getSize)
+ .def(
+ "broadcast",
+ &::c10d::ProcessGroup::broadcast,
+ py::call_guard<py::gil_scoped_release>())
+ .def(
+ "allreduce",
+ &::c10d::ProcessGroup::allreduce,
+ py::call_guard<py::gil_scoped_release>());
+
+ shared_ptr_class_<::c10d::ProcessGroupGloo>(
+ module, "ProcessGroupGloo", processGroup)
+ .def(py::init<const std::shared_ptr<::c10d::Store>&, int, int>());
+
+ shared_ptr_class_<::c10d::ProcessGroup::Work>(module, "Work")
+ .def("isCompleted", &::c10d::ProcessGroup::Work::isCompleted)
+ .def("isSuccess", &::c10d::ProcessGroup::Work::isSuccess)
+ .def("exception", &::c10d::ProcessGroup::Work::exception)
+ .def("synchronize", &::c10d::ProcessGroup::Work::synchronize)
+ .def(
+ "wait",
+ &::c10d::ProcessGroup::Work::wait,
+ py::call_guard<py::gil_scoped_release>());
+
+ // Algorithm specific option structs and enums
+ py::class_<::c10d::BroadcastOptions>(module, "BroadcastOptions")
+ .def(py::init<>())
+ .def_readwrite("rootRank", &::c10d::BroadcastOptions::rootRank)
+ .def_readwrite("rootTensor", &::c10d::BroadcastOptions::rootTensor);
+
+ py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions")
+ .def(py::init<>())
+ .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp);
+
+ py::enum_<::c10d::ReduceOp>(module, "ReduceOp")
+ .value("SUM", ::c10d::ReduceOp::SUM)
+ .value("PRODUCT", ::c10d::ReduceOp::PRODUCT)
+ .value("MIN", ::c10d::ReduceOp::MIN)
+ .value("MAX", ::c10d::ReduceOp::MAX);
+
+ Py_RETURN_TRUE;
+}
+
+} // namespace
+
+// c10d methods on torch._C
+static PyMethodDef methods[] = {
+ {"_c10d_init", (PyCFunction)c10d_init, METH_NOARGS, nullptr},
+ {nullptr, nullptr, 0, nullptr}};
+
+PyMethodDef* python_functions() {
+ return methods;
+}
+
+} // namespace c10d
+} // namespace distributed
+} // namespace torch
diff --git a/torch/distributed/c10d/__init__.py b/torch/distributed/c10d/__init__.py
new file mode 100644
index 0000000000..1d5c085295
--- /dev/null
+++ b/torch/distributed/c10d/__init__.py
@@ -0,0 +1,8 @@
+import torch
+
+
+def is_available():
+ return hasattr(torch._C, '_c10d_init')
+
+if is_available() and not torch._C._c10d_init():
+ raise RuntimeError("c10d initialization failed")
diff --git a/torch/lib/c10d/CMakeLists.txt b/torch/lib/c10d/CMakeLists.txt
index 1fb2141dbf..dec330dfc0 100644
--- a/torch/lib/c10d/CMakeLists.txt
+++ b/torch/lib/c10d/CMakeLists.txt
@@ -59,16 +59,22 @@ if(NOT CUDA_FOUND)
message(FATAL_ERROR "CUDA not found")
endif()
+function(copy_header file)
+ configure_file(${file} ${CMAKE_BINARY_DIR}/include/c10d/${file} COPYONLY)
+endfunction()
+
set(C10D_SRCS
- Utils.cpp
- Store.cpp
+ CUDAUtils.cpp
FileStore.cpp
- TCPStore.cpp
ProcessGroup.cpp
- CUDAUtils.cpp
+ Store.cpp
+ TCPStore.cpp
+ Utils.cpp
)
add_library(c10d ${C10D_SRCS})
+set_property(TARGET c10d PROPERTY POSITION_INDEPENDENT_CODE ON)
+set_property(TARGET c10d PROPERTY CXX_STANDARD 11)
target_compile_options(c10d PUBLIC
-Wall
-Wextra
@@ -78,6 +84,15 @@ target_compile_options(c10d PUBLIC
-Wno-unknown-pragmas
)
target_link_libraries(c10d PUBLIC caffe2_gpu)
+copy_header(CUDAUtils.hpp)
+copy_header(FileStore.hpp)
+copy_header(ProcessGroup.hpp)
+copy_header(Store.hpp)
+copy_header(TCPStore.hpp)
+copy_header(Types.hpp)
+copy_header(Utils.hpp)
+target_include_directories(c10d PRIVATE ${CMAKE_BINARY_DIR}/include)
+install(TARGETS c10d ARCHIVE DESTINATION lib)
# c10d links to Caffe2/ATen, but the targets don't add TH/THC to the include path
target_include_directories(c10d PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../tmp_install/include/TH)
@@ -91,16 +106,26 @@ set(C10D_GLOO_SRCS
)
add_library(c10d_gloo ${C10D_GLOO_SRCS})
+set_property(TARGET c10d_gloo PROPERTY POSITION_INDEPENDENT_CODE ON)
+set_property(TARGET c10d_gloo PROPERTY CXX_STANDARD 11)
target_include_directories(c10d_gloo PUBLIC ${GLOO_INCLUDE_DIR})
target_link_libraries(c10d_gloo PUBLIC c10d ${Gloo_NATIVE_LIBRARY} ${Gloo_LIBRARY})
+copy_header(ProcessGroupGloo.hpp)
+target_include_directories(c10d PRIVATE ${CMAKE_BINARY_DIR}/include)
+install(TARGETS c10d_gloo ARCHIVE DESTINATION lib)
if(MPI_FOUND)
set(C10D_MPI_SRCS
ProcessGroupMPI.cpp
)
add_library(c10d_mpi ${C10D_MPI_SRCS})
+ set_property(TARGET c10d_mpi PROPERTY POSITION_INDEPENDENT_CODE ON)
+ set_property(TARGET c10d_mpi PROPERTY CXX_STANDARD 11)
target_include_directories(c10d_mpi PUBLIC ${MPI_INCLUDE_PATH})
target_link_libraries(c10d_mpi PUBLIC c10d ${MPI_LIBRARIES})
+ copy_header(ProcessGroupMPI.hpp)
+ target_include_directories(c10d_mpi PRIVATE ${CMAKE_BINARY_DIR}/include)
+ install(TARGETS c10d_mpi ARCHIVE DESTINATION lib)
endif()
if(DISTRIBUTED_NCCL_FOUND)
@@ -112,7 +137,16 @@ if(DISTRIBUTED_NCCL_FOUND)
target_link_libraries(c10d_nccl PUBLIC c10d ${NCCL_LIBRARIES})
endif()
-add_subdirectory(example)
+option(BUILD_EXAMPLES "Build examples" OFF)
+if(BUILD_EXAMPLES)
+ add_subdirectory(example)
+endif()
+
+option(BUILD_TEST "Build tests" OFF)
+if(BUILD_TEST)
+ enable_testing()
+ add_subdirectory(test)
+endif()
-enable_testing()
-add_subdirectory(test)
+# Install all header files that were prepared in the build directory
+install(DIRECTORY ${CMAKE_BINARY_DIR}/include/ DESTINATION include)