/* * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "Relu6.h" #include "GraphBuilder.h" #include "Knob.h" #include "IR/TFRelu6.h" #include namespace { using namespace moco::tf; class ReLU6GraphUpdate final : public GraphUpdate { public: ReLU6GraphUpdate(loco::ReLU6 *node, const TensorName &&name) : _node(node), _name(name) {} void input(const SymbolTable *) const override; private: loco::ReLU6 *_node; const TensorName _name; }; class TFRelu6GraphUpdate final : public GraphUpdate { public: TFRelu6GraphUpdate(moco::tf::TFRelu6 *node, const TensorName &&name) : _node(node), _name(name) {} void input(const SymbolTable *) const override; private: moco::tf::TFRelu6 *_node; const TensorName _name; }; void ReLU6GraphUpdate::input(const SymbolTable *table) const { loco::Node *target = table->node(_name); _node->input(target); } void TFRelu6GraphUpdate::input(const SymbolTable *table) const { loco::Node *target = table->node(_name); _node->features(target); } } // namespace namespace moco { namespace tf { /** * @brief GraphBuilder for Relu6 node */ class Relu6GraphBuilder final : public Relu6GraphBuilderBase { public: void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override; }; bool Relu6GraphBuilderBase::validate(const tensorflow::NodeDef &node) const { // ReLU6 node SHOULD have only one input if (node.input_size() != 1) return false; return true; } void Relu6GraphBuilder::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const { assert(context != nullptr); if (moco::tf::get()) { Relu6GraphBuilderImpl builder; return builder.build(node, context); } else { Relu6GraphBuilderImpl builder; return builder.build(node, context); } } void Relu6GraphBuilderImpl::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const { loco::Graph *graph = context->graph(); SymbolTable *tensor_names = context->tensor_names(); UpdateQueue *updates = context->updates(); // Create a "ReLU6" node for Relu6 auto relu6_node = graph->nodes()->create(); // register string-name to node TensorName output_name(node.name(), 0); tensor_names->enroll(output_name, relu6_node); // Queue node input update auto update = stdex::make_unique(relu6_node, TensorName(node.input(0))); updates->enroll(std::move(update)); } void Relu6GraphBuilderImpl::build(const tensorflow::NodeDef &node, GraphBuilderContext *context) const { loco::Graph *graph = context->graph(); SymbolTable *tensor_names = context->tensor_names(); UpdateQueue *updates = context->updates(); // Create a "TFRelu6" node for Relu auto relu_node = graph->nodes()->create(); // register string-name to node TensorName output_name(node.name(), 0); tensor_names->enroll(output_name, relu_node); // Queue node input update auto update = stdex::make_unique(relu_node, TensorName(node.input(0))); updates->enroll(std::move(update)); } } // namespace tf } // namespace moco #include "GraphBuilderRegistry.h" REGISTER_OP_BUILDER(Relu6, Relu6GraphBuilder)