diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2021-08-23 13:25:15 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2021-08-23 13:25:15 +0900 |
commit | f4cf19e579a19c5346ccb2aad55bfd251065e447 (patch) | |
tree | 5d436b11f89be0e8a8289ea82b773da6402c1add /compiler/luci/partition | |
parent | 589bb1db6db6784efe21b3fbbfbfdb79aaa5f14e (diff) | |
download | nnfw-f4cf19e579a19c5346ccb2aad55bfd251065e447.tar.gz nnfw-f4cf19e579a19c5346ccb2aad55bfd251065e447.tar.bz2 nnfw-f4cf19e579a19c5346ccb2aad55bfd251065e447.zip |
Imported Upstream version 1.17.0upstream/1.17.0submit/tizen/20210823.054833submit/tizen/20210823.045832submit/tizen/20210823.044411accepted/tizen/unified/20210823.124210
Diffstat (limited to 'compiler/luci/partition')
254 files changed, 16561 insertions, 140 deletions
diff --git a/compiler/luci/partition/CMakeLists.txt b/compiler/luci/partition/CMakeLists.txt index 838642b6e..236b689c4 100644 --- a/compiler/luci/partition/CMakeLists.txt +++ b/compiler/luci/partition/CMakeLists.txt @@ -11,9 +11,12 @@ target_link_libraries(luci_partition PRIVATE luci_log) target_link_libraries(luci_partition PRIVATE luci_logex) target_link_libraries(luci_partition PRIVATE mio_circle) target_link_libraries(luci_partition PRIVATE nncc_common) +target_link_libraries(luci_partition PRIVATE pepper_csv2vec) target_link_libraries(luci_partition PRIVATE oops) install(TARGETS luci_partition DESTINATION lib) +install(DIRECTORY include/ DESTINATION include + FILES_MATCHING PATTERN "*.h") if(NOT ENABLE_TEST) return() diff --git a/compiler/luci/partition/include/luci/Partition.h b/compiler/luci/partition/include/luci/Partition.h index cf90e448b..6189ed9f2 100644 --- a/compiler/luci/partition/include/luci/Partition.h +++ b/compiler/luci/partition/include/luci/Partition.h @@ -32,13 +32,22 @@ namespace luci */ struct PartitionTable { + enum class COMPLY + { + UNDEFINED, + OPCODE, + OPNAME, + }; + std::vector<std::string> groups; std::string default_group; + COMPLY comply = COMPLY::UNDEFINED; // assign by opcode name: OPCODENAME=group std::unordered_map<std::string /* OPCODENAME */, std::string /* group */> byopcodes; - // TODO add assign by OP name + // assign by op name: OPNAME=group + std::unordered_map<std::string /* OPNAME */, std::string /* group */> byopnames; }; /** diff --git a/compiler/luci/partition/include/luci/PartitionDump.h b/compiler/luci/partition/include/luci/PartitionDump.h new file mode 100644 index 000000000..f395e57bf --- /dev/null +++ b/compiler/luci/partition/include/luci/PartitionDump.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITION_DUMP_H__ +#define __LUCI_PARTITION_DUMP_H__ + +#include "luci/Partition.h" + +#include <iostream> + +std::ostream &operator<<(std::ostream &os, const luci::PartitionTable &table); + +#endif // __LUCI_PARTITION_DUMP_H__ diff --git a/compiler/luci/partition/include/luci/PartitionValidate.h b/compiler/luci/partition/include/luci/PartitionValidate.h new file mode 100644 index 000000000..9f910c8cc --- /dev/null +++ b/compiler/luci/partition/include/luci/PartitionValidate.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PARTITION_VALIDATE_H__ +#define __LUCI_PARTITION_VALIDATE_H__ + +#include "luci/Partition.h" + +#include <luci/IR/Module.h> + +namespace luci +{ + +bool validate(luci::PartitionTable &partition); + +} // namespace luci + +#endif // __LUCI_PARTITION_VALIDATE_H__ diff --git a/compiler/luci/partition/src/ConnectNode.h b/compiler/luci/partition/src/ConnectNode.h index 017c587e5..ebbff7a6a 100644 --- a/compiler/luci/partition/src/ConnectNode.h +++ b/compiler/luci/partition/src/ConnectNode.h @@ -50,6 +50,9 @@ struct CloneContext MapNode2Clone::iterator find(const CircleNode *org) { return node2clone.find(org); } MapNode2Clone::iterator end(void) { return node2clone.end(); } + MapNode2Clone::const_iterator find(const CircleNode *org) const { return node2clone.find(org); } + MapNode2Clone::const_iterator end(void) const { return node2clone.end(); } + MapNode2Clone node2clone; }; @@ -59,138 +62,142 @@ public: ConnectNode(luci::CloneContext &clonecontext) : _clonecontext(clonecontext){}; public: - // void visit(const luci::CircleAbs *) final; + void visit(const luci::CircleAbs *) final; void visit(const luci::CircleAdd *) final; - // void visit(const luci::CircleAddN *) final; - // void visit(const luci::CircleArgMax *) final; - // void visit(const luci::CircleArgMin *) final; - // void visit(const luci::CircleAveragePool2D *) final; - // void visit(const luci::CircleBatchMatMul *) final; - // void visit(const luci::CircleBatchToSpaceND *) final; - // void visit(const luci::CircleCast *) final; - // void visit(const luci::CircleCeil *) final; - // void visit(const luci::CircleConcatenation *) final; + void visit(const luci::CircleAddN *) final; + void visit(const luci::CircleArgMax *) final; + void visit(const luci::CircleArgMin *) final; + void visit(const luci::CircleAveragePool2D *) final; + void visit(const luci::CircleBatchMatMul *) final; + void visit(const luci::CircleBatchToSpaceND *) final; + void visit(const luci::CircleCast *) final; + void visit(const luci::CircleCeil *) final; + void visit(const luci::CircleConcatenation *) final; void visit(const luci::CircleConst *) final; - // void visit(const luci::CircleConv2D *) final; - // void visit(const luci::CircleCos *) final; - // void visit(const luci::CircleCustom *) final; - // void visit(const luci::CircleDepthToSpace *) final; - // void visit(const luci::CircleDepthwiseConv2D *) final; - // void visit(const luci::CircleDequantize *) final; + void visit(const luci::CircleConv2D *) final; + void visit(const luci::CircleCos *) final; + void visit(const luci::CircleCustom *) final; + void visit(const luci::CircleDepthToSpace *) final; + void visit(const luci::CircleDepthwiseConv2D *) final; + void visit(const luci::CircleDequantize *) final; void visit(const luci::CircleDiv *) final; - // void visit(const luci::CircleElu *) final; - // void visit(const luci::CircleEqual *) final; - // void visit(const luci::CircleExp *) final; - // void visit(const luci::CircleExpandDims *) final; - // void visit(const luci::CircleFakeQuant *) final; - // void visit(const luci::CircleFill *) final; - // void visit(const luci::CircleFloor *) final; - // void visit(const luci::CircleFloorDiv *) final; - // void visit(const luci::CircleFloorMod *) final; - // void visit(const luci::CircleFullyConnected *) final; - // void visit(const luci::CircleGather *) final; - // void visit(const luci::CircleGatherNd *) final; - // void visit(const luci::CircleGreater *) final; - // void visit(const luci::CircleGreaterEqual *) final; - // void visit(const luci::CircleIf *) final; - // void visit(const luci::CircleL2Normalize *) final; - // void visit(const luci::CircleL2Pool2D *) final; - // void visit(const luci::CircleLeakyRelu *) final; - // void visit(const luci::CircleLess *) final; - // void visit(const luci::CircleLessEqual *) final; - // void visit(const luci::CircleLocalResponseNormalization *) final; - // void visit(const luci::CircleLog *) final; - // void visit(const luci::CircleLogicalAnd *) final; - // void visit(const luci::CircleLogicalNot *) final; - // void visit(const luci::CircleLogicalOr *) final; - // void visit(const luci::CircleLogistic *) final; - // void visit(const luci::CircleLogSoftmax *) final; - // void visit(const luci::CircleMatrixDiag *) final; - // void visit(const luci::CircleMatrixSetDiag *) final; - // void visit(const luci::CircleMaximum *) final; - // void visit(const luci::CircleMaxPool2D *) final; + void visit(const luci::CircleElu *) final; + void visit(const luci::CircleEqual *) final; + void visit(const luci::CircleExp *) final; + void visit(const luci::CircleExpandDims *) final; + void visit(const luci::CircleFakeQuant *) final; + void visit(const luci::CircleFill *) final; + void visit(const luci::CircleFloor *) final; + void visit(const luci::CircleFloorDiv *) final; + void visit(const luci::CircleFloorMod *) final; + void visit(const luci::CircleFullyConnected *) final; + void visit(const luci::CircleGather *) final; + void visit(const luci::CircleGatherNd *) final; + void visit(const luci::CircleGreater *) final; + void visit(const luci::CircleGreaterEqual *) final; + void visit(const luci::CircleIf *) final; + void visit(const luci::CircleL2Normalize *) final; + void visit(const luci::CircleL2Pool2D *) final; + void visit(const luci::CircleLeakyRelu *) final; + void visit(const luci::CircleLess *) final; + void visit(const luci::CircleLessEqual *) final; + void visit(const luci::CircleLocalResponseNormalization *) final; + void visit(const luci::CircleLog *) final; + void visit(const luci::CircleLogicalAnd *) final; + void visit(const luci::CircleLogicalNot *) final; + void visit(const luci::CircleLogicalOr *) final; + void visit(const luci::CircleLogistic *) final; + void visit(const luci::CircleLogSoftmax *) final; + void visit(const luci::CircleMatrixDiag *) final; + void visit(const luci::CircleMatrixSetDiag *) final; + void visit(const luci::CircleMaximum *) final; + void visit(const luci::CircleMaxPool2D *) final; void visit(const luci::CircleMean *) final; - // void visit(const luci::CircleMinimum *) final; - // void visit(const luci::CircleMirrorPad *) final; + void visit(const luci::CircleMinimum *) final; + void visit(const luci::CircleMirrorPad *) final; void visit(const luci::CircleMul *) final; - // void visit(const luci::CircleNeg *) final; - // void visit(const luci::CircleNonMaxSuppressionV4 *) final; - // void visit(const luci::CircleNonMaxSuppressionV5 *) final; - // void visit(const luci::CircleNotEqual *) final; - // void visit(const luci::CircleOneHot *) final; - // void visit(const luci::CirclePack *) final; - // void visit(const luci::CirclePad *) final; - // void visit(const luci::CirclePadV2 *) final; + void visit(const luci::CircleNeg *) final; + void visit(const luci::CircleNonMaxSuppressionV4 *) final; + void visit(const luci::CircleNonMaxSuppressionV5 *) final; + void visit(const luci::CircleNotEqual *) final; + void visit(const luci::CircleOneHot *) final; + void visit(const luci::CirclePack *) final; + void visit(const luci::CirclePad *) final; + void visit(const luci::CirclePadV2 *) final; void visit(const luci::CirclePow *) final; - // void visit(const luci::CirclePRelu *) final; - // void visit(const luci::CircleRange *) final; - // void visit(const luci::CircleRank *) final; - // void visit(const luci::CircleReduceAny *) final; - // void visit(const luci::CircleReduceMax *) final; - // void visit(const luci::CircleReduceMin *) final; - // void visit(const luci::CircleReduceProd *) final; - // void visit(const luci::CircleRelu *) final; - // void visit(const luci::CircleRelu6 *) final; - // void visit(const luci::CircleReluN1To1 *) final; - // void visit(const luci::CircleReshape *) final; - // void visit(const luci::CircleResizeBilinear *) final; - // void visit(const luci::CircleResizeNearestNeighbor *) final; - // void visit(const luci::CircleReverseSequence *) final; - // void visit(const luci::CircleReverseV2 *) final; - // void visit(const luci::CircleRound *) final; + void visit(const luci::CirclePRelu *) final; + void visit(const luci::CircleQuantize *) final; + void visit(const luci::CircleRange *) final; + void visit(const luci::CircleRank *) final; + void visit(const luci::CircleReduceAny *) final; + void visit(const luci::CircleReduceMax *) final; + void visit(const luci::CircleReduceMin *) final; + void visit(const luci::CircleReduceProd *) final; + void visit(const luci::CircleRelu *) final; + void visit(const luci::CircleRelu6 *) final; + void visit(const luci::CircleReluN1To1 *) final; + void visit(const luci::CircleReshape *) final; + void visit(const luci::CircleResizeBilinear *) final; + void visit(const luci::CircleResizeNearestNeighbor *) final; + void visit(const luci::CircleReverseSequence *) final; + void visit(const luci::CircleReverseV2 *) final; + void visit(const luci::CircleRound *) final; void visit(const luci::CircleRsqrt *) final; - // void visit(const luci::CircleScatterNd *) final; - // void visit(const luci::CircleSegmentSum *) final; - // void visit(const luci::CircleSelect *) final; - // void visit(const luci::CircleSelectV2 *) final; - // void visit(const luci::CircleShape *) final; - // void visit(const luci::CircleSin *) final; - // void visit(const luci::CircleSlice *) final; - // void visit(const luci::CircleSoftmax *) final; - // void visit(const luci::CircleSpaceToBatchND *) final; - // void visit(const luci::CircleSpaceToDepth *) final; - // void visit(const luci::CircleSparseToDense *) final; - // void visit(const luci::CircleSplit *) final; - // void visit(const luci::CircleSplitV *) final; + void visit(const luci::CircleScatterNd *) final; + void visit(const luci::CircleSegmentSum *) final; + void visit(const luci::CircleSelect *) final; + void visit(const luci::CircleSelectV2 *) final; + void visit(const luci::CircleShape *) final; + void visit(const luci::CircleSin *) final; + void visit(const luci::CircleSlice *) final; + void visit(const luci::CircleSoftmax *) final; + void visit(const luci::CircleSpaceToBatchND *) final; + void visit(const luci::CircleSpaceToDepth *) final; + void visit(const luci::CircleSparseToDense *) final; + void visit(const luci::CircleSplit *) final; + void visit(const luci::CircleSplitV *) final; void visit(const luci::CircleSqrt *) final; - // void visit(const luci::CircleSquare *) final; + void visit(const luci::CircleSquare *) final; void visit(const luci::CircleSquaredDifference *) final; - // void visit(const luci::CircleSqueeze *) final; - // void visit(const luci::CircleStridedSlice *) final; + void visit(const luci::CircleSqueeze *) final; + void visit(const luci::CircleStridedSlice *) final; void visit(const luci::CircleSub *) final; - // void visit(const luci::CircleSum *) final; - // void visit(const luci::CircleTanh *) final; - // void visit(const luci::CircleTile *) final; - // void visit(const luci::CircleTopKV2 *) final; - // void visit(const luci::CircleTranspose *) final; - // void visit(const luci::CircleTransposeConv *) final; - // void visit(const luci::CircleUnidirectionalSequenceLSTM *) final; - // void visit(const luci::CircleUnique *) final; - // void visit(const luci::CircleUnpack *) final; - // void visit(const luci::CircleWhere *) final; - // void visit(const luci::CircleWhile *) final; - // void visit(const luci::CircleZerosLike *) final; + void visit(const luci::CircleSum *) final; + void visit(const luci::CircleTanh *) final; + void visit(const luci::CircleTile *) final; + void visit(const luci::CircleTopKV2 *) final; + void visit(const luci::CircleTranspose *) final; + void visit(const luci::CircleTransposeConv *) final; + void visit(const luci::CircleUnidirectionalSequenceLSTM *) final; + void visit(const luci::CircleUnique *) final; + void visit(const luci::CircleUnpack *) final; + void visit(const luci::CircleWhere *) final; + void visit(const luci::CircleWhile *) final; + void visit(const luci::CircleZerosLike *) final; // Circle Only - // void visit(const luci::CircleBCQFullyConnected *) final; - // void visit(const luci::CircleBCQGather *) final; - // void visit(const luci::CircleInstanceNorm *) final; + void visit(const luci::CircleBCQFullyConnected *) final; + void visit(const luci::CircleBCQGather *) final; + void visit(const luci::CircleInstanceNorm *) final; + + // NOTE CircleInput and CircleOutput are not handled here as these need + // link with graph I/O // Virtual - // void visit(const luci::CircleCustomOut *) final; - // void visit(const luci::CircleIfOut *) final; + void visit(const luci::CircleCustomOut *) final; + void visit(const luci::CircleIfOut *) final; // void visit(const luci::CircleInput *) final; - // void visit(const luci::CircleNonMaxSuppressionV4Out *) final; - // void visit(const luci::CircleNonMaxSuppressionV5Out *) final; + void visit(const luci::CircleNonMaxSuppressionV4Out *) final; + void visit(const luci::CircleNonMaxSuppressionV5Out *) final; // void visit(const luci::CircleOutput *) final; - // void visit(const luci::CircleOutputDummy *) final; - // void visit(const luci::CircleOutputExclude *) final; - // void visit(const luci::CircleSplitOut *) final; - // void visit(const luci::CircleSplitVOut *) final; - // void visit(const luci::CircleTopKV2Out *) final; - // void visit(const luci::CircleUniqueOut *) final; - // void visit(const luci::CircleUnpackOut *) final; - // void visit(const luci::CircleWhileOut *) final; + void visit(const luci::CircleOutputDummy *) final; + void visit(const luci::CircleOutputExclude *) final; + void visit(const luci::CircleSplitOut *) final; + void visit(const luci::CircleSplitVOut *) final; + void visit(const luci::CircleTopKV2Out *) final; + void visit(const luci::CircleUniqueOut *) final; + void visit(const luci::CircleUnpackOut *) final; + void visit(const luci::CircleWhileOut *) final; public: luci::CircleNode *find_clone(const luci::CircleNode *node); diff --git a/compiler/luci/partition/src/ConnectNode.test.h b/compiler/luci/partition/src/ConnectNode.test.h index f7333ff99..ac4878a15 100644 --- a/compiler/luci/partition/src/ConnectNode.test.h +++ b/compiler/luci/partition/src/ConnectNode.test.h @@ -45,8 +45,9 @@ public: if (shape_in.size() != N) throw std::runtime_error("Failed to init TestIsOGraph"); - TestIsGraphlet<N>::init(TestIsGraphlet<N>::g(), shape_in); - TestOGraphlet::init(TestIsGraphlet<N>::g(), shape_out); + auto g = TestIsGraphlet<N>::g(); + TestIsGraphlet<N>::init(g, shape_in); + TestOGraphlet::init(g, shape_out); } }; @@ -82,6 +83,43 @@ protected: T *_node{nullptr}; }; +template <class T> class NodeIsOsGraphletT +{ +public: + virtual void init(loco::Graph *g, uint32_t n, uint32_t m) + { + _node = g->nodes()->create<T>(n, m); + _node->dtype(loco::DataType::S32); + _node->name("node"); + } + + T *node(void) const { return _node; } + +protected: + T *_node{nullptr}; +}; + +template <unsigned N, unsigned M> +class TestIsOsGraph : public TestIsGraphlet<N>, public TestOsGraphlet<M> +{ +public: + TestIsOsGraph() = default; + +public: + virtual void init(const std::initializer_list<ShapeU32> shape_in, + const std::initializer_list<ShapeU32> shape_out) + { + if (shape_in.size() != N) + throw std::runtime_error("Failed to init TestIsOsGraph"); + if (shape_out.size() != M) + throw std::runtime_error("Failed to init TestIsOsGraph"); + + auto g = TestIsGraphlet<N>::g(); + TestIsGraphlet<N>::init(g, shape_in); + TestOsGraphlet<M>::init(g, shape_out); + } +}; + /** * @brief ConnectionTestHelper provides common framework for testing * cloned CircleNode connection @@ -105,6 +143,33 @@ public: } } + template <unsigned N, unsigned M> void prepare_inputs(TestIsOsGraph<N, M> *isosgraph) + { + assert(N == isosgraph->num_inputs()); + assert(M == isosgraph->num_outputs()); + + for (uint32_t i = 0; i < N; ++i) + { + auto *input = _graph_clone->nodes()->create<luci::CircleInput>(); + luci::copy_common_attributes(isosgraph->input(i), input); + _clonectx.emplace(isosgraph->input(i), input); + _inputs.push_back(input); + } + } + + /** + * @note although there is only one input, method name has 's' to make test simple + */ + void prepare_inputs(TestIOGraph *isograph) + { + assert(1 == isograph->num_inputs()); + + auto *input = _graph_clone->nodes()->create<luci::CircleInput>(); + luci::copy_common_attributes(isograph->input(), input); + _clonectx.emplace(isograph->input(), input); + _inputs.push_back(input); + } + /** * @note prepare_inputs_miss is for negative testing */ @@ -122,6 +187,31 @@ public: } } + template <unsigned N, unsigned M> void prepare_inputs_miss(TestIsOsGraph<N, M> *isograph) + { + assert(N == isograph->num_inputs()); + assert(M == isograph->num_outputs()); + + for (uint32_t i = 0; i < N; ++i) + { + auto *input = _graph_clone->nodes()->create<luci::CircleInput>(); + luci::copy_common_attributes(isograph->input(i), input); + if (i != 0) + _clonectx.emplace(isograph->input(i), input); + _inputs.push_back(input); + } + } + + void prepare_inputs_miss(TestIOGraph *isograph) + { + assert(1 == isograph->num_inputs()); + + auto *input = _graph_clone->nodes()->create<luci::CircleInput>(); + luci::copy_common_attributes(isograph->input(), input); + // _clonectx.emplace() is NOT called on purpose + _inputs.push_back(input); + } + void clone_connect(luci::CircleNode *node, luci::CircleNode *clone) { _clonectx.emplace(node, clone); diff --git a/compiler/luci/partition/src/Nodes/CircleAbs.cpp b/compiler/luci/partition/src/Nodes/CircleAbs.cpp new file mode 100644 index 000000000..a3fde4c45 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleAbs.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +/** + * @note This method and all other connect() are just to reduce LOC of ConnectNode class + */ +void connect(luci::ConnectNode *cn, const luci::CircleAbs *node) +{ + auto *cloned = loco::must_cast<luci::CircleAbs *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleAbs *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleAbs.test.cpp b/compiler/luci/partition/src/Nodes/CircleAbs.test.cpp new file mode 100644 index 000000000..f3e721525 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleAbs.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleAbs> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Abs) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAbs *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAbs *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Abs_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAbs *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAbs *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleAddN.cpp b/compiler/luci/partition/src/Nodes/CircleAddN.cpp new file mode 100644 index 000000000..81e5e0949 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleAddN.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleAddN *node) +{ + auto *cloned = loco::must_cast<luci::CircleAddN *>(cn->find_clone(node)); + + uint32_t num_inputs = cloned->arity(); + for (uint32_t i = 0; i < num_inputs; ++i) + { + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->inputs(i)); + + cloned->inputs(i, cn->find_clone(input)); + } +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleAddN *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleAddN.test.cpp b/compiler/luci/partition/src/Nodes/CircleAddN.test.cpp new file mode 100644 index 000000000..5d0a7489f --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleAddN.test.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeIsGraphletT<luci::CircleAddN> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g(), 3); + + for (uint32_t i = 0; i < 3; ++i) + { + node()->inputs(i, input(i)); + } + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_AddN) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAddN *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAddN *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_AddN_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAddN *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAddN *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleArgMax.cpp b/compiler/luci/partition/src/Nodes/CircleArgMax.cpp new file mode 100644 index 000000000..1409586d7 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleArgMax.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleArgMax *node) +{ + auto *cloned = loco::must_cast<luci::CircleArgMax *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *dimension = loco::must_cast<luci::CircleNode *>(node->dimension()); + + cloned->input(cn->find_clone(input)); + cloned->dimension(cn->find_clone(dimension)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleArgMax *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleArgMax.test.cpp b/compiler/luci/partition/src/Nodes/CircleArgMax.test.cpp new file mode 100644 index 000000000..c816fbeb8 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleArgMax.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleArgMax> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->dimension(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ArgMax) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMax *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMax *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ArgMax_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMax *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMax *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleArgMin.cpp b/compiler/luci/partition/src/Nodes/CircleArgMin.cpp new file mode 100644 index 000000000..6151aa98a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleArgMin.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleArgMin *node) +{ + auto *cloned = loco::must_cast<luci::CircleArgMin *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *dimension = loco::must_cast<luci::CircleNode *>(node->dimension()); + + cloned->input(cn->find_clone(input)); + cloned->dimension(cn->find_clone(dimension)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleArgMin *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleArgMin.test.cpp b/compiler/luci/partition/src/Nodes/CircleArgMin.test.cpp new file mode 100644 index 000000000..d150be4d6 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleArgMin.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleArgMin> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->dimension(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ArgMin) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMin *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMin *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ArgMin_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMin *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleArgMin *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleAveragePool2D.cpp b/compiler/luci/partition/src/Nodes/CircleAveragePool2D.cpp new file mode 100644 index 000000000..547665771 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleAveragePool2D.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleAveragePool2D *node) +{ + auto *cloned = loco::must_cast<luci::CircleAveragePool2D *>(cn->find_clone(node)); + + luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->value()); + + cloned->value(cn->find_clone(value)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleAveragePool2D *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleAveragePool2D.test.cpp b/compiler/luci/partition/src/Nodes/CircleAveragePool2D.test.cpp new file mode 100644 index 000000000..fba2be835 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleAveragePool2D.test.cpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleAveragePool2D> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleAveragePool2D>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + _node->padding(luci::Padding::VALID); + } +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->value(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_AveragePool2D) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAveragePool2D *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAveragePool2D *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_AveragePool2D_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAveragePool2D *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleAveragePool2D *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleBCQFullyConnected.cpp b/compiler/luci/partition/src/Nodes/CircleBCQFullyConnected.cpp new file mode 100644 index 000000000..5b1dd8543 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleBCQFullyConnected.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleBCQFullyConnected *node) +{ + auto *cloned = loco::must_cast<luci::CircleBCQFullyConnected *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *weights_scales = loco::must_cast<luci::CircleNode *>(node->weights_scales()); + luci::CircleNode *weights_binary = loco::must_cast<luci::CircleNode *>(node->weights_binary()); + luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias()); + luci::CircleNode *weights_clusters = + loco::must_cast<luci::CircleNode *>(node->weights_clusters()); + + cloned->input(cn->find_clone(input)); + cloned->weights_scales(cn->find_clone(weights_scales)); + cloned->weights_binary(cn->find_clone(weights_binary)); + cloned->bias(cn->find_clone(bias)); + cloned->weights_clusters(cn->find_clone(weights_clusters)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleBCQFullyConnected *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleBCQFullyConnected.test.cpp b/compiler/luci/partition/src/Nodes/CircleBCQFullyConnected.test.cpp new file mode 100644 index 000000000..3d64f4b29 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleBCQFullyConnected.test.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleBCQFullyConnected> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleBCQFullyConnected>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<5>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<5>::init({shape, shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->weights_scales(input(1)); + node()->weights_binary(input(2)); + node()->bias(input(3)); + node()->weights_clusters(input(4)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_BCQFullyConnected) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQFullyConnected *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQFullyConnected *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(5, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); + ASSERT_EQ(cth.inputs(3), clone->arg(3)); + ASSERT_EQ(cth.inputs(4), clone->arg(4)); +} + +TEST(ConnectNodeTest, connect_BCQFullyConnected_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQFullyConnected *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQFullyConnected *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleBCQGather.cpp b/compiler/luci/partition/src/Nodes/CircleBCQGather.cpp new file mode 100644 index 000000000..90c4d9ef3 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleBCQGather.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleBCQGather *node) +{ + auto *cloned = loco::must_cast<luci::CircleBCQGather *>(cn->find_clone(node)); + + luci::CircleNode *input_scales = loco::must_cast<luci::CircleNode *>(node->input_scales()); + luci::CircleNode *input_binary = loco::must_cast<luci::CircleNode *>(node->input_binary()); + luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices()); + luci::CircleNode *input_clusters = loco::must_cast<luci::CircleNode *>(node->input_clusters()); + + cloned->input_scales(cn->find_clone(input_scales)); + cloned->input_binary(cn->find_clone(input_binary)); + cloned->indices(cn->find_clone(indices)); + cloned->input_clusters(cn->find_clone(input_clusters)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleBCQGather *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleBCQGather.test.cpp b/compiler/luci/partition/src/Nodes/CircleBCQGather.test.cpp new file mode 100644 index 000000000..bbbd3f157 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleBCQGather.test.cpp @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleBCQGather> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<4>::init({shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input_scales(input(0)); + node()->input_binary(input(1)); + node()->indices(input(2)); + node()->input_clusters(input(3)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_BCQGather) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQGather *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQGather *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(4, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); + ASSERT_EQ(cth.inputs(3), clone->arg(3)); +} + +TEST(ConnectNodeTest, connect_BCQGather_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQGather *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBCQGather *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleBatchMatMul.cpp b/compiler/luci/partition/src/Nodes/CircleBatchMatMul.cpp new file mode 100644 index 000000000..c3992a64e --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleBatchMatMul.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleBatchMatMul *node) +{ + auto *cloned = loco::must_cast<luci::CircleBatchMatMul *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleBatchMatMul *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleBatchMatMul.test.cpp b/compiler/luci/partition/src/Nodes/CircleBatchMatMul.test.cpp new file mode 100644 index 000000000..94336d36a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleBatchMatMul.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleBatchMatMul> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_BatchMatMul) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchMatMul *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchMatMul *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_BatchMatMul_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchMatMul *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchMatMul *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleBatchToSpaceND.cpp b/compiler/luci/partition/src/Nodes/CircleBatchToSpaceND.cpp new file mode 100644 index 000000000..2a463afb1 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleBatchToSpaceND.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleBatchToSpaceND *node) +{ + auto *cloned = loco::must_cast<luci::CircleBatchToSpaceND *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *block_shape = loco::must_cast<luci::CircleNode *>(node->block_shape()); + luci::CircleNode *crops = loco::must_cast<luci::CircleNode *>(node->crops()); + + cloned->input(cn->find_clone(input)); + cloned->block_shape(cn->find_clone(block_shape)); + cloned->crops(cn->find_clone(crops)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleBatchToSpaceND *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleBatchToSpaceND.test.cpp b/compiler/luci/partition/src/Nodes/CircleBatchToSpaceND.test.cpp new file mode 100644 index 000000000..544f5e127 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleBatchToSpaceND.test.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleBatchToSpaceND> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->block_shape(input(1)); + node()->crops(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_BatchToSpaceND) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchToSpaceND *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchToSpaceND *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_BatchToSpaceND_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchToSpaceND *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleBatchToSpaceND *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleCast.cpp b/compiler/luci/partition/src/Nodes/CircleCast.cpp new file mode 100644 index 000000000..f7630cd85 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCast.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleCast *node) +{ + auto *cloned = loco::must_cast<luci::CircleCast *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleCast *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleCast.test.cpp b/compiler/luci/partition/src/Nodes/CircleCast.test.cpp new file mode 100644 index 000000000..005119060 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCast.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleCast> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Cast) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCast *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCast *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Cast_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCast *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCast *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleCeil.cpp b/compiler/luci/partition/src/Nodes/CircleCeil.cpp new file mode 100644 index 000000000..a0c94033e --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCeil.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleCeil *node) +{ + auto *cloned = loco::must_cast<luci::CircleCeil *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleCeil *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleCeil.test.cpp b/compiler/luci/partition/src/Nodes/CircleCeil.test.cpp new file mode 100644 index 000000000..dbd7e5390 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCeil.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleCeil> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Ceil) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCeil *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCeil *>(node)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Ceil_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCeil *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCeil *>(node)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleConcatenation.cpp b/compiler/luci/partition/src/Nodes/CircleConcatenation.cpp new file mode 100644 index 000000000..fb24d21ca --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleConcatenation.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleConcatenation *node) +{ + auto *cloned = loco::must_cast<luci::CircleConcatenation *>(cn->find_clone(node)); + + uint32_t num_inputs = cloned->numValues(); + for (uint32_t i = 0; i < num_inputs; ++i) + { + luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->values(i)); + + cloned->values(i, cn->find_clone(value)); + } +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleConcatenation *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleConcatenation.test.cpp b/compiler/luci/partition/src/Nodes/CircleConcatenation.test.cpp new file mode 100644 index 000000000..4d64b85a2 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleConcatenation.test.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeIsGraphletT<luci::CircleConcatenation> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g, uint32_t n) override + { + NodeIsGraphletT<luci::CircleConcatenation>::init(g, n); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g(), 3); + + for (uint32_t i = 0; i < 3; ++i) + { + node()->values(i, input(i)); + } + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Concatenation) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleConcatenation *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleConcatenation *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_Concatenation_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleConcatenation *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleConcatenation *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleConv2D.cpp b/compiler/luci/partition/src/Nodes/CircleConv2D.cpp new file mode 100644 index 000000000..46716f0ec --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleConv2D.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleConv2D *node) +{ + auto *cloned = loco::must_cast<luci::CircleConv2D *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *filter = loco::must_cast<luci::CircleNode *>(node->filter()); + luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias()); + + cloned->input(cn->find_clone(input)); + cloned->filter(cn->find_clone(filter)); + cloned->bias(cn->find_clone(bias)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleConv2D *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleConv2D.test.cpp b/compiler/luci/partition/src/Nodes/CircleConv2D.test.cpp new file mode 100644 index 000000000..829adec9b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleConv2D.test.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleConv2D> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleConv2D>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + _node->padding(luci::Padding::VALID); + } +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->filter(input(1)); + node()->bias(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Conv2D) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleConv2D *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleConv2D *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_Conv2D_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleConv2D *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleConv2D *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleCos.cpp b/compiler/luci/partition/src/Nodes/CircleCos.cpp new file mode 100644 index 000000000..9dcf81e83 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCos.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleCos *node) +{ + auto *cloned = loco::must_cast<luci::CircleCos *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleCos *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleCos.test.cpp b/compiler/luci/partition/src/Nodes/CircleCos.test.cpp new file mode 100644 index 000000000..6c92b93fb --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCos.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleCos> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Cos) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCos *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCos *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Cos_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCos *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCos *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleCustom.cpp b/compiler/luci/partition/src/Nodes/CircleCustom.cpp new file mode 100644 index 000000000..ac16ebe40 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCustom.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleCustom *node) +{ + auto *cloned = loco::must_cast<luci::CircleCustom *>(cn->find_clone(node)); + + uint32_t numInputs = cloned->numInputs(); + for (uint32_t i = 0; i < numInputs; ++i) + { + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->inputs(i)); + + cloned->inputs(i, cn->find_clone(input)); + } +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleCustom *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleCustom.test.cpp b/compiler/luci/partition/src/Nodes/CircleCustom.test.cpp new file mode 100644 index 000000000..9f40b5220 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCustom.test.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +/** + * @note Does not use template like others as only Custom have both multiple in/out + */ +class NodeGraphlet +{ +public: + NodeGraphlet() = default; + +public: + virtual void init(loco::Graph *g, uint32_t in, uint32_t out) + { + _node = g->nodes()->create<luci::CircleCustom>(in, out); + _node->dtype(loco::DataType::S32); + _node->name("node"); + } + + luci::CircleCustom *node(void) const { return _node; } + +protected: + luci::CircleCustom *_node = nullptr; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g(), 3, 3); + + for (uint32_t i = 0; i < 3; ++i) + { + node()->inputs(i, input(i)); + } + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Custom) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCustom *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCustom *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_Custom_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCustom *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCustom *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleCustomOut.cpp b/compiler/luci/partition/src/Nodes/CircleCustomOut.cpp new file mode 100644 index 000000000..fee1a1a8c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCustomOut.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleCustomOut *node) +{ + auto *cloned = loco::must_cast<luci::CircleCustomOut *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleCustomOut *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleCustomOut.test.cpp b/compiler/luci/partition/src/Nodes/CircleCustomOut.test.cpp new file mode 100644 index 000000000..0a293970e --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleCustomOut.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleCustomOut> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_CustomOut) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCustomOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCustomOut *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_CustomOut_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCustomOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleCustomOut *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleDepthToSpace.cpp b/compiler/luci/partition/src/Nodes/CircleDepthToSpace.cpp new file mode 100644 index 000000000..ade266e41 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleDepthToSpace.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleDepthToSpace *node) +{ + auto *cloned = loco::must_cast<luci::CircleDepthToSpace *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleDepthToSpace *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleDepthToSpace.test.cpp b/compiler/luci/partition/src/Nodes/CircleDepthToSpace.test.cpp new file mode 100644 index 000000000..997360a9b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleDepthToSpace.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleDepthToSpace> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_DepthToSpace) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthToSpace *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthToSpace *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_DepthToSpace_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthToSpace *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthToSpace *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleDepthwiseConv2D.cpp b/compiler/luci/partition/src/Nodes/CircleDepthwiseConv2D.cpp new file mode 100644 index 000000000..19d1d5f42 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleDepthwiseConv2D.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleDepthwiseConv2D *node) +{ + auto *cloned = loco::must_cast<luci::CircleDepthwiseConv2D *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *filter = loco::must_cast<luci::CircleNode *>(node->filter()); + luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias()); + + cloned->input(cn->find_clone(input)); + cloned->filter(cn->find_clone(filter)); + cloned->bias(cn->find_clone(bias)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleDepthwiseConv2D *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleDepthwiseConv2D.test.cpp b/compiler/luci/partition/src/Nodes/CircleDepthwiseConv2D.test.cpp new file mode 100644 index 000000000..681f98bdb --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleDepthwiseConv2D.test.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleDepthwiseConv2D> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleDepthwiseConv2D>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + _node->padding(luci::Padding::VALID); + } +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->filter(input(1)); + node()->bias(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_DepthwiseConv2D) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthwiseConv2D *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthwiseConv2D *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_DepthwiseConv2D_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthwiseConv2D *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDepthwiseConv2D *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleDequantize.cpp b/compiler/luci/partition/src/Nodes/CircleDequantize.cpp new file mode 100644 index 000000000..3a520d4e9 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleDequantize.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleDequantize *node) +{ + auto *cloned = loco::must_cast<luci::CircleDequantize *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleDequantize *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleDequantize.test.cpp b/compiler/luci/partition/src/Nodes/CircleDequantize.test.cpp new file mode 100644 index 000000000..7f6006c1d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleDequantize.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleDequantize> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Dequantize) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDequantize *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDequantize *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Dequantize_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDequantize *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleDequantize *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleElu.cpp b/compiler/luci/partition/src/Nodes/CircleElu.cpp new file mode 100644 index 000000000..d21cd4c01 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleElu.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleElu *node) +{ + auto *cloned = loco::must_cast<luci::CircleElu *>(cn->find_clone(node)); + + luci::CircleNode *features = loco::must_cast<luci::CircleNode *>(node->features()); + + cloned->features(cn->find_clone(features)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleElu *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleElu.test.cpp b/compiler/luci/partition/src/Nodes/CircleElu.test.cpp new file mode 100644 index 000000000..94774cca8 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleElu.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleElu> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->features(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Elu) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleElu *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleElu *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Elu_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleElu *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleElu *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleEqual.cpp b/compiler/luci/partition/src/Nodes/CircleEqual.cpp new file mode 100644 index 000000000..6a126c0e2 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleEqual.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleEqual *node) +{ + auto *cloned = loco::must_cast<luci::CircleEqual *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleEqual *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleEqual.test.cpp b/compiler/luci/partition/src/Nodes/CircleEqual.test.cpp new file mode 100644 index 000000000..20b539199 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleEqual.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleEqual> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Equal) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleEqual *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleEqual *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Equal_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleEqual *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleEqual *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleExp.cpp b/compiler/luci/partition/src/Nodes/CircleExp.cpp new file mode 100644 index 000000000..95fb1cd67 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleExp.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleExp *node) +{ + auto *cloned = loco::must_cast<luci::CircleExp *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleExp *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleExp.test.cpp b/compiler/luci/partition/src/Nodes/CircleExp.test.cpp new file mode 100644 index 000000000..16d7244ab --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleExp.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleExp> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Exp) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleExp *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleExp *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Exp_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleExp *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleExp *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleExpandDims.cpp b/compiler/luci/partition/src/Nodes/CircleExpandDims.cpp new file mode 100644 index 000000000..6fccd6310 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleExpandDims.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleExpandDims *node) +{ + auto *cloned = loco::must_cast<luci::CircleExpandDims *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *axis = loco::must_cast<luci::CircleNode *>(node->axis()); + + cloned->input(cn->find_clone(input)); + cloned->axis(cn->find_clone(axis)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleExpandDims *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleExpandDims.test.cpp b/compiler/luci/partition/src/Nodes/CircleExpandDims.test.cpp new file mode 100644 index 000000000..8a5156509 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleExpandDims.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleExpandDims> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->axis(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ExpandDims) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleExpandDims *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleExpandDims *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ExpandDims_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleExpandDims *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleExpandDims *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleFakeQuant.cpp b/compiler/luci/partition/src/Nodes/CircleFakeQuant.cpp new file mode 100644 index 000000000..4855d80ae --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFakeQuant.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleFakeQuant *node) +{ + auto *cloned = loco::must_cast<luci::CircleFakeQuant *>(cn->find_clone(node)); + + luci::CircleNode *inputs = loco::must_cast<luci::CircleNode *>(node->inputs()); + + cloned->inputs(cn->find_clone(inputs)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleFakeQuant *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleFakeQuant.test.cpp b/compiler/luci/partition/src/Nodes/CircleFakeQuant.test.cpp new file mode 100644 index 000000000..3821d755a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFakeQuant.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleFakeQuant> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->inputs(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_FakeQuant) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFakeQuant *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFakeQuant *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_FakeQuant_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFakeQuant *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFakeQuant *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleFill.cpp b/compiler/luci/partition/src/Nodes/CircleFill.cpp new file mode 100644 index 000000000..06fca7b41 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFill.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleFill *node) +{ + auto *cloned = loco::must_cast<luci::CircleFill *>(cn->find_clone(node)); + + luci::CircleNode *dims = loco::must_cast<luci::CircleNode *>(node->dims()); + luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->value()); + + cloned->dims(cn->find_clone(dims)); + cloned->value(cn->find_clone(value)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleFill *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleFill.test.cpp b/compiler/luci/partition/src/Nodes/CircleFill.test.cpp new file mode 100644 index 000000000..97a5a348d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFill.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleFill> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->dims(input(0)); + node()->value(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Fill) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFill *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFill *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Fill_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFill *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFill *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleFloor.cpp b/compiler/luci/partition/src/Nodes/CircleFloor.cpp new file mode 100644 index 000000000..7ad392461 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFloor.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleFloor *node) +{ + auto *cloned = loco::must_cast<luci::CircleFloor *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleFloor *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleFloor.test.cpp b/compiler/luci/partition/src/Nodes/CircleFloor.test.cpp new file mode 100644 index 000000000..1a964ea21 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFloor.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleFloor> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Floor) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloor *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloor *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Floor_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloor *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloor *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleFloorDiv.cpp b/compiler/luci/partition/src/Nodes/CircleFloorDiv.cpp new file mode 100644 index 000000000..3b92b00c6 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFloorDiv.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleFloorDiv *node) +{ + auto *cloned = loco::must_cast<luci::CircleFloorDiv *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleFloorDiv *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleFloorDiv.test.cpp b/compiler/luci/partition/src/Nodes/CircleFloorDiv.test.cpp new file mode 100644 index 000000000..3d2801566 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFloorDiv.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleFloorDiv> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_FloorDiv) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorDiv *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorDiv *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_FloorDiv_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorDiv *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorDiv *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleFloorMod.cpp b/compiler/luci/partition/src/Nodes/CircleFloorMod.cpp new file mode 100644 index 000000000..9f868d0e5 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFloorMod.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleFloorMod *node) +{ + auto *cloned = loco::must_cast<luci::CircleFloorMod *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleFloorMod *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleFloorMod.test.cpp b/compiler/luci/partition/src/Nodes/CircleFloorMod.test.cpp new file mode 100644 index 000000000..89a09411b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFloorMod.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleFloorMod> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_FloorMod) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorMod *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorMod *>(node)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_FloorMod_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorMod *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFloorMod *>(node)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/partition/src/Nodes/CircleFullyConnected.cpp new file mode 100644 index 000000000..da273037a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFullyConnected.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleFullyConnected *node) +{ + auto *cloned = loco::must_cast<luci::CircleFullyConnected *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *weights = loco::must_cast<luci::CircleNode *>(node->weights()); + luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias()); + + cloned->input(cn->find_clone(input)); + cloned->weights(cn->find_clone(weights)); + cloned->bias(cn->find_clone(bias)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleFullyConnected *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleFullyConnected.test.cpp b/compiler/luci/partition/src/Nodes/CircleFullyConnected.test.cpp new file mode 100644 index 000000000..fc88204bd --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleFullyConnected.test.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleFullyConnected> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleFullyConnected>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + _node->weights_format(luci::CircleFullyConnected::WeightsFormat::DEFAULT); + } +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->weights(input(1)); + node()->bias(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_FullyConnected) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFullyConnected *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFullyConnected *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_FullyConnected_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFullyConnected *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleFullyConnected *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleGather.cpp b/compiler/luci/partition/src/Nodes/CircleGather.cpp new file mode 100644 index 000000000..0ee458394 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGather.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleGather *node) +{ + auto *cloned = loco::must_cast<luci::CircleGather *>(cn->find_clone(node)); + + luci::CircleNode *params = loco::must_cast<luci::CircleNode *>(node->params()); + luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices()); + + cloned->params(cn->find_clone(params)); + cloned->indices(cn->find_clone(indices)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleGather *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleGather.test.cpp b/compiler/luci/partition/src/Nodes/CircleGather.test.cpp new file mode 100644 index 000000000..7f4e08435 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGather.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleGather> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->params(input(0)); + node()->indices(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Gather) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGather *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGather *>(node)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Gather_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGather *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGather *>(node)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleGatherNd.cpp b/compiler/luci/partition/src/Nodes/CircleGatherNd.cpp new file mode 100644 index 000000000..4be05ca94 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGatherNd.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleGatherNd *node) +{ + auto *cloned = loco::must_cast<luci::CircleGatherNd *>(cn->find_clone(node)); + + luci::CircleNode *params = loco::must_cast<luci::CircleNode *>(node->params()); + luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices()); + + cloned->params(cn->find_clone(params)); + cloned->indices(cn->find_clone(indices)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleGatherNd *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleGatherNd.test.cpp b/compiler/luci/partition/src/Nodes/CircleGatherNd.test.cpp new file mode 100644 index 000000000..d673698e1 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGatherNd.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleGatherNd> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->params(input(0)); + node()->indices(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_GatherNd) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGatherNd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGatherNd *>(node)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_GatherNd_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGatherNd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGatherNd *>(node)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleGreater.cpp b/compiler/luci/partition/src/Nodes/CircleGreater.cpp new file mode 100644 index 000000000..7bc2a14c9 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGreater.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleGreater *node) +{ + auto *cloned = loco::must_cast<luci::CircleGreater *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleGreater *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleGreater.test.cpp b/compiler/luci/partition/src/Nodes/CircleGreater.test.cpp new file mode 100644 index 000000000..842370d42 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGreater.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleGreater> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Greater) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGreater *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGreater *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Greater_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGreater *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGreater *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleGreaterEqual.cpp b/compiler/luci/partition/src/Nodes/CircleGreaterEqual.cpp new file mode 100644 index 000000000..536a0aed6 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGreaterEqual.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleGreaterEqual *node) +{ + auto *cloned = loco::must_cast<luci::CircleGreaterEqual *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleGreaterEqual *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleGreaterEqual.test.cpp b/compiler/luci/partition/src/Nodes/CircleGreaterEqual.test.cpp new file mode 100644 index 000000000..76dc770f8 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleGreaterEqual.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleGreaterEqual> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_GreaterEqual) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGreaterEqual *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGreaterEqual *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_GreaterEqual_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGreaterEqual *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleGreaterEqual *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleIf.cpp b/compiler/luci/partition/src/Nodes/CircleIf.cpp new file mode 100644 index 000000000..1672a136d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleIf.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleIf *node) +{ + auto *cloned = loco::must_cast<luci::CircleIf *>(cn->find_clone(node)); + + luci::CircleNode *cond = loco::must_cast<luci::CircleNode *>(node->cond()); + + cloned->cond(cn->find_clone(cond)); + + auto input_count = node->input_count(); + for (uint32_t in = 0; in < input_count; ++in) + { + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input(in)); + + cloned->input(in, cn->find_clone(input)); + } +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleIf *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleIf.test.cpp b/compiler/luci/partition/src/Nodes/CircleIf.test.cpp new file mode 100644 index 000000000..dbd25c822 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleIf.test.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeIsOsGraphletT<luci::CircleIf> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g, uint32_t n, uint32_t m) override + { + // cond() will take one input + NodeIsOsGraphletT::init(g, n - 1, m); + } +}; + +class TestNodeGraph : public TestIsOsGraph<3, 1>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOsGraph<3, 1>::init({shape, shape, shape}, {shape}); + NodeGraphlet::init(g(), 3, 1); + + node()->cond(input(0)); + node()->input(0, input(1)); + node()->input(1, input(2)); + + output(0)->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_If) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs<3, 1>(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleIf *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleIf *>(clone)); + + cth.clone_connect(node, clone); + + // aritiy(3) = cond + input(2) + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_If_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss<3, 1>(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleIf *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleIf *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleIfOut.cpp b/compiler/luci/partition/src/Nodes/CircleIfOut.cpp new file mode 100644 index 000000000..969bdd93c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleIfOut.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleIfOut *node) +{ + auto *cloned = loco::must_cast<luci::CircleIfOut *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleIfOut *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleIfOut.test.cpp b/compiler/luci/partition/src/Nodes/CircleIfOut.test.cpp new file mode 100644 index 000000000..9207654bc --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleIfOut.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleIfOut> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_IfOut) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleIfOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleIfOut *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_IfOut_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleIfOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleIfOut *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleInstanceNorm.cpp b/compiler/luci/partition/src/Nodes/CircleInstanceNorm.cpp new file mode 100644 index 000000000..386652fb1 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleInstanceNorm.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleInstanceNorm *node) +{ + auto *cloned = loco::must_cast<luci::CircleInstanceNorm *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *gamma = loco::must_cast<luci::CircleNode *>(node->gamma()); + luci::CircleNode *beta = loco::must_cast<luci::CircleNode *>(node->beta()); + + cloned->input(cn->find_clone(input)); + cloned->gamma(cn->find_clone(gamma)); + cloned->beta(cn->find_clone(beta)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleInstanceNorm *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleInstanceNorm.test.cpp b/compiler/luci/partition/src/Nodes/CircleInstanceNorm.test.cpp new file mode 100644 index 000000000..b932223d0 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleInstanceNorm.test.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleInstanceNorm> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleInstanceNorm>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->gamma(input(1)); + node()->beta(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_InstanceNorm) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleInstanceNorm *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleInstanceNorm *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_InstanceNorm_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleInstanceNorm *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleInstanceNorm *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleL2Normalize.cpp b/compiler/luci/partition/src/Nodes/CircleL2Normalize.cpp new file mode 100644 index 000000000..61ddba264 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleL2Normalize.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleL2Normalize *node) +{ + auto *cloned = loco::must_cast<luci::CircleL2Normalize *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleL2Normalize *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleL2Normalize.test.cpp b/compiler/luci/partition/src/Nodes/CircleL2Normalize.test.cpp new file mode 100644 index 000000000..4fc23727a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleL2Normalize.test.cpp @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleL2Normalize> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleL2Normalize>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_L2Normalize) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Normalize *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Normalize *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_L2Normalize_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Normalize *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Normalize *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleL2Pool2D.cpp b/compiler/luci/partition/src/Nodes/CircleL2Pool2D.cpp new file mode 100644 index 000000000..24333d507 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleL2Pool2D.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleL2Pool2D *node) +{ + auto *cloned = loco::must_cast<luci::CircleL2Pool2D *>(cn->find_clone(node)); + + luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->value()); + + cloned->value(cn->find_clone(value)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleL2Pool2D *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleL2Pool2D.test.cpp b/compiler/luci/partition/src/Nodes/CircleL2Pool2D.test.cpp new file mode 100644 index 000000000..40328488c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleL2Pool2D.test.cpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleL2Pool2D> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleL2Pool2D>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + _node->padding(luci::Padding::VALID); + } +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->value(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_L2Pool2D) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Pool2D *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Pool2D *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_L2Pool2D_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Pool2D *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleL2Pool2D *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleLeakyRelu.cpp b/compiler/luci/partition/src/Nodes/CircleLeakyRelu.cpp new file mode 100644 index 000000000..3da1ba287 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLeakyRelu.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleLeakyRelu *node) +{ + auto *cloned = loco::must_cast<luci::CircleLeakyRelu *>(cn->find_clone(node)); + + luci::CircleNode *features = loco::must_cast<luci::CircleNode *>(node->features()); + + cloned->features(cn->find_clone(features)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleLeakyRelu *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleLeakyRelu.test.cpp b/compiler/luci/partition/src/Nodes/CircleLeakyRelu.test.cpp new file mode 100644 index 000000000..5a0d1dd87 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLeakyRelu.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleLeakyRelu> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->features(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_LeakyRelu) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLeakyRelu *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLeakyRelu *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_LeakyRelu_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLeakyRelu *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLeakyRelu *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleLess.cpp b/compiler/luci/partition/src/Nodes/CircleLess.cpp new file mode 100644 index 000000000..aab495fcc --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLess.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleLess *node) +{ + auto *cloned = loco::must_cast<luci::CircleLess *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleLess *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleLess.test.cpp b/compiler/luci/partition/src/Nodes/CircleLess.test.cpp new file mode 100644 index 000000000..ab65e5d18 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLess.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleLess> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Less) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLess *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLess *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Less_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLess *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLess *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleLessEqual.cpp b/compiler/luci/partition/src/Nodes/CircleLessEqual.cpp new file mode 100644 index 000000000..ec129dbe8 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLessEqual.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleLessEqual *node) +{ + auto *cloned = loco::must_cast<luci::CircleLessEqual *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleLessEqual *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleLessEqual.test.cpp b/compiler/luci/partition/src/Nodes/CircleLessEqual.test.cpp new file mode 100644 index 000000000..0dd8986b6 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLessEqual.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleLessEqual> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_LessEqual) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLessEqual *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLessEqual *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_LessEqual_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLessEqual *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLessEqual *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleLocalResponseNormalization.cpp b/compiler/luci/partition/src/Nodes/CircleLocalResponseNormalization.cpp new file mode 100644 index 000000000..6b0d1cd12 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLocalResponseNormalization.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleLocalResponseNormalization *node) +{ + auto *cloned = loco::must_cast<luci::CircleLocalResponseNormalization *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleLocalResponseNormalization *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleLocalResponseNormalization.test.cpp b/compiler/luci/partition/src/Nodes/CircleLocalResponseNormalization.test.cpp new file mode 100644 index 000000000..e1973387d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLocalResponseNormalization.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleLocalResponseNormalization> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_LocalResponseNormalization) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLocalResponseNormalization *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLocalResponseNormalization *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_LocalResponseNormalization_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLocalResponseNormalization *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLocalResponseNormalization *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleLog.cpp b/compiler/luci/partition/src/Nodes/CircleLog.cpp new file mode 100644 index 000000000..c43570fa2 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLog.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleLog *node) +{ + auto *cloned = loco::must_cast<luci::CircleLog *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleLog *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleLog.test.cpp b/compiler/luci/partition/src/Nodes/CircleLog.test.cpp new file mode 100644 index 000000000..8a43f6f01 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLog.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleLog> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Log) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLog *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLog *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Log_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLog *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLog *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleLogSoftmax.cpp b/compiler/luci/partition/src/Nodes/CircleLogSoftmax.cpp new file mode 100644 index 000000000..de582c80d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLogSoftmax.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleLogSoftmax *node) +{ + auto *cloned = loco::must_cast<luci::CircleLogSoftmax *>(cn->find_clone(node)); + + luci::CircleNode *logits = loco::must_cast<luci::CircleNode *>(node->logits()); + + cloned->logits(cn->find_clone(logits)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleLogSoftmax *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleLogSoftmax.test.cpp b/compiler/luci/partition/src/Nodes/CircleLogSoftmax.test.cpp new file mode 100644 index 000000000..1e60bf54c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLogSoftmax.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleLogSoftmax> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->logits(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_LogSoftmax) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogSoftmax *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogSoftmax *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_LogSoftmax_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogSoftmax *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogSoftmax *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleLogicalAnd.cpp b/compiler/luci/partition/src/Nodes/CircleLogicalAnd.cpp new file mode 100644 index 000000000..28e8f42e5 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLogicalAnd.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleLogicalAnd *node) +{ + auto *cloned = loco::must_cast<luci::CircleLogicalAnd *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleLogicalAnd *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleLogicalAnd.test.cpp b/compiler/luci/partition/src/Nodes/CircleLogicalAnd.test.cpp new file mode 100644 index 000000000..a1189f06f --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLogicalAnd.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleLogicalAnd> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_LogicalAnd) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalAnd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalAnd *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_LogicalAnd_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalAnd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalAnd *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleLogicalNot.cpp b/compiler/luci/partition/src/Nodes/CircleLogicalNot.cpp new file mode 100644 index 000000000..e2657824c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLogicalNot.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleLogicalNot *node) +{ + auto *cloned = loco::must_cast<luci::CircleLogicalNot *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleLogicalNot *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleLogicalNot.test.cpp b/compiler/luci/partition/src/Nodes/CircleLogicalNot.test.cpp new file mode 100644 index 000000000..f6b34596e --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLogicalNot.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleLogicalNot> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_LogicalNot) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalNot *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalNot *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_LogicalNot_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalNot *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalNot *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleLogicalOr.cpp b/compiler/luci/partition/src/Nodes/CircleLogicalOr.cpp new file mode 100644 index 000000000..418dc023b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLogicalOr.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleLogicalOr *node) +{ + auto *cloned = loco::must_cast<luci::CircleLogicalOr *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleLogicalOr *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleLogicalOr.test.cpp b/compiler/luci/partition/src/Nodes/CircleLogicalOr.test.cpp new file mode 100644 index 000000000..fee3f4779 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLogicalOr.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleLogicalOr> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_LogicalOr) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalOr *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalOr *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_LogicalOr_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalOr *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogicalOr *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleLogistic.cpp b/compiler/luci/partition/src/Nodes/CircleLogistic.cpp new file mode 100644 index 000000000..7d788512d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLogistic.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleLogistic *node) +{ + auto *cloned = loco::must_cast<luci::CircleLogistic *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleLogistic *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleLogistic.test.cpp b/compiler/luci/partition/src/Nodes/CircleLogistic.test.cpp new file mode 100644 index 000000000..c4b3f7fe3 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleLogistic.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleLogistic> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Logistic) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogistic *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogistic *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Logistic_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogistic *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleLogistic *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleMatrixDiag.cpp b/compiler/luci/partition/src/Nodes/CircleMatrixDiag.cpp new file mode 100644 index 000000000..e92806aff --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMatrixDiag.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleMatrixDiag *node) +{ + auto *cloned = loco::must_cast<luci::CircleMatrixDiag *>(cn->find_clone(node)); + + luci::CircleNode *diagonal = loco::must_cast<luci::CircleNode *>(node->diagonal()); + + cloned->diagonal(cn->find_clone(diagonal)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleMatrixDiag *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleMatrixDiag.test.cpp b/compiler/luci/partition/src/Nodes/CircleMatrixDiag.test.cpp new file mode 100644 index 000000000..03e3c3c3e --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMatrixDiag.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleMatrixDiag> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->diagonal(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_MatrixDiag) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixDiag *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixDiag *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_MatrixDiag_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixDiag *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixDiag *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleMatrixSetDiag.cpp b/compiler/luci/partition/src/Nodes/CircleMatrixSetDiag.cpp new file mode 100644 index 000000000..29bb7fe5f --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMatrixSetDiag.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleMatrixSetDiag *node) +{ + auto *cloned = loco::must_cast<luci::CircleMatrixSetDiag *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *diagonal = loco::must_cast<luci::CircleNode *>(node->diagonal()); + + cloned->input(cn->find_clone(input)); + cloned->diagonal(cn->find_clone(diagonal)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleMatrixSetDiag *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleMatrixSetDiag.test.cpp b/compiler/luci/partition/src/Nodes/CircleMatrixSetDiag.test.cpp new file mode 100644 index 000000000..5503ea18f --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMatrixSetDiag.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleMatrixSetDiag> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->diagonal(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_MatrixSetDiag) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixSetDiag *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixSetDiag *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_MatrixSetDiag_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixSetDiag *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMatrixSetDiag *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleMaxPool2D.cpp b/compiler/luci/partition/src/Nodes/CircleMaxPool2D.cpp new file mode 100644 index 000000000..75a665aee --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMaxPool2D.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleMaxPool2D *node) +{ + auto *cloned = loco::must_cast<luci::CircleMaxPool2D *>(cn->find_clone(node)); + + luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->value()); + + cloned->value(cn->find_clone(value)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleMaxPool2D *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleMaxPool2D.test.cpp b/compiler/luci/partition/src/Nodes/CircleMaxPool2D.test.cpp new file mode 100644 index 000000000..16996497a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMaxPool2D.test.cpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleMaxPool2D> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleMaxPool2D>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + _node->padding(luci::Padding::VALID); + } +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->value(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_MaxPool2D) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMaxPool2D *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMaxPool2D *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_MaxPool2D_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMaxPool2D *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMaxPool2D *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleMaximum.cpp b/compiler/luci/partition/src/Nodes/CircleMaximum.cpp new file mode 100644 index 000000000..2ba6055b4 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMaximum.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleMaximum *node) +{ + auto *cloned = loco::must_cast<luci::CircleMaximum *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleMaximum *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleMaximum.test.cpp b/compiler/luci/partition/src/Nodes/CircleMaximum.test.cpp new file mode 100644 index 000000000..370174c37 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMaximum.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleMaximum> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Maximum) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMaximum *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMaximum *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Maximum_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMaximum *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMaximum *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleMean.test.cpp b/compiler/luci/partition/src/Nodes/CircleMean.test.cpp new file mode 100644 index 000000000..53435d9dc --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMean.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleMean> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->reduction_indices(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Mean) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMean *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMean *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Mean_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMean *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMean *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleMinimum.cpp b/compiler/luci/partition/src/Nodes/CircleMinimum.cpp new file mode 100644 index 000000000..cdf757583 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMinimum.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleMinimum *node) +{ + auto *cloned = loco::must_cast<luci::CircleMinimum *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleMinimum *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleMinimum.test.cpp b/compiler/luci/partition/src/Nodes/CircleMinimum.test.cpp new file mode 100644 index 000000000..2fe6b0da6 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMinimum.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleMinimum> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Minimum) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMinimum *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMinimum *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Minimum_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMinimum *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMinimum *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleMirrorPad.cpp b/compiler/luci/partition/src/Nodes/CircleMirrorPad.cpp new file mode 100644 index 000000000..16a24abf7 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMirrorPad.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleMirrorPad *node) +{ + auto *cloned = loco::must_cast<luci::CircleMirrorPad *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *paddings = loco::must_cast<luci::CircleNode *>(node->paddings()); + + cloned->input(cn->find_clone(input)); + cloned->paddings(cn->find_clone(paddings)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleMirrorPad *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleMirrorPad.test.cpp b/compiler/luci/partition/src/Nodes/CircleMirrorPad.test.cpp new file mode 100644 index 000000000..605a126c9 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleMirrorPad.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleMirrorPad> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleMirrorPad>::init(g); + + _node->mode(luci::MirrorPadMode::REFLECT); + } +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->paddings(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_MirrorPad) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMirrorPad *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMirrorPad *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_MirrorPad_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMirrorPad *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleMirrorPad *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleNeg.cpp b/compiler/luci/partition/src/Nodes/CircleNeg.cpp new file mode 100644 index 000000000..413ad4930 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNeg.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleNeg *node) +{ + auto *cloned = loco::must_cast<luci::CircleNeg *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleNeg *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleNeg.test.cpp b/compiler/luci/partition/src/Nodes/CircleNeg.test.cpp new file mode 100644 index 000000000..bd74a3665 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNeg.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleNeg> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Neg) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNeg *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNeg *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Neg_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNeg *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNeg *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4.cpp b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4.cpp new file mode 100644 index 000000000..63ff3f021 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleNonMaxSuppressionV4 *node) +{ + auto *cloned = loco::must_cast<luci::CircleNonMaxSuppressionV4 *>(cn->find_clone(node)); + + luci::CircleNode *boxes = loco::must_cast<luci::CircleNode *>(node->boxes()); + luci::CircleNode *scores = loco::must_cast<luci::CircleNode *>(node->scores()); + luci::CircleNode *max_output_size = loco::must_cast<luci::CircleNode *>(node->max_output_size()); + luci::CircleNode *iou_threshold = loco::must_cast<luci::CircleNode *>(node->iou_threshold()); + luci::CircleNode *score_threshold = loco::must_cast<luci::CircleNode *>(node->score_threshold()); + + cloned->boxes(cn->find_clone(boxes)); + cloned->scores(cn->find_clone(scores)); + cloned->max_output_size(cn->find_clone(max_output_size)); + cloned->iou_threshold(cn->find_clone(iou_threshold)); + cloned->score_threshold(cn->find_clone(score_threshold)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleNonMaxSuppressionV4 *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4.test.cpp b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4.test.cpp new file mode 100644 index 000000000..2771aef49 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4.test.cpp @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleNonMaxSuppressionV4> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<5>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<5>::init({shape, shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->boxes(input(0)); + node()->scores(input(1)); + node()->max_output_size(input(2)); + node()->iou_threshold(input(3)); + node()->score_threshold(input(4)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_NonMaxSuppressionV4) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4 *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(5, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); + ASSERT_EQ(cth.inputs(3), clone->arg(3)); + ASSERT_EQ(cth.inputs(4), clone->arg(4)); +} + +TEST(ConnectNodeTest, connect_NonMaxSuppressionV4_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4 *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4Out.cpp b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4Out.cpp new file mode 100644 index 000000000..80e4704b9 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4Out.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleNonMaxSuppressionV4Out *node) +{ + auto *cloned = loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleNonMaxSuppressionV4Out *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4Out.test.cpp b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4Out.test.cpp new file mode 100644 index 000000000..5a0a8da8c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV4Out.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleNonMaxSuppressionV4Out> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_NonMaxSuppressionV4Out) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_NonMaxSuppressionV4Out_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV4Out *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5.cpp b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5.cpp new file mode 100644 index 000000000..c1f117724 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleNonMaxSuppressionV5 *node) +{ + auto *cloned = loco::must_cast<luci::CircleNonMaxSuppressionV5 *>(cn->find_clone(node)); + + luci::CircleNode *boxes = loco::must_cast<luci::CircleNode *>(node->boxes()); + luci::CircleNode *scores = loco::must_cast<luci::CircleNode *>(node->scores()); + luci::CircleNode *max_output_size = loco::must_cast<luci::CircleNode *>(node->max_output_size()); + luci::CircleNode *iou_threshold = loco::must_cast<luci::CircleNode *>(node->iou_threshold()); + luci::CircleNode *score_threshold = loco::must_cast<luci::CircleNode *>(node->score_threshold()); + luci::CircleNode *soft_nms_sigma = loco::must_cast<luci::CircleNode *>(node->soft_nms_sigma()); + + cloned->boxes(cn->find_clone(boxes)); + cloned->scores(cn->find_clone(scores)); + cloned->max_output_size(cn->find_clone(max_output_size)); + cloned->iou_threshold(cn->find_clone(iou_threshold)); + cloned->score_threshold(cn->find_clone(score_threshold)); + cloned->soft_nms_sigma(cn->find_clone(soft_nms_sigma)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleNonMaxSuppressionV5 *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5.test.cpp b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5.test.cpp new file mode 100644 index 000000000..1f20fbb0f --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5.test.cpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleNonMaxSuppressionV5> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<6>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<6>::init({shape, shape, shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->boxes(input(0)); + node()->scores(input(1)); + node()->max_output_size(input(2)); + node()->iou_threshold(input(3)); + node()->score_threshold(input(4)); + node()->soft_nms_sigma(input(5)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_NonMaxSuppressionV5) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5 *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(6, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); + ASSERT_EQ(cth.inputs(3), clone->arg(3)); + ASSERT_EQ(cth.inputs(4), clone->arg(4)); + ASSERT_EQ(cth.inputs(5), clone->arg(5)); +} + +TEST(ConnectNodeTest, connect_NonMaxSuppressionV5_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5 *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5Out.cpp b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5Out.cpp new file mode 100644 index 000000000..69e3cc8e8 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5Out.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleNonMaxSuppressionV5Out *node) +{ + auto *cloned = loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleNonMaxSuppressionV5Out *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5Out.test.cpp b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5Out.test.cpp new file mode 100644 index 000000000..e001b0b0b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNonMaxSuppressionV5Out.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleNonMaxSuppressionV5Out> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_NonMaxSuppressionV5Out) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_NonMaxSuppressionV5Out_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNonMaxSuppressionV5Out *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleNotEqual.cpp b/compiler/luci/partition/src/Nodes/CircleNotEqual.cpp new file mode 100644 index 000000000..c40c2a21a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNotEqual.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleNotEqual *node) +{ + auto *cloned = loco::must_cast<luci::CircleNotEqual *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + luci::CircleNode *y = loco::must_cast<luci::CircleNode *>(node->y()); + + cloned->x(cn->find_clone(x)); + cloned->y(cn->find_clone(y)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleNotEqual *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleNotEqual.test.cpp b/compiler/luci/partition/src/Nodes/CircleNotEqual.test.cpp new file mode 100644 index 000000000..360940ca7 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleNotEqual.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleNotEqual> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_NotEqual) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNotEqual *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNotEqual *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_NotEqual_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNotEqual *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleNotEqual *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleOneHot.cpp b/compiler/luci/partition/src/Nodes/CircleOneHot.cpp new file mode 100644 index 000000000..d76f49255 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleOneHot.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleOneHot *node) +{ + auto *cloned = loco::must_cast<luci::CircleOneHot *>(cn->find_clone(node)); + + luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices()); + luci::CircleNode *depth = loco::must_cast<luci::CircleNode *>(node->depth()); + luci::CircleNode *on_value = loco::must_cast<luci::CircleNode *>(node->on_value()); + luci::CircleNode *off_value = loco::must_cast<luci::CircleNode *>(node->off_value()); + + cloned->indices(cn->find_clone(indices)); + cloned->depth(cn->find_clone(depth)); + cloned->on_value(cn->find_clone(on_value)); + cloned->off_value(cn->find_clone(off_value)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleOneHot *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleOneHot.test.cpp b/compiler/luci/partition/src/Nodes/CircleOneHot.test.cpp new file mode 100644 index 000000000..3c555c290 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleOneHot.test.cpp @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleOneHot> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<4>::init({shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->indices(input(0)); + node()->depth(input(1)); + node()->on_value(input(2)); + node()->off_value(input(3)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_OneHot) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleOneHot *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleOneHot *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(4, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); + ASSERT_EQ(cth.inputs(3), clone->arg(3)); +} + +TEST(ConnectNodeTest, connect_OneHot_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleOneHot *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleOneHot *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleOutputDummy.cpp b/compiler/luci/partition/src/Nodes/CircleOutputDummy.cpp new file mode 100644 index 000000000..a033e80a8 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleOutputDummy.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleOutputDummy *) +{ + // Nothing to do +} + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleOutputExclude.cpp b/compiler/luci/partition/src/Nodes/CircleOutputExclude.cpp new file mode 100644 index 000000000..106eb405d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleOutputExclude.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleOutputExclude *) +{ + // Nothing to do +} + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CirclePRelu.cpp b/compiler/luci/partition/src/Nodes/CirclePRelu.cpp new file mode 100644 index 000000000..b8a2341c8 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePRelu.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CirclePRelu *node) +{ + auto *cloned = loco::must_cast<luci::CirclePRelu *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *alpha = loco::must_cast<luci::CircleNode *>(node->alpha()); + + cloned->input(cn->find_clone(input)); + cloned->alpha(cn->find_clone(alpha)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CirclePRelu *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CirclePRelu.test.cpp b/compiler/luci/partition/src/Nodes/CirclePRelu.test.cpp new file mode 100644 index 000000000..e5bcedcf6 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePRelu.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CirclePRelu> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->alpha(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_PRelu) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePRelu *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePRelu *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_PRelu_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePRelu *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePRelu *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CirclePack.cpp b/compiler/luci/partition/src/Nodes/CirclePack.cpp new file mode 100644 index 000000000..326881067 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePack.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CirclePack *node) +{ + auto *cloned = loco::must_cast<luci::CirclePack *>(cn->find_clone(node)); + + uint32_t values_count = cloned->values_count(); + for (uint32_t i = 0; i < values_count; ++i) + { + luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->values(i)); + + cloned->values(i, cn->find_clone(value)); + } +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CirclePack *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CirclePack.test.cpp b/compiler/luci/partition/src/Nodes/CirclePack.test.cpp new file mode 100644 index 000000000..68c513848 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePack.test.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeIsGraphletT<luci::CirclePack> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g(), 3); + + for (uint32_t i = 0; i < 3; ++i) + { + node()->values(i, input(i)); + } + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Pack) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePack *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePack *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_Pack_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePack *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePack *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CirclePad.cpp b/compiler/luci/partition/src/Nodes/CirclePad.cpp new file mode 100644 index 000000000..eb2a89c85 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePad.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CirclePad *node) +{ + auto *cloned = loco::must_cast<luci::CirclePad *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *paddings = loco::must_cast<luci::CircleNode *>(node->paddings()); + + cloned->input(cn->find_clone(input)); + cloned->paddings(cn->find_clone(paddings)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CirclePad *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CirclePad.test.cpp b/compiler/luci/partition/src/Nodes/CirclePad.test.cpp new file mode 100644 index 000000000..24ea83fa3 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePad.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CirclePad> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->paddings(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Pad) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePad *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePad *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Pad_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePad *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePad *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CirclePadV2.cpp b/compiler/luci/partition/src/Nodes/CirclePadV2.cpp new file mode 100644 index 000000000..001fecbcb --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePadV2.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CirclePadV2 *node) +{ + auto *cloned = loco::must_cast<luci::CirclePadV2 *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *paddings = loco::must_cast<luci::CircleNode *>(node->paddings()); + luci::CircleNode *constant_values = loco::must_cast<luci::CircleNode *>(node->constant_values()); + + cloned->input(cn->find_clone(input)); + cloned->paddings(cn->find_clone(paddings)); + cloned->constant_values(cn->find_clone(constant_values)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CirclePadV2 *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CirclePadV2.test.cpp b/compiler/luci/partition/src/Nodes/CirclePadV2.test.cpp new file mode 100644 index 000000000..aea8e0cce --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePadV2.test.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CirclePadV2> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->paddings(input(1)); + node()->constant_values(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_PadV2) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePadV2 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePadV2 *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_PadV2_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePadV2 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePadV2 *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CirclePow.test.cpp b/compiler/luci/partition/src/Nodes/CirclePow.test.cpp new file mode 100644 index 000000000..7a5be4d13 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CirclePow.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CirclePow> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Pow) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePow *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePow *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Pow_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePow *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CirclePow *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleQuantize.cpp b/compiler/luci/partition/src/Nodes/CircleQuantize.cpp new file mode 100644 index 000000000..340c1da42 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleQuantize.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleQuantize *node) +{ + auto *cloned = loco::must_cast<luci::CircleQuantize *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleQuantize *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleQuantize.test.cpp b/compiler/luci/partition/src/Nodes/CircleQuantize.test.cpp new file mode 100644 index 000000000..1f348b45c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleQuantize.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleQuantize> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Quantize) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleQuantize *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleQuantize *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Quantize_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleQuantize *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleQuantize *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleRange.cpp b/compiler/luci/partition/src/Nodes/CircleRange.cpp new file mode 100644 index 000000000..f295338d8 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRange.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleRange *node) +{ + auto *cloned = loco::must_cast<luci::CircleRange *>(cn->find_clone(node)); + + luci::CircleNode *start = loco::must_cast<luci::CircleNode *>(node->start()); + luci::CircleNode *limit = loco::must_cast<luci::CircleNode *>(node->limit()); + luci::CircleNode *delta = loco::must_cast<luci::CircleNode *>(node->delta()); + + cloned->start(cn->find_clone(start)); + cloned->limit(cn->find_clone(limit)); + cloned->delta(cn->find_clone(delta)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleRange *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleRange.test.cpp b/compiler/luci/partition/src/Nodes/CircleRange.test.cpp new file mode 100644 index 000000000..59a95f119 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRange.test.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleRange> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->start(input(0)); + node()->limit(input(1)); + node()->delta(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Range) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRange *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRange *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_Range_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRange *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRange *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleRank.cpp b/compiler/luci/partition/src/Nodes/CircleRank.cpp new file mode 100644 index 000000000..f7cce762b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRank.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleRank *node) +{ + auto *cloned = loco::must_cast<luci::CircleRank *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleRank *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleRank.test.cpp b/compiler/luci/partition/src/Nodes/CircleRank.test.cpp new file mode 100644 index 000000000..74c520bee --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRank.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleRank> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Rank) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRank *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRank *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Rank_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRank *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRank *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleReduceAny.cpp b/compiler/luci/partition/src/Nodes/CircleReduceAny.cpp new file mode 100644 index 000000000..ed762dbc6 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReduceAny.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleReduceAny *node) +{ + auto *cloned = loco::must_cast<luci::CircleReduceAny *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *reduction_indices = + loco::must_cast<luci::CircleNode *>(node->reduction_indices()); + + cloned->input(cn->find_clone(input)); + cloned->reduction_indices(cn->find_clone(reduction_indices)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleReduceAny *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleReduceAny.test.cpp b/compiler/luci/partition/src/Nodes/CircleReduceAny.test.cpp new file mode 100644 index 000000000..792f51187 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReduceAny.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleReduceAny> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->reduction_indices(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ReduceAny) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceAny *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceAny *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ReduceAny_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceAny *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceAny *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleReduceMax.cpp b/compiler/luci/partition/src/Nodes/CircleReduceMax.cpp new file mode 100644 index 000000000..09586ecee --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReduceMax.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleReduceMax *node) +{ + auto *cloned = loco::must_cast<luci::CircleReduceMax *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *reduction_indices = + loco::must_cast<luci::CircleNode *>(node->reduction_indices()); + + cloned->input(cn->find_clone(input)); + cloned->reduction_indices(cn->find_clone(reduction_indices)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleReduceMax *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleReduceMax.test.cpp b/compiler/luci/partition/src/Nodes/CircleReduceMax.test.cpp new file mode 100644 index 000000000..8fbaf653e --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReduceMax.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleReduceMax> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->reduction_indices(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ReduceMax) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMax *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMax *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ReduceMax_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMax *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMax *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleReduceMin.cpp b/compiler/luci/partition/src/Nodes/CircleReduceMin.cpp new file mode 100644 index 000000000..105214d0b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReduceMin.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleReduceMin *node) +{ + auto *cloned = loco::must_cast<luci::CircleReduceMin *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *reduction_indices = + loco::must_cast<luci::CircleNode *>(node->reduction_indices()); + + cloned->input(cn->find_clone(input)); + cloned->reduction_indices(cn->find_clone(reduction_indices)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleReduceMin *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleReduceMin.test.cpp b/compiler/luci/partition/src/Nodes/CircleReduceMin.test.cpp new file mode 100644 index 000000000..c37d6248f --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReduceMin.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleReduceMin> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->reduction_indices(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ReduceMin) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMin *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMin *>(node)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ReduceMin_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMin *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceMin *>(node)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleReduceProd.cpp b/compiler/luci/partition/src/Nodes/CircleReduceProd.cpp new file mode 100644 index 000000000..2fb4e3e01 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReduceProd.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleReduceProd *node) +{ + auto *cloned = loco::must_cast<luci::CircleReduceProd *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *reduction_indices = + loco::must_cast<luci::CircleNode *>(node->reduction_indices()); + + cloned->input(cn->find_clone(input)); + cloned->reduction_indices(cn->find_clone(reduction_indices)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleReduceProd *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleReduceProd.test.cpp b/compiler/luci/partition/src/Nodes/CircleReduceProd.test.cpp new file mode 100644 index 000000000..cc1ac83ad --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReduceProd.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleReduceProd> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->reduction_indices(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ReduceProd) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceProd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceProd *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ReduceProd_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceProd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReduceProd *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleRelu.cpp b/compiler/luci/partition/src/Nodes/CircleRelu.cpp new file mode 100644 index 000000000..d3617bdbd --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRelu.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleRelu *node) +{ + auto *cloned = loco::must_cast<luci::CircleRelu *>(cn->find_clone(node)); + + luci::CircleNode *features = loco::must_cast<luci::CircleNode *>(node->features()); + + cloned->features(cn->find_clone(features)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleRelu *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleRelu.test.cpp b/compiler/luci/partition/src/Nodes/CircleRelu.test.cpp new file mode 100644 index 000000000..ccaf5760b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRelu.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleRelu> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->features(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Relu) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Relu_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleRelu6.cpp b/compiler/luci/partition/src/Nodes/CircleRelu6.cpp new file mode 100644 index 000000000..fb9ba6f36 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRelu6.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleRelu6 *node) +{ + auto *cloned = loco::must_cast<luci::CircleRelu6 *>(cn->find_clone(node)); + + luci::CircleNode *features = loco::must_cast<luci::CircleNode *>(node->features()); + + cloned->features(cn->find_clone(features)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleRelu6 *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleRelu6.test.cpp b/compiler/luci/partition/src/Nodes/CircleRelu6.test.cpp new file mode 100644 index 000000000..1341b0e06 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRelu6.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleRelu6> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->features(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Relu6) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu6 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu6 *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Relu6_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu6 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRelu6 *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleReluN1To1.cpp b/compiler/luci/partition/src/Nodes/CircleReluN1To1.cpp new file mode 100644 index 000000000..476195b71 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReluN1To1.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleReluN1To1 *node) +{ + auto *cloned = loco::must_cast<luci::CircleReluN1To1 *>(cn->find_clone(node)); + + luci::CircleNode *features = loco::must_cast<luci::CircleNode *>(node->features()); + + cloned->features(cn->find_clone(features)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleReluN1To1 *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleReluN1To1.test.cpp b/compiler/luci/partition/src/Nodes/CircleReluN1To1.test.cpp new file mode 100644 index 000000000..7dc63c6ef --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReluN1To1.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleReluN1To1> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->features(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ReluN1To1) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReluN1To1 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReluN1To1 *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_ReluN1To1_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReluN1To1 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReluN1To1 *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleReshape.cpp b/compiler/luci/partition/src/Nodes/CircleReshape.cpp new file mode 100644 index 000000000..e59670453 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReshape.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleReshape *node) +{ + auto *cloned = loco::must_cast<luci::CircleReshape *>(cn->find_clone(node)); + + luci::CircleNode *tensor = loco::must_cast<luci::CircleNode *>(node->tensor()); + luci::CircleNode *shape = loco::must_cast<luci::CircleNode *>(node->shape()); + + cloned->tensor(cn->find_clone(tensor)); + cloned->shape(cn->find_clone(shape)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleReshape *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleReshape.test.cpp b/compiler/luci/partition/src/Nodes/CircleReshape.test.cpp new file mode 100644 index 000000000..73cbbdfcc --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReshape.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleReshape> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->tensor(input(0)); + node()->shape(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Reshape) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReshape *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReshape *>(node)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Reshape_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReshape *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReshape *>(node)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleResizeBilinear.cpp b/compiler/luci/partition/src/Nodes/CircleResizeBilinear.cpp new file mode 100644 index 000000000..0f504015b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleResizeBilinear.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleResizeBilinear *node) +{ + auto *cloned = loco::must_cast<luci::CircleResizeBilinear *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *size = loco::must_cast<luci::CircleNode *>(node->size()); + + cloned->input(cn->find_clone(input)); + cloned->size(cn->find_clone(size)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleResizeBilinear *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleResizeBilinear.test.cpp b/compiler/luci/partition/src/Nodes/CircleResizeBilinear.test.cpp new file mode 100644 index 000000000..c2d8b714b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleResizeBilinear.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleResizeBilinear> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->size(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ResizeBilinear) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeBilinear *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeBilinear *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ResizeBilinear_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeBilinear *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeBilinear *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleResizeNearestNeighbor.cpp b/compiler/luci/partition/src/Nodes/CircleResizeNearestNeighbor.cpp new file mode 100644 index 000000000..c985b7f51 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleResizeNearestNeighbor.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleResizeNearestNeighbor *node) +{ + auto *cloned = loco::must_cast<luci::CircleResizeNearestNeighbor *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *size = loco::must_cast<luci::CircleNode *>(node->size()); + + cloned->input(cn->find_clone(input)); + cloned->size(cn->find_clone(size)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleResizeNearestNeighbor *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleResizeNearestNeighbor.test.cpp b/compiler/luci/partition/src/Nodes/CircleResizeNearestNeighbor.test.cpp new file mode 100644 index 000000000..9cc2e558e --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleResizeNearestNeighbor.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleResizeNearestNeighbor> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->size(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ResizeNearestNeighbor) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeNearestNeighbor *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeNearestNeighbor *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ResizeNearestNeighbor_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeNearestNeighbor *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleResizeNearestNeighbor *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleReverseSequence.cpp b/compiler/luci/partition/src/Nodes/CircleReverseSequence.cpp new file mode 100644 index 000000000..225d29ea5 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReverseSequence.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleReverseSequence *node) +{ + auto *cloned = loco::must_cast<luci::CircleReverseSequence *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *seq_lengths = loco::must_cast<luci::CircleNode *>(node->seq_lengths()); + + cloned->input(cn->find_clone(input)); + cloned->seq_lengths(cn->find_clone(seq_lengths)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleReverseSequence *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleReverseSequence.test.cpp b/compiler/luci/partition/src/Nodes/CircleReverseSequence.test.cpp new file mode 100644 index 000000000..408fc0c9c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReverseSequence.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleReverseSequence> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->seq_lengths(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ReverseSequence) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseSequence *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseSequence *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ReverseSequence_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseSequence *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseSequence *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleReverseV2.cpp b/compiler/luci/partition/src/Nodes/CircleReverseV2.cpp new file mode 100644 index 000000000..d59a7de93 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReverseV2.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleReverseV2 *node) +{ + auto *cloned = loco::must_cast<luci::CircleReverseV2 *>(cn->find_clone(node)); + + luci::CircleNode *tensor = loco::must_cast<luci::CircleNode *>(node->tensor()); + luci::CircleNode *axis = loco::must_cast<luci::CircleNode *>(node->axis()); + + cloned->tensor(cn->find_clone(tensor)); + cloned->axis(cn->find_clone(axis)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleReverseV2 *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleReverseV2.test.cpp b/compiler/luci/partition/src/Nodes/CircleReverseV2.test.cpp new file mode 100644 index 000000000..d41ad8e66 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleReverseV2.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleReverseV2> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->tensor(input(0)); + node()->axis(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ReverseV2) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseV2 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseV2 *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_ReverseV2_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseV2 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleReverseV2 *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleRound.cpp b/compiler/luci/partition/src/Nodes/CircleRound.cpp new file mode 100644 index 000000000..9170bcdd9 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRound.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleRound *node) +{ + auto *cloned = loco::must_cast<luci::CircleRound *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleRound *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleRound.test.cpp b/compiler/luci/partition/src/Nodes/CircleRound.test.cpp new file mode 100644 index 000000000..fad090476 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRound.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleRound> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Round) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRound *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRound *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Round_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRound *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRound *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleRsqrt.test.cpp b/compiler/luci/partition/src/Nodes/CircleRsqrt.test.cpp new file mode 100644 index 000000000..d76b96e14 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleRsqrt.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleRsqrt> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Rsqrt) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRsqrt *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRsqrt *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Rsqrt_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRsqrt *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleRsqrt *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleScatterNd.cpp b/compiler/luci/partition/src/Nodes/CircleScatterNd.cpp new file mode 100644 index 000000000..62912b791 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleScatterNd.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleScatterNd *node) +{ + auto *cloned = loco::must_cast<luci::CircleScatterNd *>(cn->find_clone(node)); + + luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices()); + luci::CircleNode *updates = loco::must_cast<luci::CircleNode *>(node->updates()); + luci::CircleNode *shape = loco::must_cast<luci::CircleNode *>(node->shape()); + + cloned->indices(cn->find_clone(indices)); + cloned->updates(cn->find_clone(updates)); + cloned->shape(cn->find_clone(shape)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleScatterNd *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleScatterNd.test.cpp b/compiler/luci/partition/src/Nodes/CircleScatterNd.test.cpp new file mode 100644 index 000000000..f271f8843 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleScatterNd.test.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleScatterNd> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->indices(input(0)); + node()->updates(input(1)); + node()->shape(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ScatterNd) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleScatterNd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleScatterNd *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_ScatterNd_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleScatterNd *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleScatterNd *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSegmentSum.cpp b/compiler/luci/partition/src/Nodes/CircleSegmentSum.cpp new file mode 100644 index 000000000..5fc320a16 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSegmentSum.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSegmentSum *node) +{ + auto *cloned = loco::must_cast<luci::CircleSegmentSum *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *segment_ids = loco::must_cast<luci::CircleNode *>(node->segment_ids()); + + cloned->input(cn->find_clone(input)); + cloned->segment_ids(cn->find_clone(segment_ids)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSegmentSum *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSegmentSum.test.cpp b/compiler/luci/partition/src/Nodes/CircleSegmentSum.test.cpp new file mode 100644 index 000000000..a6bcff20a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSegmentSum.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSegmentSum> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->segment_ids(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SegmentSum) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSegmentSum *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSegmentSum *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_SegmentSum_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSegmentSum *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSegmentSum *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSelect.cpp b/compiler/luci/partition/src/Nodes/CircleSelect.cpp new file mode 100644 index 000000000..dbe1dd48f --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSelect.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSelect *node) +{ + auto *cloned = loco::must_cast<luci::CircleSelect *>(cn->find_clone(node)); + + luci::CircleNode *condition = loco::must_cast<luci::CircleNode *>(node->condition()); + luci::CircleNode *t = loco::must_cast<luci::CircleNode *>(node->t()); + luci::CircleNode *e = loco::must_cast<luci::CircleNode *>(node->e()); + + cloned->condition(cn->find_clone(condition)); + cloned->t(cn->find_clone(t)); + cloned->e(cn->find_clone(e)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSelect *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSelect.test.cpp b/compiler/luci/partition/src/Nodes/CircleSelect.test.cpp new file mode 100644 index 000000000..912934b8b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSelect.test.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSelect> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->condition(input(0)); + node()->t(input(1)); + node()->e(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Select) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSelect *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSelect *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_Select_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSelect *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSelect *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSelectV2.cpp b/compiler/luci/partition/src/Nodes/CircleSelectV2.cpp new file mode 100644 index 000000000..28072c860 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSelectV2.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSelectV2 *node) +{ + auto *cloned = loco::must_cast<luci::CircleSelectV2 *>(cn->find_clone(node)); + + luci::CircleNode *condition = loco::must_cast<luci::CircleNode *>(node->condition()); + luci::CircleNode *t = loco::must_cast<luci::CircleNode *>(node->t()); + luci::CircleNode *e = loco::must_cast<luci::CircleNode *>(node->e()); + + cloned->condition(cn->find_clone(condition)); + cloned->t(cn->find_clone(t)); + cloned->e(cn->find_clone(e)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSelectV2 *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSelectV2.test.cpp b/compiler/luci/partition/src/Nodes/CircleSelectV2.test.cpp new file mode 100644 index 000000000..e8d128e93 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSelectV2.test.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSelectV2> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->condition(input(0)); + node()->t(input(1)); + node()->e(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SelectV2) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSelectV2 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSelectV2 *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_SelectV2_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSelectV2 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSelectV2 *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleShape.cpp b/compiler/luci/partition/src/Nodes/CircleShape.cpp new file mode 100644 index 000000000..f93cf1458 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleShape.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleShape *node) +{ + auto *cloned = loco::must_cast<luci::CircleShape *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleShape *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleShape.test.cpp b/compiler/luci/partition/src/Nodes/CircleShape.test.cpp new file mode 100644 index 000000000..9b4afdcc2 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleShape.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleShape> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Shape) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleShape *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleShape *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Shape_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleShape *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleShape *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSin.cpp b/compiler/luci/partition/src/Nodes/CircleSin.cpp new file mode 100644 index 000000000..62c776ef6 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSin.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSin *node) +{ + auto *cloned = loco::must_cast<luci::CircleSin *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSin *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSin.test.cpp b/compiler/luci/partition/src/Nodes/CircleSin.test.cpp new file mode 100644 index 000000000..fbee6f662 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSin.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSin> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Sin) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSin *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSin *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Sin_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSin *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSin *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSlice.cpp b/compiler/luci/partition/src/Nodes/CircleSlice.cpp new file mode 100644 index 000000000..7895d9ece --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSlice.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSlice *node) +{ + auto *cloned = loco::must_cast<luci::CircleSlice *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *begin = loco::must_cast<luci::CircleNode *>(node->begin()); + luci::CircleNode *size = loco::must_cast<luci::CircleNode *>(node->size()); + + cloned->input(cn->find_clone(input)); + cloned->begin(cn->find_clone(begin)); + cloned->size(cn->find_clone(size)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSlice *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSlice.test.cpp b/compiler/luci/partition/src/Nodes/CircleSlice.test.cpp new file mode 100644 index 000000000..3c666ad6c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSlice.test.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSlice> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->begin(input(1)); + node()->size(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Slice) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSlice *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSlice *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_Slice_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSlice *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSlice *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSoftmax.cpp b/compiler/luci/partition/src/Nodes/CircleSoftmax.cpp new file mode 100644 index 000000000..0a93787e7 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSoftmax.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSoftmax *node) +{ + auto *cloned = loco::must_cast<luci::CircleSoftmax *>(cn->find_clone(node)); + + luci::CircleNode *logits = loco::must_cast<luci::CircleNode *>(node->logits()); + + cloned->logits(cn->find_clone(logits)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSoftmax *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSoftmax.test.cpp b/compiler/luci/partition/src/Nodes/CircleSoftmax.test.cpp new file mode 100644 index 000000000..b25629863 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSoftmax.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSoftmax> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->logits(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Softmax) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSoftmax *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSoftmax *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Softmax_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSoftmax *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSoftmax *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSpaceToBatchND.cpp b/compiler/luci/partition/src/Nodes/CircleSpaceToBatchND.cpp new file mode 100644 index 000000000..b94948bee --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSpaceToBatchND.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSpaceToBatchND *node) +{ + auto *cloned = loco::must_cast<luci::CircleSpaceToBatchND *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *block_shape = loco::must_cast<luci::CircleNode *>(node->block_shape()); + luci::CircleNode *paddings = loco::must_cast<luci::CircleNode *>(node->paddings()); + + cloned->input(cn->find_clone(input)); + cloned->block_shape(cn->find_clone(block_shape)); + cloned->paddings(cn->find_clone(paddings)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSpaceToBatchND *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSpaceToBatchND.test.cpp b/compiler/luci/partition/src/Nodes/CircleSpaceToBatchND.test.cpp new file mode 100644 index 000000000..279e9b232 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSpaceToBatchND.test.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSpaceToBatchND> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->block_shape(input(1)); + node()->paddings(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SpaceToBatchND) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToBatchND *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToBatchND *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_SpaceToBatchND_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToBatchND *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToBatchND *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSpaceToDepth.cpp b/compiler/luci/partition/src/Nodes/CircleSpaceToDepth.cpp new file mode 100644 index 000000000..bd4523ca8 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSpaceToDepth.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSpaceToDepth *node) +{ + auto *cloned = loco::must_cast<luci::CircleSpaceToDepth *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSpaceToDepth *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSpaceToDepth.test.cpp b/compiler/luci/partition/src/Nodes/CircleSpaceToDepth.test.cpp new file mode 100644 index 000000000..207163d08 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSpaceToDepth.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSpaceToDepth> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SpaceToDepth) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToDepth *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToDepth *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_SpaceToDepth_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToDepth *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSpaceToDepth *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSparseToDense.cpp b/compiler/luci/partition/src/Nodes/CircleSparseToDense.cpp new file mode 100644 index 000000000..d1ed18818 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSparseToDense.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSparseToDense *node) +{ + auto *cloned = loco::must_cast<luci::CircleSparseToDense *>(cn->find_clone(node)); + + luci::CircleNode *indices = loco::must_cast<luci::CircleNode *>(node->indices()); + luci::CircleNode *output_shape = loco::must_cast<luci::CircleNode *>(node->output_shape()); + luci::CircleNode *values = loco::must_cast<luci::CircleNode *>(node->values()); + luci::CircleNode *default_value = loco::must_cast<luci::CircleNode *>(node->default_value()); + + cloned->indices(cn->find_clone(indices)); + cloned->output_shape(cn->find_clone(output_shape)); + cloned->values(cn->find_clone(values)); + cloned->default_value(cn->find_clone(default_value)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSparseToDense *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSparseToDense.test.cpp b/compiler/luci/partition/src/Nodes/CircleSparseToDense.test.cpp new file mode 100644 index 000000000..2257186e8 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSparseToDense.test.cpp @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSparseToDense> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<4>::init({shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->indices(input(0)); + node()->output_shape(input(1)); + node()->values(input(2)); + node()->default_value(input(3)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SparseToDense) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSparseToDense *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSparseToDense *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(4, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); + ASSERT_EQ(cth.inputs(3), clone->arg(3)); +} + +TEST(ConnectNodeTest, connect_SparseToDense_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSparseToDense *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSparseToDense *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSplit.cpp b/compiler/luci/partition/src/Nodes/CircleSplit.cpp new file mode 100644 index 000000000..d6d62a8ed --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSplit.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSplit *node) +{ + auto *cloned = loco::must_cast<luci::CircleSplit *>(cn->find_clone(node)); + + luci::CircleNode *split_dim = loco::must_cast<luci::CircleNode *>(node->split_dim()); + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->split_dim(cn->find_clone(split_dim)); + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSplit *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSplit.test.cpp b/compiler/luci/partition/src/Nodes/CircleSplit.test.cpp new file mode 100644 index 000000000..d8d0953e0 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSplit.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSplit> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->split_dim(input(0)); + node()->input(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Split) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplit *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplit *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Split_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplit *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplit *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSplitOut.cpp b/compiler/luci/partition/src/Nodes/CircleSplitOut.cpp new file mode 100644 index 000000000..4021f2042 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSplitOut.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSplitOut *node) +{ + auto *cloned = loco::must_cast<luci::CircleSplitOut *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSplitOut *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSplitOut.test.cpp b/compiler/luci/partition/src/Nodes/CircleSplitOut.test.cpp new file mode 100644 index 000000000..85fe2685b --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSplitOut.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSplitOut> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SplitOut) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitOut *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_SplitOut_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitOut *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSplitV.cpp b/compiler/luci/partition/src/Nodes/CircleSplitV.cpp new file mode 100644 index 000000000..f13205725 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSplitV.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSplitV *node) +{ + auto *cloned = loco::must_cast<luci::CircleSplitV *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *size_splits = loco::must_cast<luci::CircleNode *>(node->size_splits()); + luci::CircleNode *split_dim = loco::must_cast<luci::CircleNode *>(node->split_dim()); + + cloned->input(cn->find_clone(input)); + cloned->size_splits(cn->find_clone(size_splits)); + cloned->split_dim(cn->find_clone(split_dim)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSplitV *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSplitV.test.cpp b/compiler/luci/partition/src/Nodes/CircleSplitV.test.cpp new file mode 100644 index 000000000..3ac1d6c27 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSplitV.test.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSplitV> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<3>::init({shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->size_splits(input(1)); + node()->split_dim(input(2)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SplitV) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitV *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitV *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(3, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); +} + +TEST(ConnectNodeTest, connect_SplitV_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitV *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitV *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSplitVOut.cpp b/compiler/luci/partition/src/Nodes/CircleSplitVOut.cpp new file mode 100644 index 000000000..2034805cd --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSplitVOut.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSplitVOut *node) +{ + auto *cloned = loco::must_cast<luci::CircleSplitVOut *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSplitVOut *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSplitVOut.test.cpp b/compiler/luci/partition/src/Nodes/CircleSplitVOut.test.cpp new file mode 100644 index 000000000..434dfb0ad --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSplitVOut.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSplitVOut> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SplitVOut) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitVOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitVOut *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_SplitVOut_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitVOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSplitVOut *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSqrt.test.cpp b/compiler/luci/partition/src/Nodes/CircleSqrt.test.cpp new file mode 100644 index 000000000..fa7f7fe2a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSqrt.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSqrt> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Sqrt) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSqrt *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSqrt *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Sqrt_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSqrt *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSqrt *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSquare.cpp b/compiler/luci/partition/src/Nodes/CircleSquare.cpp new file mode 100644 index 000000000..1476a8694 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSquare.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSquare *node) +{ + auto *cloned = loco::must_cast<luci::CircleSquare *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSquare *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSquare.test.cpp b/compiler/luci/partition/src/Nodes/CircleSquare.test.cpp new file mode 100644 index 000000000..bb6a7c33f --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSquare.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSquare> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Square) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSquare *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSquare *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Square_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSquare *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSquare *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSquaredDifference.test.cpp b/compiler/luci/partition/src/Nodes/CircleSquaredDifference.test.cpp new file mode 100644 index 000000000..9cfe9eefb --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSquaredDifference.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSquaredDifference> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->x(input(0)); + node()->y(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_SquaredDifference) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSquaredDifference *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSquaredDifference *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_SquaredDifference_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSquaredDifference *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSquaredDifference *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSqueeze.cpp b/compiler/luci/partition/src/Nodes/CircleSqueeze.cpp new file mode 100644 index 000000000..bc9fda296 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSqueeze.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSqueeze *node) +{ + auto *cloned = loco::must_cast<luci::CircleSqueeze *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSqueeze *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSqueeze.test.cpp b/compiler/luci/partition/src/Nodes/CircleSqueeze.test.cpp new file mode 100644 index 000000000..1f0971043 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSqueeze.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSqueeze> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Squeeze) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSqueeze *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSqueeze *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Squeeze_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSqueeze *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSqueeze *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleStridedSlice.cpp b/compiler/luci/partition/src/Nodes/CircleStridedSlice.cpp new file mode 100644 index 000000000..3bdca8a8a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleStridedSlice.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleStridedSlice *node) +{ + auto *cloned = loco::must_cast<luci::CircleStridedSlice *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *begin = loco::must_cast<luci::CircleNode *>(node->begin()); + luci::CircleNode *end = loco::must_cast<luci::CircleNode *>(node->end()); + luci::CircleNode *strides = loco::must_cast<luci::CircleNode *>(node->strides()); + + cloned->input(cn->find_clone(input)); + cloned->begin(cn->find_clone(begin)); + cloned->end(cn->find_clone(end)); + cloned->strides(cn->find_clone(strides)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleStridedSlice *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleStridedSlice.test.cpp b/compiler/luci/partition/src/Nodes/CircleStridedSlice.test.cpp new file mode 100644 index 000000000..130ff9159 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleStridedSlice.test.cpp @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleStridedSlice> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<4>::init({shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->begin(input(1)); + node()->end(input(2)); + node()->strides(input(3)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_StridedSlice) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleStridedSlice *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleStridedSlice *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(4, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); + ASSERT_EQ(cth.inputs(3), clone->arg(3)); +} + +TEST(ConnectNodeTest, connect_StridedSlice_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleStridedSlice *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleStridedSlice *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleSum.cpp b/compiler/luci/partition/src/Nodes/CircleSum.cpp new file mode 100644 index 000000000..bef1d4676 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSum.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleSum *node) +{ + auto *cloned = loco::must_cast<luci::CircleSum *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *reduction_indices = + loco::must_cast<luci::CircleNode *>(node->reduction_indices()); + + cloned->input(cn->find_clone(input)); + cloned->reduction_indices(cn->find_clone(reduction_indices)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleSum *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleSum.test.cpp b/compiler/luci/partition/src/Nodes/CircleSum.test.cpp new file mode 100644 index 000000000..1ed65c04f --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleSum.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleSum> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->reduction_indices(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Sum) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSum *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSum *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Sum_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSum *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleSum *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleTanh.cpp b/compiler/luci/partition/src/Nodes/CircleTanh.cpp new file mode 100644 index 000000000..e6c56ebf7 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTanh.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleTanh *node) +{ + auto *cloned = loco::must_cast<luci::CircleTanh *>(cn->find_clone(node)); + + luci::CircleNode *x = loco::must_cast<luci::CircleNode *>(node->x()); + + cloned->x(cn->find_clone(x)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleTanh *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleTanh.test.cpp b/compiler/luci/partition/src/Nodes/CircleTanh.test.cpp new file mode 100644 index 000000000..17cd48731 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTanh.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleTanh> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->x(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Tanh) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTanh *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTanh *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Tanh_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTanh *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTanh *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleTile.cpp b/compiler/luci/partition/src/Nodes/CircleTile.cpp new file mode 100644 index 000000000..0381b4dac --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTile.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleTile *node) +{ + auto *cloned = loco::must_cast<luci::CircleTile *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *multiples = loco::must_cast<luci::CircleNode *>(node->multiples()); + + cloned->input(cn->find_clone(input)); + cloned->multiples(cn->find_clone(multiples)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleTile *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleTile.test.cpp b/compiler/luci/partition/src/Nodes/CircleTile.test.cpp new file mode 100644 index 000000000..79d1ba16c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTile.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleTile> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->multiples(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Tile) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTile *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTile *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Tile_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTile *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTile *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleTopKV2.cpp b/compiler/luci/partition/src/Nodes/CircleTopKV2.cpp new file mode 100644 index 000000000..ce8a6f5df --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTopKV2.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleTopKV2 *node) +{ + auto *cloned = loco::must_cast<luci::CircleTopKV2 *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + luci::CircleNode *k = loco::must_cast<luci::CircleNode *>(node->k()); + + cloned->input(cn->find_clone(input)); + cloned->k(cn->find_clone(k)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleTopKV2 *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleTopKV2.test.cpp b/compiler/luci/partition/src/Nodes/CircleTopKV2.test.cpp new file mode 100644 index 000000000..f08f3f315 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTopKV2.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleTopKV2> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + node()->k(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_TopKV2) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2 *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_TopKV2_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2 *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2 *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleTopKV2Out.cpp b/compiler/luci/partition/src/Nodes/CircleTopKV2Out.cpp new file mode 100644 index 000000000..6ca6e3d29 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTopKV2Out.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleTopKV2Out *node) +{ + auto *cloned = loco::must_cast<luci::CircleTopKV2Out *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleTopKV2Out *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleTopKV2Out.test.cpp b/compiler/luci/partition/src/Nodes/CircleTopKV2Out.test.cpp new file mode 100644 index 000000000..a5c1c43f7 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTopKV2Out.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleTopKV2Out> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_TopKV2Out) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2Out *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2Out *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_TopKV2Out_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2Out *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTopKV2Out *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleTranspose.cpp b/compiler/luci/partition/src/Nodes/CircleTranspose.cpp new file mode 100644 index 000000000..1cbb54666 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTranspose.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleTranspose *node) +{ + auto *cloned = loco::must_cast<luci::CircleTranspose *>(cn->find_clone(node)); + + luci::CircleNode *a = loco::must_cast<luci::CircleNode *>(node->a()); + luci::CircleNode *perm = loco::must_cast<luci::CircleNode *>(node->perm()); + + cloned->a(cn->find_clone(a)); + cloned->perm(cn->find_clone(perm)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleTranspose *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleTranspose.test.cpp b/compiler/luci/partition/src/Nodes/CircleTranspose.test.cpp new file mode 100644 index 000000000..b3b16307c --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTranspose.test.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleTranspose> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIsOGraph<2>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<2>::init({shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->a(input(0)); + node()->perm(input(1)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Transpose) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTranspose *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTranspose *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(2, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); +} + +TEST(ConnectNodeTest, connect_Transpose_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTranspose *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTranspose *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleTransposeConv.cpp b/compiler/luci/partition/src/Nodes/CircleTransposeConv.cpp new file mode 100644 index 000000000..469cc9a1a --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTransposeConv.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleTransposeConv *node) +{ + auto *cloned = loco::must_cast<luci::CircleTransposeConv *>(cn->find_clone(node)); + + luci::CircleNode *inputSizes = loco::must_cast<luci::CircleNode *>(node->inputSizes()); + luci::CircleNode *filter = loco::must_cast<luci::CircleNode *>(node->filter()); + luci::CircleNode *outBackprop = loco::must_cast<luci::CircleNode *>(node->outBackprop()); + luci::CircleNode *bias = loco::must_cast<luci::CircleNode *>(node->bias()); + + cloned->inputSizes(cn->find_clone(inputSizes)); + cloned->filter(cn->find_clone(filter)); + cloned->outBackprop(cn->find_clone(outBackprop)); + cloned->bias(cn->find_clone(bias)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleTransposeConv *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp b/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp new file mode 100644 index 000000000..ee9fb0e78 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleTransposeConv.test.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleTransposeConv> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleTransposeConv>::init(g); + + _node->padding(luci::Padding::VALID); + } +}; + +class TestNodeGraph : public TestIsOGraph<4>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<4>::init({shape, shape, shape, shape}, shape); + NodeGraphlet::init(g()); + + node()->inputSizes(input(0)); + node()->filter(input(1)); + node()->outBackprop(input(2)); + node()->bias(input(3)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_TransposeConv) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTransposeConv *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTransposeConv *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(4, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); + ASSERT_EQ(cth.inputs(1), clone->arg(1)); + ASSERT_EQ(cth.inputs(2), clone->arg(2)); + ASSERT_EQ(cth.inputs(3), clone->arg(3)); +} + +TEST(ConnectNodeTest, connect_TransposeConv_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTransposeConv *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleTransposeConv *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp b/compiler/luci/partition/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp new file mode 100644 index 000000000..3f0374aac --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleUnidirectionalSequenceLSTM.cpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleUnidirectionalSequenceLSTM *node) +{ + auto *cloned = loco::must_cast<luci::CircleUnidirectionalSequenceLSTM *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + luci::CircleNode *input_to_input_weights = + loco::must_cast<luci::CircleNode *>(node->input_to_input_weights()); + luci::CircleNode *input_to_forget_weights = + loco::must_cast<luci::CircleNode *>(node->input_to_forget_weights()); + luci::CircleNode *input_to_cell_weights = + loco::must_cast<luci::CircleNode *>(node->input_to_cell_weights()); + luci::CircleNode *input_to_output_weights = + loco::must_cast<luci::CircleNode *>(node->input_to_output_weights()); + + luci::CircleNode *recurrent_to_input_weights = + loco::must_cast<luci::CircleNode *>(node->recurrent_to_input_weights()); + luci::CircleNode *recurrent_to_forget_weights = + loco::must_cast<luci::CircleNode *>(node->recurrent_to_forget_weights()); + luci::CircleNode *recurrent_to_cell_weights = + loco::must_cast<luci::CircleNode *>(node->recurrent_to_cell_weights()); + luci::CircleNode *recurrent_to_output_weights = + loco::must_cast<luci::CircleNode *>(node->recurrent_to_output_weights()); + + luci::CircleNode *cell_to_input_weights = + loco::must_cast<luci::CircleNode *>(node->cell_to_input_weights()); + luci::CircleNode *cell_to_forget_weights = + loco::must_cast<luci::CircleNode *>(node->cell_to_forget_weights()); + luci::CircleNode *cell_to_output_weights = + loco::must_cast<luci::CircleNode *>(node->cell_to_output_weights()); + + luci::CircleNode *input_gate_bias = loco::must_cast<luci::CircleNode *>(node->input_gate_bias()); + luci::CircleNode *forget_gate_bias = + loco::must_cast<luci::CircleNode *>(node->forget_gate_bias()); + luci::CircleNode *cell_gate_bias = loco::must_cast<luci::CircleNode *>(node->cell_gate_bias()); + luci::CircleNode *output_gate_bias = + loco::must_cast<luci::CircleNode *>(node->output_gate_bias()); + + luci::CircleNode *projection_weights = + loco::must_cast<luci::CircleNode *>(node->projection_weights()); + luci::CircleNode *projection_bias = loco::must_cast<luci::CircleNode *>(node->projection_bias()); + + luci::CircleNode *activation_state = + loco::must_cast<luci::CircleNode *>(node->activation_state()); + luci::CircleNode *cell_state = loco::must_cast<luci::CircleNode *>(node->cell_state()); + + luci::CircleNode *input_layer_norm_coefficients = + loco::must_cast<luci::CircleNode *>(node->input_layer_norm_coefficients()); + luci::CircleNode *forget_layer_norm_coefficients = + loco::must_cast<luci::CircleNode *>(node->forget_layer_norm_coefficients()); + luci::CircleNode *cell_layer_norm_coefficients = + loco::must_cast<luci::CircleNode *>(node->cell_layer_norm_coefficients()); + luci::CircleNode *output_layer_norm_coefficients = + loco::must_cast<luci::CircleNode *>(node->output_layer_norm_coefficients()); + + cloned->input(cn->find_clone(input)); + + cloned->input_to_input_weights(cn->find_clone(input_to_input_weights)); + cloned->input_to_forget_weights(cn->find_clone(input_to_forget_weights)); + cloned->input_to_cell_weights(cn->find_clone(input_to_cell_weights)); + cloned->input_to_output_weights(cn->find_clone(input_to_output_weights)); + + cloned->recurrent_to_input_weights(cn->find_clone(recurrent_to_input_weights)); + cloned->recurrent_to_forget_weights(cn->find_clone(recurrent_to_forget_weights)); + cloned->recurrent_to_cell_weights(cn->find_clone(recurrent_to_cell_weights)); + cloned->recurrent_to_output_weights(cn->find_clone(recurrent_to_output_weights)); + + cloned->cell_to_input_weights(cn->find_clone(cell_to_input_weights)); + cloned->cell_to_forget_weights(cn->find_clone(cell_to_forget_weights)); + cloned->cell_to_output_weights(cn->find_clone(cell_to_output_weights)); + + cloned->input_gate_bias(cn->find_clone(input_gate_bias)); + cloned->forget_gate_bias(cn->find_clone(forget_gate_bias)); + cloned->cell_gate_bias(cn->find_clone(cell_gate_bias)); + cloned->output_gate_bias(cn->find_clone(output_gate_bias)); + + cloned->projection_weights(cn->find_clone(projection_weights)); + cloned->projection_bias(cn->find_clone(projection_bias)); + + cloned->activation_state(cn->find_clone(activation_state)); + cloned->cell_state(cn->find_clone(cell_state)); + + cloned->input_layer_norm_coefficients(cn->find_clone(input_layer_norm_coefficients)); + cloned->forget_layer_norm_coefficients(cn->find_clone(forget_layer_norm_coefficients)); + cloned->cell_layer_norm_coefficients(cn->find_clone(cell_layer_norm_coefficients)); + cloned->output_layer_norm_coefficients(cn->find_clone(output_layer_norm_coefficients)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleUnidirectionalSequenceLSTM *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleUnidirectionalSequenceLSTM.test.cpp b/compiler/luci/partition/src/Nodes/CircleUnidirectionalSequenceLSTM.test.cpp new file mode 100644 index 000000000..aeefef093 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleUnidirectionalSequenceLSTM.test.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleUnidirectionalSequenceLSTM> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g) override + { + NodeGraphletT<luci::CircleUnidirectionalSequenceLSTM>::init(g); + + _node->fusedActivationFunction(luci::FusedActFunc::RELU); + } +}; + +class TestNodeGraph : public TestIsOGraph<24>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOGraph<24>::init({shape, shape, shape, shape, shape, shape, shape, shape, + shape, shape, shape, shape, shape, shape, shape, shape, + shape, shape, shape, shape, shape, shape, shape, shape}, + shape); + NodeGraphlet::init(g()); + + node()->input(input(0)); + + node()->input_to_input_weights(input(1)); + node()->input_to_forget_weights(input(2)); + node()->input_to_cell_weights(input(3)); + node()->input_to_output_weights(input(4)); + + node()->recurrent_to_input_weights(input(5)); + node()->recurrent_to_forget_weights(input(6)); + node()->recurrent_to_cell_weights(input(7)); + node()->recurrent_to_output_weights(input(8)); + + node()->cell_to_input_weights(input(9)); + node()->cell_to_forget_weights(input(10)); + node()->cell_to_output_weights(input(11)); + + node()->input_gate_bias(input(12)); + node()->forget_gate_bias(input(13)); + node()->cell_gate_bias(input(14)); + node()->output_gate_bias(input(15)); + + node()->projection_weights(input(16)); + node()->projection_bias(input(17)); + + node()->activation_state(input(18)); + node()->cell_state(input(19)); + + node()->input_layer_norm_coefficients(input(20)); + node()->forget_layer_norm_coefficients(input(21)); + node()->cell_layer_norm_coefficients(input(22)); + node()->output_layer_norm_coefficients(input(23)); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_UnidirectionalSequenceLSTM) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnidirectionalSequenceLSTM *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnidirectionalSequenceLSTM *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(24, clone->arity()); + // 24 separate checks is too much + for (uint32_t i = 0; i < 24; ++i) + ASSERT_EQ(cth.inputs(i), clone->arg(i)); +} + +TEST(ConnectNodeTest, connect_UnidirectionalSequenceLSTM_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnidirectionalSequenceLSTM *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnidirectionalSequenceLSTM *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleUnique.cpp b/compiler/luci/partition/src/Nodes/CircleUnique.cpp new file mode 100644 index 000000000..79ca59466 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleUnique.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleUnique *node) +{ + auto *cloned = loco::must_cast<luci::CircleUnique *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleUnique *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleUnique.test.cpp b/compiler/luci/partition/src/Nodes/CircleUnique.test.cpp new file mode 100644 index 000000000..23f299840 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleUnique.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleUnique> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Unique) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnique *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnique *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Unique_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnique *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnique *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleUniqueOut.cpp b/compiler/luci/partition/src/Nodes/CircleUniqueOut.cpp new file mode 100644 index 000000000..f244dd6eb --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleUniqueOut.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleUniqueOut *node) +{ + auto *cloned = loco::must_cast<luci::CircleUniqueOut *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleUniqueOut *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleUniqueOut.test.cpp b/compiler/luci/partition/src/Nodes/CircleUniqueOut.test.cpp new file mode 100644 index 000000000..887640790 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleUniqueOut.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleUniqueOut> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_UniqueOut) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUniqueOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUniqueOut *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_UniqueOut_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUniqueOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUniqueOut *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleUnpack.cpp b/compiler/luci/partition/src/Nodes/CircleUnpack.cpp new file mode 100644 index 000000000..f83c5d810 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleUnpack.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleUnpack *node) +{ + auto *cloned = loco::must_cast<luci::CircleUnpack *>(cn->find_clone(node)); + + luci::CircleNode *value = loco::must_cast<luci::CircleNode *>(node->value()); + + cloned->value(cn->find_clone(value)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleUnpack *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleUnpack.test.cpp b/compiler/luci/partition/src/Nodes/CircleUnpack.test.cpp new file mode 100644 index 000000000..b164cc3bc --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleUnpack.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleUnpack> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->value(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Unpack) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpack *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpack *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Unpack_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpack *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpack *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleUnpackOut.cpp b/compiler/luci/partition/src/Nodes/CircleUnpackOut.cpp new file mode 100644 index 000000000..b8982fff5 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleUnpackOut.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleUnpackOut *node) +{ + auto *cloned = loco::must_cast<luci::CircleUnpackOut *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleUnpackOut *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleUnpackOut.test.cpp b/compiler/luci/partition/src/Nodes/CircleUnpackOut.test.cpp new file mode 100644 index 000000000..9ed440966 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleUnpackOut.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleUnpackOut> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_UnpackOut) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpackOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpackOut *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_UnpackOut_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpackOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleUnpackOut *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleWhere.cpp b/compiler/luci/partition/src/Nodes/CircleWhere.cpp new file mode 100644 index 000000000..8ef274268 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleWhere.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleWhere *node) +{ + auto *cloned = loco::must_cast<luci::CircleWhere *>(cn->find_clone(node)); + + luci::CircleNode *condition = loco::must_cast<luci::CircleNode *>(node->condition()); + + cloned->condition(cn->find_clone(condition)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleWhere *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleWhere.test.cpp b/compiler/luci/partition/src/Nodes/CircleWhere.test.cpp new file mode 100644 index 000000000..942f804c2 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleWhere.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleWhere> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->condition(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_Where) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhere *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhere *>(node)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_Where_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhere *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhere *>(node)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleWhile.cpp b/compiler/luci/partition/src/Nodes/CircleWhile.cpp new file mode 100644 index 000000000..7820aca01 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleWhile.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleWhile *node) +{ + auto *cloned = loco::must_cast<luci::CircleWhile *>(cn->find_clone(node)); + + auto input_count = node->input_count(); + for (uint32_t in = 0; in < input_count; ++in) + { + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input(in)); + + cloned->input(in, cn->find_clone(input)); + } +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleWhile *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleWhile.test.cpp b/compiler/luci/partition/src/Nodes/CircleWhile.test.cpp new file mode 100644 index 000000000..bffb7869d --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleWhile.test.cpp @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeIsOsGraphletT<luci::CircleWhile> +{ +public: + NodeGraphlet() = default; + +public: + void init(loco::Graph *g, uint32_t n, uint32_t m) override { NodeIsOsGraphletT::init(g, n, m); } +}; + +class TestNodeGraph : public TestIsOsGraph<1, 1>, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIsOsGraph<1, 1>::init({shape}, {shape}); + NodeGraphlet::init(g(), 1, 1); + + node()->input(0, input(0)); + + output(0)->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_While) +{ + TestNodeGraph tng; + tng.init({1}); + + ConnectionTestHelper cth; + cth.prepare_inputs<1, 1>(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhile *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhile *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_While_NEG) +{ + TestNodeGraph tng; + tng.init({1}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss<1, 1>(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhile *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhile *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleWhileOut.cpp b/compiler/luci/partition/src/Nodes/CircleWhileOut.cpp new file mode 100644 index 000000000..1cb4419db --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleWhileOut.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleWhileOut *node) +{ + auto *cloned = loco::must_cast<luci::CircleWhileOut *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleWhileOut *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleWhileOut.test.cpp b/compiler/luci/partition/src/Nodes/CircleWhileOut.test.cpp new file mode 100644 index 000000000..901f31b01 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleWhileOut.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleWhileOut> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_WhileOut) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhileOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhileOut *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_WhileOut_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhileOut *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleWhileOut *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Nodes/CircleZerosLike.cpp b/compiler/luci/partition/src/Nodes/CircleZerosLike.cpp new file mode 100644 index 000000000..715042d86 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleZerosLike.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +namespace +{ + +void connect(luci::ConnectNode *cn, const luci::CircleZerosLike *node) +{ + auto *cloned = loco::must_cast<luci::CircleZerosLike *>(cn->find_clone(node)); + + luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(node->input()); + + cloned->input(cn->find_clone(input)); +} + +} // namespace + +namespace luci +{ + +void ConnectNode::visit(const luci::CircleZerosLike *node) { connect(this, node); } + +} // namespace luci diff --git a/compiler/luci/partition/src/Nodes/CircleZerosLike.test.cpp b/compiler/luci/partition/src/Nodes/CircleZerosLike.test.cpp new file mode 100644 index 000000000..74c873cb2 --- /dev/null +++ b/compiler/luci/partition/src/Nodes/CircleZerosLike.test.cpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ConnectNode.h" + +#include "ConnectNode.test.h" + +#include <luci/Service/CircleNodeClone.h> + +#include <gtest/gtest.h> + +namespace +{ + +using namespace luci::test; + +class NodeGraphlet : public NodeGraphletT<luci::CircleZerosLike> +{ +public: + NodeGraphlet() = default; +}; + +class TestNodeGraph : public TestIOGraph, public NodeGraphlet +{ +public: + TestNodeGraph() = default; + +public: + void init(const ShapeU32 shape) + { + TestIOGraph::init(shape, shape); + NodeGraphlet::init(g()); + + node()->input(input()); + + output()->from(node()); + } +}; + +} // namespace + +TEST(ConnectNodeTest, connect_ZerosLike) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleZerosLike *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleZerosLike *>(clone)); + + cth.clone_connect(node, clone); + + ASSERT_EQ(1, clone->arity()); + ASSERT_EQ(cth.inputs(0), clone->arg(0)); +} + +TEST(ConnectNodeTest, connect_ZerosLike_NEG) +{ + TestNodeGraph tng; + tng.init({2, 3}); + + ConnectionTestHelper cth; + cth.prepare_inputs_miss(&tng); + + auto *node = tng.node(); + ASSERT_NO_THROW(loco::must_cast<luci::CircleZerosLike *>(node)); + + auto *clone = luci::clone_node(node, cth.graph_clone()); + ASSERT_NO_THROW(loco::must_cast<luci::CircleZerosLike *>(clone)); + + EXPECT_ANY_THROW(cth.clone_connect(node, clone)); +} diff --git a/compiler/luci/partition/src/Partition.test.cpp b/compiler/luci/partition/src/Partition.test.cpp index 9e24c441c..42fcc5189 100644 --- a/compiler/luci/partition/src/Partition.test.cpp +++ b/compiler/luci/partition/src/Partition.test.cpp @@ -73,6 +73,7 @@ TEST(PartitionTest, simple_apply) luci::PartitionTable pt; pt.default_group = "A"; + pt.comply = luci::PartitionTable::COMPLY::OPCODE; auto pms = apply(&module, pt); diff --git a/compiler/luci/partition/src/PartitionCleanup.cpp b/compiler/luci/partition/src/PartitionCleanup.cpp index 6545295df..7bf51518a 100644 --- a/compiler/luci/partition/src/PartitionCleanup.cpp +++ b/compiler/luci/partition/src/PartitionCleanup.cpp @@ -71,9 +71,6 @@ void remove_unused_inputoutputs(luci::PGroups *pgroups, const luci::Module *sour LOGGER(l); - // TODO support multiple subgraph - assert(source->size() == 1); - INFO(l) << "--- Cleanup unused inputs/outputs"; // remove input within same pgroup diff --git a/compiler/luci/partition/src/PartitionDump.cpp b/compiler/luci/partition/src/PartitionDump.cpp new file mode 100644 index 000000000..69aec610d --- /dev/null +++ b/compiler/luci/partition/src/PartitionDump.cpp @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/PartitionDump.h" + +namespace +{ + +void dump(std::ostream &os, const luci::PartitionTable &table) +{ + os << "Backends:"; + for (auto &group : table.groups) + { + os << " " << group; + if (table.default_group == group) + os << "(default)"; + } + os << std::endl; + + os << "Assign by OPCODE: " << std::endl; + for (auto &item : table.byopcodes) + os << " " << item.first << "=" << item.second << std::endl; + + os << "Assign by OPNAME: " << std::endl; + for (auto &item : table.byopnames) + os << " " << item.first << "=" << item.second << std::endl; +} + +} // namespace + +std::ostream &operator<<(std::ostream &os, const luci::PartitionTable &table) +{ + dump(os, table); + return os; +} diff --git a/compiler/luci/partition/src/PartitionIR.cpp b/compiler/luci/partition/src/PartitionIR.cpp index ebd6b25fa..60dc74f89 100644 --- a/compiler/luci/partition/src/PartitionIR.cpp +++ b/compiler/luci/partition/src/PartitionIR.cpp @@ -67,7 +67,7 @@ std::unique_ptr<PGroups> PGroups::make_copy(void) const return std::move(d_pgroups); } -std::string PGroups::group_of(luci::CircleNode *node) const +GroupKey PGroups::group_of(luci::CircleNode *node) const { assert(node != nullptr); diff --git a/compiler/luci/partition/src/PartitionIR.h b/compiler/luci/partition/src/PartitionIR.h index 852e38cc0..c91b2f2ab 100644 --- a/compiler/luci/partition/src/PartitionIR.h +++ b/compiler/luci/partition/src/PartitionIR.h @@ -29,6 +29,8 @@ namespace luci struct PGroup; +using GroupKey = std::string; + /** * @brief Partition Node with CircleNode with group name * @note node just points to source luci::CircleNode, NOT the cloned node @@ -37,7 +39,7 @@ struct PGroup; struct PNode { const luci::CircleNode *node = nullptr; - std::string group; + GroupKey group; const PGroup *pgroup = nullptr; }; @@ -48,7 +50,7 @@ struct PNode struct PGroup { std::vector<std::unique_ptr<PNode>> pnodes; - std::string group; + GroupKey group; uint32_t id = 0; // I/O while partitioning @@ -61,13 +63,13 @@ struct PGroups std::vector<std::unique_ptr<PGroup>> pgroups; // node2group is to find group key from source node - std::map<const luci::CircleNode *, std::string> node2group; + std::map<const luci::CircleNode *, GroupKey> node2group; // id2pngroup is to find *pngroup from pngroup id std::map<uint32_t, PGroup *> id2pgroup; // default group key for reference - std::string default_group; + GroupKey default_group; public: /** @@ -78,7 +80,7 @@ public: /** * @brief return group key of node, empty string if not found */ - std::string group_of(luci::CircleNode *node) const; + GroupKey group_of(luci::CircleNode *node) const; /** * @brief return holding pgroup of node, nullptr if not found diff --git a/compiler/luci/partition/src/PartitionMerge.cpp b/compiler/luci/partition/src/PartitionMerge.cpp index 038fc2a0c..b767c77ae 100644 --- a/compiler/luci/partition/src/PartitionMerge.cpp +++ b/compiler/luci/partition/src/PartitionMerge.cpp @@ -50,9 +50,18 @@ bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups) std::string group; for (auto &input : pgroup->inputs) { + // We ignore below logic for CircleConst. + // CircleConst will be cloned if they are not found in pgroup as an input. + // Refer build_graph(), "add CircleConst for inputs" + // Reason: CircleConst can be shared as input to multiple nodes + // where each node can be placed in different groups. For this case + // we need to clone this CircleConst for each graph of the group. + if (dynamic_cast<const luci::CircleConst *>(input) != nullptr) + continue; + auto input_group = pgroups->group_of(input); // NOTE: all the nodes should be registered and return should be valid group. - // convert_to_proups() should ensure this. + // produce_pgroups() should ensure this, except CircleConst, Input, Outputs. // assert here to find if there is any problem with this. assert(not input_group.empty()); if (input_group.empty()) diff --git a/compiler/luci/partition/src/PartitionPGroups.cpp b/compiler/luci/partition/src/PartitionPGroups.cpp index 594ed6c40..e0b4e8e0d 100644 --- a/compiler/luci/partition/src/PartitionPGroups.cpp +++ b/compiler/luci/partition/src/PartitionPGroups.cpp @@ -67,8 +67,8 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source, const luci::PartitionTable &partition) { assert(source != nullptr); - // TODO support multiple subgraphs - assert(source->size() == 1); + // NOTE Only main graph (subgraph index 0) will be partitioned. + // Other subgraphs will follow the owner (IF/WHILE/...) group LOGGER(l); @@ -86,13 +86,36 @@ std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source, // check if node is normal node that we are interested if (check_allocate_partition(node)) { - auto opcodename = luci::opcode_name(node); - assert(!opcodename.empty()); - auto group = partition.default_group; - auto it = partition.byopcodes.find(opcodename); - if (it != partition.byopcodes.end()) - group = it->second; + + std::string opcodename; // opcodename or opname + + switch (partition.comply) + { + case luci::PartitionTable::COMPLY::OPCODE: + { + opcodename = luci::opcode_name(node); + assert(!opcodename.empty()); + + auto it = partition.byopcodes.find(opcodename); + if (it != partition.byopcodes.end()) + group = it->second; + break; + } + case luci::PartitionTable::COMPLY::OPNAME: + { + opcodename = node->name(); + assert(!opcodename.empty()); + + auto it = partition.byopnames.find(opcodename); + if (it != partition.byopnames.end()) + group = it->second; + break; + } + + default: + throw std::runtime_error("Unsupported partition.comply"); + } INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group << std::endl; diff --git a/compiler/luci/partition/src/PartitionPGroups.test.cpp b/compiler/luci/partition/src/PartitionPGroups.test.cpp index 960f3cde9..f31641be4 100644 --- a/compiler/luci/partition/src/PartitionPGroups.test.cpp +++ b/compiler/luci/partition/src/PartitionPGroups.test.cpp @@ -73,6 +73,7 @@ TEST(PartitionPGroupsTest, simple_produce) luci::PartitionTable pt; pt.default_group = "A"; + pt.comply = luci::PartitionTable::COMPLY::OPCODE; auto pgs = produce_pgroups(&module, pt); diff --git a/compiler/luci/partition/src/PartitionPModules.cpp b/compiler/luci/partition/src/PartitionPModules.cpp index 36f4d47a4..beaaf6093 100644 --- a/compiler/luci/partition/src/PartitionPModules.cpp +++ b/compiler/luci/partition/src/PartitionPModules.cpp @@ -25,6 +25,12 @@ namespace { +// forward declare +void clone_ifnode_subgraphs(luci::PartedModule &pm, const luci::CircleIf *if_node, + const luci::CloneContext &clonectx); +void clone_whilenode_subgraphs(luci::PartedModule &pm, const luci::CircleWhile *while_node, + const luci::CloneContext &clonectx); + void add_graph_input(loco::Graph *graph, luci::CircleInput *input_node) { assert(graph != nullptr); @@ -76,9 +82,177 @@ void add_graph_output(loco::Graph *graph, luci::CircleOutput *output_node) } /** + * @brief make a clone of graph + */ +std::unique_ptr<loco::Graph> clone_graph(loco::Graph *graph_org, luci::CloneContext &clonectx) +{ + auto graph = loco::make_graph(); + auto graph_clone = graph.get(); + + graph_clone->name(graph_org->name()); + + // clone inputs + for (uint32_t n = 0; n < graph_org->inputs()->size(); ++n) + { + auto input_org = luci::input_node(graph_org, n); + assert(input_org != nullptr); + + auto *input_clone = graph_clone->nodes()->create<luci::CircleInput>(); + luci::copy_common_attributes(input_org, input_clone); + + add_graph_input(graph_clone, input_clone); + clonectx.emplace(input_org, input_clone); + } + + // clone nodes + auto nodes = graph_org->nodes(); + for (uint32_t n = 0; n < nodes->size(); ++n) + { + auto node = nodes->at(n); + + // skip for CircleInput, CircleOutput + if (dynamic_cast<luci::CircleInput *>(node) != nullptr) + continue; + if (dynamic_cast<luci::CircleOutput *>(node) != nullptr) + continue; + + auto node_org = loco::must_cast<luci::CircleNode *>(node); + assert(clonectx.find(node_org) == clonectx.end()); + + auto *node_clone = clone_node(node_org, graph_clone); + clonectx.emplace(node_org, node_clone); + } + + // connect nodes + for (uint32_t n = 0; n < nodes->size(); ++n) + { + auto node = nodes->at(n); + + // skip for CircleInput, CircleOutput + if (dynamic_cast<luci::CircleInput *>(node) != nullptr) + continue; + if (dynamic_cast<luci::CircleOutput *>(node) != nullptr) + continue; + + auto node_org = loco::must_cast<luci::CircleNode *>(node); + clone_connect(node_org, clonectx); + } + + // clone outputs + for (uint32_t n = 0; n < graph_org->outputs()->size(); ++n) + { + auto output_org = luci::output_node(graph_org, n); + assert(output_org != nullptr); + + auto *output_clone = graph_clone->nodes()->create<luci::CircleOutput>(); + luci::copy_common_attributes(output_org, output_clone); + // note: we don't add output_clone to clonectx. + // logically, output is not used as an input to any other nodes. + auto output_from = loco::must_cast<luci::CircleNode *>(output_org->from()); + auto it = clonectx.find(output_from); + assert(it != clonectx.end()); + output_clone->from(it->second); + + add_graph_output(graph_clone, output_clone); + } + + return std::move(graph); +} + +void clone_recursive_subgraphs(luci::PartedModule &pm, loco::Graph *graph, + const luci::CloneContext &clonectx) +{ + auto nodes = graph->nodes(); + for (uint32_t n = 0; n < nodes->size(); ++n) + { + { + auto if_node = dynamic_cast<luci::CircleIf *>(nodes->at(n)); + if (if_node != nullptr) + { + clone_ifnode_subgraphs(pm, if_node, clonectx); + } + } + { + auto while_node = dynamic_cast<luci::CircleWhile *>(nodes->at(n)); + if (while_node != nullptr) + { + clone_whilenode_subgraphs(pm, while_node, clonectx); + } + } + // TODO handle others + } +} + +void clone_ifnode_subgraphs(luci::PartedModule &pm, const luci::CircleIf *if_node, + const luci::CloneContext &clonectx) +{ + assert(if_node != nullptr); + + auto it = clonectx.find(if_node); + assert(it != clonectx.end()); + auto if_clone = loco::must_cast<luci::CircleIf *>(it->second); + + luci::CloneContext then_clonectx; + luci::CloneContext else_clonectx; + + auto then_graph = if_node->then_graph(); + auto else_graph = if_node->else_graph(); + + auto then_clone = clone_graph(then_graph, then_clonectx); + auto else_clone = clone_graph(else_graph, else_clonectx); + if_clone->then_graph(then_clone.get()); + if_clone->else_graph(else_clone.get()); + + pm.module->add(std::move(then_clone)); + int32_t then_index = pm.module->size() - 1; + pm.module->add(std::move(else_clone)); + int32_t else_index = pm.module->size() - 1; + if_clone->then_branch(then_index); + if_clone->else_branch(else_index); + + // do recursive copy subgraphs of CircleIf if there are any, + // inside then_graph or else_graph. + clone_recursive_subgraphs(pm, then_graph, then_clonectx); + clone_recursive_subgraphs(pm, else_graph, else_clonectx); +} + +void clone_whilenode_subgraphs(luci::PartedModule &pm, const luci::CircleWhile *while_node, + const luci::CloneContext &clonectx) +{ + assert(while_node != nullptr); + + auto it = clonectx.find(while_node); + assert(it != clonectx.end()); + auto while_clone = loco::must_cast<luci::CircleWhile *>(it->second); + + luci::CloneContext cond_clonectx; + luci::CloneContext body_clonectx; + + auto cond_graph = while_node->cond_graph(); + auto body_graph = while_node->body_graph(); + + auto cond_clone = clone_graph(cond_graph, cond_clonectx); + auto body_clone = clone_graph(body_graph, body_clonectx); + while_clone->cond_graph(cond_clone.get()); + while_clone->body_graph(body_clone.get()); + + pm.module->add(std::move(cond_clone)); + int32_t cond_index = pm.module->size() - 1; + pm.module->add(std::move(body_clone)); + int32_t body_index = pm.module->size() - 1; + while_clone->cond_branch(cond_index); + while_clone->body_branch(body_index); + + // do recursive copy subgraphs of CircleWhile if there are any, + // inside cond_graph or body_graph. + clone_recursive_subgraphs(pm, cond_graph, cond_clonectx); + clone_recursive_subgraphs(pm, body_graph, body_clonectx); +} + +/** * @brief Build loco::graph from pgroup into graph */ -void build_graph(loco::Graph *graph, const luci::PGroup *pgroup) +void build_graph(luci::PartedModule &pm, loco::Graph *graph, const luci::PGroup *pgroup) { LOGGER(l); @@ -153,6 +327,27 @@ void build_graph(loco::Graph *graph, const luci::PGroup *pgroup) << "output(" << output << ") -> " << output_clone << "(" << output_clone->name() << ")" << ": from " << it->second << "(" << it->second->name() << ")"; } + + // TODO relocate this if needed + // subgraphs for IF/WHILE/... nodes + for (auto &pnode : pgroup->pnodes) + { + { + auto if_node = dynamic_cast<const luci::CircleIf *>(pnode->node); + if (if_node != nullptr) + { + clone_ifnode_subgraphs(pm, if_node, clonectx); + } + } + { + auto while_node = dynamic_cast<const luci::CircleWhile *>(pnode->node); + if (while_node != nullptr) + { + clone_whilenode_subgraphs(pm, while_node, clonectx); + } + } + // TODO handle others + } } std::string make_name(const luci::PGroup *pgroup) @@ -184,16 +379,20 @@ luci::PartedModules produce_pmodules(const luci::PGroups *pgroups) pm.module = std::make_unique<luci::Module>(); pm.group = pgroup->group; + // the main graph for this module auto graph = loco::make_graph(); + auto graph_ptr = graph.get(); auto graph_name = make_name(pgroup.get()); graph->name(graph_name); + // Add main graph so that other subgraphs can be added inside build_graph + pm.module->add(std::move(graph)); + INFO(l) << "--- Partition Graph build----------------------"; INFO(l) << "--- name: " << graph_name; - build_graph(graph.get(), pgroup.get()); + build_graph(pm, graph_ptr, pgroup.get()); - pm.module->add(std::move(graph)); pms.pmodules.emplace_back(std::move(pm)); } diff --git a/compiler/luci/partition/src/PartitionPModules.test.cpp b/compiler/luci/partition/src/PartitionPModules.test.cpp index 99c39e839..9b949c2de 100644 --- a/compiler/luci/partition/src/PartitionPModules.test.cpp +++ b/compiler/luci/partition/src/PartitionPModules.test.cpp @@ -74,6 +74,7 @@ TEST(PartitionPModulesTest, simple_convert) luci::PartitionTable pt; pt.default_group = "A"; + pt.comply = luci::PartitionTable::COMPLY::OPCODE; auto pgs = produce_pgroups(&module, pt); auto pms = produce_pmodules(pgs.get()); diff --git a/compiler/luci/partition/src/PartitionValidate.cpp b/compiler/luci/partition/src/PartitionValidate.cpp new file mode 100644 index 000000000..5aceb98ca --- /dev/null +++ b/compiler/luci/partition/src/PartitionValidate.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/PartitionValidate.h" + +#include <luci/Service/Validate.h> + +#include <pepper/csv2vec.h> + +#include <iostream> + +namespace luci +{ + +bool validate(luci::PartitionTable &partition) +{ + if (partition.groups.size() == 0) + { + std::cerr << "There is no 'backends' information"; + return false; + } + if (partition.default_group.empty()) + { + std::cerr << "There is no 'default' backend information"; + return false; + } + if (!pepper::is_one_of<std::string>(partition.default_group, partition.groups)) + { + std::cerr << "'default' backend is not one of 'backends' item"; + return false; + } + for (auto &byopcode : partition.byopcodes) + { + if (!pepper::is_one_of<std::string>(byopcode.second, partition.groups)) + { + std::cerr << "OPCODE " << byopcode.first << " is not assigned to one of 'backends' items"; + return false; + } + } + for (auto &byopname : partition.byopnames) + { + if (!pepper::is_one_of<std::string>(byopname.second, partition.groups)) + { + std::cerr << "OPNAME " << byopname.first << " is not assigned to one of 'backends' items"; + return false; + } + } + return true; +} + +} // namespace luci |