diff options
author | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-17 16:35:46 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2019-09-17 16:35:46 +0900 |
commit | b3d041950855f5f3d2884012ab004394e430ff0d (patch) | |
tree | a2142e5061a25112b14c86cecb8d22778e53516d | |
parent | 7b599c4025b12331edb9c28fa2b06c2b5c4ed675 (diff) | |
download | nnfw-b3d041950855f5f3d2884012ab004394e430ff0d.tar.gz nnfw-b3d041950855f5f3d2884012ab004394e430ff0d.tar.bz2 nnfw-b3d041950855f5f3d2884012ab004394e430ff0d.zip |
[exo-tflite] Introduce Div and Sub (#7514)
This will introduce IR for Div and Sub
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
-rw-r--r-- | compiler/exo-tflite/src/Dialect/IR/TFLNodes.h | 32 | ||||
-rw-r--r-- | compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst | 4 | ||||
-rw-r--r-- | compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp | 22 | ||||
-rw-r--r-- | compiler/exo-tflite/src/TFLFormattedGraph.cpp | 16 |
4 files changed, 66 insertions, 8 deletions
diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h index f9ff2223f..42de74806 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h @@ -166,7 +166,21 @@ private: // TODO TFLDepthwiseConv2D -// TODO TFLDiv +/** + * @brief DIV in TensorFlow Lite + */ +class TFLDiv final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::DIV>> +{ +public: + TFLDiv() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; /** * @brief MAX_POOL_2D in TensorFlow Lite @@ -234,7 +248,21 @@ public: // TODO TFLSqrt -// TODO TFLSub +/** + * @brief SUB in TensorFlow Lite + */ +class TFLSub final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::SUB>> +{ +public: + TFLSub() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; // TODO TFLTanh diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst index e87028258..a77ac80eb 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst @@ -10,7 +10,7 @@ TFL_NODE(AVERAGE_POOL_2D, locoex::TFLAveragePool2D) // TODO TFLConcatenation // TODO TFLConv2D // TODO TFLDepthwiseConv2D -// TODO TFLDiv +TFL_NODE(DIV, locoex::TFLDiv) TFL_NODE(MAX_POOL_2D, locoex::TFLMaxPool2D) TFL_NODE(MUL, locoex::TFLMul) TFL_NODE(RELU, locoex::TFLRelu) @@ -18,6 +18,6 @@ TFL_NODE(RELU, locoex::TFLRelu) // TODO TFLReshape // TODO TFLSoftmax // TODO TFLSqrt -// TODO TFLSub +TFL_NODE(SUB, locoex::TFLSub) // TODO TFLTanh // TODO TFLTranspose diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp index efe23f150..b52b4525e 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp @@ -40,7 +40,16 @@ TEST(TFLAddTest, constructor) // TODO TFLDepthwiseConv2D -// TODO TFLDiv +TEST(TFLDivTest, constructor) +{ + locoex::TFLDiv div_node; + + ASSERT_EQ(div_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(div_node.opcode(), locoex::TFLOpcode::DIV); + + ASSERT_EQ(div_node.x(), nullptr); + ASSERT_EQ(div_node.y(), nullptr); +} // TODO TFLMaxPool2D @@ -73,7 +82,16 @@ TEST(TFLReluTest, constructor) // TODO TFLSqrt -// TODO TFLSub +TEST(TFLSubTest, constructor) +{ + locoex::TFLSub sub_node; + + ASSERT_EQ(sub_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(sub_node.opcode(), locoex::TFLOpcode::SUB); + + ASSERT_EQ(sub_node.x(), nullptr); + ASSERT_EQ(sub_node.y(), nullptr); +} // TODO TFLTanh diff --git a/compiler/exo-tflite/src/TFLFormattedGraph.cpp b/compiler/exo-tflite/src/TFLFormattedGraph.cpp index 8b1cfde28..52ce0ff34 100644 --- a/compiler/exo-tflite/src/TFLFormattedGraph.cpp +++ b/compiler/exo-tflite/src/TFLFormattedGraph.cpp @@ -131,7 +131,13 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLAveragePool2D *node, // TODO TFLDepthwiseConv2D -// TODO TFLDiv +bool TFLNodeSummaryBuilder::summary(const locoex::TFLDiv *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaxPool2D *node, locop::NodeSummary &s) const { @@ -163,7 +169,13 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu *node, locop::NodeSumm // TODO TFLSqrt -// TODO TFLSub +bool TFLNodeSummaryBuilder::summary(const locoex::TFLSub *node, locop::NodeSummary &s) const +{ + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} // TODO TFLTanh |