diff options
author | 윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com> | 2019-09-17 16:58:51 +0900 |
---|---|---|
committer | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-17 16:58:51 +0900 |
commit | 421490ad05ba56798ce92b2e10dd60fbae168b2b (patch) | |
tree | 63f638c8484fabd2daedd8481dcde4ab2d9dbc8a /compiler/exo-tflite | |
parent | b3d041950855f5f3d2884012ab004394e430ff0d (diff) | |
download | nnfw-421490ad05ba56798ce92b2e10dd60fbae168b2b.tar.gz nnfw-421490ad05ba56798ce92b2e10dd60fbae168b2b.tar.bz2 nnfw-421490ad05ba56798ce92b2e10dd60fbae168b2b.zip |
[exo-tflite] Enhancing TestGraph.h (#7469)
* [exo-tflite] Enhancing TestGraph.h
1) Adds more operations. 2) changed the symantics of append() with no arg.
Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
* typo, adds better comment
Diffstat (limited to 'compiler/exo-tflite')
-rw-r--r-- | compiler/exo-tflite/src/TestGraph.h | 35 |
1 files changed, 31 insertions, 4 deletions
diff --git a/compiler/exo-tflite/src/TestGraph.h b/compiler/exo-tflite/src/TestGraph.h index b87545580..79d4d4b0e 100644 --- a/compiler/exo-tflite/src/TestGraph.h +++ b/compiler/exo-tflite/src/TestGraph.h @@ -76,7 +76,7 @@ private: // setInput of TFL nodes template <> void PullPushGraph<locoex::TFLAveragePool2D>::setInput() { middle_node->value(pull); } -struct TestGraph +class TestGraph { public: std::unique_ptr<loco::Graph> g; @@ -105,16 +105,17 @@ public: _next_input = pull; } - template <class T> T *append() // input will be previously appended node + /// @brief Creates node with NO arg and appends it to graph + template <class T> T *append() { auto node = g->nodes()->create<T>(); - setInput(node, _next_input); _next_input = node; return node; } - template <class T> T *append(loco::Node *arg1) // create [ arg1 - T ] subgraph + /// @brief Creates op T (arity=1) with arg1 as an input and appends it to graph + template <class T> T *append(loco::Node *arg1) { auto node = g->nodes()->create<T>(); setInput(node, arg1); @@ -123,6 +124,7 @@ public: return node; } + /// @brief Creates op T (arity=2) with arg1, arg2 as inputs and appends it to graph template <class T> T *append(loco::Node *arg1, loco::Node *arg2) { auto node = g->nodes()->create<T>(); @@ -141,6 +143,13 @@ private: // arity 1 void setInput(loco::Node *node, loco::Node *) { assert(false && "NYI"); }; + void setInput(loco::AvgPool2D *node, loco::Node *input) { node->ifm(input); } + void setInput(loco::BiasDecode *node, loco::Node *input) { node->input(input); }; + void setInput(loco::BiasEncode *node, loco::Node *input) { node->input(input); }; + void setInput(loco::FeatureDecode *node, loco::Node *input) { node->input(input); }; + void setInput(loco::FeatureEncode *node, loco::Node *input) { node->input(input); }; + void setInput(loco::MaxPool2D *node, loco::Node *input) { node->ifm(input); } + void setInput(loco::Push *node, loco::Node *input) { node->from(input); }; void setInput(loco::ReLU *node, loco::Node *input) { node->input(input); }; void setInput(loco::ReLU6 *node, loco::Node *input) { node->input(input); }; void setInput(loco::Tanh *node, loco::Node *input) { node->input(input); }; @@ -158,6 +167,24 @@ private: node->rhs(arg2); }; + void setInput(loco::FeatureBiasAdd *node, loco::Node *arg1, loco::Node *arg2) + { + node->value(arg1); + node->bias(arg2); + }; + + void setInput(locoex::TFLAdd *node, loco::Node *arg1, loco::Node *arg2) + { + node->x(arg1); + node->y(arg2); + }; + + void setInput(locoex::TFLMul *node, loco::Node *arg1, loco::Node *arg2) + { + node->x(arg1); + node->y(arg2); + }; + private: loco::Node *_next_input; }; |