summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>2019-09-17 16:58:51 +0900
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2019-09-17 16:58:51 +0900
commit421490ad05ba56798ce92b2e10dd60fbae168b2b (patch)
tree63f638c8484fabd2daedd8481dcde4ab2d9dbc8a
parentb3d041950855f5f3d2884012ab004394e430ff0d (diff)
downloadnnfw-421490ad05ba56798ce92b2e10dd60fbae168b2b.tar.gz
nnfw-421490ad05ba56798ce92b2e10dd60fbae168b2b.tar.bz2
nnfw-421490ad05ba56798ce92b2e10dd60fbae168b2b.zip
[exo-tflite] Enhancing TestGraph.h (#7469)
* [exo-tflite] Enhancing TestGraph.h 1) Adds more operations. 2) changed the symantics of append() with no arg. Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com> * typo, adds better comment
-rw-r--r--compiler/exo-tflite/src/TestGraph.h35
1 files changed, 31 insertions, 4 deletions
diff --git a/compiler/exo-tflite/src/TestGraph.h b/compiler/exo-tflite/src/TestGraph.h
index b87545580..79d4d4b0e 100644
--- a/compiler/exo-tflite/src/TestGraph.h
+++ b/compiler/exo-tflite/src/TestGraph.h
@@ -76,7 +76,7 @@ private:
// setInput of TFL nodes
template <> void PullPushGraph<locoex::TFLAveragePool2D>::setInput() { middle_node->value(pull); }
-struct TestGraph
+class TestGraph
{
public:
std::unique_ptr<loco::Graph> g;
@@ -105,16 +105,17 @@ public:
_next_input = pull;
}
- template <class T> T *append() // input will be previously appended node
+ /// @brief Creates node with NO arg and appends it to graph
+ template <class T> T *append()
{
auto node = g->nodes()->create<T>();
- setInput(node, _next_input);
_next_input = node;
return node;
}
- template <class T> T *append(loco::Node *arg1) // create [ arg1 - T ] subgraph
+ /// @brief Creates op T (arity=1) with arg1 as an input and appends it to graph
+ template <class T> T *append(loco::Node *arg1)
{
auto node = g->nodes()->create<T>();
setInput(node, arg1);
@@ -123,6 +124,7 @@ public:
return node;
}
+ /// @brief Creates op T (arity=2) with arg1, arg2 as inputs and appends it to graph
template <class T> T *append(loco::Node *arg1, loco::Node *arg2)
{
auto node = g->nodes()->create<T>();
@@ -141,6 +143,13 @@ private:
// arity 1
void setInput(loco::Node *node, loco::Node *) { assert(false && "NYI"); };
+ void setInput(loco::AvgPool2D *node, loco::Node *input) { node->ifm(input); }
+ void setInput(loco::BiasDecode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::BiasEncode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::FeatureDecode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::FeatureEncode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::MaxPool2D *node, loco::Node *input) { node->ifm(input); }
+ void setInput(loco::Push *node, loco::Node *input) { node->from(input); };
void setInput(loco::ReLU *node, loco::Node *input) { node->input(input); };
void setInput(loco::ReLU6 *node, loco::Node *input) { node->input(input); };
void setInput(loco::Tanh *node, loco::Node *input) { node->input(input); };
@@ -158,6 +167,24 @@ private:
node->rhs(arg2);
};
+ void setInput(loco::FeatureBiasAdd *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->value(arg1);
+ node->bias(arg2);
+ };
+
+ void setInput(locoex::TFLAdd *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->x(arg1);
+ node->y(arg2);
+ };
+
+ void setInput(locoex::TFLMul *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->x(arg1);
+ node->y(arg2);
+ };
+
private:
loco::Node *_next_input;
};