diff options
author | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-16 17:16:36 +0900 |
---|---|---|
committer | 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com> | 2019-09-16 17:16:36 +0900 |
commit | ddf7a27b59078bde5b0bd3fb25c60cfc1e41c71a (patch) | |
tree | ac0dcf5574bf039c29f1384a19b895b54d528406 /compiler | |
parent | d610f092ec3d9f64162f69e4a6e93acdf687adf6 (diff) | |
download | nnfw-ddf7a27b59078bde5b0bd3fb25c60cfc1e41c71a.tar.gz nnfw-ddf7a27b59078bde5b0bd3fb25c60cfc1e41c71a.tar.bz2 nnfw-ddf7a27b59078bde5b0bd3fb25c60cfc1e41c71a.zip |
[exo-tflite] Converter for TensorBroadcast IR (#7432)
* [exo-tflite] Converter for TensorBroadcast IR
This will introduce TensorBroadcastConverter that resolves loco::TensorBroadcast IR that meets some certain condition
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* make compiler happy
* fix comments
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/exo-tflite/src/Conversion/TensorBroadcastConverter.cpp | 142 | ||||
-rw-r--r-- | compiler/exo-tflite/src/Conversion/TensorBroadcastConverter.h | 40 |
2 files changed, 182 insertions, 0 deletions
diff --git a/compiler/exo-tflite/src/Conversion/TensorBroadcastConverter.cpp b/compiler/exo-tflite/src/Conversion/TensorBroadcastConverter.cpp new file mode 100644 index 000000000..1113c1189 --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/TensorBroadcastConverter.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "TensorBroadcastConverter.h" + +#include "Dialect/IR/TFLDialect.h" +#include "Dialect/IR/TFLNodeVisitor.h" + +#include <loco.h> +#include <loco/IR/CanonicalDialect.h> +#include <loco/IR/CanonicalNode.h> + +#include <set> + +namespace +{ + +template <class T> loco::TensorBroadcast *input_as_tbc(T *node) +{ + loco::TensorBroadcast *tbc = dynamic_cast<loco::TensorBroadcast *>(node->x()); + if (tbc == nullptr) + tbc = dynamic_cast<loco::TensorBroadcast *>(node->y()); + + return tbc; +} + +struct Collector final : public locoex::TFLNodeMutableVisitor<void> +{ + using NodePair = std::pair<loco::TensorBroadcast *, loco::Node *>; + + void visit(locoex::TFLAdd *node) final + { + if (auto tbc = input_as_tbc<locoex::TFLAdd>(node)) + { + NodePair pair(tbc, node); + candidates.insert(pair); + } + } + + // TODO ADD TFLDiv + + // TODO ADD TFLMul + + // TODO ADD TFLSub + + void visit(locoex::TFLNode *) final { return; } + + std::set<NodePair> candidates; +}; + +bool mapping_condition(Collector::NodePair &) +{ + // TODO fill condition + + return true; +} + +template <class T> void jump_connection(loco::TensorBroadcast *tbc, T *tflnode) +{ + if (tflnode->x() == tbc) + tflnode->x(tbc->input()); + else if (tflnode->y() == tbc) + tflnode->y(tbc->input()); + else + assert(false); + + tbc->input(nullptr); +} + +} // namespace + +namespace exo +{ + +/** + * @brief Disconnects loco::TensorBroadcast from the graph if following node + * is one of binary node: TFLAdd, TFLSub, TFLMul, TFLDiv + * and meets condition (TBA) + * @note + * Before: + * x --- TensorBroadcast --- TFLXXX --- output + * y ----------------------/ + * + * After: + * --- TensorBroadcast --- + * x --- TFLXXX --- output + * y --/ + */ +bool TensorBroadcastConverter::run(loco::Graph *graph) +{ + Collector collector; + + auto active_nodes = loco::active_nodes(loco::output_nodes(graph)); + + for (auto node : active_nodes) + { + if (node->dialect() == locoex::TFLDialect::get()) + { + auto tfl_node = dynamic_cast<locoex::TFLNode *>(node); + tfl_node->accept(&collector); + } + } + + bool changed = false; + + for (auto pair : collector.candidates) + { + if (mapping_condition(pair)) + { + loco::TensorBroadcast *tensorbroadcast = pair.first; + if (auto tfladd = dynamic_cast<locoex::TFLAdd *>(pair.second)) + { + jump_connection<locoex::TFLAdd>(tensorbroadcast, tfladd); + changed = true; + } + // TODO ADD TFLDiv + // TODO ADD TFLMul + // TODO ADD TFLSub + else + { + assert(false); + } + } + } + + return changed; +} + +} // namespace exo diff --git a/compiler/exo-tflite/src/Conversion/TensorBroadcastConverter.h b/compiler/exo-tflite/src/Conversion/TensorBroadcastConverter.h new file mode 100644 index 000000000..3cf79b0ba --- /dev/null +++ b/compiler/exo-tflite/src/Conversion/TensorBroadcastConverter.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __TENSOR_BROADCAST_CONVERTER_H__ +#define __TENSOR_BROADCAST_CONVERTER_H__ + +#include <loco.h> +#include <logo/Pass.h> + +namespace exo +{ + +/** + * @brief Pass to resolve TensorBroadcast IR + */ +class TensorBroadcastConverter : public logo::Pass +{ +public: + virtual const char *name(void) const { return "exo::TensorBroadcastConverter"; } + +public: + bool run(loco::Graph *graph); +}; + +} // namespace exo + +#endif //__TENSOR_BROADCAST_CONVERTER_H__ |