diff options
author | Xing Liu <xingl@fb.com> | 2021-09-21 09:38:04 -0700 |
---|---|---|
committer | Facebook GitHub Bot <facebook-github-bot@users.noreply.github.com> | 2021-09-21 09:58:54 -0700 |
commit | 600df80296cfca8d4ec388856c9d832da5b51214 (patch) | |
tree | 5f7e166a930fbfdc563aae86a9933c722297afc2 | |
parent | 7f6580a868026ea31c7320d8d0e60e90bbbaca94 (diff) | |
download | pytorch-600df80296cfca8d4ec388856c9d832da5b51214.tar.gz pytorch-600df80296cfca8d4ec388856c9d832da5b51214.tar.bz2 pytorch-600df80296cfca8d4ec388856c9d832da5b51214.zip |
[PT/ShardedTensor]Allow zero size local shard (#65007)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65007
Relax shard size check in ShardMetadata to allow zero size local shard.
When sharding a tensor on N ranks, some ranks may have empty shard allocated. As we are assuming SPMD, the ranks w/ empty shard still need to participate in all collectives, and we need to allow this in ShardMetadata.
Test Plan: Unit tests and CLI
Reviewed By: jiaqizhai, wanchaol
Differential Revision: D30926566
fbshipit-source-id: afa562c94ffa8f8d91d65ddb4c348156d871dc36
-rw-r--r-- | test/distributed/_sharding_spec/test_sharding_spec.py | 4 | ||||
-rw-r--r-- | torch/distributed/_sharding_spec/_internals.py | 4 |
2 files changed, 4 insertions, 4 deletions
diff --git a/test/distributed/_sharding_spec/test_sharding_spec.py b/test/distributed/_sharding_spec/test_sharding_spec.py index 409e7bd497..4709ff31cf 100644 --- a/test/distributed/_sharding_spec/test_sharding_spec.py +++ b/test/distributed/_sharding_spec/test_sharding_spec.py @@ -148,8 +148,8 @@ class TestShardingSpec(TestCase): with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'): ShardMetadata(shard_offsets=[-1, 0], shard_lengths=[1, 1], placement="cuda:0") - with self.assertRaisesRegex(ValueError, 'shard_lengths should be > 0'): - ShardMetadata(shard_offsets=[0, 0], shard_lengths=[0, 1], placement="cuda:0") + with self.assertRaisesRegex(ValueError, 'shard_lengths should be >= 0'): + ShardMetadata(shard_offsets=[0, 0], shard_lengths=[-1, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'Empty shard list provided'): EnumerableShardingSpec([]) diff --git a/torch/distributed/_sharding_spec/_internals.py b/torch/distributed/_sharding_spec/_internals.py index 568d11cfbb..afeeaeb1f6 100644 --- a/torch/distributed/_sharding_spec/_internals.py +++ b/torch/distributed/_sharding_spec/_internals.py @@ -40,8 +40,8 @@ class ShardMetadata(object): for i in range(len(self.shard_offsets)): if self.shard_offsets[i] < 0: raise ValueError('shard_offsets should be >=0') - if self.shard_lengths[i] <= 0: - raise ValueError('shard_lengths should be > 0') + if self.shard_lengths[i] < 0: + raise ValueError('shard_lengths should be >= 0') |