summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2019-09-17 07:35:46 (GMT)
committerGitHub Enterprise <noreply-CODE@samsung.com>2019-09-17 07:35:46 (GMT)
commitb3d041950855f5f3d2884012ab004394e430ff0d (patch)
treea2142e5061a25112b14c86cecb8d22778e53516d
parent7b599c4025b12331edb9c28fa2b06c2b5c4ed675 (diff)
downloadnnfw-b3d041950855f5f3d2884012ab004394e430ff0d.zip
nnfw-b3d041950855f5f3d2884012ab004394e430ff0d.tar.gz
nnfw-b3d041950855f5f3d2884012ab004394e430ff0d.tar.bz2
[exo-tflite] Introduce Div and Sub (#7514)
This will introduce IR for Div and Sub Signed-off-by: SaeHie Park <saehie.park@samsung.com>
-rw-r--r--compiler/exo-tflite/src/Dialect/IR/TFLNodes.h32
-rw-r--r--compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst4
-rw-r--r--compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp22
-rw-r--r--compiler/exo-tflite/src/TFLFormattedGraph.cpp16
4 files changed, 66 insertions, 8 deletions
diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h
index f9ff222..42de748 100644
--- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h
+++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h
@@ -166,7 +166,21 @@ private:
// TODO TFLDepthwiseConv2D
-// TODO TFLDiv
+/**
+ * @brief DIV in TensorFlow Lite
+ */
+class TFLDiv final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::DIV>>
+{
+public:
+ TFLDiv() = default;
+
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *y(void) const { return at(1)->node(); }
+ void y(loco::Node *node) { at(1)->node(node); }
+};
/**
* @brief MAX_POOL_2D in TensorFlow Lite
@@ -234,7 +248,21 @@ public:
// TODO TFLSqrt
-// TODO TFLSub
+/**
+ * @brief SUB in TensorFlow Lite
+ */
+class TFLSub final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::SUB>>
+{
+public:
+ TFLSub() = default;
+
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *y(void) const { return at(1)->node(); }
+ void y(loco::Node *node) { at(1)->node(node); }
+};
// TODO TFLTanh
diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst
index e870282..a77ac80 100644
--- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst
+++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst
@@ -10,7 +10,7 @@ TFL_NODE(AVERAGE_POOL_2D, locoex::TFLAveragePool2D)
// TODO TFLConcatenation
// TODO TFLConv2D
// TODO TFLDepthwiseConv2D
-// TODO TFLDiv
+TFL_NODE(DIV, locoex::TFLDiv)
TFL_NODE(MAX_POOL_2D, locoex::TFLMaxPool2D)
TFL_NODE(MUL, locoex::TFLMul)
TFL_NODE(RELU, locoex::TFLRelu)
@@ -18,6 +18,6 @@ TFL_NODE(RELU, locoex::TFLRelu)
// TODO TFLReshape
// TODO TFLSoftmax
// TODO TFLSqrt
-// TODO TFLSub
+TFL_NODE(SUB, locoex::TFLSub)
// TODO TFLTanh
// TODO TFLTranspose
diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp
index efe23f1..b52b452 100644
--- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp
+++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp
@@ -40,7 +40,16 @@ TEST(TFLAddTest, constructor)
// TODO TFLDepthwiseConv2D
-// TODO TFLDiv
+TEST(TFLDivTest, constructor)
+{
+ locoex::TFLDiv div_node;
+
+ ASSERT_EQ(div_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(div_node.opcode(), locoex::TFLOpcode::DIV);
+
+ ASSERT_EQ(div_node.x(), nullptr);
+ ASSERT_EQ(div_node.y(), nullptr);
+}
// TODO TFLMaxPool2D
@@ -73,7 +82,16 @@ TEST(TFLReluTest, constructor)
// TODO TFLSqrt
-// TODO TFLSub
+TEST(TFLSubTest, constructor)
+{
+ locoex::TFLSub sub_node;
+
+ ASSERT_EQ(sub_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(sub_node.opcode(), locoex::TFLOpcode::SUB);
+
+ ASSERT_EQ(sub_node.x(), nullptr);
+ ASSERT_EQ(sub_node.y(), nullptr);
+}
// TODO TFLTanh
diff --git a/compiler/exo-tflite/src/TFLFormattedGraph.cpp b/compiler/exo-tflite/src/TFLFormattedGraph.cpp
index 8b1cfde..52ce0ff 100644
--- a/compiler/exo-tflite/src/TFLFormattedGraph.cpp
+++ b/compiler/exo-tflite/src/TFLFormattedGraph.cpp
@@ -131,7 +131,13 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLAveragePool2D *node,
// TODO TFLDepthwiseConv2D
-// TODO TFLDiv
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLDiv *node, locop::NodeSummary &s) const
+{
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.args().append("y", tbl()->lookup(node->y()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaxPool2D *node, locop::NodeSummary &s) const
{
@@ -163,7 +169,13 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu *node, locop::NodeSumm
// TODO TFLSqrt
-// TODO TFLSub
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLSub *node, locop::NodeSummary &s) const
+{
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.args().append("y", tbl()->lookup(node->y()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
// TODO TFLTanh