summaryrefslogtreecommitdiff
path: root/compiler/exo-tflite
diff options
context:
space:
mode:
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2019-09-16 13:09:57 +0900
committerGitHub Enterprise <noreply-CODE@samsung.com>2019-09-16 13:09:57 +0900
commitc853c5cb52bef85d3de98978adb5b8fc29dada51 (patch)
tree4b78bbfc0a395c87bc8445287aa24d826195f8b4 /compiler/exo-tflite
parent0c37457161414a36a32fe5559b3e3aa7c78aae01 (diff)
downloadnnfw-c853c5cb52bef85d3de98978adb5b8fc29dada51.tar.gz
nnfw-c853c5cb52bef85d3de98978adb5b8fc29dada51.tar.bz2
nnfw-c853c5cb52bef85d3de98978adb5b8fc29dada51.zip
[exo-tflite] Introduce TFLMul (#7430)
This will introduce TFLMul IR for TensorFlow lite MUL node Signed-off-by: SaeHie Park <saehie.park@samsung.com>
Diffstat (limited to 'compiler/exo-tflite')
-rw-r--r--compiler/exo-tflite/src/Dialect/IR/TFLNodes.h16
-rw-r--r--compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst2
-rw-r--r--compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp11
-rw-r--r--compiler/exo-tflite/src/TFLFormattedGraph.cpp9
4 files changed, 34 insertions, 4 deletions
diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h
index 0d98b561f..f9ff2223f 100644
--- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h
+++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h
@@ -200,7 +200,21 @@ private:
Filter _filter;
};
-// TODO TFLMul
+/**
+ * @brief MUL in TensorFlow Lite
+ */
+class TFLMul final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MUL>>
+{
+public:
+ TFLMul() = default;
+
+public:
+ loco::Node *x(void) const { return at(0)->node(); }
+ void x(loco::Node *node) { at(0)->node(node); }
+
+ loco::Node *y(void) const { return at(1)->node(); }
+ void y(loco::Node *node) { at(1)->node(node); }
+};
class TFLRelu final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU>>
{
diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst
index 60059d06d..e87028258 100644
--- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst
+++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.lst
@@ -12,7 +12,7 @@ TFL_NODE(AVERAGE_POOL_2D, locoex::TFLAveragePool2D)
// TODO TFLDepthwiseConv2D
// TODO TFLDiv
TFL_NODE(MAX_POOL_2D, locoex::TFLMaxPool2D)
-// TODO TFLMul
+TFL_NODE(MUL, locoex::TFLMul)
TFL_NODE(RELU, locoex::TFLRelu)
// TODO TFLRelu6
// TODO TFLReshape
diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp
index 542fcf3b7..efe23f150 100644
--- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp
+++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp
@@ -44,7 +44,16 @@ TEST(TFLAddTest, constructor)
// TODO TFLMaxPool2D
-// TODO TFLMul
+TEST(TFLMulTest, constructor)
+{
+ locoex::TFLMul mul_node;
+
+ ASSERT_EQ(mul_node.dialect(), locoex::TFLDialect::get());
+ ASSERT_EQ(mul_node.opcode(), locoex::TFLOpcode::MUL);
+
+ ASSERT_EQ(mul_node.x(), nullptr);
+ ASSERT_EQ(mul_node.y(), nullptr);
+}
TEST(TFLReluTest, constructor)
{
diff --git a/compiler/exo-tflite/src/TFLFormattedGraph.cpp b/compiler/exo-tflite/src/TFLFormattedGraph.cpp
index 31a6d99ec..d6489f28f 100644
--- a/compiler/exo-tflite/src/TFLFormattedGraph.cpp
+++ b/compiler/exo-tflite/src/TFLFormattedGraph.cpp
@@ -124,7 +124,14 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLMaxPool2D *node, locop::Nod
return true;
}
-// TODO TFLMul
+bool TFLNodeSummaryBuilder::summary(const locoex::TFLMul *node, locop::NodeSummary &s) const
+{
+ s.opname("TFL.MUL");
+ s.args().append("x", tbl()->lookup(node->x()));
+ s.args().append("y", tbl()->lookup(node->y()));
+ s.state(locop::NodeSummary::State::Complete);
+ return true;
+}
bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu *node, locop::NodeSummary &s) const
{