summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorErjia Guan <68879799+ejguan@users.noreply.github.com>2021-10-08 10:20:03 -0400
committerGitHub <noreply@github.com>2021-10-08 07:20:03 -0700
commita27906c250946ed2258e0e6a6ebb8e2a1e99eb60 (patch)
treea53076e1ba39e60272f06020ceba796c9e795f22
parent49f52b6c074a049f7c301d4f4f2e78788b4cbbc1 (diff)
downloadpytorch-a27906c250946ed2258e0e6a6ebb8e2a1e99eb60.tar.gz
pytorch-a27906c250946ed2258e0e6a6ebb8e2a1e99eb60.tar.bz2
pytorch-a27906c250946ed2258e0e6a6ebb8e2a1e99eb60.zip
Convert Sampler back to lazily construction (#63646) (#65926)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63646 Fixes #63609 Test Plan: Imported from OSS Reviewed By: NivekT Differential Revision: D30451774 Pulled By: ejguan fbshipit-source-id: 550d77494326446d1a42b5da0559e0d384c47413
-rw-r--r--test/test_dataloader.py22
-rw-r--r--torch/utils/data/sampler.py18
2 files changed, 33 insertions, 7 deletions
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index 07628f6559..6802cc8890 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -1524,6 +1524,28 @@ except RuntimeError as e:
):
self.assertEqual(list(fn()), list(fn()))
+ for sampler in (
+ RandomSampler(self.dataset, num_samples=5, replacement=True),
+ RandomSampler(self.dataset, replacement=False),
+ WeightedRandomSampler(weights, num_samples=5, replacement=True),
+ WeightedRandomSampler(weights, num_samples=5, replacement=False),
+ SubsetRandomSampler(range(10)),
+ ):
+ torch.manual_seed(0)
+ l1 = list(sampler) + list(sampler)
+
+ torch.manual_seed(0)
+ l2 = list(sampler) + list(sampler)
+ self.assertEqual(l1, l2)
+
+ its = (iter(sampler), iter(sampler))
+ ls = ([], [])
+ for idx in range(len(sampler)):
+ for i in range(2):
+ if idx == 0:
+ torch.manual_seed(0)
+ ls[i].append(next(its[i]))
+ self.assertEqual(ls[0], ls[1])
def _test_sampler(self, **kwargs):
indices = range(2, 12) # using a regular iterable
diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py
index 7903347099..232302e53c 100644
--- a/torch/utils/data/sampler.py
+++ b/torch/utils/data/sampler.py
@@ -112,15 +112,18 @@ class RandomSampler(Sampler[int]):
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
- self.generator = torch.Generator()
- self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ generator = torch.Generator()
+ generator.manual_seed(seed)
+ else:
+ generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
- yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist()
- yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist()
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
- yield from torch.randperm(n, generator=self.generator).tolist()
+ yield from torch.randperm(n, generator=generator).tolist()
def __len__(self) -> int:
return self.num_samples
@@ -140,7 +143,8 @@ class SubsetRandomSampler(Sampler[int]):
self.generator = generator
def __iter__(self) -> Iterator[int]:
- return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator))
+ for i in torch.randperm(len(self.indices), generator=self.generator):
+ yield self.indices[i]
def __len__(self) -> int:
return len(self.indices)
@@ -183,7 +187,7 @@ class WeightedRandomSampler(Sampler[int]):
def __iter__(self) -> Iterator[int]:
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
- return iter(rand_tensor.tolist())
+ yield from iter(rand_tensor.tolist())
def __len__(self) -> int:
return self.num_samples