summaryrefslogtreecommitdiff
path: root/compiler/enco/frontend/tflite/src/Op/Activation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/enco/frontend/tflite/src/Op/Activation.cpp')
-rw-r--r--compiler/enco/frontend/tflite/src/Op/Activation.cpp96
1 files changed, 96 insertions, 0 deletions
diff --git a/compiler/enco/frontend/tflite/src/Op/Activation.cpp b/compiler/enco/frontend/tflite/src/Op/Activation.cpp
new file mode 100644
index 000000000..d6215ba34
--- /dev/null
+++ b/compiler/enco/frontend/tflite/src/Op/Activation.cpp
@@ -0,0 +1,96 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Activation.h"
+
+#include <IRBuilder.h>
+
+#include <coco/IR/Module.h>
+#include <coco/IR/FeatureLayouts.h>
+
+#include <nncc/core/ADT/tensor/Shape.h>
+#include <schema_generated.h>
+
+#include <cassert>
+
+using namespace nncc::core::ADT;
+
+namespace tflimport
+{
+
+coco::FeatureObject *build_activation(tflite::ActivationFunctionType act, coco::Block *block,
+ coco::FeatureObject *ifm)
+{
+ assert(ifm != nullptr && ifm->asFeature() != nullptr); // support feature only in this version
+
+ coco::Module *m = block->module();
+
+ auto shape = ifm->asFeature()->shape();
+
+ // creates output object
+ auto output_obj = m->entity()->object()->create<coco::FeatureObject>();
+ auto output_bag = m->entity()->bag()->create(num_elements(shape));
+ output_obj->bag(output_bag);
+ output_obj->layout(coco::FeatureLayouts::BHWC::create(shape));
+
+ switch (act)
+ {
+ case tflite::ActivationFunctionType::ActivationFunctionType_NONE:
+ {
+ // Create Copy Instr (copying from ifm to output_obj),
+ // redundant layer but optimized by backend
+ auto copy_ins = instr_builder(m).copy(output_obj, ifm);
+
+ // Append the instruction to the block
+ block->instr()->append(copy_ins);
+ break;
+ }
+ case tflite::ActivationFunctionType::ActivationFunctionType_RELU:
+ {
+ // Create Eval(output_obj, ReLU(load(ifm)))
+ auto load_op = op_builder(m).load(ifm).pop();
+ auto relu_op = m->entity()->op()->create<coco::ReLU>();
+ relu_op->arg(load_op);
+
+ auto eval_ins = instr_builder(m).eval(output_obj, relu_op);
+
+ // Append the instruction to the block
+ block->instr()->append(eval_ins);
+ break;
+ }
+ case tflite::ActivationFunctionType::ActivationFunctionType_RELU6:
+ {
+ // Create Eval(output_obj, ReLU6(load(ifm)))
+ auto load_op = op_builder(m).load(ifm).pop();
+ auto relu6_op = m->entity()->op()->create<coco::ReLU6>();
+ relu6_op->arg(load_op);
+
+ auto eval_ins = instr_builder(m).eval(output_obj, relu6_op);
+
+ // Append the instruction to the block
+ block->instr()->append(eval_ins);
+ break;
+ }
+ default:
+ // TODO support other fused activations
+ assert(false);
+ break;
+ }
+
+ return output_obj;
+}
+
+} // namespace tflimport