summaryrefslogtreecommitdiff
path: root/compiler/exo-tflite
diff options
context:
space:
mode:
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>2019-09-10 16:00:41 +0900
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>2019-09-10 16:00:41 +0900
commitf1cb3741daeca5694eacc46ace3ead0262f4d4f7 (patch)
treee76da223a41d25ffd9aa378c751438fe806ca344 /compiler/exo-tflite
parent0327f118702abaa238deaa815bbdd6c450e9413a (diff)
downloadnnfw-f1cb3741daeca5694eacc46ace3ead0262f4d4f7.tar.gz
nnfw-f1cb3741daeca5694eacc46ace3ead0262f4d4f7.tar.bz2
nnfw-f1cb3741daeca5694eacc46ace3ead0262f4d4f7.zip
[exo-tflite] Adding TFLAveragePool2D into OperationExporter (#7315)
TFLAveragePool2D was added into OperationExporter. Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
Diffstat (limited to 'compiler/exo-tflite')
-rw-r--r--compiler/exo-tflite/src/OperationExporter.cpp26
1 files changed, 24 insertions, 2 deletions
diff --git a/compiler/exo-tflite/src/OperationExporter.cpp b/compiler/exo-tflite/src/OperationExporter.cpp
index e3f1cc439..fe6ba96ef 100644
--- a/compiler/exo-tflite/src/OperationExporter.cpp
+++ b/compiler/exo-tflite/src/OperationExporter.cpp
@@ -22,6 +22,8 @@
#include "Dialect/IR/TFLNodes.h"
#include "Dialect/IR/TFLNodeVisitor.h"
+#include "Check.h"
+
#include <loco/IR/CanonicalNode.h>
#include <loco/IR/CanonicalNodeVisitor.h>
#include <locoex/COpCall.h>
@@ -46,7 +48,7 @@ public:
public:
// FOR TFLNodes
void visit(locoex::TFLAdd *) final;
- // TODO TFLAveragePool2D
+ void visit(locoex::TFLAveragePool2D *) final;
// TODO TFLConcatenation
// TODO TFLConv2D
// TODO TFLDepthwiseConv2D
@@ -109,7 +111,27 @@ void OperationExporter::visit(locoex::TFLAdd *node)
gd._operators.push_back(op_offset);
}
-// TODO TFLAveragePool2D
+void OperationExporter::visit(locoex::TFLAveragePool2D *node)
+{
+ EXO_ASSERT(node->padding() != locoex::Padding::UNDEFINED, "Padding is not set");
+ EXO_ASSERT(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED,
+ "fused activation function is not set");
+
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_AVERAGE_POOL_2D);
+ std::vector<int32_t> inputs_vec{get_tensor_index(node->value())};
+ std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+
+ tflite::Padding padding =
+ node->padding() == locoex::Padding::VALID ? tflite::Padding_VALID : tflite::Padding_SAME;
+
+ auto options = CreatePool2DOptions(builder, padding, node->stride()->w(), node->stride()->h(),
+ node->filter()->w(), node->filter()->h());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_Pool2DOptions, options.Union());
+ gd._operators.push_back(op_offset);
+}
// TODO TFLConcatenation