summaryrefslogtreecommitdiff
path: root/compiler/luci/service/src/TestGraph.h
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/service/src/TestGraph.h')
-rw-r--r--compiler/luci/service/src/TestGraph.h315
1 files changed, 315 insertions, 0 deletions
diff --git a/compiler/luci/service/src/TestGraph.h b/compiler/luci/service/src/TestGraph.h
new file mode 100644
index 000000000..73562040f
--- /dev/null
+++ b/compiler/luci/service/src/TestGraph.h
@@ -0,0 +1,315 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __TEST_GRAPH_H__
+#define __TEST_GRAPH_H__
+
+#include <luci/IR/CircleNodes.h>
+#include "GraphBlock.h"
+
+#include <loco.h>
+
+#include <cassert>
+#include <memory>
+
+// TODO Change all Canonical nodes to Circle nodes
+
+namespace luci
+{
+namespace test
+{
+
+class TestGraph
+{
+public:
+ std::unique_ptr<loco::Graph> g;
+ loco::Pull *pull;
+ loco::Push *push;
+
+ TestGraph() // creates Pull and Push
+ {
+ g = loco::make_graph();
+
+ pull = g->nodes()->create<loco::Pull>();
+
+ push = g->nodes()->create<loco::Push>();
+
+ auto input = g->inputs()->create();
+ {
+ input->name("input");
+ loco::link(input, pull);
+ }
+ auto output = g->outputs()->create();
+ {
+ output->name("output");
+ loco::link(output, push);
+ }
+
+ _next_input = pull;
+ }
+
+ loco::Graph *graph() { return g.get(); }
+
+ /// @brief Creates node with NO arg and appends it to graph
+ template <class T> T *append()
+ {
+ auto node = g->nodes()->create<T>();
+ _next_input = node;
+
+ return node;
+ }
+
+ /// @brief Creates op T (arity=1) with arg1 as an input and appends it to graph
+ template <class T> T *append(loco::Node *arg1)
+ {
+ auto node = g->nodes()->create<T>();
+ setInput(node, arg1);
+ _next_input = node;
+
+ return node;
+ }
+
+ /// @brief Creates op T (arity=2) with arg1, arg2 as inputs and appends it to graph
+ template <class T> T *append(loco::Node *arg1, loco::Node *arg2)
+ {
+ auto node = g->nodes()->create<T>();
+ setInput(node, arg1, arg2);
+ _next_input = node;
+
+ return node;
+ }
+
+ /// @brief Creates op T (arity=3) with arg1, arg2, arg3 as inputs and appends it to graph
+ template <class T> T *append(loco::Node *arg1, loco::Node *arg2, loco::Node *arg3)
+ {
+ auto node = g->nodes()->create<T>();
+ setInput(node, arg1, arg2, arg3);
+ _next_input = node;
+
+ return node;
+ }
+
+ // push will get the last appended node
+ void complete() { push->from(_next_input); }
+
+ void complete(loco::Node *last_node) { push->from(last_node); }
+
+private:
+ // arity 1
+ void setInput(loco::Node *node, loco::Node *) { assert(false && "NYI"); };
+
+ void setInput(loco::AvgPool2D *node, loco::Node *input) { node->ifm(input); }
+ void setInput(loco::BiasDecode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::BiasEncode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::FeatureDecode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::FeatureEncode *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::MaxPool2D *node, loco::Node *input) { node->ifm(input); }
+ void setInput(loco::Push *node, loco::Node *input) { node->from(input); };
+ void setInput(loco::ReLU *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::ReLU6 *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::Tanh *node, loco::Node *input) { node->input(input); };
+ void setInput(loco::TensorTranspose *node, loco::Node *input) { node->input(input); };
+
+ void setInput(luci::CircleAveragePool2D *node, loco::Node *input) { node->value(input); };
+ void setInput(luci::CircleMaxPool2D *node, loco::Node *input) { node->value(input); };
+ void setInput(luci::CircleRelu *node, loco::Node *input) { node->features(input); };
+ void setInput(luci::CircleRelu6 *node, loco::Node *input) { node->features(input); };
+
+ // arity 2
+ void setInput(loco::Node *node, loco::Node *, loco::Node *) { assert(false && "NYI"); };
+
+ void setInput(loco::Conv2D *node, loco::Node *input, loco::Node *filter)
+ {
+ node->ifm(input);
+ node->ker(filter);
+ }
+
+ void setInput(loco::EltwiseAdd *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->lhs(arg1);
+ node->rhs(arg2);
+ };
+
+ void setInput(loco::FeatureBiasAdd *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->value(arg1);
+ node->bias(arg2);
+ };
+
+ void setInput(luci::CircleAdd *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->x(arg1);
+ node->y(arg2);
+ };
+
+ void setInput(luci::CircleMul *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->x(arg1);
+ node->y(arg2);
+ };
+
+ void setInput(luci::CircleSub *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->x(arg1);
+ node->y(arg2);
+ };
+
+ void setInput(luci::CircleTranspose *node, loco::Node *arg1, loco::Node *arg2)
+ {
+ node->a(arg1);
+ node->perm(arg2);
+ };
+
+ // arity 3
+ void setInput(loco::Node *node, loco::Node *, loco::Node *, loco::Node *)
+ {
+ assert(false && "NYI");
+ };
+
+ void setInput(luci::CircleConv2D *node, loco::Node *input, loco::Node *filter, loco::Node *bias)
+ {
+ node->input(input);
+ node->filter(filter);
+ node->bias(bias);
+ }
+
+private:
+ loco::Node *_next_input;
+};
+
+enum class ExampleGraphType
+{
+ FeatureBiasAdd,
+ ConstGen_ReLU,
+ FilterEncode_FilterDecode,
+ Transpose,
+
+ CircleTranspose,
+};
+
+template <ExampleGraphType T> class ExampleGraph;
+
+/**
+ * @brief Class to create the following:
+ *
+ * Pull - FeatureEncoder - FeatureBiasAdd - FeatureDecode - Push
+ * |
+ * ConstGen - BiasEncode --+
+ */
+template <> class ExampleGraph<ExampleGraphType::FeatureBiasAdd> : public TestGraph
+{
+public:
+ loco::FeatureEncode *fea_enc = nullptr;
+ loco::ConstGen *constgen = nullptr;
+ loco::BiasEncode *bias_enc = nullptr;
+ loco::FeatureBiasAdd *fea_bias_add = nullptr;
+ loco::FeatureDecode *fea_dec = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ fea_enc = luci::make_feature_encode<luci::FeatureLayout::NHWC>(pull);
+ constgen = append<loco::ConstGen>();
+ bias_enc = append<loco::BiasEncode>(constgen);
+ fea_bias_add = append<loco::FeatureBiasAdd>(fea_enc, bias_enc);
+ fea_dec = luci::make_feature_decode<luci::FeatureLayout::NHWC>(fea_bias_add);
+ complete(fea_dec);
+ }
+};
+
+/**
+ * @brief Class to creates the following:
+ *
+ * ConstGen -- ReLU -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::ConstGen_ReLU> : public TestGraph
+{
+public:
+ loco::ConstGen *constgen = nullptr;
+ loco::ReLU *relu = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ constgen = append<loco::ConstGen>();
+ relu = append<loco::ReLU>(constgen);
+ complete(relu);
+ }
+};
+
+/**
+ * @brief Class to creates the following:
+ *
+ * Pull -- Transpose -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::Transpose> : public TestGraph
+{
+public:
+ loco::TensorTranspose *transpose = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ transpose = append<loco::TensorTranspose>(pull);
+ complete(transpose);
+ }
+};
+
+/**
+ * @brief Class to creates the following:
+ *
+ * Pull -- FilterEncode -- FilterDecode -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::FilterEncode_FilterDecode> : public TestGraph
+{
+public:
+ loco::FilterEncode *filterEncode = nullptr;
+ loco::FilterDecode *filterDecode = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ filterEncode = luci::make_filter_encode<luci::FilterLayout::HWIO>(pull); // from Tensorflow
+ filterDecode =
+ luci::make_filter_decode<luci::FilterLayout::OHWI>(filterEncode); // to Tensorflow Lite
+ complete(filterDecode);
+ }
+};
+
+/**
+ * @brief Class to create the following:
+ *
+ * Pull -- CircleTranspose -- Push
+ */
+template <> class ExampleGraph<ExampleGraphType::CircleTranspose> : public TestGraph
+{
+public:
+ loco::ConstGen *const_perm = nullptr;
+ luci::CircleTranspose *transpose_node = nullptr;
+
+public:
+ ExampleGraph()
+ {
+ const_perm = append<loco::ConstGen>();
+ transpose_node = append<luci::CircleTranspose>(pull, const_perm);
+ complete(transpose_node);
+ }
+};
+
+} // namespace test
+} // namespace luci
+
+#endif // __TEST_GRAPH_H__