summaryrefslogtreecommitdiff
path: root/compiler/exo-tflite
diff options
context:
space:
mode:
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>2019-09-11 14:04:37 +0900
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>2019-09-11 14:04:37 +0900
commit1bac9f7bca3af605431a5f2b9fc70dc870789016 (patch)
tree4f1f8f9adca1db0246da0ee1a2fffa1b6772a3f0 /compiler/exo-tflite
parent4fd603630e3987a9f657bacac917e0a2e6c9c75e (diff)
downloadnnfw-1bac9f7bca3af605431a5f2b9fc70dc870789016.tar.gz
nnfw-1bac9f7bca3af605431a5f2b9fc70dc870789016.tar.bz2
nnfw-1bac9f7bca3af605431a5f2b9fc70dc870789016.zip
[exo-tflite] shape inference for TFLAveragePool2D (#7345)
Adding shape inference for TFLAveragePool2D and two test cases. Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
Diffstat (limited to 'compiler/exo-tflite')
-rw-r--r--compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp51
-rw-r--r--compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp90
-rw-r--r--compiler/exo-tflite/src/TestGraph.h79
3 files changed, 219 insertions, 1 deletions
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
index e3ba39caf..0f8218515 100644
--- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
+++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
@@ -22,11 +22,57 @@
#include "ShapeInference.h"
+#include "Check.h"
+
#include <cassert>
namespace
{
+// Call this for TFLAvgPool2D and TFLMaxPool2D only
+template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
+{
+ EXO_ASSERT(loco::shape_known(node->value()), "Shape must be known");
+
+ auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>();
+
+ uint32_t input_height = ifm_shape.dim(1).value();
+ uint32_t input_width = ifm_shape.dim(2).value();
+ uint32_t stride_height = node->stride()->h();
+ uint32_t stride_width = node->stride()->w();
+ uint32_t window_height = node->filter()->h();
+ uint32_t window_width = node->filter()->w();
+ uint32_t dilation_height = 1; // dilation for TFLAvgPool2D and TFLMaxPool2D is 1
+ uint32_t dilation_width = 1;
+ uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
+ uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
+
+ uint32_t output_height;
+ uint32_t output_width;
+
+ if (node->padding() == locoex::Padding::VALID)
+ {
+ output_height = (input_height + stride_height - effective_window_height) / stride_height;
+ output_width = (input_width + stride_width - effective_window_width) / stride_width;
+ }
+ else if (node->padding() == locoex::Padding::SAME)
+ {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ }
+ else
+ EXO_ASSERT(false, "Wrong padding type");
+
+ loco::TensorShape ofm_shape;
+ ofm_shape.rank(4);
+ ofm_shape.dim(0) = ifm_shape.dim(0);
+ ofm_shape.dim(1) = output_height;
+ ofm_shape.dim(2) = output_width;
+ ofm_shape.dim(3) = ifm_shape.dim(3);
+
+ return loco::NodeShape{ofm_shape};
+}
+
/**
* @brief Class to infer the shape of TFLNode
*
@@ -52,7 +98,10 @@ public:
// TFLAdd
- // TFLAveragePool2D
+ loco::NodeShape visit(const locoex::TFLAveragePool2D *node) final
+ {
+ return infer_pool_2d_shape(node);
+ }
// TODO TFLConcatenation
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
index c5c375d58..eca72e618 100644
--- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
+++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp
@@ -14,6 +14,8 @@
* limitations under the License.
*/
+#include "TestGraph.h"
+
#include "Dialect/IR/TFLNodes.h"
#include "Dialect/IR/TFLDialect.h"
#include "Dialect/Service/TFLShapeInferenceRule.h"
@@ -81,3 +83,91 @@ TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu)
ASSERT_EQ(shape.dim(1), 4);
}
}
+
+// based on the case shown in
+// https://www.corvil.com/kb/what-is-the-difference-between-same-and-valid-padding-in-tf-nn-max-pool-of-tensorflow
+TEST(TFLShapeInferenceRuleTest, avgpool2d_valid)
+{
+ exo::test::PullPushGraph<locoex::TFLAveragePool2D> test_graph;
+ auto pull = test_graph.pull;
+ {
+ pull->shape({1, 4, 3, 1});
+ }
+ auto tfl_node = test_graph.middle_node;
+ {
+ tfl_node->filter()->h(2);
+ tfl_node->filter()->w(2);
+ tfl_node->stride()->h(2);
+ tfl_node->stride()->w(2);
+ tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ tfl_node->padding(locoex::Padding::VALID);
+ }
+ ASSERT_FALSE(loco::shape_known(tfl_node));
+
+ // shape inference
+ locoex::TFLShapeInferenceRule tfl_rule;
+ loco::CanonicalShapeInferenceRule canonical_rule;
+ loco::MultiDialectShapeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(locoex::TFLDialect::get(), &tfl_rule);
+
+ loco::apply(&rules).to(test_graph.g.get());
+
+ // Verify
+ {
+ ASSERT_TRUE(loco::shape_known(tfl_node));
+ ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+ auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+ ASSERT_EQ(shape.rank(), 4);
+ ASSERT_EQ(shape.dim(0).value(), 1);
+ ASSERT_EQ(shape.dim(1).value(), 2);
+ ASSERT_EQ(shape.dim(2).value(), 1);
+ ASSERT_EQ(shape.dim(3).value(), 1);
+ }
+}
+
+TEST(TFLShapeInferenceRuleTest, avgpool2d_same)
+{
+ exo::test::PullPushGraph<locoex::TFLAveragePool2D> test_graph;
+ auto pull = test_graph.pull;
+ {
+ pull->shape({1, 4, 3, 1});
+ }
+
+ auto tfl_node = test_graph.middle_node;
+ {
+ tfl_node->filter()->h(2);
+ tfl_node->filter()->w(2);
+ tfl_node->stride()->h(2);
+ tfl_node->stride()->w(2);
+ tfl_node->fusedActivationFunction(locoex::FusedActFunc::NONE);
+ tfl_node->padding(locoex::Padding::SAME);
+ }
+
+ ASSERT_FALSE(loco::shape_known(tfl_node));
+
+ // shape inference
+ locoex::TFLShapeInferenceRule tfl_rule;
+ loco::CanonicalShapeInferenceRule canonical_rule;
+ loco::MultiDialectShapeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(locoex::TFLDialect::get(), &tfl_rule);
+
+ loco::apply(&rules).to(test_graph.g.get());
+
+ // Verify
+ {
+ ASSERT_TRUE(loco::shape_known(tfl_node));
+ ASSERT_EQ(loco::shape_get(tfl_node).domain(), loco::Domain::Tensor);
+
+ auto shape = loco::shape_get(tfl_node).as<loco::TensorShape>();
+ ASSERT_EQ(shape.rank(), 4);
+ ASSERT_EQ(shape.dim(0).value(), 1);
+ ASSERT_EQ(shape.dim(1).value(), 2);
+ ASSERT_EQ(shape.dim(2).value(), 2);
+ ASSERT_EQ(shape.dim(3).value(), 1);
+ }
+}
diff --git a/compiler/exo-tflite/src/TestGraph.h b/compiler/exo-tflite/src/TestGraph.h
new file mode 100644
index 000000000..11903d304
--- /dev/null
+++ b/compiler/exo-tflite/src/TestGraph.h
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TEST_GRAPH_H__
+#define __TEST_GRAPH_H__
+
+#include "Dialect/IR/TFLNodes.h"
+
+#include <loco.h>
+
+#include <stdex/Memory.h>
+
+namespace exo
+{
+namespace test
+{
+
+// graph to build [Pull - some node of type T - Push]
+template <typename T> struct PullPushGraph
+{
+public:
+ std::unique_ptr<loco::Graph> g;
+ loco::Pull *pull;
+ loco::Push *push;
+ T *middle_node;
+
+ PullPushGraph()
+ {
+ // g = Pull - T - Push
+ g = loco::make_graph();
+
+ pull = g->nodes()->create<loco::Pull>();
+
+ middle_node = g->nodes()->create<T>();
+ {
+ setInput();
+ }
+
+ push = g->nodes()->create<loco::Push>();
+ {
+ push->from(middle_node);
+ }
+
+ auto input = g->inputs()->create();
+ {
+ input->name("input");
+ loco::link(input, pull);
+ }
+ auto output = g->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push);
+ }
+ }
+
+private:
+ void setInput(); // set the input of T
+};
+
+// setInput of TFL nodes
+template <> void PullPushGraph<locoex::TFLAveragePool2D>::setInput() { middle_node->value(pull); }
+
+} // namespace test
+} // namespace exo
+
+#endif // __TEST_GRAPH_H__