diff options
Diffstat (limited to 'compiler/luci/partition/src/Nodes')
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleAdd.cpp | 40 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleAdd.test.cpp | 100 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleConst.cpp | 27 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleDiv.cpp | 40 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleDiv.test.cpp | 100 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleMean.cpp | 41 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleMul.cpp | 40 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleMul.test.cpp | 100 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CirclePow.cpp | 40 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleRsqrt.cpp | 38 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleSqrt.cpp | 38 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp | 40 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleSub.cpp | 40 | ||||
-rw-r--r-- | compiler/luci/partition/src/Nodes/CircleSub.test.cpp | 100 |
14 files changed, 784 insertions, 0 deletions
diff --git a/compiler/luci/partition/src/Nodes/CircleAdd.cpp b/compiler/luci/partition/src/Nodes/CircleAdd.cpp new file mode 100644 index 000000000..d393997e9 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleAdd.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleAdd *node) +{ + auto *cloned = loco::must_cast<luci::CircleAdd *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleAdd *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleAdd.test.cpp b/compiler/luci/partition/src/Nodes/CircleAdd.test.cpp new file mode 100644 index 000000000..e457b83d2 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleAdd.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleAdd> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleAdd>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Add) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Add_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAdd *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleConst.cpp b/compiler/luci/partition/src/Nodes/CircleConst.cpp new file mode 100644 index 000000000..118cd8de2 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleConst.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleConst *) +{ + // Nothing to do +} + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleDiv.cpp b/compiler/luci/partition/src/Nodes/CircleDiv.cpp new file mode 100644 index 000000000..480338542 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleDiv.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleDiv *node) +{ + auto *cloned = loco::must_cast<luci::CircleDiv *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleDiv *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleDiv.test.cpp b/compiler/luci/partition/src/Nodes/CircleDiv.test.cpp new file mode 100644 index 000000000..226932337 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleDiv.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleDiv> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleDiv>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Div) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Div_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDiv *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleMean.cpp b/compiler/luci/partition/src/Nodes/CircleMean.cpp new file mode 100644 index 000000000..b634e5838 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMean.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleMean *node) +{ + auto *cloned = loco::must_cast<luci::CircleMean *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *reduction_indices = + loco::must_cast<luci::CircleNode *>(node->reduction_indices()); + + cloned->input(cn->find_clone(input)); + cloned->reduction_indices(cn->find_clone(reduction_indices)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleMean *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleMul.cpp b/compiler/luci/partition/src/Nodes/CircleMul.cpp new file mode 100644 index 000000000..2cd2b4038 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMul.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleMul *node) +{ + auto *cloned = loco::must_cast<luci::CircleMul *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleMul *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleMul.test.cpp b/compiler/luci/partition/src/Nodes/CircleMul.test.cpp new file mode 100644 index 000000000..99cf0824d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMul.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleMul> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) + { + NodeGraphletT<luci::CircleMul>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Mul) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Mul_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMul *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CirclePow.cpp b/compiler/luci/partition/src/Nodes/CirclePow.cpp new file mode 100644 index 000000000..fb180ee69 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePow.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CirclePow *node) +{ + auto *cloned = loco::must_cast<luci::CirclePow *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CirclePow *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleRsqrt.cpp b/compiler/luci/partition/src/Nodes/CircleRsqrt.cpp new file mode 100644 index 000000000..03e64aad0 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRsqrt.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleRsqrt *node) +{ + auto *cloned = loco::must_cast<luci::CircleRsqrt *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleRsqrt *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSqrt.cpp b/compiler/luci/partition/src/Nodes/CircleSqrt.cpp new file mode 100644 index 000000000..f737aac8d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSqrt.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSqrt *node) +{ + auto *cloned = loco::must_cast<luci::CircleSqrt *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSqrt *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp b/compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp new file mode 100644 index 000000000..40dd31706 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSquaredDifference.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSquaredDifference *node) +{ + auto *cloned = loco::must_cast<luci::CircleSquaredDifference *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSquaredDifference *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSub.cpp b/compiler/luci/partition/src/Nodes/CircleSub.cpp new file mode 100644 index 000000000..8ac294b7b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSub.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSub *node) +{ + auto *cloned = loco::must_cast<luci::CircleSub *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSub *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSub.test.cpp b/compiler/luci/partition/src/Nodes/CircleSub.test.cpp new file mode 100644 index 000000000..7c0d83745 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSub.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSub> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) + { + NodeGraphletT<luci::CircleSub>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Sub) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Sub_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSub *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} |