diff options
author | 윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com> | 2019-09-02 13:16:38 +0900 |
---|---|---|
committer | 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com> | 2019-09-02 13:16:38 +0900 |
commit | c04ed020ac2c9440b9d20f4c49b9f0229b07c236 (patch) | |
tree | 39abf4752fff7d3eaae407ff22b0fcea2dc60bb9 | |
parent | 6618fdedab4e0a9fb498985fba041f1f6de39fd1 (diff) | |
download | nnfw-c04ed020ac2c9440b9d20f4c49b9f0229b07c236.tar.gz nnfw-c04ed020ac2c9440b9d20f4c49b9f0229b07c236.tar.bz2 nnfw-c04ed020ac2c9440b9d20f4c49b9f0229b07c236.zip |
[exo-tflite] Change the name of input of TFLRelu (#7081)
* [exo-tflite] Change the name of input of TFLRelu
Input name follows TensorFlow input naming. (input -> features)
Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* modify input -> features
6 files changed, 7 insertions, 7 deletions
diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h index de3b0a5e1..31184a91c 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.h @@ -164,8 +164,8 @@ public: TFLRelu() = default; public: - loco::Node *input(void) const { return at(0)->node(); } - void input(loco::Node *node) { at(0)->node(node); } + loco::Node *features(void) const { return at(0)->node(); } + void features(loco::Node *node) { at(0)->node(node); } }; // TODO TFLRelu6 diff --git a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp index 90c98f3a5..6f0abc837 100644 --- a/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp +++ b/compiler/exo-tflite/src/Dialect/IR/TFLNodes.test.cpp @@ -44,7 +44,7 @@ TEST(TFLReluTest, constructor) ASSERT_EQ(relu_node.dialect(), locoex::TFLDialect::get()); ASSERT_EQ(relu_node.opcode(), locoex::TFLOpcode::RELU); - ASSERT_EQ(relu_node.input(), nullptr); + ASSERT_EQ(relu_node.features(), nullptr); } // TODO TFLRelu6 diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp index de2e966d2..3bb24f959 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.test.cpp @@ -33,7 +33,7 @@ TEST(TFLShapeInferenceRuleTest, minimal_with_TFLRelu) auto pull_node = g->nodes()->create<loco::Pull>(); auto tfl_node = g->nodes()->create<locoex::TFLRelu>(); - tfl_node->input(pull_node); + tfl_node->features(pull_node); auto push_node = g->nodes()->create<loco::Push>(); push_node->from(tfl_node); diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInference.test.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInference.test.cpp index 180dbe201..0190eeb1f 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInference.test.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInference.test.cpp @@ -32,7 +32,7 @@ TEST(TFLTypeInferenceRuleTest, minimal_with_TFLRelu) auto pull_node = g->nodes()->create<loco::Pull>(); auto tfl_node = g->nodes()->create<locoex::TFLRelu>(); - tfl_node->input(pull_node); + tfl_node->features(pull_node); auto push_node = g->nodes()->create<loco::Push>(); push_node->from(tfl_node); diff --git a/compiler/exo-tflite/src/OperationExporter.cpp b/compiler/exo-tflite/src/OperationExporter.cpp index 91fa2f8a6..a569052e3 100644 --- a/compiler/exo-tflite/src/OperationExporter.cpp +++ b/compiler/exo-tflite/src/OperationExporter.cpp @@ -115,7 +115,7 @@ private: void OperationExporter::visit(locoex::TFLRelu *node) { uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_RELU); - std::vector<int32_t> inputs_vec{get_tensor_index(node->input())}; + std::vector<int32_t> inputs_vec{get_tensor_index(node->features())}; std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))}; auto inputs = builder.CreateVector(inputs_vec); auto outputs = builder.CreateVector(outputs_vec); diff --git a/compiler/exo-tflite/src/TFLFormattedGraph.cpp b/compiler/exo-tflite/src/TFLFormattedGraph.cpp index 0f29aa38b..cb0a6f329 100644 --- a/compiler/exo-tflite/src/TFLFormattedGraph.cpp +++ b/compiler/exo-tflite/src/TFLFormattedGraph.cpp @@ -116,7 +116,7 @@ bool TFLNodeSummaryBuilder::summary(const locoex::TFLAveragePool2D *node, bool TFLNodeSummaryBuilder::summary(const locoex::TFLRelu *node, locop::NodeSummary &s) const { s.opname("TFL.RELU"); - s.args().append("input", tbl()->lookup(node->input())); + s.args().append("input", tbl()->lookup(node->features())); s.state(locop::NodeSummary::State::Complete); return true; } |