summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXing Liu <xingl@fb.com>2021-09-21 09:38:04 -0700
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>2021-09-21 09:58:54 -0700
commit600df80296cfca8d4ec388856c9d832da5b51214 (patch)
tree5f7e166a930fbfdc563aae86a9933c722297afc2
parent7f6580a868026ea31c7320d8d0e60e90bbbaca94 (diff)
downloadpytorch-600df80296cfca8d4ec388856c9d832da5b51214.tar.gz
pytorch-600df80296cfca8d4ec388856c9d832da5b51214.tar.bz2
pytorch-600df80296cfca8d4ec388856c9d832da5b51214.zip
[PT/ShardedTensor]Allow zero size local shard (#65007)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65007 Relax shard size check in ShardMetadata to allow zero size local shard. When sharding a tensor on N ranks, some ranks may have empty shard allocated. As we are assuming SPMD, the ranks w/ empty shard still need to participate in all collectives, and we need to allow this in ShardMetadata. Test Plan: Unit tests and CLI Reviewed By: jiaqizhai, wanchaol Differential Revision: D30926566 fbshipit-source-id: afa562c94ffa8f8d91d65ddb4c348156d871dc36
-rw-r--r--test/distributed/_sharding_spec/test_sharding_spec.py4
-rw-r--r--torch/distributed/_sharding_spec/_internals.py4
2 files changed, 4 insertions, 4 deletions
diff --git a/test/distributed/_sharding_spec/test_sharding_spec.py b/test/distributed/_sharding_spec/test_sharding_spec.py
index 409e7bd497..4709ff31cf 100644
--- a/test/distributed/_sharding_spec/test_sharding_spec.py
+++ b/test/distributed/_sharding_spec/test_sharding_spec.py
@@ -148,8 +148,8 @@ class TestShardingSpec(TestCase):
with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'):
ShardMetadata(shard_offsets=[-1, 0], shard_lengths=[1, 1], placement="cuda:0")
- with self.assertRaisesRegex(ValueError, 'shard_lengths should be > 0'):
- ShardMetadata(shard_offsets=[0, 0], shard_lengths=[0, 1], placement="cuda:0")
+ with self.assertRaisesRegex(ValueError, 'shard_lengths should be >= 0'):
+ ShardMetadata(shard_offsets=[0, 0], shard_lengths=[-1, 1], placement="cuda:0")
with self.assertRaisesRegex(ValueError, 'Empty shard list provided'):
EnumerableShardingSpec([])
diff --git a/torch/distributed/_sharding_spec/_internals.py b/torch/distributed/_sharding_spec/_internals.py
index 568d11cfbb..afeeaeb1f6 100644
--- a/torch/distributed/_sharding_spec/_internals.py
+++ b/torch/distributed/_sharding_spec/_internals.py
@@ -40,8 +40,8 @@ class ShardMetadata(object):
for i in range(len(self.shard_offsets)):
if self.shard_offsets[i] < 0:
raise ValueError('shard_offsets should be >=0')
- if self.shard_lengths[i] <= 0:
- raise ValueError('shard_lengths should be > 0')
+ if self.shard_lengths[i] < 0:
+ raise ValueError('shard_lengths should be >= 0')