diff options
author | Erjia Guan <68879799+ejguan@users.noreply.github.com> | 2021-10-08 10:20:03 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-10-08 07:20:03 -0700 |
commit | a27906c250946ed2258e0e6a6ebb8e2a1e99eb60 (patch) | |
tree | a53076e1ba39e60272f06020ceba796c9e795f22 | |
parent | 49f52b6c074a049f7c301d4f4f2e78788b4cbbc1 (diff) | |
download | pytorch-a27906c250946ed2258e0e6a6ebb8e2a1e99eb60.tar.gz pytorch-a27906c250946ed2258e0e6a6ebb8e2a1e99eb60.tar.bz2 pytorch-a27906c250946ed2258e0e6a6ebb8e2a1e99eb60.zip |
Convert Sampler back to lazily construction (#63646) (#65926)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63646
Fixes #63609
Test Plan: Imported from OSS
Reviewed By: NivekT
Differential Revision: D30451774
Pulled By: ejguan
fbshipit-source-id: 550d77494326446d1a42b5da0559e0d384c47413
-rw-r--r-- | test/test_dataloader.py | 22 | ||||
-rw-r--r-- | torch/utils/data/sampler.py | 18 |
2 files changed, 33 insertions, 7 deletions
diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 07628f6559..6802cc8890 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -1524,6 +1524,28 @@ except RuntimeError as e: ): self.assertEqual(list(fn()), list(fn())) + for sampler in ( + RandomSampler(self.dataset, num_samples=5, replacement=True), + RandomSampler(self.dataset, replacement=False), + WeightedRandomSampler(weights, num_samples=5, replacement=True), + WeightedRandomSampler(weights, num_samples=5, replacement=False), + SubsetRandomSampler(range(10)), + ): + torch.manual_seed(0) + l1 = list(sampler) + list(sampler) + + torch.manual_seed(0) + l2 = list(sampler) + list(sampler) + self.assertEqual(l1, l2) + + its = (iter(sampler), iter(sampler)) + ls = ([], []) + for idx in range(len(sampler)): + for i in range(2): + if idx == 0: + torch.manual_seed(0) + ls[i].append(next(its[i])) + self.assertEqual(ls[0], ls[1]) def _test_sampler(self, **kwargs): indices = range(2, 12) # using a regular iterable diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 7903347099..232302e53c 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -112,15 +112,18 @@ class RandomSampler(Sampler[int]): def __iter__(self) -> Iterator[int]: n = len(self.data_source) if self.generator is None: - self.generator = torch.Generator() - self.generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item())) + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator if self.replacement: for _ in range(self.num_samples // 32): - yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=self.generator).tolist() - yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=self.generator).tolist() + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() else: - yield from torch.randperm(n, generator=self.generator).tolist() + yield from torch.randperm(n, generator=generator).tolist() def __len__(self) -> int: return self.num_samples @@ -140,7 +143,8 @@ class SubsetRandomSampler(Sampler[int]): self.generator = generator def __iter__(self) -> Iterator[int]: - return (self.indices[i] for i in torch.randperm(len(self.indices), generator=self.generator)) + for i in torch.randperm(len(self.indices), generator=self.generator): + yield self.indices[i] def __len__(self) -> int: return len(self.indices) @@ -183,7 +187,7 @@ class WeightedRandomSampler(Sampler[int]): def __iter__(self) -> Iterator[int]: rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator) - return iter(rand_tensor.tolist()) + yield from iter(rand_tensor.tolist()) def __len__(self) -> int: return self.num_samples |