summaryrefslogtreecommitdiff
path: root/compiler/exo/src/TFLite/TFLTypeInference.test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/exo/src/TFLite/TFLTypeInference.test.cpp')
-rw-r--r--compiler/exo/src/TFLite/TFLTypeInference.test.cpp118
1 files changed, 118 insertions, 0 deletions
diff --git a/compiler/exo/src/TFLite/TFLTypeInference.test.cpp b/compiler/exo/src/TFLite/TFLTypeInference.test.cpp
new file mode 100644
index 000000000..0712f0a25
--- /dev/null
+++ b/compiler/exo/src/TFLite/TFLTypeInference.test.cpp
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TFLTypeInference.h"
+#include "Pass/TypeInferencePass.h"
+
+#include <loco/IR/PermutingCodec.h>
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+using stdex::make_unique;
+
+namespace
+{
+
+class Sequential
+{
+public:
+ loco::Pull *addPullLayer(const loco::DataType &dtype = loco::DataType::FLOAT32)
+ {
+ loco::Pull *pull = _graph.nodes()->create<loco::Pull>();
+
+ auto graph_input = _graph.inputs()->create();
+ graph_input->name("graph_input");
+ loco::link(graph_input, pull);
+
+ pull->dtype(dtype);
+ setSampleShape(pull);
+
+ return last(pull);
+ }
+
+ loco::ReLU *addReLULayer(void)
+ {
+ loco::ReLU *relu = _graph.nodes()->create<loco::ReLU>();
+
+ relu->input(_last);
+
+ return last(relu);
+ }
+
+ loco::Push *addPushLayer(void)
+ {
+ loco::Push *push = _graph.nodes()->create<loco::Push>();
+
+ auto graph_output = _graph.outputs()->create();
+ graph_output->name("graph_output");
+ loco::link(graph_output, push);
+
+ push->from(_last);
+
+ return last(push);
+ }
+
+ loco::Graph *graph() { return &_graph; }
+
+private:
+ template <typename T> uint32_t setSampleShape(T *op)
+ {
+ const uint32_t n = 1;
+ const uint32_t h = 100;
+ const uint32_t w = 100;
+ const uint32_t c = 3;
+ op->rank(4);
+ op->dim(0).set(n);
+ op->dim(1).set(c);
+ op->dim(2).set(h);
+ op->dim(3).set(w);
+ return n * h * w * c;
+ }
+
+ template <typename T> T *last(T *node)
+ {
+ _last = node;
+ return node;
+ }
+
+private:
+ loco::Graph _graph;
+ loco::Node *_last;
+};
+
+struct TypeInferenceTest : public Sequential, public ::testing::Test
+{
+ virtual ~TypeInferenceTest() = default;
+};
+
+} // namespace
+
+// TypeInference SHOULD PROPAGATE type information properly
+TEST_F(TypeInferenceTest, Regression_0000)
+{
+ auto pull = addPullLayer(loco::DataType::S8);
+ auto relu = addReLULayer();
+ auto push = addPushLayer();
+
+ using namespace exo;
+
+ TypeInferencePass type_inf_pass;
+ type_inf_pass.run(graph());
+
+ ASSERT_EQ(TypeInference::get(relu), tflite::TensorType_INT8);
+ ASSERT_EQ(TypeInference::get(push), tflite::TensorType_INT8);
+}