diff options
Diffstat (limited to 'compiler/tflchef/tflite/src/TFliteOpRegistry.h')
-rw-r--r-- | compiler/tflchef/tflite/src/TFliteOpRegistry.h | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/compiler/tflchef/tflite/src/TFliteOpRegistry.h b/compiler/tflchef/tflite/src/TFliteOpRegistry.h new file mode 100644 index 000000000..f0aed2113 --- /dev/null +++ b/compiler/tflchef/tflite/src/TFliteOpRegistry.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TFLITE_OP_REGISTRY_H__ +#define __TFLITE_OP_REGISTRY_H__ + +#include "TFliteOpChef.h" +#include "TFliteOpChefs.h" + +#include <memory> + +namespace tflchef +{ + +/** + * @brief tflchef operator registry + */ +class TFliteOpRegistry +{ +public: + /** + * @brief Returns registered TFliteOpChef pointer for BuiltinOperator or + * nullptr if not registered + */ + const TFliteOpChef *lookup(tflite::BuiltinOperator op) const + { + if (_tfliteop_map.find(op) == _tfliteop_map.end()) + return nullptr; + + return _tfliteop_map.at(op).get(); + } + + static TFliteOpRegistry &get() + { + static TFliteOpRegistry me; + return me; + } + +private: + TFliteOpRegistry() + { +#define REG_TFL_OP(OPCODE, CLASS) \ + _tfliteop_map[tflite::BuiltinOperator_##OPCODE] = std::make_unique<CLASS>() + + REG_TFL_OP(ABS, TFliteOpAbs); + REG_TFL_OP(ADD, TFliteOpAdd); + REG_TFL_OP(ARG_MAX, TFliteOpArgMax); + REG_TFL_OP(AVERAGE_POOL_2D, TFliteOpAveragePool2D); + REG_TFL_OP(BATCH_TO_SPACE_ND, TFliteOpBatchToSpaceND); + REG_TFL_OP(CONCATENATION, TFliteOpConcatenation); + REG_TFL_OP(CONV_2D, TFliteOpConv2D); + REG_TFL_OP(COS, TFliteOpCos); + REG_TFL_OP(DEPTHWISE_CONV_2D, TFliteOpDepthwiseConv2D); + REG_TFL_OP(DIV, TFliteOpDiv); + REG_TFL_OP(EQUAL, TFliteOpEqual); + REG_TFL_OP(EXP, TFliteOpExp); + REG_TFL_OP(FLOOR_DIV, TFliteOpFloorDiv); + REG_TFL_OP(FULLY_CONNECTED, TFliteOpFullyConnected); + REG_TFL_OP(LOGICAL_NOT, TFliteOpLogicalNot); + REG_TFL_OP(LOGICAL_OR, TFliteOpLogicalOr); + REG_TFL_OP(MAX_POOL_2D, TFliteOpMaxPool2D); + REG_TFL_OP(MEAN, TFliteOpMean); + REG_TFL_OP(PACK, TFliteOpPack); + REG_TFL_OP(PAD, TFliteOpPad); + REG_TFL_OP(RELU, TFliteOpReLU); + REG_TFL_OP(RELU6, TFliteOpReLU6); + REG_TFL_OP(RESHAPE, TFliteOpReshape); + REG_TFL_OP(RSQRT, TFliteOpRsqrt); + REG_TFL_OP(SOFTMAX, TFliteOpSoftmax); + REG_TFL_OP(SQRT, TFliteOpSqrt); + REG_TFL_OP(SUB, TFliteOpSub); + REG_TFL_OP(TANH, TFliteOpTanh); + REG_TFL_OP(TRANSPOSE, TFliteOpTranspose); + +#undef REG_TFL_OP + } + +private: + std::map<tflite::BuiltinOperator, std::unique_ptr<TFliteOpChef>> _tfliteop_map; +}; + +} // namespace tflchef + +#endif // __TFLITE_OP_REGISTRY_H__ |