summaryrefslogtreecommitdiff
path: root/test/test_c10d.py
diff options
context:
space:
mode:
authorWei Yang <weiyang@fb.com>2018-09-11 20:20:54 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-09-11 20:27:07 -0700
commit54107ae8cf476b4ebe7c631f75273e0b014e748c (patch)
treee6df6e45f870cf5e464daaeacffd0202a9c43f01 /test/test_c10d.py
parent045f862574063dfe1f92b84f46cd97b2aeeaf829 (diff)
downloadpytorch-54107ae8cf476b4ebe7c631f75273e0b014e748c.tar.gz
pytorch-54107ae8cf476b4ebe7c631f75273e0b014e748c.tar.bz2
pytorch-54107ae8cf476b4ebe7c631f75273e0b014e748c.zip
convert output_device at data_parallel from torch.device to index (#10189)
Summary: - fixes #9984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/10189 Differential Revision: D9545390 Pulled By: weiyangfb fbshipit-source-id: 3a6a705437553ba319e9fd4b7f676ff73857a27e
Diffstat (limited to 'test/test_c10d.py')
-rw-r--r--test/test_c10d.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/test/test_c10d.py b/test/test_c10d.py
index 64bedb3183..ff9d87be76 100644
--- a/test/test_c10d.py
+++ b/test/test_c10d.py
@@ -567,8 +567,7 @@ class DistributedDataParallelTest(MultiProcessTestCase):
def world_size(self):
return 2
- def _test_ddp_with_process_group(self, process_group):
- gpus = gpus_for_rank(self.world_size)[self.rank]
+ def _test_ddp_with_process_group(self, process_group, gpus):
model = Net()
ddp_model = DistributedDataParallel(
copy.deepcopy(model).cuda(gpus[0]),
@@ -620,14 +619,18 @@ class DistributedDataParallelTest(MultiProcessTestCase):
options = c10d.ProcessGroupGloo.Options()
options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
- self._test_ddp_with_process_group(process_group)
+ gpus = gpus_for_rank(self.world_size)[self.rank]
+ self._test_ddp_with_process_group(process_group, gpus)
+ self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
@skip_if_not_multigpu
@skip_if_not_nccl
def test_nccl_backend(self):
store = c10d.TCPStore('localhost', self.port, self.is_master)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
- self._test_ddp_with_process_group(process_group)
+ gpus = gpus_for_rank(self.world_size)[self.rank]
+ self._test_ddp_with_process_group(process_group, gpus)
+ self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
@skip_if_not_multigpu
def test_dist_broadcast_coalesced(self):