diff options
author | Wei Yang <weiyang@fb.com> | 2018-09-11 20:20:54 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-09-11 20:27:07 -0700 |
commit | 54107ae8cf476b4ebe7c631f75273e0b014e748c (patch) | |
tree | e6df6e45f870cf5e464daaeacffd0202a9c43f01 /test/test_c10d.py | |
parent | 045f862574063dfe1f92b84f46cd97b2aeeaf829 (diff) | |
download | pytorch-54107ae8cf476b4ebe7c631f75273e0b014e748c.tar.gz pytorch-54107ae8cf476b4ebe7c631f75273e0b014e748c.tar.bz2 pytorch-54107ae8cf476b4ebe7c631f75273e0b014e748c.zip |
convert output_device at data_parallel from torch.device to index (#10189)
Summary:
- fixes #9984
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10189
Differential Revision: D9545390
Pulled By: weiyangfb
fbshipit-source-id: 3a6a705437553ba319e9fd4b7f676ff73857a27e
Diffstat (limited to 'test/test_c10d.py')
-rw-r--r-- | test/test_c10d.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/test/test_c10d.py b/test/test_c10d.py index 64bedb3183..ff9d87be76 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -567,8 +567,7 @@ class DistributedDataParallelTest(MultiProcessTestCase): def world_size(self): return 2 - def _test_ddp_with_process_group(self, process_group): - gpus = gpus_for_rank(self.world_size)[self.rank] + def _test_ddp_with_process_group(self, process_group, gpus): model = Net() ddp_model = DistributedDataParallel( copy.deepcopy(model).cuda(gpus[0]), @@ -620,14 +619,18 @@ class DistributedDataParallelTest(MultiProcessTestCase): options = c10d.ProcessGroupGloo.Options() options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) - self._test_ddp_with_process_group(process_group) + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_ddp_with_process_group(process_group, gpus) + self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus))) @skip_if_not_multigpu @skip_if_not_nccl def test_nccl_backend(self): store = c10d.TCPStore('localhost', self.port, self.is_master) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) - self._test_ddp_with_process_group(process_group) + gpus = gpus_for_rank(self.world_size)[self.rank] + self._test_ddp_with_process_group(process_group, gpus) + self._test_ddp_with_process_group(process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus))) @skip_if_not_multigpu def test_dist_broadcast_coalesced(self): |