diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2020-12-14 14:43:43 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2020-12-14 14:43:43 +0900 |
commit | 62529acabbafce7730601ed01d5709d7bc0d378a (patch) | |
tree | bf6912cfa8fac4a2997292bfcb3c82055734c97e /compiler/luci | |
parent | 6ea13af5257155ff993c205cf997b870cc627f73 (diff) | |
download | nnfw-62529acabbafce7730601ed01d5709d7bc0d378a.tar.gz nnfw-62529acabbafce7730601ed01d5709d7bc0d378a.tar.bz2 nnfw-62529acabbafce7730601ed01d5709d7bc0d378a.zip |
Imported Upstream version 1.12.0upstream/1.12.0
Diffstat (limited to 'compiler/luci')
88 files changed, 4224 insertions, 148 deletions
diff --git a/compiler/luci/export/src/CircleExporterImpl.cpp b/compiler/luci/export/src/CircleExporterImpl.cpp index 860cebf6e..df7542797 100644 --- a/compiler/luci/export/src/CircleExporterImpl.cpp +++ b/compiler/luci/export/src/CircleExporterImpl.cpp @@ -16,7 +16,6 @@ #include "CircleExporterImpl.h" #include "Optimize.h" -#include "TypeBridge.h" #include "CircleTensorExporter.h" #include "CircleOperationExporter.h" #include "CircleExporterUtils.h" @@ -150,9 +149,6 @@ void CircleExporterImpl::exportGraph(loco::Graph *graph) // do graph optimization optimize(graph); - // copy shape/dtype inference data to CircleNode - copy_shape_dtype(graph); - _builder.Clear(); SerializedModelData md; @@ -223,9 +219,6 @@ void CircleExporterImpl::exportModule(Module *module) optimize(graph); - // copy shape/dtype inference data to CircleNode - copy_shape_dtype(graph); - SerializedGraphData gd; // set Subgraph name diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp index 1fdb40e51..3715513e0 100644 --- a/compiler/luci/export/src/CircleExporterUtils.cpp +++ b/compiler/luci/export/src/CircleExporterUtils.cpp @@ -87,6 +87,22 @@ circle::MirrorPadMode to_circle_mirrorpadmode(luci::MirrorPadMode mode) } } +circle::FullyConnectedOptionsWeightsFormat +to_circle_weightsformat(luci::CircleFullyConnected::WeightsFormat format) +{ + switch (format) + { + case luci::CircleFullyConnected::WeightsFormat::DEFAULT: + return circle::FullyConnectedOptionsWeightsFormat_DEFAULT; + case luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8: + return circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; + case luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32: + return circle::FullyConnectedOptionsWeightsFormat_SHUFFLED16x1FLOAT32; + default: + INTERNAL_EXN_V("trying to convert unsupported luci::WeightsFormat", oops::to_uint32(format)); + } +} + circle::DimensionType to_circle_dimensiontype(luci::DimensionType type) { switch (type) diff --git a/compiler/luci/export/src/CircleExporterUtils.h b/compiler/luci/export/src/CircleExporterUtils.h index 7857213b2..95310b353 100644 --- a/compiler/luci/export/src/CircleExporterUtils.h +++ b/compiler/luci/export/src/CircleExporterUtils.h @@ -32,6 +32,8 @@ namespace luci circle::ActivationFunctionType to_circle_actfunc(luci::FusedActFunc func); circle::TensorType to_circle_tensortype(loco::DataType type); circle::MirrorPadMode to_circle_mirrorpadmode(luci::MirrorPadMode mode); +circle::FullyConnectedOptionsWeightsFormat +to_circle_weightsformat(luci::CircleFullyConnected::WeightsFormat format); circle::DimensionType to_circle_dimensiontype(luci::DimensionType type); flatbuffers::Offset<void> to_circle_sparse_index_vector(flatbuffers::FlatBufferBuilder &fb, const SparseIndexVector &sparse_idx_vec); diff --git a/compiler/luci/export/src/CircleOperationExporter.cpp b/compiler/luci/export/src/CircleOperationExporter.cpp index c937109cd..4343cf3c9 100644 --- a/compiler/luci/export/src/CircleOperationExporter.cpp +++ b/compiler/luci/export/src/CircleOperationExporter.cpp @@ -21,7 +21,6 @@ #include <luci/IR/CircleNode.h> #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> -#include <luci/Service/CircleShapeInference.h> #include <luci/UserSettings.h> #include <luci/Log.h> @@ -930,7 +929,8 @@ void OperationExporter::visit(luci::CircleFullyConnected *node) { export_simple( node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions, - CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction())) + CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()), + to_circle_weightsformat(node->weights_format())) .Union()); } diff --git a/compiler/luci/export/src/CircleTensorExporter.cpp b/compiler/luci/export/src/CircleTensorExporter.cpp index 1429d2810..9bdfa0079 100644 --- a/compiler/luci/export/src/CircleTensorExporter.cpp +++ b/compiler/luci/export/src/CircleTensorExporter.cpp @@ -111,10 +111,10 @@ void allocateCircleTensorInfo(CircleNode *node, CircleTensorContext &ctx) CircleTensoInfo tensor_info; tensor_info.name(tensor_name); - tensor_info.dtype(to_circle_tensortype(luci::node_dtype(node))); + tensor_info.dtype(to_circle_tensortype(node->dtype())); tensor_info.shape_signature(node->shape_signature()); if (node->shape_status() == ShapeStatus::VALID) - tensor_info.shape(to_shape_description(luci::node_shape(node))); + tensor_info.shape(to_shape_description(node)); tensor_info.shape_status(node->shape_status()); tensor_info.content(dynamic_cast<luci::CircleConst *>(node)); @@ -243,6 +243,9 @@ flatbuffers::Offset<Vector<int32_t>> encodeShape(FlatBufferBuilder &builder, flatbuffers::Offset<Vector<int32_t>> encodeShapeSignature(FlatBufferBuilder &builder, const ShapeSignature &shape_signature) { + if (shape_signature.rank() == 0) + return 0; + return builder.CreateVector(shape_signature.as_vector()); } diff --git a/compiler/luci/export/src/Optimize.cpp b/compiler/luci/export/src/Optimize.cpp index 6fa50b564..036a4a2f9 100644 --- a/compiler/luci/export/src/Optimize.cpp +++ b/compiler/luci/export/src/Optimize.cpp @@ -18,6 +18,7 @@ #include "ProgressReporter.h" #include <luci/Pass/ShapeInferencePass.h> +#include <luci/Pass/ShapeSignatureInferencePass.h> #include <luci/Pass/TypeInferencePass.h> #include <logo/Phase.h> @@ -34,6 +35,7 @@ void optimize(loco::Graph *g) // prepare type and shape before optimization phase.emplace_back(std::make_unique<TypeInferencePass>()); phase.emplace_back(std::make_unique<ShapeInferencePass>()); + phase.emplace_back(std::make_unique<ShapeSignatureInferencePass>()); // TODO add more optimization passes (with a knob) } diff --git a/compiler/luci/export/src/SerializedData.h b/compiler/luci/export/src/SerializedData.h index 46b1ac2d5..c41f50edd 100644 --- a/compiler/luci/export/src/SerializedData.h +++ b/compiler/luci/export/src/SerializedData.h @@ -64,7 +64,7 @@ namespace luci { /** - * @breif Record the information of T/F Lite SubGraph and its mapping to loco + * @brief Record the information of T/F Lite SubGraph and its mapping to loco */ struct SubGraphContext { diff --git a/compiler/luci/import/include/luci/Import/CircleReader.h b/compiler/luci/import/include/luci/Import/CircleReader.h index 8636b1d9a..8e210dd77 100644 --- a/compiler/luci/import/include/luci/Import/CircleReader.h +++ b/compiler/luci/import/include/luci/Import/CircleReader.h @@ -46,6 +46,8 @@ loco::DataType luci_datatype(circle::TensorType type); FusedActFunc luci_actfunc(const circle::ActivationFunctionType type); Padding luci_padding(const circle::Padding padding); MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode); +luci::CircleFullyConnected::WeightsFormat +luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format); std::unique_ptr<CircleQuantParam> luci_quantparam(const circle::QuantizationParametersT *quantization); diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp index 068de5239..b33c920b1 100644 --- a/compiler/luci/import/src/CircleReader.cpp +++ b/compiler/luci/import/src/CircleReader.cpp @@ -151,6 +151,22 @@ MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode) return MirrorPadMode::UNDEFINED; } +luci::CircleFullyConnected::WeightsFormat +luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format) +{ + switch (weights_format) + { + case circle::FullyConnectedOptionsWeightsFormat_DEFAULT: + return luci::CircleFullyConnected::WeightsFormat::DEFAULT; + case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: + return luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8; + case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED16x1FLOAT32: + return luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32; + default: + throw std::runtime_error("Invalid FullyConnectedOptionsWeightsFormat"); + } +} + DimensionType luci_dim_type(const circle::DimensionType dim_type) { switch (dim_type) diff --git a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp index 65a863bde..17293ad7a 100644 --- a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp +++ b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp @@ -53,12 +53,7 @@ CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT const auto *options = op.builtin_options.AsFullyConnectedOptions(); node->fusedActivationFunction(luci_actfunc(options->fused_activation_function)); - if (options->weights_format != circle::FullyConnectedOptionsWeightsFormat_DEFAULT) - { - throw oops::UserExn( - "Unsupported weights format", - circle::EnumNameFullyConnectedOptionsWeightsFormat(options->weights_format)); - } + node->weights_format(luci_weights_format(options->weights_format)); return node; } diff --git a/compiler/luci/lang/include/luci/IR/AttrDilation.h b/compiler/luci/lang/include/luci/IR/AttrDilation.h index c2b28d77d..ed8232576 100644 --- a/compiler/luci/lang/include/luci/IR/AttrDilation.h +++ b/compiler/luci/lang/include/luci/IR/AttrDilation.h @@ -27,15 +27,17 @@ class Dilation final public: Dilation() : _w(1), _h(1) {} - int32_t w() const { return _w; } - void w(int32_t w) { _w = w; } + uint32_t w() const { return _w; } + void w(uint32_t w) { _w = w; } + void w(int32_t w); - int32_t h() const { return _h; } - void h(int32_t h) { _h = h; } + uint32_t h() const { return _h; } + void h(uint32_t h) { _h = h; } + void h(int32_t h); private: - int32_t _w; - int32_t _h; + uint32_t _w; + uint32_t _h; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/AttrFilter.h b/compiler/luci/lang/include/luci/IR/AttrFilter.h index 7909fa523..af9d7519f 100644 --- a/compiler/luci/lang/include/luci/IR/AttrFilter.h +++ b/compiler/luci/lang/include/luci/IR/AttrFilter.h @@ -27,15 +27,17 @@ class Filter final public: Filter() : _w(1), _h(1) {} - int32_t w() const { return _w; } - void w(int32_t w) { _w = w; } + uint32_t w() const { return _w; } + void w(uint32_t w) { _w = w; } + void w(int32_t w); - int32_t h() const { return _h; } - void h(int32_t h) { _h = h; } + uint32_t h() const { return _h; } + void h(uint32_t h) { _h = h; } + void h(int32_t h); private: - int32_t _w; - int32_t _h; + uint32_t _w; + uint32_t _h; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/AttrStride.h b/compiler/luci/lang/include/luci/IR/AttrStride.h index 654967d73..6be697975 100644 --- a/compiler/luci/lang/include/luci/IR/AttrStride.h +++ b/compiler/luci/lang/include/luci/IR/AttrStride.h @@ -27,15 +27,17 @@ class Stride final public: Stride() : _w(1), _h(1) {} - int32_t w() const { return _w; } - void w(int32_t w) { _w = w; } + uint32_t w() const { return _w; } + void w(uint32_t w) { _w = w; } + void w(int32_t w); - int32_t h() const { return _h; } - void h(int32_t h) { _h = h; } + uint32_t h() const { return _h; } + void h(uint32_t h) { _h = h; } + void h(int32_t h); private: - int32_t _w; - int32_t _h; + uint32_t _w; + uint32_t _h; }; } // namespace luci diff --git a/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h b/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h index 970f1b521..18a260486 100644 --- a/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h +++ b/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h @@ -46,6 +46,8 @@ private: std::vector<int32_t> _shape_signature{}; }; +bool operator==(const ShapeSignature &lhs, const ShapeSignature &rhs); + } // namespace luci #endif // __LUCI_IR_SHAPE_SIGNATURE_H__ diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h index d78f39494..952befc87 100644 --- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h +++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h @@ -35,6 +35,16 @@ class CircleFullyConnected final public LuciNodeMixin<LuciNodeTrait::Bias> { public: + enum class WeightsFormat + { + UNDEFINED, // This is not defined by Circle. This was added to prevent programming error. + + DEFAULT, + SHUFFLED4x16INT8, + SHUFFLED16x1FLOAT32, + }; + +public: loco::Node *input(void) const { return at(0)->node(); } void input(loco::Node *node) { at(0)->node(node); } @@ -43,6 +53,13 @@ public: loco::Node *bias(void) const override { return at(2)->node(); } void bias(loco::Node *node) override { at(2)->node(node); } + +public: + WeightsFormat weights_format(void) const { return _weights_format; } + void weights_format(WeightsFormat weights_format) { _weights_format = weights_format; } + +private: + WeightsFormat _weights_format{WeightsFormat::DEFAULT}; }; } // namespace luci diff --git a/compiler/luci/lang/src/AttrDilation.cpp b/compiler/luci/lang/src/AttrDilation.cpp new file mode 100644 index 000000000..a9f479502 --- /dev/null +++ b/compiler/luci/lang/src/AttrDilation.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/AttrDilation.h" + +#include <cassert> + +namespace luci +{ + +void Dilation::w(int32_t w) +{ + assert(w >= 0); + _w = static_cast<uint32_t>(w); +} + +void Dilation::h(int32_t h) +{ + assert(h >= 0); + _h = static_cast<uint32_t>(h); +} + +} // namespace luci diff --git a/compiler/luci/lang/src/AttrDilation.test.cpp b/compiler/luci/lang/src/AttrDilation.test.cpp new file mode 100644 index 000000000..3e4658990 --- /dev/null +++ b/compiler/luci/lang/src/AttrDilation.test.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/AttrDilation.h" + +#include <gtest/gtest.h> + +TEST(CircleAttrDilationTest, set) +{ + auto d = luci::Dilation(); + + d.h(10u); + d.w(10u); + + ASSERT_EQ(d.h(), 10u); + ASSERT_EQ(d.w(), 10u); + + d.h(10); // int32_t + d.w(10); + + ASSERT_EQ(d.h(), 10u); + ASSERT_EQ(d.w(), 10u); +} diff --git a/compiler/luci/lang/src/AttrFilter.cpp b/compiler/luci/lang/src/AttrFilter.cpp new file mode 100644 index 000000000..9c571e7f5 --- /dev/null +++ b/compiler/luci/lang/src/AttrFilter.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/AttrFilter.h" + +#include <cassert> + +namespace luci +{ + +void Filter::w(int32_t w) +{ + assert(w >= 0); + _w = static_cast<uint32_t>(w); +} + +void Filter::h(int32_t h) +{ + assert(h >= 0); + _h = static_cast<uint32_t>(h); +} + +} // namespace luci diff --git a/compiler/luci/lang/src/AttrFilter.test.cpp b/compiler/luci/lang/src/AttrFilter.test.cpp new file mode 100644 index 000000000..06dbcacd5 --- /dev/null +++ b/compiler/luci/lang/src/AttrFilter.test.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/AttrFilter.h" + +#include <gtest/gtest.h> + +TEST(CircleAttrFilterTest, set) +{ + auto f = luci::Filter(); + + f.h(10u); + f.w(10u); + + ASSERT_EQ(f.h(), 10u); + ASSERT_EQ(f.w(), 10u); + + f.h(10); // int32_t + f.w(10); + + ASSERT_EQ(f.h(), 10u); + ASSERT_EQ(f.w(), 10u); +} diff --git a/compiler/luci/lang/src/AttrStride.cpp b/compiler/luci/lang/src/AttrStride.cpp new file mode 100644 index 000000000..9720d12b5 --- /dev/null +++ b/compiler/luci/lang/src/AttrStride.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/AttrStride.h" + +#include <cassert> + +namespace luci +{ + +void Stride::w(int32_t w) +{ + assert(w >= 0); + _w = static_cast<uint32_t>(w); +} + +void Stride::h(int32_t h) +{ + assert(h >= 0); + _h = static_cast<uint32_t>(h); +} + +} // namespace luci diff --git a/compiler/luci/lang/src/AttrStride.test.cpp b/compiler/luci/lang/src/AttrStride.test.cpp new file mode 100644 index 000000000..e91365bd5 --- /dev/null +++ b/compiler/luci/lang/src/AttrStride.test.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/AttrStride.h" + +#include <gtest/gtest.h> + +TEST(CircleAttrStrideTest, set) +{ + auto s = luci::Stride(); + + s.h(10u); + s.w(10u); + + ASSERT_EQ(s.h(), 10u); + ASSERT_EQ(s.w(), 10u); + + s.h(10); // int32_t + s.w(10); + + ASSERT_EQ(s.h(), 10u); + ASSERT_EQ(s.w(), 10u); +} diff --git a/compiler/luci/lang/src/CircleShapeSignature.cpp b/compiler/luci/lang/src/CircleShapeSignature.cpp new file mode 100644 index 000000000..970000203 --- /dev/null +++ b/compiler/luci/lang/src/CircleShapeSignature.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/IR/CircleShapeSignature.h" + +namespace luci +{ + +bool operator==(const ShapeSignature &lhs, const ShapeSignature &rhs) +{ + if (lhs.rank() != rhs.rank()) + return false; + + for (uint32_t i = 0; i < lhs.rank(); ++i) + if (lhs.dim(i) != rhs.dim(i)) + return false; + + return true; +} + +} // namespace luci diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index db5bdb501..906760e0a 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -19,6 +19,8 @@ #include <loco.h> +#include <luci/IR/Module.h> + #include <string> #include <vector> @@ -47,6 +49,10 @@ public: FusePreActivationBatchNorm, MakeBatchNormGammaPositive, FuseActivationFunction, + ShuffleWeightTo16x1Float32, + RemoveRedundantTranspose, + ReplaceMulAddWithDepthwiseConv, + SubstitutePackToReshape, }; enum AlgorithmParameters @@ -77,6 +83,8 @@ public: Options *options(void); public: + void optimize(luci::Module *) const; + void optimize(loco::Graph *) const; void quantize(loco::Graph *) const; diff --git a/compiler/luci/pass/include/luci/ModulePass.h b/compiler/luci/pass/include/luci/ModulePass.h new file mode 100644 index 000000000..1835f6e0c --- /dev/null +++ b/compiler/luci/pass/include/luci/ModulePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MODULE_PASS_H__ +#define __MODULE_PASS_H__ + +#include <loco.h> +#include <logo/Pass.h> + +#include <luci/IR/Module.h> + +namespace luci +{ + +class Pass : public logo::Pass +{ +public: + // Run module pass and return false if there was nothing changed + virtual bool run(luci::Module *) = 0; +}; + +} // namespace luci + +#endif // __MODULE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h b/compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h new file mode 100644 index 000000000..379b44ccd --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CIRCLE_TYPE_INFERENCE_PASS_H__ +#define __LUCI_CIRCLE_TYPE_INFERENCE_PASS_H__ + +#include <loco.h> + +#include <luci/ModulePass.h> + +namespace luci +{ + +/** + * @brief Pass to infer type of circle nodes + */ +class CircleTypeInferencePass : public luci::Pass +{ +public: + virtual const char *name(void) const { return "luci::CircleTypeInferencePass"; } + +public: + bool run(luci::Module *m); + bool run(loco::Graph *g); +}; + +} // namespace luci + +#endif //__LUCI_CIRCLE_TYPE_INFERENCE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h b/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h index 4404a9fc9..912ad4225 100644 --- a/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h +++ b/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h @@ -17,7 +17,7 @@ #ifndef __LUCI_FUSE_BCQ_PASS_H__ #define __LUCI_FUSE_BCQ_PASS_H__ -#include <logo/Pass.h> +#include <luci/ModulePass.h> namespace luci { @@ -26,10 +26,11 @@ namespace luci * @brief Class to fuse certain pattern of subgraph into CircleBCQFullyConnected or CircleBCQGather * */ -struct FuseBCQPass final : public logo::Pass +struct FuseBCQPass final : public luci::Pass { const char *name(void) const final { return "luci::FuseBCQPass"; } + bool run(luci::Module *m) final; bool run(loco::Graph *g) final; }; diff --git a/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h b/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h new file mode 100644 index 000000000..c0ebc4e5d --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__ +#define __LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__ + +#include <loco.h> + +#include <luci/ModulePass.h> + +namespace luci +{ + +/** + * @brief Pass to copy shape/dtype of loco to circle node + * + * CAUTION : This pass will be removed after refactoring is finished + */ +class MigrateLegacyShapeDtypePass : public luci::Pass +{ +public: + virtual const char *name(void) const { return "luci::MigrateLegacyShapeDtypePass"; } + +public: + bool run(luci::Module *m); + bool run(loco::Graph *graph); +}; + +} // namespace luci + +#endif //__LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h b/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h new file mode 100644 index 000000000..7e0c44b8c --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__ +#define __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to propagate quantization parameters of an operator's output to input + */ +struct PropagateQuantParamPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::PropagateQuantParamPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h b/compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h new file mode 100644 index 000000000..ca20da5ac --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REMOVE_REDUNDANT_TRANSPOSE_H__ +#define __LUCI_REMOVE_REDUNDANT_TRANSPOSE_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief fuse or remove subsequent Transpose operators + */ +struct RemoveRedundantTransposePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::RemoveRedundantTransposePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REMOVE_REDUNDANT_TRANSPOSE_H__ diff --git a/compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h b/compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h new file mode 100644 index 000000000..5dbcc8f5b --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_REPLACE_MUL_ADD_WITH_DEPTHWISE_CONV_PASS_H__ +#define __LUCI_REPLACE_MUL_ADD_WITH_DEPTHWISE_CONV_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to replace channel-wise mul/add with CircleDepthwiseConv2D + */ +struct ReplaceMulAddWithDepthwiseConvPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::ReplaceMulAddWithDepthwiseConvPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_REPLACE_MUL_ADD_WITH_DEPTHWISE_CONV_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h b/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h index 86bb2ab42..e21ab4cce 100644 --- a/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h +++ b/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h @@ -19,7 +19,7 @@ #include <loco.h> -#include <logo/Pass.h> +#include <luci/ModulePass.h> namespace luci { @@ -27,12 +27,13 @@ namespace luci /** * @brief Pass to infer shape of nodes */ -class ShapeInferencePass : public logo::Pass +class ShapeInferencePass : public luci::Pass { public: virtual const char *name(void) const { return "luci::ShapeInferencePass"; } public: + bool run(luci::Module *m); bool run(loco::Graph *graph); }; diff --git a/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h b/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h new file mode 100644 index 000000000..2c6ffcf4e --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__ +#define __LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__ + +#include <loco.h> + +#include <luci/ModulePass.h> + +namespace luci +{ + +/** + * @brief Pass to infer shape_signature of nodes + */ +class ShapeSignatureInferencePass : public luci::Pass +{ +public: + virtual const char *name(void) const { return "luci::ShapeSignatureInferencePass"; } + +public: + bool run(luci::Module *m); + bool run(loco::Graph *graph); +}; + +} // namespace luci + +#endif //__LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h b/compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h new file mode 100644 index 000000000..3d84f5133 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_SHUFFLE_WEIGHT_TO_16X1_FLOAT32_PASS_H__ +#define __LUCI_SHUFFLE_WEIGHT_TO_16X1_FLOAT32_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to convert weight format of FullyConnected to SHUFFLED16x1FLOAT32 + */ +struct ShuffleWeightTo16x1Float32Pass final : public logo::Pass +{ + const char *name(void) const final { return "luci::ShuffleWeightTo16x1Float32Pass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_SHUFFLE_WEIGHT_TO_16X1_FLOAT32_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h b/compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h new file mode 100644 index 000000000..36d13f19f --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_SUBSTITUTE_PACK_TO_RESHAPE_PASS_H__ +#define __LUCI_SUBSTITUTE_PACK_TO_RESHAPE_PASS_H__ + +#include <logo/Pass.h> + +namespace luci +{ + +/** + * @brief Class to Substitute Pack with 1 input to single reshape node. + */ +struct SubstitutePackToReshapePass final : public logo::Pass +{ + const char *name(void) const final { return "luci::SubstitutePackToReshapePass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_SUBSTITUTE_PACK_TO_RESHAPE_PASS_H__ diff --git a/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h b/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h index c607ac63f..9d964bdd6 100644 --- a/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h +++ b/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h @@ -20,7 +20,7 @@ #include <loco.h> -#include <logo/Pass.h> +#include <luci/ModulePass.h> namespace luci { @@ -28,12 +28,13 @@ namespace luci /** * @brief Pass to infer type of nodes */ -class TypeInferencePass : public logo::Pass +class TypeInferencePass : public luci::Pass { public: virtual const char *name(void) const { return "luci::TypeInferencePass"; } public: + bool run(luci::Module *m); bool run(loco::Graph *graph); }; diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index 34f647301..cc9fe481c 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -24,6 +24,9 @@ #include "luci/Pass/FuseInstanceNormPass.h" #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/MakeBatchNormGammaPositivePass.h" +#include "luci/Pass/PropagateQuantParamPass.h" +#include "luci/Pass/RemoveRedundantTransposePass.h" +#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h" #include "luci/Pass/ResolveCustomOpAddPass.h" #include "luci/Pass/ResolveCustomOpBatchMatMulPass.h" #include "luci/Pass/ResolveCustomOpMatMulPass.h" @@ -31,14 +34,21 @@ #include "luci/Pass/QuantizeWithMinMaxPass.h" #include "luci/Pass/QuantizeDequantizeWeightsPass.h" #include "luci/Pass/SparsifyTensorPass.h" +#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h" +#include "luci/Pass/SubstitutePackToReshapePass.h" // TODO add more passes #include "luci/Pass/ShapeInferencePass.h" +#include "luci/Pass/ShapeSignatureInferencePass.h" #include "luci/Pass/TypeInferencePass.h" +// Following passes will be removed after refactoring is finished +#include "luci/Pass/MigrateLegacyShapeDtypePass.h" + // logo passes #include <logo/RemoveDeadNodeWithQueryPass.h> +#include "ModulePhase.h" #include "ProgressReporter.h" #include "CircleOptimizerUtils.h" @@ -124,11 +134,44 @@ CircleOptimizer::Options *CircleOptimizer::options(void) return _options.get(); } +void CircleOptimizer::optimize(luci::Module *m) const +{ + luci::Phase phase; + + // Following passes will be deprecated after refactoring is finished. + phase.emplace_back(std::make_unique<luci::MigrateLegacyShapeDtypePass>()); + + // Following passes are needed everytime when other passes create new node or modify some nodes. + phase.emplace_back(std::make_unique<luci::ShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::ShapeSignatureInferencePass>()); + phase.emplace_back(std::make_unique<luci::TypeInferencePass>()); + + if (_options->query(Options::Algorithm::FuseBCQ)) + { + phase.emplace_back(std::make_unique<FuseBCQPass>()); + } + + ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart); + PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m}; + phase_runner.attach(&prog); + phase_runner.run(phase); +} + void CircleOptimizer::optimize(loco::Graph *g) const { logo::Phase phase; /* TRANSFORM DECLARATION BEGIN */ + phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>()); + + // Following passes will be deprecated after refactoring is finished. + phase.emplace_back(std::make_unique<luci::MigrateLegacyShapeDtypePass>()); + + // Following passes are needed everytime when other passes create new node or modify some nodes. + phase.emplace_back(std::make_unique<luci::TypeInferencePass>()); + phase.emplace_back(std::make_unique<luci::ShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::ShapeSignatureInferencePass>()); + if (_options->query(Options::Algorithm::ResolveCustomOpAdd)) { phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>()); @@ -145,10 +188,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<FuseInstanceNormPass>()); } - if (_options->query(Options::Algorithm::FuseBCQ)) - { - phase.emplace_back(std::make_unique<FuseBCQPass>()); - } if (_options->query(Options::Algorithm::FuseBatchNormWithTConv)) { phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>()); @@ -173,15 +212,27 @@ void CircleOptimizer::optimize(loco::Graph *g) const { phase.emplace_back(std::make_unique<luci::MakeBatchNormGammaPositivePass>()); } + if (_options->query(Options::Algorithm::ShuffleWeightTo16x1Float32)) + { + phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>()); + } + if (_options->query(Options::Algorithm::RemoveRedundantTranspose)) + { + phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>()); + } + if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv)) + { + phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>()); + } + if (_options->query(Options::Algorithm::SubstitutePackToReshape)) + { + phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>()); + } - // Shape inference is needed for added nodes doing above transformations - phase.emplace_back(std::make_unique<luci::ShapeInferencePass>()); - phase.emplace_back(std::make_unique<luci::TypeInferencePass>()); - phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>()); /* TRANSFORM DECLARATION END */ - ProgressReporter prog(g, logo::PhaseStrategy::Saturate); - logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; + ProgressReporter prog(g, logo::PhaseStrategy::Restart); + logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g}; phase_runner.attach(&prog); phase_runner.run(phase); } @@ -258,6 +309,20 @@ void CircleOptimizer::quantize(loco::Graph *g) const luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity)); quantizer.run(g); + + // Post-quantization optimizations + logo::Phase phase; + + phase.emplace_back(std::make_unique<luci::PropagateQuantParamPass>()); + + phase.emplace_back(std::make_unique<luci::ShapeInferencePass>()); + phase.emplace_back(std::make_unique<luci::TypeInferencePass>()); + phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>()); + + ProgressReporter prog(g, logo::PhaseStrategy::Saturate); + logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g}; + phase_runner.attach(&prog); + phase_runner.run(phase); } // Requantize diff --git a/compiler/luci/pass/src/CircleTypeInferencePass.cpp b/compiler/luci/pass/src/CircleTypeInferencePass.cpp new file mode 100644 index 000000000..67bd253e0 --- /dev/null +++ b/compiler/luci/pass/src/CircleTypeInferencePass.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/CircleTypeInferencePass.h" + +#include <luci/Service/CircleTypeInference.h> + +#include <loco.h> + +namespace luci +{ + +bool CircleTypeInferencePass::run(luci::Module *m) +{ + bool changed = false; + + for (size_t g = 0; g < m->size(); ++g) + { + if (run(m->graph(g))) + changed = true; + } + + return changed; +} + +bool CircleTypeInferencePass::run(loco::Graph *g) +{ + luci::tinf::Rule type_infer_rule; + bool changed = false; + + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + loco::DataType dtype; + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + + if (type_infer_rule.infer(circle_node, dtype) && circle_node->dtype() != dtype) + { + circle_node->dtype(dtype); + changed = true; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp index ebf28779b..c0583d848 100644 --- a/compiler/luci/pass/src/FuseBCQPass.cpp +++ b/compiler/luci/pass/src/FuseBCQPass.cpp @@ -25,6 +25,85 @@ namespace { +bool is_fusable_const(luci::CircleConst *before, luci::CircleConst *after, bool do_w_x) +{ + if (after->dtype() != loco::DataType::FLOAT32) + return false; + + if (after->rank() != 2) + return false; + + if (after->size<loco::DataType::FLOAT32>() != before->size<loco::DataType::FLOAT32>()) + return false; + + auto after_dim0 = after->dim(0).value(); + auto after_dim1 = after->dim(1).value(); + + if (before->rank() == 2) + { + if (do_w_x) + { + // Check for [dim0, dim1] --> [dim0, dim1] + if (!(after->dim(0) == before->dim(0) && after->dim(1) == before->dim(1))) + return false; + + for (uint32_t i = 0; i < after->size<loco::DataType::FLOAT32>(); ++i) + if (after->at<loco::DataType::FLOAT32>(i) != before->at<loco::DataType::FLOAT32>(i)) + return false; + } + else + { + // Check for [dim0, dim1] --> [dim1, dim0] + if (!(after->dim(0) == before->dim(1) && after->dim(1) == before->dim(0))) + return false; + + for (uint32_t i = 0; i < after_dim0; ++i) + for (uint32_t j = 0; j < after_dim1; ++j) + if (after->at<loco::DataType::FLOAT32>(i * after_dim1 + j) != + before->at<loco::DataType::FLOAT32>(j * after_dim0 + i)) + return false; + } + + return true; + } + else if (before->rank() == 3) + { + if (do_w_x) + { + // This case is not found yet. + return false; + } + else + { + // When Einsum op is converted to FullyConnected, original rank can be 3. + auto before_dim0 = before->dim(0).value(); + auto before_dim1 = before->dim(1).value(); + auto before_dim2 = before->dim(2).value(); + + // Check if [dim0, dim1, dim2] --> [dim2, dim0 * dim1] or + // [dim0, dim1, dim2] --> [dim1 * dim2, dim0] + if ((after_dim0 == before_dim1 * before_dim2 && after_dim1 == before_dim0) || + (after_dim0 == before_dim2 && after_dim1 == before_dim0 * before_dim1)) + { + for (uint32_t i = 0; i < after_dim0; ++i) + for (uint32_t j = 0; j < after_dim1; ++j) + if (after->at<loco::DataType::FLOAT32>(i * after_dim1 + j) != + before->at<loco::DataType::FLOAT32>(j * after_dim0 + i)) + return false; + } + } + + return true; + } + + return false; +} + +} // namespace + +namespace +{ + // V means the version of BCQ. template <int32_t V> class BCQFuser; @@ -38,11 +117,9 @@ public: } public: - bool fuseBCQ(loco::Graph *g) + void register_bcq_info(loco::Graph *g) { - - const auto output_nodes = loco::output_nodes(g); - for (auto node : output_nodes) + for (auto node : loco::output_nodes(g)) { auto output_node = loco::must_cast<luci::CircleOutput *>(node); @@ -61,28 +138,29 @@ public: add_BCQ_info_node(prefix, metadata_type, circle_node); } } + } + bool fuseBCQ(loco::Graph *g) + { if (!is_bcqinfo_valid()) return false; - for (auto f : _fusable_op) + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) { - auto prefix = f.first; - luci::CircleNode *node = f.second; - - if (!is_valid_prefix(prefix)) - continue; - // Fuse Gather to BCQGather if (auto gather = dynamic_cast<luci::CircleGather *>(node)) { if (auto params = dynamic_cast<luci::CircleConst *>(gather->params())) { + auto prefix = get_prefix_of_const(params); + if (prefix == -1 || !is_valid_prefix(prefix)) + continue; + auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>(); bcq_gather->op_version(1); - bcq_gather->input_scales(_alpha[prefix]); - bcq_gather->input_binary(_packed_binary_code[prefix]); + bcq_gather->input_scales(alpha(g, prefix)); + bcq_gather->input_binary(packed_binary_code(g, prefix)); bcq_gather->indices(gather->indices()); bcq_gather->input_clusters(packed_clusters(g, prefix)); @@ -122,29 +200,20 @@ public: } } - // Einsum is unpacked to FullyConnected, Pack and Reshape - if (auto reshape = dynamic_cast<luci::CircleReshape *>(node)) - { - node = dynamic_cast<luci::CircleNode *>(reshape->tensor()); - } - if (auto pack = dynamic_cast<luci::CirclePack *>(node)) - { - if (pack->values_count() == 1 && pack->rank() == 3) - { - node = dynamic_cast<luci::CircleNode *>(pack->values(0)); - } - } - // Fuse FullyConnected to BCQFullyConnected if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node)) { if (auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights())) { + auto prefix = get_prefix_of_const(weights); + if (prefix == -1 || !is_valid_prefix(prefix)) + continue; + auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>(); bcq_fc->op_version(1); - bcq_fc->weights_scales(_alpha[prefix]); - bcq_fc->weights_binary(_packed_binary_code[prefix]); + bcq_fc->weights_scales(alpha(g, prefix)); + bcq_fc->weights_binary(packed_binary_code(g, prefix)); bcq_fc->bias(fully_connected->bias()); bcq_fc->weights_clusters(packed_clusters(g, prefix)); bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction()); @@ -179,43 +248,69 @@ public: } // If x_w formation, we should insert Transpose in front and back of BCQFullyConnected - if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0)) - { - bcq_fc->weights_hidden_size(weights->dim(0).value()); - bcq_fc->input(bcq_input); - loco::replace(fully_connected).with(bcq_fc); - } - else - { - bcq_fc->weights_hidden_size(weights->dim(1).value()); + bcq_fc->weights_hidden_size(weights->dim(1).value()); - auto perm = g->nodes()->create<luci::CircleConst>(); - perm->dtype(loco::DataType::S32); - perm->size<loco::DataType::S32>(2); - perm->rank(1); - perm->dim(0) = 2; - perm->at<loco::DataType::S32>(0) = 1; - perm->at<loco::DataType::S32>(1) = 0; - perm->shape_status(luci::ShapeStatus::VALID); + auto perm = g->nodes()->create<luci::CircleConst>(); + perm->dtype(loco::DataType::S32); + perm->size<loco::DataType::S32>(2); + perm->rank(1); + perm->dim(0) = 2; + perm->at<loco::DataType::S32>(0) = 1; + perm->at<loco::DataType::S32>(1) = 0; + perm->shape_status(luci::ShapeStatus::VALID); - auto input_transpose = g->nodes()->create<luci::CircleTranspose>(); - input_transpose->a(bcq_input); - input_transpose->perm(perm); + auto input_transpose = g->nodes()->create<luci::CircleTranspose>(); + input_transpose->a(bcq_input); + input_transpose->perm(perm); - bcq_fc->input(input_transpose); + bcq_fc->input(input_transpose); - auto output_transpose = g->nodes()->create<luci::CircleTranspose>(); - output_transpose->a(bcq_fc); - output_transpose->perm(perm); + auto output_transpose = g->nodes()->create<luci::CircleTranspose>(); + output_transpose->a(bcq_fc); + output_transpose->perm(perm); - loco::replace(fully_connected).with(output_transpose); - } + loco::replace(fully_connected).with(output_transpose); return true; } - else + else if (auto weights_as_input = + dynamic_cast<luci::CircleConst *>(fully_connected->input())) { - // TODO Is there any case that input() is constant, instead of weights()? + auto prefix = get_prefix_of_const(weights_as_input); + if (prefix == -1 || !is_valid_prefix(prefix)) + continue; + + assert(_do_w_x[prefix]->at<loco::DataType::BOOL>(0) == true); + + auto perm = g->nodes()->create<luci::CircleConst>(); + perm->dtype(loco::DataType::S32); + perm->size<loco::DataType::S32>(2); + perm->rank(1); + perm->dim(0) = 2; + perm->at<loco::DataType::S32>(0) = 1; + perm->at<loco::DataType::S32>(1) = 0; + perm->shape_status(luci::ShapeStatus::VALID); + + auto input_transpose = g->nodes()->create<luci::CircleTranspose>(); + input_transpose->a(fully_connected->weights()); + input_transpose->perm(perm); + + auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>(); + + assert(dynamic_cast<luci::CircleOutputExclude *>(fully_connected->bias()) != nullptr); + + bcq_fc->op_version(1); + bcq_fc->weights_scales(alpha(g, prefix)); + bcq_fc->weights_binary(packed_binary_code(g, prefix)); + bcq_fc->bias(fully_connected->bias()); + bcq_fc->weights_clusters(packed_clusters(g, prefix)); + bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction()); + + bcq_fc->weights_hidden_size(weights_as_input->dim(1).value()); + bcq_fc->input(input_transpose); + loco::replace(fully_connected).with(bcq_fc); + + return true; } } } @@ -268,6 +363,19 @@ private: _dequant_weight[prefix] = const_node; } + int32_t get_prefix_of_const(luci::CircleConst *w_after) + { + for (auto n : _fusable_op) + { + auto prefix = n.first; + auto w_before = loco::must_cast<luci::CircleConst *>(n.second); + if (is_fusable_const(w_before, w_after, _do_w_x[prefix]->at<loco::DataType::BOOL>(0))) + return prefix; + } + + return -1; + } + bool is_bcqinfo_valid() { LOGGER(l); @@ -332,6 +440,16 @@ private: } } + for (auto n : _fusable_op) + { + // fusable_op should be FLOAT32 type + if (n.second->dtype() != loco::DataType::FLOAT32) + { + WARN(l) << "FuseBCQPass : fusable_op has wrong type" << std::endl; + return false; + } + } + // As dequant_weight is not used for fusing, skip validation. return true; @@ -377,12 +495,50 @@ private: return false; } + if (_fusable_op.find(prefix) == _fusable_op.end()) + { + WARN(l) << "fusable_op is not found" << std::endl; + return false; + } + // As dequant_weight is not used for fusing, skip validation. return true; } private: + luci::CircleConst *alpha(loco::Graph *graph, int32_t prefix) + { + auto new_alpha = graph->nodes()->create<luci::CircleConst>(); + + new_alpha->dtype(loco::DataType::FLOAT32); + new_alpha->size<loco::DataType::FLOAT32>(_alpha[prefix]->size<loco::DataType::FLOAT32>()); + new_alpha->rank(1); + new_alpha->dim(0) = _alpha[prefix]->dim(0); + for (uint32_t i = 0; i < _alpha[prefix]->size<loco::DataType::FLOAT32>(); ++i) + new_alpha->at<loco::DataType::FLOAT32>(i) = _alpha[prefix]->at<loco::DataType::FLOAT32>(i); + new_alpha->shape_status(luci::ShapeStatus::VALID); + + return new_alpha; + } + + luci::CircleConst *packed_binary_code(loco::Graph *graph, int32_t prefix) + { + auto new_beta = graph->nodes()->create<luci::CircleConst>(); + + new_beta->dtype(loco::DataType::S32); + new_beta->size<loco::DataType::S32>(_packed_binary_code[prefix]->size<loco::DataType::S32>()); + new_beta->rank(2); + new_beta->dim(0) = _packed_binary_code[prefix]->dim(0); + new_beta->dim(1) = _packed_binary_code[prefix]->dim(1); + for (uint32_t i = 0; i < _packed_binary_code[prefix]->size<loco::DataType::S32>(); ++i) + new_beta->at<loco::DataType::S32>(i) = + _packed_binary_code[prefix]->at<loco::DataType::S32>(i); + new_beta->shape_status(luci::ShapeStatus::VALID); + + return new_beta; + } + luci::CircleConst *packed_clusters(loco::Graph *graph, int32_t prefix) { auto qbits_of_clusters = _qbits_of_clusters[prefix]; @@ -428,15 +584,17 @@ private: namespace luci { -bool FuseBCQPass::run(loco::Graph *g) +bool FuseBCQPass::run(luci::Module *m) { bool changed = false; const int32_t start_magicnum = -2e9 + 27; const int32_t end_magicnum = 2e9 - 27; + loco::Graph *main_graph = m->graph(0); + luci::CircleConst *metadata_node = nullptr; - for (auto node : loco::output_nodes(g)) + for (auto node : loco::output_nodes(main_graph)) { auto output_node = loco::must_cast<luci::CircleOutput *>(node); @@ -474,8 +632,11 @@ bool FuseBCQPass::run(loco::Graph *g) const auto bundle_cnt = metadata_node->at<loco::DataType::S32>(3); BCQFuser<1> fuser{original_output_cnt, bundle_cnt}; - if (fuser.fuseBCQ(g)) - changed = true; + fuser.register_bcq_info(main_graph); + + for (size_t g = 0; g < m->size(); ++g) + if (fuser.fuseBCQ(m->graph(g))) + changed = true; } else { @@ -486,12 +647,12 @@ bool FuseBCQPass::run(loco::Graph *g) // Remove all of BCQ information nodes iff there is no change if (changed == false) { - for (auto node : loco::output_nodes(g)) + for (auto node : loco::output_nodes(main_graph)) { auto output_node = loco::must_cast<luci::CircleOutput *>(node); if (output_node->index() == 0 || (int)output_node->index() > original_output_cnt) { - auto noOp = g->nodes()->create<luci::CircleOutputExclude>(); + auto noOp = main_graph->nodes()->create<luci::CircleOutputExclude>(); noOp->dtype(loco::DataType::FLOAT32); // TODO Remove this setting output_node->from(noOp); changed = true; @@ -503,4 +664,10 @@ bool FuseBCQPass::run(loco::Graph *g) return changed; } +bool FuseBCQPass::run(loco::Graph *) +{ + // Do nothing for graph + return false; +} + } // namespace luci diff --git a/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp b/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp new file mode 100644 index 000000000..beb962a05 --- /dev/null +++ b/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/MigrateLegacyShapeDtypePass.h" + +#include <loco/Service/ShapeInference.h> +#include <loco/Service/TypeInference.h> + +#include <luci/IR/CircleNodes.h> + +#include <loco.h> + +namespace +{ + +bool has_same_shape(luci::CircleNode *node, loco::TensorShape shape) +{ + if (node->rank() != shape.rank()) + return false; + + for (uint32_t i = 0; i < shape.rank(); ++i) + if (!(node->dim(i) == shape.dim(i))) + return false; + + return true; +} + +} // namespace + +namespace luci +{ + +bool MigrateLegacyShapeDtypePass::run(luci::Module *m) +{ + bool changed = false; + + for (size_t g = 0; g < m->size(); ++g) + { + if (run(m->graph(g))) + changed = true; + } + + return changed; +} + +bool MigrateLegacyShapeDtypePass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::all_nodes(g)) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (loco::shape_known(node)) + { + auto loco_shape = loco::shape_get(node).as<loco::TensorShape>(); + + assert(circle_node->shape_signature().rank() == 0 || + circle_node->shape_signature().rank() == loco_shape.rank()); + + // When shape of loco is copied to circle node, ShapeSignature should be applied. + loco::TensorShape new_shape; + new_shape.rank(loco_shape.rank()); + for (uint32_t i = 0; i < loco_shape.rank(); ++i) + { + if (circle_node->shape_signature().rank() > 0 && + circle_node->shape_signature().dim(i) == -1) + new_shape.dim(i) = 1; + else + new_shape.dim(i) = loco_shape.dim(i); + } + + if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED || + !has_same_shape(circle_node, new_shape)) + { + circle_node->rank(new_shape.rank()); + for (uint32_t i = 0; i < new_shape.rank(); ++i) + circle_node->dim(i) = new_shape.dim(i); + + if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED) + circle_node->shape_status(luci::ShapeStatus::VALID); + + changed = true; + } + } + + if (loco::dtype_known(node)) + { + if (loco::dtype_get(node) != circle_node->dtype()) + { + circle_node->dtype(loco::dtype_get(node)); + changed = true; + } + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ModulePhase.cpp b/compiler/luci/pass/src/ModulePhase.cpp new file mode 100644 index 000000000..46819a0f7 --- /dev/null +++ b/compiler/luci/pass/src/ModulePhase.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ModulePhase.h" + +namespace luci +{ + +void PhaseRunner<logo::PhaseStrategy::Saturate>::run(const Phase &phase) const +{ + notifyPhaseBegin(); + + for (bool changed = true; changed;) + { + changed = false; + + for (auto &pass : phase) + { + notifyPassBegin(pass.get()); + + bool pass_changed = pass->run(_module); + changed = changed || pass_changed; + + notifyPassEnd(pass.get(), pass_changed); + } + } + + notifyPhaseEnd(); +} + +void PhaseRunner<logo::PhaseStrategy::Restart>::run(const Phase &phase) const +{ + notifyPhaseBegin(); + + for (bool changed = true; changed;) + { + changed = false; + + for (auto &pass : phase) + { + notifyPassBegin(pass.get()); + + bool pass_changed = pass->run(_module); + changed = changed || pass_changed; + + notifyPassEnd(pass.get(), pass_changed); + + if (changed) + { + break; + } + } + } + + notifyPhaseEnd(); +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ModulePhase.h b/compiler/luci/pass/src/ModulePhase.h new file mode 100644 index 000000000..05966cc29 --- /dev/null +++ b/compiler/luci/pass/src/ModulePhase.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MODULE_PHASE_H__ +#define __MODULE_PHASE_H__ + +#include <luci/ModulePass.h> + +#include <logo/Phase.h> + +#include <vector> + +namespace luci +{ + +using Phase = std::vector<std::unique_ptr<Pass>>; + +template <logo::PhaseStrategy S> class PhaseRunner; + +template <> +class PhaseRunner<logo::PhaseStrategy::Saturate> final : public logo::PhaseRunnerMixinObservable +{ +public: + PhaseRunner(luci::Module *module) : _module{module} + { + // DO NOTHING + } + +public: + void run(const Phase &) const; + +private: + luci::Module *_module; +}; + +template <> +class PhaseRunner<logo::PhaseStrategy::Restart> final : public logo::PhaseRunnerMixinObservable +{ +public: + PhaseRunner(luci::Module *module) : _module{module} + { + // DO NOTHING + } + +public: + void run(const Phase &) const; + +private: + luci::Module *_module; +}; + +} // namespace luci + +#endif // __MODULE_PHASE_H__ diff --git a/compiler/luci/pass/src/ProgressReporter.cpp b/compiler/luci/pass/src/ProgressReporter.cpp index dcf47aba6..515739dc7 100644 --- a/compiler/luci/pass/src/ProgressReporter.cpp +++ b/compiler/luci/pass/src/ProgressReporter.cpp @@ -81,4 +81,46 @@ void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassE INFO(prime) << luci::fmt(graph()); } +void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseBegin> *) +{ + LOGGER(prime); + + INFO(prime) << "=============================================================="; + INFO(prime) << "ModulePhaseRunner<" << to_str(strategy()) << ">"; + INFO(prime) << "Initial graphs"; + for (size_t g = 0; g < module()->size(); ++g) + { + INFO(prime) << "graphs #" << g; + INFO(prime) << luci::fmt(module()->graph(g)); + } +} + +void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseEnd> *) +{ + LOGGER(prime); + + INFO(prime) << "ModulePhaseRunner<" << to_str(strategy()) << "> - done"; +} + +void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassBegin> *info) +{ + LOGGER(prime); + + INFO(prime) << "--------------------------------------------------------------"; + INFO(prime) << "Before " << logo::pass_name(info->pass()); +} + +void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassEnd> *info) +{ + LOGGER(prime); + + INFO(prime) << "After " << logo::pass_name(info->pass()) + << " (changed: " << to_char(info->changed()) << ")"; + for (size_t g = 0; g < module()->size(); ++g) + { + INFO(prime) << "graphs #" << g; + INFO(prime) << luci::fmt(module()->graph(g)); + } +} + } // namespace luci diff --git a/compiler/luci/pass/src/ProgressReporter.h b/compiler/luci/pass/src/ProgressReporter.h index bd2ba9849..cf30da735 100644 --- a/compiler/luci/pass/src/ProgressReporter.h +++ b/compiler/luci/pass/src/ProgressReporter.h @@ -21,6 +21,8 @@ #include <loco.h> +#include <luci/IR/Module.h> + namespace luci { @@ -48,6 +50,30 @@ private: logo::PhaseStrategy _strategy; }; +class ModuleProgressReporter : public logo::PhaseEventListener +{ +public: + ModuleProgressReporter(luci::Module *module, logo::PhaseStrategy strategy) + : _module{module}, _strategy{strategy} + { + // DO NOTHING + } + +public: + void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseBegin> *) override; + void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseEnd> *) override; + void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassBegin> *) override; + void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassEnd> *) override; + +public: + luci::Module *module(void) const { return _module; } + logo::PhaseStrategy strategy(void) const { return _strategy; } + +private: + luci::Module *_module; + logo::PhaseStrategy _strategy; +}; + } // namespace luci #endif // __LUCI_PROGRESSREPORTER_H__ diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.cpp new file mode 100644 index 000000000..af83cd83b --- /dev/null +++ b/compiler/luci/pass/src/PropagateQuantParamPass.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/PropagateQuantParamPass.h" + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Log.h> + +#include <iostream> + +namespace +{ + +bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst) +{ + assert(src->scale.size() == dst->scale.size()); + assert(src->zerop.size() == dst->zerop.size()); + + // src and dst have the same qparam + if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) && + std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) && + src->quantized_dimension == dst->quantized_dimension) + return false; + + dst->scale.assign(src->scale.begin(), src->scale.end()); + dst->zerop.assign(src->zerop.begin(), src->zerop.end()); + dst->quantized_dimension = src->quantized_dimension; + return true; +} + +bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst) +{ + // Skip nodes that do not have quantparams + auto src_qparam = src->quantparam(); + if (not src_qparam) + return false; + + auto dst_qparam = dst->quantparam(); + if (not dst_qparam) + return false; + + return copy_qparam(src_qparam, dst_qparam); +} + +// Visitor to propagate quantization parameters +struct PropagateQuantParam final : public luci::CircleNodeMutableVisitor<bool> +{ + PropagateQuantParam() = default; + + bool visit(luci::CircleNode *) { return false; } + + bool visit(luci::CircleReshape *node) + { + auto input = node->tensor(); + if (loco::succs(input).size() != 1) + return false; + + auto input_node = loco::must_cast<luci::CircleNode *>(input); + return copy_qparam(node, input_node); + } + + // TODO : Add more Ops (e.g., Transpose) +}; + +} // namespace + +namespace luci +{ + +bool PropagateQuantParamPass::run(loco::Graph *g) +{ + bool changed = false; + LOGGER(l); + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl; + + PropagateQuantParam pqp; + changed = circle_node->accept(&pqp); + if (changed) + break; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp new file mode 100644 index 000000000..15adbfc01 --- /dev/null +++ b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/PropagateQuantParamPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale, + const std::vector<int64_t> &zp) +{ + assert(node->quantparam() == nullptr); + + auto quantparam = std::make_unique<luci::CircleQuantParam>(); + quantparam->scale = scale; + quantparam->zerop = zp; + node->quantparam(std::move(quantparam)); +} + +/** + * Simple graph for test + * + * BEFORE + * + * [Conv] (qparam 1) + * | + * [Reshape] (qparam 2) + * + * AFTER + * + * [Conv] (qparam 2) + * | + * [Reshape] (qparam 2) + * + */ +class SimpleGraph +{ +public: + SimpleGraph() + { + input = g.nodes()->create<luci::CircleInput>(); + conv = g.nodes()->create<luci::CircleConv2D>(); + reshape = g.nodes()->create<luci::CircleReshape>(); + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20}); + addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10}); + + conv->input(input); + reshape->tensor(conv); + output->from(reshape); + } + +public: + loco::Graph g; + luci::CircleInput *input; + luci::CircleConv2D *conv; + luci::CircleReshape *reshape; + luci::CircleOutput *output; +}; + +} // namespace + +TEST(PropagateQuantParam, simple) +{ + SimpleGraph g; + + luci::PropagateQuantParamPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[0]); + EXPECT_FLOAT_EQ(0.4, g.conv->quantparam()->scale[1]); + EXPECT_FLOAT_EQ(0.6, g.conv->quantparam()->scale[2]); + EXPECT_EQ(-10, g.conv->quantparam()->zerop[0]); + EXPECT_EQ(0, g.conv->quantparam()->zerop[1]); + EXPECT_EQ(10, g.conv->quantparam()->zerop[2]); +} + +TEST(PropagateQuantParam, wrong_op_NEG) +{ + SimpleGraph g; + g.output->from(g.conv); + g.reshape->drop(); + + luci::PropagateQuantParamPass pass; + while (pass.run(&g.g)) + ; + + EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]); + EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]); + EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]); + EXPECT_EQ(0, g.conv->quantparam()->zerop[0]); + EXPECT_EQ(10, g.conv->quantparam()->zerop[1]); + EXPECT_EQ(20, g.conv->quantparam()->zerop[2]); +} diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp index 0ecab008f..f6eebe3b9 100644 --- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp +++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp @@ -86,6 +86,100 @@ void quant_const_values(luci::CircleConst *const_node, float scaling_factor, flo } } +// Quantize const per channel +// +// The last dimension of const is the same as the dimension of channel +// And the rest of the const dimensions should be 1 +// So, a 'single value' is quantized per channel +// +// Quantization spec (f: fp value, q: quantized value) +// +// uint8 +// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0] +// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1] +// +// int16 +// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0] +// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0] +void quant_const_per_channel(CircleConst *node, loco::DataType quant_type) +{ + assert(node->dtype() == loco::DataType::FLOAT32); + assert(node->rank() > 0); + + for (uint32_t i = 0; i < node->rank() - 1; i++) + { + // Caller should call this function when the below condition is satisfied + if (node->dim(i).value() != 1) + throw std::runtime_error("Non-channel dimension of const node must be 1"); + } + + uint32_t size = node->size<loco::DataType::FLOAT32>(); + assert(size == node->dim(node->rank() - 1).value()); + + auto quantparam = std::make_unique<CircleQuantParam>(); + quantparam->quantized_dimension = node->rank() - 1; + std::vector<int32_t> quantized_data(size); + + for (uint32_t i = 0; i < size; ++i) + { + auto data = node->at<loco::DataType::FLOAT32>(i); + if (quant_type == loco::DataType::U8) + { + if (data >= 0) + { + quantparam->scale.push_back(data); + quantparam->zerop.push_back(0); + quantized_data[i] = 1; + } + else + { + quantparam->scale.push_back(-data); + quantparam->zerop.push_back(1); + quantized_data[i] = 0; + } + } + else if (quant_type == loco::DataType::S16) + { + if (data >= 0) + { + quantparam->scale.push_back(data); + quantized_data[i] = 1; + } + else + { + quantparam->scale.push_back(-data); + quantized_data[i] = -1; + } + quantparam->zerop.push_back(0); + } + } + node->quantparam(std::move(quantparam)); + + switch (quant_type) + { + case loco::DataType::U8: + node->dtype(loco::DataType::U8); + node->size<loco::DataType::U8>(size); + for (uint32_t i = 0; i < size; ++i) + { + assert(quantized_data[i] == 0 || quantized_data[i] == 1); + node->at<loco::DataType::U8>(i) = quantized_data[i]; + } + break; + case loco::DataType::S16: + node->dtype(loco::DataType::S16); + node->size<loco::DataType::S16>(size); + for (uint32_t i = 0; i < size; ++i) + { + assert(quantized_data[i] == -1 || quantized_data[i] == 1); + node->at<loco::DataType::S16>(i) = quantized_data[i]; + } + break; + default: + throw std::runtime_error("Unsupported data type"); + } +} + void quant_const(CircleConst *node, loco::DataType quant_type) { assert(node->dtype() == loco::DataType::FLOAT32); @@ -612,10 +706,51 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool> } }; +void quant_instnorm(luci::CircleInstanceNorm *node, loco::DataType output_type, + QuantizationGranularity granularity) +{ + auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma()); + auto beta = loco::must_cast<luci::CircleConst *>(node->beta()); + assert(gamma->dtype() == loco::DataType::FLOAT32); + assert(beta->dtype() == loco::DataType::FLOAT32); + + if (granularity == QuantizationGranularity::LayerWise) + { + quant_const(gamma, output_type); + quant_const(beta, output_type); + } + else if (granularity == QuantizationGranularity::ChannelWise) + { + quant_const_per_channel(gamma, output_type); + quant_const_per_channel(beta, output_type); + } + else + throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'"); +} + +void quant_prelu(luci::CirclePRelu *node, loco::DataType output_type, + QuantizationGranularity granularity) +{ + auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha()); + assert(alpha->dtype() == loco::DataType::FLOAT32); + + if (granularity == QuantizationGranularity::LayerWise) + { + quant_const(alpha, output_type); + } + else if (granularity == QuantizationGranularity::ChannelWise) + { + quant_const_per_channel(alpha, output_type); + } + else + throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'"); +} + /** * @brief Quantize const input tensors using min/max of const values */ -void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type) +void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type, + QuantizationGranularity granularity) { auto opcode = node->opcode(); auto arity = node->arity(); @@ -660,20 +795,26 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type) quant_const(const_node, output_type); break; + case luci::CircleOpcode::INSTANCE_NORM: + quant_instnorm(loco::must_cast<luci::CircleInstanceNorm *>(node), output_type, granularity); + break; + + case luci::CircleOpcode::PRELU: + quant_prelu(loco::must_cast<luci::CirclePRelu *>(node), output_type, granularity); + break; + case luci::CircleOpcode::ADD: case luci::CircleOpcode::ADD_N: case luci::CircleOpcode::DIV: case luci::CircleOpcode::EQUAL: case luci::CircleOpcode::GREATER: case luci::CircleOpcode::GREATER_EQUAL: - case luci::CircleOpcode::INSTANCE_NORM: case luci::CircleOpcode::LESS: case luci::CircleOpcode::LESS_EQUAL: case luci::CircleOpcode::MAXIMUM: case luci::CircleOpcode::MINIMUM: case luci::CircleOpcode::MUL: case luci::CircleOpcode::NOT_EQUAL: - case luci::CircleOpcode::PRELU: case luci::CircleOpcode::SUB: // Quantize all const inputs using their values for (uint32_t i = 0; i < arity; i++) @@ -817,7 +958,7 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g) for (auto node : loco::active_nodes(loco::output_nodes(g))) { auto circle_node = loco::must_cast<luci::CircleNode *>(node); - quantize_const_inputs(circle_node, _output_dtype); + quantize_const_inputs(circle_node, _output_dtype, _granularity); } // Propagate quantization parameters of concat Op diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.cpp b/compiler/luci/pass/src/RemoveRedundantTranspose.cpp new file mode 100644 index 000000000..33cb76520 --- /dev/null +++ b/compiler/luci/pass/src/RemoveRedundantTranspose.cpp @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/RemoveRedundantTransposePass.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +/// @brief Return true if first_perm[second_perm[i]] == i +bool check_perm(const luci::CircleConst *first_perm, const luci::CircleConst *second_perm) +{ + assert(first_perm->rank() == 1); + assert(second_perm->rank() == 1); + assert(second_perm->size<loco::DataType::S32>() == first_perm->size<loco::DataType::S32>()); + for (int32_t i = 0; i < static_cast<int32_t>(first_perm->size<loco::DataType::S32>()); i++) + { + if (first_perm->at<loco::DataType::S32>(second_perm->at<loco::DataType::S32>(i)) != i) + return false; + } + return true; +} + +bool remove_consecutive_transpose_function(luci::CircleNode *node) +{ + auto target_node = dynamic_cast<luci::CircleTranspose *>(node); + if (target_node == nullptr) + return false; + auto pred_node = dynamic_cast<luci::CircleTranspose *>(target_node->a()); + if (pred_node == nullptr) + return false; + if (loco::succs(pred_node).size() != 1) + return false; + + auto pred_perm = dynamic_cast<luci::CircleConst *>(target_node->perm()); + if (pred_perm == nullptr) + return false; + + auto main_perm = dynamic_cast<luci::CircleConst *>(pred_node->perm()); + if (main_perm == nullptr) + return false; + + auto main_node = loco::must_cast<luci::CircleNode *>(pred_node->a()); + if (check_perm(pred_perm, main_perm)) + { + replace(node).with(main_node); + } + else + { + auto g = main_perm->graph(); + auto new_const_node = g->nodes()->create<luci::CircleConst>(); + + new_const_node->dtype(loco::DataType::S32); + new_const_node->rank(1); + new_const_node->dim(0) = main_perm->dim(0); + new_const_node->size<loco::DataType::S32>(main_perm->dim(0).value()); + new_const_node->shape_status(luci::ShapeStatus::VALID); + for (uint32_t i = 0; i < main_perm->size<loco::DataType::S32>(); i++) + { + new_const_node->at<loco::DataType::S32>(i) = + pred_perm->at<loco::DataType::S32>(main_perm->at<loco::DataType::S32>(i)); + } + pred_node->perm(new_const_node); + replace(node).with(pred_node); + } + return true; +} + +} // namespace + +namespace luci +{ +/** + * BEFORE + * | + * [CircleNode] [CircleConst] + * (main_node) (main_perm) + * \ / + * [CircleTranspose] [CircleConst] + * (pred_node) (pred_perm) + * \ / + * [CircleTranspose] + * (target_node) + * | + * + * AFTER + * <Optional Case> + * + * | | | + * [CircleNode] [CircleConst] | + * (main_node) (new_const_node) | + * \ / or [CircleNode] + * [CircleTranspose] (main_node) + * (pred_node) | + * | | + * + */ +bool RemoveRedundantTransposePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (remove_consecutive_transpose_function(circle_node)) + { + changed = true; + break; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp b/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp new file mode 100644 index 000000000..db608b674 --- /dev/null +++ b/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/RemoveRedundantTransposePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <vector> + +#include <gtest/gtest.h> + +namespace +{ + +void setValue(luci::CircleConst *node, const std::vector<int> &v) +{ + node->dtype(loco::DataType::S32); + node->size<loco::DataType::S32>(v.size()); + node->rank(1); + node->dim(0).set(v.size()); + for (int i = 0; i < v.size(); ++i) + { + node->at<loco::DataType::S32>(i) = v[i]; + } +} + +/** + * Type1 + * BEFORE + * | + * [CircleNode] [CircleConst] + * \ / + * [CircleTranspose] [CircleConst] + * \ / + * [CircleTranspose] + * | + * + * AFTER + * | + * [CircleNode] + * | Remove Both + * + * -------------------------------------------- + * + * Type2 + * BEFORE + * | + * [CircleNode] [CircleConst] + * \ / + * [CircleTranspose] [CircleConst] + * \ / + * [CircleTranspose] + * | + * + * AFTER + * | | + * [CircleNode] [CircleConst] + * \ / + * [CircleTranspose] + * | + * + */ +void create_redundunt_transpose(loco::Graph *g, const std::vector<int32_t> &perm1, + const std::vector<int32_t> &perm2) +{ + assert(g); + + auto input = g->nodes()->create<luci::CircleInput>(); + auto graph_input = g->inputs()->create(); + input->index(graph_input->index()); + + // Create perm1 + auto perm1_node = g->nodes()->create<luci::CircleConst>(); + setValue(perm1_node, perm1); + + auto transpose1 = g->nodes()->create<luci::CircleTranspose>(); + transpose1->dtype(loco::DataType::FLOAT32); + transpose1->a(input); + transpose1->perm(perm1_node); + + // Create perm2 + auto perm2_node = g->nodes()->create<luci::CircleConst>(); + setValue(perm2_node, perm2); + + auto transpose2 = g->nodes()->create<luci::CircleTranspose>(); + transpose2->dtype(loco::DataType::FLOAT32); + transpose2->a(transpose1); + transpose2->perm(perm2_node); + + // Output + auto output = g->nodes()->create<luci::CircleOutput>(); + output->from(transpose2); + auto graph_output = g->outputs()->create(); + output->index(graph_output->index()); +} + +} // namespace + +TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type1) +{ + auto graph = loco::make_graph(); + create_redundunt_transpose(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3}); + + luci::RemoveRedundantTransposePass pass; + while (pass.run(graph.get())) + ; + luci::CircleTranspose *transpose_node = nullptr; + for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + { + auto trans = dynamic_cast<luci::CircleTranspose *>(node); + if (not trans) + continue; + transpose_node = trans; + break; + } + // No transpose node is in graph. + ASSERT_EQ(nullptr, transpose_node); +} + +TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2) +{ + auto graph = loco::make_graph(); + create_redundunt_transpose(graph.get(), {0, 1, 3, 2}, {1, 0, 2, 3}); + + luci::RemoveRedundantTransposePass pass; + while (pass.run(graph.get())) + ; + luci::CircleTranspose *transpose_node = nullptr; + for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + { + auto trans = dynamic_cast<luci::CircleTranspose *>(node); + if (not trans) + continue; + transpose_node = trans; + break; + } + // Just one transpose node, with updated perm constant. + ASSERT_NE(nullptr, transpose_node); + auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm()); + ASSERT_EQ(1, perm->at<loco::DataType::S32>(0)); + ASSERT_EQ(0, perm->at<loco::DataType::S32>(1)); + ASSERT_EQ(3, perm->at<loco::DataType::S32>(2)); + ASSERT_EQ(2, perm->at<loco::DataType::S32>(3)); +} diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp new file mode 100644 index 000000000..7096c2591 --- /dev/null +++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp @@ -0,0 +1,223 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma) +{ + assert(gamma->rank() == 1); + auto channel_size = gamma->dim(0).value(); + + // Channel-wise MUL is the same as DEPTHWISE_CONV2D with filter shape (1,1,1,channel_size) + auto weights = gamma->graph()->nodes()->create<luci::CircleConst>(); + weights->dtype(loco::DataType::FLOAT32); + weights->rank(4); + weights->dim(0).set(1); + weights->dim(1).set(1); + weights->dim(2).set(1); + weights->dim(3).set(channel_size); + weights->shape_status(luci::ShapeStatus::VALID); + weights->size<loco::DataType::FLOAT32>(channel_size); + for (uint32_t i = 0; i < channel_size; i++) + { + weights->at<loco::DataType::FLOAT32>(i) = gamma->at<loco::DataType::FLOAT32>(i); + } + + return weights; +} + +luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta) +{ + assert(beta->rank() == 1); + auto channel_size = beta->dim(0).value(); + + // Channel-wise ADD is the same as bias (shape = (channel_size)) of DEPTHWISE_CONV2D + auto bias = beta->graph()->nodes()->create<luci::CircleConst>(); + bias->dtype(loco::DataType::FLOAT32); + bias->rank(1); + bias->dim(0).set(channel_size); + bias->size<loco::DataType::FLOAT32>(channel_size); + bias->shape_status(luci::ShapeStatus::VALID); + for (uint32_t i = 0; i < channel_size; i++) + { + bias->at<loco::DataType::FLOAT32>(i) = beta->at<loco::DataType::FLOAT32>(i); + } + + return bias; +} + +bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta) +{ + auto x = loco::must_cast<luci::CircleNode *>(add->x()); + auto y = loco::must_cast<luci::CircleNode *>(add->y()); + + luci::CircleMul *pred = nullptr; + luci::CircleConst *constant = nullptr; + + if (x->opcode() == luci::CircleOpcode::CIRCLECONST && y->opcode() == luci::CircleOpcode::MUL) + { + pred = loco::must_cast<luci::CircleMul *>(y); + constant = loco::must_cast<luci::CircleConst *>(x); + } + else if (x->opcode() == luci::CircleOpcode::MUL && y->opcode() == luci::CircleOpcode::CIRCLECONST) + { + pred = loco::must_cast<luci::CircleMul *>(x); + constant = loco::must_cast<luci::CircleConst *>(y); + } + else + { + return false; + } + + if (constant->rank() != 1) + return false; + + auto channel_dim = constant->dim(0); + // Assumption: Layout is channel-last + if (!(channel_dim == add->dim(add->rank() - 1))) + return false; + + mul = pred; + beta = constant; + return true; +} + +// Check if mul is batchnorm mul +bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node, + luci::CircleConst *&gamma) +{ + auto x = dynamic_cast<luci::CircleConst *>(mul->x()); + auto y = dynamic_cast<luci::CircleConst *>(mul->y()); + + luci::CircleNode *pred = nullptr; + luci::CircleConst *constant = nullptr; + + if (x != nullptr && y == nullptr) + { + pred = loco::must_cast<luci::CircleNode *>(mul->y()); + constant = x; + } + else if (x == nullptr && y != nullptr) + { + pred = loco::must_cast<luci::CircleNode *>(mul->x()); + constant = y; + } + else + { + return false; + } + + if (constant->rank() != 1) + return false; + + auto channel_dim = constant->dim(0); + if (!(channel_dim == mul->dim(mul->rank() - 1))) + return false; + + pred_node = pred; + gamma = constant; + return true; +} + +/** + * Replace channel-wise Mul/Add with DepthwiseConv2D + * + * BEFORE + * + * [Node] [gamma] + * | / + * [Mul] [beta] + * | / + * [Add] + * + * AFTER + * + * [Node] [weights] [bias] + * \ / / + * [DepthwiseConv2D] + */ +bool replace_mul_add_with_dwconv(luci::CircleAdd *add) +{ + luci::CircleNode *pred_node = nullptr; + luci::CircleMul *mul = nullptr; + luci::CircleConst *beta = nullptr; + luci::CircleConst *gamma = nullptr; + + if (!is_batchnorm_add(add, mul, beta)) + return false; + + if (loco::succs(mul).size() != 1) + return false; + + if (!is_batchnorm_mul(mul, pred_node, gamma)) + return false; + + if (pred_node->rank() != 4) + return false; + + if (pred_node->dtype() != loco::DataType::FLOAT32 || beta->dtype() != loco::DataType::FLOAT32 || + gamma->dtype() != loco::DataType::FLOAT32) + return false; + + auto weights = create_weights_from_gamma(gamma); + auto bias = create_bias_from_beta(beta); + + auto dwconv = add->graph()->nodes()->create<luci::CircleDepthwiseConv2D>(); + dwconv->input(pred_node); + dwconv->filter(weights); + dwconv->bias(bias); + dwconv->padding(luci::Padding::SAME); + dwconv->stride()->w(1); + dwconv->stride()->h(1); + dwconv->depthMultiplier(1); + dwconv->dilation()->w(1); + dwconv->dilation()->h(1); + dwconv->fusedActivationFunction(add->fusedActivationFunction()); + + loco::replace(add).with(dwconv); + return true; +} + +} // namespace + +namespace luci +{ + +bool ReplaceMulAddWithDepthwiseConvPass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto add = dynamic_cast<luci::CircleAdd *>(node); + if (not add) + continue; + + if (replace_mul_add_with_dwconv(add)) + { + changed = true; + break; + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp new file mode 100644 index 000000000..a90182aaa --- /dev/null +++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +/** + * Simple graph for test + * + * BEFORE + * + * [Node] [gamma] + * | / + * [Mul] [beta] + * | / + * [Add] + * + * AFTER + * + * [Node] [weights] [bias] + * \ / / + * [DepthwiseConv2D] + */ +class SimpleGraph +{ +public: + SimpleGraph() + { + input = g.nodes()->create<luci::CircleInput>(); + mul = g.nodes()->create<luci::CircleMul>(); + gamma = g.nodes()->create<luci::CircleConst>(); + add = g.nodes()->create<luci::CircleAdd>(); + beta = g.nodes()->create<luci::CircleConst>(); + output = g.nodes()->create<luci::CircleOutput>(); + + auto graph_input = g.inputs()->create(); + input->index(graph_input->index()); + auto graph_output = g.outputs()->create(); + output->index(graph_output->index()); + + input->dtype(loco::DataType::FLOAT32); + mul->dtype(loco::DataType::FLOAT32); + gamma->dtype(loco::DataType::FLOAT32); + add->dtype(loco::DataType::FLOAT32); + beta->dtype(loco::DataType::FLOAT32); + output->dtype(loco::DataType::FLOAT32); + + uint32_t channel_size = 16; + input->shape({1, 4, 4, channel_size}); + mul->shape({1, 4, 4, channel_size}); + gamma->shape({channel_size}); + add->shape({1, 4, 4, channel_size}); + beta->shape({channel_size}); + output->shape({1, 4, 4, channel_size}); + + gamma->size<loco::DataType::FLOAT32>(channel_size); + beta->size<loco::DataType::FLOAT32>(channel_size); + for (uint32_t i = 0; i < channel_size; i++) + { + gamma->at<loco::DataType::FLOAT32>(i) = i; + beta->at<loco::DataType::FLOAT32>(i) = i; + } + + mul->x(input); + mul->y(gamma); + add->x(mul); + add->y(beta); + output->from(add); + } + +public: + loco::Graph g; + luci::CircleInput *input = nullptr; + luci::CircleMul *mul = nullptr; + luci::CircleConst *gamma = nullptr; + luci::CircleAdd *add = nullptr; + luci::CircleConst *beta = nullptr; + luci::CircleOutput *output = nullptr; +}; + +} // namespace + +TEST(ReplaceMulAddWithDepthwiseConv, simple) +{ + SimpleGraph g; + + luci::ReplaceMulAddWithDepthwiseConvPass pass; + while (pass.run(&g.g)) + ; + + auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from()); + EXPECT_NE(nullptr, dwconv); + + uint32_t channel_size = 16; + auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter()); + auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias()); + EXPECT_NE(nullptr, weights); + EXPECT_EQ(4, weights->rank()); + EXPECT_EQ(channel_size, weights->dim(3).value()); + EXPECT_NE(nullptr, bias); + EXPECT_EQ(1, bias->rank()); + EXPECT_EQ(channel_size, bias->dim(0).value()); + + for (int i = 0; i < channel_size; i++) + { + EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i)); + EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i)); + } +} + +TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG) +{ + SimpleGraph g; + // swap mul/add (changed to add->mul) + g.add->x(g.input); + loco::replace(g.add).with(g.mul); + g.mul->x(g.add); + + luci::ReplaceMulAddWithDepthwiseConvPass pass; + auto changed = pass.run(&g.g); + + EXPECT_EQ(false, changed); +} diff --git a/compiler/luci/pass/src/ShapeInferencePass.cpp b/compiler/luci/pass/src/ShapeInferencePass.cpp index f681b3d5f..4bd0aaed4 100644 --- a/compiler/luci/pass/src/ShapeInferencePass.cpp +++ b/compiler/luci/pass/src/ShapeInferencePass.cpp @@ -28,6 +28,19 @@ namespace luci { +bool ShapeInferencePass::run(luci::Module *m) +{ + bool changed = false; + + for (size_t g = 0; g < m->size(); ++g) + { + if (run(m->graph(g))) + changed = true; + } + + return changed; +} + bool ShapeInferencePass::run(loco::Graph *g) { loco::CanonicalShapeInferenceRule canonical_rule; diff --git a/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp b/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp new file mode 100644 index 000000000..115b77a96 --- /dev/null +++ b/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ShapeSignatureInferencePass.h" + +#include <luci/IR/CircleShapeSignature.h> +#include <luci/Service/CircleShapeSignatureInference.h> + +#include <loco.h> + +namespace luci +{ + +bool ShapeSignatureInferencePass::run(luci::Module *m) +{ + bool changed = false; + + for (size_t g = 0; g < m->size(); ++g) + { + if (run(m->graph(g))) + changed = true; + } + + return changed; +} + +bool ShapeSignatureInferencePass::run(loco::Graph *g) +{ + luci::ssinf::Rule signature_inference_rule; + bool changed = false; + + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + luci::ShapeSignature shape_signature; + + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (signature_inference_rule.infer(circle_node, shape_signature)) + { + if (!(circle_node->shape_signature() == shape_signature)) + { + circle_node->shape_signature(shape_signature); + changed = true; + } + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp new file mode 100644 index 000000000..6a58f18c5 --- /dev/null +++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h" + +#include <luci/IR/CircleNodes.h> + +#include <cassert> +#include <vector> + +namespace +{ + +bool satisfy_precondition(luci::CircleFullyConnected *fc) +{ + // check if it's already been shuffled + if (fc->weights_format() != luci::CircleFullyConnected::WeightsFormat::DEFAULT) + return false; + + // check if its data type is FLOAT32 + if (fc->dtype() != loco::DataType::FLOAT32) + return false; + + auto weights = loco::must_cast<luci::CircleConst *>(fc->weights()); + // rank must be 2 + if (weights->rank() != 2) + return false; + + // check if it has sparsity parameter + if (weights->sparsityparam()) + return false; + + // check if the number of row of FullyConnected's weight is a multiple of 16 + const uint32_t MULTIPLE = 16; + uint32_t rows = weights->dim(0).value(); + if (rows % MULTIPLE) + return false; + + return true; +} + +// get FullyConnected op vector that has same tensor +void get_FCs_having_same_tensor(std::vector<luci::CircleFullyConnected *> &fc_vec, loco::Graph *g, + luci::CircleFullyConnected *fc) +{ + auto the_tensor = fc->weights(); + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto fc = dynamic_cast<luci::CircleFullyConnected *>(node); + if (not fc) + continue; + + if (fc->weights() == the_tensor) + fc_vec.push_back(fc); + } +} + +luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc) +{ + auto the_weights = loco::must_cast<luci::CircleConst *>(fc->weights()); + + // create CircleConst where shuffled data will be stored + luci::CircleConst *new_weights = fc->graph()->nodes()->create<luci::CircleConst>(); + new_weights->dtype(loco::DataType::FLOAT32); + new_weights->size<loco::DataType::FLOAT32>(the_weights->size<loco::DataType::FLOAT32>()); + new_weights->rank(the_weights->rank()); + new_weights->shape_status(the_weights->shape_status()); + for (uint32_t r = 0; r < new_weights->rank(); r++) + { + new_weights->dim(r).set(the_weights->dim(r).value()); + } + + // suffle weight + const uint32_t MULTIPLE = 16; + const uint32_t rows = the_weights->dim(0).value(); + const uint32_t cols = the_weights->dim(1).value(); + const uint32_t r_step = rows / MULTIPLE; + uint32_t index = 0; + for (uint32_t r = 0; r < r_step; r++) + { + for (uint32_t c = 0; c < cols; c++) + { + for (uint32_t i = 0; i < MULTIPLE; i++) + { + new_weights->at<loco::DataType::FLOAT32>(index++) = + the_weights->at<loco::DataType::FLOAT32>((r * MULTIPLE + i) * cols + c); + } + } + } + + return new_weights; +} + +} // namespace + +namespace luci +{ + +bool ShuffleWeightTo16x1Float32Pass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto fc = dynamic_cast<luci::CircleFullyConnected *>(node); + if (not fc) + continue; + + if (not satisfy_precondition(fc)) + continue; + + std::vector<luci::CircleFullyConnected *> fc_vec; + get_FCs_having_same_tensor(fc_vec, g, fc); + auto new_weights = shuffle_weight(fc); + + // replace to new weights + for (const auto fc : fc_vec) + { + fc->weights(new_weights); + fc->weights_format(luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32); + } + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp new file mode 100644 index 000000000..9745e5754 --- /dev/null +++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +void create_fc_net(loco::Graph *g) +{ + assert(g); + + const uint32_t ROW = 16; + const uint32_t COL = 2; + const uint32_t elements_num = ROW * COL; + + // input + auto input = g->nodes()->create<luci::CircleInput>(); + auto graph_input = g->inputs()->create(); + input->index(graph_input->index()); + + // fc weights + auto weights = g->nodes()->create<luci::CircleConst>(); + weights->dtype(loco::DataType::FLOAT32); + weights->size<loco::DataType::FLOAT32>(elements_num); + weights->rank(2); + weights->dim(0).set(ROW); + weights->dim(1).set(COL); + for (uint32_t idx = 0; idx < elements_num; idx++) + { + weights->at<loco::DataType::FLOAT32>(idx) = idx; + } + + // fc + auto fc = g->nodes()->create<luci::CircleFullyConnected>(); + fc->dtype(loco::DataType::FLOAT32); + fc->input(input); + fc->weights(weights); + + // output + auto output = g->nodes()->create<luci::CircleOutput>(); + output->from(fc); + auto graph_output = g->outputs()->create(); + output->index(graph_output->index()); +} + +TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1) +{ + auto graph = loco::make_graph(); + create_fc_net(graph.get()); + + luci::CircleFullyConnected *fc_node = nullptr; + for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + { + auto fc = dynamic_cast<luci::CircleFullyConnected *>(node); + if (not fc) + continue; + + fc_node = fc; + break; + } + ASSERT_NE(fc_node, nullptr); + auto weights = loco::must_cast<luci::CircleConst *>(fc_node->weights()); + // before + ASSERT_EQ(0, weights->at<loco::DataType::FLOAT32>(0)); + ASSERT_EQ(1, weights->at<loco::DataType::FLOAT32>(1)); + ASSERT_EQ(2, weights->at<loco::DataType::FLOAT32>(2)); + ASSERT_EQ(3, weights->at<loco::DataType::FLOAT32>(3)); + ASSERT_EQ(4, weights->at<loco::DataType::FLOAT32>(4)); + ASSERT_EQ(5, weights->at<loco::DataType::FLOAT32>(5)); + ASSERT_EQ(6, weights->at<loco::DataType::FLOAT32>(6)); + ASSERT_EQ(7, weights->at<loco::DataType::FLOAT32>(7)); + ASSERT_EQ(8, weights->at<loco::DataType::FLOAT32>(8)); + ASSERT_EQ(9, weights->at<loco::DataType::FLOAT32>(9)); + ASSERT_EQ(10, weights->at<loco::DataType::FLOAT32>(10)); + ASSERT_EQ(11, weights->at<loco::DataType::FLOAT32>(11)); + ASSERT_EQ(12, weights->at<loco::DataType::FLOAT32>(12)); + ASSERT_EQ(13, weights->at<loco::DataType::FLOAT32>(13)); + ASSERT_EQ(14, weights->at<loco::DataType::FLOAT32>(14)); + ASSERT_EQ(15, weights->at<loco::DataType::FLOAT32>(15)); + + luci::ShuffleWeightTo16x1Float32Pass pass; + while (pass.run(graph.get())) + ; + + weights = loco::must_cast<luci::CircleConst *>(fc_node->weights()); + // after + ASSERT_EQ(0, weights->at<loco::DataType::FLOAT32>(0)); + ASSERT_EQ(2, weights->at<loco::DataType::FLOAT32>(1)); + ASSERT_EQ(4, weights->at<loco::DataType::FLOAT32>(2)); + ASSERT_EQ(6, weights->at<loco::DataType::FLOAT32>(3)); + ASSERT_EQ(8, weights->at<loco::DataType::FLOAT32>(4)); + ASSERT_EQ(10, weights->at<loco::DataType::FLOAT32>(5)); + ASSERT_EQ(12, weights->at<loco::DataType::FLOAT32>(6)); + ASSERT_EQ(14, weights->at<loco::DataType::FLOAT32>(7)); + ASSERT_EQ(16, weights->at<loco::DataType::FLOAT32>(8)); + ASSERT_EQ(18, weights->at<loco::DataType::FLOAT32>(9)); + ASSERT_EQ(20, weights->at<loco::DataType::FLOAT32>(10)); + ASSERT_EQ(22, weights->at<loco::DataType::FLOAT32>(11)); + ASSERT_EQ(24, weights->at<loco::DataType::FLOAT32>(12)); + ASSERT_EQ(26, weights->at<loco::DataType::FLOAT32>(13)); + ASSERT_EQ(28, weights->at<loco::DataType::FLOAT32>(14)); + ASSERT_EQ(30, weights->at<loco::DataType::FLOAT32>(15)); +} diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp new file mode 100644 index 000000000..44e974b91 --- /dev/null +++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/SubstitutePackToReshapePass.h" + +#include <luci/IR/CircleNodes.h> + +namespace +{ + +bool substitute_pack_to_reshape(luci::CircleNode *node) +{ + auto target_node = dynamic_cast<luci::CirclePack *>(node); + if (target_node == nullptr) + return false; + if (target_node->values_count() != 1) + return false; + auto value_node = loco::must_cast<luci::CircleNode *>(target_node->values(0)); + if (value_node->shape_status() != luci::ShapeStatus::VALID) + return false; + int32_t axis = target_node->axis(); + if (axis < 0) + axis = axis + static_cast<int32_t>(value_node->rank()) + 1; + + auto graph = target_node->graph(); + auto reshape_node = graph->nodes()->create<luci::CircleReshape>(); + reshape_node->tensor(value_node); + + auto const_node = graph->nodes()->create<luci::CircleConst>(); + const_node->dtype(loco::DataType::S32); + const_node->size<loco::DataType::S32>(value_node->rank() + 1); + const_node->shape_status(luci::ShapeStatus::VALID); + const_node->rank(1); + const_node->dim(0).set(value_node->rank() + 1); + for (int32_t i = 0; i < static_cast<int32_t>(value_node->rank()) + 1; i++) + { + if (i == axis) + { + const_node->at<loco::DataType::S32>(i) = 1; + } + else if (i < axis) + { + const_node->at<loco::DataType::S32>(i) = value_node->dim(i).value(); + } + else + { + const_node->at<loco::DataType::S32>(i) = value_node->dim(i - 1).value(); + } + } + reshape_node->shape(const_node); + replace(target_node).with(reshape_node); + return true; +} + +} // namespace + +namespace luci +{ + +/** + * BEFORE + * | + * [CircleNode] + * | + * [CirclePack] + * | + * [CircleNode] + * | + * + * AFTER + * | + * [CircleNode] [CircleConst] + * \ / + * [CircleReshape] + * | + * [CircleNode] + * | + * + */ +bool SubstitutePackToReshapePass::run(loco::Graph *g) +{ + bool changed = false; + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + if (substitute_pack_to_reshape(circle_node)) + { + changed = true; + } + } + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp new file mode 100644 index 000000000..143b88896 --- /dev/null +++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "luci/Pass/SubstitutePackToReshapePass.h" + +#include <luci/IR/CircleNodes.h> + +#include <gtest/gtest.h> + +namespace +{ + +/** + * BEFORE + * | + * [CircleNode] + * | + * [CirclePack] + * | + * [CircleNode] + * | + * + * AFTER + * | + * [CircleNode] [CircleConst] + * \ / + * [CircleReshape] + * | + * [CircleNode] + * | + * + */ +void create_substitute_pack_to_reshape(loco::Graph *g, const std::initializer_list<uint32_t> shape, + int32_t axis) +{ + assert(g); + + // Input Create. + auto input = g->nodes()->create<luci::CircleInput>(); + auto graph_input = g->inputs()->create(); + input->index(graph_input->index()); + input->shape_status(luci::ShapeStatus::VALID); + input->rank(shape.size()); + input->shape(shape); + + // Pack Node create. + auto pack = g->nodes()->create<luci::CirclePack>(1); + pack->values(0, input); + pack->axis(axis); + + // Output Connect. + auto output = g->nodes()->create<luci::CircleOutput>(); + output->from(pack); + auto graph_output = g->outputs()->create(); + output->index(graph_output->index()); + + return; +} + +} // namespace + +TEST(SubstitutePackToReshapePass, simple_case) +{ + auto graph = loco::make_graph(); + create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, 0); + luci::SubstitutePackToReshapePass pass; + while (pass.run(graph.get())) + ; + luci::CircleReshape *reshape_node = nullptr; + luci::CirclePack *pack_node = nullptr; + for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + { + if (auto reshape = dynamic_cast<luci::CircleReshape *>(node)) + reshape_node = reshape; + else if (auto pack = dynamic_cast<luci::CirclePack *>(node)) + pack_node = pack; + } + ASSERT_NE(nullptr, reshape_node); + ASSERT_EQ(nullptr, pack_node); + auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape()); + ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0)); + ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(1)); + ASSERT_EQ(2, new_shape->at<loco::DataType::S32>(2)); + ASSERT_EQ(3, new_shape->at<loco::DataType::S32>(3)); + ASSERT_EQ(4, new_shape->at<loco::DataType::S32>(4)); +} + +TEST(SubstitutePackToReshapePass, simple_case_neg_axis) +{ + auto graph = loco::make_graph(); + create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, -1); + luci::SubstitutePackToReshapePass pass; + while (pass.run(graph.get())) + ; + luci::CircleReshape *reshape_node = nullptr; + luci::CirclePack *pack_node = nullptr; + for (auto node : loco::active_nodes(loco::output_nodes(graph.get()))) + { + if (auto reshape = dynamic_cast<luci::CircleReshape *>(node)) + reshape_node = reshape; + else if (auto pack = dynamic_cast<luci::CirclePack *>(node)) + pack_node = pack; + } + ASSERT_NE(nullptr, reshape_node); + ASSERT_EQ(nullptr, pack_node); + auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape()); + ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0)); + ASSERT_EQ(2, new_shape->at<loco::DataType::S32>(1)); + ASSERT_EQ(3, new_shape->at<loco::DataType::S32>(2)); + ASSERT_EQ(4, new_shape->at<loco::DataType::S32>(3)); + ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(4)); +} diff --git a/compiler/luci/pass/src/TypeInferencePass.cpp b/compiler/luci/pass/src/TypeInferencePass.cpp index 2c7b3a897..63744045c 100644 --- a/compiler/luci/pass/src/TypeInferencePass.cpp +++ b/compiler/luci/pass/src/TypeInferencePass.cpp @@ -26,6 +26,19 @@ namespace luci { +bool TypeInferencePass::run(luci::Module *m) +{ + bool changed = false; + + for (size_t g = 0; g < m->size(); ++g) + { + if (run(m->graph(g))) + changed = true; + } + + return changed; +} + bool TypeInferencePass::run(loco::Graph *g) { loco::CanonicalTypeInferenceRule canonical_rule; diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h index fb934c2cf..c301db5f4 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h @@ -21,6 +21,10 @@ #include <loco/IR/Nodes.h> +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Service/CircleShapeInferenceHelper.h> + namespace luci { @@ -36,6 +40,155 @@ struct ShapeInference static ShapeDescription get(loco::Node *node); }; +namespace sinf // namespace for Shape Inference +{ + +struct Rule +{ + bool infer(const luci::CircleNode *, loco::TensorShape &) const; +}; + +class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape> +{ +public: + // TODO Remove this when all of visit function is implemented + loco::TensorShape visit(const luci::CircleNode *node) final { return sinf::circle_shape(node); } + + // loco::TensorShape visit(const luci::CircleAbs *node) final; + // loco::TensorShape visit(const luci::CircleAdd *node) final; + // loco::TensorShape visit(const luci::CircleAddN *node) final; + // loco::TensorShape visit(const luci::CircleArgMax *node) final; + // loco::TensorShape visit(const luci::CircleArgMin *node) final; + // loco::TensorShape visit(const luci::CircleAveragePool2D *node) final; + // loco::TensorShape visit(const luci::CircleBatchMatMul *node) final; + // loco::TensorShape visit(const luci::CircleBatchToSpaceND *node) final; + // loco::TensorShape visit(const luci::CircleCast *node) final; + // loco::TensorShape visit(const luci::CircleCeil *node) final; + // loco::TensorShape visit(const luci::CircleConcatenation *node) final; + // loco::TensorShape visit(const luci::CircleConst *node) final; + // loco::TensorShape visit(const luci::CircleConv2D *node) final; + // loco::TensorShape visit(const luci::CircleCos *node) final; + // loco::TensorShape visit(const luci::CircleCustom *node) final; + // loco::TensorShape visit(const luci::CircleDepthToSpace *node) final; + // loco::TensorShape visit(const luci::CircleDepthwiseConv2D *node) final; + // loco::TensorShape visit(const luci::CircleDequantize *node) final; + // loco::TensorShape visit(const luci::CircleDiv *node) final; + // loco::TensorShape visit(const luci::CircleElu *node) final; + // loco::TensorShape visit(const luci::CircleEqual *node) final; + // loco::TensorShape visit(const luci::CircleExp *node) final; + // loco::TensorShape visit(const luci::CircleExpandDims *node) final; + // loco::TensorShape visit(const luci::CircleFill *node) final; + // loco::TensorShape visit(const luci::CircleFloor *node) final; + // loco::TensorShape visit(const luci::CircleFloorDiv *node) final; + // loco::TensorShape visit(const luci::CircleFloorMod *node) final; + // loco::TensorShape visit(const luci::CircleFullyConnected *node) final; + // loco::TensorShape visit(const luci::CircleGather *node) final; + // loco::TensorShape visit(const luci::CircleGatherNd *node) final; + // loco::TensorShape visit(const luci::CircleGreater *node) final; + // loco::TensorShape visit(const luci::CircleGreaterEqual *node) final; + // loco::TensorShape visit(const luci::CircleIf *node) final; + // loco::TensorShape visit(const luci::CircleL2Normalize *node) final; + // loco::TensorShape visit(const luci::CircleL2Pool2D *node) final; + // loco::TensorShape visit(const luci::CircleLeakyRelu *node) final; + // loco::TensorShape visit(const luci::CircleLess *node) final; + // loco::TensorShape visit(const luci::CircleLessEqual *node) final; + // loco::TensorShape visit(const luci::CircleLocalResponseNormalization *node) final; + // loco::TensorShape visit(const luci::CircleLog *node) final; + // loco::TensorShape visit(const luci::CircleLogicalAnd *node) final; + // loco::TensorShape visit(const luci::CircleLogicalNot *node) final; + // loco::TensorShape visit(const luci::CircleLogicalOr *node) final; + // loco::TensorShape visit(const luci::CircleLogistic *node) final; + // loco::TensorShape visit(const luci::CircleLogSoftmax *node) final; + // loco::TensorShape visit(const luci::CircleMatrixDiag *node) final; + // loco::TensorShape visit(const luci::CircleMatrixSetDiag *node) final; + // loco::TensorShape visit(const luci::CircleMaximum *node) final; + // loco::TensorShape visit(const luci::CircleMaxPool2D *node) final; + // loco::TensorShape visit(const luci::CircleMean *node) final; + // loco::TensorShape visit(const luci::CircleMinimum *node) final; + // loco::TensorShape visit(const luci::CircleMirrorPad *node) final; + // loco::TensorShape visit(const luci::CircleNeg *node) final; + // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4 *node) final; + // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5 *node) final; + // loco::TensorShape visit(const luci::CircleNotEqual *node) final; + // loco::TensorShape visit(const luci::CirclePack *node) final; + // loco::TensorShape visit(const luci::CirclePad *node) final; + // loco::TensorShape visit(const luci::CirclePadV2 *node) final; + // loco::TensorShape visit(const luci::CirclePow *node) final; + // loco::TensorShape visit(const luci::CirclePRelu *node) final; + // loco::TensorShape visit(const luci::CircleRange *node) final; + // loco::TensorShape visit(const luci::CircleRank *node) final; + // loco::TensorShape visit(const luci::CircleMul *node) final; + // loco::TensorShape visit(const luci::CircleOneHot *node) final; + // loco::TensorShape visit(const luci::CircleReduceAny *node) final; + // loco::TensorShape visit(const luci::CircleReduceMax *node) final; + // loco::TensorShape visit(const luci::CircleReduceMin *node) final; + // loco::TensorShape visit(const luci::CircleReduceProd *node) final; + // loco::TensorShape visit(const luci::CircleRelu *node) final; + // loco::TensorShape visit(const luci::CircleRelu6 *node) final; + // loco::TensorShape visit(const luci::CircleReluN1To1 *node) final; + // loco::TensorShape visit(const luci::CircleReshape *node) final; + // loco::TensorShape visit(const luci::CircleResizeBilinear *node) final; + // loco::TensorShape visit(const luci::CircleResizeNearestNeighbor *node) final; + // loco::TensorShape visit(const luci::CircleReverseSequence *node) final; + // loco::TensorShape visit(const luci::CircleReverseV2 *node) final; + // loco::TensorShape visit(const luci::CircleRound *node) final; + // loco::TensorShape visit(const luci::CircleRsqrt *node) final; + // loco::TensorShape visit(const luci::CircleScatterNd *node) final; + // loco::TensorShape visit(const luci::CircleSegmentSum *node) final; + // loco::TensorShape visit(const luci::CircleSelect *node) final; + // loco::TensorShape visit(const luci::CircleSelectV2 *node) final; + // loco::TensorShape visit(const luci::CircleShape *node) final; + // loco::TensorShape visit(const luci::CircleSin *node) final; + // loco::TensorShape visit(const luci::CircleSlice *node) final; + // loco::TensorShape visit(const luci::CircleSoftmax *node) final; + // loco::TensorShape visit(const luci::CircleSpaceToBatchND *node) final; + // loco::TensorShape visit(const luci::CircleSpaceToDepth *node) final; + // loco::TensorShape visit(const luci::CircleSparseToDense *node) final; + // loco::TensorShape visit(const luci::CircleSplit *node) final; + // loco::TensorShape visit(const luci::CircleSplitV *node) final; + // loco::TensorShape visit(const luci::CircleSqrt *node) final; + // loco::TensorShape visit(const luci::CircleSquare *node) final; + // loco::TensorShape visit(const luci::CircleSquaredDifference *node) final; + // loco::TensorShape visit(const luci::CircleSqueeze *node) final; + // loco::TensorShape visit(const luci::CircleStridedSlice *node) final; + // loco::TensorShape visit(const luci::CircleSub *node) final; + // loco::TensorShape visit(const luci::CircleSum *node) final; + // loco::TensorShape visit(const luci::CircleTanh *node) final; + // loco::TensorShape visit(const luci::CircleTile *node) final; + // loco::TensorShape visit(const luci::CircleTopKV2 *node) final; + // loco::TensorShape visit(const luci::CircleTranspose *node) final; + // loco::TensorShape visit(const luci::CircleTransposeConv *node) final; + // loco::TensorShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final; + // loco::TensorShape visit(const luci::CircleUnique *node) final; + // loco::TensorShape visit(const luci::CircleUnpack *node) final; + // loco::TensorShape visit(const luci::CircleWhere *node) final; + // loco::TensorShape visit(const luci::CircleWhile *node) final; + // loco::TensorShape visit(const luci::CircleZerosLike *node) final; + + // Circle Only + // loco::TensorShape visit(const luci::CircleBCQFullyConnected *node) final; + // loco::TensorShape visit(const luci::CircleBCQGather *node) final; + // loco::TensorShape visit(const luci::CircleInstanceNorm *node) final; + + // Virtual + // loco::TensorShape visit(const luci::CircleInput *node) final; + // loco::TensorShape visit(const luci::CircleOutput *node) final; + // loco::TensorShape visit(const luci::CircleOutputDummy *node) final; + // loco::TensorShape visit(const luci::CircleOutputExclude *node) final; + // loco::TensorShape visit(const luci::CircleCustomOut *node) final; + // loco::TensorShape visit(const luci::CircleIfOut *node) final; + // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final; + // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final; + // loco::TensorShape visit(const luci::CircleSplitOut *node) final; + // loco::TensorShape visit(const luci::CircleSplitVOut *node) final; + // loco::TensorShape visit(const luci::CircleTopKV2Out *node) final; + // loco::TensorShape visit(const luci::CircleUniqueOut *node) final; + // loco::TensorShape visit(const luci::CircleUnpackOut *node) final; + // loco::TensorShape visit(const luci::CircleWhileOut *node) final; +}; + +} // namespace sinf + } // namespace luci #endif // __LUCI_CIRCLE_SHAPE_INFERENCE_H__ diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h new file mode 100644 index 000000000..dd6a5a454 --- /dev/null +++ b/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__ +#define __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__ + +#include <loco/IR/TensorShape.h> + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleShapeSignature.h> + +namespace luci +{ +namespace sinf // Namespace for Shape Inference +{ + +// Return shape of circle node as loco::TensorShape +loco::TensorShape circle_shape(const luci::CircleNode *node); + +} // namespace sinf +} // namespace luci + +#endif // __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__ diff --git a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceRule.h b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h index 4d1d83012..f7ea89bb8 100644 --- a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceRule.h +++ b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h @@ -14,22 +14,26 @@ * limitations under the License. */ -#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_RULE_H__ -#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_RULE_H__ +#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__ +#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__ #include <luci/IR/CircleNodes.h> #include <luci/IR/CircleNodeVisitor.h> #include <luci/IR/CircleShapeSignature.h> +#include <luci/Service/CircleShapeSignatureInferenceHelper.h> namespace luci { -struct CircleShapeSignatureInferenceRule +namespace ssinf // namespace for Shape Signature Inference +{ + +struct Rule { bool infer(const luci::CircleNode *, ShapeSignature &) const; }; -class ShapeSignatureInferenceAlgorithm final : public luci::CircleNodeVisitor<ShapeSignature> +class Algorithm final : public luci::CircleNodeVisitor<ShapeSignature> { public: // TODO Remove this when visit function is implemented for all the operations. @@ -84,7 +88,7 @@ public: // ShapeSignature visit(const luci::CircleMatrixSetDiag *node) final; // ShapeSignature visit(const luci::CircleMaximum *node) final; // ShapeSignature visit(const luci::CircleMaxPool2D *node) final; - // ShapeSignature visit(const luci::CircleMean *node) final; + ShapeSignature visit(const luci::CircleMean *node) final; // ShapeSignature visit(const luci::CircleMinimum *node) final; // ShapeSignature visit(const luci::CircleMirrorPad *node) final; // ShapeSignature visit(const luci::CircleNeg *node) final; @@ -100,13 +104,13 @@ public: // ShapeSignature visit(const luci::CircleRank *node) final; // ShapeSignature visit(const luci::CircleMul *node) final; // ShapeSignature visit(const luci::CircleOneHot *node) final; - // ShapeSignature visit(const luci::CircleReduceAny *node) final; - // ShapeSignature visit(const luci::CircleReduceMax *node) final; - // ShapeSignature visit(const luci::CircleReduceMin *node) final; - // ShapeSignature visit(const luci::CircleReduceProd *node) final; - // ShapeSignature visit(const luci::CircleRelu *node) final; - // ShapeSignature visit(const luci::CircleRelu6 *node) final; - // ShapeSignature visit(const luci::CircleReluN1To1 *node) final; + ShapeSignature visit(const luci::CircleReduceAny *node) final; + ShapeSignature visit(const luci::CircleReduceMax *node) final; + ShapeSignature visit(const luci::CircleReduceMin *node) final; + ShapeSignature visit(const luci::CircleReduceProd *node) final; + ShapeSignature visit(const luci::CircleRelu *node) final; + ShapeSignature visit(const luci::CircleRelu6 *node) final; + ShapeSignature visit(const luci::CircleReluN1To1 *node) final; // ShapeSignature visit(const luci::CircleReshape *node) final; // ShapeSignature visit(const luci::CircleResizeBilinear *node) final; // ShapeSignature visit(const luci::CircleResizeNearestNeighbor *node) final; @@ -133,7 +137,7 @@ public: // ShapeSignature visit(const luci::CircleSqueeze *node) final; // ShapeSignature visit(const luci::CircleStridedSlice *node) final; // ShapeSignature visit(const luci::CircleSub *node) final; - // ShapeSignature visit(const luci::CircleSum *node) final; + ShapeSignature visit(const luci::CircleSum *node) final; // ShapeSignature visit(const luci::CircleTanh *node) final; // ShapeSignature visit(const luci::CircleTile *node) final; // ShapeSignature visit(const luci::CircleTopKV2 *node) final; @@ -152,10 +156,10 @@ public: // ShapeSignature visit(const luci::CircleInstanceNorm *node) final; // Virtual - // ShapeSignature visit(const luci::CircleInput *node) final; - // ShapeSignature visit(const luci::CircleOutput *node) final; - // ShapeSignature visit(const luci::CircleOutputDummy *node) final; - // ShapeSignature visit(const luci::CircleOutputExclude *node) final; + ShapeSignature visit(const luci::CircleInput *node) final; + ShapeSignature visit(const luci::CircleOutput *node) final; + ShapeSignature visit(const luci::CircleOutputDummy *node) final; + ShapeSignature visit(const luci::CircleOutputExclude *node) final; // ShapeSignature visit(const luci::CircleCustomOut *node) final; // ShapeSignature visit(const luci::CircleIfOut *node) final; // ShapeSignature visit(const luci::CircleNonMaxSuppressionV4Out *node) final; @@ -168,6 +172,8 @@ public: // ShapeSignature visit(const luci::CircleWhileOut *node) final; }; +} // namespace ssinf + } // namespace luci -#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_RULE_H__ +#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__ diff --git a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h new file mode 100644 index 000000000..fb5b3b302 --- /dev/null +++ b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__ +#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__ + +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleShapeSignature.h> + +namespace luci +{ + +namespace ssinf // Namespace for Shape Signature Inference +{ + +// Return empty signature if all of dimensions are known. +// If at least one of dimensions is unknown, return signature without change. +ShapeSignature legalized_signature(const luci::ShapeSignature &signature); + +// Return reduced input_signature with indices and keep_dims. +// - indices : reduction index +// - keep_dims : If true, rank is not changed. If false, rank is reduced along indices. +ShapeSignature reduced_signature(const loco::Node *node, const loco::Node *indices, bool keep_dims); + +// Return signature of index-th argument of node. +ShapeSignature input_arg_signature(const luci::CircleNode *node, uint32_t index); + +} // namespace ssinf + +} // namespace luci + +#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__ diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h index ea7a3c5ed..342214887 100644 --- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h +++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h @@ -21,6 +21,10 @@ #include <mio/circle/schema_generated.h> +#include <luci/IR/CircleNodes.h> +#include <luci/IR/CircleNodeVisitor.h> +#include <luci/Service/CircleTypeInferenceHelper.h> + namespace luci { @@ -37,6 +41,155 @@ struct TypeInference static circle::TensorType get(loco::Node *node); }; +namespace tinf // namespace for Type Inference +{ + +struct Rule +{ + bool infer(const luci::CircleNode *, loco::DataType &) const; +}; + +class Algorithm final : public luci::CircleNodeVisitor<loco::DataType> +{ +public: + // TODO Remove this when all of visit function is implemented + loco::DataType visit(const luci::CircleNode *node) final { return node->dtype(); } + + // loco::DataType visit(const luci::CircleAbs *node) final; + // loco::DataType visit(const luci::CircleAdd *node) final; + // loco::DataType visit(const luci::CircleAddN *node) final; + // loco::DataType visit(const luci::CircleArgMax *node) final; + // loco::DataType visit(const luci::CircleArgMin *node) final; + // loco::DataType visit(const luci::CircleAveragePool2D *node) final; + // loco::DataType visit(const luci::CircleBatchMatMul *node) final; + // loco::DataType visit(const luci::CircleBatchToSpaceND *node) final; + // loco::DataType visit(const luci::CircleCast *node) final; + // loco::DataType visit(const luci::CircleCeil *node) final; + // loco::DataType visit(const luci::CircleConcatenation *node) final; + // loco::DataType visit(const luci::CircleConst *node) final; + // loco::DataType visit(const luci::CircleConv2D *node) final; + // loco::DataType visit(const luci::CircleCos *node) final; + // loco::DataType visit(const luci::CircleCustom *node) final; + // loco::DataType visit(const luci::CircleDepthToSpace *node) final; + // loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final; + // loco::DataType visit(const luci::CircleDequantize *node) final; + // loco::DataType visit(const luci::CircleDiv *node) final; + // loco::DataType visit(const luci::CircleElu *node) final; + // loco::DataType visit(const luci::CircleEqual *node) final; + // loco::DataType visit(const luci::CircleExp *node) final; + // loco::DataType visit(const luci::CircleExpandDims *node) final; + // loco::DataType visit(const luci::CircleFill *node) final; + // loco::DataType visit(const luci::CircleFloor *node) final; + // loco::DataType visit(const luci::CircleFloorDiv *node) final; + // loco::DataType visit(const luci::CircleFloorMod *node) final; + // loco::DataType visit(const luci::CircleFullyConnected *node) final; + // loco::DataType visit(const luci::CircleGather *node) final; + // loco::DataType visit(const luci::CircleGatherNd *node) final; + // loco::DataType visit(const luci::CircleGreater *node) final; + // loco::DataType visit(const luci::CircleGreaterEqual *node) final; + // loco::DataType visit(const luci::CircleIf *node) final; + // loco::DataType visit(const luci::CircleL2Normalize *node) final; + // loco::DataType visit(const luci::CircleL2Pool2D *node) final; + // loco::DataType visit(const luci::CircleLeakyRelu *node) final; + // loco::DataType visit(const luci::CircleLess *node) final; + // loco::DataType visit(const luci::CircleLessEqual *node) final; + // loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final; + // loco::DataType visit(const luci::CircleLog *node) final; + // loco::DataType visit(const luci::CircleLogicalAnd *node) final; + // loco::DataType visit(const luci::CircleLogicalNot *node) final; + // loco::DataType visit(const luci::CircleLogicalOr *node) final; + // loco::DataType visit(const luci::CircleLogistic *node) final; + // loco::DataType visit(const luci::CircleLogSoftmax *node) final; + // loco::DataType visit(const luci::CircleMatrixDiag *node) final; + // loco::DataType visit(const luci::CircleMatrixSetDiag *node) final; + // loco::DataType visit(const luci::CircleMaximum *node) final; + // loco::DataType visit(const luci::CircleMaxPool2D *node) final; + // loco::DataType visit(const luci::CircleMean *node) final; + // loco::DataType visit(const luci::CircleMinimum *node) final; + // loco::DataType visit(const luci::CircleMirrorPad *node) final; + // loco::DataType visit(const luci::CircleNeg *node) final; + // loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final; + // loco::DataType visit(const luci::CircleNonMaxSuppressionV5 *node) final; + // loco::DataType visit(const luci::CircleNotEqual *node) final; + // loco::DataType visit(const luci::CirclePack *node) final; + // loco::DataType visit(const luci::CirclePad *node) final; + // loco::DataType visit(const luci::CirclePadV2 *node) final; + // loco::DataType visit(const luci::CirclePow *node) final; + // loco::DataType visit(const luci::CirclePRelu *node) final; + // loco::DataType visit(const luci::CircleRange *node) final; + // loco::DataType visit(const luci::CircleRank *node) final; + // loco::DataType visit(const luci::CircleMul *node) final; + // loco::DataType visit(const luci::CircleOneHot *node) final; + // loco::DataType visit(const luci::CircleReduceAny *node) final; + // loco::DataType visit(const luci::CircleReduceMax *node) final; + // loco::DataType visit(const luci::CircleReduceMin *node) final; + // loco::DataType visit(const luci::CircleReduceProd *node) final; + // loco::DataType visit(const luci::CircleRelu *node) final; + // loco::DataType visit(const luci::CircleRelu6 *node) final; + // loco::DataType visit(const luci::CircleReluN1To1 *node) final; + // loco::DataType visit(const luci::CircleReshape *node) final; + // loco::DataType visit(const luci::CircleResizeBilinear *node) final; + // loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final; + // loco::DataType visit(const luci::CircleReverseSequence *node) final; + // loco::DataType visit(const luci::CircleReverseV2 *node) final; + // loco::DataType visit(const luci::CircleRound *node) final; + // loco::DataType visit(const luci::CircleRsqrt *node) final; + // loco::DataType visit(const luci::CircleScatterNd *node) final; + // loco::DataType visit(const luci::CircleSegmentSum *node) final; + // loco::DataType visit(const luci::CircleSelect *node) final; + // loco::DataType visit(const luci::CircleSelectV2 *node) final; + // loco::DataType visit(const luci::CircleShape *node) final; + // loco::DataType visit(const luci::CircleSin *node) final; + // loco::DataType visit(const luci::CircleSlice *node) final; + // loco::DataType visit(const luci::CircleSoftmax *node) final; + // loco::DataType visit(const luci::CircleSpaceToBatchND *node) final; + // loco::DataType visit(const luci::CircleSpaceToDepth *node) final; + // loco::DataType visit(const luci::CircleSparseToDense *node) final; + // loco::DataType visit(const luci::CircleSplit *node) final; + // loco::DataType visit(const luci::CircleSplitV *node) final; + // loco::DataType visit(const luci::CircleSqrt *node) final; + // loco::DataType visit(const luci::CircleSquare *node) final; + // loco::DataType visit(const luci::CircleSquaredDifference *node) final; + // loco::DataType visit(const luci::CircleSqueeze *node) final; + // loco::DataType visit(const luci::CircleStridedSlice *node) final; + // loco::DataType visit(const luci::CircleSub *node) final; + // loco::DataType visit(const luci::CircleSum *node) final; + // loco::DataType visit(const luci::CircleTanh *node) final; + // loco::DataType visit(const luci::CircleTile *node) final; + // loco::DataType visit(const luci::CircleTopKV2 *node) final; + // loco::DataType visit(const luci::CircleTranspose *node) final; + // loco::DataType visit(const luci::CircleTransposeConv *node) final; + // loco::DataType visit(const luci::CircleUnidirectionalSequenceLSTM *node) final; + // loco::DataType visit(const luci::CircleUnique *node) final; + // loco::DataType visit(const luci::CircleUnpack *node) final; + // loco::DataType visit(const luci::CircleWhere *node) final; + // loco::DataType visit(const luci::CircleWhile *node) final; + // loco::DataType visit(const luci::CircleZerosLike *node) final; + + // Circle Only + // loco::DataType visit(const luci::CircleBCQFullyConnected *node) final; + // loco::DataType visit(const luci::CircleBCQGather *node) final; + // loco::DataType visit(const luci::CircleInstanceNorm *node) final; + + // Virtual + // loco::DataType visit(const luci::CircleInput *node) final; + // loco::DataType visit(const luci::CircleOutput *node) final; + // loco::DataType visit(const luci::CircleOutputDummy *node) final; + // loco::DataType visit(const luci::CircleOutputExclude *node) final; + // loco::DataType visit(const luci::CircleCustomOut *node) final; + // loco::DataType visit(const luci::CircleIfOut *node) final; + // loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final; + // loco::DataType visit(const luci::CircleNonMaxSuppressionV5Out *node) final; + // loco::DataType visit(const luci::CircleSplitOut *node) final; + // loco::DataType visit(const luci::CircleSplitVOut *node) final; + // loco::DataType visit(const luci::CircleTopKV2Out *node) final; + // loco::DataType visit(const luci::CircleUniqueOut *node) final; + // loco::DataType visit(const luci::CircleUnpackOut *node) final; + // loco::DataType visit(const luci::CircleWhileOut *node) final; +}; + +} // namespace tinf + } // namespace luci #endif // __LUCI_CIRCLE_TYPE_INFERENCE_H__ diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h new file mode 100644 index 000000000..296f99355 --- /dev/null +++ b/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_CIRCLE_TYPE_INFERENCE_HELPER_H__ +#define __LUCI_CIRCLE_TYPE_INFERENCE_HELPER_H__ + +#include <luci/IR/CircleNodes.h> + +#include <loco/IR/DataType.h> + +namespace luci +{ +namespace tinf // Namespace for Type Inference +{ + +// Helper function will be added + +} // namespace tinf +} // namespace luci + +#endif // __LUCI_CIRCLE_TYPE_INFERENCE_HELPER_H__ diff --git a/compiler/luci/service/include/luci/Service/ShapeDescription.h b/compiler/luci/service/include/luci/Service/ShapeDescription.h index 949cce535..4d92be13f 100644 --- a/compiler/luci/service/include/luci/Service/ShapeDescription.h +++ b/compiler/luci/service/include/luci/Service/ShapeDescription.h @@ -20,6 +20,8 @@ #include <loco/IR/PermutingCodec.h> #include <loco/IR/NodeShape.h> +#include <luci/IR/CircleNodes.h> + #include <cstdint> #include <vector> @@ -33,6 +35,7 @@ struct ShapeDescription }; // TODO remove these when CircleDialect is fully functioal +ShapeDescription to_shape_description(const luci::CircleNode *node); ShapeDescription to_shape_description(const loco::TensorShape &shape); ShapeDescription to_shape_description(const loco::FeatureShape &shape); ShapeDescription to_shape_description(const loco::FilterShape &shape); diff --git a/compiler/luci/service/src/CircleShapeInference.cpp b/compiler/luci/service/src/CircleShapeInference.cpp index 0732849db..db8ffd8ad 100644 --- a/compiler/luci/service/src/CircleShapeInference.cpp +++ b/compiler/luci/service/src/CircleShapeInference.cpp @@ -20,7 +20,10 @@ #include <loco.h> #include <loco/Service/ShapeInference.h> +#include <luci/Log.h> + #include <cassert> +#include <iostream> namespace luci { @@ -32,3 +35,60 @@ ShapeDescription ShapeInference::get(loco::Node *node) } } // namespace luci + +namespace +{ + +std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape) +{ + os << "["; + for (uint32_t r = 0; r < tensor_shape.rank(); ++r) + { + if (r) + os << ","; + os << tensor_shape.dim(r).value(); + } + os << "]"; + return os; +} + +bool inputs_shape_ready(const luci::CircleNode *node) +{ + for (uint32_t arity = 0; arity < node->arity(); ++arity) + { + auto node_input = loco::must_cast<luci::CircleNode *>(node->arg(arity)); + if (node_input->shape_status() == luci::ShapeStatus::UNDEFINED) + return false; + } + + return true; +} + +} // namespace + +namespace luci +{ +namespace sinf +{ + +bool Rule::infer(const luci::CircleNode *circle_node, loco::TensorShape &shape) const +{ + LOGGER(l); + VERBOSE(l, 1) << "[CircleShapeInference] " << circle_node->name(); + VERBOSE(l, 1) << " before: " << circle_shape(circle_node); + + if (!inputs_shape_ready(circle_node)) + { + VERBOSE(l, 1) << " after: Some inputs are not ready for inference"; + return false; + } + + Algorithm alg; + shape = circle_node->accept(&alg); + VERBOSE(l, 1) << " after: " << shape; + + return true; +} + +} // namespace ssinf +} // namespace luci diff --git a/compiler/luci/service/src/CircleShapeInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp new file mode 100644 index 000000000..f7eb6c3ec --- /dev/null +++ b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleShapeInferenceHelper.h" + +namespace luci +{ +namespace sinf +{ + +loco::TensorShape circle_shape(const luci::CircleNode *node) +{ + loco::TensorShape shape; + shape.rank(node->rank()); + for (uint32_t r = 0; r < node->rank(); ++r) + shape.dim(r) = loco::Dimension(node->dim(r).value()); + return shape; +} + +} // namespace sinf +} // namespace luci diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index a55f50b19..38ff619ab 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -102,7 +102,7 @@ private: }; /** - * @breif Expand shape x and y to same rank by align right and filling with 1 + * @brief Expand shape x and y to same rank by align right and filling with 1 */ void expand_rank(loco::TensorShape &x, loco::TensorShape &y) { @@ -122,7 +122,7 @@ void expand_rank(loco::TensorShape &x, loco::TensorShape &y) } /** - * @breif Returns shape of expanded dimension of input x and y having same rank + * @brief Returns shape of expanded dimension of input x and y having same rank */ loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y) { diff --git a/compiler/luci/service/src/CircleShapeSignatureInferenceRule.cpp b/compiler/luci/service/src/CircleShapeSignatureInference.cpp index dc7df3e39..1ccaa19d5 100644 --- a/compiler/luci/service/src/CircleShapeSignatureInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeSignatureInference.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "luci/Service/CircleShapeSignatureInferenceRule.h" +#include "luci/Service/CircleShapeSignatureInference.h" #include <luci/Log.h> @@ -39,14 +39,16 @@ std::ostream &operator<<(std::ostream &os, const luci::ShapeSignature &shape_sig namespace luci { -bool CircleShapeSignatureInferenceRule::infer(const luci::CircleNode *circle_node, - ShapeSignature &shape_signature) const +namespace ssinf +{ + +bool Rule::infer(const luci::CircleNode *circle_node, ShapeSignature &shape_signature) const { LOGGER(l); // There is nothing to check before ShapeSignatureInference. - ShapeSignatureInferenceAlgorithm alg; + Algorithm alg; shape_signature = circle_node->accept(&alg); @@ -57,4 +59,6 @@ bool CircleShapeSignatureInferenceRule::infer(const luci::CircleNode *circle_nod return true; } +} // namespace ssinf + } // namespace luci diff --git a/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp new file mode 100644 index 000000000..d7d1a24e8 --- /dev/null +++ b/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleShapeSignatureInferenceHelper.h" + +#include <loco.h> + +#include <luci/Log.h> + +#include <oops/InternalExn.h> + +namespace luci +{ + +namespace ssinf +{ + +luci::ShapeSignature legalized_signature(const luci::ShapeSignature &signature) +{ + // If shape signature has at least one -1, it is not static. + for (uint32_t i = 0; i < signature.rank(); ++i) + if (signature.dim(i) == -1) + return signature; + + // If all dimensions are static, return empty shape signature. + return luci::ShapeSignature(); +} + +ShapeSignature reduced_signature(const loco::Node *node, const loco::Node *indices, bool keep_dims) +{ + LOGGER(l); + + ShapeSignature input_signature; + ShapeSignature output_signature; + + auto circle_node = loco::must_cast<const luci::CircleNode *>(node); + if (circle_node->shape_signature().rank() > 0) + input_signature = circle_node->shape_signature(); + else + { + input_signature.rank(circle_node->rank()); + for (uint32_t i = 0; i < circle_node->rank(); ++i) + input_signature.dim(i) = circle_node->dim(i).value(); + } + + // If input rank is 0, it means that one of following case is occurred. + // - Input is scalar : result is always scalar + // - Input shape signature is not inferenced : cannot infer output shape signauture + // Therefore, when input signature rank is 0, always return empty signature. + if (input_signature.rank() == 0) + return output_signature; + + // When reduction_indices is not constant + auto reduction_indices = dynamic_cast<const luci::CircleConst *>(indices); + if (reduction_indices == nullptr) + { + if (keep_dims) + { + // If keep_dims is true, rank is not changed. + output_signature.rank(input_signature.rank()); + for (uint32_t i = 0; i < output_signature.rank(); ++i) + output_signature.dim(i) = -1; + } + else + { + // There is no way to inference for this case. + // Do nothing to return empty signature. + INFO(l) << "[CircleShapeSignatureInferenceHelper] " << circle_node->name() << std::endl; + INFO(l) << " reduced_signature : cannot infer because of non-constant node" << std::endl; + } + + return output_signature; + } + + std::vector<int32_t> reduction_values; + if (reduction_indices->dtype() == loco::DataType::S32) + { + auto reduction_size = reduction_indices->size<loco::DataType::S32>(); + for (uint32_t i = 0; i < reduction_size; ++i) + { + int32_t axis = reduction_indices->at<loco::DataType::S32>(i); + if (axis < 0) + axis += input_signature.rank(); + + if (!(0 <= axis && axis < static_cast<int32_t>(input_signature.rank()))) + INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis)); + + reduction_values.push_back(axis); + } + } + else if (reduction_indices->dtype() == loco::DataType::S64) + { + auto reduction_size = reduction_indices->size<loco::DataType::S64>(); + for (uint32_t i = 0; i < reduction_size; ++i) + { + int32_t axis = static_cast<int32_t>(reduction_indices->at<loco::DataType::S64>(i)); + if (axis < 0) + axis += input_signature.rank(); + + if (!(0 <= axis && axis < static_cast<int32_t>(input_signature.rank()))) + INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis)); + + reduction_values.push_back(axis); + } + } + else + { + INTERNAL_EXN("Wrong reduction axis type, Only INT32, INT64 supported."); + } + + if (keep_dims) + { + output_signature.rank(input_signature.rank()); + for (uint32_t i = 0; i < input_signature.rank(); ++i) + output_signature.dim(i) = input_signature.dim(i); + for (uint32_t i = 0; i < reduction_values.size(); ++i) + output_signature.dim(reduction_values.at(i)) = 1; + } + else + { + std::vector<bool> check_reduce(input_signature.rank(), false); + for (uint32_t i = 0; i < reduction_values.size(); ++i) + check_reduce.at(reduction_values.at(i)) = true; + + uint32_t reduce_cnt = 0; + for (uint32_t i = 0; i < check_reduce.size(); ++i) + if (check_reduce.at(i)) + ++reduce_cnt; + + output_signature.rank(input_signature.rank() - reduce_cnt); + for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i) + if (check_reduce.at(i) == false) + output_signature.dim(j++) = input_signature.dim(i); + } + + return output_signature; +} + +ShapeSignature input_arg_signature(const luci::CircleNode *node, uint32_t index) +{ + auto circle_input = loco::must_cast<luci::CircleNode *>(node->arg(index)); + return circle_input->shape_signature(); +} + +} // namespace ssinf + +} // namespace luci diff --git a/compiler/luci/service/src/CircleTypeInference.cpp b/compiler/luci/service/src/CircleTypeInference.cpp index aa8524a55..b4755b51a 100644 --- a/compiler/luci/service/src/CircleTypeInference.cpp +++ b/compiler/luci/service/src/CircleTypeInference.cpp @@ -16,6 +16,8 @@ #include "luci/Service/CircleTypeInference.h" +#include <luci/Log.h> + #include <loco.h> #include <loco/Service/TypeInference.h> @@ -70,3 +72,47 @@ circle::TensorType TypeInference::get(loco::Node *node) } } // namespace luci + +namespace +{ + +bool inputs_dtype_ready(const luci::CircleNode *node) +{ + for (uint32_t arity = 0; arity < node->arity(); ++arity) + { + if (node->dtype() == loco::DataType::Unknown) + return false; + } + + return true; +} + +} // namespace + +namespace luci +{ +namespace tinf +{ + +bool Rule::infer(const luci::CircleNode *circle_node, loco::DataType &dtype) const +{ + LOGGER(l); + VERBOSE(l, 1) << "[CircleTypeInference] " << circle_node->name(); + VERBOSE(l, 1) << " before: " << static_cast<int>(circle_node->dtype()); + + if (!inputs_dtype_ready(circle_node)) + { + VERBOSE(l, 1) << " after: Some inputs are not ready for inference"; + return false; + } + + Algorithm alg; + dtype = circle_node->accept(&alg); + + VERBOSE(l, 1) << " after: " << static_cast<int>(dtype); + + return true; +} + +} // namespace tinf +} // namespace luci diff --git a/compiler/luci/service/src/CircleTypeInferenceHelper.cpp b/compiler/luci/service/src/CircleTypeInferenceHelper.cpp new file mode 100644 index 000000000..75cd9f7b2 --- /dev/null +++ b/compiler/luci/service/src/CircleTypeInferenceHelper.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Service/CircleTypeInferenceHelper.h" + +namespace luci +{ +namespace tinf +{ + +// Helper function will be added + +} // namespace tinf +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleInput.cpp b/compiler/luci/service/src/Nodes/CircleInput.cpp new file mode 100644 index 000000000..24eab7bd6 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleInput.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleInput *node) +{ + return node->shape_signature(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleMean.cpp b/compiler/luci/service/src/Nodes/CircleMean.cpp new file mode 100644 index 000000000..a78713698 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleMean.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleMean *node) +{ + return legalized_signature( + reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleOutput.cpp b/compiler/luci/service/src/Nodes/CircleOutput.cpp new file mode 100644 index 000000000..d4c8da2d8 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleOutput.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutput *node) +{ + return input_arg_signature(node, 0); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp b/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp new file mode 100644 index 000000000..e0f13c439 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutputDummy *) { return ShapeSignature(); } + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp b/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp new file mode 100644 index 000000000..75bbbb3c0 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutputExclude *) +{ + return ShapeSignature(); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReduceAny.cpp b/compiler/luci/service/src/Nodes/CircleReduceAny.cpp new file mode 100644 index 000000000..27da81466 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReduceAny.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceAny *node) +{ + return legalized_signature( + reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReduceMax.cpp b/compiler/luci/service/src/Nodes/CircleReduceMax.cpp new file mode 100644 index 000000000..48d9cb970 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReduceMax.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceMax *node) +{ + return legalized_signature( + reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReduceMin.cpp b/compiler/luci/service/src/Nodes/CircleReduceMin.cpp new file mode 100644 index 000000000..9a9997118 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReduceMin.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceMin *node) +{ + return legalized_signature( + reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReduceProd.cpp b/compiler/luci/service/src/Nodes/CircleReduceProd.cpp new file mode 100644 index 000000000..a9d381a74 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReduceProd.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceProd *node) +{ + return legalized_signature( + reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRelu.cpp b/compiler/luci/service/src/Nodes/CircleRelu.cpp new file mode 100644 index 000000000..a7a7f6f0a --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRelu.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleRelu *node) +{ + return input_arg_signature(node, 0); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleRelu6.cpp b/compiler/luci/service/src/Nodes/CircleRelu6.cpp new file mode 100644 index 000000000..92a596d08 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleRelu6.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleRelu6 *node) +{ + return input_arg_signature(node, 0); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp b/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp new file mode 100644 index 000000000..1e8d9971d --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleReluN1To1 *node) +{ + return input_arg_signature(node, 0); +} + +} // namespace luci diff --git a/compiler/luci/service/src/Nodes/CircleSum.cpp b/compiler/luci/service/src/Nodes/CircleSum.cpp new file mode 100644 index 000000000..9ef90e8e0 --- /dev/null +++ b/compiler/luci/service/src/Nodes/CircleSum.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <luci/Service/CircleShapeSignatureInference.h> + +namespace luci +{ + +ShapeSignature ssinf::Algorithm::visit(const luci::CircleSum *node) +{ + return legalized_signature( + reduced_signature(node->input(), node->reduction_indices(), node->keep_dims())); +} + +} // namespace luci diff --git a/compiler/luci/service/src/ShapeDescription.cpp b/compiler/luci/service/src/ShapeDescription.cpp index cbc302f70..01a638f8f 100644 --- a/compiler/luci/service/src/ShapeDescription.cpp +++ b/compiler/luci/service/src/ShapeDescription.cpp @@ -23,6 +23,19 @@ namespace luci { +ShapeDescription to_shape_description(const luci::CircleNode *circle_node) +{ + ShapeDescription res; + + res._rank_known = true; + + res._dims.resize(circle_node->rank()); + for (uint32_t i = 0; i < circle_node->rank(); ++i) + res._dims.at(i) = circle_node->dim(i).value(); + + return res; +} + ShapeDescription to_shape_description(const loco::TensorShape &shape) { ShapeDescription res; diff --git a/compiler/luci/service/src/Validate.cpp b/compiler/luci/service/src/Validate.cpp index d224fd172..3f732b6fe 100644 --- a/compiler/luci/service/src/Validate.cpp +++ b/compiler/luci/service/src/Validate.cpp @@ -42,6 +42,19 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape return os; } +std::ostream &operator<<(std::ostream &os, const luci::CircleNode *circle_node) +{ + os << "["; + for (uint32_t r = 0; r < circle_node->rank(); ++r) + { + if (r) + os << ","; + os << circle_node->dim(r).value(); + } + os << "]"; + return os; +} + /** * @brief returns a node that is CircleOutput with index is out_index in nodes */ @@ -80,23 +93,28 @@ bool validate_shape_dtype(loco::Graph *g) if (dynamic_cast<luci::CircleOutputExclude *>(circle_node)) continue; - assert(loco::shape_known(circle_node)); + assert(circle_node->shape_status() != luci::ShapeStatus::UNDEFINED); // check if output node shape is same as graph output shape - auto co_tensor_shape = loco::shape_get(circle_node).as<loco::TensorShape>(); auto go_tensor_shape = graph_out->shape(); assert(go_tensor_shape); - if (!(co_tensor_shape == *go_tensor_shape)) + + bool is_shape_valid = (circle_node->rank() == go_tensor_shape->rank()); + for (uint32_t i = 0; is_shape_valid && i < circle_node->rank(); ++i) + if (circle_node->dim(i).value() != go_tensor_shape->dim(i).value()) + is_shape_valid = false; + + if (is_shape_valid == false) { INFO(l) << "[luci] Shape for output #" << out_index << " not same " << std::endl; - INFO(l) << "[luci] " << circle_node->name() << " " << co_tensor_shape << " vs " + INFO(l) << "[luci] " << circle_node->name() << " " << circle_node << " vs " << *go_tensor_shape << std::endl; return false; } // check if data type match - assert(loco::dtype_known(circle_node)); - if (graph_out->dtype() != loco::dtype_get(circle_node)) + assert(circle_node->dtype() != loco::DataType::Unknown); + if (graph_out->dtype() != circle_node->dtype()) { INFO(l) << "[luci] Type for output #" << out_index << " not same " << std::endl; return false; @@ -106,6 +124,55 @@ bool validate_shape_dtype(loco::Graph *g) return true; } +bool validate_shape_signature(loco::Graph *g) +{ + LOGGER(l); + + for (auto node : loco::postorder_traversal(loco::output_nodes(g))) + { + auto circle_node = loco::must_cast<luci::CircleNode *>(node); + const auto shape_signature = circle_node->shape_signature(); + + if (shape_signature.rank() == 0) + continue; + + // Rank of shape and shape signature should be same + if (circle_node->rank() != shape_signature.rank()) + { + INFO(l) << "[luci] Rank of shape signature for " << circle_node->name() << " do not match" + << std::endl; + return false; + } + + bool has_unknown = false; + + // If shape siganture is not -1, dimension value should be same + for (uint32_t d = 0; d < shape_signature.rank(); ++d) + { + if (shape_signature.dim(d) != -1 && + shape_signature.dim(d) != (int32_t)(circle_node->dim(d).value())) + { + INFO(l) << "[luci] Dimension " << d << "of shape signature for " << circle_node->name() + << " do not match" << std::endl; + return false; + } + + if (shape_signature.dim(d) == -1) + has_unknown = true; + } + + // Shape signature should have at least one -1 value. + if (!has_unknown) + { + INFO(l) << "[luci] Shape signature in " << circle_node->name() + << " do not have unknown dimension" << std::endl; + return false; + } + } + + return true; +} + } // namespace namespace luci @@ -119,6 +186,9 @@ bool validate(loco::Graph *g) if (!validate_shape_dtype(g)) return false; + if (!validate_shape_signature(g)) + return false; + // TODO add more validation return true; diff --git a/compiler/luci/tester/src/ReadTester.cpp b/compiler/luci/tester/src/ReadTester.cpp index a1aead1bd..f270a232c 100644 --- a/compiler/luci/tester/src/ReadTester.cpp +++ b/compiler/luci/tester/src/ReadTester.cpp @@ -21,6 +21,9 @@ #include <luci/Pass/ShapeInferencePass.h> #include <luci/Pass/TypeInferencePass.h> +// Following passes will be removed after refactoring is finished +#include <luci/Pass/MigrateLegacyShapeDtypePass.h> + #include <iostream> #include <map> #include <string> @@ -95,6 +98,12 @@ int entry(int argc, char **argv) while (pass.run(graph) == true) ; } + { + // This pass will be removed after refactoring is finished + luci::MigrateLegacyShapeDtypePass pass; + while (pass.run(graph) == true) + ; + } if (!luci::validate(graph)) return 255; diff --git a/compiler/luci/tester/src/WriteTester.cpp b/compiler/luci/tester/src/WriteTester.cpp index aa7085c77..9a6e8de05 100644 --- a/compiler/luci/tester/src/WriteTester.cpp +++ b/compiler/luci/tester/src/WriteTester.cpp @@ -23,6 +23,9 @@ #include <luci/CircleExporter.h> #include <oops/InternalExn.h> +// Following passes will be removed after refactoring is finished +#include <luci/Pass/MigrateLegacyShapeDtypePass.h> + #include <fstream> #include <iostream> #include <map> @@ -139,6 +142,12 @@ int entry(int argc, char **argv) while (pass.run(graph) == true) ; } + { + // This pass will be removed after refactoring is finished + luci::MigrateLegacyShapeDtypePass pass; + while (pass.run(graph) == true) + ; + } if (!luci::validate(graph)) return 255; |