diff options
author | 윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com> | 2019-09-10 16:00:41 +0900 |
---|---|---|
committer | 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com> | 2019-09-10 16:00:41 +0900 |
commit | f1cb3741daeca5694eacc46ace3ead0262f4d4f7 (patch) | |
tree | e76da223a41d25ffd9aa378c751438fe806ca344 /compiler/exo-tflite | |
parent | 0327f118702abaa238deaa815bbdd6c450e9413a (diff) | |
download | nnfw-f1cb3741daeca5694eacc46ace3ead0262f4d4f7.tar.gz nnfw-f1cb3741daeca5694eacc46ace3ead0262f4d4f7.tar.bz2 nnfw-f1cb3741daeca5694eacc46ace3ead0262f4d4f7.zip |
[exo-tflite] Adding TFLAveragePool2D into OperationExporter (#7315)
TFLAveragePool2D was added into OperationExporter.
Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
Diffstat (limited to 'compiler/exo-tflite')
-rw-r--r-- | compiler/exo-tflite/src/OperationExporter.cpp | 26 |
1 files changed, 24 insertions, 2 deletions
diff --git a/compiler/exo-tflite/src/OperationExporter.cpp b/compiler/exo-tflite/src/OperationExporter.cpp index e3f1cc439..fe6ba96ef 100644 --- a/compiler/exo-tflite/src/OperationExporter.cpp +++ b/compiler/exo-tflite/src/OperationExporter.cpp @@ -22,6 +22,8 @@ #include "Dialect/IR/TFLNodes.h" #include "Dialect/IR/TFLNodeVisitor.h" +#include "Check.h" + #include <loco/IR/CanonicalNode.h> #include <loco/IR/CanonicalNodeVisitor.h> #include <locoex/COpCall.h> @@ -46,7 +48,7 @@ public: public: // FOR TFLNodes void visit(locoex::TFLAdd *) final; - // TODO TFLAveragePool2D + void visit(locoex::TFLAveragePool2D *) final; // TODO TFLConcatenation // TODO TFLConv2D // TODO TFLDepthwiseConv2D @@ -109,7 +111,27 @@ void OperationExporter::visit(locoex::TFLAdd *node) gd._operators.push_back(op_offset); } -// TODO TFLAveragePool2D +void OperationExporter::visit(locoex::TFLAveragePool2D *node) +{ + EXO_ASSERT(node->padding() != locoex::Padding::UNDEFINED, "Padding is not set"); + EXO_ASSERT(node->fusedActivationFunction() != locoex::FusedActFunc::UNDEFINED, + "fused activation function is not set"); + + uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_AVERAGE_POOL_2D); + std::vector<int32_t> inputs_vec{get_tensor_index(node->value())}; + std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; + auto inputs = builder.CreateVector(inputs_vec); + auto outputs = builder.CreateVector(outputs_vec); + + tflite::Padding padding = + node->padding() == locoex::Padding::VALID ? tflite::Padding_VALID : tflite::Padding_SAME; + + auto options = CreatePool2DOptions(builder, padding, node->stride()->w(), node->stride()->h(), + node->filter()->w(), node->filter()->h()); + auto op_offset = CreateOperator(builder, op_idx, inputs, outputs, + tflite::BuiltinOptions_Pool2DOptions, options.Union()); + gd._operators.push_back(op_offset); +} // TODO TFLConcatenation |