diff options
Diffstat (limited to 'compiler/loco/src')
74 files changed, 8917 insertions, 0 deletions
diff --git a/compiler/loco/src/ADT/AnnotatedItem.test.cpp b/compiler/loco/src/ADT/AnnotatedItem.test.cpp new file mode 100644 index 000000000..42113ff7b --- /dev/null +++ b/compiler/loco/src/ADT/AnnotatedItem.test.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/ADT/AnnotatedItem.h" + +#include <gtest/gtest.h> +#include <stdex/Memory.h> + +namespace +{ + +struct Annotation +{ + virtual ~Annotation() = default; +}; + +template <int N> struct DerivedAnnotation final : public Annotation +{ + static std::unique_ptr<DerivedAnnotation<N>> make(void) + { + return stdex::make_unique<DerivedAnnotation<N>>(); + } +}; + +} // namespace + +TEST(AnnotatedItemTest, annotation) +{ + loco::AnnotatedItem<::Annotation> item; + + ASSERT_EQ(item.annot<DerivedAnnotation<0>>(), nullptr); + + item.annot(DerivedAnnotation<0>::make()); + + ASSERT_NE(item.annot<DerivedAnnotation<0>>(), nullptr); + ASSERT_EQ(item.annot<DerivedAnnotation<1>>(), nullptr); + + item.annot<DerivedAnnotation<0>>(nullptr); + ASSERT_EQ(item.annot<DerivedAnnotation<0>>(), nullptr); + + // Below check guarantees that "annot<T>(nullptr)" is allowed even when there is no annotation. + // This guarantee allows us to simplify code for some cases. + // + // Let us consider the following example: + // + // void f(loco::AnnotatedItem<T> *item) + // { + // /* DO SOMETHING */ + // if (cond) { item->annot<T>(nullptr); + // } + // + // void g(loco::AnnotatedItem<T> *item) + // { + // f(item); + // item->annot<T>(nullptr); + // } + // + // The implementation of "g" gets complicated if annot<T>(nullptr) is not allowed if there is + // no annotation. + // + item.annot<DerivedAnnotation<0>>(nullptr); +} diff --git a/compiler/loco/src/ADT/ObjectPool.cpp b/compiler/loco/src/ADT/ObjectPool.cpp new file mode 100644 index 000000000..d15a30a99 --- /dev/null +++ b/compiler/loco/src/ADT/ObjectPool.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/ADT/ObjectPool.h" + +// This file validates "ObjectPool.h". Pleaes DO NOT remove this file. diff --git a/compiler/loco/src/IR/Algorithm.cpp b/compiler/loco/src/IR/Algorithm.cpp new file mode 100644 index 000000000..712e29975 --- /dev/null +++ b/compiler/loco/src/IR/Algorithm.cpp @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Algorithm.h" + +#include <cassert> +#include <set> +#include <stack> + +namespace +{ + +class Frame final +{ +public: + Frame(loco::Node *ptr) : _ptr{ptr}, _pos{-1} + { + // DO NOTHING + } + +public: + loco::Node *ptr(void) const { return _ptr; } + int64_t pos(void) const { return _pos; } + + loco::Node &node(void) const { return *_ptr; } + + void advance(void) { _pos += 1; } + +private: + loco::Node *_ptr = nullptr; + int64_t _pos = -1; +}; + +} // namespace + +namespace loco +{ + +// TODO Support cyclic graphs +std::vector<loco::Node *> postorder_traversal(const std::vector<loco::Node *> &roots) +{ + std::vector<loco::Node *> res; + + std::set<loco::Node *> visited_nodes; + std::stack<Frame> frames; + + auto visited = [&visited_nodes](loco::Node *node) { + return visited_nodes.find(node) != visited_nodes.end(); + }; + + // NOTE There is not much difference between "auto" and "auto &" as node is of "loco::Node *" + // type. + for (auto node : roots) + { + assert((node != nullptr) && "root is invalid"); + frames.push(Frame{node}); + } + + while (!frames.empty()) + { + auto &top_frame = frames.top(); + + if (top_frame.pos() == -1) + { + if (visited(top_frame.ptr())) + { + frames.pop(); + continue; + } + visited_nodes.insert(top_frame.ptr()); + } + + top_frame.advance(); + + assert(top_frame.pos() >= 0); + + if (top_frame.pos() < static_cast<int64_t>(top_frame.node().arity())) + { + // Let's visit the next argument + // + // NOTE "next" may be nullptr if a graph is under construction. + if (auto next = top_frame.node().arg(top_frame.pos())) + { + frames.push(Frame{next}); + } + } + else + { + // Let's visit the current argument (all the arguments are already visited) + auto curr = top_frame.ptr(); + res.emplace_back(curr); + frames.pop(); + } + } + + return res; +} + +std::set<loco::Node *> active_nodes(const std::vector<loco::Node *> &roots) +{ + // This implementation works but may be inefficient + // + // TODO Use efficient implementation if necessary + auto nodes = postorder_traversal(roots); + return std::set<loco::Node *>{nodes.begin(), nodes.end()}; +} + +} // namespace loco diff --git a/compiler/loco/src/IR/Algorithm.test.cpp b/compiler/loco/src/IR/Algorithm.test.cpp new file mode 100644 index 000000000..f0a3585c0 --- /dev/null +++ b/compiler/loco/src/IR/Algorithm.test.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Algorithm.h" +#include "loco/IR/Graph.h" + +#include <algorithm> + +#include <gtest/gtest.h> + +namespace +{ + +bool contains(const std::vector<loco::Node *> &vec, loco::Node *val) +{ + return std::any_of(vec.begin(), vec.end(), [val](loco::Node *node) { return node == val; }); +} + +bool contains(const std::set<loco::Node *> &s, loco::Node *val) +{ + return std::any_of(s.begin(), s.end(), [val](loco::Node *node) { return node == val; }); +} + +} // namespace + +TEST(AlgorithmTest, postorder_traversal) +{ + auto g = loco::make_graph(); + + auto pull_1 = g->nodes()->create<loco::Pull>(); + auto push = g->nodes()->create<loco::Push>(); + + push->from(pull_1); + + // Create a dummy node unreachable from the above "push" node + g->nodes()->create<loco::Pull>(); + + auto seq = loco::postorder_traversal({push}); + + ASSERT_EQ(seq.size(), 2); + ASSERT_EQ(seq.at(0), pull_1); + ASSERT_EQ(seq.at(1), push); +} + +TEST(AlgorithmTest, postorder_traversal_visit_once) +{ + auto g = loco::make_graph(); + + // Create a network of the following form: + // + // Push1 Push2 <-- outputs + // \ / + // Pull <-- input + // + auto pull = g->nodes()->create<loco::Pull>(); + auto push_1 = g->nodes()->create<loco::Push>(); + auto push_2 = g->nodes()->create<loco::Push>(); + + push_1->from(pull); + push_2->from(pull); + + auto seq = loco::postorder_traversal({push_1, push_2}); + + ASSERT_EQ(seq.size(), 3); + ASSERT_TRUE(contains(seq, pull)); + ASSERT_TRUE(contains(seq, push_1)); + ASSERT_TRUE(contains(seq, push_2)); +} + +TEST(AlgorithmTest, postorder_traversal_incomplte_graph) +{ + auto g = loco::make_graph(); + + // Create a network of the following form: + // + // TensorConcat + // / \ + // Pull X + // + auto pull = g->nodes()->create<loco::Pull>(); + auto concat = g->nodes()->create<loco::TensorConcat>(); + + concat->lhs(pull); + + auto seq = loco::postorder_traversal({concat}); + + ASSERT_EQ(seq.size(), 2); + ASSERT_EQ(seq.at(0), pull); + ASSERT_EQ(seq.at(1), concat); +} + +TEST(AlgorithmTest, active_nodes) +{ + auto g = loco::make_graph(); + + auto pull = g->nodes()->create<loco::Pull>(); + auto push = g->nodes()->create<loco::Push>(); + + push->from(pull); + + // NOTE This new Push node is unnecessary to compute "push" + g->nodes()->create<loco::Push>(); + + auto s = loco::active_nodes({push}); + + ASSERT_EQ(s.size(), 2); + ASSERT_TRUE(contains(s, pull)); + ASSERT_TRUE(contains(s, push)); +} diff --git a/compiler/loco/src/IR/BiasShape.test.cpp b/compiler/loco/src/IR/BiasShape.test.cpp new file mode 100644 index 000000000..7f9b8dfed --- /dev/null +++ b/compiler/loco/src/IR/BiasShape.test.cpp @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/BiasShape.h" + +#include <gtest/gtest.h> + +TEST(BiasShapeTest, default_constructor) +{ + loco::BiasShape shape; + + ASSERT_FALSE(shape.length().known()); +} diff --git a/compiler/loco/src/IR/CanonicalDialect.cpp b/compiler/loco/src/IR/CanonicalDialect.cpp new file mode 100644 index 000000000..ea956b80e --- /dev/null +++ b/compiler/loco/src/IR/CanonicalDialect.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/CanonicalDialect.h" +#include "loco/IR/Graph.h" +#include "loco/IR/Nodes.h" + +#include <stdex/Memory.h> + +#include <cassert> +#include <stdexcept> + +namespace +{ + +struct GraphOutputIndexQueryServiceImpl final : public loco::GraphOutputIndexQueryService +{ + bool associated(const loco::Node *node) const final + { + if (auto push = dynamic_cast<const loco::Push *>(node)) + { + return push->indexed(); + } + return false; + } + + loco::GraphOutputIndex index(const loco::Node *node) const final + { + assert(associated(node)); + if (auto push = dynamic_cast<const loco::Push *>(node)) + { + return push->index(); + } + throw std::invalid_argument("node"); + } +}; + +} // namespace + +namespace loco +{ + +CanonicalDialect::CanonicalDialect() +{ + service<GraphOutputIndexQueryService>(stdex::make_unique<GraphOutputIndexQueryServiceImpl>()); +} + +Dialect *CanonicalDialect::get(void) +{ + static CanonicalDialect d; + return &d; +} + +} // namespace loco diff --git a/compiler/loco/src/IR/CanonicalDialect.test.cpp b/compiler/loco/src/IR/CanonicalDialect.test.cpp new file mode 100644 index 000000000..96b48218d --- /dev/null +++ b/compiler/loco/src/IR/CanonicalDialect.test.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/CanonicalDialect.h" + +#include <gtest/gtest.h> + +TEST(CanonicalDialectTest, get) +{ + auto d = loco::CanonicalDialect::get(); + + // get() SHOULD return a valid(non-null) pointer + ASSERT_NE(d, nullptr); + // The return value SHOULD be stable across multiple invocations + ASSERT_EQ(d, loco::CanonicalDialect::get()); +} diff --git a/compiler/loco/src/IR/CanonicalNode.cpp b/compiler/loco/src/IR/CanonicalNode.cpp new file mode 100644 index 000000000..d5e13a415 --- /dev/null +++ b/compiler/loco/src/IR/CanonicalNode.cpp @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/CanonicalNode.h" +#include "loco/IR/CanonicalDialect.h" + +namespace loco +{ + +const Dialect *CanonicalNode::dialect(void) const { return CanonicalDialect::get(); } + +} // namespace loco diff --git a/compiler/loco/src/IR/CanonicalNode.test.cpp b/compiler/loco/src/IR/CanonicalNode.test.cpp new file mode 100644 index 000000000..cb61b5e83 --- /dev/null +++ b/compiler/loco/src/IR/CanonicalNode.test.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/CanonicalNode.h" + +#include <gtest/gtest.h> + +TEST(CanonicalNodeTest, visitor_with_user_default_impl) +{ + struct MyVisitor final : public loco::CanonicalNodeVisitor<uint32_t> + { + // This visitor returns 128 if it visits a Forward node. + uint32_t visit(const loco::Forward *) final { return 128; } + + // Otherwise, this visitor returns 256. + uint32_t visit(const loco::Node *) final { return 256; } + }; + + loco::Forward forward; + loco::ConstGen constgen; + + MyVisitor v; + + ASSERT_EQ(forward.accept(&v), 128); + ASSERT_EQ(constgen.accept(&v), 256); +} + +TEST(CanonicalNodeTest, visitor) +{ + struct CountingVisitor final : public loco::CanonicalNodeVisitor<uint32_t> + { + uint32_t visit(const loco::Forward *) final { return 1; } + }; + + // Visitor can visit constant nodes + const loco::Forward node; + + CountingVisitor v; + + ASSERT_EQ(node.accept(&v), 1); +} + +TEST(CanonicalNodeTest, mutable_visitor) +{ + struct ResetForward final : public loco::CanonicalNodeMutableVisitor<void> + { + void visit(loco::Forward *node) final { node->input(nullptr); } + }; + + loco::Pull pull_node; + loco::Forward forward_node; + + forward_node.input(&pull_node); + + ResetForward v; + forward_node.accept(&v); + + ASSERT_EQ(forward_node.input(), nullptr); +} diff --git a/compiler/loco/src/IR/CanonicalOpcode.cpp b/compiler/loco/src/IR/CanonicalOpcode.cpp new file mode 100644 index 000000000..6355ecf1f --- /dev/null +++ b/compiler/loco/src/IR/CanonicalOpcode.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/CanonicalOpcode.h" + +// NOTE This file validates "CanonicalOpcode.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/DataType.cpp b/compiler/loco/src/IR/DataType.cpp new file mode 100644 index 000000000..56794dac7 --- /dev/null +++ b/compiler/loco/src/IR/DataType.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/DataType.h" + +// This file validates "DataType.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/DataTypeTraits.test.cpp b/compiler/loco/src/IR/DataTypeTraits.test.cpp new file mode 100644 index 000000000..76d2515a9 --- /dev/null +++ b/compiler/loco/src/IR/DataTypeTraits.test.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/DataTypeTraits.h" + +#include <typeindex> + +#include <gtest/gtest.h> + +TEST(DataTypeTraitsTest, FLOAT32) +{ + auto obtained = std::type_index(typeid(loco::DataTypeImpl<loco::DataType::FLOAT32>::Type)); + auto expected = std::type_index(typeid(float)); + + ASSERT_EQ(obtained, expected); +} diff --git a/compiler/loco/src/IR/DepthwiseFilterAxis.cpp b/compiler/loco/src/IR/DepthwiseFilterAxis.cpp new file mode 100644 index 000000000..9d58795b2 --- /dev/null +++ b/compiler/loco/src/IR/DepthwiseFilterAxis.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/DepthwiseFilterAxis.h" + +// NOTE This file validates "DepthwiseFilterAxis.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/DepthwiseFilterCodec.cpp b/compiler/loco/src/IR/DepthwiseFilterCodec.cpp new file mode 100644 index 000000000..05a7fd723 --- /dev/null +++ b/compiler/loco/src/IR/DepthwiseFilterCodec.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/DepthwiseFilterCodec.h" + +// NOTE This file validates "DepthwiseFilterCodec.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/DepthwiseFilterIndex.test.cpp b/compiler/loco/src/IR/DepthwiseFilterIndex.test.cpp new file mode 100644 index 000000000..202647cfc --- /dev/null +++ b/compiler/loco/src/IR/DepthwiseFilterIndex.test.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/DepthwiseFilterIndex.h" + +#include <gtest/gtest.h> + +TEST(DepthwiseFilterIndexTest, default_constructor) +{ + loco::DepthwiseFilterIndex index; + + // All the values are 0 at the beginning + ASSERT_EQ(index.channel(), 0); + ASSERT_EQ(index.nth(), 0); + ASSERT_EQ(index.row(), 0); + ASSERT_EQ(index.column(), 0); +} + +TEST(DepthwiseFilterIndexTest, settet_and_getter) +{ + loco::DepthwiseFilterIndex index; + + // Set depth + index.channel() = 2; + + ASSERT_EQ(index.channel(), 2); + ASSERT_EQ(index.nth(), 0); + ASSERT_EQ(index.row(), 0); + ASSERT_EQ(index.column(), 0); + + // Set multiplier + index.nth() = 3; + + ASSERT_EQ(index.channel(), 2); + ASSERT_EQ(index.nth(), 3); + ASSERT_EQ(index.row(), 0); + ASSERT_EQ(index.column(), 0); + + // Set height + index.row() = 4; + + ASSERT_EQ(index.channel(), 2); + ASSERT_EQ(index.nth(), 3); + ASSERT_EQ(index.row(), 4); + ASSERT_EQ(index.column(), 0); + + // Set width + index.column() = 5; + + ASSERT_EQ(index.channel(), 2); + ASSERT_EQ(index.nth(), 3); + ASSERT_EQ(index.row(), 4); + ASSERT_EQ(index.column(), 5); +} diff --git a/compiler/loco/src/IR/DepthwiseFilterShape.test.cpp b/compiler/loco/src/IR/DepthwiseFilterShape.test.cpp new file mode 100644 index 000000000..2b9518c1f --- /dev/null +++ b/compiler/loco/src/IR/DepthwiseFilterShape.test.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/DepthwiseFilterShape.h" + +#include <gtest/gtest.h> + +TEST(DepthwiseFilterShapeTest, default_constructor) +{ + loco::DepthwiseFilterShape shape; + + ASSERT_FALSE(shape.depth().known()); + ASSERT_FALSE(shape.multiplier().known()); + ASSERT_FALSE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); +} + +TEST(DepthwiseFilterShapeTest, settet_and_getter) +{ + loco::DepthwiseFilterShape shape; + + // Set depth + shape.depth() = 2; + + ASSERT_TRUE(shape.depth().known()); + ASSERT_FALSE(shape.multiplier().known()); + ASSERT_FALSE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); + + ASSERT_EQ(shape.depth(), 2); + + // Set multiplier + shape.multiplier() = 3; + + ASSERT_TRUE(shape.depth().known()); + ASSERT_TRUE(shape.multiplier().known()); + ASSERT_FALSE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); + + ASSERT_EQ(shape.depth(), 2); + ASSERT_EQ(shape.multiplier(), 3); + + // Set height + shape.height() = 4; + + ASSERT_TRUE(shape.depth().known()); + ASSERT_TRUE(shape.multiplier().known()); + ASSERT_TRUE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); + + ASSERT_EQ(shape.depth(), 2); + ASSERT_EQ(shape.multiplier(), 3); + ASSERT_EQ(shape.height(), 4); + + // Set width + shape.width() = 5; + + ASSERT_TRUE(shape.depth().known()); + ASSERT_TRUE(shape.multiplier().known()); + ASSERT_TRUE(shape.height().known()); + ASSERT_TRUE(shape.width().known()); + + ASSERT_EQ(shape.depth(), 2); + ASSERT_EQ(shape.multiplier(), 3); + ASSERT_EQ(shape.height(), 4); + ASSERT_EQ(shape.width(), 5); +} diff --git a/compiler/loco/src/IR/Dialect.cpp b/compiler/loco/src/IR/Dialect.cpp new file mode 100644 index 000000000..a381b47eb --- /dev/null +++ b/compiler/loco/src/IR/Dialect.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Dialect.h" + +// NOTE This file validates "Dialect.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/Dialect.test.cpp b/compiler/loco/src/IR/Dialect.test.cpp new file mode 100644 index 000000000..312bb52ef --- /dev/null +++ b/compiler/loco/src/IR/Dialect.test.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Dialect.h" + +#include <stdex/Memory.h> + +#include <gtest/gtest.h> + +TEST(DialectTest, service) +{ + struct S0 final : public loco::DialectService + { + }; + struct S1 final : public loco::DialectService + { + }; + + struct MockDialect final : public loco::Dialect + { + MockDialect() { service<S1>(stdex::make_unique<S1>()); } + }; + + MockDialect dialect; + + ASSERT_EQ(dialect.service<S0>(), nullptr); + ASSERT_NE(dialect.service<S1>(), nullptr); +} diff --git a/compiler/loco/src/IR/DialectService.cpp b/compiler/loco/src/IR/DialectService.cpp new file mode 100644 index 000000000..fb8041e47 --- /dev/null +++ b/compiler/loco/src/IR/DialectService.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/DialectService.h" + +// NOTE This file validates "DialectService.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/Dimension.cpp b/compiler/loco/src/IR/Dimension.cpp new file mode 100644 index 000000000..0d11c83e8 --- /dev/null +++ b/compiler/loco/src/IR/Dimension.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Dimension.h" + +namespace loco +{ + +bool operator==(const Dimension &lhs, const Dimension &rhs) +{ + return lhs.known() && rhs.known() && lhs.value() == rhs.value(); +} + +bool operator==(const Dimension &lhs, uint32_t rhs) { return lhs.known() && lhs.value() == rhs; } +bool operator==(uint32_t lhs, const Dimension &rhs) { return rhs.known() && lhs == rhs.value(); } + +Dimension make_dimension(void) { return Dimension{}; } + +} // namespace loco diff --git a/compiler/loco/src/IR/Dimension.test.cpp b/compiler/loco/src/IR/Dimension.test.cpp new file mode 100644 index 000000000..4faf78ac8 --- /dev/null +++ b/compiler/loco/src/IR/Dimension.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Dimension.h" + +#include <gtest/gtest.h> + +namespace +{ + +struct DimensionTest : public ::testing::Test +{ +protected: + uint32_t value(void) const { return _value; } + +private: + uint32_t const _value{3}; +}; + +} // namespace + +TEST_F(DimensionTest, default_constructor) +{ + loco::Dimension dim; + + ASSERT_FALSE(dim.known()); +} + +TEST_F(DimensionTest, value_constructor) +{ + loco::Dimension dim{value()}; + + ASSERT_TRUE(dim.known()); + ASSERT_EQ(dim.value(), value()); +} + +TEST_F(DimensionTest, set) +{ + loco::Dimension dim; + + dim.set(value()); + + ASSERT_TRUE(dim.known()); + ASSERT_EQ(dim.value(), value()); +} + +TEST_F(DimensionTest, unset) +{ + loco::Dimension dim{value()}; + + dim.unset(); + + ASSERT_FALSE(dim.known()); +} + +TEST_F(DimensionTest, operator_eq) +{ + loco::Dimension unknown; + loco::Dimension known{3}; + + // Compare uint32_t and an unknown dimension + ASSERT_FALSE(unknown == 3); + ASSERT_FALSE(3 == unknown); + + // Compare uint32_t and a known dimension + ASSERT_TRUE(known == 3); + ASSERT_TRUE(3 == known); + + ASSERT_FALSE(known == 4); + ASSERT_FALSE(4 == known); + + // Compare two known dimensions + loco::Dimension another_known{3}; + ASSERT_TRUE(known == another_known); + + // Compare two unknown dimensions + loco::Dimension unknown_a, unknown_b; + ASSERT_TRUE(unknown_a.known() == false && unknown_b.known() == false); + ASSERT_FALSE(unknown_a == unknown_b); +} + +TEST_F(DimensionTest, make_unknown_dimension) +{ + auto dim = loco::make_dimension(); + + ASSERT_FALSE(dim.known()); +} diff --git a/compiler/loco/src/IR/Domain.cpp b/compiler/loco/src/IR/Domain.cpp new file mode 100644 index 000000000..7bad04750 --- /dev/null +++ b/compiler/loco/src/IR/Domain.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Domain.h" + +// NOTE This file validates "Domain.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/FeatureAxis.cpp b/compiler/loco/src/IR/FeatureAxis.cpp new file mode 100644 index 000000000..b0f560677 --- /dev/null +++ b/compiler/loco/src/IR/FeatureAxis.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/FeatureAxis.h" + +// NOTE This file validates "FeatureAxis.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/FeatureCodec.cpp b/compiler/loco/src/IR/FeatureCodec.cpp new file mode 100644 index 000000000..99d39a489 --- /dev/null +++ b/compiler/loco/src/IR/FeatureCodec.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/FeatureCodec.h" + +// NOTE This file validates "FeatureCodec.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/FeatureIndex.test.cpp b/compiler/loco/src/IR/FeatureIndex.test.cpp new file mode 100644 index 000000000..82b563986 --- /dev/null +++ b/compiler/loco/src/IR/FeatureIndex.test.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/FeatureIndex.h" + +#include <gtest/gtest.h> + +TEST(FeatureIndexTest, default_constructor) +{ + loco::FeatureIndex index; + + // All the values are 0 at the beginning + ASSERT_EQ(index.batch(), 0); + ASSERT_EQ(index.channel(), 0); + ASSERT_EQ(index.row(), 0); + ASSERT_EQ(index.column(), 0); +} + +TEST(FeatureIndexTest, settet_and_getter) +{ + loco::FeatureIndex index; + + // Set count + index.batch() = 2; + + ASSERT_EQ(index.batch(), 2); + ASSERT_EQ(index.channel(), 0); + ASSERT_EQ(index.row(), 0); + ASSERT_EQ(index.column(), 0); + + // Set channel + index.channel() = 3; + + ASSERT_EQ(index.batch(), 2); + ASSERT_EQ(index.channel(), 3); + ASSERT_EQ(index.row(), 0); + ASSERT_EQ(index.column(), 0); + + // Set height + index.row() = 4; + + ASSERT_EQ(index.batch(), 2); + ASSERT_EQ(index.channel(), 3); + ASSERT_EQ(index.row(), 4); + ASSERT_EQ(index.column(), 0); + + // Set width + index.column() = 5; + + ASSERT_EQ(index.batch(), 2); + ASSERT_EQ(index.channel(), 3); + ASSERT_EQ(index.row(), 4); + ASSERT_EQ(index.column(), 5); +} diff --git a/compiler/loco/src/IR/FeatureShape.test.cpp b/compiler/loco/src/IR/FeatureShape.test.cpp new file mode 100644 index 000000000..59e25ac23 --- /dev/null +++ b/compiler/loco/src/IR/FeatureShape.test.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/FeatureShape.h" + +#include <gtest/gtest.h> + +TEST(FeatureShapeTest, default_constructor) +{ + loco::FeatureShape shape; + + ASSERT_FALSE(shape.count().known()); + ASSERT_FALSE(shape.depth().known()); + ASSERT_FALSE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); +} + +TEST(FeatureShapeTest, settet_and_getter) +{ + loco::FeatureShape shape; + + // Set count + shape.count() = 2; + + ASSERT_TRUE(shape.count().known()); + ASSERT_FALSE(shape.depth().known()); + ASSERT_FALSE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); + + ASSERT_EQ(shape.count(), 2); + + // Set depth + shape.depth() = 3; + + ASSERT_TRUE(shape.count().known()); + ASSERT_TRUE(shape.depth().known()); + ASSERT_FALSE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); + + ASSERT_EQ(shape.count(), 2); + ASSERT_EQ(shape.depth(), 3); + + // Set height + shape.height() = 4; + + ASSERT_TRUE(shape.count().known()); + ASSERT_TRUE(shape.depth().known()); + ASSERT_TRUE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); + + ASSERT_EQ(shape.count(), 2); + ASSERT_EQ(shape.depth(), 3); + ASSERT_EQ(shape.height(), 4); + + // Set width + shape.width() = 5; + + ASSERT_TRUE(shape.count().known()); + ASSERT_TRUE(shape.depth().known()); + ASSERT_TRUE(shape.height().known()); + ASSERT_TRUE(shape.width().known()); + + ASSERT_EQ(shape.count(), 2); + ASSERT_EQ(shape.depth(), 3); + ASSERT_EQ(shape.height(), 4); + ASSERT_EQ(shape.width(), 5); +} diff --git a/compiler/loco/src/IR/FilterAxis.cpp b/compiler/loco/src/IR/FilterAxis.cpp new file mode 100644 index 000000000..be4234e6a --- /dev/null +++ b/compiler/loco/src/IR/FilterAxis.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/FilterAxis.h" + +// NOTE This file validates "FilterAxis.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/FilterCodec.cpp b/compiler/loco/src/IR/FilterCodec.cpp new file mode 100644 index 000000000..f48cf1821 --- /dev/null +++ b/compiler/loco/src/IR/FilterCodec.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/FilterCodec.h" + +// NOTE This file validates "FilterCodec.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/FilterIndex.test.cpp b/compiler/loco/src/IR/FilterIndex.test.cpp new file mode 100644 index 000000000..58f38718e --- /dev/null +++ b/compiler/loco/src/IR/FilterIndex.test.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/FilterIndex.h" + +#include <gtest/gtest.h> + +TEST(FilterIndexTest, default_constructor) +{ + loco::FilterIndex index; + + // All the values are 0 at the beginning + ASSERT_EQ(index.nth(), 0); + ASSERT_EQ(index.channel(), 0); + ASSERT_EQ(index.row(), 0); + ASSERT_EQ(index.column(), 0); +} + +TEST(FilterIndexTest, settet_and_getter) +{ + loco::FilterIndex index; + + // Set count + index.nth() = 2; + + ASSERT_EQ(index.nth(), 2); + ASSERT_EQ(index.channel(), 0); + ASSERT_EQ(index.row(), 0); + ASSERT_EQ(index.column(), 0); + + // Set channel + index.channel() = 3; + + ASSERT_EQ(index.nth(), 2); + ASSERT_EQ(index.channel(), 3); + ASSERT_EQ(index.row(), 0); + ASSERT_EQ(index.column(), 0); + + // Set height + index.row() = 4; + + ASSERT_EQ(index.nth(), 2); + ASSERT_EQ(index.channel(), 3); + ASSERT_EQ(index.row(), 4); + ASSERT_EQ(index.column(), 0); + + // Set width + index.column() = 5; + + ASSERT_EQ(index.nth(), 2); + ASSERT_EQ(index.channel(), 3); + ASSERT_EQ(index.row(), 4); + ASSERT_EQ(index.column(), 5); +} diff --git a/compiler/loco/src/IR/FilterShape.test.cpp b/compiler/loco/src/IR/FilterShape.test.cpp new file mode 100644 index 000000000..ccb60ed76 --- /dev/null +++ b/compiler/loco/src/IR/FilterShape.test.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/FilterShape.h" + +#include <gtest/gtest.h> + +TEST(FilterShapeTest, default_constructor) +{ + loco::FilterShape shape; + + ASSERT_FALSE(shape.count().known()); + ASSERT_FALSE(shape.depth().known()); + ASSERT_FALSE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); +} + +TEST(FilterShapeTest, settet_and_getter) +{ + loco::FilterShape shape; + + // Set count + shape.count() = 2; + + ASSERT_TRUE(shape.count().known()); + ASSERT_FALSE(shape.depth().known()); + ASSERT_FALSE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); + + ASSERT_EQ(shape.count(), 2); + + // Set depth + shape.depth() = 3; + + ASSERT_TRUE(shape.count().known()); + ASSERT_TRUE(shape.depth().known()); + ASSERT_FALSE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); + + ASSERT_EQ(shape.count(), 2); + ASSERT_EQ(shape.depth(), 3); + + // Set height + shape.height() = 4; + + ASSERT_TRUE(shape.count().known()); + ASSERT_TRUE(shape.depth().known()); + ASSERT_TRUE(shape.height().known()); + ASSERT_FALSE(shape.width().known()); + + ASSERT_EQ(shape.count(), 2); + ASSERT_EQ(shape.depth(), 3); + ASSERT_EQ(shape.height(), 4); + + // Set width + shape.width() = 5; + + ASSERT_TRUE(shape.count().known()); + ASSERT_TRUE(shape.depth().known()); + ASSERT_TRUE(shape.height().known()); + ASSERT_TRUE(shape.width().known()); + + ASSERT_EQ(shape.count(), 2); + ASSERT_EQ(shape.depth(), 3); + ASSERT_EQ(shape.height(), 4); + ASSERT_EQ(shape.width(), 5); +} diff --git a/compiler/loco/src/IR/Graph.cpp b/compiler/loco/src/IR/Graph.cpp new file mode 100644 index 000000000..1d8752252 --- /dev/null +++ b/compiler/loco/src/IR/Graph.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Graph.h" + +#include <stdex/Memory.h> + +#include <cassert> + +namespace +{ + +std::unique_ptr<loco::TensorShape> make_tensor_shape(std::initializer_list<loco::Dimension> dims) +{ + auto tensor_shape = stdex::make_unique<loco::TensorShape>(); + + tensor_shape->rank(dims.size()); + { + uint32_t axis = 0; + for (auto it = dims.begin(); it != dims.end(); ++it) + { + tensor_shape->dim(axis++) = *it; + } + assert(axis == dims.size()); + } + + return std::move(tensor_shape); +} + +} // namespace + +namespace loco +{ + +void Mixin<Trait::TensorShaped>::shape(std::initializer_list<Dimension> dims) +{ + shape(make_tensor_shape(dims)); +} + +GraphInput *Graph::InputContext::create(void) +{ + return take(stdex::make_unique<GraphInput>(size())); +} + +GraphOutput *Graph::OutputContext::create(void) +{ + return take(stdex::make_unique<GraphOutput>(size())); +} + +std::set<loco::Node *> all_nodes(loco::Graph *g) +{ + std::set<loco::Node *> res; + + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + res.insert(g->nodes()->at(n)); + } + + return res; +} + +std::vector<Node *> input_nodes(const Graph *g) +{ + std::map<GraphInputIndex, loco::Node *> table; + + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + auto node = g->nodes()->at(n); + + if (auto service = node->dialect()->service<GraphInputIndexQueryService>()) + { + if (service->associated(node)) + { + auto input_index = service->index(node); + assert(table.find(input_index) == table.end()); + table[input_index] = node; + } + } + } + + std::vector<loco::Node *> res; + + for (uint32_t n = 0; n < g->inputs()->size(); ++n) + { + auto it = table.find(n); + res.emplace_back(it == table.end() ? nullptr : it->second); + } + + return res; +} + +std::vector<loco::Node *> output_nodes(loco::Graph *g) +{ + std::map<GraphOutputIndex, loco::Node *> table; + + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + auto node = g->nodes()->at(n); + + if (auto service = node->dialect()->service<GraphOutputIndexQueryService>()) + { + if (service->associated(node)) + { + auto output_index = service->index(node); + assert(table.find(output_index) == table.end()); + table[output_index] = node; + } + } + } + + std::vector<loco::Node *> res; + + for (uint32_t n = 0; n < g->outputs()->size(); ++n) + { + auto it = table.find(n); + res.emplace_back(it == table.end() ? nullptr : it->second); + } + + return res; +} + +std::unique_ptr<Graph> make_graph(void) { return std::unique_ptr<Graph>{new Graph}; } + +} // namespace loco diff --git a/compiler/loco/src/IR/Graph.test.cpp b/compiler/loco/src/IR/Graph.test.cpp new file mode 100644 index 000000000..6df630b0f --- /dev/null +++ b/compiler/loco/src/IR/Graph.test.cpp @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Graph.h" + +#include <gtest/gtest.h> + +namespace +{ + +// @brief Mockup class for loco::NamedEntity +struct NamedElement final : private loco::NamedEntity +{ + LOCO_NAMED_ENTITY_EXPOSE; +}; + +} // namespace + +TEST(NamedTest, constructor) +{ + NamedElement elem; + + ASSERT_EQ(elem.name(), ""); +} + +TEST(NamedTest, setter_and_getter) +{ + NamedElement elem; + + elem.name("name"); + ASSERT_EQ(elem.name(), "name"); +} + +TEST(DataTypedMixinTest, constructor) +{ + loco::Mixin<loco::Trait::DataTyped> mixin; + + ASSERT_EQ(mixin.dtype(), loco::DataType::Unknown); +} + +TEST(DataTypedMixinTest, setter_and_getter) +{ + loco::Mixin<loco::Trait::DataTyped> mixin; + + mixin.dtype(loco::DataType::FLOAT32); + ASSERT_EQ(mixin.dtype(), loco::DataType::FLOAT32); +} + +TEST(TensorShapedMixinTest, setter_and_getter) +{ + loco::Mixin<loco::Trait::TensorShaped> mixin; + + mixin.shape({1, 2, 3, 4}); + ASSERT_NE(mixin.shape(), nullptr); + ASSERT_EQ(mixin.shape()->rank(), 4); + ASSERT_EQ(mixin.shape()->dim(0), 1); + ASSERT_EQ(mixin.shape()->dim(1), 2); + ASSERT_EQ(mixin.shape()->dim(2), 3); + ASSERT_EQ(mixin.shape()->dim(3), 4); +} + +TEST(GraphTest, create_and_destroy_node) +{ + auto g = loco::make_graph(); + + auto pull = g->nodes()->create<loco::Pull>(); + + ASSERT_NO_THROW(g->nodes()->destroy(pull)); + ASSERT_THROW(g->nodes()->destroy(pull), std::invalid_argument); +} + +TEST(GraphTest, create_input) +{ + auto g = loco::make_graph(); + + auto input = g->inputs()->create(); + + // TODO Add more checks + ASSERT_EQ(input->shape(), nullptr); + ASSERT_EQ(input->index(), 0); +} + +TEST(GraphTest, create_output) +{ + auto g = loco::make_graph(); + + auto output = g->outputs()->create(); + + // TODO Add more checks + ASSERT_EQ(output->shape(), nullptr); + ASSERT_EQ(output->index(), 0); +} + +namespace +{ +// temp node with multple params for ctor. loco::CanonicalOpcode::ReLU is used for simplicity +class ParamCtorNode + : public loco::CanonicalNodeDef<loco::CanonicalOpcode::ReLU, loco::FixedArity<0>::Mixin> +{ +public: + ParamCtorNode(int i, float f) + { + _i = i; + _f = f; + } + + int i() { return _i; } + float f() { return _f; } + +private: + int _i; + float _f; +}; +} // namespace + +TEST(GraphTest, consturctor_with_param_node) +{ + auto g = loco::make_graph(); + + auto test_node = g->nodes()->create<ParamCtorNode>(22, 11.11); + + ASSERT_EQ(test_node->graph(), g.get()); + ASSERT_EQ(const_cast<const ParamCtorNode *>(test_node)->graph(), g.get()); + + ASSERT_EQ(test_node->i(), 22); + ASSERT_FLOAT_EQ(test_node->f(), 11.11); + + ASSERT_NO_THROW(g->nodes()->destroy(test_node)); + ASSERT_THROW(g->nodes()->destroy(test_node), std::invalid_argument); +} + +TEST(GraphTest, getters_over_const_instance) +{ + auto g = loco::make_graph(); + + auto pull = g->nodes()->create<loco::Pull>(); + auto push = g->nodes()->create<loco::Push>(); + + loco::link(g->inputs()->create(), pull); + loco::link(g->outputs()->create(), push); + + auto ptr = const_cast<const loco::Graph *>(g.get()); + + EXPECT_EQ(ptr->nodes()->size(), 2); + EXPECT_EQ(ptr->inputs()->size(), 1); +} + +TEST(GraphTest, graph_node_enumeration) +{ + auto g = loco::make_graph(); + + auto pull_1 = g->nodes()->create<loco::Pull>(); + auto push_1 = g->nodes()->create<loco::Push>(); + + auto nodes = loco::all_nodes(g.get()); + + // Returns true if "nodes" includes a given node + auto member = [&nodes](loco::Node *node) { return nodes.find(node) != nodes.end(); }; + + ASSERT_EQ(nodes.size(), 2); + ASSERT_TRUE(member(pull_1)); + ASSERT_TRUE(member(push_1)); +} + +TEST(GraphTest, graph_inout_enumeration) +{ + auto g = loco::make_graph(); + + std::vector<loco::Pull *> pull_nodes; + + auto pull_1 = g->nodes()->create<loco::Pull>(); + auto pull_2 = g->nodes()->create<loco::Pull>(); + auto pull_3 = g->nodes()->create<loco::Pull>(); + + auto push_1 = g->nodes()->create<loco::Push>(); + auto push_2 = g->nodes()->create<loco::Push>(); + auto push_3 = g->nodes()->create<loco::Push>(); + + loco::link(g->inputs()->create(), pull_2); + loco::link(g->inputs()->create(), pull_1); + + loco::link(g->outputs()->create(), push_1); + loco::link(g->outputs()->create(), push_3); + + auto output_nodes = loco::output_nodes(g.get()); + + ASSERT_EQ(output_nodes.size(), 2); + ASSERT_EQ(output_nodes.at(0), push_1); + ASSERT_EQ(output_nodes.at(1), push_3); +} + +TEST(GraphTest, graph_name) +{ + auto g = loco::make_graph(); + + g->name("HelloGraph"); + ASSERT_TRUE(g->name() == "HelloGraph"); +} + +TEST(GraphTest, graph_name_nullptr_NEG) +{ + auto g = loco::make_graph(); + + EXPECT_ANY_THROW(g->name(nullptr)); +} diff --git a/compiler/loco/src/IR/GraphInputIndex.cpp b/compiler/loco/src/IR/GraphInputIndex.cpp new file mode 100644 index 000000000..0c94d704c --- /dev/null +++ b/compiler/loco/src/IR/GraphInputIndex.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/GraphInputIndex.h" + +// NOTE This file validates "GraphInputIndex.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/GraphOutputIndex.cpp b/compiler/loco/src/IR/GraphOutputIndex.cpp new file mode 100644 index 000000000..e6fdb9f94 --- /dev/null +++ b/compiler/loco/src/IR/GraphOutputIndex.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/GraphOutputIndex.h" + +// NOTE This file validates "GraphOutputIndex.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/MatrixAxis.cpp b/compiler/loco/src/IR/MatrixAxis.cpp new file mode 100644 index 000000000..d0773f758 --- /dev/null +++ b/compiler/loco/src/IR/MatrixAxis.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/MatrixAxis.h" + +// NOTE This file validates "MatrixAxis.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/MatrixCodec.cpp b/compiler/loco/src/IR/MatrixCodec.cpp new file mode 100644 index 000000000..87ae42610 --- /dev/null +++ b/compiler/loco/src/IR/MatrixCodec.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/MatrixCodec.h" + +// NOTE This file validates "MatrixCodec.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/MockupNode.h b/compiler/loco/src/IR/MockupNode.h new file mode 100644 index 000000000..ec56c90e2 --- /dev/null +++ b/compiler/loco/src/IR/MockupNode.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LOCO_IR_MOCKUP_NODE_H__ +#define __LOCO_IR_MOCKUP_NODE_H__ + +#include "loco/IR/Use.h" +#include "loco/IR/Node.h" + +namespace +{ + +struct MockDialect final : public loco::Dialect +{ + static loco::Dialect *get(void) + { + static MockDialect d; + return &d; + } +}; + +// @brief Mockup node for internal testing +class MockupNode final : public loco::Node +{ +public: + MockupNode() = default; + +public: + const loco::Dialect *dialect(void) const final { return MockDialect::get(); } + uint32_t opnum(void) const final { return 0; } + + uint32_t arity(void) const final { return 1; } + Node *arg(uint32_t N) const final { return _arg.node(); } + void drop(void) final { _arg.node(nullptr); } + + Node *in(void)const { return _arg.node(); } + void in(Node *node) { _arg.node(node); } + +private: + loco::Use _arg{this}; +}; + +} // namespace + +#endif // __LOCO_IR_MOCKUP_NODE_H__ diff --git a/compiler/loco/src/IR/Node.cpp b/compiler/loco/src/IR/Node.cpp new file mode 100644 index 000000000..90ec5c997 --- /dev/null +++ b/compiler/loco/src/IR/Node.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Node.h" +#include "loco/IR/Use.h" + +#include <cassert> + +namespace loco +{ + +Node::~Node() +{ + // To detect dangling references + assert(_uses.size() == 0); +} + +std::set<Node *> preds(const Node *node) +{ + std::set<Node *> res; + + for (uint32_t n = 0; n < node->arity(); ++n) + { + if (auto pred = node->arg(n)) + { + res.insert(pred); + } + } + + return res; +} + +std::set<Node *> succs(const Node *node) +{ + std::set<Node *> res; + + for (auto use : node->_uses) + { + auto user = use->user(); + assert(user != nullptr); + res.insert(user); + } + + return res; +} + +Subst<SubstQualifier::Default>::Subst(Node *from) : _from{from} +{ + // _from SHOULD be valid + assert(_from != nullptr); +} + +void Subst<SubstQualifier::Default>::with(Node *into) const +{ + if (_from == into) + { + return; + } + + auto *uses = &(_from->_uses); + + while (!uses->empty()) + { + auto use = *(uses->begin()); + use->node(into); + } +} + +Subst<SubstQualifier::Default> replace(Node *node) +{ + // Let's create Subst<SubstQualifier::Default>! + return Subst<SubstQualifier::Default>{node}; +} + +} // namespace loco diff --git a/compiler/loco/src/IR/Node.test.cpp b/compiler/loco/src/IR/Node.test.cpp new file mode 100644 index 000000000..00e444465 --- /dev/null +++ b/compiler/loco/src/IR/Node.test.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Node.h" + +#include "MockupNode.h" + +#include <gtest/gtest.h> + +TEST(NodeTest, preds) +{ + ::MockupNode arg; + ::MockupNode node; + + node.in(&arg); + + auto preds = loco::preds(&node); + + ASSERT_EQ(preds.size(), 1); + ASSERT_NE(preds.find(&arg), preds.end()); +} + +TEST(NodeTest, succs) +{ + ::MockupNode node; + ::MockupNode succ_1; + ::MockupNode succ_2; + + succ_1.in(&node); + succ_2.in(&node); + + auto succs = loco::succs(&node); + + ASSERT_EQ(succs.size(), 2); + ASSERT_NE(succs.find(&succ_1), succs.end()); + ASSERT_NE(succs.find(&succ_2), succs.end()); +} + +TEST(NodeTest, replace_with) +{ + ::MockupNode node_1; + ::MockupNode node_2; + + ::MockupNode node_3; + ::MockupNode node_4; + + node_3.in(&node_1); + node_4.in(&node_2); + + // The following holds at this point + // - node_3 USE node_1 + // - node_4 USE node_2 + ASSERT_EQ(node_3.in(), &node_1); + ASSERT_EQ(node_4.in(), &node_2); + + // Replace all the usage of node_1 with node_2 + replace(&node_1).with(&node_2); + + // The following holds at this point + // - node_3 USE node_2 + // - node_4 USE node_2 + ASSERT_EQ(node_3.in(), &node_2); + ASSERT_EQ(node_4.in(), &node_2); +} + +TEST(NodeTest, constructor) +{ + MockupNode node; + + // graph() SHOULD return nullptr if node is not constructed through "Graph" + ASSERT_EQ(node.graph(), nullptr); +} + +// TODO Rewrite this as a FixedAritry mix-in test +#if 0 +TEST(FixedArityNodeTest, constructor) +{ + struct DerivedNode final : public loco::FixedArityNode<1, loco::Node> + { + loco::Dialect *dialect(void) const final { return MockDialect::get(); } + uint32_t opnum(void) const final { return 0; } + }; + + DerivedNode node; + + ASSERT_EQ(node.arity(), 1); + ASSERT_EQ(node.arg(0), nullptr); +} +#endif diff --git a/compiler/loco/src/IR/NodeMixins.cpp b/compiler/loco/src/IR/NodeMixins.cpp new file mode 100644 index 000000000..66037b17a --- /dev/null +++ b/compiler/loco/src/IR/NodeMixins.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/NodeMixins.h" + +// NOTE This file validates "NodeMixins.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/NodePool.cpp b/compiler/loco/src/IR/NodePool.cpp new file mode 100644 index 000000000..553f15eb5 --- /dev/null +++ b/compiler/loco/src/IR/NodePool.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/NodePool.h" + +namespace loco +{ + +NodePool::~NodePool() +{ + // Drop all the references before deallocation + for (uint32_t n = 0; n < size(); ++n) + { + at(n)->drop(); + } +} + +} // namespace loco diff --git a/compiler/loco/src/IR/NodeShape.cpp b/compiler/loco/src/IR/NodeShape.cpp new file mode 100644 index 000000000..0130cfbdb --- /dev/null +++ b/compiler/loco/src/IR/NodeShape.cpp @@ -0,0 +1,284 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/NodeShape.h" + +#include <cassert> +#include <stdexcept> + +// +// BiasShape Support +// +namespace loco +{ + +void NodeShape::set(const BiasShape &shape) +{ + _domain = Domain::Bias; + + _dims.resize(1); + _dims.at(0) = shape.length(); +} + +template <> BiasShape NodeShape::as(void) const +{ + assert(_domain == Domain::Bias); + + BiasShape res; + + res.length() = _dims.at(0); + + return res; +} + +} // namespace loco + +// +// DepthwiseFilterShape Support +// +namespace loco +{ + +void NodeShape::set(const DepthwiseFilterShape &shape) +{ + _domain = Domain::DepthwiseFilter; + + _dims.resize(4); + _dims.at(0) = shape.multiplier(); + _dims.at(1) = shape.depth(); + _dims.at(2) = shape.height(); + _dims.at(3) = shape.width(); +} + +template <> DepthwiseFilterShape NodeShape::as(void) const +{ + assert(_domain == Domain::DepthwiseFilter); + + DepthwiseFilterShape res; + + res.multiplier() = _dims.at(0); + res.depth() = _dims.at(1); + res.height() = _dims.at(2); + res.width() = _dims.at(3); + + return res; +} + +} // namespace loco + +// +// FeatureShape Support +// +namespace loco +{ + +void NodeShape::set(const FeatureShape &shape) +{ + _domain = Domain::Feature; + + _dims.resize(4); + _dims.at(0) = shape.count(); + _dims.at(1) = shape.depth(); + _dims.at(2) = shape.height(); + _dims.at(3) = shape.width(); +} + +template <> FeatureShape NodeShape::as(void) const +{ + assert(_domain == Domain::Feature); + + FeatureShape res; + + res.count() = _dims.at(0); + res.depth() = _dims.at(1); + res.height() = _dims.at(2); + res.width() = _dims.at(3); + + return res; +} + +} // namespace loco + +// +// FilterShape Support +// +namespace loco +{ + +void NodeShape::set(const FilterShape &shape) +{ + _domain = Domain::Filter; + + _dims.resize(4); + _dims.at(0) = shape.count(); + _dims.at(1) = shape.depth(); + _dims.at(2) = shape.height(); + _dims.at(3) = shape.width(); +} + +template <> FilterShape NodeShape::as(void) const +{ + assert(_domain == Domain::Filter); + + FilterShape res; + + res.count() = _dims.at(0); + res.depth() = _dims.at(1); + res.height() = _dims.at(2); + res.width() = _dims.at(3); + + return res; +} + +} // namespace loco + +// +// MatrixShape Support +// +namespace loco +{ + +void NodeShape::set(const MatrixShape &shape) +{ + _domain = Domain::Matrix; + + _dims.resize(2); + _dims.at(0) = shape.height(); + _dims.at(1) = shape.width(); +} + +template <> MatrixShape NodeShape::as(void) const +{ + assert(_domain == Domain::Matrix); + + MatrixShape res; + + res.height() = _dims.at(0); + res.width() = _dims.at(1); + + return res; +} + +} // namespace loco + +// +// TensorShape Support +// +namespace loco +{ + +void NodeShape::set(const TensorShape &shape) +{ + _domain = Domain::Tensor; + + _dims.resize(shape.rank()); + for (uint32_t axis = 0; axis < shape.rank(); ++axis) + { + _dims.at(axis) = shape.dim(axis); + } +} + +template <> TensorShape NodeShape::as(void) const +{ + assert(_domain == Domain::Tensor); + + TensorShape res; + + res.rank(_dims.size()); + for (uint32_t axis = 0; axis < _dims.size(); ++axis) + { + res.dim(axis) = _dims.at(axis); + } + + return res; +} + +} // namespace loco + +namespace loco +{ + +bool operator==(const NodeShape &lhs, const NodeShape &rhs) +{ + if (lhs.domain() != rhs.domain()) + return false; + + switch (lhs.domain()) + { + case loco::Domain::Tensor: + { + auto lhs_t = lhs.as<TensorShape>(); + auto rhs_t = rhs.as<TensorShape>(); + if (lhs_t.rank() != rhs_t.rank()) + return false; + for (uint32_t axis = 0; axis < lhs_t.rank(); ++axis) + { + if (!(lhs_t.dim(axis) == rhs_t.dim(axis))) + return false; + } + return true; + } + + case loco::Domain::Feature: + { + auto lhs_f = lhs.as<FeatureShape>(); + auto rhs_f = rhs.as<FeatureShape>(); + + return (lhs_f.count() == rhs_f.count() && lhs_f.depth() == rhs_f.depth() && + lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width()); + } + + case loco::Domain::Filter: + { + auto lhs_f = lhs.as<FilterShape>(); + auto rhs_f = rhs.as<FilterShape>(); + + return (lhs_f.count() == rhs_f.count() && lhs_f.depth() == rhs_f.depth() && + lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width()); + } + + case loco::Domain::DepthwiseFilter: + { + auto lhs_f = lhs.as<DepthwiseFilterShape>(); + auto rhs_f = rhs.as<DepthwiseFilterShape>(); + + return (lhs_f.multiplier() == rhs_f.multiplier() && lhs_f.depth() == rhs_f.depth() && + lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width()); + } + + case loco::Domain::Bias: + { + auto lhs_f = lhs.as<BiasShape>(); + auto rhs_f = rhs.as<BiasShape>(); + + return (lhs_f.length() == rhs_f.length()); + } + + case loco::Domain::Matrix: + { + auto lhs_f = lhs.as<MatrixShape>(); + auto rhs_f = rhs.as<MatrixShape>(); + + return (lhs_f.height() == rhs_f.height() && lhs_f.width() == rhs_f.width()); + } + + default: + throw std::runtime_error("Not supported domain for NodeShape equality"); + } + return false; +} + +} // namespace loco diff --git a/compiler/loco/src/IR/NodeShape.test.cpp b/compiler/loco/src/IR/NodeShape.test.cpp new file mode 100644 index 000000000..4f092e024 --- /dev/null +++ b/compiler/loco/src/IR/NodeShape.test.cpp @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/NodeShape.h" + +#include <gtest/gtest.h> + +TEST(NodeShapeTest, default_constructor) +{ + loco::NodeShape node_shape; + + ASSERT_EQ(node_shape.domain(), loco::Domain::Unknown); +} + +TEST(NodeShapeTest, bias_shape_constructor) +{ + loco::BiasShape bias_shape; + + bias_shape.length() = 4; + + loco::NodeShape node_shape{bias_shape}; + + ASSERT_EQ(node_shape.domain(), loco::Domain::Bias); + ASSERT_EQ(node_shape.as<loco::BiasShape>().length(), 4); +} + +TEST(NodeShapeTest, dwfilter_shape_constructor) +{ + loco::DepthwiseFilterShape dwfilter_shape; + + dwfilter_shape.depth() = 2; + dwfilter_shape.multiplier() = 3; + dwfilter_shape.height() = 4; + dwfilter_shape.width() = 5; + + loco::NodeShape node_shape{dwfilter_shape}; + + ASSERT_EQ(node_shape.domain(), loco::Domain::DepthwiseFilter); + ASSERT_EQ(node_shape.as<loco::DepthwiseFilterShape>().depth(), 2); + ASSERT_EQ(node_shape.as<loco::DepthwiseFilterShape>().multiplier(), 3); + ASSERT_EQ(node_shape.as<loco::DepthwiseFilterShape>().height(), 4); + ASSERT_EQ(node_shape.as<loco::DepthwiseFilterShape>().width(), 5); +} + +TEST(NodeShapeTest, feature_shape_constructor) +{ + loco::FeatureShape feature_shape; + + feature_shape.count() = 2; + feature_shape.depth() = 3; + feature_shape.height() = 4; + feature_shape.width() = 5; + + loco::NodeShape node_shape{feature_shape}; + + ASSERT_EQ(node_shape.domain(), loco::Domain::Feature); + ASSERT_EQ(node_shape.as<loco::FeatureShape>().count(), 2); + ASSERT_EQ(node_shape.as<loco::FeatureShape>().depth(), 3); + ASSERT_EQ(node_shape.as<loco::FeatureShape>().height(), 4); + ASSERT_EQ(node_shape.as<loco::FeatureShape>().width(), 5); +} + +TEST(NodeShapeTest, filter_shape_constructor) +{ + loco::FilterShape filter_shape; + + filter_shape.count() = 2; + filter_shape.depth() = 3; + filter_shape.height() = 4; + filter_shape.width() = 5; + + loco::NodeShape node_shape{filter_shape}; + + ASSERT_EQ(node_shape.domain(), loco::Domain::Filter); + ASSERT_EQ(node_shape.as<loco::FilterShape>().count(), 2); + ASSERT_EQ(node_shape.as<loco::FilterShape>().depth(), 3); + ASSERT_EQ(node_shape.as<loco::FilterShape>().height(), 4); + ASSERT_EQ(node_shape.as<loco::FilterShape>().width(), 5); +} + +TEST(NodeShapeTest, tensor_shape_constructor) +{ + loco::TensorShape tensor_shape; + + tensor_shape.rank(2); + tensor_shape.dim(0) = 4; + tensor_shape.dim(1) = 5; + + loco::NodeShape node_shape{tensor_shape}; + + ASSERT_EQ(node_shape.domain(), loco::Domain::Tensor); + ASSERT_EQ(node_shape.as<loco::TensorShape>().rank(), 2); + ASSERT_EQ(node_shape.as<loco::TensorShape>().dim(0), 4); + ASSERT_EQ(node_shape.as<loco::TensorShape>().dim(1), 5); +} + +TEST(NodeShapeTest, copy_constructible) +{ + loco::TensorShape tensor_shape; + + tensor_shape.rank(2); + tensor_shape.dim(0) = 4; + tensor_shape.dim(1) = 5; + + loco::NodeShape orig{tensor_shape}; + loco::NodeShape copy{orig}; // Call Copy Constructor + + ASSERT_EQ(copy.domain(), loco::Domain::Tensor); + ASSERT_EQ(copy.as<loco::TensorShape>().rank(), 2); + ASSERT_EQ(copy.as<loco::TensorShape>().dim(0), 4); + ASSERT_EQ(copy.as<loco::TensorShape>().dim(1), 5); +} diff --git a/compiler/loco/src/IR/Nodes.cpp b/compiler/loco/src/IR/Nodes.cpp new file mode 100644 index 000000000..133b69430 --- /dev/null +++ b/compiler/loco/src/IR/Nodes.cpp @@ -0,0 +1,243 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Nodes.h" +#include "loco/IR/Graph.h" + +#include <cassert> +#include <limits> + +// This file validates "Nodes.h". Please DO NOT remove this file. +namespace +{ + +/** + * @note This function is currently only used in assert. Compiler will + * warn/error this function as unused in Release build. + * Making inline will make compiler happy. + */ +// Is it possible to update lhs as rhs? +inline bool dtype_assignable(loco::DataType lhs, loco::DataType rhs) +{ + if (lhs == loco::DataType::Unknown) + { + return true; + } + + // lhs is already known, and thus rhs should be matched + return lhs == rhs; +} + +} // namespace + +/** + * Push + */ +namespace loco +{ + +void Push::index(const GraphOutputIndex &index) +{ + // Push internally stores "GraphOutputIndex" as int64_t + _index = static_cast<int64_t>(index); +} + +GraphOutputIndex Push::index(void) const +{ + assert(_index >= std::numeric_limits<GraphOutputIndex>::min()); + assert(_index <= std::numeric_limits<GraphOutputIndex>::max()); + return static_cast<GraphOutputIndex>(_index); +} + +void link(GraphOutput *output, Push *push) { push->index(output->index()); } + +Push *push_node(Graph *g, const GraphOutputIndex &index) +{ + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + if (auto push = dynamic_cast<Push *>(g->nodes()->at(n))) + { + if (push->indexed() && push->index() == index) + { + return push; + } + } + } + return nullptr; +} + +} // namespace loco + +/** + * Pull + */ +namespace loco +{ + +void Pull::index(const GraphInputIndex &index) +{ + // ASSUMPTION + // + // It is possible to update index multiple times, but only with the same value! + assert(!indexed() or _index == index); + + if (indexed()) + { + assert(_index == index); + return; + } + + // Push internally stores "GraphInputIndex" as int64_t + _index = static_cast<int64_t>(index); + + // ASSUMPTION: The return value of graph() never changes! + if (graph() != nullptr && _dtype != loco::DataType::Unknown) + { + // Update Graph-level input only if it is not yet specified + if (graph()->inputs()->at(_index)->dtype() == DataType::Unknown) + { + graph()->inputs()->at(_index)->dtype(_dtype); + } + assert(graph()->inputs()->at(_index)->dtype() == _dtype); + graph()->inputs()->at(_index)->dtype(_dtype); + + // Reset the locally cached data + _dtype = DataType::Unknown; + } +} + +GraphInputIndex Pull::index(void) const +{ + assert(_index >= std::numeric_limits<GraphInputIndex>::min()); + assert(_index <= std::numeric_limits<GraphInputIndex>::max()); + return static_cast<GraphInputIndex>(_index); +} + +void Pull::dtype(const DataType &dt) +{ + // ASSUMPTION: "dtype" is never invalidated! + assert(dt != loco::DataType::Unknown); + // ASSUMPTION + // + // It is possible to update index multiple times, but only with the same value! + if (indexed()) + { + assert(dtype_assignable(graph()->inputs()->at(_index)->dtype(), dt)); + graph()->inputs()->at(_index)->dtype(dt); + return; + } + + // Use local cache + _dtype = dt; +} + +DataType Pull::dtype(void) const +{ + if (graph() != nullptr and _index >= 0) + { + assert(_dtype == DataType::Unknown); + return graph()->inputs()->at(_index)->dtype(); + } + else + { + return _dtype; + } +} + +void link(GraphInput *input, Pull *pull) { pull->index(input->index()); } + +Pull *pull_node(Graph *g, const GraphInputIndex &index) +{ + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + if (auto pull = dynamic_cast<Pull *>(g->nodes()->at(n))) + { + if (pull->indexed() && pull->index() == index) + { + return pull; + } + } + } + return nullptr; +} + +} // namespace loco + +/** + * ConstGen + */ +namespace loco +{ + +template <DataType DT> uint32_t ConstGen::size(void) const +{ + assert(dtype() == DT); + assert(_data.size() % sizeof(typename DataTypeImpl<DT>::Type) == 0); + return _data.size() / sizeof(typename DataTypeImpl<DT>::Type); +} + +template <DataType DT> void ConstGen::size(uint32_t l) +{ + assert(dtype() == DT); + _data.resize(l * sizeof(typename DataTypeImpl<DT>::Type)); +} + +template <DataType DT> const typename DataTypeImpl<DT>::Type &ConstGen::at(uint32_t n) const +{ + assert(dtype() == DT); + assert(n < size<DT>()); + return *(reinterpret_cast<const typename DataTypeImpl<DT>::Type *>(_data.data()) + n); +} + +template <DataType DT> typename DataTypeImpl<DT>::Type &ConstGen::at(uint32_t n) +{ + assert(dtype() == DT); + assert(n < size<DT>()); + return *(reinterpret_cast<typename DataTypeImpl<DT>::Type *>(_data.data()) + n); +} + +#define INSTANTIATE(DT) \ + template uint32_t ConstGen::size<DT>(void) const; \ + template void ConstGen::size<DT>(uint32_t); \ + template const typename DataTypeImpl<DT>::Type &ConstGen::at<DT>(uint32_t) const; \ + template typename DataTypeImpl<DT>::Type &ConstGen::at<DT>(uint32_t); + +INSTANTIATE(DataType::S32); +INSTANTIATE(DataType::FLOAT32); + +#undef INSTANTIATE + +} // namespace loco + +/** + * TensorBroadcast + */ +namespace loco +{ + +bool TensorBroadcast::Mapping::defined(const TensorAxis &axis) const +{ + return _content.find(axis) != _content.end(); +} + +const Dimension &TensorBroadcast::Mapping::dim(const TensorAxis &axis) const +{ + return _content.at(axis); +} + +Dimension &TensorBroadcast::Mapping::dim(const TensorAxis &axis) { return _content[axis]; } + +} // namespace loco diff --git a/compiler/loco/src/IR/Nodes.test.cpp b/compiler/loco/src/IR/Nodes.test.cpp new file mode 100644 index 000000000..cd51f46c0 --- /dev/null +++ b/compiler/loco/src/IR/Nodes.test.cpp @@ -0,0 +1,588 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Nodes.h" +#include "loco/IR/CanonicalDialect.h" + +#include <gtest/gtest.h> + +TEST(PushTest, constructor) +{ + loco::Push push_node; + + ASSERT_EQ(push_node.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(push_node.opcode(), loco::CanonicalOpcode::Push); + + ASSERT_FALSE(push_node.indexed()); +} + +TEST(PushTest, shape) +{ + const std::vector<uint32_t> dims{1, 8, 16, 3}; + + loco::Pull push_node; + + push_node.shape({dims[0], dims[1], dims[2], dims[3]}); + + ASSERT_EQ(push_node.rank(), dims.size()); + ASSERT_EQ(push_node.dim(0), dims[0]); + ASSERT_EQ(push_node.dim(1), dims[1]); + ASSERT_EQ(push_node.dim(2), dims[2]); + ASSERT_EQ(push_node.dim(3), dims[3]); +} + +TEST(PullTest, constructor) +{ + loco::Pull pull_node; + + ASSERT_EQ(pull_node.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(pull_node.opcode(), loco::CanonicalOpcode::Pull); + + ASSERT_FALSE(pull_node.indexed()); + + ASSERT_EQ(pull_node.dtype(), loco::DataType::Unknown); + ASSERT_EQ(pull_node.rank(), 0); +} + +TEST(PullTest, shape) +{ + const std::vector<uint32_t> dims{1, 8, 16, 3}; + + loco::Pull pull_node; + + pull_node.shape({dims[0], dims[1], dims[2], dims[3]}); + + ASSERT_EQ(pull_node.rank(), dims.size()); + ASSERT_EQ(pull_node.dim(0), dims[0]); + ASSERT_EQ(pull_node.dim(1), dims[1]); + ASSERT_EQ(pull_node.dim(2), dims[2]); + ASSERT_EQ(pull_node.dim(3), dims[3]); +} + +TEST(ForwardTest, constructor) +{ + loco::Forward forward_node; + + ASSERT_EQ(forward_node.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(forward_node.opcode(), loco::CanonicalOpcode::Forward); + + ASSERT_EQ(forward_node.input(), nullptr); +} + +TEST(ReLUTest, constructor) +{ + loco::ReLU relu_node; + + ASSERT_EQ(relu_node.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(relu_node.opcode(), loco::CanonicalOpcode::ReLU); + + ASSERT_EQ(relu_node.input(), nullptr); +} + +TEST(ReLU6Test, constructor) +{ + loco::ReLU6 relu6_node; + + ASSERT_EQ(relu6_node.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(relu6_node.opcode(), loco::CanonicalOpcode::ReLU6); + + ASSERT_EQ(relu6_node.input(), nullptr); +} + +TEST(ConstGenTest, constructor) +{ + loco::ConstGen constgen_node; + + ASSERT_EQ(constgen_node.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(constgen_node.opcode(), loco::CanonicalOpcode::ConstGen); + + ASSERT_EQ(constgen_node.dtype(), loco::DataType::Unknown); + ASSERT_EQ(constgen_node.rank(), 0); + + constgen_node.dtype(loco::DataType::FLOAT32); + ASSERT_EQ(constgen_node.dtype(), loco::DataType::FLOAT32); + + constgen_node.rank(2); + ASSERT_EQ(constgen_node.rank(), 2); + + constgen_node.dim(0) = 2; + constgen_node.dim(1) = 3; + + ASSERT_TRUE(constgen_node.dim(0).known()); + ASSERT_TRUE(constgen_node.dim(1).known()); + + ASSERT_EQ(constgen_node.dim(0), 2); + ASSERT_EQ(constgen_node.dim(1), 3); + + constgen_node.size<loco::DataType::FLOAT32>(6); + + ASSERT_EQ(constgen_node.size<loco::DataType::FLOAT32>(), 6); + + constgen_node.at<loco::DataType::FLOAT32>(0) = 0.0f; // Set 0,0 + constgen_node.at<loco::DataType::FLOAT32>(1) = 1.0f; // Set 0,1 + constgen_node.at<loco::DataType::FLOAT32>(2) = 2.0f; // Set 0,2 + constgen_node.at<loco::DataType::FLOAT32>(3) = 3.0f; // Set 1,0 + constgen_node.at<loco::DataType::FLOAT32>(4) = 4.0f; // Set 1,1 + constgen_node.at<loco::DataType::FLOAT32>(5) = 5.0f; // Set 1,2 + + ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(0), 0.0f); + ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(1), 1.0f); + ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(2), 2.0f); + ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(3), 3.0f); + ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(4), 4.0f); + ASSERT_EQ(constgen_node.at<loco::DataType::FLOAT32>(5), 5.0f); +} + +TEST(ConstGenTest, constructor_s32) +{ + loco::ConstGen constgen_node; + + ASSERT_EQ(constgen_node.dtype(), loco::DataType::Unknown); + ASSERT_EQ(constgen_node.rank(), 0); + + constgen_node.dtype(loco::DataType::S32); + ASSERT_EQ(constgen_node.dtype(), loco::DataType::S32); + + constgen_node.rank(2); + ASSERT_EQ(constgen_node.rank(), 2); + + constgen_node.dim(0) = 2; + constgen_node.dim(1) = 3; + + ASSERT_TRUE(constgen_node.dim(0).known()); + ASSERT_TRUE(constgen_node.dim(1).known()); + + ASSERT_EQ(constgen_node.dim(0), 2); + ASSERT_EQ(constgen_node.dim(1), 3); + + constgen_node.size<loco::DataType::S32>(6); + + ASSERT_EQ(constgen_node.size<loco::DataType::S32>(), 6); + + constgen_node.at<loco::DataType::S32>(0) = 0; // Set 0,0 + constgen_node.at<loco::DataType::S32>(1) = 1; // Set 0,1 + constgen_node.at<loco::DataType::S32>(2) = 2; // Set 0,2 + constgen_node.at<loco::DataType::S32>(3) = -3; // Set 1,0 + constgen_node.at<loco::DataType::S32>(4) = -4; // Set 1,1 + constgen_node.at<loco::DataType::S32>(5) = -5; // Set 1,2 + + ASSERT_EQ(constgen_node.at<loco::DataType::S32>(0), 0); + ASSERT_EQ(constgen_node.at<loco::DataType::S32>(1), 1); + ASSERT_EQ(constgen_node.at<loco::DataType::S32>(2), 2); + ASSERT_EQ(constgen_node.at<loco::DataType::S32>(3), -3); + ASSERT_EQ(constgen_node.at<loco::DataType::S32>(4), -4); + ASSERT_EQ(constgen_node.at<loco::DataType::S32>(5), -5); +} + +TEST(MaxPool2DTest, constructor) +{ + loco::MaxPool2D maxpool_node; + + ASSERT_EQ(maxpool_node.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(maxpool_node.opcode(), loco::CanonicalOpcode::MaxPool2D); + + ASSERT_EQ(maxpool_node.ifm(), nullptr); + + ASSERT_EQ(maxpool_node.pad()->top(), 0); + ASSERT_EQ(maxpool_node.pad()->bottom(), 0); + ASSERT_EQ(maxpool_node.pad()->left(), 0); + ASSERT_EQ(maxpool_node.pad()->right(), 0); + + ASSERT_EQ(maxpool_node.window()->vertical(), 1); + ASSERT_EQ(maxpool_node.window()->horizontal(), 1); + + ASSERT_EQ(maxpool_node.stride()->vertical(), 1); + ASSERT_EQ(maxpool_node.stride()->horizontal(), 1); +} + +TEST(MaxPool2DTest, pad) +{ + const uint32_t t = 1; + const uint32_t b = 2; + const uint32_t l = 3; + const uint32_t r = 4; + + loco::MaxPool2D maxpool_node; + + maxpool_node.pad()->top(t); + ASSERT_EQ(maxpool_node.pad()->top(), t); + + maxpool_node.pad()->bottom(b); + ASSERT_EQ(maxpool_node.pad()->bottom(), b); + + maxpool_node.pad()->left(l); + ASSERT_EQ(maxpool_node.pad()->left(), l); + + maxpool_node.pad()->right(r); + ASSERT_EQ(maxpool_node.pad()->right(), r); +} + +TEST(AvgPool2DTest, constructor) +{ + loco::AvgPool2D avgpool_node; + + ASSERT_EQ(avgpool_node.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(avgpool_node.opcode(), loco::CanonicalOpcode::AvgPool2D); + + ASSERT_EQ(avgpool_node.ifm(), nullptr); + + ASSERT_EQ(avgpool_node.convention(), loco::AvgPool2D::Convention::Unknown); + + ASSERT_EQ(avgpool_node.pad()->top(), 0); + ASSERT_EQ(avgpool_node.pad()->bottom(), 0); + ASSERT_EQ(avgpool_node.pad()->left(), 0); + ASSERT_EQ(avgpool_node.pad()->right(), 0); + + ASSERT_EQ(avgpool_node.window()->vertical(), 1); + ASSERT_EQ(avgpool_node.window()->horizontal(), 1); + + ASSERT_EQ(avgpool_node.stride()->vertical(), 1); + ASSERT_EQ(avgpool_node.stride()->horizontal(), 1); +} + +TEST(FeatureEncodeTest, constructor) +{ + loco::FeatureEncode feature_encode; + + ASSERT_EQ(feature_encode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(feature_encode.opcode(), loco::CanonicalOpcode::FeatureEncode); + + ASSERT_EQ(feature_encode.input(), nullptr); + ASSERT_EQ(feature_encode.encoder(), nullptr); +} + +TEST(FeatureDecodeTest, constructor) +{ + loco::FeatureDecode feature_decode; + + ASSERT_EQ(feature_decode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(feature_decode.opcode(), loco::CanonicalOpcode::FeatureDecode); + + ASSERT_EQ(feature_decode.input(), nullptr); + ASSERT_EQ(feature_decode.decoder(), nullptr); +} + +TEST(Reshape_Fixed_Test, constructor) +{ + loco::Reshape<loco::ReshapeType::Fixed> reshape; + + ASSERT_EQ(reshape.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(reshape.opcode(), loco::CanonicalOpcode::FixedReshape); + + ASSERT_EQ(reshape.rank(), 0); +} + +TEST(Reshape_Fixed_Test, shape) +{ + loco::Reshape<loco::ReshapeType::Fixed> reshape; + reshape.shape({2, 3}); + + ASSERT_EQ(reshape.rank(), 2); + ASSERT_EQ(reshape.dim(0), 2); + ASSERT_EQ(reshape.dim(1), 3); +} + +TEST(FilterEncodeTest, constructor) +{ + loco::FilterEncode filter_encode; + + ASSERT_EQ(filter_encode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(filter_encode.opcode(), loco::CanonicalOpcode::FilterEncode); + + ASSERT_EQ(filter_encode.input(), nullptr); + ASSERT_EQ(filter_encode.encoder(), nullptr); +} + +TEST(FilterDecodeTest, constructor) +{ + loco::FilterDecode filter_decode; + + ASSERT_EQ(filter_decode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(filter_decode.opcode(), loco::CanonicalOpcode::FilterDecode); + + ASSERT_EQ(filter_decode.input(), nullptr); + ASSERT_EQ(filter_decode.decoder(), nullptr); +} + +TEST(DepthwiseFilterEncodeTest, constructor) +{ + loco::DepthwiseFilterEncode dw_filter_encode; + + ASSERT_EQ(dw_filter_encode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(dw_filter_encode.opcode(), loco::CanonicalOpcode::DepthwiseFilterEncode); + + ASSERT_EQ(dw_filter_encode.input(), nullptr); + ASSERT_EQ(dw_filter_encode.encoder(), nullptr); +} + +TEST(DepthwiseFilterDecodeTest, constructor) +{ + loco::DepthwiseFilterDecode dw_filter_decode; + + ASSERT_EQ(dw_filter_decode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(dw_filter_decode.opcode(), loco::CanonicalOpcode::DepthwiseFilterDecode); + + ASSERT_EQ(dw_filter_decode.input(), nullptr); + ASSERT_EQ(dw_filter_decode.decoder(), nullptr); +} + +TEST(TensorConcatTest, constructor) +{ + loco::TensorConcat tensor_concat; + + ASSERT_EQ(tensor_concat.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(tensor_concat.opcode(), loco::CanonicalOpcode::TensorConcat); + + ASSERT_EQ(tensor_concat.lhs(), nullptr); + ASSERT_EQ(tensor_concat.rhs(), nullptr); + ASSERT_EQ(tensor_concat.axis(), 0); + + tensor_concat.axis(3); + ASSERT_EQ(tensor_concat.axis(), 3); +} + +TEST(Conv2DTest, constructor) +{ + loco::Conv2D conv2d; + + ASSERT_EQ(conv2d.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(conv2d.opcode(), loco::CanonicalOpcode::Conv2D); + + ASSERT_EQ(conv2d.ifm(), nullptr); + ASSERT_EQ(conv2d.ker(), nullptr); + + ASSERT_NE(conv2d.pad(), nullptr); + ASSERT_EQ(conv2d.pad()->top(), 0); + ASSERT_EQ(conv2d.pad()->bottom(), 0); + ASSERT_EQ(conv2d.pad()->left(), 0); + ASSERT_EQ(conv2d.pad()->right(), 0); + + ASSERT_NE(conv2d.stride(), nullptr); + ASSERT_EQ(conv2d.stride()->vertical(), 1); + ASSERT_EQ(conv2d.stride()->horizontal(), 1); +} + +TEST(DepthwiseConv2DTest, constructor) +{ + loco::DepthwiseConv2D dw_conv2d; + + ASSERT_EQ(dw_conv2d.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(dw_conv2d.opcode(), loco::CanonicalOpcode::DepthwiseConv2D); + + ASSERT_EQ(dw_conv2d.ifm(), nullptr); + ASSERT_EQ(dw_conv2d.ker(), nullptr); + + ASSERT_NE(dw_conv2d.pad(), nullptr); + ASSERT_EQ(dw_conv2d.pad()->top(), 0); + ASSERT_EQ(dw_conv2d.pad()->bottom(), 0); + ASSERT_EQ(dw_conv2d.pad()->left(), 0); + ASSERT_EQ(dw_conv2d.pad()->right(), 0); + + ASSERT_NE(dw_conv2d.stride(), nullptr); + ASSERT_EQ(dw_conv2d.stride()->vertical(), 1); + ASSERT_EQ(dw_conv2d.stride()->horizontal(), 1); +} + +TEST(TransposedConv2DTest, constructor) +{ + loco::TransposedConv2D tr_conv2d; + + ASSERT_EQ(tr_conv2d.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(tr_conv2d.opcode(), loco::CanonicalOpcode::TransposedConv2D); + + ASSERT_EQ(tr_conv2d.ifm(), nullptr); + ASSERT_EQ(tr_conv2d.ker(), nullptr); + + ASSERT_NE(tr_conv2d.pad(), nullptr); + ASSERT_EQ(tr_conv2d.pad()->top(), 0); + ASSERT_EQ(tr_conv2d.pad()->bottom(), 0); + ASSERT_EQ(tr_conv2d.pad()->left(), 0); + ASSERT_EQ(tr_conv2d.pad()->right(), 0); + + ASSERT_NE(tr_conv2d.stride(), nullptr); + ASSERT_EQ(tr_conv2d.stride()->vertical(), 1); + ASSERT_EQ(tr_conv2d.stride()->horizontal(), 1); +} + +TEST(BiasEncodeTest, constructor) +{ + loco::BiasEncode bias_encode; + + ASSERT_EQ(bias_encode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(bias_encode.opcode(), loco::CanonicalOpcode::BiasEncode); + + ASSERT_EQ(bias_encode.input(), nullptr); +} + +TEST(TensorBiasAddTest, constructor) +{ + loco::BiasAdd<loco::Domain::Tensor> bias_add; + + ASSERT_EQ(bias_add.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(bias_add.opcode(), loco::CanonicalOpcode::TensorBiasAdd); + + ASSERT_EQ(bias_add.value(), nullptr); + ASSERT_EQ(bias_add.bias(), nullptr); + ASSERT_EQ(bias_add.axis(), 0); +} + +TEST(TensorBiasAddTest, alias) +{ + loco::TensorBiasAdd bias_add; + + SUCCEED(); +} + +TEST(FeatureBiasAddTest, constructor) +{ + loco::BiasAdd<loco::Domain::Feature> bias_add; + + ASSERT_EQ(bias_add.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(bias_add.opcode(), loco::CanonicalOpcode::FeatureBiasAdd); + + ASSERT_EQ(bias_add.value(), nullptr); + ASSERT_EQ(bias_add.bias(), nullptr); +} + +TEST(FeatureBiasAddTest, alias) +{ + loco::FeatureBiasAdd bias_add; + + SUCCEED(); +} + +TEST(EltwiseAddTest, constructor) +{ + loco::EltwiseAdd eltwise_add; + + SUCCEED(); +} + +TEST(EltwiseMaxTest, constructor) +{ + loco::EltwiseMax eltwise_max; + + SUCCEED(); +} + +TEST(EltwiseMulTest, constructor) +{ + loco::EltwiseMul eltwise_mul; + + SUCCEED(); +} + +TEST(EltwiseSubTest, constructor) +{ + loco::EltwiseSub eltwise_sub; + + SUCCEED(); +} + +TEST(EltwiseDivTest, constructor) +{ + loco::EltwiseDiv eltwise_div; + + SUCCEED(); +} + +TEST(EltwiseSqrtTest, constructor) +{ + loco::EltwiseSqrt sqrt_node; + + ASSERT_EQ(sqrt_node.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(sqrt_node.opcode(), loco::CanonicalOpcode::EltwiseSqrt); + + ASSERT_EQ(sqrt_node.input(), nullptr); +} + +TEST(TensorBroadcastTest, constructor) +{ + loco::TensorBroadcast tensor_broadcast_node; + + ASSERT_EQ(tensor_broadcast_node.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(tensor_broadcast_node.opcode(), loco::CanonicalOpcode::TensorBroadcast); + + ASSERT_EQ(tensor_broadcast_node.input(), nullptr); +} + +TEST(TensorBroadcastTest, mapping) +{ + loco::TensorBroadcast tensor_broadcast_node; + + ASSERT_EQ(tensor_broadcast_node.mapping()->defined(0), false); + + tensor_broadcast_node.mapping()->dim(0) = 3; + + ASSERT_EQ(tensor_broadcast_node.mapping()->defined(0), true); + ASSERT_EQ(tensor_broadcast_node.mapping()->dim(0), 3); +} + +TEST(MatrixEncodeTest, constructor) +{ + loco::MatrixEncode matrix_encode; + + ASSERT_EQ(matrix_encode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(matrix_encode.opcode(), loco::CanonicalOpcode::MatrixEncode); + + ASSERT_EQ(matrix_encode.input(), nullptr); +} + +TEST(MatrixDecodeTest, constructor) +{ + loco::MatrixDecode matrix_decode; + + ASSERT_EQ(matrix_decode.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(matrix_decode.opcode(), loco::CanonicalOpcode::MatrixDecode); + + ASSERT_EQ(matrix_decode.input(), nullptr); +} + +TEST(MatMulTest, constructor) +{ + loco::MatMul mat_mul; + + ASSERT_EQ(mat_mul.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(mat_mul.opcode(), loco::CanonicalOpcode::MatMul); + + ASSERT_EQ(mat_mul.lhs(), nullptr); + ASSERT_EQ(mat_mul.rhs(), nullptr); +} + +TEST(TransposeTest, constructor) +{ + loco::TensorTranspose transpose; + + ASSERT_EQ(transpose.dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(transpose.opcode(), loco::CanonicalOpcode::TensorTranspose); + + ASSERT_EQ(transpose.input(), nullptr); + ASSERT_EQ(transpose.perm()->size(), 0); +} + +TEST(TransposeTest, perm) +{ + loco::TensorTranspose transpose; + + transpose.perm()->size(3); + transpose.perm()->axis(0) = 1; + transpose.perm()->axis(1) = 2; + transpose.perm()->axis(2) = 0; + + ASSERT_EQ(transpose.perm()->axis(0), 1); + ASSERT_EQ(transpose.perm()->axis(1), 2); + ASSERT_EQ(transpose.perm()->axis(2), 0); +} diff --git a/compiler/loco/src/IR/Padding2D.test.cpp b/compiler/loco/src/IR/Padding2D.test.cpp new file mode 100644 index 000000000..2e3d4af87 --- /dev/null +++ b/compiler/loco/src/IR/Padding2D.test.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Padding2D.h" + +#include <gtest/gtest.h> + +TEST(PadTest, default_constructor_2D) +{ + loco::Padding2D pad; + + ASSERT_EQ(pad.top(), 0); + ASSERT_EQ(pad.bottom(), 0); + ASSERT_EQ(pad.left(), 0); + ASSERT_EQ(pad.right(), 0); +} diff --git a/compiler/loco/src/IR/PaddingND.test.cpp b/compiler/loco/src/IR/PaddingND.test.cpp new file mode 100644 index 000000000..0e20406ff --- /dev/null +++ b/compiler/loco/src/IR/PaddingND.test.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/PaddingND.h" + +#include <gtest/gtest.h> + +TEST(PaddingNDTest, default_constructor_ND) +{ + loco::PaddingND padding; + + padding.rank(1); + padding.front(0) = 1; + padding.back(0) = 2; + + ASSERT_EQ(padding.rank(), 1); + ASSERT_EQ(padding.front(0), 1); + ASSERT_EQ(padding.back(0), 2); +} diff --git a/compiler/loco/src/IR/PermutingCodec.cpp b/compiler/loco/src/IR/PermutingCodec.cpp new file mode 100644 index 000000000..2857e5e28 --- /dev/null +++ b/compiler/loco/src/IR/PermutingCodec.cpp @@ -0,0 +1,630 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/PermutingCodec.h" + +#include <stdex/Memory.h> + +#include <cassert> +#include <set> +#include <stdexcept> + +/** + * Feature Domain + */ +namespace +{ + +using loco::FeatureAxis; + +inline bool valid(const FeatureAxis &axis) +{ + switch (axis) + { + case FeatureAxis::Count: + return true; + case FeatureAxis::Depth: + return true; + case FeatureAxis::Height: + return true; + case FeatureAxis::Width: + return true; + default: + break; + } + + return false; +} + +inline bool valid(const loco::Permutation<loco::Domain::Feature> &perm) +{ + auto check = [&perm](FeatureAxis axis_f) { + if (!perm.mapped(axis_f)) + return false; + return perm.axis(axis_f) < 4; + }; + + if (!check(FeatureAxis::Count)) + return false; + if (!check(FeatureAxis::Depth)) + return false; + if (!check(FeatureAxis::Height)) + return false; + if (!check(FeatureAxis::Width)) + return false; + + // Check whether tensor axes are all distinct + std::set<loco::TensorAxis> values; + + values.insert(perm[FeatureAxis::Count]); + values.insert(perm[FeatureAxis::Depth]); + values.insert(perm[FeatureAxis::Height]); + values.insert(perm[FeatureAxis::Width]); + + return values.size() == 4; +} + +} // namespace + +namespace loco +{ + +// +// Permutation +// +bool Permutation<Domain::Feature>::mapped(const FeatureAxis &axis_f) const +{ + assert(valid(axis_f) && "invalid feature axis"); + return _map.find(axis_f) != _map.end(); +} + +uint32_t Permutation<Domain::Feature>::axis(const FeatureAxis &axis_f) const +{ + assert(valid(axis_f) && "invalid feature axis"); + assert(mapped(axis_f) && "unmapped feature axis"); + return _map.at(axis_f); +} + +uint32_t &Permutation<Domain::Feature>::axis(const FeatureAxis &axis_f) +{ + assert(valid(axis_f) && "invalid feature axis"); + return _map[axis_f]; +} + +// +// Permuting Encoder +// +FeatureShape PermutingEncoder<Domain::Feature>::shape(const TensorShape &in) const +{ + assert(valid() && "invalid permutation"); + + FeatureShape out; + + out.count() = in.dim(_perm[FeatureAxis::Count]); + out.depth() = in.dim(_perm[FeatureAxis::Depth]); + out.height() = in.dim(_perm[FeatureAxis::Height]); + out.width() = in.dim(_perm[FeatureAxis::Width]); + + return out; +} + +TensorIndex PermutingEncoder<Domain::Feature>::value(const FeatureIndex &in) const +{ + assert(valid() && "invalid permutation"); + + TensorIndex out; + + out.resize(4); + + out.at(_perm[FeatureAxis::Count]) = in.batch(); + out.at(_perm[FeatureAxis::Depth]) = in.channel(); + out.at(_perm[FeatureAxis::Height]) = in.row(); + out.at(_perm[FeatureAxis::Width]) = in.column(); + + return out; +} + +std::unique_ptr<FeatureEncoder> PermutingEncoder<Domain::Feature>::clone(void) const +{ + return stdex::make_unique<PermutingEncoder<Domain::Feature>>(_perm); +} + +bool PermutingEncoder<Domain::Feature>::valid(void) const { return ::valid(_perm); } + +// +// Permuting Decoder +// +TensorShape PermutingDecoder<Domain::Feature>::shape(const FeatureShape &in) const +{ + assert(valid() && "invalid permuation"); + + TensorShape out; + + out.rank(4); + + out.dim(_perm[FeatureAxis::Count]) = in.count(); + out.dim(_perm[FeatureAxis::Depth]) = in.depth(); + out.dim(_perm[FeatureAxis::Height]) = in.height(); + out.dim(_perm[FeatureAxis::Width]) = in.width(); + + return out; +} + +FeatureIndex PermutingDecoder<Domain::Feature>::value(const TensorIndex &in) const +{ + assert(valid() && "invalid permutation"); + + FeatureIndex out; + + out.batch() = in.at(_perm[FeatureAxis::Count]); + out.channel() = in.at(_perm[FeatureAxis::Depth]); + out.row() = in.at(_perm[FeatureAxis::Height]); + out.column() = in.at(_perm[FeatureAxis::Width]); + + return out; +} + +std::unique_ptr<FeatureDecoder> PermutingDecoder<Domain::Feature>::clone(void) const +{ + return stdex::make_unique<PermutingDecoder<Domain::Feature>>(_perm); +} + +bool PermutingDecoder<Domain::Feature>::valid(void) const { return ::valid(_perm); } + +} // namespace loco + +/** + * Filter Domain + */ +namespace +{ + +using loco::FilterAxis; + +inline bool valid(const FilterAxis &axis) +{ + switch (axis) + { + case FilterAxis::Count: + return true; + case FilterAxis::Depth: + return true; + case FilterAxis::Height: + return true; + case FilterAxis::Width: + return true; + default: + break; + } + + return false; +} + +inline bool valid(const loco::Permutation<loco::Domain::Filter> &perm) +{ + auto check = [&perm](FilterAxis axis_f) { + if (!perm.mapped(axis_f)) + return false; + return perm.axis(axis_f) < 4; + }; + + if (!check(FilterAxis::Count)) + return false; + if (!check(FilterAxis::Depth)) + return false; + if (!check(FilterAxis::Height)) + return false; + if (!check(FilterAxis::Width)) + return false; + + // Check whether tensor axes are all distinct + std::set<loco::TensorAxis> values; + + values.insert(perm[FilterAxis::Count]); + values.insert(perm[FilterAxis::Depth]); + values.insert(perm[FilterAxis::Height]); + values.insert(perm[FilterAxis::Width]); + + return values.size() == 4; +} + +} // namespace + +namespace loco +{ + +// +// Permutation +// +bool Permutation<Domain::Filter>::mapped(const FilterAxis &axis_f) const +{ + assert(valid(axis_f) && "invalid filter axis"); + return _map.find(axis_f) != _map.end(); +} + +const uint32_t &Permutation<Domain::Filter>::axis(const FilterAxis &axis_f) const +{ + assert(valid(axis_f) && "invalid filter axis"); + assert(mapped(axis_f) && "unmapped filter axis"); + return _map.at(axis_f); +} + +uint32_t &Permutation<Domain::Filter>::axis(const FilterAxis &axis_f) +{ + assert(valid(axis_f) && "invalid filter axis"); + return _map[axis_f]; +} + +// +// Permuting Encoder +// +FilterShape PermutingEncoder<Domain::Filter>::shape(const TensorShape &in) const +{ + assert(valid() && "invalid permutation"); + + FilterShape out; + + out.count() = in.dim(_perm[FilterAxis::Count]); + out.depth() = in.dim(_perm[FilterAxis::Depth]); + out.height() = in.dim(_perm[FilterAxis::Height]); + out.width() = in.dim(_perm[FilterAxis::Width]); + + return out; +} + +TensorIndex PermutingEncoder<Domain::Filter>::value(const FilterIndex &in) const +{ + assert(valid() && "invalid permutation"); + + TensorIndex out; + + out.resize(4); + + out.at(_perm[FilterAxis::Count]) = in.nth(); + out.at(_perm[FilterAxis::Depth]) = in.channel(); + out.at(_perm[FilterAxis::Height]) = in.row(); + out.at(_perm[FilterAxis::Width]) = in.column(); + + return out; +} + +bool PermutingEncoder<Domain::Filter>::valid(void) const { return ::valid(_perm); } + +// +// Permuting Decoder +// +TensorShape PermutingDecoder<Domain::Filter>::shape(const FilterShape &in) const +{ + assert(valid() && "invalid permutation"); + + TensorShape out; + + out.rank(4); + out.dim(_perm[FilterAxis::Count]) = in.count(); + out.dim(_perm[FilterAxis::Depth]) = in.depth(); + out.dim(_perm[FilterAxis::Height]) = in.height(); + out.dim(_perm[FilterAxis::Width]) = in.width(); + + return out; +} + +FilterIndex PermutingDecoder<Domain::Filter>::value(const TensorIndex &in) const +{ + assert(valid() && "invalid permutation"); + + FilterIndex out; + + out.nth() = in.at(_perm[FilterAxis::Count]); + out.channel() = in.at(_perm[FilterAxis::Depth]); + out.row() = in.at(_perm[FilterAxis::Height]); + out.column() = in.at(_perm[FilterAxis::Width]); + + return out; +} + +bool PermutingDecoder<Domain::Filter>::valid(void) const { return ::valid(_perm); } + +} // namespace loco + +/** + * DepthwiseFilter Domain + */ +namespace +{ + +using loco::DepthwiseFilterAxis; + +inline bool valid(const DepthwiseFilterAxis &axis) +{ + switch (axis) + { + case DepthwiseFilterAxis::Depth: + return true; + case DepthwiseFilterAxis::Multiplier: + return true; + case DepthwiseFilterAxis::Height: + return true; + case DepthwiseFilterAxis::Width: + return true; + default: + break; + } + + return false; +} + +inline bool valid(const loco::Permutation<loco::Domain::DepthwiseFilter> &perm) +{ + auto check = [&perm](DepthwiseFilterAxis axis_f) { + if (!perm.mapped(axis_f)) + return false; + return perm.axis(axis_f) < 4; + }; + + if (!check(DepthwiseFilterAxis::Depth)) + return false; + if (!check(DepthwiseFilterAxis::Multiplier)) + return false; + if (!check(DepthwiseFilterAxis::Height)) + return false; + if (!check(DepthwiseFilterAxis::Width)) + return false; + + // Check whether tensor axes are all distinct + std::set<loco::TensorAxis> values; + + values.insert(perm[DepthwiseFilterAxis::Depth]); + values.insert(perm[DepthwiseFilterAxis::Multiplier]); + values.insert(perm[DepthwiseFilterAxis::Height]); + values.insert(perm[DepthwiseFilterAxis::Width]); + + return values.size() == 4; +} + +} // namespace + +namespace loco +{ + +// +// Permutation +// +bool Permutation<Domain::DepthwiseFilter>::mapped(const DepthwiseFilterAxis &axis_f) const +{ + assert(valid(axis_f) && "invalid depthwise filter axis"); + return _map.find(axis_f) != _map.end(); +} + +const uint32_t &Permutation<Domain::DepthwiseFilter>::axis(const DepthwiseFilterAxis &axis_f) const +{ + assert(valid(axis_f) && "invalid depthwise filter axis"); + assert(mapped(axis_f) && "unmapped depthwise filter axis"); + return _map.at(axis_f); +} + +uint32_t &Permutation<Domain::DepthwiseFilter>::axis(const DepthwiseFilterAxis &axis_f) +{ + assert(valid(axis_f) && "invalid depthwise filter axis"); + return _map[axis_f]; +} + +// +// Permuting Encoder +// +DepthwiseFilterShape PermutingEncoder<Domain::DepthwiseFilter>::shape(const TensorShape &in) const +{ + assert(valid() && "invalid permutation"); + + DepthwiseFilterShape out; + + out.depth() = in.dim(_perm[DepthwiseFilterAxis::Depth]); + out.multiplier() = in.dim(_perm[DepthwiseFilterAxis::Multiplier]); + out.height() = in.dim(_perm[DepthwiseFilterAxis::Height]); + out.width() = in.dim(_perm[DepthwiseFilterAxis::Width]); + + return out; +} + +TensorIndex PermutingEncoder<Domain::DepthwiseFilter>::value(const DepthwiseFilterIndex &in) const +{ + assert(valid() && "invalid permutation"); + + TensorIndex out; + + out.resize(4); + + out.at(_perm[DepthwiseFilterAxis::Depth]) = in.channel(); + out.at(_perm[DepthwiseFilterAxis::Multiplier]) = in.nth(); + out.at(_perm[DepthwiseFilterAxis::Height]) = in.row(); + out.at(_perm[DepthwiseFilterAxis::Width]) = in.column(); + + return out; +} + +bool PermutingEncoder<Domain::DepthwiseFilter>::valid(void) const { return ::valid(_perm); } + +// +// Permuting Decoder +// +TensorShape PermutingDecoder<Domain::DepthwiseFilter>::shape(const DepthwiseFilterShape &in) const +{ + assert(valid() && "invalid permutation"); + + TensorShape out; + out.rank(4); + + out.dim(_perm[DepthwiseFilterAxis::Depth]) = in.depth(); + out.dim(_perm[DepthwiseFilterAxis::Multiplier]) = in.multiplier(); + out.dim(_perm[DepthwiseFilterAxis::Height]) = in.height(); + out.dim(_perm[DepthwiseFilterAxis::Width]) = in.width(); + + return out; +} + +DepthwiseFilterIndex PermutingDecoder<Domain::DepthwiseFilter>::value(const TensorIndex &in) const +{ + assert(valid() && "invalid permutation"); + assert(in.rank() == 4); + + DepthwiseFilterIndex out; + + out.channel() = in.at(_perm[DepthwiseFilterAxis::Depth]); + out.nth() = in.at(_perm[DepthwiseFilterAxis::Multiplier]); + out.row() = in.at(_perm[DepthwiseFilterAxis::Height]); + out.column() = in.at(_perm[DepthwiseFilterAxis::Width]); + + return out; +} + +bool PermutingDecoder<Domain::DepthwiseFilter>::valid(void) const { return ::valid(_perm); } + +} // namespace loco + +/** + * Matrix Domain + */ +namespace +{ + +using loco::MatrixAxis; + +inline bool valid(const MatrixAxis &axis) +{ + switch (axis) + { + case MatrixAxis::Height: + return true; + case MatrixAxis::Width: + return true; + default: + break; + } + + return false; +} + +inline bool valid(const loco::Permutation<loco::Domain::Matrix> &perm) +{ + auto check = [&perm](MatrixAxis axis_f) { + if (!perm.mapped(axis_f)) + return false; + return perm.axis(axis_f) < 2; + }; + + if (!check(MatrixAxis::Height)) + return false; + if (!check(MatrixAxis::Width)) + return false; + + // Check whether tensor axes are all distinct + std::set<loco::TensorAxis> values; + + values.insert(perm[MatrixAxis::Height]); + values.insert(perm[MatrixAxis::Width]); + + return values.size() == 2; +} + +} // namespace + +namespace loco +{ + +// +// Permutation +// +bool Permutation<Domain::Matrix>::mapped(const MatrixAxis &axis_f) const +{ + assert(valid(axis_f) && "invalid matrix axis"); + return _map.find(axis_f) != _map.end(); +} + +uint32_t Permutation<Domain::Matrix>::axis(const MatrixAxis &axis_f) const +{ + assert(valid(axis_f) && "invalid matrix axis"); + assert(mapped(axis_f) && "unmapped matrix axis"); + return _map.at(axis_f); +} + +uint32_t &Permutation<Domain::Matrix>::axis(const MatrixAxis &axis_f) +{ + assert(valid(axis_f) && "invalid matrix axis"); + return _map[axis_f]; +} + +// +// Permuting Encoder +// +MatrixShape PermutingEncoder<Domain::Matrix>::shape(const TensorShape &in) const +{ + assert(valid() && "invalid permutation"); + + MatrixShape out; + + out.height() = in.dim(_perm[MatrixAxis::Height]); + out.width() = in.dim(_perm[MatrixAxis::Width]); + + return out; +} + +TensorIndex PermutingEncoder<Domain::Matrix>::value(const MatrixIndex &in) const +{ + assert(valid() && "invalid permutation"); + + TensorIndex out; + + out.resize(2); + + out.at(_perm[MatrixAxis::Height]) = in.row(); + out.at(_perm[MatrixAxis::Width]) = in.column(); + + return out; +} + +bool PermutingEncoder<Domain::Matrix>::valid(void) const { return ::valid(_perm); } + +// +// Permuting Decoder +// +TensorShape PermutingDecoder<Domain::Matrix>::shape(const MatrixShape &in) const +{ + assert(valid() && "invalid permuation"); + + TensorShape out; + + out.rank(2); + + out.dim(_perm[MatrixAxis::Height]) = in.height(); + out.dim(_perm[MatrixAxis::Width]) = in.width(); + + return out; +} + +MatrixIndex PermutingDecoder<Domain::Matrix>::value(const TensorIndex &in) const +{ + assert(valid() && "invalid permutation"); + + MatrixIndex out; + + out.row() = in.at(_perm[MatrixAxis::Height]); + out.column() = in.at(_perm[MatrixAxis::Width]); + + return out; +} + +bool PermutingDecoder<Domain::Matrix>::valid(void) const { return ::valid(_perm); } + +} // namespace loco diff --git a/compiler/loco/src/IR/PermutingCodec.test.cpp b/compiler/loco/src/IR/PermutingCodec.test.cpp new file mode 100644 index 000000000..2eff286d0 --- /dev/null +++ b/compiler/loco/src/IR/PermutingCodec.test.cpp @@ -0,0 +1,553 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/PermutingCodec.h" + +#include <gtest/gtest.h> + +using namespace loco; + +TEST(PemutationTest, feature) +{ + Permutation<Domain::Feature> perm; + + // All values are invalid at the beginning + ASSERT_FALSE(perm.mapped(FeatureAxis::Count)); + ASSERT_FALSE(perm.mapped(FeatureAxis::Depth)); + ASSERT_FALSE(perm.mapped(FeatureAxis::Height)); + ASSERT_FALSE(perm.mapped(FeatureAxis::Width)); + + // Update mapping + perm[FeatureAxis::Count] = 5; + perm[FeatureAxis::Depth] = 6; + perm[FeatureAxis::Height] = 7; + perm[FeatureAxis::Width] = 8; + + // Now perm has a mapping for all the axes + ASSERT_TRUE(perm.mapped(FeatureAxis::Count)); + ASSERT_TRUE(perm.mapped(FeatureAxis::Depth)); + ASSERT_TRUE(perm.mapped(FeatureAxis::Height)); + ASSERT_TRUE(perm.mapped(FeatureAxis::Width)); + + // Check the value + ASSERT_EQ(perm[FeatureAxis::Count], 5); + ASSERT_EQ(perm[FeatureAxis::Depth], 6); + ASSERT_EQ(perm[FeatureAxis::Height], 7); + ASSERT_EQ(perm[FeatureAxis::Width], 8); +} + +TEST(PemutationTest, filter) +{ + Permutation<Domain::Filter> perm; + + // All values are invalid at the beginning + ASSERT_FALSE(perm.mapped(FilterAxis::Count)); + ASSERT_FALSE(perm.mapped(FilterAxis::Depth)); + ASSERT_FALSE(perm.mapped(FilterAxis::Height)); + ASSERT_FALSE(perm.mapped(FilterAxis::Width)); + + // Update mapping + perm[FilterAxis::Count] = 5; + perm[FilterAxis::Depth] = 6; + perm[FilterAxis::Height] = 7; + perm[FilterAxis::Width] = 8; + + // Now perm has a mapping for all the axes + ASSERT_TRUE(perm.mapped(FilterAxis::Count)); + ASSERT_TRUE(perm.mapped(FilterAxis::Depth)); + ASSERT_TRUE(perm.mapped(FilterAxis::Height)); + ASSERT_TRUE(perm.mapped(FilterAxis::Width)); + + // Check the value + ASSERT_EQ(perm[FilterAxis::Count], 5); + ASSERT_EQ(perm[FilterAxis::Depth], 6); + ASSERT_EQ(perm[FilterAxis::Height], 7); + ASSERT_EQ(perm[FilterAxis::Width], 8); +} + +TEST(PemutationTest, depthwise_filter) +{ + Permutation<Domain::DepthwiseFilter> perm; + + // All values are invalid at the beginning + ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Depth)); + ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Multiplier)); + ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Height)); + ASSERT_FALSE(perm.mapped(DepthwiseFilterAxis::Width)); + + // Update mapping + perm[DepthwiseFilterAxis::Depth] = 5; + perm[DepthwiseFilterAxis::Multiplier] = 6; + perm[DepthwiseFilterAxis::Height] = 7; + perm[DepthwiseFilterAxis::Width] = 8; + + // Now perm has a mapping for all the axes + ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Depth)); + ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Multiplier)); + ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Height)); + ASSERT_TRUE(perm.mapped(DepthwiseFilterAxis::Width)); + + // Check the value + ASSERT_EQ(perm[DepthwiseFilterAxis::Depth], 5); + ASSERT_EQ(perm[DepthwiseFilterAxis::Multiplier], 6); + ASSERT_EQ(perm[DepthwiseFilterAxis::Height], 7); + ASSERT_EQ(perm[DepthwiseFilterAxis::Width], 8); +} + +TEST(PermutingEncoderTest, feature) +{ + PermutingEncoder<Domain::Feature> enc; + + // Encoder is invalid at the beginning + ASSERT_FALSE(enc.valid()); + + // Set "invalid" mapping + enc.perm()->axis(FeatureAxis::Count) = 0; + enc.perm()->axis(FeatureAxis::Depth) = 6; + enc.perm()->axis(FeatureAxis::Height) = 1; + enc.perm()->axis(FeatureAxis::Width) = 2; + + // Encoder is still invalid + ASSERT_FALSE(enc.valid()); + + // Set another "invalid" mapping + enc.perm()->axis(FeatureAxis::Depth) = 1; + + // Encoder is still invalid + ASSERT_FALSE(enc.valid()); + + // Set "valid" mapping + enc.perm()->axis(FeatureAxis::Depth) = 3; + + // Encoder is now valid + ASSERT_TRUE(enc.valid()); + + // Let's test with a HD (1280x720) RGB image + TensorShape tensor_shape; + + tensor_shape.rank(4); + tensor_shape.dim(0) = 1; // COUNT + tensor_shape.dim(1) = 720; // HEIGHT + tensor_shape.dim(2) = 1280; // WIDTH + tensor_shape.dim(3) = 3; // DEPTH + + // Get the feature shape corresponding to a given image + auto feature_shape = enc.shape(tensor_shape); + + ASSERT_EQ(feature_shape.count(), 1); + ASSERT_EQ(feature_shape.depth(), 3); + ASSERT_EQ(feature_shape.height(), 720); + ASSERT_EQ(feature_shape.width(), 1280); + + // Let's find a source tensor index! + FeatureIndex feature_index; + + feature_index.batch() = 0; + feature_index.channel() = 1; + feature_index.row() = 2; + feature_index.column() = 3; + + auto tensor_index = enc.value(feature_index); + + ASSERT_EQ(tensor_index.at(0), 0); // BATCH(COUNT) + ASSERT_EQ(tensor_index.at(1), 2); // ROW(HEIGHT) + ASSERT_EQ(tensor_index.at(2), 3); // COLUMN(WIDTH) + ASSERT_EQ(tensor_index.at(3), 1); // CHANNEL(DEPTH) +} + +TEST(PermutingEncoderTest, feature_clone) +{ + PermutingEncoder<Domain::Feature> src_enc; + + auto src_perm = src_enc.perm(); + + src_perm->axis(FeatureAxis::Count) = 0; + src_perm->axis(FeatureAxis::Depth) = 3; + src_perm->axis(FeatureAxis::Height) = 1; + src_perm->axis(FeatureAxis::Width) = 2; + + auto dst_enc = src_enc.clone(); + auto dst_perm = dynamic_cast<PermutingEncoder<Domain::Feature> *>(dst_enc.get())->perm(); + + EXPECT_EQ(dst_perm->axis(FeatureAxis::Count), src_perm->axis(FeatureAxis::Count)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Depth), src_perm->axis(FeatureAxis::Depth)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), src_perm->axis(FeatureAxis::Height)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Width), src_perm->axis(FeatureAxis::Width)); + + // Update on cloned encoder SHOULD NOT affect the original encoder + dst_perm->axis(FeatureAxis::Height) += 1; + + EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2); +} + +TEST(PermutingEncoderTest, filter) +{ + PermutingEncoder<Domain::Filter> enc; + + // Encoder is invalid at the beginning + ASSERT_FALSE(enc.valid()); + + // Set "invalid" mapping + enc.perm()->axis(FilterAxis::Count) = 0; + enc.perm()->axis(FilterAxis::Depth) = 6; + enc.perm()->axis(FilterAxis::Height) = 1; + enc.perm()->axis(FilterAxis::Width) = 2; + + // Encoder is still invalid + ASSERT_FALSE(enc.valid()); + + // Set another "invalid" mapping + enc.perm()->axis(FilterAxis::Depth) = 1; + + // Encoder is still invalid + ASSERT_FALSE(enc.valid()); + + // Set "valid" mapping + enc.perm()->axis(FilterAxis::Depth) = 3; + + // Encoder is now valid + ASSERT_TRUE(enc.valid()); + + TensorShape tensor_shape; + + tensor_shape.rank(4); + tensor_shape.dim(0) = 8; // COUNT + tensor_shape.dim(1) = 1; // HEIGHT + tensor_shape.dim(2) = 7; // WIDTH + tensor_shape.dim(3) = 4; // DEPTH + + // Get the corresponding filter shape + auto filter_shape = enc.shape(tensor_shape); + + ASSERT_EQ(filter_shape.count(), 8); + ASSERT_EQ(filter_shape.depth(), 4); + ASSERT_EQ(filter_shape.height(), 1); + ASSERT_EQ(filter_shape.width(), 7); + + // Let's find a source tensor index! + FilterIndex filter_index; + + filter_index.nth() = 1; + filter_index.channel() = 2; + filter_index.row() = 0; + filter_index.column() = 3; + + auto tensor_index = enc.value(filter_index); + + ASSERT_EQ(tensor_index.at(0), 1); // NTH(COUNT) + ASSERT_EQ(tensor_index.at(1), 0); // ROW(HEIGHT) + ASSERT_EQ(tensor_index.at(2), 3); // COLUMN(WIDTH) + ASSERT_EQ(tensor_index.at(3), 2); // CHANNEL(DEPTH) +} + +TEST(PermutingEncoderTest, depthwise_filter) +{ + PermutingEncoder<Domain::DepthwiseFilter> enc; + + // Encoder is invalid at the beginning + ASSERT_FALSE(enc.valid()); + + // Set "invalid" mapping + enc.perm()->axis(DepthwiseFilterAxis::Depth) = 0; + enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6; + enc.perm()->axis(DepthwiseFilterAxis::Height) = 1; + enc.perm()->axis(DepthwiseFilterAxis::Width) = 2; + + // Encoder is still invalid + ASSERT_FALSE(enc.valid()); + + // Set another "invalid" mapping + enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1; + + // Encoder is still invalid + ASSERT_FALSE(enc.valid()); + + // Set "valid" mapping + enc.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3; + + // Encoder is now valid + ASSERT_TRUE(enc.valid()); + + TensorShape tensor_shape; + + tensor_shape.rank(4); + tensor_shape.dim(0) = 8; // DEPTH + tensor_shape.dim(1) = 1; // HEIGHT + tensor_shape.dim(2) = 7; // WIDTH + tensor_shape.dim(3) = 4; // MULTIPLIER + + // Get the corresponding depthwise filter shape + auto filter_shape = enc.shape(tensor_shape); + + ASSERT_EQ(filter_shape.depth(), 8); + ASSERT_EQ(filter_shape.multiplier(), 4); + ASSERT_EQ(filter_shape.height(), 1); + ASSERT_EQ(filter_shape.width(), 7); + + // Let's find a source tensor index! + DepthwiseFilterIndex filter_index; + + filter_index.channel() = 1; + filter_index.nth() = 2; + filter_index.row() = 0; + filter_index.column() = 3; + + auto tensor_index = enc.value(filter_index); + + ASSERT_EQ(tensor_index.at(0), 1); // CHANNEL(DEPTH) + ASSERT_EQ(tensor_index.at(1), 0); // ROW(HEIGHT) + ASSERT_EQ(tensor_index.at(2), 3); // COLUMN(WIDTH) + ASSERT_EQ(tensor_index.at(3), 2); // NTH(MULTIPLIER) +} + +TEST(PermutingEncoderTest, depthwisefilter_init) +{ + Permutation<Domain::DepthwiseFilter> src_perm; + + src_perm.axis(DepthwiseFilterAxis::Multiplier) = 0; + src_perm.axis(DepthwiseFilterAxis::Depth) = 3; + src_perm.axis(DepthwiseFilterAxis::Height) = 1; + src_perm.axis(DepthwiseFilterAxis::Width) = 2; + + PermutingEncoder<Domain::DepthwiseFilter> dst_enc{src_perm}; + auto dst_perm = dst_enc.perm(); + + EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Multiplier), + src_perm.axis(DepthwiseFilterAxis::Multiplier)); + EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Depth), src_perm.axis(DepthwiseFilterAxis::Depth)); + EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Height), + src_perm.axis(DepthwiseFilterAxis::Height)); + EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Width), src_perm.axis(DepthwiseFilterAxis::Width)); + + // Update on dst perm SHOULD NOT affect the src perm + dst_perm->axis(DepthwiseFilterAxis::Height) += 1; + + EXPECT_EQ(src_perm.axis(DepthwiseFilterAxis::Height), 1); + EXPECT_EQ(dst_perm->axis(DepthwiseFilterAxis::Height), 2); +} + +TEST(PermutingDecoderTest, feature) +{ + PermutingDecoder<Domain::Feature> dec; + + // Decoder is invalid at the beginning + ASSERT_FALSE(dec.valid()); + + // Set "invalid" mapping + dec.perm()->axis(FeatureAxis::Count) = 0; + dec.perm()->axis(FeatureAxis::Depth) = 6; + dec.perm()->axis(FeatureAxis::Height) = 1; + dec.perm()->axis(FeatureAxis::Width) = 2; + + // Decoder is still invalid + ASSERT_FALSE(dec.valid()); + + // Set another "invalid" mapping + dec.perm()->axis(FeatureAxis::Depth) = 1; + + // Decoder is still invalid + ASSERT_FALSE(dec.valid()); + + // Set "valid" mapping + dec.perm()->axis(FeatureAxis::Depth) = 3; + + // Decoder is now valid + ASSERT_TRUE(dec.valid()); + + // Let's test with a HD (1280x720) RGB image + FeatureShape feature_shape; + + feature_shape.count() = 1; + feature_shape.depth() = 3; + feature_shape.height() = 720; + feature_shape.width() = 1280; + + // Get the tensor shape corresponding to a given image + auto tensor_shape = dec.shape(feature_shape); + + ASSERT_EQ(tensor_shape.rank(), 4); + ASSERT_EQ(tensor_shape.dim(0), 1); // COUNT + ASSERT_EQ(tensor_shape.dim(1), 720); // HEIGHT + ASSERT_EQ(tensor_shape.dim(2), 1280); // WIDTH + ASSERT_EQ(tensor_shape.dim(3), 3); // DEPTH + + // Let's find a source feature index! + TensorIndex tensor_index; + + tensor_index.resize(4); + + tensor_index.at(0) = 0; // BATCH(COUNT) + tensor_index.at(3) = 1; // CHANNEL(DEPTH) + tensor_index.at(1) = 2; // ROW(HEIGHT) + tensor_index.at(2) = 3; // COLUMN(WIDTH) + + auto feature_index = dec.value(tensor_index); + + ASSERT_EQ(feature_index.batch(), 0); + ASSERT_EQ(feature_index.channel(), 1); + ASSERT_EQ(feature_index.row(), 2); + ASSERT_EQ(feature_index.column(), 3); +} + +TEST(PermutingDecoderTest, feature_clone) +{ + PermutingDecoder<Domain::Feature> src_enc; + + auto src_perm = src_enc.perm(); + + src_perm->axis(FeatureAxis::Count) = 0; + src_perm->axis(FeatureAxis::Depth) = 3; + src_perm->axis(FeatureAxis::Height) = 1; + src_perm->axis(FeatureAxis::Width) = 2; + + auto dst_enc = src_enc.clone(); + auto dst_perm = dynamic_cast<PermutingDecoder<Domain::Feature> *>(dst_enc.get())->perm(); + + EXPECT_EQ(dst_perm->axis(FeatureAxis::Count), src_perm->axis(FeatureAxis::Count)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Depth), src_perm->axis(FeatureAxis::Depth)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), src_perm->axis(FeatureAxis::Height)); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Width), src_perm->axis(FeatureAxis::Width)); + + // Update on cloned decoder SHOULD NOT affect the original decoder + dst_perm->axis(FeatureAxis::Height) += 1; + + EXPECT_EQ(src_perm->axis(FeatureAxis::Height), 1); + EXPECT_EQ(dst_perm->axis(FeatureAxis::Height), 2); +} + +TEST(PermutingDecoderTest, filter) +{ + PermutingDecoder<Domain::Filter> dec; + + // Decoder is invalid at the beginning + ASSERT_FALSE(dec.valid()); + + // Set "invalid" mapping + dec.perm()->axis(FilterAxis::Count) = 0; + dec.perm()->axis(FilterAxis::Depth) = 6; + dec.perm()->axis(FilterAxis::Height) = 1; + dec.perm()->axis(FilterAxis::Width) = 2; + + // Decoder is still invalid + ASSERT_FALSE(dec.valid()); + + // Set another "invalid" mapping + dec.perm()->axis(FilterAxis::Depth) = 1; + + // Decoder is still invalid + ASSERT_FALSE(dec.valid()); + + // Set "valid" mapping + dec.perm()->axis(FilterAxis::Depth) = 3; + + // Decoder is now valid + ASSERT_TRUE(dec.valid()); + + // Let's test with a small filter + FilterShape filter_shape; + + filter_shape.count() = 10; + filter_shape.depth() = 3; + filter_shape.height() = 6; + filter_shape.width() = 8; + + // Get the tensor shape corresponding to a given image + auto tensor_shape = dec.shape(filter_shape); + + ASSERT_EQ(tensor_shape.rank(), 4); + ASSERT_EQ(tensor_shape.dim(0), 10); // COUNT + ASSERT_EQ(tensor_shape.dim(1), 6); // HEIGHT + ASSERT_EQ(tensor_shape.dim(2), 8); // WIDTH + ASSERT_EQ(tensor_shape.dim(3), 3); // DEPTH + + // Let's find a source filter index! + TensorIndex tensor_index; + + tensor_index.resize(4); + + tensor_index.at(0) = 0; // BATCH(COUNT) + tensor_index.at(3) = 1; // CHANNEL(DEPTH) + tensor_index.at(1) = 2; // ROW(HEIGHT) + tensor_index.at(2) = 3; // COLUMN(WIDTH) + + auto filter_index = dec.value(tensor_index); + + ASSERT_EQ(filter_index.nth(), 0); + ASSERT_EQ(filter_index.channel(), 1); + ASSERT_EQ(filter_index.row(), 2); + ASSERT_EQ(filter_index.column(), 3); +} + +TEST(PermutingDecoderTest, depthwise_filter) +{ + PermutingDecoder<Domain::DepthwiseFilter> dec; + + // Decoder is invalid at the beginning + ASSERT_FALSE(dec.valid()); + + // Set "invalid" mapping + dec.perm()->axis(DepthwiseFilterAxis::Depth) = 0; + dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 6; + dec.perm()->axis(DepthwiseFilterAxis::Height) = 1; + dec.perm()->axis(DepthwiseFilterAxis::Width) = 2; + + // Decoder is still invalid + ASSERT_FALSE(dec.valid()); + + // Set another "invalid" mapping + dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 1; + + // Decoder is still invalid + ASSERT_FALSE(dec.valid()); + + // Set "valid" mapping + dec.perm()->axis(DepthwiseFilterAxis::Multiplier) = 3; + + // Decoder is now valid + ASSERT_TRUE(dec.valid()); + + DepthwiseFilterShape dw_filter_shape; + + dw_filter_shape.depth() = 8; + dw_filter_shape.multiplier() = 1; + dw_filter_shape.height() = 7; + dw_filter_shape.width() = 4; + + // Get the corresponding depthwise filter shape + auto tensor_shape = dec.shape(dw_filter_shape); + + ASSERT_EQ(tensor_shape.dim(0).value(), 8); + ASSERT_EQ(tensor_shape.dim(1).value(), 7); + ASSERT_EQ(tensor_shape.dim(2).value(), 4); + ASSERT_EQ(tensor_shape.dim(3).value(), 1); + + // Let's find a source tensor index! + TensorIndex tensor_index; + tensor_index.resize(4); + + tensor_index.at(0) = 4; + tensor_index.at(1) = 2; + tensor_index.at(2) = 1; + tensor_index.at(3) = 0; + + auto dw_filter_index = dec.value(tensor_index); + + ASSERT_EQ(dw_filter_index.channel(), 4); + ASSERT_EQ(dw_filter_index.nth(), 0); + ASSERT_EQ(dw_filter_index.row(), 2); + ASSERT_EQ(dw_filter_index.column(), 1); +} diff --git a/compiler/loco/src/IR/Stride.test.cpp b/compiler/loco/src/IR/Stride.test.cpp new file mode 100644 index 000000000..60deb5c6f --- /dev/null +++ b/compiler/loco/src/IR/Stride.test.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Stride.h" + +#include <gtest/gtest.h> + +TEST(StrideTest, default_constructor_2D) +{ + loco::Stride<2> stride; + + ASSERT_EQ(stride.vertical(), 1); + ASSERT_EQ(stride.horizontal(), 1); +} + +TEST(StrideTest, setter_and_getter_2D) +{ + loco::Stride<2> stride; + + stride.vertical(2); + + ASSERT_EQ(stride.vertical(), 2); + ASSERT_EQ(stride.horizontal(), 1); + + stride.horizontal(3); + + ASSERT_EQ(stride.vertical(), 2); + ASSERT_EQ(stride.horizontal(), 3); +} diff --git a/compiler/loco/src/IR/TensorAxis.cpp b/compiler/loco/src/IR/TensorAxis.cpp new file mode 100644 index 000000000..b083847fc --- /dev/null +++ b/compiler/loco/src/IR/TensorAxis.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/TensorAxis.h" + +// NOTE This file validates "TensorAxis.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/TensorAxisSet.cpp b/compiler/loco/src/IR/TensorAxisSet.cpp new file mode 100644 index 000000000..c58237bf7 --- /dev/null +++ b/compiler/loco/src/IR/TensorAxisSet.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/TensorAxisSet.h" + +// NOTE This file validates "TensorAxisSet.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/TensorIndex.cpp b/compiler/loco/src/IR/TensorIndex.cpp new file mode 100644 index 000000000..cbd3698eb --- /dev/null +++ b/compiler/loco/src/IR/TensorIndex.cpp @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/TensorIndex.h" + +// NOTE This file validates "TensorIndex.h". Please DO NOT remove this file. diff --git a/compiler/loco/src/IR/TensorShape.cpp b/compiler/loco/src/IR/TensorShape.cpp new file mode 100644 index 000000000..ad30dcbc0 --- /dev/null +++ b/compiler/loco/src/IR/TensorShape.cpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/TensorShape.h" + +#include <cassert> + +namespace loco +{ + +uint32_t element_count(const loco::TensorShape *tensor_shape) +{ + uint32_t res = 1; + + for (uint32_t axis = 0; axis < tensor_shape->rank(); ++axis) + { + // Let's use "assert" here as "caller" is responsible for this check. + // Please refer to the header for details. + assert(tensor_shape->dim(axis).known()); + res *= tensor_shape->dim(axis).value(); + } + + return res; +} + +} // namespace loco diff --git a/compiler/loco/src/IR/TensorShape.test.cpp b/compiler/loco/src/IR/TensorShape.test.cpp new file mode 100644 index 000000000..ce03ccbd4 --- /dev/null +++ b/compiler/loco/src/IR/TensorShape.test.cpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/TensorShape.h" + +#include <gtest/gtest.h> + +TEST(TensorShapeTest, default_constructor) +{ + loco::TensorShape tensor_shape; + + ASSERT_EQ(tensor_shape.rank(), 0); +} + +TEST(TensorShapeTest, initializer_list_constructor) +{ + loco::TensorShape tensor_shape{3, 5}; + + ASSERT_EQ(tensor_shape.rank(), 2); + + ASSERT_TRUE(tensor_shape.dim(0).known()); + ASSERT_TRUE(tensor_shape.dim(1).known()); + + ASSERT_EQ(tensor_shape.dim(0).value(), 3); + ASSERT_EQ(tensor_shape.dim(1).value(), 5); +} + +TEST(TensorShapeTest, rank) +{ + loco::TensorShape tensor_shape; + + tensor_shape.rank(2); + + ASSERT_EQ(tensor_shape.rank(), 2); + ASSERT_FALSE(tensor_shape.dim(0).known()); + ASSERT_FALSE(tensor_shape.dim(1).known()); +} + +TEST(TensorShapeTest, dim) +{ + loco::TensorShape tensor_shape; + + tensor_shape.rank(2); + + tensor_shape.dim(0) = 3; + + ASSERT_TRUE(tensor_shape.dim(0).known()); + ASSERT_FALSE(tensor_shape.dim(1).known()); + + ASSERT_EQ(tensor_shape.dim(0), 3); +} + +TEST(TensorShapeTest, rank_update) +{ + loco::TensorShape tensor_shape; + + tensor_shape.rank(2); + + tensor_shape.dim(1) = 3; + + tensor_shape.rank(4); + + ASSERT_FALSE(tensor_shape.dim(0).known()); + ASSERT_TRUE(tensor_shape.dim(1).known()); + ASSERT_FALSE(tensor_shape.dim(2).known()); + ASSERT_FALSE(tensor_shape.dim(3).known()); + + ASSERT_EQ(tensor_shape.dim(1), 3); +} + +TEST(TensorShapeTest, copy) +{ + loco::TensorShape src; + + src.rank(2); + src.dim(1) = 3; + + loco::TensorShape dst; + + dst = src; + + ASSERT_EQ(dst.rank(), 2); + + ASSERT_FALSE(dst.dim(0).known()); + ASSERT_TRUE(dst.dim(1).known()); + + ASSERT_EQ(dst.dim(1), 3); +} + +TEST(TensorShapeTest, element_count) +{ + // Check Rank-0 case + loco::TensorShape src; + + ASSERT_EQ(loco::element_count(&src), 1); +} diff --git a/compiler/loco/src/IR/Use.cpp b/compiler/loco/src/IR/Use.cpp new file mode 100644 index 000000000..fed562c65 --- /dev/null +++ b/compiler/loco/src/IR/Use.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Use.h" +#include "loco/IR/Node.h" + +#include <cassert> + +namespace loco +{ + +void Use::node(Node *node) +{ + if (_node != nullptr) + { + assert(_node->_uses.find(this) != _node->_uses.end()); + _node->_uses.erase(this); + _node = nullptr; + } + + assert(_node == nullptr); + + if (node != nullptr) + { + _node = node; + _node->_uses.insert(this); + } + + assert(_node == node); +} + +} // namespace loco diff --git a/compiler/loco/src/IR/Use.test.cpp b/compiler/loco/src/IR/Use.test.cpp new file mode 100644 index 000000000..4a2f1cc25 --- /dev/null +++ b/compiler/loco/src/IR/Use.test.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Use.h" + +#include "MockupNode.h" + +#include <gtest/gtest.h> + +TEST(UseTest, constructor) +{ + MockupNode user; + loco::Use use{&user}; + + ASSERT_EQ(use.user(), &user); + ASSERT_EQ(use.node(), nullptr); +} + +TEST(UseTest, link_node) +{ + MockupNode def; + MockupNode user; + loco::Use use{&user}; + + use.node(&def); + + ASSERT_EQ(use.user(), &user); + ASSERT_EQ(use.node(), &def); +} diff --git a/compiler/loco/src/IR/Verifier.cpp b/compiler/loco/src/IR/Verifier.cpp new file mode 100644 index 000000000..42735a327 --- /dev/null +++ b/compiler/loco/src/IR/Verifier.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Verifier.h" + +#include <set> +#include <cassert> + +namespace +{ + +using namespace loco; + +struct GraphVerifier final +{ +public: + GraphVerifier(loco::Graph *graph) : _graph{graph} + { + // graph SHOULD NOT BE null + assert(_graph != nullptr); + } + +public: + // ErrorListener SHOULD outlive GraphVerifier + GraphVerifier &enroll(ErrorListener *l) + { + if (l != nullptr) + { + _listeners.insert(l); + } + return (*this); + } + + GraphVerifier &enroll(std::unique_ptr<ErrorListener> &&l) + { + if (l != nullptr) + { + _listeners.insert(l.get()); + // Take the ownership of a given listener + _owned_listeners.insert(std::move(l)); + } + return (*this); + } + +public: + void run(void) const + { + for (auto node : loco::all_nodes(_graph)) + { + // Verify nodes + for (uint32_t n = 0; n < node->arity(); ++n) + { + if (node->arg(n) == nullptr) + { + notify(ErrorDetail<ErrorCategory::MissingArgument>{node, n}); + } + } + } + } + +private: + template <typename Error> void notify(const Error &error) const + { + for (const auto &listener : _listeners) + { + listener->notify(error); + } + } + +private: + loco::Graph *_graph = nullptr; + + // All active error listeners + std::set<ErrorListener *> _listeners; + + // Owned error listeners + std::set<std::unique_ptr<ErrorListener>> _owned_listeners; +}; + +inline GraphVerifier graph_verifier(loco::Graph *graph) { return GraphVerifier{graph}; } + +} // namespace + +namespace loco +{ + +bool valid(Graph *g, std::unique_ptr<ErrorListener> &&l) +{ + class ErrorCounter final : public ErrorListener + { + public: + uint32_t count(void) const { return _count; } + + public: + void notify(const ErrorDetail<ErrorCategory::MissingArgument> &) { _count += 1; } + + private: + uint32_t _count = 0; + }; + + ErrorCounter counter; + graph_verifier(g).enroll(&counter).enroll(std::move(l)).run(); + return counter.count() == 0; +} + +} // namespace loco diff --git a/compiler/loco/src/IR/Verifier.test.cpp b/compiler/loco/src/IR/Verifier.test.cpp new file mode 100644 index 000000000..247a59390 --- /dev/null +++ b/compiler/loco/src/IR/Verifier.test.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Verifier.h" + +#include <gtest/gtest.h> + +#include <stdex/Memory.h> +#include <vector> + +using stdex::make_unique; + +TEST(VerifierTest, valid_minimal) +{ + auto g = loco::make_graph(); + auto push = g->nodes()->create<loco::Push>(); + + ASSERT_FALSE(loco::valid(g.get())); +} + +TEST(VerifierTest, valid_error_reporter) +{ + using namespace loco; + + auto g = loco::make_graph(); + auto push = g->nodes()->create<loco::Push>(); + + class Collector final : public loco::ErrorListener + { + public: + Collector(std::vector<ErrorDetail<ErrorCategory::MissingArgument>> *out) : _out{out} + { + // DO NOTHING + } + + public: + void notify(const ErrorDetail<ErrorCategory::MissingArgument> &d) override + { + _out->emplace_back(d); + } + + private: + std::vector<ErrorDetail<ErrorCategory::MissingArgument>> *_out; + }; + + std::vector<ErrorDetail<ErrorCategory::MissingArgument>> errors; + ASSERT_FALSE(loco::valid(g.get(), make_unique<Collector>(&errors))); + ASSERT_EQ(errors.size(), 1); + ASSERT_EQ(errors.at(0).node(), push); + ASSERT_EQ(errors.at(0).index(), 0); +} diff --git a/compiler/loco/src/IR/Window.test.cpp b/compiler/loco/src/IR/Window.test.cpp new file mode 100644 index 000000000..c112e0f96 --- /dev/null +++ b/compiler/loco/src/IR/Window.test.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/IR/Window.h" + +#include <gtest/gtest.h> + +TEST(WindowTest, default_constructor_2D) +{ + loco::Window<2> window; + + ASSERT_EQ(window.vertical(), 1); + ASSERT_EQ(window.horizontal(), 1); +} + +TEST(WindowTest, setter_and_getter_2D) +{ + loco::Window<2> window; + + window.vertical(2); + + ASSERT_EQ(window.vertical(), 2); + ASSERT_EQ(window.horizontal(), 1); + + window.horizontal(3); + + ASSERT_EQ(window.vertical(), 2); + ASSERT_EQ(window.horizontal(), 3); +} diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp new file mode 100644 index 000000000..d30a8279a --- /dev/null +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -0,0 +1,774 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/CanonicalShapeInferenceRule.h" +#include "loco/Service/ShapeInference.h" + +#include <loco/IR/CanonicalDialect.h> +#include <loco/IR/CanonicalNode.h> +#include <loco/IR/CanonicalNodeVisitor.h> + +#include <cassert> + +namespace +{ + +struct PlaneShape +{ + loco::Dimension height; + loco::Dimension width; +}; + +PlaneShape make_plane_shape(const loco::FeatureShape &feature_shape) +{ + PlaneShape plane_shape; + + plane_shape.height = feature_shape.height(); + plane_shape.width = feature_shape.width(); + + return plane_shape; +} + +class FeatureShapeUpdater final +{ +public: + FeatureShapeUpdater(loco::FeatureShape *ptr) : _feature_shape_ptr{ptr} + { + // DO NOTHING + } + +public: + void with(const PlaneShape &plane_shape) const + { + _feature_shape_ptr->height() = plane_shape.height; + _feature_shape_ptr->width() = plane_shape.width; + } + +private: + loco::FeatureShape *_feature_shape_ptr; +}; + +/** + * HOW TO USE + * + * loco::FeatureShape feature_shape = ...; + * + * update(feature_shape).with(...) + */ +FeatureShapeUpdater update(loco::FeatureShape &feature_shape) +{ + return FeatureShapeUpdater{&feature_shape}; +} + +loco::Window<2> window_of(const loco::FilterShape &filter_shape) +{ + loco::Window<2> window; + + window.vertical(filter_shape.height().value()); + window.horizontal(filter_shape.width().value()); + + return window; +} + +loco::Window<2> window_of(const loco::DepthwiseFilterShape &depthwise_filter_shape) +{ + loco::Window<2> window; + + window.vertical(depthwise_filter_shape.height().value()); + window.horizontal(depthwise_filter_shape.width().value()); + + return window; +} + +enum class Direction +{ + Forward, + Backward, +}; + +template <Direction> class PlaneInference; + +template <> class PlaneInference<Direction::Forward> final +{ +public: + PlaneShape operator()(const PlaneShape &in) const + { + assert(_pad != nullptr); + assert(_window != nullptr); + assert(_stride != nullptr); + + uint32_t const raw_input_height = in.height.value(); + uint32_t const raw_input_width = in.width.value(); + + uint32_t const raw_window_height = _window->vertical(); + uint32_t const raw_window_width = _window->horizontal(); + + uint32_t const vertical_padding = _pad->top() + _pad->bottom(); + uint32_t const horizontal_padding = _pad->left() + _pad->right(); + + uint32_t const effective_input_height = raw_input_height + vertical_padding; + uint32_t const effective_input_width = raw_input_width + horizontal_padding; + + // NOTE To support "dilation" later + uint32_t const effective_window_height = raw_window_height; + uint32_t const effective_window_width = raw_window_width; + + uint32_t const vertical_stride = _stride->vertical(); + uint32_t const horizontal_stride = _stride->horizontal(); + + assert((effective_input_height - effective_window_height) % vertical_stride == 0); + assert((effective_input_width - effective_window_width) % horizontal_stride == 0); + + PlaneShape res; + + res.height = (effective_input_height - effective_window_height) / vertical_stride + 1; + res.width = (effective_input_width - effective_window_width) / horizontal_stride + 1; + + return res; + } + +public: + void pad(const loco::Padding2D *value) { _pad = value; } + void window(const loco::Window<2> *value) { _window = value; } + void stride(const loco::Stride<2> *value) { _stride = value; } + +private: + const loco::Padding2D *_pad = nullptr; + const loco::Window<2> *_window = nullptr; + const loco::Stride<2> *_stride = nullptr; +}; + +template <> class PlaneInference<Direction::Backward> final +{ +public: + PlaneShape operator()(const PlaneShape &in) const + { + assert(_pad != nullptr); + assert(_window != nullptr); + assert(_stride != nullptr); + + uint32_t const input_height = in.height.value(); + uint32_t const input_width = in.width.value(); + + uint32_t const vertical_padding = _pad->top() + _pad->bottom(); + uint32_t const horizontal_padding = _pad->left() + _pad->right(); + + uint32_t const raw_window_height = _window->vertical(); + uint32_t const raw_window_width = _window->horizontal(); + + // TODO Support "dilation" + uint32_t const effective_window_height = raw_window_height; + uint32_t const effective_window_width = raw_window_width; + + uint32_t const vertical_stride = _stride->vertical(); + uint32_t const horizontal_stride = _stride->horizontal(); + + PlaneShape res; + + res.height = vertical_stride * (input_height - 1) + effective_window_height - vertical_padding; + res.width = horizontal_stride * (input_width - 1) + effective_window_width - horizontal_padding; + + return res; + } + +public: + void pad(const loco::Padding2D *value) { _pad = value; } + void window(const loco::Window<2> *value) { _window = value; } + void stride(const loco::Stride<2> *value) { _stride = value; } + +private: + const loco::Padding2D *_pad = nullptr; + const loco::Window<2> *_window = nullptr; + const loco::Stride<2> *_stride = nullptr; +}; + +/** + * There are two possible maintenance policies. + * - Introduce a new canonical node first, and then extend this algorithm later + * - Introduce a new canonical node and extend this algorithm at the same time + * + * The current implementation assumes the former one (for historical reason). + * + * TODO Evaluate the impact of the latter one + * + * NOTE "Forward" means that this algorithm computes the ouput shape from inputs shapes + */ +class ForwardShapeInferenceAlgorithm final : public loco::CanonicalNodeVisitor<loco::NodeShape> +{ +public: + ForwardShapeInferenceAlgorithm(const loco::ShapeInferenceRule::Context *ctx) : _ctx{ctx} + { + // DO NOTHING + } + +private: + const loco::ShapeInferenceRule::Context *_ctx; + +private: + bool shape_known(const loco::Node *node) const { return _ctx->known(node); } + loco::NodeShape node_shape(const loco::Node *node) const { return _ctx->get(node); } + +private: + loco::NodeShape eltwise_binary_node_shape(const loco::Node *node) + { + // This helper works only for binary node. + assert(node->arity() == 2); + + auto lhs_shape = node_shape(node->arg(0)); + auto rhs_shape = node_shape(node->arg(1)); + + // ASSERT: lhs_shape == rhs_shape + + return lhs_shape; + } + +public: + // CASE: AvgPool2D + loco::NodeShape visit(const loco::AvgPool2D *node) final + { + PlaneInference<Direction::Forward> infer_plane_shape; + + infer_plane_shape.pad(node->pad()); + infer_plane_shape.window(node->window()); + infer_plane_shape.stride(node->stride()); + + auto input_feature_shape = node_shape(node->ifm()).as<loco::FeatureShape>(); + auto input_plane_shape = make_plane_shape(input_feature_shape); + auto output_plane_shape = infer_plane_shape(input_plane_shape); + auto output_feature_shape = input_feature_shape; // AvgPool2D does not change count/depth + + // Update the height/width of output_feature_shape with that of output_plane_shape + update(output_feature_shape).with(output_plane_shape); + + return loco::NodeShape{output_feature_shape}; + } + + // CASE: BiasDecode + loco::NodeShape visit(const loco::BiasDecode *node) final + { + // The input of BiasDecode SHOULD BE a bias! + assert(node_shape(node->input()).domain() == loco::Domain::Bias); + auto input_bias_shape = node_shape(node->input()).as<loco::BiasShape>(); + + loco::TensorShape output_tensor_shape; + + output_tensor_shape.rank(1); + output_tensor_shape.dim(0) = input_bias_shape.length(); + + return loco::NodeShape{output_tensor_shape}; + } + + // CASE: BiasEncode + loco::NodeShape visit(const loco::BiasEncode *node) final + { + // The input of BiasEncode SHOULD BE a tensor! + assert(node_shape(node->input()).domain() == loco::Domain::Tensor); + auto input_tensor_shape = node_shape(node->input()).as<loco::TensorShape>(); + + loco::BiasShape output_bias_shape; + + output_bias_shape.length() = input_tensor_shape.dim(0); + + return loco::NodeShape{output_bias_shape}; + } + + // CASE: ConstGen + loco::NodeShape visit(const loco::ConstGen *node) final + { + loco::TensorShape tensor_shape; + + tensor_shape.rank(node->rank()); + for (uint32_t axis = 0; axis < node->rank(); ++axis) + { + tensor_shape.dim(axis) = node->dim(axis); + } + + return loco::NodeShape{tensor_shape}; + } + + // CASE: Conv2D + loco::NodeShape visit(const loco::Conv2D *node) final + { + auto filter_shape = node_shape(node->ker()).as<loco::FilterShape>(); + auto filter_window = window_of(filter_shape); + + PlaneInference<Direction::Forward> infer_plane_shape; + + infer_plane_shape.pad(node->pad()); + infer_plane_shape.window(&filter_window); + infer_plane_shape.stride(node->stride()); + + auto input_feature_shape = node_shape(node->ifm()).as<loco::FeatureShape>(); + auto input_plane_shape = make_plane_shape(input_feature_shape); + auto output_plane_shape = infer_plane_shape(input_plane_shape); + + loco::FeatureShape output_feature_shape; + + // "COUNT" does not change + output_feature_shape.count() = input_feature_shape.count(); + // "DEPTH" depends on # of filters + output_feature_shape.depth() = filter_shape.count(); + // Update the height/width of output_feature_shape with that of output_plane_shape + update(output_feature_shape).with(output_plane_shape); + + return loco::NodeShape{output_feature_shape}; + } + + // CASE: DepthwiseConv2D + loco::NodeShape visit(const loco::DepthwiseConv2D *node) final + { + auto depthwise_filter_shape = node_shape(node->ker()).as<loco::DepthwiseFilterShape>(); + auto dpethwise_filter_window = window_of(depthwise_filter_shape); + + PlaneInference<Direction::Forward> infer_plane_shape; + + infer_plane_shape.pad(node->pad()); + infer_plane_shape.window(&dpethwise_filter_window); + infer_plane_shape.stride(node->stride()); + + auto input_feature_shape = node_shape(node->ifm()).as<loco::FeatureShape>(); + auto input_plane_shape = make_plane_shape(input_feature_shape); + auto output_plane_shape = infer_plane_shape(input_plane_shape); + + loco::FeatureShape output_feature_shape; + + // "COUNT" does not change + output_feature_shape.count() = input_feature_shape.count(); + // "DEPTH" depends on [in_channels * channel_multiplier] of filters + output_feature_shape.depth() = loco::Dimension(depthwise_filter_shape.depth().value() * + depthwise_filter_shape.multiplier().value()); + // Update the height/width of output_feature_shape with that of output_plane_shape + update(output_feature_shape).with(output_plane_shape); + + return loco::NodeShape{output_feature_shape}; + } + + // CASE: DepthwiseFilterEncode + loco::NodeShape visit(const loco::DepthwiseFilterEncode *node) final + { + auto input_tensor_shape = node_shape(node->input()).as<loco::TensorShape>(); + return loco::NodeShape{node->encoder()->shape(input_tensor_shape)}; + } + + // CASE: DepthwiseFilterDecode + loco::NodeShape visit(const loco::DepthwiseFilterDecode *node) final + { + auto input_dw_filter_shape = node_shape(node->input()).as<loco::DepthwiseFilterShape>(); + return loco::NodeShape{node->decoder()->shape(input_dw_filter_shape)}; + } + + // CASE: EltwiseAdd + loco::NodeShape visit(const loco::EltwiseAdd *node) final + { + return eltwise_binary_node_shape(node); + } + + // CASE: EltwiseDiv + loco::NodeShape visit(const loco::EltwiseDiv *node) final + { + return eltwise_binary_node_shape(node); + } + + // CASE: EltwiseMax + loco::NodeShape visit(const loco::EltwiseMax *node) final + { + return eltwise_binary_node_shape(node); + } + + // CASE: EltwiseMul + loco::NodeShape visit(const loco::EltwiseMul *node) final + { + return eltwise_binary_node_shape(node); + } + + // CASE: EltwiseSqrt + loco::NodeShape visit(const loco::EltwiseSqrt *node) final { return node_shape(node->input()); } + + // CASE: EltwiseSub + loco::NodeShape visit(const loco::EltwiseSub *node) final + { + return eltwise_binary_node_shape(node); + } + + // CASE: Forward + loco::NodeShape visit(const loco::Forward *node) final { return node_shape(node->input()); } + + // CASE: FeatureBiasAdd + loco::NodeShape visit(const loco::FeatureBiasAdd *node) final + { + assert(node_shape(node->value()).domain() == loco::Domain::Feature); + assert(node_shape(node->bias()).domain() == loco::Domain::Bias); + + // Q. What to do when there is a mismatch between value's depth and bias's length? + + return node_shape(node->value()); + } + + // CASE: FeatureDecode + loco::NodeShape visit(const loco::FeatureDecode *node) final + { + auto input_node_shape = node_shape(node->input()); + return loco::NodeShape{node->decoder()->shape(input_node_shape.as<loco::FeatureShape>())}; + } + + // CASE: FeatureEncode + loco::NodeShape visit(const loco::FeatureEncode *node) final + { + auto input_node_shape = node_shape(node->input()); + return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())}; + } + + // CASE: FilterDecode + loco::NodeShape visit(const loco::FilterDecode *node) final + { + auto input_filter_shape = node_shape(node->input()).as<loco::FilterShape>(); + return loco::NodeShape{node->decoder()->shape(input_filter_shape)}; + } + + // CASE: FilterEncode + loco::NodeShape visit(const loco::FilterEncode *node) final + { + auto input_tensor_shape = node_shape(node->input()).as<loco::TensorShape>(); + return loco::NodeShape{node->encoder()->shape(input_tensor_shape)}; + } + + // CASE: FixedReshape + loco::NodeShape visit(const loco::FixedReshape *node) final + { + loco::TensorShape tensor_shape; + + tensor_shape.rank(node->rank()); + for (uint32_t axis = 0; axis < node->rank(); ++axis) + { + tensor_shape.dim(axis) = node->dim(axis); + } + + return loco::NodeShape{tensor_shape}; + } + + // CASE: MatMul + loco::NodeShape visit(const loco::MatMul *node) final + { + assert(shape_known(node->lhs())); + assert(shape_known(node->rhs())); + auto const lhs_shape = node_shape(node->lhs()).as<loco::MatrixShape>(); + auto const rhs_shape = node_shape(node->rhs()).as<loco::MatrixShape>(); + + loco::MatrixShape out_shape; + + // Checking shape capability for multiplication + assert(lhs_shape.width() == rhs_shape.height()); + + out_shape.height() = lhs_shape.height(); + out_shape.width() = rhs_shape.width(); + + return out_shape; + } + + // CASE: MatrixDecode + loco::NodeShape visit(const loco::MatrixDecode *node) final + { + auto input_node_shape = node_shape(node->input()); + return loco::NodeShape{node->decoder()->shape(input_node_shape.as<loco::MatrixShape>())}; + } + + // CASE: MatrixEncode + loco::NodeShape visit(const loco::MatrixEncode *node) final + { + auto input_node_shape = node_shape(node->input()); + return loco::NodeShape{node->encoder()->shape(input_node_shape.as<loco::TensorShape>())}; + } + + // CASE: MaxPool2D + loco::NodeShape visit(const loco::MaxPool2D *node) final + { + PlaneInference<Direction::Forward> infer_plane_shape; + + infer_plane_shape.pad(node->pad()); + infer_plane_shape.window(node->window()); + infer_plane_shape.stride(node->stride()); + + auto input_feature_shape = node_shape(node->ifm()).as<loco::FeatureShape>(); + auto input_plane_shape = make_plane_shape(input_feature_shape); + auto output_plane_shape = infer_plane_shape(input_plane_shape); + auto output_feature_shape = input_feature_shape; // MaxPool2D does not change count/depth + + // Update the height/width of output_feature_shape with that of output_plane_shape + update(output_feature_shape).with(output_plane_shape); + + return loco::NodeShape{output_feature_shape}; + } + + // CASE: Push + loco::NodeShape visit(const loco::Push *node) final + { + assert(shape_known(node->from())); + return node_shape(node->from()); + } + + // CASE: Pull + loco::NodeShape visit(const loco::Pull *node) final + { + // Build a tensor shape from "Pull" node + loco::TensorShape tensor_shape; + + tensor_shape.rank(node->rank()); + for (uint32_t axis = 0; axis < node->rank(); ++axis) + { + tensor_shape.dim(axis) = node->dim(axis); + } + + return loco::NodeShape{tensor_shape}; + } + + // CASE: ReLU + loco::NodeShape visit(const loco::ReLU *node) final { return node_shape(node->input()); } + + // CASE: ReLU6 + loco::NodeShape visit(const loco::ReLU6 *node) final { return node_shape(node->input()); } + + // CASE: Tanh + loco::NodeShape visit(const loco::Tanh *node) final { return node_shape(node->input()); } + + // CASE: TensorBiasAdd + loco::NodeShape visit(const loco::TensorBiasAdd *node) final + { + assert(node_shape(node->value()).domain() == loco::Domain::Tensor); + assert(node_shape(node->bias()).domain() == loco::Domain::Bias); + + // Q. What to do when there is a mismatch between value's dim and bias's length? + + return node_shape(node->value()); + } + + // CASE: TensorConcat + loco::NodeShape visit(const loco::TensorConcat *node) + { + auto const lhs_shape = node_shape(node->lhs()).as<loco::TensorShape>(); + auto const rhs_shape = node_shape(node->rhs()).as<loco::TensorShape>(); + + assert(lhs_shape.rank() == rhs_shape.rank()); + uint32_t const out_rank = lhs_shape.rank(); + + loco::TensorShape out_shape; + + out_shape.rank(out_rank); + + for (uint32_t axis = 0; axis < out_rank; ++axis) + { + if (axis == node->axis()) + { + out_shape.dim(axis) = lhs_shape.dim(axis).value() + rhs_shape.dim(axis).value(); + } + else + { + assert(lhs_shape.dim(axis) == rhs_shape.dim(axis)); + out_shape.dim(axis) = lhs_shape.dim(axis); + } + } + + return loco::NodeShape{out_shape}; + } + + // CASE: TensorBroadcast + loco::NodeShape visit(const loco::TensorBroadcast *node) final + { + auto tensor_shape = node_shape(node->input()).as<loco::TensorShape>(); + auto const tensor_rank = tensor_shape.rank(); + + for (uint32_t axis = 0; axis < tensor_rank; ++axis) + { + if (node->mapping()->defined(axis)) + { + tensor_shape.dim(axis) = node->mapping()->dim(axis); + } + } + + return loco::NodeShape{tensor_shape}; + } + + // CASE: TensorReduce + loco::NodeShape visit(const loco::TensorReduce *node) final + { + auto tensor_shape = node_shape(node->input()).as<loco::TensorShape>(); + auto const tensor_rank = tensor_shape.rank(); + + for (uint32_t d = 0; d < tensor_rank; ++d) + if (node->axes()->defined(d)) + tensor_shape.dim(d) = 1; + + return loco::NodeShape{tensor_shape}; + } + + // CASE: TensorSoftmax + loco::NodeShape visit(const loco::TensorSoftmax *node) final { return node_shape(node->input()); } + + // CASE: TensorTranspose + loco::NodeShape visit(const loco::TensorTranspose *node) final + { + loco::TensorShape output_shape; + + auto input_shape = node_shape(node->input()).as<loco::TensorShape>(); + assert(input_shape.rank() == node->perm()->size()); + + output_shape.rank(input_shape.rank()); + + for (uint32_t output_axis = 0; output_axis < output_shape.rank(); output_axis++) + { + auto new_dim = input_shape.dim(node->perm()->axis(output_axis)); + output_shape.dim(output_axis) = new_dim; + } + + return loco::NodeShape(output_shape); + } + + // CASE: TransposedConv2D + loco::NodeShape visit(const loco::TransposedConv2D *node) final + { + auto filter_shape = node_shape(node->ker()).as<loco::FilterShape>(); + auto filter_window = window_of(filter_shape); + + PlaneInference<Direction::Backward> infer_plane_shape; + + infer_plane_shape.pad(node->pad()); + infer_plane_shape.window(&filter_window); + infer_plane_shape.stride(node->stride()); + + auto input_feature_shape = node_shape(node->ifm()).as<loco::FeatureShape>(); + auto input_plane_shape = make_plane_shape(input_feature_shape); + auto output_plane_shape = infer_plane_shape(input_plane_shape); + + loco::FeatureShape output_feature_shape; + + // "COUNT" does not change + output_feature_shape.count() = input_feature_shape.count(); + // Output "DEPTH" depends on count of filters + output_feature_shape.depth() = filter_shape.count(); + // Update the height/width of output_feature_shape with that of output_plane_shape + update(output_feature_shape).with(output_plane_shape); + + return loco::NodeShape{output_feature_shape}; + } + + // CASE: TensorConstantPad + loco::NodeShape visit(const loco::TensorConstantPad *node) final + { + auto const tensor_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto padding = node->padding(); + + loco::TensorShape out_shape; + out_shape.rank(tensor_shape.rank()); + for (uint32_t axis = 0; axis < out_shape.rank(); ++axis) + { + out_shape.dim(axis) = + tensor_shape.dim(axis).value() + padding->front(axis) + padding->back(axis); + } + + return loco::NodeShape{out_shape}; + } +}; + +} // namespace + +namespace +{ +namespace compat +{ + +struct Context final : public loco::ShapeInferenceRule::Context +{ + bool known(const loco::Node *node) const final { return loco::shape_known(node); } + loco::NodeShape get(const loco::Node *node) const final { return loco::shape_get(node); } +}; + +class Sink final : public loco::ShapeInferenceRule::Sink +{ +public: + enum Status + { + Unknown, + Okay, + Fail, + }; + +public: + const Status &status(void) const { return _status; } + const loco::NodeShape &shape(void) const { return _shape; } + +public: + void okay(const loco::NodeShape &shape) final + { + _status = Okay; + _shape = shape; + } + + void fail(void) final + { + // Notify failure + _status = Fail; + } + +private: + Status _status = Unknown; + loco::NodeShape _shape; +}; + +} // namespace compat +} // namespace + +namespace loco +{ + +bool CanonicalShapeInferenceRule::support(const API &api) const +{ + return api == API::V1 or api == API::V2; +} + +bool CanonicalShapeInferenceRule::recognize(const Dialect *d) const +{ + return CanonicalDialect::get() == d; +} + +bool CanonicalShapeInferenceRule::infer(const Node *node, NodeShape &shape) const +{ + ::compat::Context ctx; + ::compat::Sink sink; + + infer(&ctx, node, &sink); + + assert(sink.status() == ::compat::Sink::Okay or sink.status() == ::compat::Sink::Fail); + + if (sink.status() == ::compat::Sink::Fail) + { + return false; + } + + shape = sink.shape(); + return true; +} + +void CanonicalShapeInferenceRule::infer(const Context *ctx, const Node *node, Sink *sink) const +{ + assert(node->dialect() == loco::CanonicalDialect::get()); + assert(dynamic_cast<const loco::CanonicalNode *>(node) != nullptr); + + ForwardShapeInferenceAlgorithm alg{ctx}; + auto shape = dynamic_cast<const loco::CanonicalNode *>(node)->accept(&alg); + + sink->okay(shape); +} + +} // namespace loco diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp new file mode 100644 index 000000000..5cc8c3808 --- /dev/null +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.test.cpp @@ -0,0 +1,400 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/CanonicalShapeInferenceRule.h" +#include "loco/Service/ShapeInference.h" + +#include "GraphTestcase.h" + +#include <vector> + +#include <gtest/gtest.h> + +TEST(CanonicalShapeInferenceRuleTest, minimal) +{ + // Create a simple identity network, which takes Tensor<1x2x3x4> as input. + GraphTestcase<GraphCode::Identity> testcase{1, 2, 3, 4}; + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.push_node)); + ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 4); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 1); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4); +} + +TEST(CanonicalShapeInferenceRuleTest, const_gen) +{ + // Create a sample network + GraphTestcase<GraphCode::ConstGen> testcase; + + testcase.const_node->dtype(loco::DataType::FLOAT32); + testcase.const_node->shape({1, 2}); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.push_node)); + ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 2); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 1); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2); +} + +TEST(CanonicalShapeInferenceRuleTest, relu) +{ + // Create a sample network + GraphTestcase<GraphCode::Relu> testcase; + + testcase.pull_node->shape({1, 2, 3, 4}); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.push_node)); + ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 4); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 1); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(2), 3); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(3), 4); +} + +TEST(CanonicalShapeInferenceRuleTest, feature_codec) +{ + // Create a sample network + GraphTestcase<GraphCode::FeatureCodec> testcase; + + testcase.pull_node->shape({1, 2, 3, 4}); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.encode_node)); + ASSERT_EQ(loco::shape_get(testcase.encode_node).domain(), loco::Domain::Feature); + + ASSERT_TRUE(loco::shape_known(testcase.decode_node)); + ASSERT_EQ(loco::shape_get(testcase.decode_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().rank(), 4); + ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(0), 1); + ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(1), 2); + ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(2), 3); + ASSERT_EQ(loco::shape_get(testcase.decode_node).as<loco::TensorShape>().dim(3), 4); +} + +TEST(CanonicalShapeInferenceRuleTest, avgpool2d) +{ + using namespace loco; + + // Create a sample network + GraphTestcase<GraphCode::AvgPool2D> testcase; + + auto perm = make_NHWC_perm<Domain::Feature>(); + + testcase.pull_node->shape({1, 8, 4, 3}); + + testcase.encode_node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm)); + + testcase.avgpool2d_node->window()->vertical(2); + testcase.avgpool2d_node->window()->horizontal(2); + + testcase.avgpool2d_node->stride()->vertical(2); + testcase.avgpool2d_node->stride()->horizontal(2); + + testcase.decode_node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm)); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + // + // NOTE AvgPool2D testcase assumes NHWC layout + ASSERT_TRUE(loco::shape_known(testcase.avgpool2d_node)); + ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).domain(), loco::Domain::Feature); + ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().count(), 1); + ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().depth(), 3); + ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().height(), 4); + ASSERT_EQ(loco::shape_get(testcase.avgpool2d_node).as<FeatureShape>().width(), 2); +} + +TEST(CanonicalShapeInferenceRuleTest, depthwiseconv2d) +{ + using namespace loco; + + // Create a sample network + GraphTestcase<GraphCode::DepthwiseConv2D> testcase; + + testcase.pull_node->shape({1, 4, 4, 3}); + + testcase.const_node->dtype(loco::DataType::FLOAT32); + testcase.const_node->shape({2, 2, 3, 2}); + + testcase.depthwiseconv2d_node->stride()->vertical(1); + testcase.depthwiseconv2d_node->stride()->horizontal(1); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + // + // NOTE DepthwiseConv2D testcase assumes NHWC layout + ASSERT_TRUE(loco::shape_known(testcase.depthwiseconv2d_node)); + ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).domain(), loco::Domain::Feature); + ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().count(), 1); + ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().depth(), 6); + ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().height(), 3); + ASSERT_EQ(loco::shape_get(testcase.depthwiseconv2d_node).as<FeatureShape>().width(), 3); +} + +TEST(CanonicalShapeInferenceRuleTest, transposedconv2d) +{ + using namespace loco; + + // Create a sample network + GraphTestcase<GraphCode::TransposedConv2D> testcase; + + testcase.pull_node->shape({1, 270, 480, 24}); // NHWC + + testcase.const_node->dtype(loco::DataType::FLOAT32); + testcase.const_node->shape({3, 3, 24, 12}); // HWCN (or HWIO) + + testcase.tr_conv2d_node->stride()->vertical(2); + testcase.tr_conv2d_node->stride()->horizontal(2); + + testcase.tr_conv2d_node->pad()->top(0); + testcase.tr_conv2d_node->pad()->bottom(1); + testcase.tr_conv2d_node->pad()->left(0); + testcase.tr_conv2d_node->pad()->right(1); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.tr_conv2d_node)); + ASSERT_EQ(loco::shape_get(testcase.tr_conv2d_node).domain(), loco::Domain::Feature); + ASSERT_EQ(loco::shape_get(testcase.tr_conv2d_node).as<FeatureShape>().count(), 1); + ASSERT_EQ(loco::shape_get(testcase.tr_conv2d_node).as<FeatureShape>().height(), 540); + ASSERT_EQ(loco::shape_get(testcase.tr_conv2d_node).as<FeatureShape>().width(), 960); + ASSERT_EQ(loco::shape_get(testcase.tr_conv2d_node).as<FeatureShape>().depth(), 12); +} + +TEST(CanonicalShapeInferenceRuleTest, maxpool2d) +{ + using namespace loco; + + // Create a sample network + GraphTestcase<GraphCode::MaxPool2D> testcase; + + auto perm = make_NHWC_perm<Domain::Feature>(); + + testcase.pull_node->shape({1, 8, 4, 3}); + + testcase.encode_node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm)); + + testcase.maxpool2d_node->window()->vertical(2); + testcase.maxpool2d_node->window()->horizontal(2); + + testcase.maxpool2d_node->stride()->vertical(2); + testcase.maxpool2d_node->stride()->horizontal(2); + + testcase.decode_node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm)); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + // + // NOTE MaxPool2D testcase assumes NHWC layout + ASSERT_TRUE(loco::shape_known(testcase.maxpool2d_node)); + ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).domain(), loco::Domain::Feature); + ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().count(), 1); + ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().depth(), 3); + ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().height(), 4); + ASSERT_EQ(loco::shape_get(testcase.maxpool2d_node).as<FeatureShape>().width(), 2); +} + +TEST(CanonicalShapeInferenceRuleTest, tensor_concat) +{ + using namespace loco; + + // Create a sample network + GraphTestcase<GraphCode::TensorConcat> testcase; + + testcase.lhs_node->shape({1, 2, 3}); + testcase.rhs_node->shape({1, 4, 3}); + testcase.concat_node->axis(1); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.concat_node)); + ASSERT_EQ(loco::shape_get(testcase.concat_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().rank(), 3); + ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(0), 1); + ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(1), 6); + ASSERT_EQ(loco::shape_get(testcase.concat_node).as<TensorShape>().dim(2), 3); +} + +TEST(CanonicalShapeInferenceRuleTest, fixed_reshape) +{ + // Create a sample network + GraphTestcase<GraphCode::FixedReshape> testcase; + + testcase.pull_node->shape({6, 6}); + testcase.reshape_node->shape({4, 9}); + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.push_node)); + ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 2); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 4); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 9); +} + +TEST(CanonicalShapeInferenceRuleTest, tensor_broadcast) +{ + // Create a sample network + GraphTestcase<GraphCode::TensorBroadcast> testcase{1, 2}; + + testcase.broadcast_node->mapping()->dim(0) = 4; + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(testcase.push_node)); + ASSERT_EQ(loco::shape_get(testcase.push_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().rank(), 2); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(0), 4); + ASSERT_EQ(loco::shape_get(testcase.push_node).as<loco::TensorShape>().dim(1), 2); +} + +TEST(CanonicalShapeInferenceRuleTest, tensor_transpose) +{ + // Create a sample network + GraphTestcase<GraphCode::TensorTranspose> tc; + + tc.pull_node->shape({10, 20, 30, 40}); + + tc.transpose_node->perm()->size(4); + tc.transpose_node->perm()->axis(0) = 2; + tc.transpose_node->perm()->axis(1) = 3; + tc.transpose_node->perm()->axis(2) = 0; + tc.transpose_node->perm()->axis(3) = 1; + + // Run Inference + loco::CanonicalShapeInferenceRule rule; + + loco::apply(&rule).to(tc.graph()); + + // Verify! + ASSERT_TRUE(loco::shape_known(tc.push_node)); + ASSERT_EQ(loco::shape_get(tc.push_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().rank(), 4); + ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(0), 30); + ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(1), 40); + ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(2), 10); + ASSERT_EQ(loco::shape_get(tc.push_node).as<loco::TensorShape>().dim(3), 20); +} + +namespace +{ + +struct MockContext final : public loco::ShapeInferenceRule::Context +{ + bool known(const loco::Node *node) const final { return _content.find(node) != _content.end(); } + loco::NodeShape get(const loco::Node *node) const final { return _content.at(node); } + + std::map<const loco::Node *, loco::NodeShape> _content; +}; + +struct MockSink final : public loco::ShapeInferenceRule::Sink +{ + void okay(const loco::NodeShape &res) final { shape = res; } + void fail(void) final { return; } + + loco::NodeShape shape; +}; + +} // namespace + +TEST(CanonicalShapeInferenceRuleTest, infer_v2) +{ + auto g = loco::make_graph(); + + // Create an incomplete graph + auto relu_1 = g->nodes()->create<loco::ReLU>(); + auto relu_2 = g->nodes()->create<loco::ReLU>(); + + relu_2->input(relu_1); + + // Set up Context + MockContext ctx; + + loco::TensorShape tensor_shape; + + tensor_shape.rank(2); + tensor_shape.dim(0) = 4; + tensor_shape.dim(1) = 5; + + ctx._content[relu_1] = tensor_shape; + + // Create a Sink + MockSink sink; + + loco::CanonicalShapeInferenceRule rule; + + rule.infer(&ctx, relu_2, &sink); + + ASSERT_EQ(sink.shape.domain(), loco::Domain::Tensor); + ASSERT_EQ(sink.shape.as<loco::TensorShape>().rank(), 2); + ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(0), 4); + ASSERT_EQ(sink.shape.as<loco::TensorShape>().dim(1), 5); +} diff --git a/compiler/loco/src/Service/GraphBuilder.h b/compiler/loco/src/Service/GraphBuilder.h new file mode 100644 index 000000000..71084673c --- /dev/null +++ b/compiler/loco/src/Service/GraphBuilder.h @@ -0,0 +1,547 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __GRAPH_BUILDER_H__ +#define __GRAPH_BUILDER_H__ + +// loco-internal headers +#include "loco/IR/Graph.h" + +// repo-internal headers +#include <stdex/Memory.h> + +// C++ standard headers +#include <stack> + +// +// This file includes a stack-based loco graph builder +// +// HOW TO USE +// +// loco::Graph *g = ... +// auto builder = make_graph_builder(g); +// +// builder->push<YourAwesomeLayer>(...); +// + +class GraphBuilder final +{ +public: + class Stack final + { + public: + Stack() = default; + + public: + loco::Node *top(void) const { return _content.top(); } + + public: + loco::Node *pop(void) + { + auto ret = top(); + _content.pop(); + return ret; + } + + public: + void push(loco::Node *node) { _content.push(node); } + + private: + std::stack<loco::Node *> _content; + }; + + class Context final + { + public: + Context(loco::Graph *graph) : _graph{graph} + { + // DO NOTHING + } + + public: + loco::Graph *graph(void) { return _graph; } + Stack *stack(void) { return &_stack; } + + private: + loco::Graph *_graph = nullptr; + Stack _stack; + }; + +public: + GraphBuilder(loco::Graph *graph) : _context{graph} + { + // DO NOTHING + } + +public: + // "Layer" is in theory a subgraph builder. + template <typename Layer, typename... Args> + auto push(Args &&... args) + -> decltype(static_cast<Layer *>(nullptr)->operator()(static_cast<Context *>(nullptr))) + { + Layer layer{std::forward<Args>(args)...}; + return layer(ctx()); + } + +public: + loco::Node *pop(void) { return ctx()->stack()->pop(); } + +private: + Context *ctx(void) { return &_context; } + +private: + Context _context; +}; + +static inline std::unique_ptr<GraphBuilder> make_graph_builder(loco::Graph *g) +{ + return stdex::make_unique<GraphBuilder>(g); +} + +// "InputLayer" creates both GraphInput and Pull node at once +struct InputLayer final +{ + class Return + { + public: + Return(loco::GraphInput *input, loco::Pull *node) : _input{input}, _node{node} + { + // DO NOTHING + } + + public: + loco::Pull *node(void) { return _node; } + + public: + Return *name(const std::string &value) + { + _input->name(value); + return this; + } + + public: + Return *shape(std::initializer_list<uint32_t> dims) + { + // TODO Uncomment this line when GraphInput is ready + // _graph_input->shape(dims) + _node->shape(dims); + return this; + } + + private: + loco::GraphInput *_input = nullptr; + loco::Pull *_node = nullptr; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto input_index = ctx->graph()->inputs()->size(); + auto graph_input = ctx->graph()->inputs()->create(); + + auto pull_node = ctx->graph()->nodes()->create<loco::Pull>(); + + pull_node->index(input_index); + + loco::link(graph_input, pull_node); + + ctx->stack()->push(pull_node); + + return stdex::make_unique<Return>(graph_input, pull_node); + } +}; + +// "OutputLayer" creates both GraphOutput and Push node at once. +struct OutputLayer final +{ + class Return + { + public: + Return(loco::GraphOutput *output, loco::Push *node) : _output{output}, _node{node} + { + // DO NOTHING + } + + public: + loco::Push *node(void) { return _node; } + + public: + Return *name(const std::string &value) + { + // TODO Uncomment this line when GraphOutput is ready + // _graph_output->shape(dims) + _output->name(value); + return this; + } + + private: + loco::GraphOutput *_output = nullptr; + loco::Push *_node = nullptr; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto output_index = ctx->graph()->outputs()->size(); + auto graph_output = ctx->graph()->outputs()->create(); + + auto push_node = ctx->graph()->nodes()->create<loco::Push>(); + + push_node->from(ctx->stack()->pop()); + push_node->index(output_index); + + loco::link(graph_output, push_node); + + ctx->stack()->push(push_node); + + return stdex::make_unique<Return>(graph_output, push_node); + } +}; + +struct ReLULayer final +{ + // This "Return" is unnecessary for ReLU as ReLU has no attributes), but + // introduced for consistency. + class Return + { + public: + Return(loco::ReLU *node) : _node{node} + { + // DO NOTHING + } + + public: + loco::ReLU *node(void) { return _node; } + + private: + loco::ReLU *_node = nullptr; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto relu_node = ctx->graph()->nodes()->create<loco::ReLU>(); + + relu_node->input(ctx->stack()->pop()); + + ctx->stack()->push(relu_node); + + return stdex::make_unique<Return>(relu_node); + } +}; + +struct ConstGenLayer final +{ + class Return + { + public: + Return(loco::ConstGen *node) : _node{node} + { + // DO NOTHING + } + + public: + loco::ConstGen *node(void) { return _node; } + + private: + loco::ConstGen *_node = nullptr; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto const_node = ctx->graph()->nodes()->create<loco::ConstGen>(); + + ctx->stack()->push(const_node); + + return stdex::make_unique<Return>(const_node); + } +}; + +#include "loco/IR/PermutingCodec.h" + +struct FeatureEncodeLayer final +{ + class Return + { + public: + Return(loco::FeatureEncode *node) : _node{node} + { + // DO NOTHING + } + + public: + Return *perm(const loco::Permutation<loco::Domain::Feature> &perm) + { + using namespace loco; + _node->encoder(stdex::make_unique<PermutingEncoder<Domain::Feature>>(perm)); + return this; + } + + public: + loco::FeatureEncode *node(void) { return _node; } + + private: + loco::FeatureEncode *_node; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto encode_node = ctx->graph()->nodes()->create<loco::FeatureEncode>(); + + encode_node->input(ctx->stack()->pop()); + + ctx->stack()->push(encode_node); + + return stdex::make_unique<Return>(encode_node); + } +}; + +struct FeatureDecodeLayer final +{ + class Return + { + public: + Return(loco::FeatureDecode *node) : _node{node} + { + // DO NOTHING + } + + public: + Return *perm(const loco::Permutation<loco::Domain::Feature> &perm) + { + using namespace loco; + _node->decoder(stdex::make_unique<PermutingDecoder<Domain::Feature>>(perm)); + return this; + } + + public: + loco::FeatureDecode *node(void) { return _node; } + + private: + loco::FeatureDecode *_node; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + using namespace loco; + + auto decode_node = ctx->graph()->nodes()->create<FeatureDecode>(); + + decode_node->input(ctx->stack()->pop()); + + ctx->stack()->push(decode_node); + + return stdex::make_unique<Return>(decode_node); + } +}; + +struct FilterEncodeLayer final +{ + class Return + { + public: + Return(loco::FilterEncode *node) : _node{node} + { + // DO NOTHING + } + + public: + Return *perm(const loco::Permutation<loco::Domain::Filter> &perm) + { + auto encoder = stdex::make_unique<loco::PermutingEncoder<loco::Domain::Filter>>(); + encoder->perm(perm); + _node->encoder(std::move(encoder)); + return this; + } + + public: + loco::FilterEncode *node(void) { return _node; } + + private: + loco::FilterEncode *_node; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto encode_node = ctx->graph()->nodes()->create<loco::FilterEncode>(); + + encode_node->input(ctx->stack()->pop()); + + ctx->stack()->push(encode_node); + + return stdex::make_unique<Return>(encode_node); + } +}; + +struct DepthwiseFilterEncodeLayer final +{ + class Return + { + public: + Return(loco::DepthwiseFilterEncode *node) : _node{node} + { + // DO NOTHING + } + + public: + Return *perm(const loco::Permutation<loco::Domain::DepthwiseFilter> &perm) + { + using namespace loco; + _node->encoder(stdex::make_unique<PermutingEncoder<Domain::DepthwiseFilter>>(perm)); + return this; + } + + public: + loco::DepthwiseFilterEncode *node(void) { return _node; } + + private: + loco::DepthwiseFilterEncode *_node; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto encode_node = ctx->graph()->nodes()->create<loco::DepthwiseFilterEncode>(); + + encode_node->input(ctx->stack()->pop()); + + ctx->stack()->push(encode_node); + + return stdex::make_unique<Return>(encode_node); + } +}; + +struct DepthwiseConv2DLayer final +{ + class Return + { + public: + Return(loco::DepthwiseConv2D *node) : _node{node} + { + // DO NOTHING + } + + public: + loco::DepthwiseConv2D *node(void) { return _node; } + + private: + loco::DepthwiseConv2D *_node; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto depthwiseconv2d_node = ctx->graph()->nodes()->create<loco::DepthwiseConv2D>(); + + depthwiseconv2d_node->ker(ctx->stack()->pop()); + depthwiseconv2d_node->ifm(ctx->stack()->pop()); + + ctx->stack()->push(depthwiseconv2d_node); + + return stdex::make_unique<Return>(depthwiseconv2d_node); + } +}; + +struct TransposedConv2DLayer final +{ + class Return + { + public: + Return(loco::TransposedConv2D *node) : _node{node} + { + // DO NOTHING + } + + public: + loco::TransposedConv2D *node(void) { return _node; } + + private: + loco::TransposedConv2D *_node; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto tr_conv2d_node = ctx->graph()->nodes()->create<loco::TransposedConv2D>(); + + tr_conv2d_node->ker(ctx->stack()->pop()); + tr_conv2d_node->ifm(ctx->stack()->pop()); + + ctx->stack()->push(tr_conv2d_node); + + return stdex::make_unique<Return>(tr_conv2d_node); + } +}; + +struct FixedReshapeLayer final +{ + class Return + { + public: + Return(loco::FixedReshape *node) : _node{node} + { + // DO NOTHING + } + + public: + Return *shape(std::initializer_list<uint32_t> dims) + { + _node->shape(dims); + return this; + } + + public: + loco::FixedReshape *node(void) { return _node; } + + private: + loco::FixedReshape *_node = nullptr; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto reshape_node = ctx->graph()->nodes()->create<loco::FixedReshape>(); + + reshape_node->input(ctx->stack()->pop()); + + ctx->stack()->push(reshape_node); + + return stdex::make_unique<Return>(reshape_node); + } +}; + +struct TensorBroadcastLayer final +{ + class Return + { + public: + Return(loco::TensorBroadcast *node) : _node{node} + { + // DO NOTHING + } + + public: + loco::TensorBroadcast *node(void) { return _node; } + + private: + loco::TensorBroadcast *_node = nullptr; + }; + + std::unique_ptr<Return> operator()(GraphBuilder::Context *ctx) + { + auto broadcast_node = ctx->graph()->nodes()->create<loco::TensorBroadcast>(); + + broadcast_node->input(ctx->stack()->pop()); + ctx->stack()->push(broadcast_node); + + return stdex::make_unique<Return>(broadcast_node); + } +}; + +#endif // __GRAPH_BUILDER_H__ diff --git a/compiler/loco/src/Service/GraphBuilder.test.cpp b/compiler/loco/src/Service/GraphBuilder.test.cpp new file mode 100644 index 000000000..7b2ea5198 --- /dev/null +++ b/compiler/loco/src/Service/GraphBuilder.test.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "GraphBuilder.h" + +#include "loco/IR/Nodes.h" +#include "loco/IR/CanonicalDialect.h" +#include "loco/IR/CanonicalOpcode.h" + +#include <gtest/gtest.h> + +TEST(GraphBuilderTest, Usecase_000) +{ + struct SampleLayer final + { + loco::Node *operator()(GraphBuilder::Context *ctx) + { + auto node = ctx->graph()->nodes()->create<loco::ConstGen>(); + ctx->stack()->push(node); + return node; + } + }; + + auto g = loco::make_graph(); + auto gbuilder = make_graph_builder(g.get()); + + gbuilder->push<SampleLayer>(); + + auto node = gbuilder->pop(); + + ASSERT_EQ(g->nodes()->size(), 1); + ASSERT_EQ(node->dialect(), loco::CanonicalDialect::get()); + ASSERT_EQ(node->opnum(), static_cast<uint32_t>(loco::CanonicalOpcode::ConstGen)); +} diff --git a/compiler/loco/src/Service/GraphTestcase.h b/compiler/loco/src/Service/GraphTestcase.h new file mode 100644 index 000000000..6743b9a14 --- /dev/null +++ b/compiler/loco/src/Service/GraphTestcase.h @@ -0,0 +1,541 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __GRAPH_TESTCASE_H__ +#define __GRAPH_TESTCASE_H__ + +#include "loco/IR/Graph.h" +#include "loco/IR/PermutingCodec.h" + +#include "GraphBuilder.h" + +#include <stdex/Memory.h> + +enum class GraphCode +{ + Identity, + ConstGen, + Relu, + FeatureCodec, + AvgPool2D, + DepthwiseConv2D, + TransposedConv2D, + MaxPool2D, + TensorBroadcast, + TensorConcat, + TensorTranspose, + FixedReshape, +}; + +namespace +{ + +template <loco::Domain D> loco::Permutation<D> make_NHWC_perm(void); + +template <> loco::Permutation<loco::Domain::Feature> make_NHWC_perm(void) +{ + loco::Permutation<loco::Domain::Feature> perm; + + perm[loco::FeatureAxis::Count] = 0; + perm[loco::FeatureAxis::Height] = 1; + perm[loco::FeatureAxis::Width] = 2; + perm[loco::FeatureAxis::Depth] = 3; + + return perm; +} + +template <loco::Domain D> loco::Permutation<D> make_HWCN_perm(void); + +// @note Also known as HWIO permutation +template <> loco::Permutation<loco::Domain::Filter> make_HWCN_perm(void) +{ + loco::Permutation<loco::Domain::Filter> perm; + + perm[loco::FilterAxis::Height] = 0; + perm[loco::FilterAxis::Width] = 1; + perm[loco::FilterAxis::Depth] = 2; + perm[loco::FilterAxis::Count] = 3; + + return perm; +} + +template <loco::Domain D> loco::Permutation<D> make_HWCM_perm(void); + +template <> loco::Permutation<loco::Domain::DepthwiseFilter> make_HWCM_perm(void) +{ + loco::Permutation<loco::Domain::DepthwiseFilter> perm; + + perm[loco::DepthwiseFilterAxis::Height] = 0; + perm[loco::DepthwiseFilterAxis::Width] = 1; + perm[loco::DepthwiseFilterAxis::Depth] = 2; + perm[loco::DepthwiseFilterAxis::Multiplier] = 3; + + return perm; +} + +} // namespace + +template <GraphCode Code> class GraphTestcase; + +template <> class GraphTestcase<GraphCode::Identity> final +{ +private: + void init(std::initializer_list<uint32_t> dims) + { + // Create a sample network + _graph = loco::make_graph(); + + auto graph_builder = make_graph_builder(_graph.get()); + + pull_node = graph_builder->push<InputLayer>()->name("input")->shape(dims)->node(); + push_node = graph_builder->push<OutputLayer>()->name("output")->node(); + } + +public: + // NOTE This default constructor guarantees backward compatbility. + GraphTestcase() { init({1, 4, 8, 3}); } + GraphTestcase(std::initializer_list<uint32_t> dims) { init(dims); } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::ConstGen> final +{ +public: + GraphTestcase() + { + _graph = loco::make_graph(); + + auto graph_builder = make_graph_builder(_graph.get()); + + const_node = graph_builder->push<ConstGenLayer>()->node(); + + push_node = graph_builder->push<OutputLayer>()->name("output")->node(); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::ConstGen *const_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::Relu> final +{ +public: + GraphTestcase() + { + // Create a sample network + _graph = loco::make_graph(); + + auto graph_builder = make_graph_builder(_graph.get()); + + pull_node = graph_builder->push<InputLayer>()->name("input")->node(); + relu_node = graph_builder->push<ReLULayer>()->node(); + push_node = graph_builder->push<OutputLayer>()->name("output")->node(); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::ReLU *relu_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::FeatureCodec> final +{ +public: + GraphTestcase() + { + using namespace loco; + + Permutation<Domain::Feature> perm; + + perm[FeatureAxis::Count] = 0; + perm[FeatureAxis::Height] = 1; + perm[FeatureAxis::Width] = 2; + perm[FeatureAxis::Depth] = 3; + + // Create a sample network + _graph = make_graph(); + + auto graph_builder = make_graph_builder(_graph.get()); + + pull_node = graph_builder->push<InputLayer>()->name("input")->node(); + encode_node = graph_builder->push<FeatureEncodeLayer>()->perm(perm)->node(); + decode_node = graph_builder->push<FeatureDecodeLayer>()->perm(perm)->node(); + push_node = graph_builder->push<OutputLayer>()->name("output")->node(); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::FeatureEncode *encode_node = nullptr; + loco::FeatureDecode *decode_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::AvgPool2D> final +{ +public: + GraphTestcase() + { + using namespace loco; + + // Create a sample network + _graph = make_graph(); + + // Create Graph Input/Output + auto graph_input = _graph->inputs()->create(); + auto graph_output = _graph->outputs()->create(); + + graph_input->name("input"); + graph_output->name("output"); + + // Create and connect nodes + pull_node = _graph->nodes()->create<Pull>(); + pull_node->index(0); + + encode_node = _graph->nodes()->create<FeatureEncode>(); + encode_node->input(pull_node); + + avgpool2d_node = _graph->nodes()->create<AvgPool2D>(); + avgpool2d_node->ifm(encode_node); + + decode_node = _graph->nodes()->create<FeatureDecode>(); + decode_node->input(avgpool2d_node); + + push_node = _graph->nodes()->create<loco::Push>(); + push_node->index(0); + push_node->from(decode_node); + + // Create a link between input/output and corresponding nodes + loco::link(graph_input, pull_node); + loco::link(graph_output, push_node); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::FeatureEncode *encode_node = nullptr; + loco::AvgPool2D *avgpool2d_node = nullptr; + loco::FeatureDecode *decode_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::DepthwiseConv2D> final +{ +public: + GraphTestcase() + { + using namespace loco; + + _graph = make_graph(); + + auto graph_builder = make_graph_builder(_graph.get()); + + Permutation<Domain::Feature> perm = make_NHWC_perm<Domain::Feature>(); + Permutation<Domain::DepthwiseFilter> filter_perm = make_HWCM_perm<Domain::DepthwiseFilter>(); + + pull_node = graph_builder->push<InputLayer>()->name("input")->node(); + encode_node = graph_builder->push<FeatureEncodeLayer>()->perm(perm)->node(); + + const_node = graph_builder->push<ConstGenLayer>()->node(); + + filter_encode_node = + graph_builder->push<DepthwiseFilterEncodeLayer>()->perm(filter_perm)->node(); + + depthwiseconv2d_node = graph_builder->push<DepthwiseConv2DLayer>()->node(); + + decode_node = graph_builder->push<FeatureDecodeLayer>()->perm(perm)->node(); + push_node = graph_builder->push<OutputLayer>()->name("output")->node(); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::FeatureEncode *encode_node = nullptr; + loco::ConstGen *const_node = nullptr; + loco::DepthwiseFilterEncode *filter_encode_node = nullptr; + loco::DepthwiseConv2D *depthwiseconv2d_node = nullptr; + loco::FeatureDecode *decode_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::TransposedConv2D> final +{ +public: + GraphTestcase() + { + using namespace loco; + + // Prepare permutations + Permutation<Domain::Feature> feature_perm = make_NHWC_perm<Domain::Feature>(); + Permutation<Domain::Filter> filter_perm = make_HWCN_perm<Domain::Filter>(); + + // Build graph + _graph = make_graph(); + auto graph_builder = make_graph_builder(_graph.get()); + + pull_node = graph_builder->push<InputLayer>()->name("input")->node(); + encode_node = graph_builder->push<FeatureEncodeLayer>()->perm(feature_perm)->node(); + const_node = graph_builder->push<ConstGenLayer>()->node(); + filter_encode_node = graph_builder->push<FilterEncodeLayer>()->perm(filter_perm)->node(); + tr_conv2d_node = graph_builder->push<TransposedConv2DLayer>()->node(); + decode_node = graph_builder->push<FeatureDecodeLayer>()->perm(feature_perm)->node(); + push_node = graph_builder->push<OutputLayer>()->name("output")->node(); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::FeatureEncode *encode_node = nullptr; + loco::ConstGen *const_node = nullptr; + loco::FilterEncode *filter_encode_node = nullptr; + loco::TransposedConv2D *tr_conv2d_node = nullptr; + loco::FeatureDecode *decode_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::MaxPool2D> final +{ +public: + GraphTestcase() + { + using namespace loco; + + // Create a sample network + _graph = make_graph(); + + // Create Graph Input/Output + auto graph_input = _graph->inputs()->create(); + auto graph_output = _graph->outputs()->create(); + + graph_input->name("input"); + graph_output->name("output"); + + // Create and connect nodes + pull_node = _graph->nodes()->create<Pull>(); + pull_node->index(0); + + encode_node = _graph->nodes()->create<FeatureEncode>(); + encode_node->input(pull_node); + + maxpool2d_node = _graph->nodes()->create<MaxPool2D>(); + maxpool2d_node->ifm(encode_node); + + decode_node = _graph->nodes()->create<FeatureDecode>(); + decode_node->input(maxpool2d_node); + + push_node = _graph->nodes()->create<loco::Push>(); + push_node->index(0); + push_node->from(decode_node); + + // Create a link between input/output and corresponding nodes + loco::link(graph_input, pull_node); + loco::link(graph_output, push_node); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::FeatureEncode *encode_node = nullptr; + loco::MaxPool2D *maxpool2d_node = nullptr; + loco::FeatureDecode *decode_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::TensorConcat> final +{ +public: + GraphTestcase() + { + using namespace loco; + + // Create a sample network + _graph = make_graph(); + + // Create Graph Input/Output + auto graph_lhs = _graph->inputs()->create(); + auto graph_rhs = _graph->inputs()->create(); + auto graph_out = _graph->outputs()->create(); + + graph_lhs->name("lhs"); + graph_rhs->name("rhs"); + graph_out->name("output"); + + // Create and connect nodes + lhs_node = _graph->nodes()->create<Pull>(); + lhs_node->index(0); + + rhs_node = _graph->nodes()->create<Pull>(); + rhs_node->index(1); + + concat_node = _graph->nodes()->create<TensorConcat>(); + concat_node->lhs(lhs_node); + concat_node->rhs(rhs_node); + + push_node = _graph->nodes()->create<loco::Push>(); + push_node->index(0); + push_node->from(concat_node); + + // Create a link between input/output and corresponding nodes + loco::link(graph_lhs, lhs_node); + loco::link(graph_rhs, rhs_node); + loco::link(graph_out, push_node); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *lhs_node = nullptr; + loco::Pull *rhs_node = nullptr; + loco::TensorConcat *concat_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::FixedReshape> final +{ +public: + GraphTestcase() + { + _graph = loco::make_graph(); + + auto graph_builder = make_graph_builder(_graph.get()); + + pull_node = graph_builder->push<InputLayer>()->name("input")->node(); + reshape_node = graph_builder->push<FixedReshapeLayer>()->node(); + push_node = graph_builder->push<OutputLayer>()->name("output")->node(); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::FixedReshape *reshape_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::TensorBroadcast> final +{ +public: + GraphTestcase(std::initializer_list<uint32_t> dims) + { + _graph = loco::make_graph(); + + auto graph_builder = make_graph_builder(_graph.get()); + + pull_node = graph_builder->push<InputLayer>()->name("input")->shape(dims)->node(); + broadcast_node = graph_builder->push<TensorBroadcastLayer>()->node(); + push_node = graph_builder->push<OutputLayer>()->name("output")->node(); + } + +public: + loco::Graph *graph(void) { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::TensorBroadcast *broadcast_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +template <> class GraphTestcase<GraphCode::TensorTranspose> final +{ +public: + GraphTestcase() + { + using namespace loco; + + // Create a sample network + _graph = make_graph(); + + // Create Graph Input/Output + auto graph_input = _graph->inputs()->create(); + auto graph_output = _graph->outputs()->create(); + + graph_input->name("input"); + graph_output->name("output"); + + // Create and connect nodes + pull_node = _graph->nodes()->create<Pull>(); + pull_node->index(0); + + transpose_node = _graph->nodes()->create<TensorTranspose>(); + transpose_node->input(pull_node); + + push_node = _graph->nodes()->create<loco::Push>(); + push_node->index(0); + push_node->from(transpose_node); + + // Create a link between input/output and corresponding nodes + loco::link(graph_input, pull_node); + loco::link(graph_output, push_node); + } + +public: + loco::Graph *graph() { return _graph.get(); } + + loco::Pull *pull_node = nullptr; + loco::TensorTranspose *transpose_node = nullptr; + loco::Push *push_node = nullptr; + +private: + std::unique_ptr<loco::Graph> _graph; +}; + +#endif // __GRAPH_TESTCASE_H__ diff --git a/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp new file mode 100644 index 000000000..2178f5d05 --- /dev/null +++ b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/MultiDialectShapeInferenceRule.h" +#include "loco/Service/ShapeInferenceRule.h" + +#include <loco/IR/Dialect.h> +#include <loco/IR/Node.h> +#include <loco/IR/NodeShape.h> + +#include <cassert> + +namespace loco +{ + +bool MultiDialectShapeInferenceRule::recognize(const Dialect *d) const +{ + const auto found = _rules.find(d); + + if (found == _rules.cend()) + return false; + + auto rule = found->second; + auto result = rule->recognize(d); + + return result; +} + +bool MultiDialectShapeInferenceRule::infer(const Node *node, NodeShape &shape) const +{ + const auto found = _rules.find(node->dialect()); + + if (found == _rules.cend()) + return false; + + auto rule = found->second; + if (rule->infer(node, shape)) + return true; + + return false; +} + +MultiDialectShapeInferenceRule &MultiDialectShapeInferenceRule::bind(const Dialect *d, + const ShapeInferenceRule *rule) +{ + assert(_rules.find(d) == _rules.end()); + assert(rule->recognize(d)); + + _rules[d] = rule; + + return (*this); +} + +} // namespace loco diff --git a/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp new file mode 100644 index 000000000..ffa9ee5ca --- /dev/null +++ b/compiler/loco/src/Service/MultiDialectShapeInferenceRule.test.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/CanonicalShapeInferenceRule.h" +#include "loco/Service/MultiDialectShapeInferenceRule.h" +#include "loco/Service/ShapeInference.h" + +#include <loco/IR/Dialect.h> +#include <loco/IR/CanonicalDialect.h> + +#include <gtest/gtest.h> + +#include <cassert> +#include <vector> + +// mockup for MultiDialectShapeInferenceRule +// Each class is dedicated for handling shape { D1, D2 } and D1, D2 are declared as a template +namespace +{ + +template <uint32_t D1, uint32_t D2> class TestDialect final : public loco::Dialect +{ +public: + static Dialect *get(void) + { + static TestDialect<D1, D2> d; + return &d; + } +}; + +template <uint32_t D1, uint32_t D2> +struct TestOpNode final : public loco::FixedArity<1>::Mixin<loco::Node>, + public loco::NodeMixin<loco::NodeTrait::TensorShape> +{ + void input(Node *node) { at(0)->node(node); } + const loco::Dialect *dialect(void) const final { return TestDialect<D1, D2>::get(); } + uint32_t opnum(void) const final { return static_cast<uint32_t>(D1); /* not used */ } +}; + +template <uint32_t D1, uint32_t D2> +struct TestShapeInferenceRule final : public loco::ShapeInferenceRule +{ +public: + bool recognize(const loco::Dialect *d) const final { return (d == TestDialect<D1, D2>::get()); } + + bool infer(const loco::Node *node, loco::NodeShape &node_shape) const final + { + assert(recognize(node->dialect())); + auto test_node = dynamic_cast<const TestOpNode<D1, D2> *>(node); + assert(test_node != nullptr); + + loco::TensorShape ts; + { + ts.rank(2); + ts.dim(0) = D1; + ts.dim(1) = D2; // making shape : { D1, D2 } + } + + node_shape.set(ts); + + return true; + } +}; + +} // namespace + +TEST(MultiDialectShapeInferenceRuleTest, test1) +{ + // Create a simple network : Pull ------- t23<2,3> ------------ t45<4,5> ---------- Push + // TensorShape({2, 3}) TensorShape({4, 5}) + auto g = loco::make_graph(); + + auto pull_node = g->nodes()->create<loco::Pull>(); + auto t23_node = g->nodes()->create<TestOpNode<2, 3>>(); + auto t45_node = g->nodes()->create<TestOpNode<4, 5>>(); + auto push_node = g->nodes()->create<loco::Push>(); + + t23_node->input(pull_node); + t45_node->input(t23_node); + push_node->from(t45_node); + + auto graph_input = g->inputs()->create(); + graph_input->name("input"); + loco::link(graph_input, pull_node); + + auto graph_output = g->outputs()->create(); + graph_output->name("output"); + loco::link(graph_output, push_node); + + // initially they don't have shape info + ASSERT_FALSE(loco::shape_known(t23_node)); + ASSERT_FALSE(loco::shape_known(t45_node)); + + // Run Type Inference + loco::CanonicalShapeInferenceRule canonical_rule; + TestShapeInferenceRule<2, 3> t23_rule; + TestShapeInferenceRule<4, 5> t45_rule; + + loco::MultiDialectShapeInferenceRule rules; + + rules.bind(loco::CanonicalDialect::get(), &canonical_rule) + .bind(TestDialect<2, 3>::get(), &t23_rule) + .bind(TestDialect<4, 5>::get(), &t45_rule); + + loco::apply(&rules).to(g.get()); + + // Verify! + ASSERT_TRUE(loco::shape_known(t23_node)); + auto t23_shape = loco::shape_get(t23_node); + ASSERT_EQ(t23_shape.domain(), loco::Domain::Tensor); + ASSERT_EQ(t23_shape.as<loco::TensorShape>().rank(), 2); + ASSERT_EQ(t23_shape.as<loco::TensorShape>().dim(0), 2); + ASSERT_EQ(t23_shape.as<loco::TensorShape>().dim(1), 3); + + ASSERT_TRUE(loco::shape_known(t45_node)); + auto t45_shape = loco::shape_get(t45_node); + ASSERT_EQ(t45_shape.domain(), loco::Domain::Tensor); + ASSERT_EQ(t45_shape.as<loco::TensorShape>().rank(), 2); + ASSERT_EQ(t45_shape.as<loco::TensorShape>().dim(0), 4); + ASSERT_EQ(t45_shape.as<loco::TensorShape>().dim(1), 5); +} diff --git a/compiler/loco/src/Service/ShapeInference.cpp b/compiler/loco/src/Service/ShapeInference.cpp new file mode 100644 index 000000000..84eb10963 --- /dev/null +++ b/compiler/loco/src/Service/ShapeInference.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/ShapeInference.h" +#include "loco/IR/Algorithm.h" + +#include <cassert> + +#include <stdex/Memory.h> + +namespace +{ + +bool inputs_shape_ready(loco::Node *node) +{ + assert(node != nullptr); + + for (uint32_t arity = 0; arity < node->arity(); ++arity) + { + if (!loco::ShapeInference::known(node->arg(arity))) + { + return false; + } + } + return true; +} + +} // namespace + +// +// Infrastructure +// +namespace +{ + +struct ShapeAnnotation : public loco::NodeAnnotation +{ +public: + ShapeAnnotation(const loco::NodeShape &shape) : _shape{shape} + { + // DO NOTHING + } + +public: + const loco::NodeShape &shape(void) const { return _shape; } + +private: + loco::NodeShape _shape; +}; + +} // namespace + +namespace loco +{ + +bool ShapeInferenceSession::to(Graph *g) const +{ + assert(_rule->support(ShapeInferenceRule::API::V1) && "API v1 is unavailable"); + + bool changed = false; + + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + if (_rule->recognize(node->dialect())) + { + loco::NodeShape shape; + + if (!shape_known(node) && inputs_shape_ready(node)) + { + if (_rule->infer(node, shape)) + { + node->annot(stdex::make_unique<ShapeAnnotation>(shape)); + changed = true; + } + } + } + } + + return changed; +} + +bool ShapeInference::known(const Node *node) { return node->annot<ShapeAnnotation>() != nullptr; } + +NodeShape ShapeInference::get(const Node *node) +{ + assert(known(node)); + return node->annot<ShapeAnnotation>()->shape(); +} + +void ShapeInference::erase(Node *node) { node->annot<ShapeAnnotation>(nullptr); } + +} // namespace loco diff --git a/compiler/loco/src/Service/ShapeInference.test.cpp b/compiler/loco/src/Service/ShapeInference.test.cpp new file mode 100644 index 000000000..e10b98844 --- /dev/null +++ b/compiler/loco/src/Service/ShapeInference.test.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/ShapeInference.h" +#include "GraphTestcase.h" + +#include <vector> + +#include <gtest/gtest.h> + +// This test validates whether framework works as expected. +TEST(ShapeInferenceTest, framework) +{ + // Mock-up Shape Inference Rule + struct SampleShapeInferenceRule final : public loco::ShapeInferenceRule + { + public: + SampleShapeInferenceRule(std::vector<const loco::Node *> *nodes) : _nodes{nodes} + { + // DO NOTHING + } + + public: + // Accept all the dialects + bool recognize(const loco::Dialect *) const final { return true; } + + bool infer(const loco::Node *node, loco::NodeShape &shape) const final + { + // Record the order of inference + _nodes->emplace_back(node); + + if (_nodes->size() != 1) + { + return false; + } + + // Set the first node as Tensor<1> + loco::TensorShape tensor_shape; + + tensor_shape.rank(1); + tensor_shape.dim(0) = 4; + + shape.set(tensor_shape); + + return true; + } + + private: + std::vector<const loco::Node *> *_nodes; + }; + + GraphTestcase<GraphCode::Identity> testcase; + + std::vector<const loco::Node *> nodes; + + SampleShapeInferenceRule rule{&nodes}; + + loco::apply(&rule).to(testcase.graph()); + + // Framework SHOULD visit all the nodes + ASSERT_EQ(nodes.size(), 2); + // Framework SHOULD visit "pull" before "push" + ASSERT_EQ(nodes.at(0), testcase.pull_node); + ASSERT_EQ(nodes.at(1), testcase.push_node); + + // Framework SHOULD make an annotation if "rule" returns TRUE + ASSERT_TRUE(loco::shape_known(testcase.pull_node)); + ASSERT_EQ(loco::shape_get(testcase.pull_node).domain(), loco::Domain::Tensor); + ASSERT_EQ(loco::shape_get(testcase.pull_node).as<loco::TensorShape>().rank(), 1); + ASSERT_EQ(loco::shape_get(testcase.pull_node).as<loco::TensorShape>().dim(0), 4); + + // Framework SHOULD NOT make any annotation if "rule" returns FALSE + ASSERT_FALSE(loco::shape_known(testcase.push_node)); +} diff --git a/compiler/loco/src/Service/ShapeInferenceRule.cpp b/compiler/loco/src/Service/ShapeInferenceRule.cpp new file mode 100644 index 000000000..bed841260 --- /dev/null +++ b/compiler/loco/src/Service/ShapeInferenceRule.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/ShapeInferenceRule.h" + +#include <stdexcept> + +// This file validates "ShapeInferenceRule.h". Please DO NOT remove this file. + +namespace loco +{ + +void ShapeInferenceRule::infer(const Context *, const Node *, Sink *) const +{ + throw std::runtime_error{"API v2 is not supported"}; +} + +} // namespace loco diff --git a/compiler/loco/src/Service/TypeInference.cpp b/compiler/loco/src/Service/TypeInference.cpp new file mode 100644 index 000000000..fbf0033ee --- /dev/null +++ b/compiler/loco/src/Service/TypeInference.cpp @@ -0,0 +1,228 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/TypeInference.h" + +#include "loco/IR/Algorithm.h" + +#include <cassert> + +#include <stdex/Memory.h> + +namespace +{ + +struct DataTypeAnnotation : public loco::NodeAnnotation +{ +public: + DataTypeAnnotation(const loco::DataType &dtype) : _dtype{dtype} + { + // DO NOTHING + } + +public: + const loco::DataType &dtype(void) const { return _dtype; } + +private: + loco::DataType _dtype; +}; + +bool inputs_dtype_ready(loco::Node *node) +{ + assert(node != nullptr); + + for (uint32_t arity = 0; arity < node->arity(); ++arity) + { + if (!loco::TypeInference::known(node->arg(arity))) + { + return false; + } + } + return true; +} + +} // namespace + +namespace loco +{ + +bool TypeInferenceSession::to(Graph *g) const +{ + bool changed = false; + + for (auto node : postorder_traversal(output_nodes(g))) + { + if (_rule->recognize(node->dialect())) + { + DataType dtype = DataType::Unknown; + + if (!dtype_known(node) && inputs_dtype_ready(node)) + { + if (_rule->infer(node, dtype)) + { + node->annot(stdex::make_unique<DataTypeAnnotation>(dtype)); + changed = true; + } + } + } + } + + return changed; +} + +bool TypeInference::known(const Node *node) { return node->annot<DataTypeAnnotation>() != nullptr; } + +DataType TypeInference::get(const Node *node) +{ + assert(known(node)); + return node->annot<DataTypeAnnotation>()->dtype(); +} + +void TypeInference::erase(Node *node) { return node->annot<DataTypeAnnotation>(nullptr); } + +} // namespace loco + +// +// Canonical (Data) Type Inference Rule +// +#include <loco/IR/CanonicalDialect.h> +#include <loco/IR/CanonicalNode.h> +#include <loco/IR/CanonicalNodeVisitor.h> + +namespace +{ + +/** + * There are two possible maintenance policies. + * - Introduce a new canonical node first, and then extend this algorithm later + * - Introduce a new canonical node and extend this algorithm at the same time + * + * The current implementation assumes the former one (for historical reason). + * + * TODO Evaluate the impact of the latter one + */ +struct CanonicalTypeForwardAlgorithm final : public loco::CanonicalNodeVisitor<loco::DataType> +{ + loco::DataType visit(const loco::AvgPool2D *node) { return loco::dtype_get(node->ifm()); } + loco::DataType visit(const loco::BiasDecode *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::BiasEncode *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::ConstGen *node) { return node->dtype(); } + loco::DataType visit(const loco::Conv2D *node) { return loco::dtype_get(node->ifm()); } + loco::DataType visit(const loco::DepthwiseConv2D *node) { return loco::dtype_get(node->ifm()); } + loco::DataType visit(const loco::DepthwiseFilterEncode *node) + { + return loco::dtype_get(node->input()); + } + loco::DataType visit(const loco::DepthwiseFilterDecode *node) + { + return loco::dtype_get(node->input()); + } + loco::DataType visit(const loco::EltwiseAdd *node) { return loco::dtype_get(node->lhs()); } + loco::DataType visit(const loco::EltwiseDiv *node) { return loco::dtype_get(node->lhs()); } + loco::DataType visit(const loco::EltwiseMax *node) { return loco::dtype_get(node->lhs()); } + loco::DataType visit(const loco::EltwiseMul *node) { return loco::dtype_get(node->lhs()); } + loco::DataType visit(const loco::EltwiseSqrt *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::EltwiseSub *node) { return loco::dtype_get(node->lhs()); } + loco::DataType visit(const loco::Forward *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::FeatureBiasAdd *node) { return loco::dtype_get(node->value()); } + loco::DataType visit(const loco::FeatureDecode *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::FeatureEncode *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::FilterDecode *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::FilterEncode *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::FixedReshape *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::MatrixDecode *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::MatrixEncode *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::MatMul *node) { return loco::dtype_get(node->lhs()); } + loco::DataType visit(const loco::MaxPool2D *node) { return loco::dtype_get(node->ifm()); } + loco::DataType visit(const loco::Push *node) { return loco::dtype_get(node->from()); } + loco::DataType visit(const loco::Pull *node) { return node->dtype(); } + loco::DataType visit(const loco::ReLU *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::ReLU6 *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::Tanh *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::TensorConcat *node) { return loco::dtype_get(node->lhs()); } + loco::DataType visit(const loco::TensorConstantPad *node) + { + return loco::dtype_get(node->input()); + } + loco::DataType visit(const loco::TensorBiasAdd *node) { return loco::dtype_get(node->value()); } + loco::DataType visit(const loco::TensorBroadcast *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::TensorReduce *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::TensorSoftmax *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::TensorTranspose *node) { return loco::dtype_get(node->input()); } + loco::DataType visit(const loco::TransposedConv2D *node) { return loco::dtype_get(node->ifm()); } +}; + +} // namespace + +namespace loco +{ + +bool CanonicalTypeInferenceRule::recognize(const Dialect *d) const +{ + // This rule recognizes only "loco.canonical" dialect! + return CanonicalDialect::get() == d; +} + +bool CanonicalTypeInferenceRule::infer(const Node *node, DataType &dtype) const +{ + assert(node->dialect() == loco::CanonicalDialect::get()); + assert(dynamic_cast<const loco::CanonicalNode *>(node) != nullptr); + + CanonicalTypeForwardAlgorithm alg; + dtype = dynamic_cast<const loco::CanonicalNode *>(node)->accept(&alg); + + return true; +} + +bool MultiDialectTypeInferenceRule::recognize(const Dialect *d) const +{ + const auto found = _rules.find(d); + + if (found == _rules.cend()) + return false; + + auto rule = found->second; + auto result = rule->recognize(d); + + return result; +} + +bool MultiDialectTypeInferenceRule::infer(const Node *node, DataType &dtype) const +{ + const auto found = _rules.find(node->dialect()); + + if (found == _rules.cend()) + return false; + + auto rule = found->second; + if (rule->infer(node, dtype)) + return true; + + return false; +} + +MultiDialectTypeInferenceRule &MultiDialectTypeInferenceRule::bind(const Dialect *d, + const TypeInferenceRule *rule) +{ + assert(_rules.find(d) == _rules.end()); + assert(rule->recognize(d)); + + _rules[d] = rule; + + return (*this); +} + +} // namespace loco diff --git a/compiler/loco/src/Service/TypeInference.test.cpp b/compiler/loco/src/Service/TypeInference.test.cpp new file mode 100644 index 000000000..4660401db --- /dev/null +++ b/compiler/loco/src/Service/TypeInference.test.cpp @@ -0,0 +1,282 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco/Service/TypeInference.h" + +#include "GraphTestcase.h" + +#include <loco/IR/CanonicalDialect.h> +#include <loco/Service/TypeInference.h> + +#include <vector> + +#include <gtest/gtest.h> + +// This test validates whether framework works as expected. +TEST(TypeInferenceTest, framework) +{ + // Create a sample network + auto g = loco::make_graph(); + + auto pull_node = g->nodes()->create<loco::Pull>(); + auto push_node = g->nodes()->create<loco::Push>(); + + push_node->from(pull_node); + + // Create Graph Input & Output + auto graph_input = g->inputs()->create(); + + graph_input->name("input"); + loco::link(graph_input, pull_node); + + auto graph_output = g->outputs()->create(); + + graph_output->name("output"); + loco::link(graph_output, push_node); + + // Mock-up Type Inference Rule + struct SampleTypeInferenceRule final : public loco::TypeInferenceRule + { + public: + SampleTypeInferenceRule(std::vector<const loco::Node *> *nodes) : _nodes{nodes} + { + // DO NOTHING + } + + public: + bool recognize(const loco::Dialect *) const final + { + // Accept all the dialects + return true; + } + + bool infer(const loco::Node *node, loco::DataType &dtype) const final + { + // Record the order of inference + _nodes->emplace_back(node); + + if (_nodes->size() != 1) + { + return false; + } + + // Annotate the first node as "U8" + dtype = loco::DataType::U8; + return true; + } + + private: + std::vector<const loco::Node *> *_nodes; + }; + + std::vector<const loco::Node *> nodes; + + SampleTypeInferenceRule rule{&nodes}; + + loco::apply(&rule).to(g.get()); + + ASSERT_EQ(nodes.size(), 2); // Framework SHOULD visit all the nodes + ASSERT_EQ(nodes.at(0), pull_node); // Framework SHOULD visit "pull" before "push" + ASSERT_EQ(nodes.at(1), push_node); + + // Framework SHOULD NOT make any annotation if "rule" returns FALSE + ASSERT_TRUE(loco::dtype_known(pull_node)); + // Framework SHOULD make an annotation if "rule" returns TRUE + ASSERT_EQ(loco::dtype_get(pull_node), loco::DataType::U8); + ASSERT_FALSE(loco::dtype_known(push_node)); +} + +TEST(CanonicalTypeInferenceRuleTest, minimal) +{ + // Create a simple network + auto g = loco::make_graph(); + + auto pull_node = g->nodes()->create<loco::Pull>(); + + pull_node->dtype(loco::DataType::U8); + + auto push_node = g->nodes()->create<loco::Push>(); + + push_node->from(pull_node); + + auto graph_input = g->inputs()->create(); + + graph_input->name("input"); + loco::link(graph_input, pull_node); + + auto graph_output = g->outputs()->create(); + + graph_output->name("output"); + loco::link(graph_output, push_node); + + // Run Type Inference + loco::CanonicalTypeInferenceRule rule; + + loco::apply(&rule).to(g.get()); + + // Verify! + ASSERT_TRUE(loco::dtype_known(push_node)); + ASSERT_EQ(loco::dtype_get(push_node), loco::DataType::U8); +} + +TEST(CanonicalTypeInferenceRuleTest, relu6) +{ + // Create a simple Relu6 network + auto g = loco::make_graph(); + + auto pull_node = g->nodes()->create<loco::Pull>(); + + pull_node->dtype(loco::DataType::FLOAT32); + + auto relu6_node = g->nodes()->create<loco::ReLU6>(); + + relu6_node->input(pull_node); + + auto push_node = g->nodes()->create<loco::Push>(); + + push_node->from(relu6_node); + + auto graph_input = g->inputs()->create(); + + graph_input->name("input"); + loco::link(graph_input, pull_node); + + auto graph_output = g->outputs()->create(); + + graph_output->name("output"); + loco::link(graph_output, push_node); + + // Run Type Inference + loco::CanonicalTypeInferenceRule rule; + + loco::apply(&rule).to(g.get()); + + // Verify! + ASSERT_TRUE(loco::dtype_known(relu6_node)); + ASSERT_EQ(loco::dtype_get(relu6_node), loco::DataType::FLOAT32); +} + +TEST(CanonicalTypeInferenceRuleTest, tensor_broadcast) +{ + // Create a sample network + GraphTestcase<GraphCode::TensorBroadcast> testcase{1, 2}; + + testcase.graph()->inputs()->at(0)->dtype(loco::DataType::U8); + + // Run Type Inference + loco::CanonicalTypeInferenceRule rule; + + loco::apply(&rule).to(testcase.graph()); + + // Verify! + ASSERT_TRUE(loco::dtype_known(testcase.push_node)); + ASSERT_EQ(loco::dtype_get(testcase.push_node), loco::DataType::U8); +} + +// mockup for MultiDialectTypeInferenceRule +// OpNode of a specific loco datatype (defined in template) will be used. +// And a Dialect for the OpNode and its inference rules are created. +#include <loco/IR/Dialect.h> + +namespace +{ + +template <loco::DataType N> class TestDialect final : public loco::Dialect +{ +public: + static Dialect *get(void) + { + static TestDialect<N> d; + return &d; + } +}; + +template <loco::DataType N> +struct TestOpNode final : public loco::FixedArity<1>::Mixin<loco::Node>, + public loco::NodeMixin<loco::NodeTrait::DataType> +{ + void input(Node *node) { at(0)->node(node); } + const loco::Dialect *dialect(void) const final { return TestDialect<N>::get(); } + uint32_t opnum(void) const final { return static_cast<uint32_t>(N); } +}; + +template <loco::DataType N> struct TestTypeInferenceRule final : public loco::TypeInferenceRule +{ +public: + bool recognize(const loco::Dialect *d) const final { return (d == TestDialect<N>::get()); } + + bool infer(const loco::Node *node, loco::DataType &dtype) const final + { + assert(node->dialect() == TestDialect<N>::get()); + auto test_node = dynamic_cast<const TestOpNode<N> *>(node); + assert(test_node != nullptr); + + dtype = N; + return true; + } +}; + +} // namespace + +TEST(MultiDialectTypeInferenceRuleTest, test1) +{ + // Create a simple network : Pull - S8 - U8 - Push + auto g = loco::make_graph(); + + auto pull_node = g->nodes()->create<loco::Pull>(); + pull_node->dtype(loco::DataType::FLOAT32); + + auto s8_node = g->nodes()->create<TestOpNode<loco::DataType::S8>>(); + s8_node->input(pull_node); + + auto u8_node = g->nodes()->create<TestOpNode<loco::DataType::U8>>(); + u8_node->input(s8_node); + + auto push_node = g->nodes()->create<loco::Push>(); + push_node->from(u8_node); + + auto graph_input = g->inputs()->create(); + graph_input->name("input"); + loco::link(graph_input, pull_node); + + auto graph_output = g->outputs()->create(); + graph_output->name("output"); + loco::link(graph_output, push_node); + + // initially they don't have type info + ASSERT_FALSE(loco::dtype_known(s8_node)); + ASSERT_FALSE(loco::dtype_known(u8_node)); + + // Run Type Inference + TestTypeInferenceRule<loco::DataType::U8> u8_rule; + TestTypeInferenceRule<loco::DataType::S8> s8_rule; + loco::CanonicalTypeInferenceRule canon_rule; + + loco::MultiDialectTypeInferenceRule rules; + + rules.bind(TestDialect<loco::DataType::S8>::get(), &s8_rule) + .bind(TestDialect<loco::DataType::U8>::get(), &u8_rule) + .bind(loco::CanonicalDialect::get(), &canon_rule); + + loco::apply(&rules).to(g.get()); + + // Verify! + ASSERT_TRUE(loco::dtype_known(s8_node)); + ASSERT_EQ(loco::dtype_get(s8_node), loco::DataType::S8); + + ASSERT_TRUE(loco::dtype_known(u8_node)); + ASSERT_EQ(loco::dtype_get(u8_node), loco::DataType::U8); +} diff --git a/compiler/loco/src/loco.test.cpp b/compiler/loco/src/loco.test.cpp new file mode 100644 index 000000000..4c4f51aa5 --- /dev/null +++ b/compiler/loco/src/loco.test.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "loco.h" + +#include <gtest/gtest.h> + +// This test shows how to create an "identity" network with loco. +// +// What is "identity" network? +// - A network simply passes its input as its output +// +// TODO Create "Ouput" first and then create "Push" later +TEST(LOCO, identity_network) +{ + auto g = loco::make_graph(); + + // Create a "pull" node as an input + auto pull_node = g->nodes()->create<loco::Pull>(); + + // Set "data type" + pull_node->dtype(loco::DataType::FLOAT32); + + // Set "data shape" + pull_node->rank(2); + pull_node->dim(0) = 3; + pull_node->dim(1) = 4; + + // Create a "push" node as an output + auto push_node = g->nodes()->create<loco::Push>(); + + // Set "source" + push_node->from(pull_node); + + // Create Graph Input & Output + auto graph_input = g->inputs()->create(); + + graph_input->name("input"); + loco::link(graph_input, pull_node); + graph_input->dtype(loco::DataType::FLOAT32); + + auto graph_output = g->outputs()->create(); + + graph_output->name("output"); + loco::link(graph_output, push_node); + + // loco::link SHOULD update "index" + ASSERT_EQ(pull_node->index(), 0); + ASSERT_EQ(graph_input->dtype(), loco::DataType::FLOAT32); + + // loco::link SHOULD update "index" + ASSERT_EQ(push_node->index(), 0); +} + +#if 0 +"identity_network_V2" test shows how to use loco when loco.core and loco.canonical are decoupled. + +NOTE "identity_network" test is left for backward compatiblity check +TODO Remove "identity_network" test once all the clients are migrated. +#endif +TEST(LOCO, identity_network_V2) +{ + auto g = loco::make_graph(); + + // Create Graph Input & Output + auto graph_input = g->inputs()->create(); + + graph_input->name("input"); + graph_input->dtype(loco::DataType::FLOAT32); + // TODO Set Shape + + auto graph_output = g->outputs()->create(); + + graph_output->name("output"); + graph_output->dtype(loco::DataType::FLOAT32); + // TODO Set Shape + + // Create a "pull" node as an input + auto pull_node = g->nodes()->create<loco::Pull>(); + + pull_node->index(0); + + // Create a "push" node as an output + auto push_node = g->nodes()->create<loco::Push>(); + + push_node->index(0); + push_node->from(pull_node); + + ASSERT_EQ(pull_node->dtype(), loco::DataType::FLOAT32); + // TODO Check Shape of pull_node + // TODO Check Shape of push_node + + ASSERT_EQ(loco::pull_node(g.get(), 0), pull_node); + ASSERT_EQ(loco::push_node(g.get(), 0), push_node); +} diff --git a/compiler/loco/src/tensorflow.test.cpp b/compiler/loco/src/tensorflow.test.cpp new file mode 100644 index 000000000..f534aee7b --- /dev/null +++ b/compiler/loco/src/tensorflow.test.cpp @@ -0,0 +1,386 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * @brief This file includes various tests that shows how to encode TensorFlow models using loco. + * + * @note All the python examples below assume TensorFlow v1.13 + */ +#include "loco.h" + +#include <gtest/gtest.h> + +#include <stdex/Memory.h> + +using stdex::make_unique; + +namespace +{ + +loco::Permutation<loco::Domain::Feature> make_NHWC_permutation(void) +{ + loco::Permutation<loco::Domain::Feature> NHWC; + + NHWC.axis(loco::FeatureAxis::Count) = 0; + NHWC.axis(loco::FeatureAxis::Height) = 1; + NHWC.axis(loco::FeatureAxis::Width) = 2; + NHWC.axis(loco::FeatureAxis::Depth) = 3; + + return NHWC; +} + +/** + * @brief Create a HxWxIxO (or HxWxCxN) permutation which tf.nn.conv2d uses + * + * Reference: [tf.nn.conv2d](https://www.tensorflow.org/api_docs/python/tf/nn/conv2d) + * > Given an input tensor of shape [batch, in_height, in_width, in_channels] and a filter / + * > kernel tensor of shape [filter_height, filter_width, in_channels, out_channels], ... + * + * NOTE "HWIO" is borrowed from TensorFlow Lite Converter + * + * https://github.com/tensorflow/tensorflow/blob/v1.13.1/tensorflow/lite/toco/model.h#L169 + */ +loco::Permutation<loco::Domain::Filter> make_HWIO_permutation(void) +{ + loco::Permutation<loco::Domain::Filter> HWIO; + + HWIO.axis(loco::FilterAxis::Height) = 0; // H + HWIO.axis(loco::FilterAxis::Width) = 1; // W + HWIO.axis(loco::FilterAxis::Depth) = 2; // I, a.k.a. C + HWIO.axis(loco::FilterAxis::Count) = 3; // O, a.k.a. N + + return HWIO; +} + +} // nemaspace + +#if 0 +>>> MaxPool_Float_000 testcase + +MaxPool_Float_000 test guarantees that loco is expressive enough to encode the following example. + +Python: +``` +import tensorflow as tf +value = tf.placeholder(dtype=tf.float32, shape=[1, 16, 16, 2], name="value") +maxpool = tf.nn.max_pool(value, [1, 3, 3, 1], [1, 1, 1, 1], 'VALID', name="maxpool") +tf.get_default_graph().as_graph_def() +``` + +The above code produces the following TensorFlow GraphDef: + +node { + name: "value" + op: "Placeholder" + attr { + key: "dtype" + value { type: DT_FLOAT } + } + attr { + key: "shape" + value { + shape { + dim { size: 1 } + dim { size: 16 } + dim { size: 16 } + dim { size: 2 } + } + } + } +} +node { + name: "maxpool" + op: "MaxPool" + input: "Placeholder" + attr { + key: "T" + value { type: DT_FLOAT } + } + attr { + key: "data_format" + value { s: "NHWC" } + } + attr { + key: "ksize" + value { list { i: 1 i: 3 i: 3 i: 1 } } + } + attr { + key: "padding" + value { s: "VALID" } + } + attr { + key: "strides" + value { list { i: 1 i: 1 i: 1 i: 1 } } + } +} + +Below test guarantees that loco is expressive enough to encode this example. +#endif +TEST(TensorFlowTest, MaxPool_Float_000) +{ + auto g = loco::make_graph(); + + // The first "value" node corresponds to the following "Pull" node. + // + // %value = Pull(dtype: FLOAT32, shape: [1, 16, 16, 2]) + auto value = g->nodes()->create<loco::Pull>(); + + value->dtype(loco::DataType::FLOAT32); + value->shape({1, 16, 16, 2}); + + // The next "maxpool" node corresponds to a sequence of the following loco nodes: + // - "FeatureEncode" + // - "MaxPool2D + // - "FeatureDecode" + // + // "maxpool.data_format" is 'NHWC' which corresponds to the following permutation + // Count <-> 0 + // Height <-> 1 + // Width <-> 2 + // Depth <-> 3 + loco::Permutation<loco::Domain::Feature> NHWC; + + NHWC.axis(loco::FeatureAxis::Count) = 0; + NHWC.axis(loco::FeatureAxis::Height) = 1; + NHWC.axis(loco::FeatureAxis::Width) = 2; + NHWC.axis(loco::FeatureAxis::Depth) = 3; + + auto encoder = make_unique<loco::PermutingEncoder<loco::Domain::Feature>>(); + + encoder->perm(NHWC); + + auto decoder = make_unique<loco::PermutingDecoder<loco::Domain::Feature>>(); + + decoder->perm(NHWC); + + // %node_0 = FeatureEncode(%value, perm { Count = 0, Height = 1, Width = 2, Depth = 3 }) + auto node_0 = g->nodes()->create<loco::FeatureEncode>(); + + node_0->input(value); + node_0->encoder(std::move(encoder)); + + // %node_1 = MaxPool(%node_0, window.H: 3, window.W: 3, stride.H: 1, stride.W : 1) + auto node_1 = g->nodes()->create<loco::MaxPool2D>(); + + node_1->ifm(node_0); + + // From "ksize" attributes + node_1->window()->horizontal(3); + node_1->window()->vertical(3); + + // From "strides" attributes + node_1->stride()->horizontal(1); + node_1->stride()->vertical(1); + + // %output = FeatureDecode(%node_1, perm { Count = 0, Height = 1, Width = 2, Depth = 3 }) + auto output = g->nodes()->create<loco::FeatureDecode>(); + + output->input(node_1); + output->decoder(std::move(decoder)); + + // %push = Push(%output) + auto push = g->nodes()->create<loco::Push>(); + + push->from(output); + + // + // Mark network-level input/output + // + auto input_0 = g->inputs()->create(); + loco::link(input_0, value); + + auto output_0 = g->outputs()->create(); + loco::link(output_0, push); + + // NOTE This example SHOULD BE valid. + ASSERT_TRUE(loco::valid(g.get())); +} + +#if 0 +>>> Conv2D_Float_000 testcase + +Conv2D_Float_000 test guarantees that loco is expressive enough to encode the following example. + +Python: +``` +import tensorflow as tf +inp = tf.placeholder(dtype=tf.float32, shape=[1, 16, 16, 2], name="inp") +ker = tf.constant(value=[1.0], dtype=tf.float32, shape=[7, 1, 2, 4], name="ker") +conv2d = tf.nn.conv2d(input=inp, filter=ker, strides=[1, 1, 1, 1], padding='VALID', name="conv2d") +tf.get_default_graph().as_graph_def() +``` + +TensorFlow GraphDef: +``` +node { + name: "inp" + op: "Placeholder" + attr { + key: "dtype" + value { type: DT_FLOAT } + } + attr { + key: "shape" + value { + shape { + dim { size: 1 } + dim { size: 16 } + dim { size: 16 } + dim { size: 2 } + } + } + } +} +node { + name: "ker" + op: "Const" + attr { + key: "dtype" + value { type: DT_FLOAT } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { size: 7 } + dim { size: 1 } + dim { size: 2 } + dim { size: 4 } + } + float_val: 1.0 + } + } + } +} +node { + name: "conv2d" + op: "Conv2D" + input: "inp" + input: "ker" + attr { + key: "T" + value { type: DT_FLOAT } + } + attr { + key: "data_format" + value { s: "NHWC" } + } + attr { + key: "dilations" + value { list { i: 1 i: 1 i: 1 i: 1 } } + } + attr { + key: "padding" + value { s: "VALID" } + } + attr { + key: "strides" + value { list { i: 1 i: 1 i: 1 i: 1 } } + } +} +``` +#endif +TEST(TensorFlowTest, Conv2D_Float_000) +{ + auto g = loco::make_graph(); + + // The first "inp" node corresponds to "Pull" + auto inp = g->nodes()->create<loco::Pull>(); + { + inp->dtype(loco::DataType::FLOAT32); + inp->shape({1, 16, 16, 2}); + } + + // The seoncd "ker" node corresponds to "ConstGen" + auto ker = g->nodes()->create<loco::ConstGen>(); + { + ker->dtype(loco::DataType::FLOAT32); + // 'I' denotes IFM DEPTH, and 'O' denotes OFM DEPTH + ker->shape({7 /*H*/, 1 /*W*/, 2 /*I*/, 3 /*O*/}); + ker->size<loco::DataType::FLOAT32>(7 * 1 * 2 * 3); + for (uint32_t n = 0; n < 7 * 1 * 2 * 3; ++n) + { + // NOTE TensorFlow uses the last value to fill unspecified region + ker->at<loco::DataType::FLOAT32>(n) = 1.0f; + } + } + + // The next "conv2d" node is decomposed into the following loco nodes + // - "FeatureEncode" + // - "FilterEncode" + // - "Conv2D" + // - "FeatureDecode" + auto encoded_ifm = g->nodes()->create<loco::FeatureEncode>(); + { + // From "conv2d.data_format" attribute + auto encoder = make_unique<loco::PermutingEncoder<loco::Domain::Feature>>(); + encoder->perm(make_NHWC_permutation()); + + encoded_ifm->input(inp); + encoded_ifm->encoder(std::move(encoder)); + } + + auto encoded_ker = g->nodes()->create<loco::FilterEncode>(); + { + // From "tf.nn.conv2d" specification + auto encoder = make_unique<loco::PermutingEncoder<loco::Domain::Filter>>(); + encoder->perm(make_HWIO_permutation()); + + encoded_ker->input(ker); + encoded_ker->encoder(std::move(encoder)); + } + + auto conv2d = g->nodes()->create<loco::Conv2D>(); + { + conv2d->ifm(encoded_ifm); + conv2d->ker(encoded_ker); + + // From "stride" attribute + conv2d->stride()->horizontal(1); + conv2d->stride()->vertical(1); + } + + // "decoded_ofm" corresponds to the output of "conv2d" node. + auto decoded_ofm = g->nodes()->create<loco::FeatureDecode>(); + { + // From "conv2d.data_format" attribute + auto decoder = make_unique<loco::PermutingDecoder<loco::Domain::Feature>>(); + decoder->perm(make_NHWC_permutation()); + + decoded_ofm->input(conv2d); + decoded_ofm->decoder(std::move(decoder)); + } + + // Makr "conv2d" as a network-level output with Push + auto push = g->nodes()->create<loco::Push>(); + { + push->from(decoded_ofm); + } + + // + // Mark network-level input/output + // + auto input_0 = g->inputs()->create(); + loco::link(input_0, inp); + + auto output_0 = g->outputs()->create(); + loco::link(output_0, push); + + // NOTE This example SHOULD BE valid. + ASSERT_TRUE(loco::valid(g.get())); +} |