summaryrefslogtreecommitdiff
path: root/caffe2/python/examples
diff options
context:
space:
mode:
authorPieter Noordhuis <pietern@fb.com>2017-09-08 10:42:08 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-09-08 10:57:41 -0700
commitb8eb8ced7dadcb54238909b7bcd32a813593be4b (patch)
tree12d505ae464417367d073ba026d559dc6134e33d /caffe2/python/examples
parentfdbfcfc43112dcea5e12b3826ec71d2ccc5026dc (diff)
downloadpytorch-b8eb8ced7dadcb54238909b7bcd32a813593be4b.tar.gz
pytorch-b8eb8ced7dadcb54238909b7bcd32a813593be4b.tar.bz2
pytorch-b8eb8ced7dadcb54238909b7bcd32a813593be4b.zip
Add transport/interface arguments to CreateCommonWorld operator
Summary: These arguments control which Gloo transport (TCP or IB) and which network interface is used for the common world. If not specified, it defaults to using TCP and the network interface for the IP that the machine's hostname resolves to. The valid values for the transport argument are "tcp" and "ibverbs". For ibverbs to work, Gloo must have been compiled with ibverbs support. If Gloo is built as part of Caffe2 (sourced from the third_party directory), then you can pass -DUSE_IBVERBS=ON to CMake to enable ibverbs support in Gloo. Closes https://github.com/caffe2/caffe2/pull/1177 Reviewed By: akyrola Differential Revision: D5789729 Pulled By: pietern fbshipit-source-id: 0dea1a115c729e54c5c1f9fdd5fb29c14a834a82
Diffstat (limited to 'caffe2/python/examples')
-rw-r--r--caffe2/python/examples/resnet50_trainer.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/caffe2/python/examples/resnet50_trainer.py b/caffe2/python/examples/resnet50_trainer.py
index 746582470a..61c558e92d 100644
--- a/caffe2/python/examples/resnet50_trainer.py
+++ b/caffe2/python/examples/resnet50_trainer.py
@@ -250,11 +250,18 @@ def Train(args):
prefix=args.run_id,
)
)
+
+ # Expect interfaces to be comma separated.
+ # Use of multiple network interfaces is not yet complete,
+ # so simply use the first one in the list.
+ interfaces = args.distributed_interfaces.split(",")
rendezvous = dict(
kv_handler=store_handler,
shard_id=shard_id,
num_shards=num_shards,
engine="GLOO",
+ transport=args.distributed_transport,
+ interface=interfaces[0],
exit_nets=None)
else:
rendezvous = None
@@ -490,6 +497,10 @@ def main():
help='Data type used for training')
parser.add_argument('--enable-tensor-core', action='store_true',
help='Enable Tensor Core math for Conv and FC ops')
+ parser.add_argument("--distributed_transport", type=str, default="tcp",
+ help="Transport to use for distributed run [tcp|ibverbs]")
+ parser.add_argument("--distributed_interfaces", type=str, default="",
+ help="Network interfaces to use for distributed run")
args = parser.parse_args()