diff options
author | Peter Goldsborough <psag@fb.com> | 2019-01-11 19:45:40 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-11 19:50:18 -0800 |
commit | a4c1aa4bc542c7ff6e600b67e9a0aeb233718514 (patch) | |
tree | f38d4f26b2558c00464a3365b1e99541ccae9cd4 /test | |
parent | e5266b4ba63182f9f86024965ca5330b80b8ad36 (diff) | |
download | pytorch-a4c1aa4bc542c7ff6e600b67e9a0aeb233718514.tar.gz pytorch-a4c1aa4bc542c7ff6e600b67e9a0aeb233718514.tar.bz2 pytorch-a4c1aa4bc542c7ff6e600b67e9a0aeb233718514.zip |
Add the normalize transform to the core library (#15891)
Summary:
Adds the `Normalize` transform to the core C++ frontend library.
ebetica ezyang soumith
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15891
Differential Revision: D13642167
Pulled By: goldsborough
fbshipit-source-id: 573428e626d6106cf2aadf3dc2e2aecb9a85efc3
Diffstat (limited to 'test')
-rw-r--r-- | test/cpp/api/dataloader.cpp | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/test/cpp/api/dataloader.cpp b/test/cpp/api/dataloader.cpp index 461dfe5633..dbb6349383 100644 --- a/test/cpp/api/dataloader.cpp +++ b/test/cpp/api/dataloader.cpp @@ -450,6 +450,80 @@ TEST(DataTest, TensorLambdaWorksforAnyTargetType) { ASSERT_EQ(batch[1].target, "2"); } +struct DummyTensorDataset + : datasets::Dataset<DummyTensorDataset, Example<torch::Tensor, int>> { + Example<torch::Tensor, int> get(size_t index) override { + const auto channels = static_cast<int64_t>(index); + torch::Tensor tensor = + (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4}); + return {tensor, static_cast<int>(channels)}; + } + + torch::optional<size_t> size() const override { + return 100; + } +}; + +TEST(DataTest, NormalizeTransform) { + auto dataset = DummyTensorDataset().map(transforms::Normalize<int>(0.5, 0.1)); + + // Works for zero (one implicit) channels + std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0); + ASSERT_EQ(output.size(), 1); + // (1 - 0.5) / 0.1 = 5 + ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5)) + << output[0].data; + + // Works for one explicit channel + output = dataset.get_batch(1); + ASSERT_EQ(output.size(), 1); + ASSERT_EQ(output[0].data.size(0), 1); + ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5)) + << output[0].data; + + // Works for two channels with different moments + dataset = DummyTensorDataset().map( + transforms::Normalize<int>({0.5, 1.5}, {0.1, 0.2})); + output = dataset.get_batch(2); + ASSERT_EQ(output.size(), 1); + ASSERT_EQ(output[0].data.size(0), 2); + ASSERT_TRUE(output[0] + .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1) + .allclose(torch::ones({1, 4, 4}) * 5)) + << output[0].data; + ASSERT_TRUE(output[0] + .data.slice(/*dim=*/0, /*start=*/1) + .allclose(torch::ones({1, 4, 4}) * -2.5)) + << output[0].data; + + // Works for three channels with one moment value + dataset = DummyTensorDataset().map(transforms::Normalize<int>(1.5, 0.2)); + output = dataset.get_batch(3); + ASSERT_EQ(output.size(), 1); + ASSERT_EQ(output[0].data.size(0), 3); + ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5)) + << output[0].data; + + // Works for three channels with different moments + dataset = DummyTensorDataset().map( + transforms::Normalize<int>({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2})); + output = dataset.get_batch(3); + ASSERT_EQ(output.size(), 1); + ASSERT_EQ(output[0].data.size(0), 3); + ASSERT_TRUE(output[0] + .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1) + .allclose(torch::ones({1, 4, 4}) * 5)) + << output[0].data; + ASSERT_TRUE(output[0] + .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2) + .allclose(torch::ones({1, 4, 4}) * -2.5)) + << output[0].data; + ASSERT_TRUE(output[0] + .data.slice(/*dim=*/0, /*start=*/2) + .allclose(torch::ones({1, 4, 4}) * 12.5)) + << output[0].data; +} + struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> { UnCopyableDataset() = default; |