diff options
author | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-16 13:09:57 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2019-09-16 13:09:57 +0900 |
commit | c853c5cb52bef85d3de98978adb5b8fc29dada51 (patch) | |
tree | 4b78bbfc0a395c87bc8445287aa24d826195f8b4 /compiler/exo-tflite | |
parent | 0c37457161414a36a32fe5559b3e3aa7c78aae01 (diff) | |
download | nnfw-c853c5cb52bef85d3de98978adb5b8fc29dada51.tar.gz nnfw-c853c5cb52bef85d3de98978adb5b8fc29dada51.tar.bz2 nnfw-c853c5cb52bef85d3de98978adb5b8fc29dada51.zip |
[exo-tflite] Introduce TFLMul (#7430)
This will introduce TFLMul IR for TensorFlow lite MUL node
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
Diffstat (limited to 'compiler/exo-tflite')
-rw-r--r-- | compiler/exo-tflite/src/Dialect/IR/TFLNodes.h | 16 | ||||
-rw-r--r-- | compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst | 2 | ||||
-rw-r--r-- | compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp | 11 | ||||
-rw-r--r-- | compiler/exo-tflite/src/TFLFormattedGraph.cpp | 9 |
4 files changed, 34 insertions, 4 deletions
diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h index 0d98b561f..f9ff2223f 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h @@ -200,7 +200,21 @@ private: Filter _filter; }; -// TODO TFLMul +/** + * @brief MUL in TensorFlow Lite + */ +class TFLMul final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MUL>> +{ +public: + TFLMul() = default; + +public: + loco::Node *x(void) const { return at(0)->node(); } + void x(loco::Node *node) { at(0)->node(node); } + + loco::Node *y(void) const { return at(1)->node(); } + void y(loco::Node *node) { at(1)->node(node); } +}; class TFLRelu final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU>> { diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst index 60059d06d..e87028258 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst @@ -12,7 +12,7 @@ TFL_NODE(AVERAGE_POOL_2D, locoex::TFLAveragePool2D) // TODO TFLDepthwiseConv2D // TODO TFLDiv TFL_NODE(MAX_POOL_2D, locoex::TFLMaxPool2D) -// TODO TFLMul +TFL_NODE(MUL, locoex::TFLMul) TFL_NODE(RELU, locoex::TFLRelu) // TODO TFLRelu6 // TODO TFLReshape diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp index 542fcf3b7..efe23f150 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp @@ -44,7 +44,16 @@ TEST(TFLAddTest, constructor) // TODO TFLMaxPool2D -// TODO TFLMul +TEST(TFLMulTest, constructor) +{ + locoex::TFLMul mul_node; + + ASSERT_EQ(mul_node.dialect(), locoex::TFLDialect::get()); + ASSERT_EQ(mul_node.opcode(), locoex::TFLOpcode::MUL); + + ASSERT_EQ(mul_node.x(), nullptr); + ASSERT_EQ(mul_node.y(), nullptr); +} TEST(TFLReluTest, constructor) { diff --git a/compiler/exo-tflite/src/TFLFormattedGraph.cpp b/compiler/exo-tflite/src/TFLFormattedGraph.cpp index 31a6d99ec..d6489f28f 100644 --- a/compiler/exo-tflite/src/TFLFormattedGraph.cpp +++ b/compiler/exo-tflite/src/TFLFormattedGraph.cpp @@ -124,7 +124,14 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaxPool2D *node, locop::Nod return true; } -// TODO TFLMul +bool TFLNodeSummaryBuilder::summary(const locoex::TFLMul *node, locop::NodeSummary &s) const +{ + s.opname("TFL.MUL"); + s.args().append("x", tbl()->lookup(node->x())); + s.args().append("y", tbl()->lookup(node->y())); + s.state(locop::NodeSummary::State::Complete); + return true; +} bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu *node, locop::NodeSummary &s) const { |