summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorPeter Goldsborough <psag@fb.com>2019-01-11 19:45:40 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-11 19:50:18 -0800
commita4c1aa4bc542c7ff6e600b67e9a0aeb233718514 (patch)
treef38d4f26b2558c00464a3365b1e99541ccae9cd4 /test
parente5266b4ba63182f9f86024965ca5330b80b8ad36 (diff)
downloadpytorch-a4c1aa4bc542c7ff6e600b67e9a0aeb233718514.tar.gz
pytorch-a4c1aa4bc542c7ff6e600b67e9a0aeb233718514.tar.bz2
pytorch-a4c1aa4bc542c7ff6e600b67e9a0aeb233718514.zip
Add the normalize transform to the core library (#15891)
Summary: Adds the `Normalize` transform to the core C++ frontend library. ebetica ezyang soumith Pull Request resolved: https://github.com/pytorch/pytorch/pull/15891 Differential Revision: D13642167 Pulled By: goldsborough fbshipit-source-id: 573428e626d6106cf2aadf3dc2e2aecb9a85efc3
Diffstat (limited to 'test')
-rw-r--r--test/cpp/api/dataloader.cpp74
1 files changed, 74 insertions, 0 deletions
diff --git a/test/cpp/api/dataloader.cpp b/test/cpp/api/dataloader.cpp
index 461dfe5633..dbb6349383 100644
--- a/test/cpp/api/dataloader.cpp
+++ b/test/cpp/api/dataloader.cpp
@@ -450,6 +450,80 @@ TEST(DataTest, TensorLambdaWorksforAnyTargetType) {
ASSERT_EQ(batch[1].target, "2");
}
+struct DummyTensorDataset
+ : datasets::Dataset<DummyTensorDataset, Example<torch::Tensor, int>> {
+ Example<torch::Tensor, int> get(size_t index) override {
+ const auto channels = static_cast<int64_t>(index);
+ torch::Tensor tensor =
+ (channels > 0) ? torch::ones({channels, 4, 4}) : torch::ones({4, 4});
+ return {tensor, static_cast<int>(channels)};
+ }
+
+ torch::optional<size_t> size() const override {
+ return 100;
+ }
+};
+
+TEST(DataTest, NormalizeTransform) {
+ auto dataset = DummyTensorDataset().map(transforms::Normalize<int>(0.5, 0.1));
+
+ // Works for zero (one implicit) channels
+ std::vector<Example<torch::Tensor, int>> output = dataset.get_batch(0);
+ ASSERT_EQ(output.size(), 1);
+ // (1 - 0.5) / 0.1 = 5
+ ASSERT_TRUE(output[0].data.allclose(torch::ones({4, 4}) * 5))
+ << output[0].data;
+
+ // Works for one explicit channel
+ output = dataset.get_batch(1);
+ ASSERT_EQ(output.size(), 1);
+ ASSERT_EQ(output[0].data.size(0), 1);
+ ASSERT_TRUE(output[0].data.allclose(torch::ones({1, 4, 4}) * 5))
+ << output[0].data;
+
+ // Works for two channels with different moments
+ dataset = DummyTensorDataset().map(
+ transforms::Normalize<int>({0.5, 1.5}, {0.1, 0.2}));
+ output = dataset.get_batch(2);
+ ASSERT_EQ(output.size(), 1);
+ ASSERT_EQ(output[0].data.size(0), 2);
+ ASSERT_TRUE(output[0]
+ .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
+ .allclose(torch::ones({1, 4, 4}) * 5))
+ << output[0].data;
+ ASSERT_TRUE(output[0]
+ .data.slice(/*dim=*/0, /*start=*/1)
+ .allclose(torch::ones({1, 4, 4}) * -2.5))
+ << output[0].data;
+
+ // Works for three channels with one moment value
+ dataset = DummyTensorDataset().map(transforms::Normalize<int>(1.5, 0.2));
+ output = dataset.get_batch(3);
+ ASSERT_EQ(output.size(), 1);
+ ASSERT_EQ(output[0].data.size(0), 3);
+ ASSERT_TRUE(output[0].data.allclose(torch::ones({3, 4, 4}) * -2.5))
+ << output[0].data;
+
+ // Works for three channels with different moments
+ dataset = DummyTensorDataset().map(
+ transforms::Normalize<int>({0.5, 1.5, -1.5}, {0.1, 0.2, 0.2}));
+ output = dataset.get_batch(3);
+ ASSERT_EQ(output.size(), 1);
+ ASSERT_EQ(output[0].data.size(0), 3);
+ ASSERT_TRUE(output[0]
+ .data.slice(/*dim=*/0, /*start=*/0, /*end=*/1)
+ .allclose(torch::ones({1, 4, 4}) * 5))
+ << output[0].data;
+ ASSERT_TRUE(output[0]
+ .data.slice(/*dim=*/0, /*start=*/1, /*end=*/2)
+ .allclose(torch::ones({1, 4, 4}) * -2.5))
+ << output[0].data;
+ ASSERT_TRUE(output[0]
+ .data.slice(/*dim=*/0, /*start=*/2)
+ .allclose(torch::ones({1, 4, 4}) * 12.5))
+ << output[0].data;
+}
+
struct UnCopyableDataset : public datasets::Dataset<UnCopyableDataset> {
UnCopyableDataset() = default;