diff options
author | Pieter Noordhuis <pietern@fb.com> | 2017-09-08 10:42:08 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2017-09-08 10:57:41 -0700 |
commit | b8eb8ced7dadcb54238909b7bcd32a813593be4b (patch) | |
tree | 12d505ae464417367d073ba026d559dc6134e33d /caffe2/python/examples | |
parent | fdbfcfc43112dcea5e12b3826ec71d2ccc5026dc (diff) | |
download | pytorch-b8eb8ced7dadcb54238909b7bcd32a813593be4b.tar.gz pytorch-b8eb8ced7dadcb54238909b7bcd32a813593be4b.tar.bz2 pytorch-b8eb8ced7dadcb54238909b7bcd32a813593be4b.zip |
Add transport/interface arguments to CreateCommonWorld operator
Summary:
These arguments control which Gloo transport (TCP or IB) and which
network interface is used for the common world. If not specified, it
defaults to using TCP and the network interface for the IP that the
machine's hostname resolves to.
The valid values for the transport argument are "tcp" and "ibverbs".
For ibverbs to work, Gloo must have been compiled with ibverbs
support. If Gloo is built as part of Caffe2 (sourced from the
third_party directory), then you can pass -DUSE_IBVERBS=ON to CMake to
enable ibverbs support in Gloo.
Closes https://github.com/caffe2/caffe2/pull/1177
Reviewed By: akyrola
Differential Revision: D5789729
Pulled By: pietern
fbshipit-source-id: 0dea1a115c729e54c5c1f9fdd5fb29c14a834a82
Diffstat (limited to 'caffe2/python/examples')
-rw-r--r-- | caffe2/python/examples/resnet50_trainer.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/caffe2/python/examples/resnet50_trainer.py b/caffe2/python/examples/resnet50_trainer.py index 746582470a..61c558e92d 100644 --- a/caffe2/python/examples/resnet50_trainer.py +++ b/caffe2/python/examples/resnet50_trainer.py @@ -250,11 +250,18 @@ def Train(args): prefix=args.run_id, ) ) + + # Expect interfaces to be comma separated. + # Use of multiple network interfaces is not yet complete, + # so simply use the first one in the list. + interfaces = args.distributed_interfaces.split(",") rendezvous = dict( kv_handler=store_handler, shard_id=shard_id, num_shards=num_shards, engine="GLOO", + transport=args.distributed_transport, + interface=interfaces[0], exit_nets=None) else: rendezvous = None @@ -490,6 +497,10 @@ def main(): help='Data type used for training') parser.add_argument('--enable-tensor-core', action='store_true', help='Enable Tensor Core math for Conv and FC ops') + parser.add_argument("--distributed_transport", type=str, default="tcp", + help="Transport to use for distributed run [tcp|ibverbs]") + parser.add_argument("--distributed_interfaces", type=str, default="", + help="Network interfaces to use for distributed run") args = parser.parse_args() |