summaryrefslogtreecommitdiff
path: root/compiler/loco/src
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/loco/src')
-rw-r--r--compiler/loco/src/ADT/AnnotatedItem.test.cpp75
-rw-r--r--compiler/loco/src/ADT/ObjectPool.cpp19
-rw-r--r--compiler/loco/src/IR/Algorithm.cpp121
-rw-r--r--compiler/loco/src/IR/Algorithm.test.cpp122
-rw-r--r--compiler/loco/src/IR/BiasShape.test.cpp26
-rw-r--r--compiler/loco/src/IR/CanonicalDialect.cpp67
-rw-r--r--compiler/loco/src/IR/CanonicalDialect.test.cpp29
-rw-r--r--compiler/loco/src/IR/CanonicalNode.cpp25
-rw-r--r--compiler/loco/src/IR/CanonicalNode.test.cpp72
-rw-r--r--compiler/loco/src/IR/CanonicalOpcode.cpp19
-rw-r--r--compiler/loco/src/IR/DataType.cpp19
-rw-r--r--compiler/loco/src/IR/DataTypeTraits.test.cpp29
-rw-r--r--compiler/loco/src/IR/DepthwiseFilterAxis.cpp19
-rw-r--r--compiler/loco/src/IR/DepthwiseFilterCodec.cpp19
-rw-r--r--compiler/loco/src/IR/DepthwiseFilterIndex.test.cpp67
-rw-r--r--compiler/loco/src/IR/DepthwiseFilterShape.test.cpp80
-rw-r--r--compiler/loco/src/IR/Dialect.cpp19
-rw-r--r--compiler/loco/src/IR/Dialect.test.cpp41
-rw-r--r--compiler/loco/src/IR/DialectService.cpp19
-rw-r--r--compiler/loco/src/IR/Dimension.cpp32
-rw-r--r--compiler/loco/src/IR/Dimension.test.cpp100
-rw-r--r--compiler/loco/src/IR/Domain.cpp19
-rw-r--r--compiler/loco/src/IR/FeatureAxis.cpp19
-rw-r--r--compiler/loco/src/IR/FeatureCodec.cpp19
-rw-r--r--compiler/loco/src/IR/FeatureIndex.test.cpp67
-rw-r--r--compiler/loco/src/IR/FeatureShape.test.cpp80
-rw-r--r--compiler/loco/src/IR/FilterAxis.cpp19
-rw-r--r--compiler/loco/src/IR/FilterCodec.cpp19
-rw-r--r--compiler/loco/src/IR/FilterIndex.test.cpp67
-rw-r--r--compiler/loco/src/IR/FilterShape.test.cpp80
-rw-r--r--compiler/loco/src/IR/Graph.cpp137
-rw-r--r--compiler/loco/src/IR/Graph.test.cpp218
-rw-r--r--compiler/loco/src/IR/GraphInputIndex.cpp19
-rw-r--r--compiler/loco/src/IR/GraphOutputIndex.cpp19
-rw-r--r--compiler/loco/src/IR/MatrixAxis.cpp19
-rw-r--r--compiler/loco/src/IR/MatrixCodec.cpp19
-rw-r--r--compiler/loco/src/IR/MockupNode.h58
-rw-r--r--compiler/loco/src/IR/Node.cpp88
-rw-r--r--compiler/loco/src/IR/Node.test.cpp102
-rw-r--r--compiler/loco/src/IR/NodeMixins.cpp19
-rw-r--r--compiler/loco/src/IR/NodePool.cpp31
-rw-r--r--compiler/loco/src/IR/NodeShape.cpp284
-rw-r--r--compiler/loco/src/IR/NodeShape.test.cpp125
-rw-r--r--compiler/loco/src/IR/Nodes.cpp243
-rw-r--r--compiler/loco/src/IR/Nodes.test.cpp588
-rw-r--r--compiler/loco/src/IR/Padding2D.test.cpp29
-rw-r--r--compiler/loco/src/IR/PaddingND.test.cpp32
-rw-r--r--compiler/loco/src/IR/PermutingCodec.cpp630
-rw-r--r--compiler/loco/src/IR/PermutingCodec.test.cpp553
-rw-r--r--compiler/loco/src/IR/Stride.test.cpp42
-rw-r--r--compiler/loco/src/IR/TensorAxis.cpp19
-rw-r--r--compiler/loco/src/IR/TensorAxisSet.cpp19
-rw-r--r--compiler/loco/src/IR/TensorIndex.cpp19
-rw-r--r--compiler/loco/src/IR/TensorShape.cpp39
-rw-r--r--compiler/loco/src/IR/TensorShape.test.cpp109
-rw-r--r--compiler/loco/src/IR/Use.cpp45
-rw-r--r--compiler/loco/src/IR/Use.test.cpp42
-rw-r--r--compiler/loco/src/IR/Verifier.cpp119
-rw-r--r--compiler/loco/src/IR/Verifier.test.cpp64
-rw-r--r--compiler/loco/src/IR/Window.test.cpp42
-rw-r--r--compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp774
-rw-r--r--compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp400
-rw-r--r--compiler/loco/src/Service/GraphBuilder.h547
-rw-r--r--compiler/loco/src/Service/GraphBuilder.test.cpp47
-rw-r--r--compiler/loco/src/Service/GraphTestcase.h541
-rw-r--r--compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp67
-rw-r--r--compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp134
-rw-r--r--compiler/loco/src/Service/ShapeInference.cpp105
-rw-r--r--compiler/loco/src/Service/ShapeInference.test.cpp87
-rw-r--r--compiler/loco/src/Service/ShapeInferenceRule.cpp31
-rw-r--r--compiler/loco/src/Service/TypeInference.cpp228
-rw-r--r--compiler/loco/src/Service/TypeInference.test.cpp282
-rw-r--r--compiler/loco/src/loco.test.cpp108
-rw-r--r--compiler/loco/src/tensorflow.test.cpp386
74 files changed, 8917 insertions, 0 deletions
diff --git a/compiler/loco/src/ADT/AnnotatedItem.test.cpp b/compiler/loco/src/ADT/AnnotatedItem.test.cpp
new file mode 100644
index 000000000..42113ff7b
--- /dev/null
+++ b/compiler/loco/src/ADT/AnnotatedItem.test.cpp
@@ -0,0 +1,75 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/ADT/AnnotatedItem.h"
+
+#include <gtest/gtest.h>
+#include <stdex/Memory.h>
+
+namespace
+{
+
+struct Annotation
+{
+ virtual ~Annotation() = default;
+};
+
+template <int N> struct DerivedAnnotation final : public Annotation
+{
+ static std::unique_ptr<DerivedAnnotation<N>> make(void)
+ {
+ return stdex::make_unique<DerivedAnnotation<N>>();
+ }
+};
+
+} // namespace
+
+TEST(AnnotatedItemTest, annotation)
+{
+ loco::AnnotatedItem<::Annotation> item;
+
+ ASSERT_EQ(item.annot<DerivedAnnotation<0>>(), nullptr);
+
+ item.annot(DerivedAnnotation<0>::make());
+
+ ASSERT_NE(item.annot<DerivedAnnotation<0>>(), nullptr);
+ ASSERT_EQ(item.annot<DerivedAnnotation<1>>(), nullptr);
+
+ item.annot<DerivedAnnotation<0>>(nullptr);
+ ASSERT_EQ(item.annot<DerivedAnnotation<0>>(), nullptr);
+
+ // Below check guarantees that "annot<T>(nullptr)" is allowed even when there is no annotation.
+ // This guarantee allows us to simplify code for some cases.
+ //
+ // Let us consider the following example:
+ //
+ // void f(loco::AnnotatedItem<T> *item)
+ // {
+ // /* DO SOMETHING */
+ // if (cond) { item->annot<T>(nullptr);
+ // }
+ //
+ // void g(loco::AnnotatedItem<T> *item)
+ // {
+ // f(item);
+ // item->annot<T>(nullptr);
+ // }
+ //
+ // The implementation of "g" gets complicated if annot<T>(nullptr) is not allowed if there is
+ // no annotation.
+ //
+ item.annot<DerivedAnnotation<0>>(nullptr);
+}
diff --git a/compiler/loco/src/ADT/ObjectPool.cpp b/compiler/loco/src/ADT/ObjectPool.cpp
new file mode 100644
index 000000000..d15a30a99
--- /dev/null
+++ b/compiler/loco/src/ADT/ObjectPool.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/ADT/ObjectPool.h"
+
+// This file validates "ObjectPool.h". Pleaes DO NOT remove this file.
diff --git a/compiler/loco/src/IR/Algorithm.cpp b/compiler/loco/src/IR/Algorithm.cpp
new file mode 100644
index 000000000..712e29975
--- /dev/null
+++ b/compiler/loco/src/IR/Algorithm.cpp
@@ -0,0 +1,121 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Algorithm.h"
+
+#include <cassert>
+#include <set>
+#include <stack>
+
+namespace
+{
+
+class Frame final
+{
+public:
+ Frame(loco::Node *ptr) : _ptr{ptr}, _pos{-1}
+ {
+ // DO NOTHING
+ }
+
+public:
+ loco::Node *ptr(void) const { return _ptr; }
+ int64_t pos(void) const { return _pos; }
+
+ loco::Node &node(void) const { return *_ptr; }
+
+ void advance(void) { _pos += 1; }
+
+private:
+ loco::Node *_ptr = nullptr;
+ int64_t _pos = -1;
+};
+
+} // namespace
+
+namespace loco
+{
+
+// TODO Support cyclic graphs
+std::vector<loco::Node *> postorder_traversal(const std::vector<loco::Node *> &roots)
+{
+ std::vector<loco::Node *> res;
+
+ std::set<loco::Node *> visited_nodes;
+ std::stack<Frame> frames;
+
+ auto visited = [&visited_nodes](loco::Node *node) {
+ return visited_nodes.find(node) != visited_nodes.end();
+ };
+
+ // NOTE There is not much difference between "auto" and "auto &" as node is of "loco::Node *"
+ // type.
+ for (auto node : roots)
+ {
+ assert((node != nullptr) && "root is invalid");
+ frames.push(Frame{node});
+ }
+
+ while (!frames.empty())
+ {
+ auto &top_frame = frames.top();
+
+ if (top_frame.pos() == -1)
+ {
+ if (visited(top_frame.ptr()))
+ {
+ frames.pop();
+ continue;
+ }
+ visited_nodes.insert(top_frame.ptr());
+ }
+
+ top_frame.advance();
+
+ assert(top_frame.pos() >= 0);
+
+ if (top_frame.pos() < static_cast<int64_t>(top_frame.node().arity()))
+ {
+ // Let's visit the next argument
+ //
+ // NOTE "next" may be nullptr if a graph is under construction.
+ if (auto next = top_frame.node().arg(top_frame.pos()))
+ {
+ frames.push(Frame{next});
+ }
+ }
+ else
+ {
+ // Let's visit the current argument (all the arguments are already visited)
+ auto curr = top_frame.ptr();
+ res.emplace_back(curr);
+ frames.pop();
+ }
+ }
+
+ return res;
+}
+
+std::set<loco::Node *> active_nodes(const std::vector<loco::Node *> &roots)
+{
+ // This implementation works but may be inefficient
+ //
+ // TODO Use efficient implementation if necessary
+ auto nodes = postorder_traversal(roots);
+ return std::set<loco::Node *>{nodes.begin(), nodes.end()};
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/Algorithm.test.cpp b/compiler/loco/src/IR/Algorithm.test.cpp
new file mode 100644
index 000000000..f0a3585c0
--- /dev/null
+++ b/compiler/loco/src/IR/Algorithm.test.cpp
@@ -0,0 +1,122 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Algorithm.h"
+#include "loco/IR/Graph.h"
+
+#include <algorithm>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+bool contains(const std::vector<loco::Node *> &vec, loco::Node *val)
+{
+ return std::any_of(vec.begin(), vec.end(), [val](loco::Node *node) { return node == val; });
+}
+
+bool contains(const std::set<loco::Node *> &s, loco::Node *val)
+{
+ return std::any_of(s.begin(), s.end(), [val](loco::Node *node) { return node == val; });
+}
+
+} // namespace
+
+TEST(AlgorithmTest, postorder_traversal)
+{
+ auto g = loco::make_graph();
+
+ auto pull_1 = g->nodes()->create<loco::Pull>();
+ auto push = g->nodes()->create<loco::Push>();
+
+ push->from(pull_1);
+
+ // Create a dummy node unreachable from the above "push" node
+ g->nodes()->create<loco::Pull>();
+
+ auto seq = loco::postorder_traversal({push});
+
+ ASSERT_EQ(seq.size(), 2);
+ ASSERT_EQ(seq.at(0), pull_1);
+ ASSERT_EQ(seq.at(1), push);
+}
+
+TEST(AlgorithmTest, postorder_traversal_visit_once)
+{
+ auto g = loco::make_graph();
+
+ // Create a network of the following form:
+ //
+ // Push1 Push2 <-- outputs
+ // \ /
+ // Pull <-- input
+ //
+ auto pull = g->nodes()->create<loco::Pull>();
+ auto push_1 = g->nodes()->create<loco::Push>();
+ auto push_2 = g->nodes()->create<loco::Push>();
+
+ push_1->from(pull);
+ push_2->from(pull);
+
+ auto seq = loco::postorder_traversal({push_1, push_2});
+
+ ASSERT_EQ(seq.size(), 3);
+ ASSERT_TRUE(contains(seq, pull));
+ ASSERT_TRUE(contains(seq, push_1));
+ ASSERT_TRUE(contains(seq, push_2));
+}
+
+TEST(AlgorithmTest, postorder_traversal_incomplte_graph)
+{
+ auto g = loco::make_graph();
+
+ // Create a network of the following form:
+ //
+ // TensorConcat
+ // / \
+ // Pull X
+ //
+ auto pull = g->nodes()->create<loco::Pull>();
+ auto concat = g->nodes()->create<loco::TensorConcat>();
+
+ concat->lhs(pull);
+
+ auto seq = loco::postorder_traversal({concat});
+
+ ASSERT_EQ(seq.size(), 2);
+ ASSERT_EQ(seq.at(0), pull);
+ ASSERT_EQ(seq.at(1), concat);
+}
+
+TEST(AlgorithmTest, active_nodes)
+{
+ auto g = loco::make_graph();
+
+ auto pull = g->nodes()->create<loco::Pull>();
+ auto push = g->nodes()->create<loco::Push>();
+
+ push->from(pull);
+
+ // NOTE This new Push node is unnecessary to compute "push"
+ g->nodes()->create<loco::Push>();
+
+ auto s = loco::active_nodes({push});
+
+ ASSERT_EQ(s.size(), 2);
+ ASSERT_TRUE(contains(s, pull));
+ ASSERT_TRUE(contains(s, push));
+}
diff --git a/compiler/loco/src/IR/BiasShape.test.cpp b/compiler/loco/src/IR/BiasShape.test.cpp
new file mode 100644
index 000000000..7f9b8dfed
--- /dev/null
+++ b/compiler/loco/src/IR/BiasShape.test.cpp
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/BiasShape.h"
+
+#include <gtest/gtest.h>
+
+TEST(BiasShapeTest, default_constructor)
+{
+ loco::BiasShape shape;
+
+ ASSERT_FALSE(shape.length().known());
+}
diff --git a/compiler/loco/src/IR/CanonicalDialect.cpp b/compiler/loco/src/IR/CanonicalDialect.cpp
new file mode 100644
index 000000000..ea956b80e
--- /dev/null
+++ b/compiler/loco/src/IR/CanonicalDialect.cpp
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/CanonicalDialect.h"
+#include "loco/IR/Graph.h"
+#include "loco/IR/Nodes.h"
+
+#include <stdex/Memory.h>
+
+#include <cassert>
+#include <stdexcept>
+
+namespace
+{
+
+struct GraphOutputIndexQueryServiceImpl final : public loco::GraphOutputIndexQueryService
+{
+ bool associated(const loco::Node *node) const final
+ {
+ if (auto push = dynamic_cast<const loco::Push *>(node))
+ {
+ return push->indexed();
+ }
+ return false;
+ }
+
+ loco::GraphOutputIndex index(const loco::Node *node) const final
+ {
+ assert(associated(node));
+ if (auto push = dynamic_cast<const loco::Push *>(node))
+ {
+ return push->index();
+ }
+ throw std::invalid_argument("node");
+ }
+};
+
+} // namespace
+
+namespace loco
+{
+
+CanonicalDialect::CanonicalDialect()
+{
+ service<GraphOutputIndexQueryService>(stdex::make_unique<GraphOutputIndexQueryServiceImpl>());
+}
+
+Dialect *CanonicalDialect::get(void)
+{
+ static CanonicalDialect d;
+ return &d;
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/CanonicalDialect.test.cpp b/compiler/loco/src/IR/CanonicalDialect.test.cpp
new file mode 100644
index 000000000..96b48218d
--- /dev/null
+++ b/compiler/loco/src/IR/CanonicalDialect.test.cpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/CanonicalDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(CanonicalDialectTest, get)
+{
+ auto d = loco::CanonicalDialect::get();
+
+ // get() SHOULD return a valid(non-null) pointer
+ ASSERT_NE(d, nullptr);
+ // The return value SHOULD be stable across multiple invocations
+ ASSERT_EQ(d, loco::CanonicalDialect::get());
+}
diff --git a/compiler/loco/src/IR/CanonicalNode.cpp b/compiler/loco/src/IR/CanonicalNode.cpp
new file mode 100644
index 000000000..d5e13a415
--- /dev/null
+++ b/compiler/loco/src/IR/CanonicalNode.cpp
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/CanonicalNode.h"
+#include "loco/IR/CanonicalDialect.h"
+
+namespace loco
+{
+
+const Dialect *CanonicalNode::dialect(void) const { return CanonicalDialect::get(); }
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/CanonicalNode.test.cpp b/compiler/loco/src/IR/CanonicalNode.test.cpp
new file mode 100644
index 000000000..cb61b5e83
--- /dev/null
+++ b/compiler/loco/src/IR/CanonicalNode.test.cpp
@@ -0,0 +1,72 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/CanonicalNode.h"
+
+#include <gtest/gtest.h>
+
+TEST(CanonicalNodeTest, visitor_with_user_default_impl)
+{
+ struct MyVisitor final : public loco::CanonicalNodeVisitor<uint32_t>
+ {
+ // This visitor returns 128 if it visits a Forward node.
+ uint32_t visit(const loco::Forward *) final { return 128; }
+
+ // Otherwise, this visitor returns 256.
+ uint32_t visit(const loco::Node *) final { return 256; }
+ };
+
+ loco::Forward forward;
+ loco::ConstGen constgen;
+
+ MyVisitor v;
+
+ ASSERT_EQ(forward.accept(&v), 128);
+ ASSERT_EQ(constgen.accept(&v), 256);
+}
+
+TEST(CanonicalNodeTest, visitor)
+{
+ struct CountingVisitor final : public loco::CanonicalNodeVisitor<uint32_t>
+ {
+ uint32_t visit(const loco::Forward *) final { return 1; }
+ };
+
+ // Visitor can visit constant nodes
+ const loco::Forward node;
+
+ CountingVisitor v;
+
+ ASSERT_EQ(node.accept(&v), 1);
+}
+
+TEST(CanonicalNodeTest, mutable_visitor)
+{
+ struct ResetForward final : public loco::CanonicalNodeMutableVisitor<void>
+ {
+ void visit(loco::Forward *node) final { node->input(nullptr); }
+ };
+
+ loco::Pull pull_node;
+ loco::Forward forward_node;
+
+ forward_node.input(&pull_node);
+
+ ResetForward v;
+ forward_node.accept(&v);
+
+ ASSERT_EQ(forward_node.input(), nullptr);
+}
diff --git a/compiler/loco/src/IR/CanonicalOpcode.cpp b/compiler/loco/src/IR/CanonicalOpcode.cpp
new file mode 100644
index 000000000..6355ecf1f
--- /dev/null
+++ b/compiler/loco/src/IR/CanonicalOpcode.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/CanonicalOpcode.h"
+
+// NOTE This file validates "CanonicalOpcode.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/DataType.cpp b/compiler/loco/src/IR/DataType.cpp
new file mode 100644
index 000000000..56794dac7
--- /dev/null
+++ b/compiler/loco/src/IR/DataType.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/DataType.h"
+
+// This file validates "DataType.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/DataTypeTraits.test.cpp b/compiler/loco/src/IR/DataTypeTraits.test.cpp
new file mode 100644
index 000000000..76d2515a9
--- /dev/null
+++ b/compiler/loco/src/IR/DataTypeTraits.test.cpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/DataTypeTraits.h"
+
+#include <typeindex>
+
+#include <gtest/gtest.h>
+
+TEST(DataTypeTraitsTest, FLOAT32)
+{
+ auto obtained = std::type_index(typeid(loco::DataTypeImpl<loco::DataType::FLOAT32>::Type));
+ auto expected = std::type_index(typeid(float));
+
+ ASSERT_EQ(obtained, expected);
+}
diff --git a/compiler/loco/src/IR/DepthwiseFilterAxis.cpp b/compiler/loco/src/IR/DepthwiseFilterAxis.cpp
new file mode 100644
index 000000000..9d58795b2
--- /dev/null
+++ b/compiler/loco/src/IR/DepthwiseFilterAxis.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/DepthwiseFilterAxis.h"
+
+// NOTE This file validates "DepthwiseFilterAxis.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/DepthwiseFilterCodec.cpp b/compiler/loco/src/IR/DepthwiseFilterCodec.cpp
new file mode 100644
index 000000000..05a7fd723
--- /dev/null
+++ b/compiler/loco/src/IR/DepthwiseFilterCodec.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/DepthwiseFilterCodec.h"
+
+// NOTE This file validates "DepthwiseFilterCodec.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/DepthwiseFilterIndex.test.cpp b/compiler/loco/src/IR/DepthwiseFilterIndex.test.cpp
new file mode 100644
index 000000000..202647cfc
--- /dev/null
+++ b/compiler/loco/src/IR/DepthwiseFilterIndex.test.cpp
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/DepthwiseFilterIndex.h"
+
+#include <gtest/gtest.h>
+
+TEST(DepthwiseFilterIndexTest, default_constructor)
+{
+ loco::DepthwiseFilterIndex index;
+
+ // All the values are 0 at the beginning
+ ASSERT_EQ(index.channel(), 0);
+ ASSERT_EQ(index.nth(), 0);
+ ASSERT_EQ(index.row(), 0);
+ ASSERT_EQ(index.column(), 0);
+}
+
+TEST(DepthwiseFilterIndexTest, settet_and_getter)
+{
+ loco::DepthwiseFilterIndex index;
+
+ // Set depth
+ index.channel() = 2;
+
+ ASSERT_EQ(index.channel(), 2);
+ ASSERT_EQ(index.nth(), 0);
+ ASSERT_EQ(index.row(), 0);
+ ASSERT_EQ(index.column(), 0);
+
+ // Set multiplier
+ index.nth() = 3;
+
+ ASSERT_EQ(index.channel(), 2);
+ ASSERT_EQ(index.nth(), 3);
+ ASSERT_EQ(index.row(), 0);
+ ASSERT_EQ(index.column(), 0);
+
+ // Set height
+ index.row() = 4;
+
+ ASSERT_EQ(index.channel(), 2);
+ ASSERT_EQ(index.nth(), 3);
+ ASSERT_EQ(index.row(), 4);
+ ASSERT_EQ(index.column(), 0);
+
+ // Set width
+ index.column() = 5;
+
+ ASSERT_EQ(index.channel(), 2);
+ ASSERT_EQ(index.nth(), 3);
+ ASSERT_EQ(index.row(), 4);
+ ASSERT_EQ(index.column(), 5);
+}
diff --git a/compiler/loco/src/IR/DepthwiseFilterShape.test.cpp b/compiler/loco/src/IR/DepthwiseFilterShape.test.cpp
new file mode 100644
index 000000000..2b9518c1f
--- /dev/null
+++ b/compiler/loco/src/IR/DepthwiseFilterShape.test.cpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/DepthwiseFilterShape.h"
+
+#include <gtest/gtest.h>
+
+TEST(DepthwiseFilterShapeTest, default_constructor)
+{
+ loco::DepthwiseFilterShape shape;
+
+ ASSERT_FALSE(shape.depth().known());
+ ASSERT_FALSE(shape.multiplier().known());
+ ASSERT_FALSE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+}
+
+TEST(DepthwiseFilterShapeTest, settet_and_getter)
+{
+ loco::DepthwiseFilterShape shape;
+
+ // Set depth
+ shape.depth() = 2;
+
+ ASSERT_TRUE(shape.depth().known());
+ ASSERT_FALSE(shape.multiplier().known());
+ ASSERT_FALSE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+
+ ASSERT_EQ(shape.depth(), 2);
+
+ // Set multiplier
+ shape.multiplier() = 3;
+
+ ASSERT_TRUE(shape.depth().known());
+ ASSERT_TRUE(shape.multiplier().known());
+ ASSERT_FALSE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+
+ ASSERT_EQ(shape.depth(), 2);
+ ASSERT_EQ(shape.multiplier(), 3);
+
+ // Set height
+ shape.height() = 4;
+
+ ASSERT_TRUE(shape.depth().known());
+ ASSERT_TRUE(shape.multiplier().known());
+ ASSERT_TRUE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+
+ ASSERT_EQ(shape.depth(), 2);
+ ASSERT_EQ(shape.multiplier(), 3);
+ ASSERT_EQ(shape.height(), 4);
+
+ // Set width
+ shape.width() = 5;
+
+ ASSERT_TRUE(shape.depth().known());
+ ASSERT_TRUE(shape.multiplier().known());
+ ASSERT_TRUE(shape.height().known());
+ ASSERT_TRUE(shape.width().known());
+
+ ASSERT_EQ(shape.depth(), 2);
+ ASSERT_EQ(shape.multiplier(), 3);
+ ASSERT_EQ(shape.height(), 4);
+ ASSERT_EQ(shape.width(), 5);
+}
diff --git a/compiler/loco/src/IR/Dialect.cpp b/compiler/loco/src/IR/Dialect.cpp
new file mode 100644
index 000000000..a381b47eb
--- /dev/null
+++ b/compiler/loco/src/IR/Dialect.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Dialect.h"
+
+// NOTE This file validates "Dialect.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/Dialect.test.cpp b/compiler/loco/src/IR/Dialect.test.cpp
new file mode 100644
index 000000000..312bb52ef
--- /dev/null
+++ b/compiler/loco/src/IR/Dialect.test.cpp
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Dialect.h"
+
+#include <stdex/Memory.h>
+
+#include <gtest/gtest.h>
+
+TEST(DialectTest, service)
+{
+ struct S0 final : public loco::DialectService
+ {
+ };
+ struct S1 final : public loco::DialectService
+ {
+ };
+
+ struct MockDialect final : public loco::Dialect
+ {
+ MockDialect() { service<S1>(stdex::make_unique<S1>()); }
+ };
+
+ MockDialect dialect;
+
+ ASSERT_EQ(dialect.service<S0>(), nullptr);
+ ASSERT_NE(dialect.service<S1>(), nullptr);
+}
diff --git a/compiler/loco/src/IR/DialectService.cpp b/compiler/loco/src/IR/DialectService.cpp
new file mode 100644
index 000000000..fb8041e47
--- /dev/null
+++ b/compiler/loco/src/IR/DialectService.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/DialectService.h"
+
+// NOTE This file validates "DialectService.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/Dimension.cpp b/compiler/loco/src/IR/Dimension.cpp
new file mode 100644
index 000000000..0d11c83e8
--- /dev/null
+++ b/compiler/loco/src/IR/Dimension.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Dimension.h"
+
+namespace loco
+{
+
+bool operator==(const Dimension &lhs, const Dimension &rhs)
+{
+ return lhs.known() && rhs.known() && lhs.value() == rhs.value();
+}
+
+bool operator==(const Dimension &lhs, uint32_t rhs) { return lhs.known() && lhs.value() == rhs; }
+bool operator==(uint32_t lhs, const Dimension &rhs) { return rhs.known() && lhs == rhs.value(); }
+
+Dimension make_dimension(void) { return Dimension{}; }
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/Dimension.test.cpp b/compiler/loco/src/IR/Dimension.test.cpp
new file mode 100644
index 000000000..4faf78ac8
--- /dev/null
+++ b/compiler/loco/src/IR/Dimension.test.cpp
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Dimension.h"
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+struct DimensionTest : public ::testing::Test
+{
+protected:
+ uint32_t value(void) const { return _value; }
+
+private:
+ uint32_t const _value{3};
+};
+
+} // namespace
+
+TEST_F(DimensionTest, default_constructor)
+{
+ loco::Dimension dim;
+
+ ASSERT_FALSE(dim.known());
+}
+
+TEST_F(DimensionTest, value_constructor)
+{
+ loco::Dimension dim{value()};
+
+ ASSERT_TRUE(dim.known());
+ ASSERT_EQ(dim.value(), value());
+}
+
+TEST_F(DimensionTest, set)
+{
+ loco::Dimension dim;
+
+ dim.set(value());
+
+ ASSERT_TRUE(dim.known());
+ ASSERT_EQ(dim.value(), value());
+}
+
+TEST_F(DimensionTest, unset)
+{
+ loco::Dimension dim{value()};
+
+ dim.unset();
+
+ ASSERT_FALSE(dim.known());
+}
+
+TEST_F(DimensionTest, operator_eq)
+{
+ loco::Dimension unknown;
+ loco::Dimension known{3};
+
+ // Compare uint32_t and an unknown dimension
+ ASSERT_FALSE(unknown == 3);
+ ASSERT_FALSE(3 == unknown);
+
+ // Compare uint32_t and a known dimension
+ ASSERT_TRUE(known == 3);
+ ASSERT_TRUE(3 == known);
+
+ ASSERT_FALSE(known == 4);
+ ASSERT_FALSE(4 == known);
+
+ // Compare two known dimensions
+ loco::Dimension another_known{3};
+ ASSERT_TRUE(known == another_known);
+
+ // Compare two unknown dimensions
+ loco::Dimension unknown_a, unknown_b;
+ ASSERT_TRUE(unknown_a.known() == false && unknown_b.known() == false);
+ ASSERT_FALSE(unknown_a == unknown_b);
+}
+
+TEST_F(DimensionTest, make_unknown_dimension)
+{
+ auto dim = loco::make_dimension();
+
+ ASSERT_FALSE(dim.known());
+}
diff --git a/compiler/loco/src/IR/Domain.cpp b/compiler/loco/src/IR/Domain.cpp
new file mode 100644
index 000000000..7bad04750
--- /dev/null
+++ b/compiler/loco/src/IR/Domain.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Domain.h"
+
+// NOTE This file validates "Domain.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/FeatureAxis.cpp b/compiler/loco/src/IR/FeatureAxis.cpp
new file mode 100644
index 000000000..b0f560677
--- /dev/null
+++ b/compiler/loco/src/IR/FeatureAxis.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/FeatureAxis.h"
+
+// NOTE This file validates "FeatureAxis.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/FeatureCodec.cpp b/compiler/loco/src/IR/FeatureCodec.cpp
new file mode 100644
index 000000000..99d39a489
--- /dev/null
+++ b/compiler/loco/src/IR/FeatureCodec.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/FeatureCodec.h"
+
+// NOTE This file validates "FeatureCodec.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/FeatureIndex.test.cpp b/compiler/loco/src/IR/FeatureIndex.test.cpp
new file mode 100644
index 000000000..82b563986
--- /dev/null
+++ b/compiler/loco/src/IR/FeatureIndex.test.cpp
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/FeatureIndex.h"
+
+#include <gtest/gtest.h>
+
+TEST(FeatureIndexTest, default_constructor)
+{
+ loco::FeatureIndex index;
+
+ // All the values are 0 at the beginning
+ ASSERT_EQ(index.batch(), 0);
+ ASSERT_EQ(index.channel(), 0);
+ ASSERT_EQ(index.row(), 0);
+ ASSERT_EQ(index.column(), 0);
+}
+
+TEST(FeatureIndexTest, settet_and_getter)
+{
+ loco::FeatureIndex index;
+
+ // Set count
+ index.batch() = 2;
+
+ ASSERT_EQ(index.batch(), 2);
+ ASSERT_EQ(index.channel(), 0);
+ ASSERT_EQ(index.row(), 0);
+ ASSERT_EQ(index.column(), 0);
+
+ // Set channel
+ index.channel() = 3;
+
+ ASSERT_EQ(index.batch(), 2);
+ ASSERT_EQ(index.channel(), 3);
+ ASSERT_EQ(index.row(), 0);
+ ASSERT_EQ(index.column(), 0);
+
+ // Set height
+ index.row() = 4;
+
+ ASSERT_EQ(index.batch(), 2);
+ ASSERT_EQ(index.channel(), 3);
+ ASSERT_EQ(index.row(), 4);
+ ASSERT_EQ(index.column(), 0);
+
+ // Set width
+ index.column() = 5;
+
+ ASSERT_EQ(index.batch(), 2);
+ ASSERT_EQ(index.channel(), 3);
+ ASSERT_EQ(index.row(), 4);
+ ASSERT_EQ(index.column(), 5);
+}
diff --git a/compiler/loco/src/IR/FeatureShape.test.cpp b/compiler/loco/src/IR/FeatureShape.test.cpp
new file mode 100644
index 000000000..59e25ac23
--- /dev/null
+++ b/compiler/loco/src/IR/FeatureShape.test.cpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/FeatureShape.h"
+
+#include <gtest/gtest.h>
+
+TEST(FeatureShapeTest, default_constructor)
+{
+ loco::FeatureShape shape;
+
+ ASSERT_FALSE(shape.count().known());
+ ASSERT_FALSE(shape.depth().known());
+ ASSERT_FALSE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+}
+
+TEST(FeatureShapeTest, settet_and_getter)
+{
+ loco::FeatureShape shape;
+
+ // Set count
+ shape.count() = 2;
+
+ ASSERT_TRUE(shape.count().known());
+ ASSERT_FALSE(shape.depth().known());
+ ASSERT_FALSE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+
+ ASSERT_EQ(shape.count(), 2);
+
+ // Set depth
+ shape.depth() = 3;
+
+ ASSERT_TRUE(shape.count().known());
+ ASSERT_TRUE(shape.depth().known());
+ ASSERT_FALSE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+
+ ASSERT_EQ(shape.count(), 2);
+ ASSERT_EQ(shape.depth(), 3);
+
+ // Set height
+ shape.height() = 4;
+
+ ASSERT_TRUE(shape.count().known());
+ ASSERT_TRUE(shape.depth().known());
+ ASSERT_TRUE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+
+ ASSERT_EQ(shape.count(), 2);
+ ASSERT_EQ(shape.depth(), 3);
+ ASSERT_EQ(shape.height(), 4);
+
+ // Set width
+ shape.width() = 5;
+
+ ASSERT_TRUE(shape.count().known());
+ ASSERT_TRUE(shape.depth().known());
+ ASSERT_TRUE(shape.height().known());
+ ASSERT_TRUE(shape.width().known());
+
+ ASSERT_EQ(shape.count(), 2);
+ ASSERT_EQ(shape.depth(), 3);
+ ASSERT_EQ(shape.height(), 4);
+ ASSERT_EQ(shape.width(), 5);
+}
diff --git a/compiler/loco/src/IR/FilterAxis.cpp b/compiler/loco/src/IR/FilterAxis.cpp
new file mode 100644
index 000000000..be4234e6a
--- /dev/null
+++ b/compiler/loco/src/IR/FilterAxis.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/FilterAxis.h"
+
+// NOTE This file validates "FilterAxis.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/FilterCodec.cpp b/compiler/loco/src/IR/FilterCodec.cpp
new file mode 100644
index 000000000..f48cf1821
--- /dev/null
+++ b/compiler/loco/src/IR/FilterCodec.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/FilterCodec.h"
+
+// NOTE This file validates "FilterCodec.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/FilterIndex.test.cpp b/compiler/loco/src/IR/FilterIndex.test.cpp
new file mode 100644
index 000000000..58f38718e
--- /dev/null
+++ b/compiler/loco/src/IR/FilterIndex.test.cpp
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/FilterIndex.h"
+
+#include <gtest/gtest.h>
+
+TEST(FilterIndexTest, default_constructor)
+{
+ loco::FilterIndex index;
+
+ // All the values are 0 at the beginning
+ ASSERT_EQ(index.nth(), 0);
+ ASSERT_EQ(index.channel(), 0);
+ ASSERT_EQ(index.row(), 0);
+ ASSERT_EQ(index.column(), 0);
+}
+
+TEST(FilterIndexTest, settet_and_getter)
+{
+ loco::FilterIndex index;
+
+ // Set count
+ index.nth() = 2;
+
+ ASSERT_EQ(index.nth(), 2);
+ ASSERT_EQ(index.channel(), 0);
+ ASSERT_EQ(index.row(), 0);
+ ASSERT_EQ(index.column(), 0);
+
+ // Set channel
+ index.channel() = 3;
+
+ ASSERT_EQ(index.nth(), 2);
+ ASSERT_EQ(index.channel(), 3);
+ ASSERT_EQ(index.row(), 0);
+ ASSERT_EQ(index.column(), 0);
+
+ // Set height
+ index.row() = 4;
+
+ ASSERT_EQ(index.nth(), 2);
+ ASSERT_EQ(index.channel(), 3);
+ ASSERT_EQ(index.row(), 4);
+ ASSERT_EQ(index.column(), 0);
+
+ // Set width
+ index.column() = 5;
+
+ ASSERT_EQ(index.nth(), 2);
+ ASSERT_EQ(index.channel(), 3);
+ ASSERT_EQ(index.row(), 4);
+ ASSERT_EQ(index.column(), 5);
+}
diff --git a/compiler/loco/src/IR/FilterShape.test.cpp b/compiler/loco/src/IR/FilterShape.test.cpp
new file mode 100644
index 000000000..ccb60ed76
--- /dev/null
+++ b/compiler/loco/src/IR/FilterShape.test.cpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/FilterShape.h"
+
+#include <gtest/gtest.h>
+
+TEST(FilterShapeTest, default_constructor)
+{
+ loco::FilterShape shape;
+
+ ASSERT_FALSE(shape.count().known());
+ ASSERT_FALSE(shape.depth().known());
+ ASSERT_FALSE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+}
+
+TEST(FilterShapeTest, settet_and_getter)
+{
+ loco::FilterShape shape;
+
+ // Set count
+ shape.count() = 2;
+
+ ASSERT_TRUE(shape.count().known());
+ ASSERT_FALSE(shape.depth().known());
+ ASSERT_FALSE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+
+ ASSERT_EQ(shape.count(), 2);
+
+ // Set depth
+ shape.depth() = 3;
+
+ ASSERT_TRUE(shape.count().known());
+ ASSERT_TRUE(shape.depth().known());
+ ASSERT_FALSE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+
+ ASSERT_EQ(shape.count(), 2);
+ ASSERT_EQ(shape.depth(), 3);
+
+ // Set height
+ shape.height() = 4;
+
+ ASSERT_TRUE(shape.count().known());
+ ASSERT_TRUE(shape.depth().known());
+ ASSERT_TRUE(shape.height().known());
+ ASSERT_FALSE(shape.width().known());
+
+ ASSERT_EQ(shape.count(), 2);
+ ASSERT_EQ(shape.depth(), 3);
+ ASSERT_EQ(shape.height(), 4);
+
+ // Set width
+ shape.width() = 5;
+
+ ASSERT_TRUE(shape.count().known());
+ ASSERT_TRUE(shape.depth().known());
+ ASSERT_TRUE(shape.height().known());
+ ASSERT_TRUE(shape.width().known());
+
+ ASSERT_EQ(shape.count(), 2);
+ ASSERT_EQ(shape.depth(), 3);
+ ASSERT_EQ(shape.height(), 4);
+ ASSERT_EQ(shape.width(), 5);
+}
diff --git a/compiler/loco/src/IR/Graph.cpp b/compiler/loco/src/IR/Graph.cpp
new file mode 100644
index 000000000..1d8752252
--- /dev/null
+++ b/compiler/loco/src/IR/Graph.cpp
@@ -0,0 +1,137 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Graph.h"
+
+#include <stdex/Memory.h>
+
+#include <cassert>
+
+namespace
+{
+
+std::unique_ptr<loco::TensorShape> make_tensor_shape(std::initializer_list<loco::Dimension> dims)
+{
+ auto tensor_shape = stdex::make_unique<loco::TensorShape>();
+
+ tensor_shape->rank(dims.size());
+ {
+ uint32_t axis = 0;
+ for (auto it = dims.begin(); it != dims.end(); ++it)
+ {
+ tensor_shape->dim(axis++) = *it;
+ }
+ assert(axis == dims.size());
+ }
+
+ return std::move(tensor_shape);
+}
+
+} // namespace
+
+namespace loco
+{
+
+void Mixin<Trait::TensorShaped>::shape(std::initializer_list<Dimension> dims)
+{
+ shape(make_tensor_shape(dims));
+}
+
+GraphInput *Graph::InputContext::create(void)
+{
+ return take(stdex::make_unique<GraphInput>(size()));
+}
+
+GraphOutput *Graph::OutputContext::create(void)
+{
+ return take(stdex::make_unique<GraphOutput>(size()));
+}
+
+std::set<loco::Node *> all_nodes(loco::Graph *g)
+{
+ std::set<loco::Node *> res;
+
+ for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+ {
+ res.insert(g->nodes()->at(n));
+ }
+
+ return res;
+}
+
+std::vector<Node *> input_nodes(const Graph *g)
+{
+ std::map<GraphInputIndex, loco::Node *> table;
+
+ for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+ {
+ auto node = g->nodes()->at(n);
+
+ if (auto service = node->dialect()->service<GraphInputIndexQueryService>())
+ {
+ if (service->associated(node))
+ {
+ auto input_index = service->index(node);
+ assert(table.find(input_index) == table.end());
+ table[input_index] = node;
+ }
+ }
+ }
+
+ std::vector<loco::Node *> res;
+
+ for (uint32_t n = 0; n < g->inputs()->size(); ++n)
+ {
+ auto it = table.find(n);
+ res.emplace_back(it == table.end() ? nullptr : it->second);
+ }
+
+ return res;
+}
+
+std::vector<loco::Node *> output_nodes(loco::Graph *g)
+{
+ std::map<GraphOutputIndex, loco::Node *> table;
+
+ for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+ {
+ auto node = g->nodes()->at(n);
+
+ if (auto service = node->dialect()->service<GraphOutputIndexQueryService>())
+ {
+ if (service->associated(node))
+ {
+ auto output_index = service->index(node);
+ assert(table.find(output_index) == table.end());
+ table[output_index] = node;
+ }
+ }
+ }
+
+ std::vector<loco::Node *> res;
+
+ for (uint32_t n = 0; n < g->outputs()->size(); ++n)
+ {
+ auto it = table.find(n);
+ res.emplace_back(it == table.end() ? nullptr : it->second);
+ }
+
+ return res;
+}
+
+std::unique_ptr<Graph> make_graph(void) { return std::unique_ptr<Graph>{new Graph}; }
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/Graph.test.cpp b/compiler/loco/src/IR/Graph.test.cpp
new file mode 100644
index 000000000..6df630b0f
--- /dev/null
+++ b/compiler/loco/src/IR/Graph.test.cpp
@@ -0,0 +1,218 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Graph.h"
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+// @brief Mockup class for loco::NamedEntity
+struct NamedElement final : private loco::NamedEntity
+{
+ LOCO_NAMED_ENTITY_EXPOSE;
+};
+
+} // namespace
+
+TEST(NamedTest, constructor)
+{
+ NamedElement elem;
+
+ ASSERT_EQ(elem.name(), "");
+}
+
+TEST(NamedTest, setter_and_getter)
+{
+ NamedElement elem;
+
+ elem.name("name");
+ ASSERT_EQ(elem.name(), "name");
+}
+
+TEST(DataTypedMixinTest, constructor)
+{
+ loco::Mixin<loco::Trait::DataTyped> mixin;
+
+ ASSERT_EQ(mixin.dtype(), loco::DataType::Unknown);
+}
+
+TEST(DataTypedMixinTest, setter_and_getter)
+{
+ loco::Mixin<loco::Trait::DataTyped> mixin;
+
+ mixin.dtype(loco::DataType::FLOAT32);
+ ASSERT_EQ(mixin.dtype(), loco::DataType::FLOAT32);
+}
+
+TEST(TensorShapedMixinTest, setter_and_getter)
+{
+ loco::Mixin<loco::Trait::TensorShaped> mixin;
+
+ mixin.shape({1, 2, 3, 4});
+ ASSERT_NE(mixin.shape(), nullptr);
+ ASSERT_EQ(mixin.shape()->rank(), 4);
+ ASSERT_EQ(mixin.shape()->dim(0), 1);
+ ASSERT_EQ(mixin.shape()->dim(1), 2);
+ ASSERT_EQ(mixin.shape()->dim(2), 3);
+ ASSERT_EQ(mixin.shape()->dim(3), 4);
+}
+
+TEST(GraphTest, create_and_destroy_node)
+{
+ auto g = loco::make_graph();
+
+ auto pull = g->nodes()->create<loco::Pull>();
+
+ ASSERT_NO_THROW(g->nodes()->destroy(pull));
+ ASSERT_THROW(g->nodes()->destroy(pull), std::invalid_argument);
+}
+
+TEST(GraphTest, create_input)
+{
+ auto g = loco::make_graph();
+
+ auto input = g->inputs()->create();
+
+ // TODO Add more checks
+ ASSERT_EQ(input->shape(), nullptr);
+ ASSERT_EQ(input->index(), 0);
+}
+
+TEST(GraphTest, create_output)
+{
+ auto g = loco::make_graph();
+
+ auto output = g->outputs()->create();
+
+ // TODO Add more checks
+ ASSERT_EQ(output->shape(), nullptr);
+ ASSERT_EQ(output->index(), 0);
+}
+
+namespace
+{
+// temp node with multple params for ctor. loco::CanonicalOpcode::ReLU is used for simplicity
+class ParamCtorNode
+ : public loco::CanonicalNodeDef<loco::CanonicalOpcode::ReLU, loco::FixedArity<0>::Mixin>
+{
+public:
+ ParamCtorNode(int i, float f)
+ {
+ _i = i;
+ _f = f;
+ }
+
+ int i() { return _i; }
+ float f() { return _f; }
+
+private:
+ int _i;
+ float _f;
+};
+} // namespace
+
+TEST(GraphTest, consturctor_with_param_node)
+{
+ auto g = loco::make_graph();
+
+ auto test_node = g->nodes()->create<ParamCtorNode>(22, 11.11);
+
+ ASSERT_EQ(test_node->graph(), g.get());
+ ASSERT_EQ(const_cast<const ParamCtorNode *>(test_node)->graph(), g.get());
+
+ ASSERT_EQ(test_node->i(), 22);
+ ASSERT_FLOAT_EQ(test_node->f(), 11.11);
+
+ ASSERT_NO_THROW(g->nodes()->destroy(test_node));
+ ASSERT_THROW(g->nodes()->destroy(test_node), std::invalid_argument);
+}
+
+TEST(GraphTest, getters_over_const_instance)
+{
+ auto g = loco::make_graph();
+
+ auto pull = g->nodes()->create<loco::Pull>();
+ auto push = g->nodes()->create<loco::Push>();
+
+ loco::link(g->inputs()->create(), pull);
+ loco::link(g->outputs()->create(), push);
+
+ auto ptr = const_cast<const loco::Graph *>(g.get());
+
+ EXPECT_EQ(ptr->nodes()->size(), 2);
+ EXPECT_EQ(ptr->inputs()->size(), 1);
+}
+
+TEST(GraphTest, graph_node_enumeration)
+{
+ auto g = loco::make_graph();
+
+ auto pull_1 = g->nodes()->create<loco::Pull>();
+ auto push_1 = g->nodes()->create<loco::Push>();
+
+ auto nodes = loco::all_nodes(g.get());
+
+ // Returns true if "nodes" includes a given node
+ auto member = [&nodes](loco::Node *node) { return nodes.find(node) != nodes.end(); };
+
+ ASSERT_EQ(nodes.size(), 2);
+ ASSERT_TRUE(member(pull_1));
+ ASSERT_TRUE(member(push_1));
+}
+
+TEST(GraphTest, graph_inout_enumeration)
+{
+ auto g = loco::make_graph();
+
+ std::vector<loco::Pull *> pull_nodes;
+
+ auto pull_1 = g->nodes()->create<loco::Pull>();
+ auto pull_2 = g->nodes()->create<loco::Pull>();
+ auto pull_3 = g->nodes()->create<loco::Pull>();
+
+ auto push_1 = g->nodes()->create<loco::Push>();
+ auto push_2 = g->nodes()->create<loco::Push>();
+ auto push_3 = g->nodes()->create<loco::Push>();
+
+ loco::link(g->inputs()->create(), pull_2);
+ loco::link(g->inputs()->create(), pull_1);
+
+ loco::link(g->outputs()->create(), push_1);
+ loco::link(g->outputs()->create(), push_3);
+
+ auto output_nodes = loco::output_nodes(g.get());
+
+ ASSERT_EQ(output_nodes.size(), 2);
+ ASSERT_EQ(output_nodes.at(0), push_1);
+ ASSERT_EQ(output_nodes.at(1), push_3);
+}
+
+TEST(GraphTest, graph_name)
+{
+ auto g = loco::make_graph();
+
+ g->name("HelloGraph");
+ ASSERT_TRUE(g->name() == "HelloGraph");
+}
+
+TEST(GraphTest, graph_name_nullptr_NEG)
+{
+ auto g = loco::make_graph();
+
+ EXPECT_ANY_THROW(g->name(nullptr));
+}
diff --git a/compiler/loco/src/IR/GraphInputIndex.cpp b/compiler/loco/src/IR/GraphInputIndex.cpp
new file mode 100644
index 000000000..0c94d704c
--- /dev/null
+++ b/compiler/loco/src/IR/GraphInputIndex.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/GraphInputIndex.h"
+
+// NOTE This file validates "GraphInputIndex.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/GraphOutputIndex.cpp b/compiler/loco/src/IR/GraphOutputIndex.cpp
new file mode 100644
index 000000000..e6fdb9f94
--- /dev/null
+++ b/compiler/loco/src/IR/GraphOutputIndex.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/GraphOutputIndex.h"
+
+// NOTE This file validates "GraphOutputIndex.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/MatrixAxis.cpp b/compiler/loco/src/IR/MatrixAxis.cpp
new file mode 100644
index 000000000..d0773f758
--- /dev/null
+++ b/compiler/loco/src/IR/MatrixAxis.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/MatrixAxis.h"
+
+// NOTE This file validates "MatrixAxis.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/MatrixCodec.cpp b/compiler/loco/src/IR/MatrixCodec.cpp
new file mode 100644
index 000000000..87ae42610
--- /dev/null
+++ b/compiler/loco/src/IR/MatrixCodec.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/MatrixCodec.h"
+
+// NOTE This file validates "MatrixCodec.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/MockupNode.h b/compiler/loco/src/IR/MockupNode.h
new file mode 100644
index 000000000..ec56c90e2
--- /dev/null
+++ b/compiler/loco/src/IR/MockupNode.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCO_IR_MOCKUP_NODE_H__
+#define __LOCO_IR_MOCKUP_NODE_H__
+
+#include "loco/IR/Use.h"
+#include "loco/IR/Node.h"
+
+namespace
+{
+
+struct MockDialect final : public loco::Dialect
+{
+ static loco::Dialect *get(void)
+ {
+ static MockDialect d;
+ return &d;
+ }
+};
+
+// @brief Mockup node for internal testing
+class MockupNode final : public loco::Node
+{
+public:
+ MockupNode() = default;
+
+public:
+ const loco::Dialect *dialect(void) const final { return MockDialect::get(); }
+ uint32_t opnum(void) const final { return 0; }
+
+ uint32_t arity(void) const final { return 1; }
+ Node *arg(uint32_t N) const final { return _arg.node(); }
+ void drop(void) final { _arg.node(nullptr); }
+
+ Node *in(void)const { return _arg.node(); }
+ void in(Node *node) { _arg.node(node); }
+
+private:
+ loco::Use _arg{this};
+};
+
+} // namespace
+
+#endif // __LOCO_IR_MOCKUP_NODE_H__
diff --git a/compiler/loco/src/IR/Node.cpp b/compiler/loco/src/IR/Node.cpp
new file mode 100644
index 000000000..90ec5c997
--- /dev/null
+++ b/compiler/loco/src/IR/Node.cpp
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Node.h"
+#include "loco/IR/Use.h"
+
+#include <cassert>
+
+namespace loco
+{
+
+Node::~Node()
+{
+ // To detect dangling references
+ assert(_uses.size() == 0);
+}
+
+std::set<Node *> preds(const Node *node)
+{
+ std::set<Node *> res;
+
+ for (uint32_t n = 0; n < node->arity(); ++n)
+ {
+ if (auto pred = node->arg(n))
+ {
+ res.insert(pred);
+ }
+ }
+
+ return res;
+}
+
+std::set<Node *> succs(const Node *node)
+{
+ std::set<Node *> res;
+
+ for (auto use : node->_uses)
+ {
+ auto user = use->user();
+ assert(user != nullptr);
+ res.insert(user);
+ }
+
+ return res;
+}
+
+Subst<SubstQualifier::Default>::Subst(Node *from) : _from{from}
+{
+ // _from SHOULD be valid
+ assert(_from != nullptr);
+}
+
+void Subst<SubstQualifier::Default>::with(Node *into) const
+{
+ if (_from == into)
+ {
+ return;
+ }
+
+ auto *uses = &(_from->_uses);
+
+ while (!uses->empty())
+ {
+ auto use = *(uses->begin());
+ use->node(into);
+ }
+}
+
+Subst<SubstQualifier::Default> replace(Node *node)
+{
+ // Let's create Subst<SubstQualifier::Default>!
+ return Subst<SubstQualifier::Default>{node};
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/Node.test.cpp b/compiler/loco/src/IR/Node.test.cpp
new file mode 100644
index 000000000..00e444465
--- /dev/null
+++ b/compiler/loco/src/IR/Node.test.cpp
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Node.h"
+
+#include "MockupNode.h"
+
+#include <gtest/gtest.h>
+
+TEST(NodeTest, preds)
+{
+ ::MockupNode arg;
+ ::MockupNode node;
+
+ node.in(&arg);
+
+ auto preds = loco::preds(&node);
+
+ ASSERT_EQ(preds.size(), 1);
+ ASSERT_NE(preds.find(&arg), preds.end());
+}
+
+TEST(NodeTest, succs)
+{
+ ::MockupNode node;
+ ::MockupNode succ_1;
+ ::MockupNode succ_2;
+
+ succ_1.in(&node);
+ succ_2.in(&node);
+
+ auto succs = loco::succs(&node);
+
+ ASSERT_EQ(succs.size(), 2);
+ ASSERT_NE(succs.find(&succ_1), succs.end());
+ ASSERT_NE(succs.find(&succ_2), succs.end());
+}
+
+TEST(NodeTest, replace_with)
+{
+ ::MockupNode node_1;
+ ::MockupNode node_2;
+
+ ::MockupNode node_3;
+ ::MockupNode node_4;
+
+ node_3.in(&node_1);
+ node_4.in(&node_2);
+
+ // The following holds at this point
+ // - node_3 USE node_1
+ // - node_4 USE node_2
+ ASSERT_EQ(node_3.in(), &node_1);
+ ASSERT_EQ(node_4.in(), &node_2);
+
+ // Replace all the usage of node_1 with node_2
+ replace(&node_1).with(&node_2);
+
+ // The following holds at this point
+ // - node_3 USE node_2
+ // - node_4 USE node_2
+ ASSERT_EQ(node_3.in(), &node_2);
+ ASSERT_EQ(node_4.in(), &node_2);
+}
+
+TEST(NodeTest, constructor)
+{
+ MockupNode node;
+
+ // graph() SHOULD return nullptr if node is not constructed through "Graph"
+ ASSERT_EQ(node.graph(), nullptr);
+}
+
+// TODO Rewrite this as a FixedAritry mix-in test
+#if 0
+TEST(FixedArityNodeTest, constructor)
+{
+ struct DerivedNode final : public loco::FixedArityNode<1, loco::Node>
+ {
+ loco::Dialect *dialect(void) const final { return MockDialect::get(); }
+ uint32_t opnum(void) const final { return 0; }
+ };
+
+ DerivedNode node;
+
+ ASSERT_EQ(node.arity(), 1);
+ ASSERT_EQ(node.arg(0), nullptr);
+}
+#endif
diff --git a/compiler/loco/src/IR/NodeMixins.cpp b/compiler/loco/src/IR/NodeMixins.cpp
new file mode 100644
index 000000000..66037b17a
--- /dev/null
+++ b/compiler/loco/src/IR/NodeMixins.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/NodeMixins.h"
+
+// NOTE This file validates "NodeMixins.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/NodePool.cpp b/compiler/loco/src/IR/NodePool.cpp
new file mode 100644
index 000000000..553f15eb5
--- /dev/null
+++ b/compiler/loco/src/IR/NodePool.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/NodePool.h"
+
+namespace loco
+{
+
+NodePool::~NodePool()
+{
+ // Drop all the references before deallocation
+ for (uint32_t n = 0; n < size(); ++n)
+ {
+ at(n)->drop();
+ }
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/NodeShape.cpp b/compiler/loco/src/IR/NodeShape.cpp
new file mode 100644
index 000000000..0130cfbdb
--- /dev/null
+++ b/compiler/loco/src/IR/NodeShape.cpp
@@ -0,0 +1,284 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/NodeShape.h"
+
+#include <cassert>
+#include <stdexcept>
+
+//
+// BiasShape Support
+//
+namespace loco
+{
+
+void NodeShape::set(const BiasShape &shape)
+{
+ _domain = Domain::Bias;
+
+ _dims.resize(1);
+ _dims.at(0) = shape.length();
+}
+
+template <> BiasShape NodeShape::as(void) const
+{
+ assert(_domain == Domain::Bias);
+
+ BiasShape res;
+
+ res.length() = _dims.at(0);
+
+ return res;
+}
+
+} // namespace loco
+
+//
+// DepthwiseFilterShape Support
+//
+namespace loco
+{
+
+void NodeShape::set(const DepthwiseFilterShape &shape)
+{
+ _domain = Domain::DepthwiseFilter;
+
+ _dims.resize(4);
+ _dims.at(0) = shape.multiplier();
+ _dims.at(1) = shape.depth();
+ _dims.at(2) = shape.height();
+ _dims.at(3) = shape.width();
+}
+
+template <> DepthwiseFilterShape NodeShape::as(void) const
+{
+ assert(_domain == Domain::DepthwiseFilter);
+
+ DepthwiseFilterShape res;
+
+ res.multiplier() = _dims.at(0);
+ res.depth() = _dims.at(1);
+ res.height() = _dims.at(2);
+ res.width() = _dims.at(3);
+
+ return res;
+}
+
+} // namespace loco
+
+//
+// FeatureShape Support
+//
+namespace loco
+{
+
+void NodeShape::set(const FeatureShape &shape)
+{
+ _domain = Domain::Feature;
+
+ _dims.resize(4);
+ _dims.at(0) = shape.count();
+ _dims.at(1) = shape.depth();
+ _dims.at(2) = shape.height();
+ _dims.at(3) = shape.width();
+}
+
+template <> FeatureShape NodeShape::as(void) const
+{
+ assert(_domain == Domain::Feature);
+
+ FeatureShape res;
+
+ res.count() = _dims.at(0);
+ res.depth() = _dims.at(1);
+ res.height() = _dims.at(2);
+ res.width() = _dims.at(3);
+
+ return res;
+}
+
+} // namespace loco
+
+//
+// FilterShape Support
+//
+namespace loco
+{
+
+void NodeShape::set(const FilterShape &shape)
+{
+ _domain = Domain::Filter;
+
+ _dims.resize(4);
+ _dims.at(0) = shape.count();
+ _dims.at(1) = shape.depth();
+ _dims.at(2) = shape.height();
+ _dims.at(3) = shape.width();
+}
+
+template <> FilterShape NodeShape::as(void) const
+{
+ assert(_domain == Domain::Filter);
+
+ FilterShape res;
+
+ res.count() = _dims.at(0);
+ res.depth() = _dims.at(1);
+ res.height() = _dims.at(2);
+ res.width() = _dims.at(3);
+
+ return res;
+}
+
+} // namespace loco
+
+//
+// MatrixShape Support
+//
+namespace loco
+{
+
+void NodeShape::set(const MatrixShape &shape)
+{
+ _domain = Domain::Matrix;
+
+ _dims.resize(2);
+ _dims.at(0) = shape.height();
+ _dims.at(1) = shape.width();
+}
+
+template <> MatrixShape NodeShape::as(void) const
+{
+ assert(_domain == Domain::Matrix);
+
+ MatrixShape res;
+
+ res.height() = _dims.at(0);
+ res.width() = _dims.at(1);
+
+ return res;
+}
+
+} // namespace loco
+
+//
+// TensorShape Support
+//
+namespace loco
+{
+
+void NodeShape::set(const TensorShape &shape)
+{
+ _domain = Domain::Tensor;
+
+ _dims.resize(shape.rank());
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ {
+ _dims.at(axis) = shape.dim(axis);
+ }
+}
+
+template <> TensorShape NodeShape::as(void) const
+{
+ assert(_domain == Domain::Tensor);
+
+ TensorShape res;
+
+ res.rank(_dims.size());
+ for (uint32_t axis = 0; axis < _dims.size(); ++axis)
+ {
+ res.dim(axis) = _dims.at(axis);
+ }
+
+ return res;
+}
+
+} // namespace loco
+
+namespace loco
+{
+
+bool operator==(const NodeShape &lhs, const NodeShape &rhs)
+{
+ if (lhs.domain() != rhs.domain())
+ return false;
+
+ switch (lhs.domain())
+ {
+ case loco::Domain::Tensor:
+ {
+ auto lhs_t = lhs.as<TensorShape>();
+ auto rhs_t = rhs.as<TensorShape>();
+ if (lhs_t.rank() != rhs_t.rank())
+ return false;
+ for (uint32_t axis = 0; axis < lhs_t.rank(); ++axis)
+ {
+ if (!(lhs_t.dim(axis) == rhs_t.dim(axis)))
+ return false;
+ }
+ return true;
+ }
+
+ case loco::Domain::Feature:
+ {
+ auto lhs_f = lhs.as<FeatureShape>();
+ auto rhs_f = rhs.as<FeatureShape>();
+
+ return (lhs_f.count() == rhs_f.count() && lhs_f.depth() == rhs_f.depth() &&
+ lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width());
+ }
+
+ case loco::Domain::Filter:
+ {
+ auto lhs_f = lhs.as<FilterShape>();
+ auto rhs_f = rhs.as<FilterShape>();
+
+ return (lhs_f.count() == rhs_f.count() && lhs_f.depth() == rhs_f.depth() &&
+ lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width());
+ }
+
+ case loco::Domain::DepthwiseFilter:
+ {
+ auto lhs_f = lhs.as<DepthwiseFilterShape>();
+ auto rhs_f = rhs.as<DepthwiseFilterShape>();
+
+ return (lhs_f.multiplier() == rhs_f.multiplier() && lhs_f.depth() == rhs_f.depth() &&
+ lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width());
+ }
+
+ case loco::Domain::Bias:
+ {
+ auto lhs_f = lhs.as<BiasShape>();
+ auto rhs_f = rhs.as<BiasShape>();
+
+ return (lhs_f.length() == rhs_f.length());
+ }
+
+ case loco::Domain::Matrix:
+ {
+ auto lhs_f = lhs.as<MatrixShape>();
+ auto rhs_f = rhs.as<MatrixShape>();
+
+ return (lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width());
+ }
+
+ default:
+ throw std::runtime_error("Not supported domain for NodeShape equality");
+ }
+ return false;
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/NodeShape.test.cpp b/compiler/loco/src/IR/NodeShape.test.cpp
new file mode 100644
index 000000000..4f092e024
--- /dev/null
+++ b/compiler/loco/src/IR/NodeShape.test.cpp
@@ -0,0 +1,125 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/NodeShape.h"
+
+#include <gtest/gtest.h>
+
+TEST(NodeShapeTest, default_constructor)
+{
+ loco::NodeShape node_shape;
+
+ ASSERT_EQ(node_shape.domain(), loco::Domain::Unknown);
+}
+
+TEST(NodeShapeTest, bias_shape_constructor)
+{
+ loco::BiasShape bias_shape;
+
+ bias_shape.length() = 4;
+
+ loco::NodeShape node_shape{bias_shape};
+
+ ASSERT_EQ(node_shape.domain(), loco::Domain::Bias);
+ ASSERT_EQ(node_shape.as<loco::BiasShape>().length(), 4);
+}
+
+TEST(NodeShapeTest, dwfilter_shape_constructor)
+{
+ loco::DepthwiseFilterShape dwfilter_shape;
+
+ dwfilter_shape.depth() = 2;
+ dwfilter_shape.multiplier() = 3;
+ dwfilter_shape.height() = 4;
+ dwfilter_shape.width() = 5;
+
+ loco::NodeShape node_shape{dwfilter_shape};
+
+ ASSERT_EQ(node_shape.domain(), loco::Domain::DepthwiseFilter);
+ ASSERT_EQ(node_shape.as<loco::DepthwiseFilterShape>().depth(), 2);
+ ASSERT_EQ(node_shape.as<loco::DepthwiseFilterShape>().multiplier(), 3);
+ ASSERT_EQ(node_shape.as<loco::DepthwiseFilterShape>().height(), 4);
+ ASSERT_EQ(node_shape.as<loco::DepthwiseFilterShape>().width(), 5);
+}
+
+TEST(NodeShapeTest, feature_shape_constructor)
+{
+ loco::FeatureShape feature_shape;
+
+ feature_shape.count() = 2;
+ feature_shape.depth() = 3;
+ feature_shape.height() = 4;
+ feature_shape.width() = 5;
+
+ loco::NodeShape node_shape{feature_shape};
+
+ ASSERT_EQ(node_shape.domain(), loco::Domain::Feature);
+ ASSERT_EQ(node_shape.as<loco::FeatureShape>().count(), 2);
+ ASSERT_EQ(node_shape.as<loco::FeatureShape>().depth(), 3);
+ ASSERT_EQ(node_shape.as<loco::FeatureShape>().height(), 4);
+ ASSERT_EQ(node_shape.as<loco::FeatureShape>().width(), 5);
+}
+
+TEST(NodeShapeTest, filter_shape_constructor)
+{
+ loco::FilterShape filter_shape;
+
+ filter_shape.count() = 2;
+ filter_shape.depth() = 3;
+ filter_shape.height() = 4;
+ filter_shape.width() = 5;
+
+ loco::NodeShape node_shape{filter_shape};
+
+ ASSERT_EQ(node_shape.domain(), loco::Domain::Filter);
+ ASSERT_EQ(node_shape.as<loco::FilterShape>().count(), 2);
+ ASSERT_EQ(node_shape.as<loco::FilterShape>().depth(), 3);
+ ASSERT_EQ(node_shape.as<loco::FilterShape>().height(), 4);
+ ASSERT_EQ(node_shape.as<loco::FilterShape>().width(), 5);
+}
+
+TEST(NodeShapeTest, tensor_shape_constructor)
+{
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(2);
+ tensor_shape.dim(0) = 4;
+ tensor_shape.dim(1) = 5;
+
+ loco::NodeShape node_shape{tensor_shape};
+
+ ASSERT_EQ(node_shape.domain(), loco::Domain::Tensor);
+ ASSERT_EQ(node_shape.as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(node_shape.as<loco::TensorShape>().dim(0), 4);
+ ASSERT_EQ(node_shape.as<loco::TensorShape>().dim(1), 5);
+}
+
+TEST(NodeShapeTest, copy_constructible)
+{
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(2);
+ tensor_shape.dim(0) = 4;
+ tensor_shape.dim(1) = 5;
+
+ loco::NodeShape orig{tensor_shape};
+ loco::NodeShape copy{orig}; // Call Copy Constructor
+
+ ASSERT_EQ(copy.domain(), loco::Domain::Tensor);
+ ASSERT_EQ(copy.as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(copy.as<loco::TensorShape>().dim(0), 4);
+ ASSERT_EQ(copy.as<loco::TensorShape>().dim(1), 5);
+}
diff --git a/compiler/loco/src/IR/Nodes.cpp b/compiler/loco/src/IR/Nodes.cpp
new file mode 100644
index 000000000..133b69430
--- /dev/null
+++ b/compiler/loco/src/IR/Nodes.cpp
@@ -0,0 +1,243 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Nodes.h"
+#include "loco/IR/Graph.h"
+
+#include <cassert>
+#include <limits>
+
+// This file validates "Nodes.h". Please DO NOT remove this file.
+namespace
+{
+
+/**
+ * @note This function is currently only used in assert. Compiler will
+ * warn/error this function as unused in Release build.
+ * Making inline will make compiler happy.
+ */
+// Is it possible to update lhs as rhs?
+inline bool dtype_assignable(loco::DataType lhs, loco::DataType rhs)
+{
+ if (lhs == loco::DataType::Unknown)
+ {
+ return true;
+ }
+
+ // lhs is already known, and thus rhs should be matched
+ return lhs == rhs;
+}
+
+} // namespace
+
+/**
+ * Push
+ */
+namespace loco
+{
+
+void Push::index(const GraphOutputIndex &index)
+{
+ // Push internally stores "GraphOutputIndex" as int64_t
+ _index = static_cast<int64_t>(index);
+}
+
+GraphOutputIndex Push::index(void) const
+{
+ assert(_index >= std::numeric_limits<GraphOutputIndex>::min());
+ assert(_index <= std::numeric_limits<GraphOutputIndex>::max());
+ return static_cast<GraphOutputIndex>(_index);
+}
+
+void link(GraphOutput *output, Push *push) { push->index(output->index()); }
+
+Push *push_node(Graph *g, const GraphOutputIndex &index)
+{
+ for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+ {
+ if (auto push = dynamic_cast<Push *>(g->nodes()->at(n)))
+ {
+ if (push->indexed() && push->index() == index)
+ {
+ return push;
+ }
+ }
+ }
+ return nullptr;
+}
+
+} // namespace loco
+
+/**
+ * Pull
+ */
+namespace loco
+{
+
+void Pull::index(const GraphInputIndex &index)
+{
+ // ASSUMPTION
+ //
+ // It is possible to update index multiple times, but only with the same value!
+ assert(!indexed() or _index == index);
+
+ if (indexed())
+ {
+ assert(_index == index);
+ return;
+ }
+
+ // Push internally stores "GraphInputIndex" as int64_t
+ _index = static_cast<int64_t>(index);
+
+ // ASSUMPTION: The return value of graph() never changes!
+ if (graph() != nullptr && _dtype != loco::DataType::Unknown)
+ {
+ // Update Graph-level input only if it is not yet specified
+ if (graph()->inputs()->at(_index)->dtype() == DataType::Unknown)
+ {
+ graph()->inputs()->at(_index)->dtype(_dtype);
+ }
+ assert(graph()->inputs()->at(_index)->dtype() == _dtype);
+ graph()->inputs()->at(_index)->dtype(_dtype);
+
+ // Reset the locally cached data
+ _dtype = DataType::Unknown;
+ }
+}
+
+GraphInputIndex Pull::index(void) const
+{
+ assert(_index >= std::numeric_limits<GraphInputIndex>::min());
+ assert(_index <= std::numeric_limits<GraphInputIndex>::max());
+ return static_cast<GraphInputIndex>(_index);
+}
+
+void Pull::dtype(const DataType &dt)
+{
+ // ASSUMPTION: "dtype" is never invalidated!
+ assert(dt != loco::DataType::Unknown);
+ // ASSUMPTION
+ //
+ // It is possible to update index multiple times, but only with the same value!
+ if (indexed())
+ {
+ assert(dtype_assignable(graph()->inputs()->at(_index)->dtype(), dt));
+ graph()->inputs()->at(_index)->dtype(dt);
+ return;
+ }
+
+ // Use local cache
+ _dtype = dt;
+}
+
+DataType Pull::dtype(void) const
+{
+ if (graph() != nullptr and _index >= 0)
+ {
+ assert(_dtype == DataType::Unknown);
+ return graph()->inputs()->at(_index)->dtype();
+ }
+ else
+ {
+ return _dtype;
+ }
+}
+
+void link(GraphInput *input, Pull *pull) { pull->index(input->index()); }
+
+Pull *pull_node(Graph *g, const GraphInputIndex &index)
+{
+ for (uint32_t n = 0; n < g->nodes()->size(); ++n)
+ {
+ if (auto pull = dynamic_cast<Pull *>(g->nodes()->at(n)))
+ {
+ if (pull->indexed() && pull->index() == index)
+ {
+ return pull;
+ }
+ }
+ }
+ return nullptr;
+}
+
+} // namespace loco
+
+/**
+ * ConstGen
+ */
+namespace loco
+{
+
+template <DataType DT> uint32_t ConstGen::size(void) const
+{
+ assert(dtype() == DT);
+ assert(_data.size() % sizeof(typename DataTypeImpl<DT>::Type) == 0);
+ return _data.size() / sizeof(typename DataTypeImpl<DT>::Type);
+}
+
+template <DataType DT> void ConstGen::size(uint32_t l)
+{
+ assert(dtype() == DT);
+ _data.resize(l * sizeof(typename DataTypeImpl<DT>::Type));
+}
+
+template <DataType DT> const typename DataTypeImpl<DT>::Type &ConstGen::at(uint32_t n) const
+{
+ assert(dtype() == DT);
+ assert(n < size<DT>());
+ return *(reinterpret_cast<const typename DataTypeImpl<DT>::Type *>(_data.data()) + n);
+}
+
+template <DataType DT> typename DataTypeImpl<DT>::Type &ConstGen::at(uint32_t n)
+{
+ assert(dtype() == DT);
+ assert(n < size<DT>());
+ return *(reinterpret_cast<typename DataTypeImpl<DT>::Type *>(_data.data()) + n);
+}
+
+#define INSTANTIATE(DT) \
+ template uint32_t ConstGen::size<DT>(void) const; \
+ template void ConstGen::size<DT>(uint32_t); \
+ template const typename DataTypeImpl<DT>::Type &ConstGen::at<DT>(uint32_t) const; \
+ template typename DataTypeImpl<DT>::Type &ConstGen::at<DT>(uint32_t);
+
+INSTANTIATE(DataType::S32);
+INSTANTIATE(DataType::FLOAT32);
+
+#undef INSTANTIATE
+
+} // namespace loco
+
+/**
+ * TensorBroadcast
+ */
+namespace loco
+{
+
+bool TensorBroadcast::Mapping::defined(const TensorAxis &axis) const
+{
+ return _content.find(axis) != _content.end();
+}
+
+const Dimension &TensorBroadcast::Mapping::dim(const TensorAxis &axis) const
+{
+ return _content.at(axis);
+}
+
+Dimension &TensorBroadcast::Mapping::dim(const TensorAxis &axis) { return _content[axis]; }
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/Nodes.test.cpp b/compiler/loco/src/IR/Nodes.test.cpp
new file mode 100644
index 000000000..cd51f46c0
--- /dev/null
+++ b/compiler/loco/src/IR/Nodes.test.cpp
@@ -0,0 +1,588 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Nodes.h"
+#include "loco/IR/CanonicalDialect.h"
+
+#include <gtest/gtest.h>
+
+TEST(PushTest, constructor)
+{
+ loco::Push push_node;
+
+ ASSERT_EQ(push_node.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(push_node.opcode(), loco::CanonicalOpcode::Push);
+
+ ASSERT_FALSE(push_node.indexed());
+}
+
+TEST(PushTest, shape)
+{
+ const std::vector<uint32_t> dims{1, 8, 16, 3};
+
+ loco::Pull push_node;
+
+ push_node.shape({dims[0], dims[1], dims[2], dims[3]});
+
+ ASSERT_EQ(push_node.rank(), dims.size());
+ ASSERT_EQ(push_node.dim(0), dims[0]);
+ ASSERT_EQ(push_node.dim(1), dims[1]);
+ ASSERT_EQ(push_node.dim(2), dims[2]);
+ ASSERT_EQ(push_node.dim(3), dims[3]);
+}
+
+TEST(PullTest, constructor)
+{
+ loco::Pull pull_node;
+
+ ASSERT_EQ(pull_node.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(pull_node.opcode(), loco::CanonicalOpcode::Pull);
+
+ ASSERT_FALSE(pull_node.indexed());
+
+ ASSERT_EQ(pull_node.dtype(), loco::DataType::Unknown);
+ ASSERT_EQ(pull_node.rank(), 0);
+}
+
+TEST(PullTest, shape)
+{
+ const std::vector<uint32_t> dims{1, 8, 16, 3};
+
+ loco::Pull pull_node;
+
+ pull_node.shape({dims[0], dims[1], dims[2], dims[3]});
+
+ ASSERT_EQ(pull_node.rank(), dims.size());
+ ASSERT_EQ(pull_node.dim(0), dims[0]);
+ ASSERT_EQ(pull_node.dim(1), dims[1]);
+ ASSERT_EQ(pull_node.dim(2), dims[2]);
+ ASSERT_EQ(pull_node.dim(3), dims[3]);
+}
+
+TEST(ForwardTest, constructor)
+{
+ loco::Forward forward_node;
+
+ ASSERT_EQ(forward_node.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(forward_node.opcode(), loco::CanonicalOpcode::Forward);
+
+ ASSERT_EQ(forward_node.input(), nullptr);
+}
+
+TEST(ReLUTest, constructor)
+{
+ loco::ReLU relu_node;
+
+ ASSERT_EQ(relu_node.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(relu_node.opcode(), loco::CanonicalOpcode::ReLU);
+
+ ASSERT_EQ(relu_node.input(), nullptr);
+}
+
+TEST(ReLU6Test, constructor)
+{
+ loco::ReLU6 relu6_node;
+
+ ASSERT_EQ(relu6_node.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(relu6_node.opcode(), loco::CanonicalOpcode::ReLU6);
+
+ ASSERT_EQ(relu6_node.input(), nullptr);
+}
+
+TEST(ConstGenTest, constructor)
+{
+ loco::ConstGen constgen_node;
+
+ ASSERT_EQ(constgen_node.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(constgen_node.opcode(), loco::CanonicalOpcode::ConstGen);
+
+ ASSERT_EQ(constgen_node.dtype(), loco::DataType::Unknown);
+ ASSERT_EQ(constgen_node.rank(), 0);
+
+ constgen_node.dtype(loco::DataType::FLOAT32);
+ ASSERT_EQ(constgen_node.dtype(), loco::DataType::FLOAT32);
+
+ constgen_node.rank(2);
+ ASSERT_EQ(constgen_node.rank(), 2);
+
+ constgen_node.dim(0) = 2;
+ constgen_node.dim(1) = 3;
+
+ ASSERT_TRUE(constgen_node.dim(0).known());
+ ASSERT_TRUE(constgen_node.dim(1).known());
+
+ ASSERT_EQ(constgen_node.dim(0), 2);
+ ASSERT_EQ(constgen_node.dim(1), 3);
+
+ constgen_node.size<loco::DataType::FLOAT32>(6);
+
+ ASSERT_EQ(constgen_node.size<loco::DataType::FLOAT32>(), 6);
+
+ constgen_node.at<loco::DataType::FLOAT32>(0) = 0.0f; // Set 0,0
+ constgen_node.at<loco::DataType::FLOAT32>(1) = 1.0f; // Set 0,1
+ constgen_node.at<loco::DataType::FLOAT32>(2) = 2.0f; // Set 0,2
+ constgen_node.at<loco::DataType::FLOAT32>(3) = 3.0f; // Set 1,0
+ constgen_node.at<loco::DataType::FLOAT32>(4) = 4.0f; // Set 1,1
+ constgen_node.at<loco::DataType::FLOAT32>(5) = 5.0f; // Set 1,2
+
+ ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(0), 0.0f);
+ ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(1), 1.0f);
+ ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(2), 2.0f);
+ ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(3), 3.0f);
+ ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(4), 4.0f);
+ ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(5), 5.0f);
+}
+
+TEST(ConstGenTest, constructor_s32)
+{
+ loco::ConstGen constgen_node;
+
+ ASSERT_EQ(constgen_node.dtype(), loco::DataType::Unknown);
+ ASSERT_EQ(constgen_node.rank(), 0);
+
+ constgen_node.dtype(loco::DataType::S32);
+ ASSERT_EQ(constgen_node.dtype(), loco::DataType::S32);
+
+ constgen_node.rank(2);
+ ASSERT_EQ(constgen_node.rank(), 2);
+
+ constgen_node.dim(0) = 2;
+ constgen_node.dim(1) = 3;
+
+ ASSERT_TRUE(constgen_node.dim(0).known());
+ ASSERT_TRUE(constgen_node.dim(1).known());
+
+ ASSERT_EQ(constgen_node.dim(0), 2);
+ ASSERT_EQ(constgen_node.dim(1), 3);
+
+ constgen_node.size<loco::DataType::S32>(6);
+
+ ASSERT_EQ(constgen_node.size<loco::DataType::S32>(), 6);
+
+ constgen_node.at<loco::DataType::S32>(0) = 0; // Set 0,0
+ constgen_node.at<loco::DataType::S32>(1) = 1; // Set 0,1
+ constgen_node.at<loco::DataType::S32>(2) = 2; // Set 0,2
+ constgen_node.at<loco::DataType::S32>(3) = -3; // Set 1,0
+ constgen_node.at<loco::DataType::S32>(4) = -4; // Set 1,1
+ constgen_node.at<loco::DataType::S32>(5) = -5; // Set 1,2
+
+ ASSERT_EQ(constgen_node.at<loco::DataType::S32>(0), 0);
+ ASSERT_EQ(constgen_node.at<loco::DataType::S32>(1), 1);
+ ASSERT_EQ(constgen_node.at<loco::DataType::S32>(2), 2);
+ ASSERT_EQ(constgen_node.at<loco::DataType::S32>(3), -3);
+ ASSERT_EQ(constgen_node.at<loco::DataType::S32>(4), -4);
+ ASSERT_EQ(constgen_node.at<loco::DataType::S32>(5), -5);
+}
+
+TEST(MaxPool2DTest, constructor)
+{
+ loco::MaxPool2D maxpool_node;
+
+ ASSERT_EQ(maxpool_node.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(maxpool_node.opcode(), loco::CanonicalOpcode::MaxPool2D);
+
+ ASSERT_EQ(maxpool_node.ifm(), nullptr);
+
+ ASSERT_EQ(maxpool_node.pad()->top(), 0);
+ ASSERT_EQ(maxpool_node.pad()->bottom(), 0);
+ ASSERT_EQ(maxpool_node.pad()->left(), 0);
+ ASSERT_EQ(maxpool_node.pad()->right(), 0);
+
+ ASSERT_EQ(maxpool_node.window()->vertical(), 1);
+ ASSERT_EQ(maxpool_node.window()->horizontal(), 1);
+
+ ASSERT_EQ(maxpool_node.stride()->vertical(), 1);
+ ASSERT_EQ(maxpool_node.stride()->horizontal(), 1);
+}
+
+TEST(MaxPool2DTest, pad)
+{
+ const uint32_t t = 1;
+ const uint32_t b = 2;
+ const uint32_t l = 3;
+ const uint32_t r = 4;
+
+ loco::MaxPool2D maxpool_node;
+
+ maxpool_node.pad()->top(t);
+ ASSERT_EQ(maxpool_node.pad()->top(), t);
+
+ maxpool_node.pad()->bottom(b);
+ ASSERT_EQ(maxpool_node.pad()->bottom(), b);
+
+ maxpool_node.pad()->left(l);
+ ASSERT_EQ(maxpool_node.pad()->left(), l);
+
+ maxpool_node.pad()->right(r);
+ ASSERT_EQ(maxpool_node.pad()->right(), r);
+}
+
+TEST(AvgPool2DTest, constructor)
+{
+ loco::AvgPool2D avgpool_node;
+
+ ASSERT_EQ(avgpool_node.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(avgpool_node.opcode(), loco::CanonicalOpcode::AvgPool2D);
+
+ ASSERT_EQ(avgpool_node.ifm(), nullptr);
+
+ ASSERT_EQ(avgpool_node.convention(), loco::AvgPool2D::Convention::Unknown);
+
+ ASSERT_EQ(avgpool_node.pad()->top(), 0);
+ ASSERT_EQ(avgpool_node.pad()->bottom(), 0);
+ ASSERT_EQ(avgpool_node.pad()->left(), 0);
+ ASSERT_EQ(avgpool_node.pad()->right(), 0);
+
+ ASSERT_EQ(avgpool_node.window()->vertical(), 1);
+ ASSERT_EQ(avgpool_node.window()->horizontal(), 1);
+
+ ASSERT_EQ(avgpool_node.stride()->vertical(), 1);
+ ASSERT_EQ(avgpool_node.stride()->horizontal(), 1);
+}
+
+TEST(FeatureEncodeTest, constructor)
+{
+ loco::FeatureEncode feature_encode;
+
+ ASSERT_EQ(feature_encode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(feature_encode.opcode(), loco::CanonicalOpcode::FeatureEncode);
+
+ ASSERT_EQ(feature_encode.input(), nullptr);
+ ASSERT_EQ(feature_encode.encoder(), nullptr);
+}
+
+TEST(FeatureDecodeTest, constructor)
+{
+ loco::FeatureDecode feature_decode;
+
+ ASSERT_EQ(feature_decode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(feature_decode.opcode(), loco::CanonicalOpcode::FeatureDecode);
+
+ ASSERT_EQ(feature_decode.input(), nullptr);
+ ASSERT_EQ(feature_decode.decoder(), nullptr);
+}
+
+TEST(Reshape_Fixed_Test, constructor)
+{
+ loco::Reshape<loco::ReshapeType::Fixed> reshape;
+
+ ASSERT_EQ(reshape.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(reshape.opcode(), loco::CanonicalOpcode::FixedReshape);
+
+ ASSERT_EQ(reshape.rank(), 0);
+}
+
+TEST(Reshape_Fixed_Test, shape)
+{
+ loco::Reshape<loco::ReshapeType::Fixed> reshape;
+ reshape.shape({2, 3});
+
+ ASSERT_EQ(reshape.rank(), 2);
+ ASSERT_EQ(reshape.dim(0), 2);
+ ASSERT_EQ(reshape.dim(1), 3);
+}
+
+TEST(FilterEncodeTest, constructor)
+{
+ loco::FilterEncode filter_encode;
+
+ ASSERT_EQ(filter_encode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(filter_encode.opcode(), loco::CanonicalOpcode::FilterEncode);
+
+ ASSERT_EQ(filter_encode.input(), nullptr);
+ ASSERT_EQ(filter_encode.encoder(), nullptr);
+}
+
+TEST(FilterDecodeTest, constructor)
+{
+ loco::FilterDecode filter_decode;
+
+ ASSERT_EQ(filter_decode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(filter_decode.opcode(), loco::CanonicalOpcode::FilterDecode);
+
+ ASSERT_EQ(filter_decode.input(), nullptr);
+ ASSERT_EQ(filter_decode.decoder(), nullptr);
+}
+
+TEST(DepthwiseFilterEncodeTest, constructor)
+{
+ loco::DepthwiseFilterEncode dw_filter_encode;
+
+ ASSERT_EQ(dw_filter_encode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(dw_filter_encode.opcode(), loco::CanonicalOpcode::DepthwiseFilterEncode);
+
+ ASSERT_EQ(dw_filter_encode.input(), nullptr);
+ ASSERT_EQ(dw_filter_encode.encoder(), nullptr);
+}
+
+TEST(DepthwiseFilterDecodeTest, constructor)
+{
+ loco::DepthwiseFilterDecode dw_filter_decode;
+
+ ASSERT_EQ(dw_filter_decode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(dw_filter_decode.opcode(), loco::CanonicalOpcode::DepthwiseFilterDecode);
+
+ ASSERT_EQ(dw_filter_decode.input(), nullptr);
+ ASSERT_EQ(dw_filter_decode.decoder(), nullptr);
+}
+
+TEST(TensorConcatTest, constructor)
+{
+ loco::TensorConcat tensor_concat;
+
+ ASSERT_EQ(tensor_concat.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(tensor_concat.opcode(), loco::CanonicalOpcode::TensorConcat);
+
+ ASSERT_EQ(tensor_concat.lhs(), nullptr);
+ ASSERT_EQ(tensor_concat.rhs(), nullptr);
+ ASSERT_EQ(tensor_concat.axis(), 0);
+
+ tensor_concat.axis(3);
+ ASSERT_EQ(tensor_concat.axis(), 3);
+}
+
+TEST(Conv2DTest, constructor)
+{
+ loco::Conv2D conv2d;
+
+ ASSERT_EQ(conv2d.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(conv2d.opcode(), loco::CanonicalOpcode::Conv2D);
+
+ ASSERT_EQ(conv2d.ifm(), nullptr);
+ ASSERT_EQ(conv2d.ker(), nullptr);
+
+ ASSERT_NE(conv2d.pad(), nullptr);
+ ASSERT_EQ(conv2d.pad()->top(), 0);
+ ASSERT_EQ(conv2d.pad()->bottom(), 0);
+ ASSERT_EQ(conv2d.pad()->left(), 0);
+ ASSERT_EQ(conv2d.pad()->right(), 0);
+
+ ASSERT_NE(conv2d.stride(), nullptr);
+ ASSERT_EQ(conv2d.stride()->vertical(), 1);
+ ASSERT_EQ(conv2d.stride()->horizontal(), 1);
+}
+
+TEST(DepthwiseConv2DTest, constructor)
+{
+ loco::DepthwiseConv2D dw_conv2d;
+
+ ASSERT_EQ(dw_conv2d.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(dw_conv2d.opcode(), loco::CanonicalOpcode::DepthwiseConv2D);
+
+ ASSERT_EQ(dw_conv2d.ifm(), nullptr);
+ ASSERT_EQ(dw_conv2d.ker(), nullptr);
+
+ ASSERT_NE(dw_conv2d.pad(), nullptr);
+ ASSERT_EQ(dw_conv2d.pad()->top(), 0);
+ ASSERT_EQ(dw_conv2d.pad()->bottom(), 0);
+ ASSERT_EQ(dw_conv2d.pad()->left(), 0);
+ ASSERT_EQ(dw_conv2d.pad()->right(), 0);
+
+ ASSERT_NE(dw_conv2d.stride(), nullptr);
+ ASSERT_EQ(dw_conv2d.stride()->vertical(), 1);
+ ASSERT_EQ(dw_conv2d.stride()->horizontal(), 1);
+}
+
+TEST(TransposedConv2DTest, constructor)
+{
+ loco::TransposedConv2D tr_conv2d;
+
+ ASSERT_EQ(tr_conv2d.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(tr_conv2d.opcode(), loco::CanonicalOpcode::TransposedConv2D);
+
+ ASSERT_EQ(tr_conv2d.ifm(), nullptr);
+ ASSERT_EQ(tr_conv2d.ker(), nullptr);
+
+ ASSERT_NE(tr_conv2d.pad(), nullptr);
+ ASSERT_EQ(tr_conv2d.pad()->top(), 0);
+ ASSERT_EQ(tr_conv2d.pad()->bottom(), 0);
+ ASSERT_EQ(tr_conv2d.pad()->left(), 0);
+ ASSERT_EQ(tr_conv2d.pad()->right(), 0);
+
+ ASSERT_NE(tr_conv2d.stride(), nullptr);
+ ASSERT_EQ(tr_conv2d.stride()->vertical(), 1);
+ ASSERT_EQ(tr_conv2d.stride()->horizontal(), 1);
+}
+
+TEST(BiasEncodeTest, constructor)
+{
+ loco::BiasEncode bias_encode;
+
+ ASSERT_EQ(bias_encode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(bias_encode.opcode(), loco::CanonicalOpcode::BiasEncode);
+
+ ASSERT_EQ(bias_encode.input(), nullptr);
+}
+
+TEST(TensorBiasAddTest, constructor)
+{
+ loco::BiasAdd<loco::Domain::Tensor> bias_add;
+
+ ASSERT_EQ(bias_add.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(bias_add.opcode(), loco::CanonicalOpcode::TensorBiasAdd);
+
+ ASSERT_EQ(bias_add.value(), nullptr);
+ ASSERT_EQ(bias_add.bias(), nullptr);
+ ASSERT_EQ(bias_add.axis(), 0);
+}
+
+TEST(TensorBiasAddTest, alias)
+{
+ loco::TensorBiasAdd bias_add;
+
+ SUCCEED();
+}
+
+TEST(FeatureBiasAddTest, constructor)
+{
+ loco::BiasAdd<loco::Domain::Feature> bias_add;
+
+ ASSERT_EQ(bias_add.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(bias_add.opcode(), loco::CanonicalOpcode::FeatureBiasAdd);
+
+ ASSERT_EQ(bias_add.value(), nullptr);
+ ASSERT_EQ(bias_add.bias(), nullptr);
+}
+
+TEST(FeatureBiasAddTest, alias)
+{
+ loco::FeatureBiasAdd bias_add;
+
+ SUCCEED();
+}
+
+TEST(EltwiseAddTest, constructor)
+{
+ loco::EltwiseAdd eltwise_add;
+
+ SUCCEED();
+}
+
+TEST(EltwiseMaxTest, constructor)
+{
+ loco::EltwiseMax eltwise_max;
+
+ SUCCEED();
+}
+
+TEST(EltwiseMulTest, constructor)
+{
+ loco::EltwiseMul eltwise_mul;
+
+ SUCCEED();
+}
+
+TEST(EltwiseSubTest, constructor)
+{
+ loco::EltwiseSub eltwise_sub;
+
+ SUCCEED();
+}
+
+TEST(EltwiseDivTest, constructor)
+{
+ loco::EltwiseDiv eltwise_div;
+
+ SUCCEED();
+}
+
+TEST(EltwiseSqrtTest, constructor)
+{
+ loco::EltwiseSqrt sqrt_node;
+
+ ASSERT_EQ(sqrt_node.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(sqrt_node.opcode(), loco::CanonicalOpcode::EltwiseSqrt);
+
+ ASSERT_EQ(sqrt_node.input(), nullptr);
+}
+
+TEST(TensorBroadcastTest, constructor)
+{
+ loco::TensorBroadcast tensor_broadcast_node;
+
+ ASSERT_EQ(tensor_broadcast_node.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(tensor_broadcast_node.opcode(), loco::CanonicalOpcode::TensorBroadcast);
+
+ ASSERT_EQ(tensor_broadcast_node.input(), nullptr);
+}
+
+TEST(TensorBroadcastTest, mapping)
+{
+ loco::TensorBroadcast tensor_broadcast_node;
+
+ ASSERT_EQ(tensor_broadcast_node.mapping()->defined(0), false);
+
+ tensor_broadcast_node.mapping()->dim(0) = 3;
+
+ ASSERT_EQ(tensor_broadcast_node.mapping()->defined(0), true);
+ ASSERT_EQ(tensor_broadcast_node.mapping()->dim(0), 3);
+}
+
+TEST(MatrixEncodeTest, constructor)
+{
+ loco::MatrixEncode matrix_encode;
+
+ ASSERT_EQ(matrix_encode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(matrix_encode.opcode(), loco::CanonicalOpcode::MatrixEncode);
+
+ ASSERT_EQ(matrix_encode.input(), nullptr);
+}
+
+TEST(MatrixDecodeTest, constructor)
+{
+ loco::MatrixDecode matrix_decode;
+
+ ASSERT_EQ(matrix_decode.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(matrix_decode.opcode(), loco::CanonicalOpcode::MatrixDecode);
+
+ ASSERT_EQ(matrix_decode.input(), nullptr);
+}
+
+TEST(MatMulTest, constructor)
+{
+ loco::MatMul mat_mul;
+
+ ASSERT_EQ(mat_mul.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(mat_mul.opcode(), loco::CanonicalOpcode::MatMul);
+
+ ASSERT_EQ(mat_mul.lhs(), nullptr);
+ ASSERT_EQ(mat_mul.rhs(), nullptr);
+}
+
+TEST(TransposeTest, constructor)
+{
+ loco::TensorTranspose transpose;
+
+ ASSERT_EQ(transpose.dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(transpose.opcode(), loco::CanonicalOpcode::TensorTranspose);
+
+ ASSERT_EQ(transpose.input(), nullptr);
+ ASSERT_EQ(transpose.perm()->size(), 0);
+}
+
+TEST(TransposeTest, perm)
+{
+ loco::TensorTranspose transpose;
+
+ transpose.perm()->size(3);
+ transpose.perm()->axis(0) = 1;
+ transpose.perm()->axis(1) = 2;
+ transpose.perm()->axis(2) = 0;
+
+ ASSERT_EQ(transpose.perm()->axis(0), 1);
+ ASSERT_EQ(transpose.perm()->axis(1), 2);
+ ASSERT_EQ(transpose.perm()->axis(2), 0);
+}
diff --git a/compiler/loco/src/IR/Padding2D.test.cpp b/compiler/loco/src/IR/Padding2D.test.cpp
new file mode 100644
index 000000000..2e3d4af87
--- /dev/null
+++ b/compiler/loco/src/IR/Padding2D.test.cpp
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Padding2D.h"
+
+#include <gtest/gtest.h>
+
+TEST(PadTest, default_constructor_2D)
+{
+ loco::Padding2D pad;
+
+ ASSERT_EQ(pad.top(), 0);
+ ASSERT_EQ(pad.bottom(), 0);
+ ASSERT_EQ(pad.left(), 0);
+ ASSERT_EQ(pad.right(), 0);
+}
diff --git a/compiler/loco/src/IR/PaddingND.test.cpp b/compiler/loco/src/IR/PaddingND.test.cpp
new file mode 100644
index 000000000..0e20406ff
--- /dev/null
+++ b/compiler/loco/src/IR/PaddingND.test.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/PaddingND.h"
+
+#include <gtest/gtest.h>
+
+TEST(PaddingNDTest, default_constructor_ND)
+{
+ loco::PaddingND padding;
+
+ padding.rank(1);
+ padding.front(0) = 1;
+ padding.back(0) = 2;
+
+ ASSERT_EQ(padding.rank(), 1);
+ ASSERT_EQ(padding.front(0), 1);
+ ASSERT_EQ(padding.back(0), 2);
+}
diff --git a/compiler/loco/src/IR/PermutingCodec.cpp b/compiler/loco/src/IR/PermutingCodec.cpp
new file mode 100644
index 000000000..2857e5e28
--- /dev/null
+++ b/compiler/loco/src/IR/PermutingCodec.cpp
@@ -0,0 +1,630 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/PermutingCodec.h"
+
+#include <stdex/Memory.h>
+
+#include <cassert>
+#include <set>
+#include <stdexcept>
+
+/**
+ * Feature Domain
+ */
+namespace
+{
+
+using loco::FeatureAxis;
+
+inline bool valid(const FeatureAxis &axis)
+{
+ switch (axis)
+ {
+ case FeatureAxis::Count:
+ return true;
+ case FeatureAxis::Depth:
+ return true;
+ case FeatureAxis::Height:
+ return true;
+ case FeatureAxis::Width:
+ return true;
+ default:
+ break;
+ }
+
+ return false;
+}
+
+inline bool valid(const loco::Permutation<loco::Domain::Feature> &perm)
+{
+ auto check = [&perm](FeatureAxis axis_f) {
+ if (!perm.mapped(axis_f))
+ return false;
+ return perm.axis(axis_f) < 4;
+ };
+
+ if (!check(FeatureAxis::Count))
+ return false;
+ if (!check(FeatureAxis::Depth))
+ return false;
+ if (!check(FeatureAxis::Height))
+ return false;
+ if (!check(FeatureAxis::Width))
+ return false;
+
+ // Check whether tensor axes are all distinct
+ std::set<loco::TensorAxis> values;
+
+ values.insert(perm[FeatureAxis::Count]);
+ values.insert(perm[FeatureAxis::Depth]);
+ values.insert(perm[FeatureAxis::Height]);
+ values.insert(perm[FeatureAxis::Width]);
+
+ return values.size() == 4;
+}
+
+} // namespace
+
+namespace loco
+{
+
+//
+// Permutation
+//
+bool Permutation<Domain::Feature>::mapped(const FeatureAxis &axis_f) const
+{
+ assert(valid(axis_f) && "invalid feature axis");
+ return _map.find(axis_f) != _map.end();
+}
+
+uint32_t Permutation<Domain::Feature>::axis(const FeatureAxis &axis_f) const
+{
+ assert(valid(axis_f) && "invalid feature axis");
+ assert(mapped(axis_f) && "unmapped feature axis");
+ return _map.at(axis_f);
+}
+
+uint32_t &Permutation<Domain::Feature>::axis(const FeatureAxis &axis_f)
+{
+ assert(valid(axis_f) && "invalid feature axis");
+ return _map[axis_f];
+}
+
+//
+// Permuting Encoder
+//
+FeatureShape PermutingEncoder<Domain::Feature>::shape(const TensorShape &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ FeatureShape out;
+
+ out.count() = in.dim(_perm[FeatureAxis::Count]);
+ out.depth() = in.dim(_perm[FeatureAxis::Depth]);
+ out.height() = in.dim(_perm[FeatureAxis::Height]);
+ out.width() = in.dim(_perm[FeatureAxis::Width]);
+
+ return out;
+}
+
+TensorIndex PermutingEncoder<Domain::Feature>::value(const FeatureIndex &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ TensorIndex out;
+
+ out.resize(4);
+
+ out.at(_perm[FeatureAxis::Count]) = in.batch();
+ out.at(_perm[FeatureAxis::Depth]) = in.channel();
+ out.at(_perm[FeatureAxis::Height]) = in.row();
+ out.at(_perm[FeatureAxis::Width]) = in.column();
+
+ return out;
+}
+
+std::unique_ptr<FeatureEncoder> PermutingEncoder<Domain::Feature>::clone(void) const
+{
+ return stdex::make_unique<PermutingEncoder<Domain::Feature>>(_perm);
+}
+
+bool PermutingEncoder<Domain::Feature>::valid(void) const { return ::valid(_perm); }
+
+//
+// Permuting Decoder
+//
+TensorShape PermutingDecoder<Domain::Feature>::shape(const FeatureShape &in) const
+{
+ assert(valid() && "invalid permuation");
+
+ TensorShape out;
+
+ out.rank(4);
+
+ out.dim(_perm[FeatureAxis::Count]) = in.count();
+ out.dim(_perm[FeatureAxis::Depth]) = in.depth();
+ out.dim(_perm[FeatureAxis::Height]) = in.height();
+ out.dim(_perm[FeatureAxis::Width]) = in.width();
+
+ return out;
+}
+
+FeatureIndex PermutingDecoder<Domain::Feature>::value(const TensorIndex &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ FeatureIndex out;
+
+ out.batch() = in.at(_perm[FeatureAxis::Count]);
+ out.channel() = in.at(_perm[FeatureAxis::Depth]);
+ out.row() = in.at(_perm[FeatureAxis::Height]);
+ out.column() = in.at(_perm[FeatureAxis::Width]);
+
+ return out;
+}
+
+std::unique_ptr<FeatureDecoder> PermutingDecoder<Domain::Feature>::clone(void) const
+{
+ return stdex::make_unique<PermutingDecoder<Domain::Feature>>(_perm);
+}
+
+bool PermutingDecoder<Domain::Feature>::valid(void) const { return ::valid(_perm); }
+
+} // namespace loco
+
+/**
+ * Filter Domain
+ */
+namespace
+{
+
+using loco::FilterAxis;
+
+inline bool valid(const FilterAxis &axis)
+{
+ switch (axis)
+ {
+ case FilterAxis::Count:
+ return true;
+ case FilterAxis::Depth:
+ return true;
+ case FilterAxis::Height:
+ return true;
+ case FilterAxis::Width:
+ return true;
+ default:
+ break;
+ }
+
+ return false;
+}
+
+inline bool valid(const loco::Permutation<loco::Domain::Filter> &perm)
+{
+ auto check = [&perm](FilterAxis axis_f) {
+ if (!perm.mapped(axis_f))
+ return false;
+ return perm.axis(axis_f) < 4;
+ };
+
+ if (!check(FilterAxis::Count))
+ return false;
+ if (!check(FilterAxis::Depth))
+ return false;
+ if (!check(FilterAxis::Height))
+ return false;
+ if (!check(FilterAxis::Width))
+ return false;
+
+ // Check whether tensor axes are all distinct
+ std::set<loco::TensorAxis> values;
+
+ values.insert(perm[FilterAxis::Count]);
+ values.insert(perm[FilterAxis::Depth]);
+ values.insert(perm[FilterAxis::Height]);
+ values.insert(perm[FilterAxis::Width]);
+
+ return values.size() == 4;
+}
+
+} // namespace
+
+namespace loco
+{
+
+//
+// Permutation
+//
+bool Permutation<Domain::Filter>::mapped(const FilterAxis &axis_f) const
+{
+ assert(valid(axis_f) && "invalid filter axis");
+ return _map.find(axis_f) != _map.end();
+}
+
+const uint32_t &Permutation<Domain::Filter>::axis(const FilterAxis &axis_f) const
+{
+ assert(valid(axis_f) && "invalid filter axis");
+ assert(mapped(axis_f) && "unmapped filter axis");
+ return _map.at(axis_f);
+}
+
+uint32_t &Permutation<Domain::Filter>::axis(const FilterAxis &axis_f)
+{
+ assert(valid(axis_f) && "invalid filter axis");
+ return _map[axis_f];
+}
+
+//
+// Permuting Encoder
+//
+FilterShape PermutingEncoder<Domain::Filter>::shape(const TensorShape &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ FilterShape out;
+
+ out.count() = in.dim(_perm[FilterAxis::Count]);
+ out.depth() = in.dim(_perm[FilterAxis::Depth]);
+ out.height() = in.dim(_perm[FilterAxis::Height]);
+ out.width() = in.dim(_perm[FilterAxis::Width]);
+
+ return out;
+}
+
+TensorIndex PermutingEncoder<Domain::Filter>::value(const FilterIndex &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ TensorIndex out;
+
+ out.resize(4);
+
+ out.at(_perm[FilterAxis::Count]) = in.nth();
+ out.at(_perm[FilterAxis::Depth]) = in.channel();
+ out.at(_perm[FilterAxis::Height]) = in.row();
+ out.at(_perm[FilterAxis::Width]) = in.column();
+
+ return out;
+}
+
+bool PermutingEncoder<Domain::Filter>::valid(void) const { return ::valid(_perm); }
+
+//
+// Permuting Decoder
+//
+TensorShape PermutingDecoder<Domain::Filter>::shape(const FilterShape &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ TensorShape out;
+
+ out.rank(4);
+ out.dim(_perm[FilterAxis::Count]) = in.count();
+ out.dim(_perm[FilterAxis::Depth]) = in.depth();
+ out.dim(_perm[FilterAxis::Height]) = in.height();
+ out.dim(_perm[FilterAxis::Width]) = in.width();
+
+ return out;
+}
+
+FilterIndex PermutingDecoder<Domain::Filter>::value(const TensorIndex &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ FilterIndex out;
+
+ out.nth() = in.at(_perm[FilterAxis::Count]);
+ out.channel() = in.at(_perm[FilterAxis::Depth]);
+ out.row() = in.at(_perm[FilterAxis::Height]);
+ out.column() = in.at(_perm[FilterAxis::Width]);
+
+ return out;
+}
+
+bool PermutingDecoder<Domain::Filter>::valid(void) const { return ::valid(_perm); }
+
+} // namespace loco
+
+/**
+ * DepthwiseFilter Domain
+ */
+namespace
+{
+
+using loco::DepthwiseFilterAxis;
+
+inline bool valid(const DepthwiseFilterAxis &axis)
+{
+ switch (axis)
+ {
+ case DepthwiseFilterAxis::Depth:
+ return true;
+ case DepthwiseFilterAxis::Multiplier:
+ return true;
+ case DepthwiseFilterAxis::Height:
+ return true;
+ case DepthwiseFilterAxis::Width:
+ return true;
+ default:
+ break;
+ }
+
+ return false;
+}
+
+inline bool valid(const loco::Permutation<loco::Domain::DepthwiseFilter> &perm)
+{
+ auto check = [&perm](DepthwiseFilterAxis axis_f) {
+ if (!perm.mapped(axis_f))
+ return false;
+ return perm.axis(axis_f) < 4;
+ };
+
+ if (!check(DepthwiseFilterAxis::Depth))
+ return false;
+ if (!check(DepthwiseFilterAxis::Multiplier))
+ return false;
+ if (!check(DepthwiseFilterAxis::Height))
+ return false;
+ if (!check(DepthwiseFilterAxis::Width))
+ return false;
+
+ // Check whether tensor axes are all distinct
+ std::set<loco::TensorAxis> values;
+
+ values.insert(perm[DepthwiseFilterAxis::Depth]);
+ values.insert(perm[DepthwiseFilterAxis::Multiplier]);
+ values.insert(perm[DepthwiseFilterAxis::Height]);
+ values.insert(perm[DepthwiseFilterAxis::Width]);
+
+ return values.size() == 4;
+}
+
+} // namespace
+
+namespace loco
+{
+
+//
+// Permutation
+//
+bool Permutation<Domain::DepthwiseFilter>::mapped(const DepthwiseFilterAxis &axis_f) const
+{
+ assert(valid(axis_f) && "invalid depthwise filter axis");
+ return _map.find(axis_f) != _map.end();
+}
+
+const uint32_t &Permutation<Domain::DepthwiseFilter>::axis(const DepthwiseFilterAxis &axis_f) const
+{
+ assert(valid(axis_f) && "invalid depthwise filter axis");
+ assert(mapped(axis_f) && "unmapped depthwise filter axis");
+ return _map.at(axis_f);
+}
+
+uint32_t &Permutation<Domain::DepthwiseFilter>::axis(const DepthwiseFilterAxis &axis_f)
+{
+ assert(valid(axis_f) && "invalid depthwise filter axis");
+ return _map[axis_f];
+}
+
+//
+// Permuting Encoder
+//
+DepthwiseFilterShape PermutingEncoder<Domain::DepthwiseFilter>::shape(const TensorShape &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ DepthwiseFilterShape out;
+
+ out.depth() = in.dim(_perm[DepthwiseFilterAxis::Depth]);
+ out.multiplier() = in.dim(_perm[DepthwiseFilterAxis::Multiplier]);
+ out.height() = in.dim(_perm[DepthwiseFilterAxis::Height]);
+ out.width() = in.dim(_perm[DepthwiseFilterAxis::Width]);
+
+ return out;
+}
+
+TensorIndex PermutingEncoder<Domain::DepthwiseFilter>::value(const DepthwiseFilterIndex &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ TensorIndex out;
+
+ out.resize(4);
+
+ out.at(_perm[DepthwiseFilterAxis::Depth]) = in.channel();
+ out.at(_perm[DepthwiseFilterAxis::Multiplier]) = in.nth();
+ out.at(_perm[DepthwiseFilterAxis::Height]) = in.row();
+ out.at(_perm[DepthwiseFilterAxis::Width]) = in.column();
+
+ return out;
+}
+
+bool PermutingEncoder<Domain::DepthwiseFilter>::valid(void) const { return ::valid(_perm); }
+
+//
+// Permuting Decoder
+//
+TensorShape PermutingDecoder<Domain::DepthwiseFilter>::shape(const DepthwiseFilterShape &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ TensorShape out;
+ out.rank(4);
+
+ out.dim(_perm[DepthwiseFilterAxis::Depth]) = in.depth();
+ out.dim(_perm[DepthwiseFilterAxis::Multiplier]) = in.multiplier();
+ out.dim(_perm[DepthwiseFilterAxis::Height]) = in.height();
+ out.dim(_perm[DepthwiseFilterAxis::Width]) = in.width();
+
+ return out;
+}
+
+DepthwiseFilterIndex PermutingDecoder<Domain::DepthwiseFilter>::value(const TensorIndex &in) const
+{
+ assert(valid() && "invalid permutation");
+ assert(in.rank() == 4);
+
+ DepthwiseFilterIndex out;
+
+ out.channel() = in.at(_perm[DepthwiseFilterAxis::Depth]);
+ out.nth() = in.at(_perm[DepthwiseFilterAxis::Multiplier]);
+ out.row() = in.at(_perm[DepthwiseFilterAxis::Height]);
+ out.column() = in.at(_perm[DepthwiseFilterAxis::Width]);
+
+ return out;
+}
+
+bool PermutingDecoder<Domain::DepthwiseFilter>::valid(void) const { return ::valid(_perm); }
+
+} // namespace loco
+
+/**
+ * Matrix Domain
+ */
+namespace
+{
+
+using loco::MatrixAxis;
+
+inline bool valid(const MatrixAxis &axis)
+{
+ switch (axis)
+ {
+ case MatrixAxis::Height:
+ return true;
+ case MatrixAxis::Width:
+ return true;
+ default:
+ break;
+ }
+
+ return false;
+}
+
+inline bool valid(const loco::Permutation<loco::Domain::Matrix> &perm)
+{
+ auto check = [&perm](MatrixAxis axis_f) {
+ if (!perm.mapped(axis_f))
+ return false;
+ return perm.axis(axis_f) < 2;
+ };
+
+ if (!check(MatrixAxis::Height))
+ return false;
+ if (!check(MatrixAxis::Width))
+ return false;
+
+ // Check whether tensor axes are all distinct
+ std::set<loco::TensorAxis> values;
+
+ values.insert(perm[MatrixAxis::Height]);
+ values.insert(perm[MatrixAxis::Width]);
+
+ return values.size() == 2;
+}
+
+} // namespace
+
+namespace loco
+{
+
+//
+// Permutation
+//
+bool Permutation<Domain::Matrix>::mapped(const MatrixAxis &axis_f) const
+{
+ assert(valid(axis_f) && "invalid matrix axis");
+ return _map.find(axis_f) != _map.end();
+}
+
+uint32_t Permutation<Domain::Matrix>::axis(const MatrixAxis &axis_f) const
+{
+ assert(valid(axis_f) && "invalid matrix axis");
+ assert(mapped(axis_f) && "unmapped matrix axis");
+ return _map.at(axis_f);
+}
+
+uint32_t &Permutation<Domain::Matrix>::axis(const MatrixAxis &axis_f)
+{
+ assert(valid(axis_f) && "invalid matrix axis");
+ return _map[axis_f];
+}
+
+//
+// Permuting Encoder
+//
+MatrixShape PermutingEncoder<Domain::Matrix>::shape(const TensorShape &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ MatrixShape out;
+
+ out.height() = in.dim(_perm[MatrixAxis::Height]);
+ out.width() = in.dim(_perm[MatrixAxis::Width]);
+
+ return out;
+}
+
+TensorIndex PermutingEncoder<Domain::Matrix>::value(const MatrixIndex &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ TensorIndex out;
+
+ out.resize(2);
+
+ out.at(_perm[MatrixAxis::Height]) = in.row();
+ out.at(_perm[MatrixAxis::Width]) = in.column();
+
+ return out;
+}
+
+bool PermutingEncoder<Domain::Matrix>::valid(void) const { return ::valid(_perm); }
+
+//
+// Permuting Decoder
+//
+TensorShape PermutingDecoder<Domain::Matrix>::shape(const MatrixShape &in) const
+{
+ assert(valid() && "invalid permuation");
+
+ TensorShape out;
+
+ out.rank(2);
+
+ out.dim(_perm[MatrixAxis::Height]) = in.height();
+ out.dim(_perm[MatrixAxis::Width]) = in.width();
+
+ return out;
+}
+
+MatrixIndex PermutingDecoder<Domain::Matrix>::value(const TensorIndex &in) const
+{
+ assert(valid() && "invalid permutation");
+
+ MatrixIndex out;
+
+ out.row() = in.at(_perm[MatrixAxis::Height]);
+ out.column() = in.at(_perm[MatrixAxis::Width]);
+
+ return out;
+}
+
+bool PermutingDecoder<Domain::Matrix>::valid(void) const { return ::valid(_perm); }
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/PermutingCodec.test.cpp b/compiler/loco/src/IR/PermutingCodec.test.cpp
new file mode 100644
index 000000000..2eff286d0
--- /dev/null
+++ b/compiler/loco/src/IR/PermutingCodec.test.cpp
@@ -0,0 +1,553 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/PermutingCodec.h"
+
+#include <gtest/gtest.h>
+
+using namespace loco;
+
+TEST(PemutationTest, feature)
+{
+ Permutation<Domain::Feature> perm;
+
+ // All values are invalid at the beginning
+ ASSERT_FALSE(perm.mapped(FeatureAxis::Count));
+ ASSERT_FALSE(perm.mapped(FeatureAxis::Depth));
+ ASSERT_FALSE(perm.mapped(FeatureAxis::Height));
+ ASSERT_FALSE(perm.mapped(FeatureAxis::Width));
+
+ // Update mapping
+ perm[FeatureAxis::Count] = 5;
+ perm[FeatureAxis::Depth] = 6;
+ perm[FeatureAxis::Height] = 7;
+ perm[FeatureAxis::Width] = 8;
+
+ // Now perm has a mapping for all the axes
+ ASSERT_TRUE(perm.mapped(FeatureAxis::Count));
+ ASSERT_TRUE(perm.mapped(FeatureAxis::Depth));
+ ASSERT_TRUE(perm.mapped(FeatureAxis::Height));
+ ASSERT_TRUE(perm.mapped(FeatureAxis::Width));
+
+ // Check the value
+ ASSERT_EQ(perm[FeatureAxis::Count], 5);
+ ASSERT_EQ(perm[FeatureAxis::Depth], 6);
+ ASSERT_EQ(perm[FeatureAxis::Height], 7);
+ ASSERT_EQ(perm[FeatureAxis::Width], 8);
+}
+
+TEST(PemutationTest, filter)
+{
+ Permutation<Domain::Filter> perm;
+
+ // All values are invalid at the beginning
+ ASSERT_FALSE(perm.mapped(FilterAxis::Count));
+ ASSERT_FALSE(perm.mapped(FilterAxis::Depth));
+ ASSERT_FALSE(perm.mapped(FilterAxis::Height));
+ ASSERT_FALSE(perm.mapped(FilterAxis::Width));
+
+ // Update mapping
+ perm[FilterAxis::Count] = 5;
+ perm[FilterAxis::Depth] = 6;
+ perm[FilterAxis::Height] = 7;
+ perm[FilterAxis::Width] = 8;
+
+ // Now perm has a mapping for all the axes
+ ASSERT_TRUE(perm.mapped(FilterAxis::Count));
+ ASSERT_TRUE(perm.mapped(FilterAxis::Depth));
+ ASSERT_TRUE(perm.mapped(FilterAxis::Height));
+ ASSERT_TRUE(perm.mapped(FilterAxis::Width));
+
+ // Check the value
+ ASSERT_EQ(perm[FilterAxis::Count], 5);
+ ASSERT_EQ(perm[FilterAxis::Depth], 6);
+ ASSERT_EQ(perm[FilterAxis::Height], 7);
+ ASSERT_EQ(perm[FilterAxis::Width], 8);
+}
+
+TEST(PemutationTest, depthwise_filter)
+{
+ Permutation<Domain::DepthwiseFilter> perm;
+
+ // All values are invalid at the beginning
+ ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Depth));
+ ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Multiplier));
+ ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Height));
+ ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Width));
+
+ // Update mapping
+ perm[DepthwiseFilterAxis::Depth] = 5;
+ perm[DepthwiseFilterAxis::Multiplier] = 6;
+ perm[DepthwiseFilterAxis::Height] = 7;
+ perm[DepthwiseFilterAxis::Width] = 8;
+
+ // Now perm has a mapping for all the axes
+ ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Depth));
+ ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Multiplier));
+ ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Height));
+ ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Width));
+
+ // Check the value
+ ASSERT_EQ(perm[DepthwiseFilterAxis::Depth], 5);
+ ASSERT_EQ(perm[DepthwiseFilterAxis::Multiplier], 6);
+ ASSERT_EQ(perm[DepthwiseFilterAxis::Height], 7);
+ ASSERT_EQ(perm[DepthwiseFilterAxis::Width], 8);
+}
+
+TEST(PermutingEncoderTest, feature)
+{
+ PermutingEncoder<Domain::Feature> enc;
+
+ // Encoder is invalid at the beginning
+ ASSERT_FALSE(enc.valid());
+
+ // Set "invalid" mapping
+ enc.perm()->axis(FeatureAxis::Count) = 0;
+ enc.perm()->axis(FeatureAxis::Depth) = 6;
+ enc.perm()->axis(FeatureAxis::Height) = 1;
+ enc.perm()->axis(FeatureAxis::Width) = 2;
+
+ // Encoder is still invalid
+ ASSERT_FALSE(enc.valid());
+
+ // Set another "invalid" mapping
+ enc.perm()->axis(FeatureAxis::Depth) = 1;
+
+ // Encoder is still invalid
+ ASSERT_FALSE(enc.valid());
+
+ // Set "valid" mapping
+ enc.perm()->axis(FeatureAxis::Depth) = 3;
+
+ // Encoder is now valid
+ ASSERT_TRUE(enc.valid());
+
+ // Let's test with a HD (1280x720) RGB image
+ TensorShape tensor_shape;
+
+ tensor_shape.rank(4);
+ tensor_shape.dim(0) = 1; // COUNT
+ tensor_shape.dim(1) = 720; // HEIGHT
+ tensor_shape.dim(2) = 1280; // WIDTH
+ tensor_shape.dim(3) = 3; // DEPTH
+
+ // Get the feature shape corresponding to a given image
+ auto feature_shape = enc.shape(tensor_shape);
+
+ ASSERT_EQ(feature_shape.count(), 1);
+ ASSERT_EQ(feature_shape.depth(), 3);
+ ASSERT_EQ(feature_shape.height(), 720);
+ ASSERT_EQ(feature_shape.width(), 1280);
+
+ // Let's find a source tensor index!
+ FeatureIndex feature_index;
+
+ feature_index.batch() = 0;
+ feature_index.channel() = 1;
+ feature_index.row() = 2;
+ feature_index.column() = 3;
+
+ auto tensor_index = enc.value(feature_index);
+
+ ASSERT_EQ(tensor_index.at(0), 0); // BATCH(COUNT)
+ ASSERT_EQ(tensor_index.at(1), 2); // ROW(HEIGHT)
+ ASSERT_EQ(tensor_index.at(2), 3); // COLUMN(WIDTH)
+ ASSERT_EQ(tensor_index.at(3), 1); // CHANNEL(DEPTH)
+}
+
+TEST(PermutingEncoderTest, feature_clone)
+{
+ PermutingEncoder<Domain::Feature> src_enc;
+
+ auto src_perm = src_enc.perm();
+
+ src_perm->axis(FeatureAxis::Count) = 0;
+ src_perm->axis(FeatureAxis::Depth) = 3;
+ src_perm->axis(FeatureAxis::Height) = 1;
+ src_perm->axis(FeatureAxis::Width) = 2;
+
+ auto dst_enc = src_enc.clone();
+ auto dst_perm = dynamic_cast<PermutingEncoder<Domain::Feature> *>(dst_enc.get())->perm();
+
+ EXPECT_EQ(dst_perm->axis(FeatureAxis::Count), src_perm->axis(FeatureAxis::Count));
+ EXPECT_EQ(dst_perm->axis(FeatureAxis::Depth), src_perm->axis(FeatureAxis::Depth));
+ EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), src_perm->axis(FeatureAxis::Height));
+ EXPECT_EQ(dst_perm->axis(FeatureAxis::Width), src_perm->axis(FeatureAxis::Width));
+
+ // Update on cloned encoder SHOULD NOT affect the original encoder
+ dst_perm->axis(FeatureAxis::Height) += 1;
+
+ EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1);
+ EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2);
+}
+
+TEST(PermutingEncoderTest, filter)
+{
+ PermutingEncoder<Domain::Filter> enc;
+
+ // Encoder is invalid at the beginning
+ ASSERT_FALSE(enc.valid());
+
+ // Set "invalid" mapping
+ enc.perm()->axis(FilterAxis::Count) = 0;
+ enc.perm()->axis(FilterAxis::Depth) = 6;
+ enc.perm()->axis(FilterAxis::Height) = 1;
+ enc.perm()->axis(FilterAxis::Width) = 2;
+
+ // Encoder is still invalid
+ ASSERT_FALSE(enc.valid());
+
+ // Set another "invalid" mapping
+ enc.perm()->axis(FilterAxis::Depth) = 1;
+
+ // Encoder is still invalid
+ ASSERT_FALSE(enc.valid());
+
+ // Set "valid" mapping
+ enc.perm()->axis(FilterAxis::Depth) = 3;
+
+ // Encoder is now valid
+ ASSERT_TRUE(enc.valid());
+
+ TensorShape tensor_shape;
+
+ tensor_shape.rank(4);
+ tensor_shape.dim(0) = 8; // COUNT
+ tensor_shape.dim(1) = 1; // HEIGHT
+ tensor_shape.dim(2) = 7; // WIDTH
+ tensor_shape.dim(3) = 4; // DEPTH
+
+ // Get the corresponding filter shape
+ auto filter_shape = enc.shape(tensor_shape);
+
+ ASSERT_EQ(filter_shape.count(), 8);
+ ASSERT_EQ(filter_shape.depth(), 4);
+ ASSERT_EQ(filter_shape.height(), 1);
+ ASSERT_EQ(filter_shape.width(), 7);
+
+ // Let's find a source tensor index!
+ FilterIndex filter_index;
+
+ filter_index.nth() = 1;
+ filter_index.channel() = 2;
+ filter_index.row() = 0;
+ filter_index.column() = 3;
+
+ auto tensor_index = enc.value(filter_index);
+
+ ASSERT_EQ(tensor_index.at(0), 1); // NTH(COUNT)
+ ASSERT_EQ(tensor_index.at(1), 0); // ROW(HEIGHT)
+ ASSERT_EQ(tensor_index.at(2), 3); // COLUMN(WIDTH)
+ ASSERT_EQ(tensor_index.at(3), 2); // CHANNEL(DEPTH)
+}
+
+TEST(PermutingEncoderTest, depthwise_filter)
+{
+ PermutingEncoder<Domain::DepthwiseFilter> enc;
+
+ // Encoder is invalid at the beginning
+ ASSERT_FALSE(enc.valid());
+
+ // Set "invalid" mapping
+ enc.perm()->axis(DepthwiseFilterAxis::Depth) = 0;
+ enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6;
+ enc.perm()->axis(DepthwiseFilterAxis::Height) = 1;
+ enc.perm()->axis(DepthwiseFilterAxis::Width) = 2;
+
+ // Encoder is still invalid
+ ASSERT_FALSE(enc.valid());
+
+ // Set another "invalid" mapping
+ enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1;
+
+ // Encoder is still invalid
+ ASSERT_FALSE(enc.valid());
+
+ // Set "valid" mapping
+ enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3;
+
+ // Encoder is now valid
+ ASSERT_TRUE(enc.valid());
+
+ TensorShape tensor_shape;
+
+ tensor_shape.rank(4);
+ tensor_shape.dim(0) = 8; // DEPTH
+ tensor_shape.dim(1) = 1; // HEIGHT
+ tensor_shape.dim(2) = 7; // WIDTH
+ tensor_shape.dim(3) = 4; // MULTIPLIER
+
+ // Get the corresponding depthwise filter shape
+ auto filter_shape = enc.shape(tensor_shape);
+
+ ASSERT_EQ(filter_shape.depth(), 8);
+ ASSERT_EQ(filter_shape.multiplier(), 4);
+ ASSERT_EQ(filter_shape.height(), 1);
+ ASSERT_EQ(filter_shape.width(), 7);
+
+ // Let's find a source tensor index!
+ DepthwiseFilterIndex filter_index;
+
+ filter_index.channel() = 1;
+ filter_index.nth() = 2;
+ filter_index.row() = 0;
+ filter_index.column() = 3;
+
+ auto tensor_index = enc.value(filter_index);
+
+ ASSERT_EQ(tensor_index.at(0), 1); // CHANNEL(DEPTH)
+ ASSERT_EQ(tensor_index.at(1), 0); // ROW(HEIGHT)
+ ASSERT_EQ(tensor_index.at(2), 3); // COLUMN(WIDTH)
+ ASSERT_EQ(tensor_index.at(3), 2); // NTH(MULTIPLIER)
+}
+
+TEST(PermutingEncoderTest, depthwisefilter_init)
+{
+ Permutation<Domain::DepthwiseFilter> src_perm;
+
+ src_perm.axis(DepthwiseFilterAxis::Multiplier) = 0;
+ src_perm.axis(DepthwiseFilterAxis::Depth) = 3;
+ src_perm.axis(DepthwiseFilterAxis::Height) = 1;
+ src_perm.axis(DepthwiseFilterAxis::Width) = 2;
+
+ PermutingEncoder<Domain::DepthwiseFilter> dst_enc{src_perm};
+ auto dst_perm = dst_enc.perm();
+
+ EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Multiplier),
+ src_perm.axis(DepthwiseFilterAxis::Multiplier));
+ EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Depth), src_perm.axis(DepthwiseFilterAxis::Depth));
+ EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Height),
+ src_perm.axis(DepthwiseFilterAxis::Height));
+ EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Width), src_perm.axis(DepthwiseFilterAxis::Width));
+
+ // Update on dst perm SHOULD NOT affect the src perm
+ dst_perm->axis(DepthwiseFilterAxis::Height) += 1;
+
+ EXPECT_EQ(src_perm.axis(DepthwiseFilterAxis::Height), 1);
+ EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Height), 2);
+}
+
+TEST(PermutingDecoderTest, feature)
+{
+ PermutingDecoder<Domain::Feature> dec;
+
+ // Decoder is invalid at the beginning
+ ASSERT_FALSE(dec.valid());
+
+ // Set "invalid" mapping
+ dec.perm()->axis(FeatureAxis::Count) = 0;
+ dec.perm()->axis(FeatureAxis::Depth) = 6;
+ dec.perm()->axis(FeatureAxis::Height) = 1;
+ dec.perm()->axis(FeatureAxis::Width) = 2;
+
+ // Decoder is still invalid
+ ASSERT_FALSE(dec.valid());
+
+ // Set another "invalid" mapping
+ dec.perm()->axis(FeatureAxis::Depth) = 1;
+
+ // Decoder is still invalid
+ ASSERT_FALSE(dec.valid());
+
+ // Set "valid" mapping
+ dec.perm()->axis(FeatureAxis::Depth) = 3;
+
+ // Decoder is now valid
+ ASSERT_TRUE(dec.valid());
+
+ // Let's test with a HD (1280x720) RGB image
+ FeatureShape feature_shape;
+
+ feature_shape.count() = 1;
+ feature_shape.depth() = 3;
+ feature_shape.height() = 720;
+ feature_shape.width() = 1280;
+
+ // Get the tensor shape corresponding to a given image
+ auto tensor_shape = dec.shape(feature_shape);
+
+ ASSERT_EQ(tensor_shape.rank(), 4);
+ ASSERT_EQ(tensor_shape.dim(0), 1); // COUNT
+ ASSERT_EQ(tensor_shape.dim(1), 720); // HEIGHT
+ ASSERT_EQ(tensor_shape.dim(2), 1280); // WIDTH
+ ASSERT_EQ(tensor_shape.dim(3), 3); // DEPTH
+
+ // Let's find a source feature index!
+ TensorIndex tensor_index;
+
+ tensor_index.resize(4);
+
+ tensor_index.at(0) = 0; // BATCH(COUNT)
+ tensor_index.at(3) = 1; // CHANNEL(DEPTH)
+ tensor_index.at(1) = 2; // ROW(HEIGHT)
+ tensor_index.at(2) = 3; // COLUMN(WIDTH)
+
+ auto feature_index = dec.value(tensor_index);
+
+ ASSERT_EQ(feature_index.batch(), 0);
+ ASSERT_EQ(feature_index.channel(), 1);
+ ASSERT_EQ(feature_index.row(), 2);
+ ASSERT_EQ(feature_index.column(), 3);
+}
+
+TEST(PermutingDecoderTest, feature_clone)
+{
+ PermutingDecoder<Domain::Feature> src_enc;
+
+ auto src_perm = src_enc.perm();
+
+ src_perm->axis(FeatureAxis::Count) = 0;
+ src_perm->axis(FeatureAxis::Depth) = 3;
+ src_perm->axis(FeatureAxis::Height) = 1;
+ src_perm->axis(FeatureAxis::Width) = 2;
+
+ auto dst_enc = src_enc.clone();
+ auto dst_perm = dynamic_cast<PermutingDecoder<Domain::Feature> *>(dst_enc.get())->perm();
+
+ EXPECT_EQ(dst_perm->axis(FeatureAxis::Count), src_perm->axis(FeatureAxis::Count));
+ EXPECT_EQ(dst_perm->axis(FeatureAxis::Depth), src_perm->axis(FeatureAxis::Depth));
+ EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), src_perm->axis(FeatureAxis::Height));
+ EXPECT_EQ(dst_perm->axis(FeatureAxis::Width), src_perm->axis(FeatureAxis::Width));
+
+ // Update on cloned decoder SHOULD NOT affect the original decoder
+ dst_perm->axis(FeatureAxis::Height) += 1;
+
+ EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1);
+ EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2);
+}
+
+TEST(PermutingDecoderTest, filter)
+{
+ PermutingDecoder<Domain::Filter> dec;
+
+ // Decoder is invalid at the beginning
+ ASSERT_FALSE(dec.valid());
+
+ // Set "invalid" mapping
+ dec.perm()->axis(FilterAxis::Count) = 0;
+ dec.perm()->axis(FilterAxis::Depth) = 6;
+ dec.perm()->axis(FilterAxis::Height) = 1;
+ dec.perm()->axis(FilterAxis::Width) = 2;
+
+ // Decoder is still invalid
+ ASSERT_FALSE(dec.valid());
+
+ // Set another "invalid" mapping
+ dec.perm()->axis(FilterAxis::Depth) = 1;
+
+ // Decoder is still invalid
+ ASSERT_FALSE(dec.valid());
+
+ // Set "valid" mapping
+ dec.perm()->axis(FilterAxis::Depth) = 3;
+
+ // Decoder is now valid
+ ASSERT_TRUE(dec.valid());
+
+ // Let's test with a small filter
+ FilterShape filter_shape;
+
+ filter_shape.count() = 10;
+ filter_shape.depth() = 3;
+ filter_shape.height() = 6;
+ filter_shape.width() = 8;
+
+ // Get the tensor shape corresponding to a given image
+ auto tensor_shape = dec.shape(filter_shape);
+
+ ASSERT_EQ(tensor_shape.rank(), 4);
+ ASSERT_EQ(tensor_shape.dim(0), 10); // COUNT
+ ASSERT_EQ(tensor_shape.dim(1), 6); // HEIGHT
+ ASSERT_EQ(tensor_shape.dim(2), 8); // WIDTH
+ ASSERT_EQ(tensor_shape.dim(3), 3); // DEPTH
+
+ // Let's find a source filter index!
+ TensorIndex tensor_index;
+
+ tensor_index.resize(4);
+
+ tensor_index.at(0) = 0; // BATCH(COUNT)
+ tensor_index.at(3) = 1; // CHANNEL(DEPTH)
+ tensor_index.at(1) = 2; // ROW(HEIGHT)
+ tensor_index.at(2) = 3; // COLUMN(WIDTH)
+
+ auto filter_index = dec.value(tensor_index);
+
+ ASSERT_EQ(filter_index.nth(), 0);
+ ASSERT_EQ(filter_index.channel(), 1);
+ ASSERT_EQ(filter_index.row(), 2);
+ ASSERT_EQ(filter_index.column(), 3);
+}
+
+TEST(PermutingDecoderTest, depthwise_filter)
+{
+ PermutingDecoder<Domain::DepthwiseFilter> dec;
+
+ // Decoder is invalid at the beginning
+ ASSERT_FALSE(dec.valid());
+
+ // Set "invalid" mapping
+ dec.perm()->axis(DepthwiseFilterAxis::Depth) = 0;
+ dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6;
+ dec.perm()->axis(DepthwiseFilterAxis::Height) = 1;
+ dec.perm()->axis(DepthwiseFilterAxis::Width) = 2;
+
+ // Decoder is still invalid
+ ASSERT_FALSE(dec.valid());
+
+ // Set another "invalid" mapping
+ dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1;
+
+ // Decoder is still invalid
+ ASSERT_FALSE(dec.valid());
+
+ // Set "valid" mapping
+ dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3;
+
+ // Decoder is now valid
+ ASSERT_TRUE(dec.valid());
+
+ DepthwiseFilterShape dw_filter_shape;
+
+ dw_filter_shape.depth() = 8;
+ dw_filter_shape.multiplier() = 1;
+ dw_filter_shape.height() = 7;
+ dw_filter_shape.width() = 4;
+
+ // Get the corresponding depthwise filter shape
+ auto tensor_shape = dec.shape(dw_filter_shape);
+
+ ASSERT_EQ(tensor_shape.dim(0).value(), 8);
+ ASSERT_EQ(tensor_shape.dim(1).value(), 7);
+ ASSERT_EQ(tensor_shape.dim(2).value(), 4);
+ ASSERT_EQ(tensor_shape.dim(3).value(), 1);
+
+ // Let's find a source tensor index!
+ TensorIndex tensor_index;
+ tensor_index.resize(4);
+
+ tensor_index.at(0) = 4;
+ tensor_index.at(1) = 2;
+ tensor_index.at(2) = 1;
+ tensor_index.at(3) = 0;
+
+ auto dw_filter_index = dec.value(tensor_index);
+
+ ASSERT_EQ(dw_filter_index.channel(), 4);
+ ASSERT_EQ(dw_filter_index.nth(), 0);
+ ASSERT_EQ(dw_filter_index.row(), 2);
+ ASSERT_EQ(dw_filter_index.column(), 1);
+}
diff --git a/compiler/loco/src/IR/Stride.test.cpp b/compiler/loco/src/IR/Stride.test.cpp
new file mode 100644
index 000000000..60deb5c6f
--- /dev/null
+++ b/compiler/loco/src/IR/Stride.test.cpp
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Stride.h"
+
+#include <gtest/gtest.h>
+
+TEST(StrideTest, default_constructor_2D)
+{
+ loco::Stride<2> stride;
+
+ ASSERT_EQ(stride.vertical(), 1);
+ ASSERT_EQ(stride.horizontal(), 1);
+}
+
+TEST(StrideTest, setter_and_getter_2D)
+{
+ loco::Stride<2> stride;
+
+ stride.vertical(2);
+
+ ASSERT_EQ(stride.vertical(), 2);
+ ASSERT_EQ(stride.horizontal(), 1);
+
+ stride.horizontal(3);
+
+ ASSERT_EQ(stride.vertical(), 2);
+ ASSERT_EQ(stride.horizontal(), 3);
+}
diff --git a/compiler/loco/src/IR/TensorAxis.cpp b/compiler/loco/src/IR/TensorAxis.cpp
new file mode 100644
index 000000000..b083847fc
--- /dev/null
+++ b/compiler/loco/src/IR/TensorAxis.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/TensorAxis.h"
+
+// NOTE This file validates "TensorAxis.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/TensorAxisSet.cpp b/compiler/loco/src/IR/TensorAxisSet.cpp
new file mode 100644
index 000000000..c58237bf7
--- /dev/null
+++ b/compiler/loco/src/IR/TensorAxisSet.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/TensorAxisSet.h"
+
+// NOTE This file validates "TensorAxisSet.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/TensorIndex.cpp b/compiler/loco/src/IR/TensorIndex.cpp
new file mode 100644
index 000000000..cbd3698eb
--- /dev/null
+++ b/compiler/loco/src/IR/TensorIndex.cpp
@@ -0,0 +1,19 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/TensorIndex.h"
+
+// NOTE This file validates "TensorIndex.h". Please DO NOT remove this file.
diff --git a/compiler/loco/src/IR/TensorShape.cpp b/compiler/loco/src/IR/TensorShape.cpp
new file mode 100644
index 000000000..ad30dcbc0
--- /dev/null
+++ b/compiler/loco/src/IR/TensorShape.cpp
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/TensorShape.h"
+
+#include <cassert>
+
+namespace loco
+{
+
+uint32_t element_count(const loco::TensorShape *tensor_shape)
+{
+ uint32_t res = 1;
+
+ for (uint32_t axis = 0; axis < tensor_shape->rank(); ++axis)
+ {
+ // Let's use "assert" here as "caller" is responsible for this check.
+ // Please refer to the header for details.
+ assert(tensor_shape->dim(axis).known());
+ res *= tensor_shape->dim(axis).value();
+ }
+
+ return res;
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/TensorShape.test.cpp b/compiler/loco/src/IR/TensorShape.test.cpp
new file mode 100644
index 000000000..ce03ccbd4
--- /dev/null
+++ b/compiler/loco/src/IR/TensorShape.test.cpp
@@ -0,0 +1,109 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/TensorShape.h"
+
+#include <gtest/gtest.h>
+
+TEST(TensorShapeTest, default_constructor)
+{
+ loco::TensorShape tensor_shape;
+
+ ASSERT_EQ(tensor_shape.rank(), 0);
+}
+
+TEST(TensorShapeTest, initializer_list_constructor)
+{
+ loco::TensorShape tensor_shape{3, 5};
+
+ ASSERT_EQ(tensor_shape.rank(), 2);
+
+ ASSERT_TRUE(tensor_shape.dim(0).known());
+ ASSERT_TRUE(tensor_shape.dim(1).known());
+
+ ASSERT_EQ(tensor_shape.dim(0).value(), 3);
+ ASSERT_EQ(tensor_shape.dim(1).value(), 5);
+}
+
+TEST(TensorShapeTest, rank)
+{
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(2);
+
+ ASSERT_EQ(tensor_shape.rank(), 2);
+ ASSERT_FALSE(tensor_shape.dim(0).known());
+ ASSERT_FALSE(tensor_shape.dim(1).known());
+}
+
+TEST(TensorShapeTest, dim)
+{
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(2);
+
+ tensor_shape.dim(0) = 3;
+
+ ASSERT_TRUE(tensor_shape.dim(0).known());
+ ASSERT_FALSE(tensor_shape.dim(1).known());
+
+ ASSERT_EQ(tensor_shape.dim(0), 3);
+}
+
+TEST(TensorShapeTest, rank_update)
+{
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(2);
+
+ tensor_shape.dim(1) = 3;
+
+ tensor_shape.rank(4);
+
+ ASSERT_FALSE(tensor_shape.dim(0).known());
+ ASSERT_TRUE(tensor_shape.dim(1).known());
+ ASSERT_FALSE(tensor_shape.dim(2).known());
+ ASSERT_FALSE(tensor_shape.dim(3).known());
+
+ ASSERT_EQ(tensor_shape.dim(1), 3);
+}
+
+TEST(TensorShapeTest, copy)
+{
+ loco::TensorShape src;
+
+ src.rank(2);
+ src.dim(1) = 3;
+
+ loco::TensorShape dst;
+
+ dst = src;
+
+ ASSERT_EQ(dst.rank(), 2);
+
+ ASSERT_FALSE(dst.dim(0).known());
+ ASSERT_TRUE(dst.dim(1).known());
+
+ ASSERT_EQ(dst.dim(1), 3);
+}
+
+TEST(TensorShapeTest, element_count)
+{
+ // Check Rank-0 case
+ loco::TensorShape src;
+
+ ASSERT_EQ(loco::element_count(&src), 1);
+}
diff --git a/compiler/loco/src/IR/Use.cpp b/compiler/loco/src/IR/Use.cpp
new file mode 100644
index 000000000..fed562c65
--- /dev/null
+++ b/compiler/loco/src/IR/Use.cpp
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Use.h"
+#include "loco/IR/Node.h"
+
+#include <cassert>
+
+namespace loco
+{
+
+void Use::node(Node *node)
+{
+ if (_node != nullptr)
+ {
+ assert(_node->_uses.find(this) != _node->_uses.end());
+ _node->_uses.erase(this);
+ _node = nullptr;
+ }
+
+ assert(_node == nullptr);
+
+ if (node != nullptr)
+ {
+ _node = node;
+ _node->_uses.insert(this);
+ }
+
+ assert(_node == node);
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/Use.test.cpp b/compiler/loco/src/IR/Use.test.cpp
new file mode 100644
index 000000000..4a2f1cc25
--- /dev/null
+++ b/compiler/loco/src/IR/Use.test.cpp
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Use.h"
+
+#include "MockupNode.h"
+
+#include <gtest/gtest.h>
+
+TEST(UseTest, constructor)
+{
+ MockupNode user;
+ loco::Use use{&user};
+
+ ASSERT_EQ(use.user(), &user);
+ ASSERT_EQ(use.node(), nullptr);
+}
+
+TEST(UseTest, link_node)
+{
+ MockupNode def;
+ MockupNode user;
+ loco::Use use{&user};
+
+ use.node(&def);
+
+ ASSERT_EQ(use.user(), &user);
+ ASSERT_EQ(use.node(), &def);
+}
diff --git a/compiler/loco/src/IR/Verifier.cpp b/compiler/loco/src/IR/Verifier.cpp
new file mode 100644
index 000000000..42735a327
--- /dev/null
+++ b/compiler/loco/src/IR/Verifier.cpp
@@ -0,0 +1,119 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Verifier.h"
+
+#include <set>
+#include <cassert>
+
+namespace
+{
+
+using namespace loco;
+
+struct GraphVerifier final
+{
+public:
+ GraphVerifier(loco::Graph *graph) : _graph{graph}
+ {
+ // graph SHOULD NOT BE null
+ assert(_graph != nullptr);
+ }
+
+public:
+ // ErrorListener SHOULD outlive GraphVerifier
+ GraphVerifier &enroll(ErrorListener *l)
+ {
+ if (l != nullptr)
+ {
+ _listeners.insert(l);
+ }
+ return (*this);
+ }
+
+ GraphVerifier &enroll(std::unique_ptr<ErrorListener> &&l)
+ {
+ if (l != nullptr)
+ {
+ _listeners.insert(l.get());
+ // Take the ownership of a given listener
+ _owned_listeners.insert(std::move(l));
+ }
+ return (*this);
+ }
+
+public:
+ void run(void) const
+ {
+ for (auto node : loco::all_nodes(_graph))
+ {
+ // Verify nodes
+ for (uint32_t n = 0; n < node->arity(); ++n)
+ {
+ if (node->arg(n) == nullptr)
+ {
+ notify(ErrorDetail<ErrorCategory::MissingArgument>{node, n});
+ }
+ }
+ }
+ }
+
+private:
+ template <typename Error> void notify(const Error &error) const
+ {
+ for (const auto &listener : _listeners)
+ {
+ listener->notify(error);
+ }
+ }
+
+private:
+ loco::Graph *_graph = nullptr;
+
+ // All active error listeners
+ std::set<ErrorListener *> _listeners;
+
+ // Owned error listeners
+ std::set<std::unique_ptr<ErrorListener>> _owned_listeners;
+};
+
+inline GraphVerifier graph_verifier(loco::Graph *graph) { return GraphVerifier{graph}; }
+
+} // namespace
+
+namespace loco
+{
+
+bool valid(Graph *g, std::unique_ptr<ErrorListener> &&l)
+{
+ class ErrorCounter final : public ErrorListener
+ {
+ public:
+ uint32_t count(void) const { return _count; }
+
+ public:
+ void notify(const ErrorDetail<ErrorCategory::MissingArgument> &) { _count += 1; }
+
+ private:
+ uint32_t _count = 0;
+ };
+
+ ErrorCounter counter;
+ graph_verifier(g).enroll(&counter).enroll(std::move(l)).run();
+ return counter.count() == 0;
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/IR/Verifier.test.cpp b/compiler/loco/src/IR/Verifier.test.cpp
new file mode 100644
index 000000000..247a59390
--- /dev/null
+++ b/compiler/loco/src/IR/Verifier.test.cpp
@@ -0,0 +1,64 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Verifier.h"
+
+#include <gtest/gtest.h>
+
+#include <stdex/Memory.h>
+#include <vector>
+
+using stdex::make_unique;
+
+TEST(VerifierTest, valid_minimal)
+{
+ auto g = loco::make_graph();
+ auto push = g->nodes()->create<loco::Push>();
+
+ ASSERT_FALSE(loco::valid(g.get()));
+}
+
+TEST(VerifierTest, valid_error_reporter)
+{
+ using namespace loco;
+
+ auto g = loco::make_graph();
+ auto push = g->nodes()->create<loco::Push>();
+
+ class Collector final : public loco::ErrorListener
+ {
+ public:
+ Collector(std::vector<ErrorDetail<ErrorCategory::MissingArgument>> *out) : _out{out}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ void notify(const ErrorDetail<ErrorCategory::MissingArgument> &d) override
+ {
+ _out->emplace_back(d);
+ }
+
+ private:
+ std::vector<ErrorDetail<ErrorCategory::MissingArgument>> *_out;
+ };
+
+ std::vector<ErrorDetail<ErrorCategory::MissingArgument>> errors;
+ ASSERT_FALSE(loco::valid(g.get(), make_unique<Collector>(&errors)));
+ ASSERT_EQ(errors.size(), 1);
+ ASSERT_EQ(errors.at(0).node(), push);
+ ASSERT_EQ(errors.at(0).index(), 0);
+}
diff --git a/compiler/loco/src/IR/Window.test.cpp b/compiler/loco/src/IR/Window.test.cpp
new file mode 100644
index 000000000..c112e0f96
--- /dev/null
+++ b/compiler/loco/src/IR/Window.test.cpp
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/IR/Window.h"
+
+#include <gtest/gtest.h>
+
+TEST(WindowTest, default_constructor_2D)
+{
+ loco::Window<2> window;
+
+ ASSERT_EQ(window.vertical(), 1);
+ ASSERT_EQ(window.horizontal(), 1);
+}
+
+TEST(WindowTest, setter_and_getter_2D)
+{
+ loco::Window<2> window;
+
+ window.vertical(2);
+
+ ASSERT_EQ(window.vertical(), 2);
+ ASSERT_EQ(window.horizontal(), 1);
+
+ window.horizontal(3);
+
+ ASSERT_EQ(window.vertical(), 2);
+ ASSERT_EQ(window.horizontal(), 3);
+}
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
new file mode 100644
index 000000000..d30a8279a
--- /dev/null
+++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
@@ -0,0 +1,774 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/CanonicalShapeInferenceRule.h"
+#include "loco/Service/ShapeInference.h"
+
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+
+#include <cassert>
+
+namespace
+{
+
+struct PlaneShape
+{
+ loco::Dimension height;
+ loco::Dimension width;
+};
+
+PlaneShape make_plane_shape(const loco::FeatureShape &feature_shape)
+{
+ PlaneShape plane_shape;
+
+ plane_shape.height = feature_shape.height();
+ plane_shape.width = feature_shape.width();
+
+ return plane_shape;
+}
+
+class FeatureShapeUpdater final
+{
+public:
+ FeatureShapeUpdater(loco::FeatureShape *ptr) : _feature_shape_ptr{ptr}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void with(const PlaneShape &plane_shape) const
+ {
+ _feature_shape_ptr->height() = plane_shape.height;
+ _feature_shape_ptr->width() = plane_shape.width;
+ }
+
+private:
+ loco::FeatureShape *_feature_shape_ptr;
+};
+
+/**
+ * HOW TO USE
+ *
+ * loco::FeatureShape feature_shape = ...;
+ *
+ * update(feature_shape).with(...)
+ */
+FeatureShapeUpdater update(loco::FeatureShape &feature_shape)
+{
+ return FeatureShapeUpdater{&feature_shape};
+}
+
+loco::Window<2> window_of(const loco::FilterShape &filter_shape)
+{
+ loco::Window<2> window;
+
+ window.vertical(filter_shape.height().value());
+ window.horizontal(filter_shape.width().value());
+
+ return window;
+}
+
+loco::Window<2> window_of(const loco::DepthwiseFilterShape &depthwise_filter_shape)
+{
+ loco::Window<2> window;
+
+ window.vertical(depthwise_filter_shape.height().value());
+ window.horizontal(depthwise_filter_shape.width().value());
+
+ return window;
+}
+
+enum class Direction
+{
+ Forward,
+ Backward,
+};
+
+template <Direction> class PlaneInference;
+
+template <> class PlaneInference<Direction::Forward> final
+{
+public:
+ PlaneShape operator()(const PlaneShape &in) const
+ {
+ assert(_pad != nullptr);
+ assert(_window != nullptr);
+ assert(_stride != nullptr);
+
+ uint32_t const raw_input_height = in.height.value();
+ uint32_t const raw_input_width = in.width.value();
+
+ uint32_t const raw_window_height = _window->vertical();
+ uint32_t const raw_window_width = _window->horizontal();
+
+ uint32_t const vertical_padding = _pad->top() + _pad->bottom();
+ uint32_t const horizontal_padding = _pad->left() + _pad->right();
+
+ uint32_t const effective_input_height = raw_input_height + vertical_padding;
+ uint32_t const effective_input_width = raw_input_width + horizontal_padding;
+
+ // NOTE To support "dilation" later
+ uint32_t const effective_window_height = raw_window_height;
+ uint32_t const effective_window_width = raw_window_width;
+
+ uint32_t const vertical_stride = _stride->vertical();
+ uint32_t const horizontal_stride = _stride->horizontal();
+
+ assert((effective_input_height - effective_window_height) % vertical_stride == 0);
+ assert((effective_input_width - effective_window_width) % horizontal_stride == 0);
+
+ PlaneShape res;
+
+ res.height = (effective_input_height - effective_window_height) / vertical_stride + 1;
+ res.width = (effective_input_width - effective_window_width) / horizontal_stride + 1;
+
+ return res;
+ }
+
+public:
+ void pad(const loco::Padding2D *value) { _pad = value; }
+ void window(const loco::Window<2> *value) { _window = value; }
+ void stride(const loco::Stride<2> *value) { _stride = value; }
+
+private:
+ const loco::Padding2D *_pad = nullptr;
+ const loco::Window<2> *_window = nullptr;
+ const loco::Stride<2> *_stride = nullptr;
+};
+
+template <> class PlaneInference<Direction::Backward> final
+{
+public:
+ PlaneShape operator()(const PlaneShape &in) const
+ {
+ assert(_pad != nullptr);
+ assert(_window != nullptr);
+ assert(_stride != nullptr);
+
+ uint32_t const input_height = in.height.value();
+ uint32_t const input_width = in.width.value();
+
+ uint32_t const vertical_padding = _pad->top() + _pad->bottom();
+ uint32_t const horizontal_padding = _pad->left() + _pad->right();
+
+ uint32_t const raw_window_height = _window->vertical();
+ uint32_t const raw_window_width = _window->horizontal();
+
+ // TODO Support "dilation"
+ uint32_t const effective_window_height = raw_window_height;
+ uint32_t const effective_window_width = raw_window_width;
+
+ uint32_t const vertical_stride = _stride->vertical();
+ uint32_t const horizontal_stride = _stride->horizontal();
+
+ PlaneShape res;
+
+ res.height = vertical_stride * (input_height - 1) + effective_window_height - vertical_padding;
+ res.width = horizontal_stride * (input_width - 1) + effective_window_width - horizontal_padding;
+
+ return res;
+ }
+
+public:
+ void pad(const loco::Padding2D *value) { _pad = value; }
+ void window(const loco::Window<2> *value) { _window = value; }
+ void stride(const loco::Stride<2> *value) { _stride = value; }
+
+private:
+ const loco::Padding2D *_pad = nullptr;
+ const loco::Window<2> *_window = nullptr;
+ const loco::Stride<2> *_stride = nullptr;
+};
+
+/**
+ * There are two possible maintenance policies.
+ * - Introduce a new canonical node first, and then extend this algorithm later
+ * - Introduce a new canonical node and extend this algorithm at the same time
+ *
+ * The current implementation assumes the former one (for historical reason).
+ *
+ * TODO Evaluate the impact of the latter one
+ *
+ * NOTE "Forward" means that this algorithm computes the ouput shape from inputs shapes
+ */
+class ForwardShapeInferenceAlgorithm final : public loco::CanonicalNodeVisitor<loco::NodeShape>
+{
+public:
+ ForwardShapeInferenceAlgorithm(const loco::ShapeInferenceRule::Context *ctx) : _ctx{ctx}
+ {
+ // DO NOTHING
+ }
+
+private:
+ const loco::ShapeInferenceRule::Context *_ctx;
+
+private:
+ bool shape_known(const loco::Node *node) const { return _ctx->known(node); }
+ loco::NodeShape node_shape(const loco::Node *node) const { return _ctx->get(node); }
+
+private:
+ loco::NodeShape eltwise_binary_node_shape(const loco::Node *node)
+ {
+ // This helper works only for binary node.
+ assert(node->arity() == 2);
+
+ auto lhs_shape = node_shape(node->arg(0));
+ auto rhs_shape = node_shape(node->arg(1));
+
+ // ASSERT: lhs_shape == rhs_shape
+
+ return lhs_shape;
+ }
+
+public:
+ // CASE: AvgPool2D
+ loco::NodeShape visit(const loco::AvgPool2D *node) final
+ {
+ PlaneInference<Direction::Forward> infer_plane_shape;
+
+ infer_plane_shape.pad(node->pad());
+ infer_plane_shape.window(node->window());
+ infer_plane_shape.stride(node->stride());
+
+ auto input_feature_shape = node_shape(node->ifm()).as<loco::FeatureShape>();
+ auto input_plane_shape = make_plane_shape(input_feature_shape);
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+ auto output_feature_shape = input_feature_shape; // AvgPool2D does not change count/depth
+
+ // Update the height/width of output_feature_shape with that of output_plane_shape
+ update(output_feature_shape).with(output_plane_shape);
+
+ return loco::NodeShape{output_feature_shape};
+ }
+
+ // CASE: BiasDecode
+ loco::NodeShape visit(const loco::BiasDecode *node) final
+ {
+ // The input of BiasDecode SHOULD BE a bias!
+ assert(node_shape(node->input()).domain() == loco::Domain::Bias);
+ auto input_bias_shape = node_shape(node->input()).as<loco::BiasShape>();
+
+ loco::TensorShape output_tensor_shape;
+
+ output_tensor_shape.rank(1);
+ output_tensor_shape.dim(0) = input_bias_shape.length();
+
+ return loco::NodeShape{output_tensor_shape};
+ }
+
+ // CASE: BiasEncode
+ loco::NodeShape visit(const loco::BiasEncode *node) final
+ {
+ // The input of BiasEncode SHOULD BE a tensor!
+ assert(node_shape(node->input()).domain() == loco::Domain::Tensor);
+ auto input_tensor_shape = node_shape(node->input()).as<loco::TensorShape>();
+
+ loco::BiasShape output_bias_shape;
+
+ output_bias_shape.length() = input_tensor_shape.dim(0);
+
+ return loco::NodeShape{output_bias_shape};
+ }
+
+ // CASE: ConstGen
+ loco::NodeShape visit(const loco::ConstGen *node) final
+ {
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(node->rank());
+ for (uint32_t axis = 0; axis < node->rank(); ++axis)
+ {
+ tensor_shape.dim(axis) = node->dim(axis);
+ }
+
+ return loco::NodeShape{tensor_shape};
+ }
+
+ // CASE: Conv2D
+ loco::NodeShape visit(const loco::Conv2D *node) final
+ {
+ auto filter_shape = node_shape(node->ker()).as<loco::FilterShape>();
+ auto filter_window = window_of(filter_shape);
+
+ PlaneInference<Direction::Forward> infer_plane_shape;
+
+ infer_plane_shape.pad(node->pad());
+ infer_plane_shape.window(&filter_window);
+ infer_plane_shape.stride(node->stride());
+
+ auto input_feature_shape = node_shape(node->ifm()).as<loco::FeatureShape>();
+ auto input_plane_shape = make_plane_shape(input_feature_shape);
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+ loco::FeatureShape output_feature_shape;
+
+ // "COUNT" does not change
+ output_feature_shape.count() = input_feature_shape.count();
+ // "DEPTH" depends on # of filters
+ output_feature_shape.depth() = filter_shape.count();
+ // Update the height/width of output_feature_shape with that of output_plane_shape
+ update(output_feature_shape).with(output_plane_shape);
+
+ return loco::NodeShape{output_feature_shape};
+ }
+
+ // CASE: DepthwiseConv2D
+ loco::NodeShape visit(const loco::DepthwiseConv2D *node) final
+ {
+ auto depthwise_filter_shape = node_shape(node->ker()).as<loco::DepthwiseFilterShape>();
+ auto dpethwise_filter_window = window_of(depthwise_filter_shape);
+
+ PlaneInference<Direction::Forward> infer_plane_shape;
+
+ infer_plane_shape.pad(node->pad());
+ infer_plane_shape.window(&dpethwise_filter_window);
+ infer_plane_shape.stride(node->stride());
+
+ auto input_feature_shape = node_shape(node->ifm()).as<loco::FeatureShape>();
+ auto input_plane_shape = make_plane_shape(input_feature_shape);
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+ loco::FeatureShape output_feature_shape;
+
+ // "COUNT" does not change
+ output_feature_shape.count() = input_feature_shape.count();
+ // "DEPTH" depends on [in_channels * channel_multiplier] of filters
+ output_feature_shape.depth() = loco::Dimension(depthwise_filter_shape.depth().value() *
+ depthwise_filter_shape.multiplier().value());
+ // Update the height/width of output_feature_shape with that of output_plane_shape
+ update(output_feature_shape).with(output_plane_shape);
+
+ return loco::NodeShape{output_feature_shape};
+ }
+
+ // CASE: DepthwiseFilterEncode
+ loco::NodeShape visit(const loco::DepthwiseFilterEncode *node) final
+ {
+ auto input_tensor_shape = node_shape(node->input()).as<loco::TensorShape>();
+ return loco::NodeShape{node->encoder()->shape(input_tensor_shape)};
+ }
+
+ // CASE: DepthwiseFilterDecode
+ loco::NodeShape visit(const loco::DepthwiseFilterDecode *node) final
+ {
+ auto input_dw_filter_shape = node_shape(node->input()).as<loco::DepthwiseFilterShape>();
+ return loco::NodeShape{node->decoder()->shape(input_dw_filter_shape)};
+ }
+
+ // CASE: EltwiseAdd
+ loco::NodeShape visit(const loco::EltwiseAdd *node) final
+ {
+ return eltwise_binary_node_shape(node);
+ }
+
+ // CASE: EltwiseDiv
+ loco::NodeShape visit(const loco::EltwiseDiv *node) final
+ {
+ return eltwise_binary_node_shape(node);
+ }
+
+ // CASE: EltwiseMax
+ loco::NodeShape visit(const loco::EltwiseMax *node) final
+ {
+ return eltwise_binary_node_shape(node);
+ }
+
+ // CASE: EltwiseMul
+ loco::NodeShape visit(const loco::EltwiseMul *node) final
+ {
+ return eltwise_binary_node_shape(node);
+ }
+
+ // CASE: EltwiseSqrt
+ loco::NodeShape visit(const loco::EltwiseSqrt *node) final { return node_shape(node->input()); }
+
+ // CASE: EltwiseSub
+ loco::NodeShape visit(const loco::EltwiseSub *node) final
+ {
+ return eltwise_binary_node_shape(node);
+ }
+
+ // CASE: Forward
+ loco::NodeShape visit(const loco::Forward *node) final { return node_shape(node->input()); }
+
+ // CASE: FeatureBiasAdd
+ loco::NodeShape visit(const loco::FeatureBiasAdd *node) final
+ {
+ assert(node_shape(node->value()).domain() == loco::Domain::Feature);
+ assert(node_shape(node->bias()).domain() == loco::Domain::Bias);
+
+ // Q. What to do when there is a mismatch between value's depth and bias's length?
+
+ return node_shape(node->value());
+ }
+
+ // CASE: FeatureDecode
+ loco::NodeShape visit(const loco::FeatureDecode *node) final
+ {
+ auto input_node_shape = node_shape(node->input());
+ return loco::NodeShape{node->decoder()->shape(input_node_shape.as<loco::FeatureShape>())};
+ }
+
+ // CASE: FeatureEncode
+ loco::NodeShape visit(const loco::FeatureEncode *node) final
+ {
+ auto input_node_shape = node_shape(node->input());
+ return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())};
+ }
+
+ // CASE: FilterDecode
+ loco::NodeShape visit(const loco::FilterDecode *node) final
+ {
+ auto input_filter_shape = node_shape(node->input()).as<loco::FilterShape>();
+ return loco::NodeShape{node->decoder()->shape(input_filter_shape)};
+ }
+
+ // CASE: FilterEncode
+ loco::NodeShape visit(const loco::FilterEncode *node) final
+ {
+ auto input_tensor_shape = node_shape(node->input()).as<loco::TensorShape>();
+ return loco::NodeShape{node->encoder()->shape(input_tensor_shape)};
+ }
+
+ // CASE: FixedReshape
+ loco::NodeShape visit(const loco::FixedReshape *node) final
+ {
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(node->rank());
+ for (uint32_t axis = 0; axis < node->rank(); ++axis)
+ {
+ tensor_shape.dim(axis) = node->dim(axis);
+ }
+
+ return loco::NodeShape{tensor_shape};
+ }
+
+ // CASE: MatMul
+ loco::NodeShape visit(const loco::MatMul *node) final
+ {
+ assert(shape_known(node->lhs()));
+ assert(shape_known(node->rhs()));
+ auto const lhs_shape = node_shape(node->lhs()).as<loco::MatrixShape>();
+ auto const rhs_shape = node_shape(node->rhs()).as<loco::MatrixShape>();
+
+ loco::MatrixShape out_shape;
+
+ // Checking shape capability for multiplication
+ assert(lhs_shape.width() == rhs_shape.height());
+
+ out_shape.height() = lhs_shape.height();
+ out_shape.width() = rhs_shape.width();
+
+ return out_shape;
+ }
+
+ // CASE: MatrixDecode
+ loco::NodeShape visit(const loco::MatrixDecode *node) final
+ {
+ auto input_node_shape = node_shape(node->input());
+ return loco::NodeShape{node->decoder()->shape(input_node_shape.as<loco::MatrixShape>())};
+ }
+
+ // CASE: MatrixEncode
+ loco::NodeShape visit(const loco::MatrixEncode *node) final
+ {
+ auto input_node_shape = node_shape(node->input());
+ return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())};
+ }
+
+ // CASE: MaxPool2D
+ loco::NodeShape visit(const loco::MaxPool2D *node) final
+ {
+ PlaneInference<Direction::Forward> infer_plane_shape;
+
+ infer_plane_shape.pad(node->pad());
+ infer_plane_shape.window(node->window());
+ infer_plane_shape.stride(node->stride());
+
+ auto input_feature_shape = node_shape(node->ifm()).as<loco::FeatureShape>();
+ auto input_plane_shape = make_plane_shape(input_feature_shape);
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+ auto output_feature_shape = input_feature_shape; // MaxPool2D does not change count/depth
+
+ // Update the height/width of output_feature_shape with that of output_plane_shape
+ update(output_feature_shape).with(output_plane_shape);
+
+ return loco::NodeShape{output_feature_shape};
+ }
+
+ // CASE: Push
+ loco::NodeShape visit(const loco::Push *node) final
+ {
+ assert(shape_known(node->from()));
+ return node_shape(node->from());
+ }
+
+ // CASE: Pull
+ loco::NodeShape visit(const loco::Pull *node) final
+ {
+ // Build a tensor shape from "Pull" node
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(node->rank());
+ for (uint32_t axis = 0; axis < node->rank(); ++axis)
+ {
+ tensor_shape.dim(axis) = node->dim(axis);
+ }
+
+ return loco::NodeShape{tensor_shape};
+ }
+
+ // CASE: ReLU
+ loco::NodeShape visit(const loco::ReLU *node) final { return node_shape(node->input()); }
+
+ // CASE: ReLU6
+ loco::NodeShape visit(const loco::ReLU6 *node) final { return node_shape(node->input()); }
+
+ // CASE: Tanh
+ loco::NodeShape visit(const loco::Tanh *node) final { return node_shape(node->input()); }
+
+ // CASE: TensorBiasAdd
+ loco::NodeShape visit(const loco::TensorBiasAdd *node) final
+ {
+ assert(node_shape(node->value()).domain() == loco::Domain::Tensor);
+ assert(node_shape(node->bias()).domain() == loco::Domain::Bias);
+
+ // Q. What to do when there is a mismatch between value's dim and bias's length?
+
+ return node_shape(node->value());
+ }
+
+ // CASE: TensorConcat
+ loco::NodeShape visit(const loco::TensorConcat *node)
+ {
+ auto const lhs_shape = node_shape(node->lhs()).as<loco::TensorShape>();
+ auto const rhs_shape = node_shape(node->rhs()).as<loco::TensorShape>();
+
+ assert(lhs_shape.rank() == rhs_shape.rank());
+ uint32_t const out_rank = lhs_shape.rank();
+
+ loco::TensorShape out_shape;
+
+ out_shape.rank(out_rank);
+
+ for (uint32_t axis = 0; axis < out_rank; ++axis)
+ {
+ if (axis == node->axis())
+ {
+ out_shape.dim(axis) = lhs_shape.dim(axis).value() + rhs_shape.dim(axis).value();
+ }
+ else
+ {
+ assert(lhs_shape.dim(axis) == rhs_shape.dim(axis));
+ out_shape.dim(axis) = lhs_shape.dim(axis);
+ }
+ }
+
+ return loco::NodeShape{out_shape};
+ }
+
+ // CASE: TensorBroadcast
+ loco::NodeShape visit(const loco::TensorBroadcast *node) final
+ {
+ auto tensor_shape = node_shape(node->input()).as<loco::TensorShape>();
+ auto const tensor_rank = tensor_shape.rank();
+
+ for (uint32_t axis = 0; axis < tensor_rank; ++axis)
+ {
+ if (node->mapping()->defined(axis))
+ {
+ tensor_shape.dim(axis) = node->mapping()->dim(axis);
+ }
+ }
+
+ return loco::NodeShape{tensor_shape};
+ }
+
+ // CASE: TensorReduce
+ loco::NodeShape visit(const loco::TensorReduce *node) final
+ {
+ auto tensor_shape = node_shape(node->input()).as<loco::TensorShape>();
+ auto const tensor_rank = tensor_shape.rank();
+
+ for (uint32_t d = 0; d < tensor_rank; ++d)
+ if (node->axes()->defined(d))
+ tensor_shape.dim(d) = 1;
+
+ return loco::NodeShape{tensor_shape};
+ }
+
+ // CASE: TensorSoftmax
+ loco::NodeShape visit(const loco::TensorSoftmax *node) final { return node_shape(node->input()); }
+
+ // CASE: TensorTranspose
+ loco::NodeShape visit(const loco::TensorTranspose *node) final
+ {
+ loco::TensorShape output_shape;
+
+ auto input_shape = node_shape(node->input()).as<loco::TensorShape>();
+ assert(input_shape.rank() == node->perm()->size());
+
+ output_shape.rank(input_shape.rank());
+
+ for (uint32_t output_axis = 0; output_axis < output_shape.rank(); output_axis++)
+ {
+ auto new_dim = input_shape.dim(node->perm()->axis(output_axis));
+ output_shape.dim(output_axis) = new_dim;
+ }
+
+ return loco::NodeShape(output_shape);
+ }
+
+ // CASE: TransposedConv2D
+ loco::NodeShape visit(const loco::TransposedConv2D *node) final
+ {
+ auto filter_shape = node_shape(node->ker()).as<loco::FilterShape>();
+ auto filter_window = window_of(filter_shape);
+
+ PlaneInference<Direction::Backward> infer_plane_shape;
+
+ infer_plane_shape.pad(node->pad());
+ infer_plane_shape.window(&filter_window);
+ infer_plane_shape.stride(node->stride());
+
+ auto input_feature_shape = node_shape(node->ifm()).as<loco::FeatureShape>();
+ auto input_plane_shape = make_plane_shape(input_feature_shape);
+ auto output_plane_shape = infer_plane_shape(input_plane_shape);
+
+ loco::FeatureShape output_feature_shape;
+
+ // "COUNT" does not change
+ output_feature_shape.count() = input_feature_shape.count();
+ // Output "DEPTH" depends on count of filters
+ output_feature_shape.depth() = filter_shape.count();
+ // Update the height/width of output_feature_shape with that of output_plane_shape
+ update(output_feature_shape).with(output_plane_shape);
+
+ return loco::NodeShape{output_feature_shape};
+ }
+
+ // CASE: TensorConstantPad
+ loco::NodeShape visit(const loco::TensorConstantPad *node) final
+ {
+ auto const tensor_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto padding = node->padding();
+
+ loco::TensorShape out_shape;
+ out_shape.rank(tensor_shape.rank());
+ for (uint32_t axis = 0; axis < out_shape.rank(); ++axis)
+ {
+ out_shape.dim(axis) =
+ tensor_shape.dim(axis).value() + padding->front(axis) + padding->back(axis);
+ }
+
+ return loco::NodeShape{out_shape};
+ }
+};
+
+} // namespace
+
+namespace
+{
+namespace compat
+{
+
+struct Context final : public loco::ShapeInferenceRule::Context
+{
+ bool known(const loco::Node *node) const final { return loco::shape_known(node); }
+ loco::NodeShape get(const loco::Node *node) const final { return loco::shape_get(node); }
+};
+
+class Sink final : public loco::ShapeInferenceRule::Sink
+{
+public:
+ enum Status
+ {
+ Unknown,
+ Okay,
+ Fail,
+ };
+
+public:
+ const Status &status(void) const { return _status; }
+ const loco::NodeShape &shape(void) const { return _shape; }
+
+public:
+ void okay(const loco::NodeShape &shape) final
+ {
+ _status = Okay;
+ _shape = shape;
+ }
+
+ void fail(void) final
+ {
+ // Notify failure
+ _status = Fail;
+ }
+
+private:
+ Status _status = Unknown;
+ loco::NodeShape _shape;
+};
+
+} // namespace compat
+} // namespace
+
+namespace loco
+{
+
+bool CanonicalShapeInferenceRule::support(const API &api) const
+{
+ return api == API::V1 or api == API::V2;
+}
+
+bool CanonicalShapeInferenceRule::recognize(const Dialect *d) const
+{
+ return CanonicalDialect::get() == d;
+}
+
+bool CanonicalShapeInferenceRule::infer(const Node *node, NodeShape &shape) const
+{
+ ::compat::Context ctx;
+ ::compat::Sink sink;
+
+ infer(&ctx, node, &sink);
+
+ assert(sink.status() == ::compat::Sink::Okay or sink.status() == ::compat::Sink::Fail);
+
+ if (sink.status() == ::compat::Sink::Fail)
+ {
+ return false;
+ }
+
+ shape = sink.shape();
+ return true;
+}
+
+void CanonicalShapeInferenceRule::infer(const Context *ctx, const Node *node, Sink *sink) const
+{
+ assert(node->dialect() == loco::CanonicalDialect::get());
+ assert(dynamic_cast<const loco::CanonicalNode *>(node) != nullptr);
+
+ ForwardShapeInferenceAlgorithm alg{ctx};
+ auto shape = dynamic_cast<const loco::CanonicalNode *>(node)->accept(&alg);
+
+ sink->okay(shape);
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
new file mode 100644
index 000000000..5cc8c3808
--- /dev/null
+++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp
@@ -0,0 +1,400 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/CanonicalShapeInferenceRule.h"
+#include "loco/Service/ShapeInference.h"
+
+#include "GraphTestcase.h"
+
+#include <vector>
+
+#include <gtest/gtest.h>
+
+TEST(CanonicalShapeInferenceRuleTest, minimal)
+{
+ // Create a simple identity network, which takes Tensor<1x2x3x4> as input.
+ GraphTestcase<GraphCode::Identity> testcase{1, 2, 3, 4};
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.push_node));
+ ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 4);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 1);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, const_gen)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::ConstGen> testcase;
+
+ testcase.const_node->dtype(loco::DataType::FLOAT32);
+ testcase.const_node->shape({1, 2});
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.push_node));
+ ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 1);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, relu)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::Relu> testcase;
+
+ testcase.pull_node->shape({1, 2, 3, 4});
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.push_node));
+ ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 4);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 1);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, feature_codec)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::FeatureCodec> testcase;
+
+ testcase.pull_node->shape({1, 2, 3, 4});
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.encode_node));
+ ASSERT_EQ(loco::shape_get(testcase.encode_node).domain(), loco::Domain::Feature);
+
+ ASSERT_TRUE(loco::shape_known(testcase.decode_node));
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().rank(), 4);
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(0), 1);
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(1), 2);
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(2), 3);
+ ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(3), 4);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, avgpool2d)
+{
+ using namespace loco;
+
+ // Create a sample network
+ GraphTestcase<GraphCode::AvgPool2D> testcase;
+
+ auto perm = make_NHWC_perm<Domain::Feature>();
+
+ testcase.pull_node->shape({1, 8, 4, 3});
+
+ testcase.encode_node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm));
+
+ testcase.avgpool2d_node->window()->vertical(2);
+ testcase.avgpool2d_node->window()->horizontal(2);
+
+ testcase.avgpool2d_node->stride()->vertical(2);
+ testcase.avgpool2d_node->stride()->horizontal(2);
+
+ testcase.decode_node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm));
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ //
+ // NOTE AvgPool2D testcase assumes NHWC layout
+ ASSERT_TRUE(loco::shape_known(testcase.avgpool2d_node));
+ ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).domain(), loco::Domain::Feature);
+ ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().count(), 1);
+ ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().depth(), 3);
+ ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().height(), 4);
+ ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().width(), 2);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, depthwiseconv2d)
+{
+ using namespace loco;
+
+ // Create a sample network
+ GraphTestcase<GraphCode::DepthwiseConv2D> testcase;
+
+ testcase.pull_node->shape({1, 4, 4, 3});
+
+ testcase.const_node->dtype(loco::DataType::FLOAT32);
+ testcase.const_node->shape({2, 2, 3, 2});
+
+ testcase.depthwiseconv2d_node->stride()->vertical(1);
+ testcase.depthwiseconv2d_node->stride()->horizontal(1);
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ //
+ // NOTE DepthwiseConv2D testcase assumes NHWC layout
+ ASSERT_TRUE(loco::shape_known(testcase.depthwiseconv2d_node));
+ ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).domain(), loco::Domain::Feature);
+ ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().count(), 1);
+ ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().depth(), 6);
+ ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().height(), 3);
+ ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().width(), 3);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, transposedconv2d)
+{
+ using namespace loco;
+
+ // Create a sample network
+ GraphTestcase<GraphCode::TransposedConv2D> testcase;
+
+ testcase.pull_node->shape({1, 270, 480, 24}); // NHWC
+
+ testcase.const_node->dtype(loco::DataType::FLOAT32);
+ testcase.const_node->shape({3, 3, 24, 12}); // HWCN (or HWIO)
+
+ testcase.tr_conv2d_node->stride()->vertical(2);
+ testcase.tr_conv2d_node->stride()->horizontal(2);
+
+ testcase.tr_conv2d_node->pad()->top(0);
+ testcase.tr_conv2d_node->pad()->bottom(1);
+ testcase.tr_conv2d_node->pad()->left(0);
+ testcase.tr_conv2d_node->pad()->right(1);
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.tr_conv2d_node));
+ ASSERT_EQ(loco::shape_get(testcase.tr_conv2d_node).domain(), loco::Domain::Feature);
+ ASSERT_EQ(loco::shape_get(testcase.tr_conv2d_node).as<FeatureShape>().count(), 1);
+ ASSERT_EQ(loco::shape_get(testcase.tr_conv2d_node).as<FeatureShape>().height(), 540);
+ ASSERT_EQ(loco::shape_get(testcase.tr_conv2d_node).as<FeatureShape>().width(), 960);
+ ASSERT_EQ(loco::shape_get(testcase.tr_conv2d_node).as<FeatureShape>().depth(), 12);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, maxpool2d)
+{
+ using namespace loco;
+
+ // Create a sample network
+ GraphTestcase<GraphCode::MaxPool2D> testcase;
+
+ auto perm = make_NHWC_perm<Domain::Feature>();
+
+ testcase.pull_node->shape({1, 8, 4, 3});
+
+ testcase.encode_node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm));
+
+ testcase.maxpool2d_node->window()->vertical(2);
+ testcase.maxpool2d_node->window()->horizontal(2);
+
+ testcase.maxpool2d_node->stride()->vertical(2);
+ testcase.maxpool2d_node->stride()->horizontal(2);
+
+ testcase.decode_node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm));
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ //
+ // NOTE MaxPool2D testcase assumes NHWC layout
+ ASSERT_TRUE(loco::shape_known(testcase.maxpool2d_node));
+ ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).domain(), loco::Domain::Feature);
+ ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().count(), 1);
+ ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().depth(), 3);
+ ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().height(), 4);
+ ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().width(), 2);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, tensor_concat)
+{
+ using namespace loco;
+
+ // Create a sample network
+ GraphTestcase<GraphCode::TensorConcat> testcase;
+
+ testcase.lhs_node->shape({1, 2, 3});
+ testcase.rhs_node->shape({1, 4, 3});
+ testcase.concat_node->axis(1);
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.concat_node));
+ ASSERT_EQ(loco::shape_get(testcase.concat_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().rank(), 3);
+ ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(0), 1);
+ ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(1), 6);
+ ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(2), 3);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, fixed_reshape)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::FixedReshape> testcase;
+
+ testcase.pull_node->shape({6, 6});
+ testcase.reshape_node->shape({4, 9});
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.push_node));
+ ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 4);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 9);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, tensor_broadcast)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::TensorBroadcast> testcase{1, 2};
+
+ testcase.broadcast_node->mapping()->dim(0) = 4;
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(testcase.push_node));
+ ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 4);
+ ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2);
+}
+
+TEST(CanonicalShapeInferenceRuleTest, tensor_transpose)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::TensorTranspose> tc;
+
+ tc.pull_node->shape({10, 20, 30, 40});
+
+ tc.transpose_node->perm()->size(4);
+ tc.transpose_node->perm()->axis(0) = 2;
+ tc.transpose_node->perm()->axis(1) = 3;
+ tc.transpose_node->perm()->axis(2) = 0;
+ tc.transpose_node->perm()->axis(3) = 1;
+
+ // Run Inference
+ loco::CanonicalShapeInferenceRule rule;
+
+ loco::apply(&rule).to(tc.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(tc.push_node));
+ ASSERT_EQ(loco::shape_get(tc.push_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().rank(), 4);
+ ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(0), 30);
+ ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(1), 40);
+ ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(2), 10);
+ ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(3), 20);
+}
+
+namespace
+{
+
+struct MockContext final : public loco::ShapeInferenceRule::Context
+{
+ bool known(const loco::Node *node) const final { return _content.find(node) != _content.end(); }
+ loco::NodeShape get(const loco::Node *node) const final { return _content.at(node); }
+
+ std::map<const loco::Node *, loco::NodeShape> _content;
+};
+
+struct MockSink final : public loco::ShapeInferenceRule::Sink
+{
+ void okay(const loco::NodeShape &res) final { shape = res; }
+ void fail(void) final { return; }
+
+ loco::NodeShape shape;
+};
+
+} // namespace
+
+TEST(CanonicalShapeInferenceRuleTest, infer_v2)
+{
+ auto g = loco::make_graph();
+
+ // Create an incomplete graph
+ auto relu_1 = g->nodes()->create<loco::ReLU>();
+ auto relu_2 = g->nodes()->create<loco::ReLU>();
+
+ relu_2->input(relu_1);
+
+ // Set up Context
+ MockContext ctx;
+
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(2);
+ tensor_shape.dim(0) = 4;
+ tensor_shape.dim(1) = 5;
+
+ ctx._content[relu_1] = tensor_shape;
+
+ // Create a Sink
+ MockSink sink;
+
+ loco::CanonicalShapeInferenceRule rule;
+
+ rule.infer(&ctx, relu_2, &sink);
+
+ ASSERT_EQ(sink.shape.domain(), loco::Domain::Tensor);
+ ASSERT_EQ(sink.shape.as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(0), 4);
+ ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(1), 5);
+}
diff --git a/compiler/loco/src/Service/GraphBuilder.h b/compiler/loco/src/Service/GraphBuilder.h
new file mode 100644
index 000000000..71084673c
--- /dev/null
+++ b/compiler/loco/src/Service/GraphBuilder.h
@@ -0,0 +1,547 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __GRAPH_BUILDER_H__
+#define __GRAPH_BUILDER_H__
+
+// loco-internal headers
+#include "loco/IR/Graph.h"
+
+// repo-internal headers
+#include <stdex/Memory.h>
+
+// C++ standard headers
+#include <stack>
+
+//
+// This file includes a stack-based loco graph builder
+//
+// HOW TO USE
+//
+// loco::Graph *g = ...
+// auto builder = make_graph_builder(g);
+//
+// builder->push<YourAwesomeLayer>(...);
+//
+
+class GraphBuilder final
+{
+public:
+ class Stack final
+ {
+ public:
+ Stack() = default;
+
+ public:
+ loco::Node *top(void) const { return _content.top(); }
+
+ public:
+ loco::Node *pop(void)
+ {
+ auto ret = top();
+ _content.pop();
+ return ret;
+ }
+
+ public:
+ void push(loco::Node *node) { _content.push(node); }
+
+ private:
+ std::stack<loco::Node *> _content;
+ };
+
+ class Context final
+ {
+ public:
+ Context(loco::Graph *graph) : _graph{graph}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ loco::Graph *graph(void) { return _graph; }
+ Stack *stack(void) { return &_stack; }
+
+ private:
+ loco::Graph *_graph = nullptr;
+ Stack _stack;
+ };
+
+public:
+ GraphBuilder(loco::Graph *graph) : _context{graph}
+ {
+ // DO NOTHING
+ }
+
+public:
+ // "Layer" is in theory a subgraph builder.
+ template <typename Layer, typename... Args>
+ auto push(Args &&... args)
+ -> decltype(static_cast<Layer *>(nullptr)->operator()(static_cast<Context *>(nullptr)))
+ {
+ Layer layer{std::forward<Args>(args)...};
+ return layer(ctx());
+ }
+
+public:
+ loco::Node *pop(void) { return ctx()->stack()->pop(); }
+
+private:
+ Context *ctx(void) { return &_context; }
+
+private:
+ Context _context;
+};
+
+static inline std::unique_ptr<GraphBuilder> make_graph_builder(loco::Graph *g)
+{
+ return stdex::make_unique<GraphBuilder>(g);
+}
+
+// "InputLayer" creates both GraphInput and Pull node at once
+struct InputLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::GraphInput *input, loco::Pull *node) : _input{input}, _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ loco::Pull *node(void) { return _node; }
+
+ public:
+ Return *name(const std::string &value)
+ {
+ _input->name(value);
+ return this;
+ }
+
+ public:
+ Return *shape(std::initializer_list<uint32_t> dims)
+ {
+ // TODO Uncomment this line when GraphInput is ready
+ // _graph_input->shape(dims)
+ _node->shape(dims);
+ return this;
+ }
+
+ private:
+ loco::GraphInput *_input = nullptr;
+ loco::Pull *_node = nullptr;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto input_index = ctx->graph()->inputs()->size();
+ auto graph_input = ctx->graph()->inputs()->create();
+
+ auto pull_node = ctx->graph()->nodes()->create<loco::Pull>();
+
+ pull_node->index(input_index);
+
+ loco::link(graph_input, pull_node);
+
+ ctx->stack()->push(pull_node);
+
+ return stdex::make_unique<Return>(graph_input, pull_node);
+ }
+};
+
+// "OutputLayer" creates both GraphOutput and Push node at once.
+struct OutputLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::GraphOutput *output, loco::Push *node) : _output{output}, _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ loco::Push *node(void) { return _node; }
+
+ public:
+ Return *name(const std::string &value)
+ {
+ // TODO Uncomment this line when GraphOutput is ready
+ // _graph_output->shape(dims)
+ _output->name(value);
+ return this;
+ }
+
+ private:
+ loco::GraphOutput *_output = nullptr;
+ loco::Push *_node = nullptr;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto output_index = ctx->graph()->outputs()->size();
+ auto graph_output = ctx->graph()->outputs()->create();
+
+ auto push_node = ctx->graph()->nodes()->create<loco::Push>();
+
+ push_node->from(ctx->stack()->pop());
+ push_node->index(output_index);
+
+ loco::link(graph_output, push_node);
+
+ ctx->stack()->push(push_node);
+
+ return stdex::make_unique<Return>(graph_output, push_node);
+ }
+};
+
+struct ReLULayer final
+{
+ // This "Return" is unnecessary for ReLU as ReLU has no attributes), but
+ // introduced for consistency.
+ class Return
+ {
+ public:
+ Return(loco::ReLU *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ loco::ReLU *node(void) { return _node; }
+
+ private:
+ loco::ReLU *_node = nullptr;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto relu_node = ctx->graph()->nodes()->create<loco::ReLU>();
+
+ relu_node->input(ctx->stack()->pop());
+
+ ctx->stack()->push(relu_node);
+
+ return stdex::make_unique<Return>(relu_node);
+ }
+};
+
+struct ConstGenLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::ConstGen *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ loco::ConstGen *node(void) { return _node; }
+
+ private:
+ loco::ConstGen *_node = nullptr;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto const_node = ctx->graph()->nodes()->create<loco::ConstGen>();
+
+ ctx->stack()->push(const_node);
+
+ return stdex::make_unique<Return>(const_node);
+ }
+};
+
+#include "loco/IR/PermutingCodec.h"
+
+struct FeatureEncodeLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::FeatureEncode *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ Return *perm(const loco::Permutation<loco::Domain::Feature> &perm)
+ {
+ using namespace loco;
+ _node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm));
+ return this;
+ }
+
+ public:
+ loco::FeatureEncode *node(void) { return _node; }
+
+ private:
+ loco::FeatureEncode *_node;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto encode_node = ctx->graph()->nodes()->create<loco::FeatureEncode>();
+
+ encode_node->input(ctx->stack()->pop());
+
+ ctx->stack()->push(encode_node);
+
+ return stdex::make_unique<Return>(encode_node);
+ }
+};
+
+struct FeatureDecodeLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::FeatureDecode *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ Return *perm(const loco::Permutation<loco::Domain::Feature> &perm)
+ {
+ using namespace loco;
+ _node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm));
+ return this;
+ }
+
+ public:
+ loco::FeatureDecode *node(void) { return _node; }
+
+ private:
+ loco::FeatureDecode *_node;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ using namespace loco;
+
+ auto decode_node = ctx->graph()->nodes()->create<FeatureDecode>();
+
+ decode_node->input(ctx->stack()->pop());
+
+ ctx->stack()->push(decode_node);
+
+ return stdex::make_unique<Return>(decode_node);
+ }
+};
+
+struct FilterEncodeLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::FilterEncode *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ Return *perm(const loco::Permutation<loco::Domain::Filter> &perm)
+ {
+ auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
+ encoder->perm(perm);
+ _node->encoder(std::move(encoder));
+ return this;
+ }
+
+ public:
+ loco::FilterEncode *node(void) { return _node; }
+
+ private:
+ loco::FilterEncode *_node;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto encode_node = ctx->graph()->nodes()->create<loco::FilterEncode>();
+
+ encode_node->input(ctx->stack()->pop());
+
+ ctx->stack()->push(encode_node);
+
+ return stdex::make_unique<Return>(encode_node);
+ }
+};
+
+struct DepthwiseFilterEncodeLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::DepthwiseFilterEncode *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ Return *perm(const loco::Permutation<loco::Domain::DepthwiseFilter> &perm)
+ {
+ using namespace loco;
+ _node->encoder(stdex::make_unique<PermutingEncoder<Domain::DepthwiseFilter>>(perm));
+ return this;
+ }
+
+ public:
+ loco::DepthwiseFilterEncode *node(void) { return _node; }
+
+ private:
+ loco::DepthwiseFilterEncode *_node;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto encode_node = ctx->graph()->nodes()->create<loco::DepthwiseFilterEncode>();
+
+ encode_node->input(ctx->stack()->pop());
+
+ ctx->stack()->push(encode_node);
+
+ return stdex::make_unique<Return>(encode_node);
+ }
+};
+
+struct DepthwiseConv2DLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::DepthwiseConv2D *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ loco::DepthwiseConv2D *node(void) { return _node; }
+
+ private:
+ loco::DepthwiseConv2D *_node;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto depthwiseconv2d_node = ctx->graph()->nodes()->create<loco::DepthwiseConv2D>();
+
+ depthwiseconv2d_node->ker(ctx->stack()->pop());
+ depthwiseconv2d_node->ifm(ctx->stack()->pop());
+
+ ctx->stack()->push(depthwiseconv2d_node);
+
+ return stdex::make_unique<Return>(depthwiseconv2d_node);
+ }
+};
+
+struct TransposedConv2DLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::TransposedConv2D *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ loco::TransposedConv2D *node(void) { return _node; }
+
+ private:
+ loco::TransposedConv2D *_node;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto tr_conv2d_node = ctx->graph()->nodes()->create<loco::TransposedConv2D>();
+
+ tr_conv2d_node->ker(ctx->stack()->pop());
+ tr_conv2d_node->ifm(ctx->stack()->pop());
+
+ ctx->stack()->push(tr_conv2d_node);
+
+ return stdex::make_unique<Return>(tr_conv2d_node);
+ }
+};
+
+struct FixedReshapeLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::FixedReshape *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ Return *shape(std::initializer_list<uint32_t> dims)
+ {
+ _node->shape(dims);
+ return this;
+ }
+
+ public:
+ loco::FixedReshape *node(void) { return _node; }
+
+ private:
+ loco::FixedReshape *_node = nullptr;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto reshape_node = ctx->graph()->nodes()->create<loco::FixedReshape>();
+
+ reshape_node->input(ctx->stack()->pop());
+
+ ctx->stack()->push(reshape_node);
+
+ return stdex::make_unique<Return>(reshape_node);
+ }
+};
+
+struct TensorBroadcastLayer final
+{
+ class Return
+ {
+ public:
+ Return(loco::TensorBroadcast *node) : _node{node}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ loco::TensorBroadcast *node(void) { return _node; }
+
+ private:
+ loco::TensorBroadcast *_node = nullptr;
+ };
+
+ std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx)
+ {
+ auto broadcast_node = ctx->graph()->nodes()->create<loco::TensorBroadcast>();
+
+ broadcast_node->input(ctx->stack()->pop());
+ ctx->stack()->push(broadcast_node);
+
+ return stdex::make_unique<Return>(broadcast_node);
+ }
+};
+
+#endif // __GRAPH_BUILDER_H__
diff --git a/compiler/loco/src/Service/GraphBuilder.test.cpp b/compiler/loco/src/Service/GraphBuilder.test.cpp
new file mode 100644
index 000000000..7b2ea5198
--- /dev/null
+++ b/compiler/loco/src/Service/GraphBuilder.test.cpp
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "GraphBuilder.h"
+
+#include "loco/IR/Nodes.h"
+#include "loco/IR/CanonicalDialect.h"
+#include "loco/IR/CanonicalOpcode.h"
+
+#include <gtest/gtest.h>
+
+TEST(GraphBuilderTest, Usecase_000)
+{
+ struct SampleLayer final
+ {
+ loco::Node *operator()(GraphBuilder::Context *ctx)
+ {
+ auto node = ctx->graph()->nodes()->create<loco::ConstGen>();
+ ctx->stack()->push(node);
+ return node;
+ }
+ };
+
+ auto g = loco::make_graph();
+ auto gbuilder = make_graph_builder(g.get());
+
+ gbuilder->push<SampleLayer>();
+
+ auto node = gbuilder->pop();
+
+ ASSERT_EQ(g->nodes()->size(), 1);
+ ASSERT_EQ(node->dialect(), loco::CanonicalDialect::get());
+ ASSERT_EQ(node->opnum(), static_cast<uint32_t>(loco::CanonicalOpcode::ConstGen));
+}
diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h
new file mode 100644
index 000000000..6743b9a14
--- /dev/null
+++ b/compiler/loco/src/Service/GraphTestcase.h
@@ -0,0 +1,541 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __GRAPH_TESTCASE_H__
+#define __GRAPH_TESTCASE_H__
+
+#include "loco/IR/Graph.h"
+#include "loco/IR/PermutingCodec.h"
+
+#include "GraphBuilder.h"
+
+#include <stdex/Memory.h>
+
+enum class GraphCode
+{
+ Identity,
+ ConstGen,
+ Relu,
+ FeatureCodec,
+ AvgPool2D,
+ DepthwiseConv2D,
+ TransposedConv2D,
+ MaxPool2D,
+ TensorBroadcast,
+ TensorConcat,
+ TensorTranspose,
+ FixedReshape,
+};
+
+namespace
+{
+
+template <loco::Domain D> loco::Permutation<D> make_NHWC_perm(void);
+
+template <> loco::Permutation<loco::Domain::Feature> make_NHWC_perm(void)
+{
+ loco::Permutation<loco::Domain::Feature> perm;
+
+ perm[loco::FeatureAxis::Count] = 0;
+ perm[loco::FeatureAxis::Height] = 1;
+ perm[loco::FeatureAxis::Width] = 2;
+ perm[loco::FeatureAxis::Depth] = 3;
+
+ return perm;
+}
+
+template <loco::Domain D> loco::Permutation<D> make_HWCN_perm(void);
+
+// @note Also known as HWIO permutation
+template <> loco::Permutation<loco::Domain::Filter> make_HWCN_perm(void)
+{
+ loco::Permutation<loco::Domain::Filter> perm;
+
+ perm[loco::FilterAxis::Height] = 0;
+ perm[loco::FilterAxis::Width] = 1;
+ perm[loco::FilterAxis::Depth] = 2;
+ perm[loco::FilterAxis::Count] = 3;
+
+ return perm;
+}
+
+template <loco::Domain D> loco::Permutation<D> make_HWCM_perm(void);
+
+template <> loco::Permutation<loco::Domain::DepthwiseFilter> make_HWCM_perm(void)
+{
+ loco::Permutation<loco::Domain::DepthwiseFilter> perm;
+
+ perm[loco::DepthwiseFilterAxis::Height] = 0;
+ perm[loco::DepthwiseFilterAxis::Width] = 1;
+ perm[loco::DepthwiseFilterAxis::Depth] = 2;
+ perm[loco::DepthwiseFilterAxis::Multiplier] = 3;
+
+ return perm;
+}
+
+} // namespace
+
+template <GraphCode Code> class GraphTestcase;
+
+template <> class GraphTestcase<GraphCode::Identity> final
+{
+private:
+ void init(std::initializer_list<uint32_t> dims)
+ {
+ // Create a sample network
+ _graph = loco::make_graph();
+
+ auto graph_builder = make_graph_builder(_graph.get());
+
+ pull_node = graph_builder->push<InputLayer>()->name("input")->shape(dims)->node();
+ push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+ }
+
+public:
+ // NOTE This default constructor guarantees backward compatbility.
+ GraphTestcase() { init({1, 4, 8, 3}); }
+ GraphTestcase(std::initializer_list<uint32_t> dims) { init(dims); }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::ConstGen> final
+{
+public:
+ GraphTestcase()
+ {
+ _graph = loco::make_graph();
+
+ auto graph_builder = make_graph_builder(_graph.get());
+
+ const_node = graph_builder->push<ConstGenLayer>()->node();
+
+ push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::ConstGen *const_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::Relu> final
+{
+public:
+ GraphTestcase()
+ {
+ // Create a sample network
+ _graph = loco::make_graph();
+
+ auto graph_builder = make_graph_builder(_graph.get());
+
+ pull_node = graph_builder->push<InputLayer>()->name("input")->node();
+ relu_node = graph_builder->push<ReLULayer>()->node();
+ push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::ReLU *relu_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::FeatureCodec> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ Permutation<Domain::Feature> perm;
+
+ perm[FeatureAxis::Count] = 0;
+ perm[FeatureAxis::Height] = 1;
+ perm[FeatureAxis::Width] = 2;
+ perm[FeatureAxis::Depth] = 3;
+
+ // Create a sample network
+ _graph = make_graph();
+
+ auto graph_builder = make_graph_builder(_graph.get());
+
+ pull_node = graph_builder->push<InputLayer>()->name("input")->node();
+ encode_node = graph_builder->push<FeatureEncodeLayer>()->perm(perm)->node();
+ decode_node = graph_builder->push<FeatureDecodeLayer>()->perm(perm)->node();
+ push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::FeatureEncode *encode_node = nullptr;
+ loco::FeatureDecode *decode_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::AvgPool2D> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ // Create a sample network
+ _graph = make_graph();
+
+ // Create Graph Input/Output
+ auto graph_input = _graph->inputs()->create();
+ auto graph_output = _graph->outputs()->create();
+
+ graph_input->name("input");
+ graph_output->name("output");
+
+ // Create and connect nodes
+ pull_node = _graph->nodes()->create<Pull>();
+ pull_node->index(0);
+
+ encode_node = _graph->nodes()->create<FeatureEncode>();
+ encode_node->input(pull_node);
+
+ avgpool2d_node = _graph->nodes()->create<AvgPool2D>();
+ avgpool2d_node->ifm(encode_node);
+
+ decode_node = _graph->nodes()->create<FeatureDecode>();
+ decode_node->input(avgpool2d_node);
+
+ push_node = _graph->nodes()->create<loco::Push>();
+ push_node->index(0);
+ push_node->from(decode_node);
+
+ // Create a link between input/output and corresponding nodes
+ loco::link(graph_input, pull_node);
+ loco::link(graph_output, push_node);
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::FeatureEncode *encode_node = nullptr;
+ loco::AvgPool2D *avgpool2d_node = nullptr;
+ loco::FeatureDecode *decode_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::DepthwiseConv2D> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ _graph = make_graph();
+
+ auto graph_builder = make_graph_builder(_graph.get());
+
+ Permutation<Domain::Feature> perm = make_NHWC_perm<Domain::Feature>();
+ Permutation<Domain::DepthwiseFilter> filter_perm = make_HWCM_perm<Domain::DepthwiseFilter>();
+
+ pull_node = graph_builder->push<InputLayer>()->name("input")->node();
+ encode_node = graph_builder->push<FeatureEncodeLayer>()->perm(perm)->node();
+
+ const_node = graph_builder->push<ConstGenLayer>()->node();
+
+ filter_encode_node =
+ graph_builder->push<DepthwiseFilterEncodeLayer>()->perm(filter_perm)->node();
+
+ depthwiseconv2d_node = graph_builder->push<DepthwiseConv2DLayer>()->node();
+
+ decode_node = graph_builder->push<FeatureDecodeLayer>()->perm(perm)->node();
+ push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::FeatureEncode *encode_node = nullptr;
+ loco::ConstGen *const_node = nullptr;
+ loco::DepthwiseFilterEncode *filter_encode_node = nullptr;
+ loco::DepthwiseConv2D *depthwiseconv2d_node = nullptr;
+ loco::FeatureDecode *decode_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::TransposedConv2D> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ // Prepare permutations
+ Permutation<Domain::Feature> feature_perm = make_NHWC_perm<Domain::Feature>();
+ Permutation<Domain::Filter> filter_perm = make_HWCN_perm<Domain::Filter>();
+
+ // Build graph
+ _graph = make_graph();
+ auto graph_builder = make_graph_builder(_graph.get());
+
+ pull_node = graph_builder->push<InputLayer>()->name("input")->node();
+ encode_node = graph_builder->push<FeatureEncodeLayer>()->perm(feature_perm)->node();
+ const_node = graph_builder->push<ConstGenLayer>()->node();
+ filter_encode_node = graph_builder->push<FilterEncodeLayer>()->perm(filter_perm)->node();
+ tr_conv2d_node = graph_builder->push<TransposedConv2DLayer>()->node();
+ decode_node = graph_builder->push<FeatureDecodeLayer>()->perm(feature_perm)->node();
+ push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::FeatureEncode *encode_node = nullptr;
+ loco::ConstGen *const_node = nullptr;
+ loco::FilterEncode *filter_encode_node = nullptr;
+ loco::TransposedConv2D *tr_conv2d_node = nullptr;
+ loco::FeatureDecode *decode_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::MaxPool2D> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ // Create a sample network
+ _graph = make_graph();
+
+ // Create Graph Input/Output
+ auto graph_input = _graph->inputs()->create();
+ auto graph_output = _graph->outputs()->create();
+
+ graph_input->name("input");
+ graph_output->name("output");
+
+ // Create and connect nodes
+ pull_node = _graph->nodes()->create<Pull>();
+ pull_node->index(0);
+
+ encode_node = _graph->nodes()->create<FeatureEncode>();
+ encode_node->input(pull_node);
+
+ maxpool2d_node = _graph->nodes()->create<MaxPool2D>();
+ maxpool2d_node->ifm(encode_node);
+
+ decode_node = _graph->nodes()->create<FeatureDecode>();
+ decode_node->input(maxpool2d_node);
+
+ push_node = _graph->nodes()->create<loco::Push>();
+ push_node->index(0);
+ push_node->from(decode_node);
+
+ // Create a link between input/output and corresponding nodes
+ loco::link(graph_input, pull_node);
+ loco::link(graph_output, push_node);
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::FeatureEncode *encode_node = nullptr;
+ loco::MaxPool2D *maxpool2d_node = nullptr;
+ loco::FeatureDecode *decode_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::TensorConcat> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ // Create a sample network
+ _graph = make_graph();
+
+ // Create Graph Input/Output
+ auto graph_lhs = _graph->inputs()->create();
+ auto graph_rhs = _graph->inputs()->create();
+ auto graph_out = _graph->outputs()->create();
+
+ graph_lhs->name("lhs");
+ graph_rhs->name("rhs");
+ graph_out->name("output");
+
+ // Create and connect nodes
+ lhs_node = _graph->nodes()->create<Pull>();
+ lhs_node->index(0);
+
+ rhs_node = _graph->nodes()->create<Pull>();
+ rhs_node->index(1);
+
+ concat_node = _graph->nodes()->create<TensorConcat>();
+ concat_node->lhs(lhs_node);
+ concat_node->rhs(rhs_node);
+
+ push_node = _graph->nodes()->create<loco::Push>();
+ push_node->index(0);
+ push_node->from(concat_node);
+
+ // Create a link between input/output and corresponding nodes
+ loco::link(graph_lhs, lhs_node);
+ loco::link(graph_rhs, rhs_node);
+ loco::link(graph_out, push_node);
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *lhs_node = nullptr;
+ loco::Pull *rhs_node = nullptr;
+ loco::TensorConcat *concat_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::FixedReshape> final
+{
+public:
+ GraphTestcase()
+ {
+ _graph = loco::make_graph();
+
+ auto graph_builder = make_graph_builder(_graph.get());
+
+ pull_node = graph_builder->push<InputLayer>()->name("input")->node();
+ reshape_node = graph_builder->push<FixedReshapeLayer>()->node();
+ push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::FixedReshape *reshape_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::TensorBroadcast> final
+{
+public:
+ GraphTestcase(std::initializer_list<uint32_t> dims)
+ {
+ _graph = loco::make_graph();
+
+ auto graph_builder = make_graph_builder(_graph.get());
+
+ pull_node = graph_builder->push<InputLayer>()->name("input")->shape(dims)->node();
+ broadcast_node = graph_builder->push<TensorBroadcastLayer>()->node();
+ push_node = graph_builder->push<OutputLayer>()->name("output")->node();
+ }
+
+public:
+ loco::Graph *graph(void) { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::TensorBroadcast *broadcast_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+template <> class GraphTestcase<GraphCode::TensorTranspose> final
+{
+public:
+ GraphTestcase()
+ {
+ using namespace loco;
+
+ // Create a sample network
+ _graph = make_graph();
+
+ // Create Graph Input/Output
+ auto graph_input = _graph->inputs()->create();
+ auto graph_output = _graph->outputs()->create();
+
+ graph_input->name("input");
+ graph_output->name("output");
+
+ // Create and connect nodes
+ pull_node = _graph->nodes()->create<Pull>();
+ pull_node->index(0);
+
+ transpose_node = _graph->nodes()->create<TensorTranspose>();
+ transpose_node->input(pull_node);
+
+ push_node = _graph->nodes()->create<loco::Push>();
+ push_node->index(0);
+ push_node->from(transpose_node);
+
+ // Create a link between input/output and corresponding nodes
+ loco::link(graph_input, pull_node);
+ loco::link(graph_output, push_node);
+ }
+
+public:
+ loco::Graph *graph() { return _graph.get(); }
+
+ loco::Pull *pull_node = nullptr;
+ loco::TensorTranspose *transpose_node = nullptr;
+ loco::Push *push_node = nullptr;
+
+private:
+ std::unique_ptr<loco::Graph> _graph;
+};
+
+#endif // __GRAPH_TESTCASE_H__
diff --git a/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp
new file mode 100644
index 000000000..2178f5d05
--- /dev/null
+++ b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/MultiDialectShapeInferenceRule.h"
+#include "loco/Service/ShapeInferenceRule.h"
+
+#include <loco/IR/Dialect.h>
+#include <loco/IR/Node.h>
+#include <loco/IR/NodeShape.h>
+
+#include <cassert>
+
+namespace loco
+{
+
+bool MultiDialectShapeInferenceRule::recognize(const Dialect *d) const
+{
+ const auto found = _rules.find(d);
+
+ if (found == _rules.cend())
+ return false;
+
+ auto rule = found->second;
+ auto result = rule->recognize(d);
+
+ return result;
+}
+
+bool MultiDialectShapeInferenceRule::infer(const Node *node, NodeShape &shape) const
+{
+ const auto found = _rules.find(node->dialect());
+
+ if (found == _rules.cend())
+ return false;
+
+ auto rule = found->second;
+ if (rule->infer(node, shape))
+ return true;
+
+ return false;
+}
+
+MultiDialectShapeInferenceRule &MultiDialectShapeInferenceRule::bind(const Dialect *d,
+ const ShapeInferenceRule *rule)
+{
+ assert(_rules.find(d) == _rules.end());
+ assert(rule->recognize(d));
+
+ _rules[d] = rule;
+
+ return (*this);
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp
new file mode 100644
index 000000000..ffa9ee5ca
--- /dev/null
+++ b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp
@@ -0,0 +1,134 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/CanonicalShapeInferenceRule.h"
+#include "loco/Service/MultiDialectShapeInferenceRule.h"
+#include "loco/Service/ShapeInference.h"
+
+#include <loco/IR/Dialect.h>
+#include <loco/IR/CanonicalDialect.h>
+
+#include <gtest/gtest.h>
+
+#include <cassert>
+#include <vector>
+
+// mockup for MultiDialectShapeInferenceRule
+// Each class is dedicated for handling shape { D1, D2 } and D1, D2 are declared as a template
+namespace
+{
+
+template <uint32_t D1, uint32_t D2> class TestDialect final : public loco::Dialect
+{
+public:
+ static Dialect *get(void)
+ {
+ static TestDialect<D1, D2> d;
+ return &d;
+ }
+};
+
+template <uint32_t D1, uint32_t D2>
+struct TestOpNode final : public loco::FixedArity<1>::Mixin<loco::Node>,
+ public loco::NodeMixin<loco::NodeTrait::TensorShape>
+{
+ void input(Node *node) { at(0)->node(node); }
+ const loco::Dialect *dialect(void) const final { return TestDialect<D1, D2>::get(); }
+ uint32_t opnum(void) const final { return static_cast<uint32_t>(D1); /* not used */ }
+};
+
+template <uint32_t D1, uint32_t D2>
+struct TestShapeInferenceRule final : public loco::ShapeInferenceRule
+{
+public:
+ bool recognize(const loco::Dialect *d) const final { return (d == TestDialect<D1, D2>::get()); }
+
+ bool infer(const loco::Node *node, loco::NodeShape &node_shape) const final
+ {
+ assert(recognize(node->dialect()));
+ auto test_node = dynamic_cast<const TestOpNode<D1, D2> *>(node);
+ assert(test_node != nullptr);
+
+ loco::TensorShape ts;
+ {
+ ts.rank(2);
+ ts.dim(0) = D1;
+ ts.dim(1) = D2; // making shape : { D1, D2 }
+ }
+
+ node_shape.set(ts);
+
+ return true;
+ }
+};
+
+} // namespace
+
+TEST(MultiDialectShapeInferenceRuleTest, test1)
+{
+ // Create a simple network : Pull ------- t23<2,3> ------------ t45<4,5> ---------- Push
+ // TensorShape({2, 3}) TensorShape({4, 5})
+ auto g = loco::make_graph();
+
+ auto pull_node = g->nodes()->create<loco::Pull>();
+ auto t23_node = g->nodes()->create<TestOpNode<2, 3>>();
+ auto t45_node = g->nodes()->create<TestOpNode<4, 5>>();
+ auto push_node = g->nodes()->create<loco::Push>();
+
+ t23_node->input(pull_node);
+ t45_node->input(t23_node);
+ push_node->from(t45_node);
+
+ auto graph_input = g->inputs()->create();
+ graph_input->name("input");
+ loco::link(graph_input, pull_node);
+
+ auto graph_output = g->outputs()->create();
+ graph_output->name("output");
+ loco::link(graph_output, push_node);
+
+ // initially they don't have shape info
+ ASSERT_FALSE(loco::shape_known(t23_node));
+ ASSERT_FALSE(loco::shape_known(t45_node));
+
+ // Run Type Inference
+ loco::CanonicalShapeInferenceRule canonical_rule;
+ TestShapeInferenceRule<2, 3> t23_rule;
+ TestShapeInferenceRule<4, 5> t45_rule;
+
+ loco::MultiDialectShapeInferenceRule rules;
+
+ rules.bind(loco::CanonicalDialect::get(), &canonical_rule)
+ .bind(TestDialect<2, 3>::get(), &t23_rule)
+ .bind(TestDialect<4, 5>::get(), &t45_rule);
+
+ loco::apply(&rules).to(g.get());
+
+ // Verify!
+ ASSERT_TRUE(loco::shape_known(t23_node));
+ auto t23_shape = loco::shape_get(t23_node);
+ ASSERT_EQ(t23_shape.domain(), loco::Domain::Tensor);
+ ASSERT_EQ(t23_shape.as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(t23_shape.as<loco::TensorShape>().dim(0), 2);
+ ASSERT_EQ(t23_shape.as<loco::TensorShape>().dim(1), 3);
+
+ ASSERT_TRUE(loco::shape_known(t45_node));
+ auto t45_shape = loco::shape_get(t45_node);
+ ASSERT_EQ(t45_shape.domain(), loco::Domain::Tensor);
+ ASSERT_EQ(t45_shape.as<loco::TensorShape>().rank(), 2);
+ ASSERT_EQ(t45_shape.as<loco::TensorShape>().dim(0), 4);
+ ASSERT_EQ(t45_shape.as<loco::TensorShape>().dim(1), 5);
+}
diff --git a/compiler/loco/src/Service/ShapeInference.cpp b/compiler/loco/src/Service/ShapeInference.cpp
new file mode 100644
index 000000000..84eb10963
--- /dev/null
+++ b/compiler/loco/src/Service/ShapeInference.cpp
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/ShapeInference.h"
+#include "loco/IR/Algorithm.h"
+
+#include <cassert>
+
+#include <stdex/Memory.h>
+
+namespace
+{
+
+bool inputs_shape_ready(loco::Node *node)
+{
+ assert(node != nullptr);
+
+ for (uint32_t arity = 0; arity < node->arity(); ++arity)
+ {
+ if (!loco::ShapeInference::known(node->arg(arity)))
+ {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+//
+// Infrastructure
+//
+namespace
+{
+
+struct ShapeAnnotation : public loco::NodeAnnotation
+{
+public:
+ ShapeAnnotation(const loco::NodeShape &shape) : _shape{shape}
+ {
+ // DO NOTHING
+ }
+
+public:
+ const loco::NodeShape &shape(void) const { return _shape; }
+
+private:
+ loco::NodeShape _shape;
+};
+
+} // namespace
+
+namespace loco
+{
+
+bool ShapeInferenceSession::to(Graph *g) const
+{
+ assert(_rule->support(ShapeInferenceRule::API::V1) && "API v1 is unavailable");
+
+ bool changed = false;
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ if (_rule->recognize(node->dialect()))
+ {
+ loco::NodeShape shape;
+
+ if (!shape_known(node) && inputs_shape_ready(node))
+ {
+ if (_rule->infer(node, shape))
+ {
+ node->annot(stdex::make_unique<ShapeAnnotation>(shape));
+ changed = true;
+ }
+ }
+ }
+ }
+
+ return changed;
+}
+
+bool ShapeInference::known(const Node *node) { return node->annot<ShapeAnnotation>() != nullptr; }
+
+NodeShape ShapeInference::get(const Node *node)
+{
+ assert(known(node));
+ return node->annot<ShapeAnnotation>()->shape();
+}
+
+void ShapeInference::erase(Node *node) { node->annot<ShapeAnnotation>(nullptr); }
+
+} // namespace loco
diff --git a/compiler/loco/src/Service/ShapeInference.test.cpp b/compiler/loco/src/Service/ShapeInference.test.cpp
new file mode 100644
index 000000000..e10b98844
--- /dev/null
+++ b/compiler/loco/src/Service/ShapeInference.test.cpp
@@ -0,0 +1,87 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/ShapeInference.h"
+#include "GraphTestcase.h"
+
+#include <vector>
+
+#include <gtest/gtest.h>
+
+// This test validates whether framework works as expected.
+TEST(ShapeInferenceTest, framework)
+{
+ // Mock-up Shape Inference Rule
+ struct SampleShapeInferenceRule final : public loco::ShapeInferenceRule
+ {
+ public:
+ SampleShapeInferenceRule(std::vector<const loco::Node *> *nodes) : _nodes{nodes}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ // Accept all the dialects
+ bool recognize(const loco::Dialect *) const final { return true; }
+
+ bool infer(const loco::Node *node, loco::NodeShape &shape) const final
+ {
+ // Record the order of inference
+ _nodes->emplace_back(node);
+
+ if (_nodes->size() != 1)
+ {
+ return false;
+ }
+
+ // Set the first node as Tensor<1>
+ loco::TensorShape tensor_shape;
+
+ tensor_shape.rank(1);
+ tensor_shape.dim(0) = 4;
+
+ shape.set(tensor_shape);
+
+ return true;
+ }
+
+ private:
+ std::vector<const loco::Node *> *_nodes;
+ };
+
+ GraphTestcase<GraphCode::Identity> testcase;
+
+ std::vector<const loco::Node *> nodes;
+
+ SampleShapeInferenceRule rule{&nodes};
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Framework SHOULD visit all the nodes
+ ASSERT_EQ(nodes.size(), 2);
+ // Framework SHOULD visit "pull" before "push"
+ ASSERT_EQ(nodes.at(0), testcase.pull_node);
+ ASSERT_EQ(nodes.at(1), testcase.push_node);
+
+ // Framework SHOULD make an annotation if "rule" returns TRUE
+ ASSERT_TRUE(loco::shape_known(testcase.pull_node));
+ ASSERT_EQ(loco::shape_get(testcase.pull_node).domain(), loco::Domain::Tensor);
+ ASSERT_EQ(loco::shape_get(testcase.pull_node).as<loco::TensorShape>().rank(), 1);
+ ASSERT_EQ(loco::shape_get(testcase.pull_node).as<loco::TensorShape>().dim(0), 4);
+
+ // Framework SHOULD NOT make any annotation if "rule" returns FALSE
+ ASSERT_FALSE(loco::shape_known(testcase.push_node));
+}
diff --git a/compiler/loco/src/Service/ShapeInferenceRule.cpp b/compiler/loco/src/Service/ShapeInferenceRule.cpp
new file mode 100644
index 000000000..bed841260
--- /dev/null
+++ b/compiler/loco/src/Service/ShapeInferenceRule.cpp
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/ShapeInferenceRule.h"
+
+#include <stdexcept>
+
+// This file validates "ShapeInferenceRule.h". Please DO NOT remove this file.
+
+namespace loco
+{
+
+void ShapeInferenceRule::infer(const Context *, const Node *, Sink *) const
+{
+ throw std::runtime_error{"API v2 is not supported"};
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/Service/TypeInference.cpp b/compiler/loco/src/Service/TypeInference.cpp
new file mode 100644
index 000000000..fbf0033ee
--- /dev/null
+++ b/compiler/loco/src/Service/TypeInference.cpp
@@ -0,0 +1,228 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/TypeInference.h"
+
+#include "loco/IR/Algorithm.h"
+
+#include <cassert>
+
+#include <stdex/Memory.h>
+
+namespace
+{
+
+struct DataTypeAnnotation : public loco::NodeAnnotation
+{
+public:
+ DataTypeAnnotation(const loco::DataType &dtype) : _dtype{dtype}
+ {
+ // DO NOTHING
+ }
+
+public:
+ const loco::DataType &dtype(void) const { return _dtype; }
+
+private:
+ loco::DataType _dtype;
+};
+
+bool inputs_dtype_ready(loco::Node *node)
+{
+ assert(node != nullptr);
+
+ for (uint32_t arity = 0; arity < node->arity(); ++arity)
+ {
+ if (!loco::TypeInference::known(node->arg(arity)))
+ {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+namespace loco
+{
+
+bool TypeInferenceSession::to(Graph *g) const
+{
+ bool changed = false;
+
+ for (auto node : postorder_traversal(output_nodes(g)))
+ {
+ if (_rule->recognize(node->dialect()))
+ {
+ DataType dtype = DataType::Unknown;
+
+ if (!dtype_known(node) && inputs_dtype_ready(node))
+ {
+ if (_rule->infer(node, dtype))
+ {
+ node->annot(stdex::make_unique<DataTypeAnnotation>(dtype));
+ changed = true;
+ }
+ }
+ }
+ }
+
+ return changed;
+}
+
+bool TypeInference::known(const Node *node) { return node->annot<DataTypeAnnotation>() != nullptr; }
+
+DataType TypeInference::get(const Node *node)
+{
+ assert(known(node));
+ return node->annot<DataTypeAnnotation>()->dtype();
+}
+
+void TypeInference::erase(Node *node) { return node->annot<DataTypeAnnotation>(nullptr); }
+
+} // namespace loco
+
+//
+// Canonical (Data) Type Inference Rule
+//
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/IR/CanonicalNode.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+
+namespace
+{
+
+/**
+ * There are two possible maintenance policies.
+ * - Introduce a new canonical node first, and then extend this algorithm later
+ * - Introduce a new canonical node and extend this algorithm at the same time
+ *
+ * The current implementation assumes the former one (for historical reason).
+ *
+ * TODO Evaluate the impact of the latter one
+ */
+struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitor<loco::DataType>
+{
+ loco::DataType visit(const loco::AvgPool2D *node) { return loco::dtype_get(node->ifm()); }
+ loco::DataType visit(const loco::BiasDecode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::BiasEncode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::ConstGen *node) { return node->dtype(); }
+ loco::DataType visit(const loco::Conv2D *node) { return loco::dtype_get(node->ifm()); }
+ loco::DataType visit(const loco::DepthwiseConv2D *node) { return loco::dtype_get(node->ifm()); }
+ loco::DataType visit(const loco::DepthwiseFilterEncode *node)
+ {
+ return loco::dtype_get(node->input());
+ }
+ loco::DataType visit(const loco::DepthwiseFilterDecode *node)
+ {
+ return loco::dtype_get(node->input());
+ }
+ loco::DataType visit(const loco::EltwiseAdd *node) { return loco::dtype_get(node->lhs()); }
+ loco::DataType visit(const loco::EltwiseDiv *node) { return loco::dtype_get(node->lhs()); }
+ loco::DataType visit(const loco::EltwiseMax *node) { return loco::dtype_get(node->lhs()); }
+ loco::DataType visit(const loco::EltwiseMul *node) { return loco::dtype_get(node->lhs()); }
+ loco::DataType visit(const loco::EltwiseSqrt *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::EltwiseSub *node) { return loco::dtype_get(node->lhs()); }
+ loco::DataType visit(const loco::Forward *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::FeatureBiasAdd *node) { return loco::dtype_get(node->value()); }
+ loco::DataType visit(const loco::FeatureDecode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::FeatureEncode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::FilterDecode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::FilterEncode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::FixedReshape *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::MatrixDecode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::MatrixEncode *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::MatMul *node) { return loco::dtype_get(node->lhs()); }
+ loco::DataType visit(const loco::MaxPool2D *node) { return loco::dtype_get(node->ifm()); }
+ loco::DataType visit(const loco::Push *node) { return loco::dtype_get(node->from()); }
+ loco::DataType visit(const loco::Pull *node) { return node->dtype(); }
+ loco::DataType visit(const loco::ReLU *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::ReLU6 *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::Tanh *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::TensorConcat *node) { return loco::dtype_get(node->lhs()); }
+ loco::DataType visit(const loco::TensorConstantPad *node)
+ {
+ return loco::dtype_get(node->input());
+ }
+ loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); }
+ loco::DataType visit(const loco::TensorBroadcast *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::TensorReduce *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::TensorSoftmax *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::TensorTranspose *node) { return loco::dtype_get(node->input()); }
+ loco::DataType visit(const loco::TransposedConv2D *node) { return loco::dtype_get(node->ifm()); }
+};
+
+} // namespace
+
+namespace loco
+{
+
+bool CanonicalTypeInferenceRule::recognize(const Dialect *d) const
+{
+ // This rule recognizes only "loco.canonical" dialect!
+ return CanonicalDialect::get() == d;
+}
+
+bool CanonicalTypeInferenceRule::infer(const Node *node, DataType &dtype) const
+{
+ assert(node->dialect() == loco::CanonicalDialect::get());
+ assert(dynamic_cast<const loco::CanonicalNode *>(node) != nullptr);
+
+ CanonicalTypeForwardAlgorithm alg;
+ dtype = dynamic_cast<const loco::CanonicalNode *>(node)->accept(&alg);
+
+ return true;
+}
+
+bool MultiDialectTypeInferenceRule::recognize(const Dialect *d) const
+{
+ const auto found = _rules.find(d);
+
+ if (found == _rules.cend())
+ return false;
+
+ auto rule = found->second;
+ auto result = rule->recognize(d);
+
+ return result;
+}
+
+bool MultiDialectTypeInferenceRule::infer(const Node *node, DataType &dtype) const
+{
+ const auto found = _rules.find(node->dialect());
+
+ if (found == _rules.cend())
+ return false;
+
+ auto rule = found->second;
+ if (rule->infer(node, dtype))
+ return true;
+
+ return false;
+}
+
+MultiDialectTypeInferenceRule &MultiDialectTypeInferenceRule::bind(const Dialect *d,
+ const TypeInferenceRule *rule)
+{
+ assert(_rules.find(d) == _rules.end());
+ assert(rule->recognize(d));
+
+ _rules[d] = rule;
+
+ return (*this);
+}
+
+} // namespace loco
diff --git a/compiler/loco/src/Service/TypeInference.test.cpp b/compiler/loco/src/Service/TypeInference.test.cpp
new file mode 100644
index 000000000..4660401db
--- /dev/null
+++ b/compiler/loco/src/Service/TypeInference.test.cpp
@@ -0,0 +1,282 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco/Service/TypeInference.h"
+
+#include "GraphTestcase.h"
+
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/Service/TypeInference.h>
+
+#include <vector>
+
+#include <gtest/gtest.h>
+
+// This test validates whether framework works as expected.
+TEST(TypeInferenceTest, framework)
+{
+ // Create a sample network
+ auto g = loco::make_graph();
+
+ auto pull_node = g->nodes()->create<loco::Pull>();
+ auto push_node = g->nodes()->create<loco::Push>();
+
+ push_node->from(pull_node);
+
+ // Create Graph Input & Output
+ auto graph_input = g->inputs()->create();
+
+ graph_input->name("input");
+ loco::link(graph_input, pull_node);
+
+ auto graph_output = g->outputs()->create();
+
+ graph_output->name("output");
+ loco::link(graph_output, push_node);
+
+ // Mock-up Type Inference Rule
+ struct SampleTypeInferenceRule final : public loco::TypeInferenceRule
+ {
+ public:
+ SampleTypeInferenceRule(std::vector<const loco::Node *> *nodes) : _nodes{nodes}
+ {
+ // DO NOTHING
+ }
+
+ public:
+ bool recognize(const loco::Dialect *) const final
+ {
+ // Accept all the dialects
+ return true;
+ }
+
+ bool infer(const loco::Node *node, loco::DataType &dtype) const final
+ {
+ // Record the order of inference
+ _nodes->emplace_back(node);
+
+ if (_nodes->size() != 1)
+ {
+ return false;
+ }
+
+ // Annotate the first node as "U8"
+ dtype = loco::DataType::U8;
+ return true;
+ }
+
+ private:
+ std::vector<const loco::Node *> *_nodes;
+ };
+
+ std::vector<const loco::Node *> nodes;
+
+ SampleTypeInferenceRule rule{&nodes};
+
+ loco::apply(&rule).to(g.get());
+
+ ASSERT_EQ(nodes.size(), 2); // Framework SHOULD visit all the nodes
+ ASSERT_EQ(nodes.at(0), pull_node); // Framework SHOULD visit "pull" before "push"
+ ASSERT_EQ(nodes.at(1), push_node);
+
+ // Framework SHOULD NOT make any annotation if "rule" returns FALSE
+ ASSERT_TRUE(loco::dtype_known(pull_node));
+ // Framework SHOULD make an annotation if "rule" returns TRUE
+ ASSERT_EQ(loco::dtype_get(pull_node), loco::DataType::U8);
+ ASSERT_FALSE(loco::dtype_known(push_node));
+}
+
+TEST(CanonicalTypeInferenceRuleTest, minimal)
+{
+ // Create a simple network
+ auto g = loco::make_graph();
+
+ auto pull_node = g->nodes()->create<loco::Pull>();
+
+ pull_node->dtype(loco::DataType::U8);
+
+ auto push_node = g->nodes()->create<loco::Push>();
+
+ push_node->from(pull_node);
+
+ auto graph_input = g->inputs()->create();
+
+ graph_input->name("input");
+ loco::link(graph_input, pull_node);
+
+ auto graph_output = g->outputs()->create();
+
+ graph_output->name("output");
+ loco::link(graph_output, push_node);
+
+ // Run Type Inference
+ loco::CanonicalTypeInferenceRule rule;
+
+ loco::apply(&rule).to(g.get());
+
+ // Verify!
+ ASSERT_TRUE(loco::dtype_known(push_node));
+ ASSERT_EQ(loco::dtype_get(push_node), loco::DataType::U8);
+}
+
+TEST(CanonicalTypeInferenceRuleTest, relu6)
+{
+ // Create a simple Relu6 network
+ auto g = loco::make_graph();
+
+ auto pull_node = g->nodes()->create<loco::Pull>();
+
+ pull_node->dtype(loco::DataType::FLOAT32);
+
+ auto relu6_node = g->nodes()->create<loco::ReLU6>();
+
+ relu6_node->input(pull_node);
+
+ auto push_node = g->nodes()->create<loco::Push>();
+
+ push_node->from(relu6_node);
+
+ auto graph_input = g->inputs()->create();
+
+ graph_input->name("input");
+ loco::link(graph_input, pull_node);
+
+ auto graph_output = g->outputs()->create();
+
+ graph_output->name("output");
+ loco::link(graph_output, push_node);
+
+ // Run Type Inference
+ loco::CanonicalTypeInferenceRule rule;
+
+ loco::apply(&rule).to(g.get());
+
+ // Verify!
+ ASSERT_TRUE(loco::dtype_known(relu6_node));
+ ASSERT_EQ(loco::dtype_get(relu6_node), loco::DataType::FLOAT32);
+}
+
+TEST(CanonicalTypeInferenceRuleTest, tensor_broadcast)
+{
+ // Create a sample network
+ GraphTestcase<GraphCode::TensorBroadcast> testcase{1, 2};
+
+ testcase.graph()->inputs()->at(0)->dtype(loco::DataType::U8);
+
+ // Run Type Inference
+ loco::CanonicalTypeInferenceRule rule;
+
+ loco::apply(&rule).to(testcase.graph());
+
+ // Verify!
+ ASSERT_TRUE(loco::dtype_known(testcase.push_node));
+ ASSERT_EQ(loco::dtype_get(testcase.push_node), loco::DataType::U8);
+}
+
+// mockup for MultiDialectTypeInferenceRule
+// OpNode of a specific loco datatype (defined in template) will be used.
+// And a Dialect for the OpNode and its inference rules are created.
+#include <loco/IR/Dialect.h>
+
+namespace
+{
+
+template <loco::DataType N> class TestDialect final : public loco::Dialect
+{
+public:
+ static Dialect *get(void)
+ {
+ static TestDialect<N> d;
+ return &d;
+ }
+};
+
+template <loco::DataType N>
+struct TestOpNode final : public loco::FixedArity<1>::Mixin<loco::Node>,
+ public loco::NodeMixin<loco::NodeTrait::DataType>
+{
+ void input(Node *node) { at(0)->node(node); }
+ const loco::Dialect *dialect(void) const final { return TestDialect<N>::get(); }
+ uint32_t opnum(void) const final { return static_cast<uint32_t>(N); }
+};
+
+template <loco::DataType N> struct TestTypeInferenceRule final : public loco::TypeInferenceRule
+{
+public:
+ bool recognize(const loco::Dialect *d) const final { return (d == TestDialect<N>::get()); }
+
+ bool infer(const loco::Node *node, loco::DataType &dtype) const final
+ {
+ assert(node->dialect() == TestDialect<N>::get());
+ auto test_node = dynamic_cast<const TestOpNode<N> *>(node);
+ assert(test_node != nullptr);
+
+ dtype = N;
+ return true;
+ }
+};
+
+} // namespace
+
+TEST(MultiDialectTypeInferenceRuleTest, test1)
+{
+ // Create a simple network : Pull - S8 - U8 - Push
+ auto g = loco::make_graph();
+
+ auto pull_node = g->nodes()->create<loco::Pull>();
+ pull_node->dtype(loco::DataType::FLOAT32);
+
+ auto s8_node = g->nodes()->create<TestOpNode<loco::DataType::S8>>();
+ s8_node->input(pull_node);
+
+ auto u8_node = g->nodes()->create<TestOpNode<loco::DataType::U8>>();
+ u8_node->input(s8_node);
+
+ auto push_node = g->nodes()->create<loco::Push>();
+ push_node->from(u8_node);
+
+ auto graph_input = g->inputs()->create();
+ graph_input->name("input");
+ loco::link(graph_input, pull_node);
+
+ auto graph_output = g->outputs()->create();
+ graph_output->name("output");
+ loco::link(graph_output, push_node);
+
+ // initially they don't have type info
+ ASSERT_FALSE(loco::dtype_known(s8_node));
+ ASSERT_FALSE(loco::dtype_known(u8_node));
+
+ // Run Type Inference
+ TestTypeInferenceRule<loco::DataType::U8> u8_rule;
+ TestTypeInferenceRule<loco::DataType::S8> s8_rule;
+ loco::CanonicalTypeInferenceRule canon_rule;
+
+ loco::MultiDialectTypeInferenceRule rules;
+
+ rules.bind(TestDialect<loco::DataType::S8>::get(), &s8_rule)
+ .bind(TestDialect<loco::DataType::U8>::get(), &u8_rule)
+ .bind(loco::CanonicalDialect::get(), &canon_rule);
+
+ loco::apply(&rules).to(g.get());
+
+ // Verify!
+ ASSERT_TRUE(loco::dtype_known(s8_node));
+ ASSERT_EQ(loco::dtype_get(s8_node), loco::DataType::S8);
+
+ ASSERT_TRUE(loco::dtype_known(u8_node));
+ ASSERT_EQ(loco::dtype_get(u8_node), loco::DataType::U8);
+}
diff --git a/compiler/loco/src/loco.test.cpp b/compiler/loco/src/loco.test.cpp
new file mode 100644
index 000000000..4c4f51aa5
--- /dev/null
+++ b/compiler/loco/src/loco.test.cpp
@@ -0,0 +1,108 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "loco.h"
+
+#include <gtest/gtest.h>
+
+// This test shows how to create an "identity" network with loco.
+//
+// What is "identity" network?
+// - A network simply passes its input as its output
+//
+// TODO Create "Ouput" first and then create "Push" later
+TEST(LOCO, identity_network)
+{
+ auto g = loco::make_graph();
+
+ // Create a "pull" node as an input
+ auto pull_node = g->nodes()->create<loco::Pull>();
+
+ // Set "data type"
+ pull_node->dtype(loco::DataType::FLOAT32);
+
+ // Set "data shape"
+ pull_node->rank(2);
+ pull_node->dim(0) = 3;
+ pull_node->dim(1) = 4;
+
+ // Create a "push" node as an output
+ auto push_node = g->nodes()->create<loco::Push>();
+
+ // Set "source"
+ push_node->from(pull_node);
+
+ // Create Graph Input & Output
+ auto graph_input = g->inputs()->create();
+
+ graph_input->name("input");
+ loco::link(graph_input, pull_node);
+ graph_input->dtype(loco::DataType::FLOAT32);
+
+ auto graph_output = g->outputs()->create();
+
+ graph_output->name("output");
+ loco::link(graph_output, push_node);
+
+ // loco::link SHOULD update "index"
+ ASSERT_EQ(pull_node->index(), 0);
+ ASSERT_EQ(graph_input->dtype(), loco::DataType::FLOAT32);
+
+ // loco::link SHOULD update "index"
+ ASSERT_EQ(push_node->index(), 0);
+}
+
+#if 0
+"identity_network_V2" test shows how to use loco when loco.core and loco.canonical are decoupled.
+
+NOTE "identity_network" test is left for backward compatiblity check
+TODO Remove "identity_network" test once all the clients are migrated.
+#endif
+TEST(LOCO, identity_network_V2)
+{
+ auto g = loco::make_graph();
+
+ // Create Graph Input & Output
+ auto graph_input = g->inputs()->create();
+
+ graph_input->name("input");
+ graph_input->dtype(loco::DataType::FLOAT32);
+ // TODO Set Shape
+
+ auto graph_output = g->outputs()->create();
+
+ graph_output->name("output");
+ graph_output->dtype(loco::DataType::FLOAT32);
+ // TODO Set Shape
+
+ // Create a "pull" node as an input
+ auto pull_node = g->nodes()->create<loco::Pull>();
+
+ pull_node->index(0);
+
+ // Create a "push" node as an output
+ auto push_node = g->nodes()->create<loco::Push>();
+
+ push_node->index(0);
+ push_node->from(pull_node);
+
+ ASSERT_EQ(pull_node->dtype(), loco::DataType::FLOAT32);
+ // TODO Check Shape of pull_node
+ // TODO Check Shape of push_node
+
+ ASSERT_EQ(loco::pull_node(g.get(), 0), pull_node);
+ ASSERT_EQ(loco::push_node(g.get(), 0), push_node);
+}
diff --git a/compiler/loco/src/tensorflow.test.cpp b/compiler/loco/src/tensorflow.test.cpp
new file mode 100644
index 000000000..f534aee7b
--- /dev/null
+++ b/compiler/loco/src/tensorflow.test.cpp
@@ -0,0 +1,386 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * @brief This file includes various tests that shows how to encode TensorFlow models using loco.
+ *
+ * @note All the python examples below assume TensorFlow v1.13
+ */
+#include "loco.h"
+
+#include <gtest/gtest.h>
+
+#include <stdex/Memory.h>
+
+using stdex::make_unique;
+
+namespace
+{
+
+loco::Permutation<loco::Domain::Feature> make_NHWC_permutation(void)
+{
+ loco::Permutation<loco::Domain::Feature> NHWC;
+
+ NHWC.axis(loco::FeatureAxis::Count) = 0;
+ NHWC.axis(loco::FeatureAxis::Height) = 1;
+ NHWC.axis(loco::FeatureAxis::Width) = 2;
+ NHWC.axis(loco::FeatureAxis::Depth) = 3;
+
+ return NHWC;
+}
+
+/**
+ * @brief Create a HxWxIxO (or HxWxCxN) permutation which tf.nn.conv2d uses
+ *
+ * Reference: [tf.nn.conv2d](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d)
+ * > Given an input tensor of shape [batch, in_height, in_width, in_channels] and a filter /
+ * > kernel tensor of shape [filter_height, filter_width, in_channels, out_channels], ...
+ *
+ * NOTE "HWIO" is borrowed from TensorFlow Lite Converter
+ *
+ * https://github.com/tensorflow/tensorflow/blob/v1.13.1/tensorflow/lite/toco/model.h#L169
+ */
+loco::Permutation<loco::Domain::Filter> make_HWIO_permutation(void)
+{
+ loco::Permutation<loco::Domain::Filter> HWIO;
+
+ HWIO.axis(loco::FilterAxis::Height) = 0; // H
+ HWIO.axis(loco::FilterAxis::Width) = 1; // W
+ HWIO.axis(loco::FilterAxis::Depth) = 2; // I, a.k.a. C
+ HWIO.axis(loco::FilterAxis::Count) = 3; // O, a.k.a. N
+
+ return HWIO;
+}
+
+} // nemaspace
+
+#if 0
+>>> MaxPool_Float_000 testcase
+
+MaxPool_Float_000 test guarantees that loco is expressive enough to encode the following example.
+
+Python:
+```
+import tensorflow as tf
+value = tf.placeholder(dtype=tf.float32, shape=[1, 16, 16, 2], name="value")
+maxpool = tf.nn.max_pool(value, [1, 3, 3, 1], [1, 1, 1, 1], 'VALID', name="maxpool")
+tf.get_default_graph().as_graph_def()
+```
+
+The above code produces the following TensorFlow GraphDef:
+
+node {
+ name: "value"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim { size: 1 }
+ dim { size: 16 }
+ dim { size: 16 }
+ dim { size: 2 }
+ }
+ }
+ }
+}
+node {
+ name: "maxpool"
+ op: "MaxPool"
+ input: "Placeholder"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "data_format"
+ value { s: "NHWC" }
+ }
+ attr {
+ key: "ksize"
+ value { list { i: 1 i: 3 i: 3 i: 1 } }
+ }
+ attr {
+ key: "padding"
+ value { s: "VALID" }
+ }
+ attr {
+ key: "strides"
+ value { list { i: 1 i: 1 i: 1 i: 1 } }
+ }
+}
+
+Below test guarantees that loco is expressive enough to encode this example.
+#endif
+TEST(TensorFlowTest, MaxPool_Float_000)
+{
+ auto g = loco::make_graph();
+
+ // The first "value" node corresponds to the following "Pull" node.
+ //
+ // %value = Pull(dtype: FLOAT32, shape: [1, 16, 16, 2])
+ auto value = g->nodes()->create<loco::Pull>();
+
+ value->dtype(loco::DataType::FLOAT32);
+ value->shape({1, 16, 16, 2});
+
+ // The next "maxpool" node corresponds to a sequence of the following loco nodes:
+ // - "FeatureEncode"
+ // - "MaxPool2D
+ // - "FeatureDecode"
+ //
+ // "maxpool.data_format" is 'NHWC' which corresponds to the following permutation
+ // Count <-> 0
+ // Height <-> 1
+ // Width <-> 2
+ // Depth <-> 3
+ loco::Permutation<loco::Domain::Feature> NHWC;
+
+ NHWC.axis(loco::FeatureAxis::Count) = 0;
+ NHWC.axis(loco::FeatureAxis::Height) = 1;
+ NHWC.axis(loco::FeatureAxis::Width) = 2;
+ NHWC.axis(loco::FeatureAxis::Depth) = 3;
+
+ auto encoder = make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+
+ encoder->perm(NHWC);
+
+ auto decoder = make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+
+ decoder->perm(NHWC);
+
+ // %node_0 = FeatureEncode(%value, perm { Count = 0, Height = 1, Width = 2, Depth = 3 })
+ auto node_0 = g->nodes()->create<loco::FeatureEncode>();
+
+ node_0->input(value);
+ node_0->encoder(std::move(encoder));
+
+ // %node_1 = MaxPool(%node_0, window.H: 3, window.W: 3, stride.H: 1, stride.W : 1)
+ auto node_1 = g->nodes()->create<loco::MaxPool2D>();
+
+ node_1->ifm(node_0);
+
+ // From "ksize" attributes
+ node_1->window()->horizontal(3);
+ node_1->window()->vertical(3);
+
+ // From "strides" attributes
+ node_1->stride()->horizontal(1);
+ node_1->stride()->vertical(1);
+
+ // %output = FeatureDecode(%node_1, perm { Count = 0, Height = 1, Width = 2, Depth = 3 })
+ auto output = g->nodes()->create<loco::FeatureDecode>();
+
+ output->input(node_1);
+ output->decoder(std::move(decoder));
+
+ // %push = Push(%output)
+ auto push = g->nodes()->create<loco::Push>();
+
+ push->from(output);
+
+ //
+ // Mark network-level input/output
+ //
+ auto input_0 = g->inputs()->create();
+ loco::link(input_0, value);
+
+ auto output_0 = g->outputs()->create();
+ loco::link(output_0, push);
+
+ // NOTE This example SHOULD BE valid.
+ ASSERT_TRUE(loco::valid(g.get()));
+}
+
+#if 0
+>>> Conv2D_Float_000 testcase
+
+Conv2D_Float_000 test guarantees that loco is expressive enough to encode the following example.
+
+Python:
+```
+import tensorflow as tf
+inp = tf.placeholder(dtype=tf.float32, shape=[1, 16, 16, 2], name="inp")
+ker = tf.constant(value=[1.0], dtype=tf.float32, shape=[7, 1, 2, 4], name="ker")
+conv2d = tf.nn.conv2d(input=inp, filter=ker, strides=[1, 1, 1, 1], padding='VALID', name="conv2d")
+tf.get_default_graph().as_graph_def()
+```
+
+TensorFlow GraphDef:
+```
+node {
+ name: "inp"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim { size: 1 }
+ dim { size: 16 }
+ dim { size: 16 }
+ dim { size: 2 }
+ }
+ }
+ }
+}
+node {
+ name: "ker"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim { size: 7 }
+ dim { size: 1 }
+ dim { size: 2 }
+ dim { size: 4 }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "conv2d"
+ op: "Conv2D"
+ input: "inp"
+ input: "ker"
+ attr {
+ key: "T"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "data_format"
+ value { s: "NHWC" }
+ }
+ attr {
+ key: "dilations"
+ value { list { i: 1 i: 1 i: 1 i: 1 } }
+ }
+ attr {
+ key: "padding"
+ value { s: "VALID" }
+ }
+ attr {
+ key: "strides"
+ value { list { i: 1 i: 1 i: 1 i: 1 } }
+ }
+}
+```
+#endif
+TEST(TensorFlowTest, Conv2D_Float_000)
+{
+ auto g = loco::make_graph();
+
+ // The first "inp" node corresponds to "Pull"
+ auto inp = g->nodes()->create<loco::Pull>();
+ {
+ inp->dtype(loco::DataType::FLOAT32);
+ inp->shape({1, 16, 16, 2});
+ }
+
+ // The seoncd "ker" node corresponds to "ConstGen"
+ auto ker = g->nodes()->create<loco::ConstGen>();
+ {
+ ker->dtype(loco::DataType::FLOAT32);
+ // 'I' denotes IFM DEPTH, and 'O' denotes OFM DEPTH
+ ker->shape({7 /*H*/, 1 /*W*/, 2 /*I*/, 3 /*O*/});
+ ker->size<loco::DataType::FLOAT32>(7 * 1 * 2 * 3);
+ for (uint32_t n = 0; n < 7 * 1 * 2 * 3; ++n)
+ {
+ // NOTE TensorFlow uses the last value to fill unspecified region
+ ker->at<loco::DataType::FLOAT32>(n) = 1.0f;
+ }
+ }
+
+ // The next "conv2d" node is decomposed into the following loco nodes
+ // - "FeatureEncode"
+ // - "FilterEncode"
+ // - "Conv2D"
+ // - "FeatureDecode"
+ auto encoded_ifm = g->nodes()->create<loco::FeatureEncode>();
+ {
+ // From "conv2d.data_format" attribute
+ auto encoder = make_unique<loco::PermutingEncoder<loco::Domain::Feature>>();
+ encoder->perm(make_NHWC_permutation());
+
+ encoded_ifm->input(inp);
+ encoded_ifm->encoder(std::move(encoder));
+ }
+
+ auto encoded_ker = g->nodes()->create<loco::FilterEncode>();
+ {
+ // From "tf.nn.conv2d" specification
+ auto encoder = make_unique<loco::PermutingEncoder<loco::Domain::Filter>>();
+ encoder->perm(make_HWIO_permutation());
+
+ encoded_ker->input(ker);
+ encoded_ker->encoder(std::move(encoder));
+ }
+
+ auto conv2d = g->nodes()->create<loco::Conv2D>();
+ {
+ conv2d->ifm(encoded_ifm);
+ conv2d->ker(encoded_ker);
+
+ // From "stride" attribute
+ conv2d->stride()->horizontal(1);
+ conv2d->stride()->vertical(1);
+ }
+
+ // "decoded_ofm" corresponds to the output of "conv2d" node.
+ auto decoded_ofm = g->nodes()->create<loco::FeatureDecode>();
+ {
+ // From "conv2d.data_format" attribute
+ auto decoder = make_unique<loco::PermutingDecoder<loco::Domain::Feature>>();
+ decoder->perm(make_NHWC_permutation());
+
+ decoded_ofm->input(conv2d);
+ decoded_ofm->decoder(std::move(decoder));
+ }
+
+ // Makr "conv2d" as a network-level output with Push
+ auto push = g->nodes()->create<loco::Push>();
+ {
+ push->from(decoded_ofm);
+ }
+
+ //
+ // Mark network-level input/output
+ //
+ auto input_0 = g->inputs()->create();
+ loco::link(input_0, inp);
+
+ auto output_0 = g->outputs()->create();
+ loco::link(output_0, push);
+
+ // NOTE This example SHOULD BE valid.
+ ASSERT_TRUE(loco::valid(g.get()));
+}