summaryrefslogtreecommitdiff
path: root/compiler/luci
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2020-12-14 14:43:43 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2020-12-14 14:43:43 +0900
commit62529acabbafce7730601ed01d5709d7bc0d378a (patch)
treebf6912cfa8fac4a2997292bfcb3c82055734c97e /compiler/luci
parent6ea13af5257155ff993c205cf997b870cc627f73 (diff)
downloadnnfw-62529acabbafce7730601ed01d5709d7bc0d378a.tar.gz
nnfw-62529acabbafce7730601ed01d5709d7bc0d378a.tar.bz2
nnfw-62529acabbafce7730601ed01d5709d7bc0d378a.zip
Imported Upstream version 1.12.0upstream/1.12.0
Diffstat (limited to 'compiler/luci')
-rw-r--r--compiler/luci/export/src/CircleExporterImpl.cpp7
-rw-r--r--compiler/luci/export/src/CircleExporterUtils.cpp16
-rw-r--r--compiler/luci/export/src/CircleExporterUtils.h2
-rw-r--r--compiler/luci/export/src/CircleOperationExporter.cpp4
-rw-r--r--compiler/luci/export/src/CircleTensorExporter.cpp7
-rw-r--r--compiler/luci/export/src/Optimize.cpp2
-rw-r--r--compiler/luci/export/src/SerializedData.h2
-rw-r--r--compiler/luci/import/include/luci/Import/CircleReader.h2
-rw-r--r--compiler/luci/import/src/CircleReader.cpp16
-rw-r--r--compiler/luci/import/src/Nodes/CircleFullyConnected.cpp7
-rw-r--r--compiler/luci/lang/include/luci/IR/AttrDilation.h14
-rw-r--r--compiler/luci/lang/include/luci/IR/AttrFilter.h14
-rw-r--r--compiler/luci/lang/include/luci/IR/AttrStride.h14
-rw-r--r--compiler/luci/lang/include/luci/IR/CircleShapeSignature.h2
-rw-r--r--compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h17
-rw-r--r--compiler/luci/lang/src/AttrDilation.cpp36
-rw-r--r--compiler/luci/lang/src/AttrDilation.test.cpp36
-rw-r--r--compiler/luci/lang/src/AttrFilter.cpp36
-rw-r--r--compiler/luci/lang/src/AttrFilter.test.cpp36
-rw-r--r--compiler/luci/lang/src/AttrStride.cpp36
-rw-r--r--compiler/luci/lang/src/AttrStride.test.cpp36
-rw-r--r--compiler/luci/lang/src/CircleShapeSignature.cpp34
-rw-r--r--compiler/luci/pass/include/luci/CircleOptimizer.h8
-rw-r--r--compiler/luci/pass/include/luci/ModulePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h42
-rw-r--r--compiler/luci/pass/include/luci/Pass/FuseBCQPass.h5
-rw-r--r--compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h44
-rw-r--r--compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h5
-rw-r--r--compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h42
-rw-r--r--compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/TypeInferencePass.h5
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp85
-rw-r--r--compiler/luci/pass/src/CircleTypeInferencePass.cpp59
-rw-r--r--compiler/luci/pass/src/FuseBCQPass.cpp291
-rw-r--r--compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp112
-rw-r--r--compiler/luci/pass/src/ModulePhase.cpp71
-rw-r--r--compiler/luci/pass/src/ModulePhase.h67
-rw-r--r--compiler/luci/pass/src/ProgressReporter.cpp42
-rw-r--r--compiler/luci/pass/src/ProgressReporter.h26
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.cpp102
-rw-r--r--compiler/luci/pass/src/PropagateQuantParamPass.test.cpp118
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp149
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTranspose.cpp127
-rw-r--r--compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp156
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp223
-rw-r--r--compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp142
-rw-r--r--compiler/luci/pass/src/ShapeInferencePass.cpp13
-rw-r--r--compiler/luci/pass/src/ShapeSignatureInferencePass.cpp63
-rw-r--r--compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp139
-rw-r--r--compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp118
-rw-r--r--compiler/luci/pass/src/SubstitutePackToReshapePass.cpp107
-rw-r--r--compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp124
-rw-r--r--compiler/luci/pass/src/TypeInferencePass.cpp13
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeInference.h153
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h36
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h (renamed from compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceRule.h)42
-rw-r--r--compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h45
-rw-r--r--compiler/luci/service/include/luci/Service/CircleTypeInference.h153
-rw-r--r--compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h34
-rw-r--r--compiler/luci/service/include/luci/Service/ShapeDescription.h3
-rw-r--r--compiler/luci/service/src/CircleShapeInference.cpp60
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceHelper.cpp34
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceRule.cpp4
-rw-r--r--compiler/luci/service/src/CircleShapeSignatureInference.cpp (renamed from compiler/luci/service/src/CircleShapeSignatureInferenceRule.cpp)12
-rw-r--r--compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp160
-rw-r--r--compiler/luci/service/src/CircleTypeInference.cpp46
-rw-r--r--compiler/luci/service/src/CircleTypeInferenceHelper.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleInput.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleMean.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleOutput.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleOutputDummy.cpp24
-rw-r--r--compiler/luci/service/src/Nodes/CircleOutputExclude.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceAny.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceMax.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceMin.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleReduceProd.cpp28
-rw-r--r--compiler/luci/service/src/Nodes/CircleRelu.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleRelu6.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleReluN1To1.cpp27
-rw-r--r--compiler/luci/service/src/Nodes/CircleSum.cpp28
-rw-r--r--compiler/luci/service/src/ShapeDescription.cpp13
-rw-r--r--compiler/luci/service/src/Validate.cpp82
-rw-r--r--compiler/luci/tester/src/ReadTester.cpp9
-rw-r--r--compiler/luci/tester/src/WriteTester.cpp9
88 files changed, 4224 insertions, 148 deletions
diff --git a/compiler/luci/export/src/CircleExporterImpl.cpp b/compiler/luci/export/src/CircleExporterImpl.cpp
index 860cebf6e..df7542797 100644
--- a/compiler/luci/export/src/CircleExporterImpl.cpp
+++ b/compiler/luci/export/src/CircleExporterImpl.cpp
@@ -16,7 +16,6 @@
#include "CircleExporterImpl.h"
#include "Optimize.h"
-#include "TypeBridge.h"
#include "CircleTensorExporter.h"
#include "CircleOperationExporter.h"
#include "CircleExporterUtils.h"
@@ -150,9 +149,6 @@ void CircleExporterImpl::exportGraph(loco::Graph *graph)
// do graph optimization
optimize(graph);
- // copy shape/dtype inference data to CircleNode
- copy_shape_dtype(graph);
-
_builder.Clear();
SerializedModelData md;
@@ -223,9 +219,6 @@ void CircleExporterImpl::exportModule(Module *module)
optimize(graph);
- // copy shape/dtype inference data to CircleNode
- copy_shape_dtype(graph);
-
SerializedGraphData gd;
// set Subgraph name
diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp
index 1fdb40e51..3715513e0 100644
--- a/compiler/luci/export/src/CircleExporterUtils.cpp
+++ b/compiler/luci/export/src/CircleExporterUtils.cpp
@@ -87,6 +87,22 @@ circle::MirrorPadMode to_circle_mirrorpadmode(luci::MirrorPadMode mode)
}
}
+circle::FullyConnectedOptionsWeightsFormat
+to_circle_weightsformat(luci::CircleFullyConnected::WeightsFormat format)
+{
+ switch (format)
+ {
+ case luci::CircleFullyConnected::WeightsFormat::DEFAULT:
+ return circle::FullyConnectedOptionsWeightsFormat_DEFAULT;
+ case luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8:
+ return circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
+ case luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32:
+ return circle::FullyConnectedOptionsWeightsFormat_SHUFFLED16x1FLOAT32;
+ default:
+ INTERNAL_EXN_V("trying to convert unsupported luci::WeightsFormat", oops::to_uint32(format));
+ }
+}
+
circle::DimensionType to_circle_dimensiontype(luci::DimensionType type)
{
switch (type)
diff --git a/compiler/luci/export/src/CircleExporterUtils.h b/compiler/luci/export/src/CircleExporterUtils.h
index 7857213b2..95310b353 100644
--- a/compiler/luci/export/src/CircleExporterUtils.h
+++ b/compiler/luci/export/src/CircleExporterUtils.h
@@ -32,6 +32,8 @@ namespace luci
circle::ActivationFunctionType to_circle_actfunc(luci::FusedActFunc func);
circle::TensorType to_circle_tensortype(loco::DataType type);
circle::MirrorPadMode to_circle_mirrorpadmode(luci::MirrorPadMode mode);
+circle::FullyConnectedOptionsWeightsFormat
+to_circle_weightsformat(luci::CircleFullyConnected::WeightsFormat format);
circle::DimensionType to_circle_dimensiontype(luci::DimensionType type);
flatbuffers::Offset<void> to_circle_sparse_index_vector(flatbuffers::FlatBufferBuilder &fb,
const SparseIndexVector &sparse_idx_vec);
diff --git a/compiler/luci/export/src/CircleOperationExporter.cpp b/compiler/luci/export/src/CircleOperationExporter.cpp
index c937109cd..4343cf3c9 100644
--- a/compiler/luci/export/src/CircleOperationExporter.cpp
+++ b/compiler/luci/export/src/CircleOperationExporter.cpp
@@ -21,7 +21,6 @@
#include <luci/IR/CircleNode.h>
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
-#include <luci/Service/CircleShapeInference.h>
#include <luci/UserSettings.h>
#include <luci/Log.h>
@@ -930,7 +929,8 @@ void OperationExporter::visit(luci::CircleFullyConnected *node)
{
export_simple(
node, circle::BuiltinOperator_FULLY_CONNECTED, circle::BuiltinOptions_FullyConnectedOptions,
- CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()))
+ CreateFullyConnectedOptions(_ctx.builder, to_circle_actfunc(node->fusedActivationFunction()),
+ to_circle_weightsformat(node->weights_format()))
.Union());
}
diff --git a/compiler/luci/export/src/CircleTensorExporter.cpp b/compiler/luci/export/src/CircleTensorExporter.cpp
index 1429d2810..9bdfa0079 100644
--- a/compiler/luci/export/src/CircleTensorExporter.cpp
+++ b/compiler/luci/export/src/CircleTensorExporter.cpp
@@ -111,10 +111,10 @@ void allocateCircleTensorInfo(CircleNode *node, CircleTensorContext &ctx)
CircleTensoInfo tensor_info;
tensor_info.name(tensor_name);
- tensor_info.dtype(to_circle_tensortype(luci::node_dtype(node)));
+ tensor_info.dtype(to_circle_tensortype(node->dtype()));
tensor_info.shape_signature(node->shape_signature());
if (node->shape_status() == ShapeStatus::VALID)
- tensor_info.shape(to_shape_description(luci::node_shape(node)));
+ tensor_info.shape(to_shape_description(node));
tensor_info.shape_status(node->shape_status());
tensor_info.content(dynamic_cast<luci::CircleConst *>(node));
@@ -243,6 +243,9 @@ flatbuffers::Offset<Vector<int32_t>> encodeShape(FlatBufferBuilder &builder,
flatbuffers::Offset<Vector<int32_t>> encodeShapeSignature(FlatBufferBuilder &builder,
const ShapeSignature &shape_signature)
{
+ if (shape_signature.rank() == 0)
+ return 0;
+
return builder.CreateVector(shape_signature.as_vector());
}
diff --git a/compiler/luci/export/src/Optimize.cpp b/compiler/luci/export/src/Optimize.cpp
index 6fa50b564..036a4a2f9 100644
--- a/compiler/luci/export/src/Optimize.cpp
+++ b/compiler/luci/export/src/Optimize.cpp
@@ -18,6 +18,7 @@
#include "ProgressReporter.h"
#include <luci/Pass/ShapeInferencePass.h>
+#include <luci/Pass/ShapeSignatureInferencePass.h>
#include <luci/Pass/TypeInferencePass.h>
#include <logo/Phase.h>
@@ -34,6 +35,7 @@ void optimize(loco::Graph *g)
// prepare type and shape before optimization
phase.emplace_back(std::make_unique<TypeInferencePass>());
phase.emplace_back(std::make_unique<ShapeInferencePass>());
+ phase.emplace_back(std::make_unique<ShapeSignatureInferencePass>());
// TODO add more optimization passes (with a knob)
}
diff --git a/compiler/luci/export/src/SerializedData.h b/compiler/luci/export/src/SerializedData.h
index 46b1ac2d5..c41f50edd 100644
--- a/compiler/luci/export/src/SerializedData.h
+++ b/compiler/luci/export/src/SerializedData.h
@@ -64,7 +64,7 @@ namespace luci
{
/**
- * @breif Record the information of T/F Lite SubGraph and its mapping to loco
+ * @brief Record the information of T/F Lite SubGraph and its mapping to loco
*/
struct SubGraphContext
{
diff --git a/compiler/luci/import/include/luci/Import/CircleReader.h b/compiler/luci/import/include/luci/Import/CircleReader.h
index 8636b1d9a..8e210dd77 100644
--- a/compiler/luci/import/include/luci/Import/CircleReader.h
+++ b/compiler/luci/import/include/luci/Import/CircleReader.h
@@ -46,6 +46,8 @@ loco::DataType luci_datatype(circle::TensorType type);
FusedActFunc luci_actfunc(const circle::ActivationFunctionType type);
Padding luci_padding(const circle::Padding padding);
MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode);
+luci::CircleFullyConnected::WeightsFormat
+luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format);
std::unique_ptr<CircleQuantParam>
luci_quantparam(const circle::QuantizationParametersT *quantization);
diff --git a/compiler/luci/import/src/CircleReader.cpp b/compiler/luci/import/src/CircleReader.cpp
index 068de5239..b33c920b1 100644
--- a/compiler/luci/import/src/CircleReader.cpp
+++ b/compiler/luci/import/src/CircleReader.cpp
@@ -151,6 +151,22 @@ MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode)
return MirrorPadMode::UNDEFINED;
}
+luci::CircleFullyConnected::WeightsFormat
+luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format)
+{
+ switch (weights_format)
+ {
+ case circle::FullyConnectedOptionsWeightsFormat_DEFAULT:
+ return luci::CircleFullyConnected::WeightsFormat::DEFAULT;
+ case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ return luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8;
+ case circle::FullyConnectedOptionsWeightsFormat_SHUFFLED16x1FLOAT32:
+ return luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32;
+ default:
+ throw std::runtime_error("Invalid FullyConnectedOptionsWeightsFormat");
+ }
+}
+
DimensionType luci_dim_type(const circle::DimensionType dim_type)
{
switch (dim_type)
diff --git a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
index 65a863bde..17293ad7a 100644
--- a/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
+++ b/compiler/luci/import/src/Nodes/CircleFullyConnected.cpp
@@ -53,12 +53,7 @@ CircleNode *CircleFullyConnectedGraphBuilder::build_node(const circle::OperatorT
const auto *options = op.builtin_options.AsFullyConnectedOptions();
node->fusedActivationFunction(luci_actfunc(options->fused_activation_function));
- if (options->weights_format != circle::FullyConnectedOptionsWeightsFormat_DEFAULT)
- {
- throw oops::UserExn(
- "Unsupported weights format",
- circle::EnumNameFullyConnectedOptionsWeightsFormat(options->weights_format));
- }
+ node->weights_format(luci_weights_format(options->weights_format));
return node;
}
diff --git a/compiler/luci/lang/include/luci/IR/AttrDilation.h b/compiler/luci/lang/include/luci/IR/AttrDilation.h
index c2b28d77d..ed8232576 100644
--- a/compiler/luci/lang/include/luci/IR/AttrDilation.h
+++ b/compiler/luci/lang/include/luci/IR/AttrDilation.h
@@ -27,15 +27,17 @@ class Dilation final
public:
Dilation() : _w(1), _h(1) {}
- int32_t w() const { return _w; }
- void w(int32_t w) { _w = w; }
+ uint32_t w() const { return _w; }
+ void w(uint32_t w) { _w = w; }
+ void w(int32_t w);
- int32_t h() const { return _h; }
- void h(int32_t h) { _h = h; }
+ uint32_t h() const { return _h; }
+ void h(uint32_t h) { _h = h; }
+ void h(int32_t h);
private:
- int32_t _w;
- int32_t _h;
+ uint32_t _w;
+ uint32_t _h;
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/AttrFilter.h b/compiler/luci/lang/include/luci/IR/AttrFilter.h
index 7909fa523..af9d7519f 100644
--- a/compiler/luci/lang/include/luci/IR/AttrFilter.h
+++ b/compiler/luci/lang/include/luci/IR/AttrFilter.h
@@ -27,15 +27,17 @@ class Filter final
public:
Filter() : _w(1), _h(1) {}
- int32_t w() const { return _w; }
- void w(int32_t w) { _w = w; }
+ uint32_t w() const { return _w; }
+ void w(uint32_t w) { _w = w; }
+ void w(int32_t w);
- int32_t h() const { return _h; }
- void h(int32_t h) { _h = h; }
+ uint32_t h() const { return _h; }
+ void h(uint32_t h) { _h = h; }
+ void h(int32_t h);
private:
- int32_t _w;
- int32_t _h;
+ uint32_t _w;
+ uint32_t _h;
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/AttrStride.h b/compiler/luci/lang/include/luci/IR/AttrStride.h
index 654967d73..6be697975 100644
--- a/compiler/luci/lang/include/luci/IR/AttrStride.h
+++ b/compiler/luci/lang/include/luci/IR/AttrStride.h
@@ -27,15 +27,17 @@ class Stride final
public:
Stride() : _w(1), _h(1) {}
- int32_t w() const { return _w; }
- void w(int32_t w) { _w = w; }
+ uint32_t w() const { return _w; }
+ void w(uint32_t w) { _w = w; }
+ void w(int32_t w);
- int32_t h() const { return _h; }
- void h(int32_t h) { _h = h; }
+ uint32_t h() const { return _h; }
+ void h(uint32_t h) { _h = h; }
+ void h(int32_t h);
private:
- int32_t _w;
- int32_t _h;
+ uint32_t _w;
+ uint32_t _h;
};
} // namespace luci
diff --git a/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h b/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h
index 970f1b521..18a260486 100644
--- a/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h
+++ b/compiler/luci/lang/include/luci/IR/CircleShapeSignature.h
@@ -46,6 +46,8 @@ private:
std::vector<int32_t> _shape_signature{};
};
+bool operator==(const ShapeSignature &lhs, const ShapeSignature &rhs);
+
} // namespace luci
#endif // __LUCI_IR_SHAPE_SIGNATURE_H__
diff --git a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
index d78f39494..952befc87 100644
--- a/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
+++ b/compiler/luci/lang/include/luci/IR/Nodes/CircleFullyConnected.h
@@ -35,6 +35,16 @@ class CircleFullyConnected final
public LuciNodeMixin<LuciNodeTrait::Bias>
{
public:
+ enum class WeightsFormat
+ {
+ UNDEFINED, // This is not defined by Circle. This was added to prevent programming error.
+
+ DEFAULT,
+ SHUFFLED4x16INT8,
+ SHUFFLED16x1FLOAT32,
+ };
+
+public:
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }
@@ -43,6 +53,13 @@ public:
loco::Node *bias(void) const override { return at(2)->node(); }
void bias(loco::Node *node) override { at(2)->node(node); }
+
+public:
+ WeightsFormat weights_format(void) const { return _weights_format; }
+ void weights_format(WeightsFormat weights_format) { _weights_format = weights_format; }
+
+private:
+ WeightsFormat _weights_format{WeightsFormat::DEFAULT};
};
} // namespace luci
diff --git a/compiler/luci/lang/src/AttrDilation.cpp b/compiler/luci/lang/src/AttrDilation.cpp
new file mode 100644
index 000000000..a9f479502
--- /dev/null
+++ b/compiler/luci/lang/src/AttrDilation.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrDilation.h"
+
+#include <cassert>
+
+namespace luci
+{
+
+void Dilation::w(int32_t w)
+{
+ assert(w >= 0);
+ _w = static_cast<uint32_t>(w);
+}
+
+void Dilation::h(int32_t h)
+{
+ assert(h >= 0);
+ _h = static_cast<uint32_t>(h);
+}
+
+} // namespace luci
diff --git a/compiler/luci/lang/src/AttrDilation.test.cpp b/compiler/luci/lang/src/AttrDilation.test.cpp
new file mode 100644
index 000000000..3e4658990
--- /dev/null
+++ b/compiler/luci/lang/src/AttrDilation.test.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrDilation.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleAttrDilationTest, set)
+{
+ auto d = luci::Dilation();
+
+ d.h(10u);
+ d.w(10u);
+
+ ASSERT_EQ(d.h(), 10u);
+ ASSERT_EQ(d.w(), 10u);
+
+ d.h(10); // int32_t
+ d.w(10);
+
+ ASSERT_EQ(d.h(), 10u);
+ ASSERT_EQ(d.w(), 10u);
+}
diff --git a/compiler/luci/lang/src/AttrFilter.cpp b/compiler/luci/lang/src/AttrFilter.cpp
new file mode 100644
index 000000000..9c571e7f5
--- /dev/null
+++ b/compiler/luci/lang/src/AttrFilter.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrFilter.h"
+
+#include <cassert>
+
+namespace luci
+{
+
+void Filter::w(int32_t w)
+{
+ assert(w >= 0);
+ _w = static_cast<uint32_t>(w);
+}
+
+void Filter::h(int32_t h)
+{
+ assert(h >= 0);
+ _h = static_cast<uint32_t>(h);
+}
+
+} // namespace luci
diff --git a/compiler/luci/lang/src/AttrFilter.test.cpp b/compiler/luci/lang/src/AttrFilter.test.cpp
new file mode 100644
index 000000000..06dbcacd5
--- /dev/null
+++ b/compiler/luci/lang/src/AttrFilter.test.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrFilter.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleAttrFilterTest, set)
+{
+ auto f = luci::Filter();
+
+ f.h(10u);
+ f.w(10u);
+
+ ASSERT_EQ(f.h(), 10u);
+ ASSERT_EQ(f.w(), 10u);
+
+ f.h(10); // int32_t
+ f.w(10);
+
+ ASSERT_EQ(f.h(), 10u);
+ ASSERT_EQ(f.w(), 10u);
+}
diff --git a/compiler/luci/lang/src/AttrStride.cpp b/compiler/luci/lang/src/AttrStride.cpp
new file mode 100644
index 000000000..9720d12b5
--- /dev/null
+++ b/compiler/luci/lang/src/AttrStride.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrStride.h"
+
+#include <cassert>
+
+namespace luci
+{
+
+void Stride::w(int32_t w)
+{
+ assert(w >= 0);
+ _w = static_cast<uint32_t>(w);
+}
+
+void Stride::h(int32_t h)
+{
+ assert(h >= 0);
+ _h = static_cast<uint32_t>(h);
+}
+
+} // namespace luci
diff --git a/compiler/luci/lang/src/AttrStride.test.cpp b/compiler/luci/lang/src/AttrStride.test.cpp
new file mode 100644
index 000000000..e91365bd5
--- /dev/null
+++ b/compiler/luci/lang/src/AttrStride.test.cpp
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/AttrStride.h"
+
+#include <gtest/gtest.h>
+
+TEST(CircleAttrStrideTest, set)
+{
+ auto s = luci::Stride();
+
+ s.h(10u);
+ s.w(10u);
+
+ ASSERT_EQ(s.h(), 10u);
+ ASSERT_EQ(s.w(), 10u);
+
+ s.h(10); // int32_t
+ s.w(10);
+
+ ASSERT_EQ(s.h(), 10u);
+ ASSERT_EQ(s.w(), 10u);
+}
diff --git a/compiler/luci/lang/src/CircleShapeSignature.cpp b/compiler/luci/lang/src/CircleShapeSignature.cpp
new file mode 100644
index 000000000..970000203
--- /dev/null
+++ b/compiler/luci/lang/src/CircleShapeSignature.cpp
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/IR/CircleShapeSignature.h"
+
+namespace luci
+{
+
+bool operator==(const ShapeSignature &lhs, const ShapeSignature &rhs)
+{
+ if (lhs.rank() != rhs.rank())
+ return false;
+
+ for (uint32_t i = 0; i < lhs.rank(); ++i)
+ if (lhs.dim(i) != rhs.dim(i))
+ return false;
+
+ return true;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h
index db5bdb501..906760e0a 100644
--- a/compiler/luci/pass/include/luci/CircleOptimizer.h
+++ b/compiler/luci/pass/include/luci/CircleOptimizer.h
@@ -19,6 +19,8 @@
#include <loco.h>
+#include <luci/IR/Module.h>
+
#include <string>
#include <vector>
@@ -47,6 +49,10 @@ public:
FusePreActivationBatchNorm,
MakeBatchNormGammaPositive,
FuseActivationFunction,
+ ShuffleWeightTo16x1Float32,
+ RemoveRedundantTranspose,
+ ReplaceMulAddWithDepthwiseConv,
+ SubstitutePackToReshape,
};
enum AlgorithmParameters
@@ -77,6 +83,8 @@ public:
Options *options(void);
public:
+ void optimize(luci::Module *) const;
+
void optimize(loco::Graph *) const;
void quantize(loco::Graph *) const;
diff --git a/compiler/luci/pass/include/luci/ModulePass.h b/compiler/luci/pass/include/luci/ModulePass.h
new file mode 100644
index 000000000..1835f6e0c
--- /dev/null
+++ b/compiler/luci/pass/include/luci/ModulePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MODULE_PASS_H__
+#define __MODULE_PASS_H__
+
+#include <loco.h>
+#include <logo/Pass.h>
+
+#include <luci/IR/Module.h>
+
+namespace luci
+{
+
+class Pass : public logo::Pass
+{
+public:
+ // Run module pass and return false if there was nothing changed
+ virtual bool run(luci::Module *) = 0;
+};
+
+} // namespace luci
+
+#endif // __MODULE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h b/compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h
new file mode 100644
index 000000000..379b44ccd
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/CircleTypeInferencePass.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_TYPE_INFERENCE_PASS_H__
+#define __LUCI_CIRCLE_TYPE_INFERENCE_PASS_H__
+
+#include <loco.h>
+
+#include <luci/ModulePass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to infer type of circle nodes
+ */
+class CircleTypeInferencePass : public luci::Pass
+{
+public:
+ virtual const char *name(void) const { return "luci::CircleTypeInferencePass"; }
+
+public:
+ bool run(luci::Module *m);
+ bool run(loco::Graph *g);
+};
+
+} // namespace luci
+
+#endif //__LUCI_CIRCLE_TYPE_INFERENCE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h b/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h
index 4404a9fc9..912ad4225 100644
--- a/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h
+++ b/compiler/luci/pass/include/luci/Pass/FuseBCQPass.h
@@ -17,7 +17,7 @@
#ifndef __LUCI_FUSE_BCQ_PASS_H__
#define __LUCI_FUSE_BCQ_PASS_H__
-#include <logo/Pass.h>
+#include <luci/ModulePass.h>
namespace luci
{
@@ -26,10 +26,11 @@ namespace luci
* @brief Class to fuse certain pattern of subgraph into CircleBCQFullyConnected or CircleBCQGather
*
*/
-struct FuseBCQPass final : public logo::Pass
+struct FuseBCQPass final : public luci::Pass
{
const char *name(void) const final { return "luci::FuseBCQPass"; }
+ bool run(luci::Module *m) final;
bool run(loco::Graph *g) final;
};
diff --git a/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h b/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h
new file mode 100644
index 000000000..c0ebc4e5d
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/MigrateLegacyShapeDtypePass.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__
+#define __LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__
+
+#include <loco.h>
+
+#include <luci/ModulePass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to copy shape/dtype of loco to circle node
+ *
+ * CAUTION : This pass will be removed after refactoring is finished
+ */
+class MigrateLegacyShapeDtypePass : public luci::Pass
+{
+public:
+ virtual const char *name(void) const { return "luci::MigrateLegacyShapeDtypePass"; }
+
+public:
+ bool run(luci::Module *m);
+ bool run(loco::Graph *graph);
+};
+
+} // namespace luci
+
+#endif //__LUCI_MIGRATE_LEGACY_SHAPE_DTYPE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h b/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h
new file mode 100644
index 000000000..7e0c44b8c
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/PropagateQuantParamPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
+#define __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to propagate quantization parameters of an operator's output to input
+ */
+struct PropagateQuantParamPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::PropagateQuantParamPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_PROPAGATE_QUANT_PARAM_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h b/compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h
new file mode 100644
index 000000000..ca20da5ac
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveRedundantTransposePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_REDUNDANT_TRANSPOSE_H__
+#define __LUCI_REMOVE_REDUNDANT_TRANSPOSE_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief fuse or remove subsequent Transpose operators
+ */
+struct RemoveRedundantTransposePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveRedundantTransposePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_REDUNDANT_TRANSPOSE_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h b/compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h
new file mode 100644
index 000000000..5dbcc8f5b
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REPLACE_MUL_ADD_WITH_DEPTHWISE_CONV_PASS_H__
+#define __LUCI_REPLACE_MUL_ADD_WITH_DEPTHWISE_CONV_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to replace channel-wise mul/add with CircleDepthwiseConv2D
+ */
+struct ReplaceMulAddWithDepthwiseConvPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::ReplaceMulAddWithDepthwiseConvPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REPLACE_MUL_ADD_WITH_DEPTHWISE_CONV_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h b/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h
index 86bb2ab42..e21ab4cce 100644
--- a/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h
+++ b/compiler/luci/pass/include/luci/Pass/ShapeInferencePass.h
@@ -19,7 +19,7 @@
#include <loco.h>
-#include <logo/Pass.h>
+#include <luci/ModulePass.h>
namespace luci
{
@@ -27,12 +27,13 @@ namespace luci
/**
* @brief Pass to infer shape of nodes
*/
-class ShapeInferencePass : public logo::Pass
+class ShapeInferencePass : public luci::Pass
{
public:
virtual const char *name(void) const { return "luci::ShapeInferencePass"; }
public:
+ bool run(luci::Module *m);
bool run(loco::Graph *graph);
};
diff --git a/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h b/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h
new file mode 100644
index 000000000..2c6ffcf4e
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ShapeSignatureInferencePass.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__
+#define __LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__
+
+#include <loco.h>
+
+#include <luci/ModulePass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Pass to infer shape_signature of nodes
+ */
+class ShapeSignatureInferencePass : public luci::Pass
+{
+public:
+ virtual const char *name(void) const { return "luci::ShapeSignatureInferencePass"; }
+
+public:
+ bool run(luci::Module *m);
+ bool run(loco::Graph *graph);
+};
+
+} // namespace luci
+
+#endif //__LUCI_SHAPE_SIGNATURE_INFERENCE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h b/compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h
new file mode 100644
index 000000000..3d84f5133
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ShuffleWeightTo16x1Float32Pass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SHUFFLE_WEIGHT_TO_16X1_FLOAT32_PASS_H__
+#define __LUCI_SHUFFLE_WEIGHT_TO_16X1_FLOAT32_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to convert weight format of FullyConnected to SHUFFLED16x1FLOAT32
+ */
+struct ShuffleWeightTo16x1Float32Pass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::ShuffleWeightTo16x1Float32Pass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_SHUFFLE_WEIGHT_TO_16X1_FLOAT32_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h b/compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h
new file mode 100644
index 000000000..36d13f19f
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/SubstitutePackToReshapePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_SUBSTITUTE_PACK_TO_RESHAPE_PASS_H__
+#define __LUCI_SUBSTITUTE_PACK_TO_RESHAPE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Substitute Pack with 1 input to single reshape node.
+ */
+struct SubstitutePackToReshapePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::SubstitutePackToReshapePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_SUBSTITUTE_PACK_TO_RESHAPE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h b/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h
index c607ac63f..9d964bdd6 100644
--- a/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h
+++ b/compiler/luci/pass/include/luci/Pass/TypeInferencePass.h
@@ -20,7 +20,7 @@
#include <loco.h>
-#include <logo/Pass.h>
+#include <luci/ModulePass.h>
namespace luci
{
@@ -28,12 +28,13 @@ namespace luci
/**
* @brief Pass to infer type of nodes
*/
-class TypeInferencePass : public logo::Pass
+class TypeInferencePass : public luci::Pass
{
public:
virtual const char *name(void) const { return "luci::TypeInferencePass"; }
public:
+ bool run(luci::Module *m);
bool run(loco::Graph *graph);
};
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index 34f647301..cc9fe481c 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -24,6 +24,9 @@
#include "luci/Pass/FuseInstanceNormPass.h"
#include "luci/Pass/FusePreActivationBatchNormPass.h"
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
+#include "luci/Pass/PropagateQuantParamPass.h"
+#include "luci/Pass/RemoveRedundantTransposePass.h"
+#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
#include "luci/Pass/ResolveCustomOpAddPass.h"
#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMatMulPass.h"
@@ -31,14 +34,21 @@
#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
#include "luci/Pass/SparsifyTensorPass.h"
+#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
+#include "luci/Pass/SubstitutePackToReshapePass.h"
// TODO add more passes
#include "luci/Pass/ShapeInferencePass.h"
+#include "luci/Pass/ShapeSignatureInferencePass.h"
#include "luci/Pass/TypeInferencePass.h"
+// Following passes will be removed after refactoring is finished
+#include "luci/Pass/MigrateLegacyShapeDtypePass.h"
+
// logo passes
#include <logo/RemoveDeadNodeWithQueryPass.h>
+#include "ModulePhase.h"
#include "ProgressReporter.h"
#include "CircleOptimizerUtils.h"
@@ -124,11 +134,44 @@ CircleOptimizer::Options *CircleOptimizer::options(void)
return _options.get();
}
+void CircleOptimizer::optimize(luci::Module *m) const
+{
+ luci::Phase phase;
+
+ // Following passes will be deprecated after refactoring is finished.
+ phase.emplace_back(std::make_unique<luci::MigrateLegacyShapeDtypePass>());
+
+ // Following passes are needed everytime when other passes create new node or modify some nodes.
+ phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::ShapeSignatureInferencePass>());
+ phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
+
+ if (_options->query(Options::Algorithm::FuseBCQ))
+ {
+ phase.emplace_back(std::make_unique<FuseBCQPass>());
+ }
+
+ ModuleProgressReporter prog(m, logo::PhaseStrategy::Restart);
+ PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{m};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+}
+
void CircleOptimizer::optimize(loco::Graph *g) const
{
logo::Phase phase;
/* TRANSFORM DECLARATION BEGIN */
+ phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
+
+ // Following passes will be deprecated after refactoring is finished.
+ phase.emplace_back(std::make_unique<luci::MigrateLegacyShapeDtypePass>());
+
+ // Following passes are needed everytime when other passes create new node or modify some nodes.
+ phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::ShapeSignatureInferencePass>());
+
if (_options->query(Options::Algorithm::ResolveCustomOpAdd))
{
phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
@@ -145,10 +188,6 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
}
- if (_options->query(Options::Algorithm::FuseBCQ))
- {
- phase.emplace_back(std::make_unique<FuseBCQPass>());
- }
if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
{
phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
@@ -173,15 +212,27 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::MakeBatchNormGammaPositivePass>());
}
+ if (_options->query(Options::Algorithm::ShuffleWeightTo16x1Float32))
+ {
+ phase.emplace_back(std::make_unique<luci::ShuffleWeightTo16x1Float32Pass>());
+ }
+ if (_options->query(Options::Algorithm::RemoveRedundantTranspose))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantTransposePass>());
+ }
+ if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
+ {
+ phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
+ }
+ if (_options->query(Options::Algorithm::SubstitutePackToReshape))
+ {
+ phase.emplace_back(std::make_unique<luci::SubstitutePackToReshapePass>());
+ }
- // Shape inference is needed for added nodes doing above transformations
- phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
- phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
- phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
/* TRANSFORM DECLARATION END */
- ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
- logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ ProgressReporter prog(g, logo::PhaseStrategy::Restart);
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
phase_runner.attach(&prog);
phase_runner.run(phase);
}
@@ -258,6 +309,20 @@ void CircleOptimizer::quantize(loco::Graph *g) const
luci::QuantizeWithMinMaxPass quantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype),
str_to_granularity(granularity));
quantizer.run(g);
+
+ // Post-quantization optimizations
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::PropagateQuantParamPass>());
+
+ phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
+ phase.emplace_back(std::make_unique<luci::TypeInferencePass>());
+ phase.emplace_back(std::make_unique<logo::RemoveDeadNodeWithQueryPass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
}
// Requantize
diff --git a/compiler/luci/pass/src/CircleTypeInferencePass.cpp b/compiler/luci/pass/src/CircleTypeInferencePass.cpp
new file mode 100644
index 000000000..67bd253e0
--- /dev/null
+++ b/compiler/luci/pass/src/CircleTypeInferencePass.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/CircleTypeInferencePass.h"
+
+#include <luci/Service/CircleTypeInference.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool CircleTypeInferencePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
+bool CircleTypeInferencePass::run(loco::Graph *g)
+{
+ luci::tinf::Rule type_infer_rule;
+ bool changed = false;
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ loco::DataType dtype;
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+
+ if (type_infer_rule.infer(circle_node, dtype) && circle_node->dtype() != dtype)
+ {
+ circle_node->dtype(dtype);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp
index ebf28779b..c0583d848 100644
--- a/compiler/luci/pass/src/FuseBCQPass.cpp
+++ b/compiler/luci/pass/src/FuseBCQPass.cpp
@@ -25,6 +25,85 @@
namespace
{
+bool is_fusable_const(luci::CircleConst *before, luci::CircleConst *after, bool do_w_x)
+{
+ if (after->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ if (after->rank() != 2)
+ return false;
+
+ if (after->size<loco::DataType::FLOAT32>() != before->size<loco::DataType::FLOAT32>())
+ return false;
+
+ auto after_dim0 = after->dim(0).value();
+ auto after_dim1 = after->dim(1).value();
+
+ if (before->rank() == 2)
+ {
+ if (do_w_x)
+ {
+ // Check for [dim0, dim1] --> [dim0, dim1]
+ if (!(after->dim(0) == before->dim(0) && after->dim(1) == before->dim(1)))
+ return false;
+
+ for (uint32_t i = 0; i < after->size<loco::DataType::FLOAT32>(); ++i)
+ if (after->at<loco::DataType::FLOAT32>(i) != before->at<loco::DataType::FLOAT32>(i))
+ return false;
+ }
+ else
+ {
+ // Check for [dim0, dim1] --> [dim1, dim0]
+ if (!(after->dim(0) == before->dim(1) && after->dim(1) == before->dim(0)))
+ return false;
+
+ for (uint32_t i = 0; i < after_dim0; ++i)
+ for (uint32_t j = 0; j < after_dim1; ++j)
+ if (after->at<loco::DataType::FLOAT32>(i * after_dim1 + j) !=
+ before->at<loco::DataType::FLOAT32>(j * after_dim0 + i))
+ return false;
+ }
+
+ return true;
+ }
+ else if (before->rank() == 3)
+ {
+ if (do_w_x)
+ {
+ // This case is not found yet.
+ return false;
+ }
+ else
+ {
+ // When Einsum op is converted to FullyConnected, original rank can be 3.
+ auto before_dim0 = before->dim(0).value();
+ auto before_dim1 = before->dim(1).value();
+ auto before_dim2 = before->dim(2).value();
+
+ // Check if [dim0, dim1, dim2] --> [dim2, dim0 * dim1] or
+ // [dim0, dim1, dim2] --> [dim1 * dim2, dim0]
+ if ((after_dim0 == before_dim1 * before_dim2 && after_dim1 == before_dim0) ||
+ (after_dim0 == before_dim2 && after_dim1 == before_dim0 * before_dim1))
+ {
+ for (uint32_t i = 0; i < after_dim0; ++i)
+ for (uint32_t j = 0; j < after_dim1; ++j)
+ if (after->at<loco::DataType::FLOAT32>(i * after_dim1 + j) !=
+ before->at<loco::DataType::FLOAT32>(j * after_dim0 + i))
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ return false;
+}
+
+} // namespace
+
+namespace
+{
+
// V means the version of BCQ.
template <int32_t V> class BCQFuser;
@@ -38,11 +117,9 @@ public:
}
public:
- bool fuseBCQ(loco::Graph *g)
+ void register_bcq_info(loco::Graph *g)
{
-
- const auto output_nodes = loco::output_nodes(g);
- for (auto node : output_nodes)
+ for (auto node : loco::output_nodes(g))
{
auto output_node = loco::must_cast<luci::CircleOutput *>(node);
@@ -61,28 +138,29 @@ public:
add_BCQ_info_node(prefix, metadata_type, circle_node);
}
}
+ }
+ bool fuseBCQ(loco::Graph *g)
+ {
if (!is_bcqinfo_valid())
return false;
- for (auto f : _fusable_op)
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
{
- auto prefix = f.first;
- luci::CircleNode *node = f.second;
-
- if (!is_valid_prefix(prefix))
- continue;
-
// Fuse Gather to BCQGather
if (auto gather = dynamic_cast<luci::CircleGather *>(node))
{
if (auto params = dynamic_cast<luci::CircleConst *>(gather->params()))
{
+ auto prefix = get_prefix_of_const(params);
+ if (prefix == -1 || !is_valid_prefix(prefix))
+ continue;
+
auto bcq_gather = g->nodes()->create<luci::CircleBCQGather>();
bcq_gather->op_version(1);
- bcq_gather->input_scales(_alpha[prefix]);
- bcq_gather->input_binary(_packed_binary_code[prefix]);
+ bcq_gather->input_scales(alpha(g, prefix));
+ bcq_gather->input_binary(packed_binary_code(g, prefix));
bcq_gather->indices(gather->indices());
bcq_gather->input_clusters(packed_clusters(g, prefix));
@@ -122,29 +200,20 @@ public:
}
}
- // Einsum is unpacked to FullyConnected, Pack and Reshape
- if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
- {
- node = dynamic_cast<luci::CircleNode *>(reshape->tensor());
- }
- if (auto pack = dynamic_cast<luci::CirclePack *>(node))
- {
- if (pack->values_count() == 1 && pack->rank() == 3)
- {
- node = dynamic_cast<luci::CircleNode *>(pack->values(0));
- }
- }
-
// Fuse FullyConnected to BCQFullyConnected
if (auto fully_connected = dynamic_cast<luci::CircleFullyConnected *>(node))
{
if (auto weights = dynamic_cast<luci::CircleConst *>(fully_connected->weights()))
{
+ auto prefix = get_prefix_of_const(weights);
+ if (prefix == -1 || !is_valid_prefix(prefix))
+ continue;
+
auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
bcq_fc->op_version(1);
- bcq_fc->weights_scales(_alpha[prefix]);
- bcq_fc->weights_binary(_packed_binary_code[prefix]);
+ bcq_fc->weights_scales(alpha(g, prefix));
+ bcq_fc->weights_binary(packed_binary_code(g, prefix));
bcq_fc->bias(fully_connected->bias());
bcq_fc->weights_clusters(packed_clusters(g, prefix));
bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
@@ -179,43 +248,69 @@ public:
}
// If x_w formation, we should insert Transpose in front and back of BCQFullyConnected
- if (_do_w_x[prefix]->at<loco::DataType::BOOL>(0))
- {
- bcq_fc->weights_hidden_size(weights->dim(0).value());
- bcq_fc->input(bcq_input);
- loco::replace(fully_connected).with(bcq_fc);
- }
- else
- {
- bcq_fc->weights_hidden_size(weights->dim(1).value());
+ bcq_fc->weights_hidden_size(weights->dim(1).value());
- auto perm = g->nodes()->create<luci::CircleConst>();
- perm->dtype(loco::DataType::S32);
- perm->size<loco::DataType::S32>(2);
- perm->rank(1);
- perm->dim(0) = 2;
- perm->at<loco::DataType::S32>(0) = 1;
- perm->at<loco::DataType::S32>(1) = 0;
- perm->shape_status(luci::ShapeStatus::VALID);
+ auto perm = g->nodes()->create<luci::CircleConst>();
+ perm->dtype(loco::DataType::S32);
+ perm->size<loco::DataType::S32>(2);
+ perm->rank(1);
+ perm->dim(0) = 2;
+ perm->at<loco::DataType::S32>(0) = 1;
+ perm->at<loco::DataType::S32>(1) = 0;
+ perm->shape_status(luci::ShapeStatus::VALID);
- auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
- input_transpose->a(bcq_input);
- input_transpose->perm(perm);
+ auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
+ input_transpose->a(bcq_input);
+ input_transpose->perm(perm);
- bcq_fc->input(input_transpose);
+ bcq_fc->input(input_transpose);
- auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
- output_transpose->a(bcq_fc);
- output_transpose->perm(perm);
+ auto output_transpose = g->nodes()->create<luci::CircleTranspose>();
+ output_transpose->a(bcq_fc);
+ output_transpose->perm(perm);
- loco::replace(fully_connected).with(output_transpose);
- }
+ loco::replace(fully_connected).with(output_transpose);
return true;
}
- else
+ else if (auto weights_as_input =
+ dynamic_cast<luci::CircleConst *>(fully_connected->input()))
{
- // TODO Is there any case that input() is constant, instead of weights()?
+ auto prefix = get_prefix_of_const(weights_as_input);
+ if (prefix == -1 || !is_valid_prefix(prefix))
+ continue;
+
+ assert(_do_w_x[prefix]->at<loco::DataType::BOOL>(0) == true);
+
+ auto perm = g->nodes()->create<luci::CircleConst>();
+ perm->dtype(loco::DataType::S32);
+ perm->size<loco::DataType::S32>(2);
+ perm->rank(1);
+ perm->dim(0) = 2;
+ perm->at<loco::DataType::S32>(0) = 1;
+ perm->at<loco::DataType::S32>(1) = 0;
+ perm->shape_status(luci::ShapeStatus::VALID);
+
+ auto input_transpose = g->nodes()->create<luci::CircleTranspose>();
+ input_transpose->a(fully_connected->weights());
+ input_transpose->perm(perm);
+
+ auto bcq_fc = g->nodes()->create<luci::CircleBCQFullyConnected>();
+
+ assert(dynamic_cast<luci::CircleOutputExclude *>(fully_connected->bias()) != nullptr);
+
+ bcq_fc->op_version(1);
+ bcq_fc->weights_scales(alpha(g, prefix));
+ bcq_fc->weights_binary(packed_binary_code(g, prefix));
+ bcq_fc->bias(fully_connected->bias());
+ bcq_fc->weights_clusters(packed_clusters(g, prefix));
+ bcq_fc->fusedActivationFunction(fully_connected->fusedActivationFunction());
+
+ bcq_fc->weights_hidden_size(weights_as_input->dim(1).value());
+ bcq_fc->input(input_transpose);
+ loco::replace(fully_connected).with(bcq_fc);
+
+ return true;
}
}
}
@@ -268,6 +363,19 @@ private:
_dequant_weight[prefix] = const_node;
}
+ int32_t get_prefix_of_const(luci::CircleConst *w_after)
+ {
+ for (auto n : _fusable_op)
+ {
+ auto prefix = n.first;
+ auto w_before = loco::must_cast<luci::CircleConst *>(n.second);
+ if (is_fusable_const(w_before, w_after, _do_w_x[prefix]->at<loco::DataType::BOOL>(0)))
+ return prefix;
+ }
+
+ return -1;
+ }
+
bool is_bcqinfo_valid()
{
LOGGER(l);
@@ -332,6 +440,16 @@ private:
}
}
+ for (auto n : _fusable_op)
+ {
+ // fusable_op should be FLOAT32 type
+ if (n.second->dtype() != loco::DataType::FLOAT32)
+ {
+ WARN(l) << "FuseBCQPass : fusable_op has wrong type" << std::endl;
+ return false;
+ }
+ }
+
// As dequant_weight is not used for fusing, skip validation.
return true;
@@ -377,12 +495,50 @@ private:
return false;
}
+ if (_fusable_op.find(prefix) == _fusable_op.end())
+ {
+ WARN(l) << "fusable_op is not found" << std::endl;
+ return false;
+ }
+
// As dequant_weight is not used for fusing, skip validation.
return true;
}
private:
+ luci::CircleConst *alpha(loco::Graph *graph, int32_t prefix)
+ {
+ auto new_alpha = graph->nodes()->create<luci::CircleConst>();
+
+ new_alpha->dtype(loco::DataType::FLOAT32);
+ new_alpha->size<loco::DataType::FLOAT32>(_alpha[prefix]->size<loco::DataType::FLOAT32>());
+ new_alpha->rank(1);
+ new_alpha->dim(0) = _alpha[prefix]->dim(0);
+ for (uint32_t i = 0; i < _alpha[prefix]->size<loco::DataType::FLOAT32>(); ++i)
+ new_alpha->at<loco::DataType::FLOAT32>(i) = _alpha[prefix]->at<loco::DataType::FLOAT32>(i);
+ new_alpha->shape_status(luci::ShapeStatus::VALID);
+
+ return new_alpha;
+ }
+
+ luci::CircleConst *packed_binary_code(loco::Graph *graph, int32_t prefix)
+ {
+ auto new_beta = graph->nodes()->create<luci::CircleConst>();
+
+ new_beta->dtype(loco::DataType::S32);
+ new_beta->size<loco::DataType::S32>(_packed_binary_code[prefix]->size<loco::DataType::S32>());
+ new_beta->rank(2);
+ new_beta->dim(0) = _packed_binary_code[prefix]->dim(0);
+ new_beta->dim(1) = _packed_binary_code[prefix]->dim(1);
+ for (uint32_t i = 0; i < _packed_binary_code[prefix]->size<loco::DataType::S32>(); ++i)
+ new_beta->at<loco::DataType::S32>(i) =
+ _packed_binary_code[prefix]->at<loco::DataType::S32>(i);
+ new_beta->shape_status(luci::ShapeStatus::VALID);
+
+ return new_beta;
+ }
+
luci::CircleConst *packed_clusters(loco::Graph *graph, int32_t prefix)
{
auto qbits_of_clusters = _qbits_of_clusters[prefix];
@@ -428,15 +584,17 @@ private:
namespace luci
{
-bool FuseBCQPass::run(loco::Graph *g)
+bool FuseBCQPass::run(luci::Module *m)
{
bool changed = false;
const int32_t start_magicnum = -2e9 + 27;
const int32_t end_magicnum = 2e9 - 27;
+ loco::Graph *main_graph = m->graph(0);
+
luci::CircleConst *metadata_node = nullptr;
- for (auto node : loco::output_nodes(g))
+ for (auto node : loco::output_nodes(main_graph))
{
auto output_node = loco::must_cast<luci::CircleOutput *>(node);
@@ -474,8 +632,11 @@ bool FuseBCQPass::run(loco::Graph *g)
const auto bundle_cnt = metadata_node->at<loco::DataType::S32>(3);
BCQFuser<1> fuser{original_output_cnt, bundle_cnt};
- if (fuser.fuseBCQ(g))
- changed = true;
+ fuser.register_bcq_info(main_graph);
+
+ for (size_t g = 0; g < m->size(); ++g)
+ if (fuser.fuseBCQ(m->graph(g)))
+ changed = true;
}
else
{
@@ -486,12 +647,12 @@ bool FuseBCQPass::run(loco::Graph *g)
// Remove all of BCQ information nodes iff there is no change
if (changed == false)
{
- for (auto node : loco::output_nodes(g))
+ for (auto node : loco::output_nodes(main_graph))
{
auto output_node = loco::must_cast<luci::CircleOutput *>(node);
if (output_node->index() == 0 || (int)output_node->index() > original_output_cnt)
{
- auto noOp = g->nodes()->create<luci::CircleOutputExclude>();
+ auto noOp = main_graph->nodes()->create<luci::CircleOutputExclude>();
noOp->dtype(loco::DataType::FLOAT32); // TODO Remove this setting
output_node->from(noOp);
changed = true;
@@ -503,4 +664,10 @@ bool FuseBCQPass::run(loco::Graph *g)
return changed;
}
+bool FuseBCQPass::run(loco::Graph *)
+{
+ // Do nothing for graph
+ return false;
+}
+
} // namespace luci
diff --git a/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp b/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp
new file mode 100644
index 000000000..beb962a05
--- /dev/null
+++ b/compiler/luci/pass/src/MigrateLegacyShapeDtypePass.cpp
@@ -0,0 +1,112 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/MigrateLegacyShapeDtypePass.h"
+
+#include <loco/Service/ShapeInference.h>
+#include <loco/Service/TypeInference.h>
+
+#include <luci/IR/CircleNodes.h>
+
+#include <loco.h>
+
+namespace
+{
+
+bool has_same_shape(luci::CircleNode *node, loco::TensorShape shape)
+{
+ if (node->rank() != shape.rank())
+ return false;
+
+ for (uint32_t i = 0; i < shape.rank(); ++i)
+ if (!(node->dim(i) == shape.dim(i)))
+ return false;
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool MigrateLegacyShapeDtypePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
+bool MigrateLegacyShapeDtypePass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::all_nodes(g))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (loco::shape_known(node))
+ {
+ auto loco_shape = loco::shape_get(node).as<loco::TensorShape>();
+
+ assert(circle_node->shape_signature().rank() == 0 ||
+ circle_node->shape_signature().rank() == loco_shape.rank());
+
+ // When shape of loco is copied to circle node, ShapeSignature should be applied.
+ loco::TensorShape new_shape;
+ new_shape.rank(loco_shape.rank());
+ for (uint32_t i = 0; i < loco_shape.rank(); ++i)
+ {
+ if (circle_node->shape_signature().rank() > 0 &&
+ circle_node->shape_signature().dim(i) == -1)
+ new_shape.dim(i) = 1;
+ else
+ new_shape.dim(i) = loco_shape.dim(i);
+ }
+
+ if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED ||
+ !has_same_shape(circle_node, new_shape))
+ {
+ circle_node->rank(new_shape.rank());
+ for (uint32_t i = 0; i < new_shape.rank(); ++i)
+ circle_node->dim(i) = new_shape.dim(i);
+
+ if (circle_node->shape_status() == luci::ShapeStatus::UNDEFINED)
+ circle_node->shape_status(luci::ShapeStatus::VALID);
+
+ changed = true;
+ }
+ }
+
+ if (loco::dtype_known(node))
+ {
+ if (loco::dtype_get(node) != circle_node->dtype())
+ {
+ circle_node->dtype(loco::dtype_get(node));
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ModulePhase.cpp b/compiler/luci/pass/src/ModulePhase.cpp
new file mode 100644
index 000000000..46819a0f7
--- /dev/null
+++ b/compiler/luci/pass/src/ModulePhase.cpp
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ModulePhase.h"
+
+namespace luci
+{
+
+void PhaseRunner<logo::PhaseStrategy::Saturate>::run(const Phase &phase) const
+{
+ notifyPhaseBegin();
+
+ for (bool changed = true; changed;)
+ {
+ changed = false;
+
+ for (auto &pass : phase)
+ {
+ notifyPassBegin(pass.get());
+
+ bool pass_changed = pass->run(_module);
+ changed = changed || pass_changed;
+
+ notifyPassEnd(pass.get(), pass_changed);
+ }
+ }
+
+ notifyPhaseEnd();
+}
+
+void PhaseRunner<logo::PhaseStrategy::Restart>::run(const Phase &phase) const
+{
+ notifyPhaseBegin();
+
+ for (bool changed = true; changed;)
+ {
+ changed = false;
+
+ for (auto &pass : phase)
+ {
+ notifyPassBegin(pass.get());
+
+ bool pass_changed = pass->run(_module);
+ changed = changed || pass_changed;
+
+ notifyPassEnd(pass.get(), pass_changed);
+
+ if (changed)
+ {
+ break;
+ }
+ }
+ }
+
+ notifyPhaseEnd();
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ModulePhase.h b/compiler/luci/pass/src/ModulePhase.h
new file mode 100644
index 000000000..05966cc29
--- /dev/null
+++ b/compiler/luci/pass/src/ModulePhase.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __MODULE_PHASE_H__
+#define __MODULE_PHASE_H__
+
+#include <luci/ModulePass.h>
+
+#include <logo/Phase.h>
+
+#include <vector>
+
+namespace luci
+{
+
+using Phase = std::vector<std::unique_ptr<Pass>>;
+
+template <logo::PhaseStrategy S> class PhaseRunner;
+
+template <>
+class PhaseRunner<logo::PhaseStrategy::Saturate> final : public logo::PhaseRunnerMixinObservable
+{
+public:
+ PhaseRunner(luci::Module *module) : _module{module}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void run(const Phase &) const;
+
+private:
+ luci::Module *_module;
+};
+
+template <>
+class PhaseRunner<logo::PhaseStrategy::Restart> final : public logo::PhaseRunnerMixinObservable
+{
+public:
+ PhaseRunner(luci::Module *module) : _module{module}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void run(const Phase &) const;
+
+private:
+ luci::Module *_module;
+};
+
+} // namespace luci
+
+#endif // __MODULE_PHASE_H__
diff --git a/compiler/luci/pass/src/ProgressReporter.cpp b/compiler/luci/pass/src/ProgressReporter.cpp
index dcf47aba6..515739dc7 100644
--- a/compiler/luci/pass/src/ProgressReporter.cpp
+++ b/compiler/luci/pass/src/ProgressReporter.cpp
@@ -81,4 +81,46 @@ void ProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassE
INFO(prime) << luci::fmt(graph());
}
+void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseBegin> *)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "==============================================================";
+ INFO(prime) << "ModulePhaseRunner<" << to_str(strategy()) << ">";
+ INFO(prime) << "Initial graphs";
+ for (size_t g = 0; g < module()->size(); ++g)
+ {
+ INFO(prime) << "graphs #" << g;
+ INFO(prime) << luci::fmt(module()->graph(g));
+ }
+}
+
+void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseEnd> *)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "ModulePhaseRunner<" << to_str(strategy()) << "> - done";
+}
+
+void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassBegin> *info)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "--------------------------------------------------------------";
+ INFO(prime) << "Before " << logo::pass_name(info->pass());
+}
+
+void ModuleProgressReporter::notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassEnd> *info)
+{
+ LOGGER(prime);
+
+ INFO(prime) << "After " << logo::pass_name(info->pass())
+ << " (changed: " << to_char(info->changed()) << ")";
+ for (size_t g = 0; g < module()->size(); ++g)
+ {
+ INFO(prime) << "graphs #" << g;
+ INFO(prime) << luci::fmt(module()->graph(g));
+ }
+}
+
} // namespace luci
diff --git a/compiler/luci/pass/src/ProgressReporter.h b/compiler/luci/pass/src/ProgressReporter.h
index bd2ba9849..cf30da735 100644
--- a/compiler/luci/pass/src/ProgressReporter.h
+++ b/compiler/luci/pass/src/ProgressReporter.h
@@ -21,6 +21,8 @@
#include <loco.h>
+#include <luci/IR/Module.h>
+
namespace luci
{
@@ -48,6 +50,30 @@ private:
logo::PhaseStrategy _strategy;
};
+class ModuleProgressReporter : public logo::PhaseEventListener
+{
+public:
+ ModuleProgressReporter(luci::Module *module, logo::PhaseStrategy strategy)
+ : _module{module}, _strategy{strategy}
+ {
+ // DO NOTHING
+ }
+
+public:
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseBegin> *) override;
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PhaseEnd> *) override;
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassBegin> *) override;
+ void notify(const logo::PhaseEventInfo<logo::PhaseEvent::PassEnd> *) override;
+
+public:
+ luci::Module *module(void) const { return _module; }
+ logo::PhaseStrategy strategy(void) const { return _strategy; }
+
+private:
+ luci::Module *_module;
+ logo::PhaseStrategy _strategy;
+};
+
} // namespace luci
#endif // __LUCI_PROGRESSREPORTER_H__
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.cpp
new file mode 100644
index 000000000..af83cd83b
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQuantParamPass.cpp
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQuantParamPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+
+#include <iostream>
+
+namespace
+{
+
+bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
+{
+ assert(src->scale.size() == dst->scale.size());
+ assert(src->zerop.size() == dst->zerop.size());
+
+ // src and dst have the same qparam
+ if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
+ std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
+ src->quantized_dimension == dst->quantized_dimension)
+ return false;
+
+ dst->scale.assign(src->scale.begin(), src->scale.end());
+ dst->zerop.assign(src->zerop.begin(), src->zerop.end());
+ dst->quantized_dimension = src->quantized_dimension;
+ return true;
+}
+
+bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
+{
+ // Skip nodes that do not have quantparams
+ auto src_qparam = src->quantparam();
+ if (not src_qparam)
+ return false;
+
+ auto dst_qparam = dst->quantparam();
+ if (not dst_qparam)
+ return false;
+
+ return copy_qparam(src_qparam, dst_qparam);
+}
+
+// Visitor to propagate quantization parameters
+struct PropagateQuantParam final : public luci::CircleNodeMutableVisitor<bool>
+{
+ PropagateQuantParam() = default;
+
+ bool visit(luci::CircleNode *) { return false; }
+
+ bool visit(luci::CircleReshape *node)
+ {
+ auto input = node->tensor();
+ if (loco::succs(input).size() != 1)
+ return false;
+
+ auto input_node = loco::must_cast<luci::CircleNode *>(input);
+ return copy_qparam(node, input_node);
+ }
+
+ // TODO : Add more Ops (e.g., Transpose)
+};
+
+} // namespace
+
+namespace luci
+{
+
+bool PropagateQuantParamPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ LOGGER(l);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl;
+
+ PropagateQuantParam pqp;
+ changed = circle_node->accept(&pqp);
+ if (changed)
+ break;
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
new file mode 100644
index 000000000..15adbfc01
--- /dev/null
+++ b/compiler/luci/pass/src/PropagateQuantParamPass.test.cpp
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/PropagateQuantParamPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void addQuantParam(luci::CircleNode *node, const std::vector<float> &scale,
+ const std::vector<int64_t> &zp)
+{
+ assert(node->quantparam() == nullptr);
+
+ auto quantparam = std::make_unique<luci::CircleQuantParam>();
+ quantparam->scale = scale;
+ quantparam->zerop = zp;
+ node->quantparam(std::move(quantparam));
+}
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [Conv] (qparam 1)
+ * |
+ * [Reshape] (qparam 2)
+ *
+ * AFTER
+ *
+ * [Conv] (qparam 2)
+ * |
+ * [Reshape] (qparam 2)
+ *
+ */
+class SimpleGraph
+{
+public:
+ SimpleGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ conv = g.nodes()->create<luci::CircleConv2D>();
+ reshape = g.nodes()->create<luci::CircleReshape>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ addQuantParam(conv, {0.1, 0.2, 0.3}, {0, 10, 20});
+ addQuantParam(reshape, {0.2, 0.4, 0.6}, {-10, 0, 10});
+
+ conv->input(input);
+ reshape->tensor(conv);
+ output->from(reshape);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input;
+ luci::CircleConv2D *conv;
+ luci::CircleReshape *reshape;
+ luci::CircleOutput *output;
+};
+
+} // namespace
+
+TEST(PropagateQuantParam, simple)
+{
+ SimpleGraph g;
+
+ luci::PropagateQuantParamPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.4, g.conv->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.6, g.conv->quantparam()->scale[2]);
+ EXPECT_EQ(-10, g.conv->quantparam()->zerop[0]);
+ EXPECT_EQ(0, g.conv->quantparam()->zerop[1]);
+ EXPECT_EQ(10, g.conv->quantparam()->zerop[2]);
+}
+
+TEST(PropagateQuantParam, wrong_op_NEG)
+{
+ SimpleGraph g;
+ g.output->from(g.conv);
+ g.reshape->drop();
+
+ luci::PropagateQuantParamPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ EXPECT_FLOAT_EQ(0.1, g.conv->quantparam()->scale[0]);
+ EXPECT_FLOAT_EQ(0.2, g.conv->quantparam()->scale[1]);
+ EXPECT_FLOAT_EQ(0.3, g.conv->quantparam()->scale[2]);
+ EXPECT_EQ(0, g.conv->quantparam()->zerop[0]);
+ EXPECT_EQ(10, g.conv->quantparam()->zerop[1]);
+ EXPECT_EQ(20, g.conv->quantparam()->zerop[2]);
+}
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index 0ecab008f..f6eebe3b9 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -86,6 +86,100 @@ void quant_const_values(luci::CircleConst *const_node, float scaling_factor, flo
}
}
+// Quantize const per channel
+//
+// The last dimension of const is the same as the dimension of channel
+// And the rest of the const dimensions should be 1
+// So, a 'single value' is quantized per channel
+//
+// Quantization spec (f: fp value, q: quantized value)
+//
+// uint8
+// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
+// Negative f: f = (-f) * (q - 1) [q = 0, scale = -f, zp = 1]
+//
+// int16
+// Positive f: f = f * (q - 0) [q = 1, scale = f, zp = 0]
+// Negative f: f = (-f) * (q - 0) [q = -1, scale = -f, zp = 0]
+void quant_const_per_channel(CircleConst *node, loco::DataType quant_type)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+ assert(node->rank() > 0);
+
+ for (uint32_t i = 0; i < node->rank() - 1; i++)
+ {
+ // Caller should call this function when the below condition is satisfied
+ if (node->dim(i).value() != 1)
+ throw std::runtime_error("Non-channel dimension of const node must be 1");
+ }
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ assert(size == node->dim(node->rank() - 1).value());
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->quantized_dimension = node->rank() - 1;
+ std::vector<int32_t> quantized_data(size);
+
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ auto data = node->at<loco::DataType::FLOAT32>(i);
+ if (quant_type == loco::DataType::U8)
+ {
+ if (data >= 0)
+ {
+ quantparam->scale.push_back(data);
+ quantparam->zerop.push_back(0);
+ quantized_data[i] = 1;
+ }
+ else
+ {
+ quantparam->scale.push_back(-data);
+ quantparam->zerop.push_back(1);
+ quantized_data[i] = 0;
+ }
+ }
+ else if (quant_type == loco::DataType::S16)
+ {
+ if (data >= 0)
+ {
+ quantparam->scale.push_back(data);
+ quantized_data[i] = 1;
+ }
+ else
+ {
+ quantparam->scale.push_back(-data);
+ quantized_data[i] = -1;
+ }
+ quantparam->zerop.push_back(0);
+ }
+ }
+ node->quantparam(std::move(quantparam));
+
+ switch (quant_type)
+ {
+ case loco::DataType::U8:
+ node->dtype(loco::DataType::U8);
+ node->size<loco::DataType::U8>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(quantized_data[i] == 0 || quantized_data[i] == 1);
+ node->at<loco::DataType::U8>(i) = quantized_data[i];
+ }
+ break;
+ case loco::DataType::S16:
+ node->dtype(loco::DataType::S16);
+ node->size<loco::DataType::S16>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(quantized_data[i] == -1 || quantized_data[i] == 1);
+ node->at<loco::DataType::S16>(i) = quantized_data[i];
+ }
+ break;
+ default:
+ throw std::runtime_error("Unsupported data type");
+ }
+}
+
void quant_const(CircleConst *node, loco::DataType quant_type)
{
assert(node->dtype() == loco::DataType::FLOAT32);
@@ -612,10 +706,51 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
}
};
+void quant_instnorm(luci::CircleInstanceNorm *node, loco::DataType output_type,
+ QuantizationGranularity granularity)
+{
+ auto gamma = loco::must_cast<luci::CircleConst *>(node->gamma());
+ auto beta = loco::must_cast<luci::CircleConst *>(node->beta());
+ assert(gamma->dtype() == loco::DataType::FLOAT32);
+ assert(beta->dtype() == loco::DataType::FLOAT32);
+
+ if (granularity == QuantizationGranularity::LayerWise)
+ {
+ quant_const(gamma, output_type);
+ quant_const(beta, output_type);
+ }
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ quant_const_per_channel(gamma, output_type);
+ quant_const_per_channel(beta, output_type);
+ }
+ else
+ throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'");
+}
+
+void quant_prelu(luci::CirclePRelu *node, loco::DataType output_type,
+ QuantizationGranularity granularity)
+{
+ auto alpha = loco::must_cast<luci::CircleConst *>(node->alpha());
+ assert(alpha->dtype() == loco::DataType::FLOAT32);
+
+ if (granularity == QuantizationGranularity::LayerWise)
+ {
+ quant_const(alpha, output_type);
+ }
+ else if (granularity == QuantizationGranularity::ChannelWise)
+ {
+ quant_const_per_channel(alpha, output_type);
+ }
+ else
+ throw std::runtime_error("Quantization granularity must be either 'layer' or 'channel'");
+}
+
/**
* @brief Quantize const input tensors using min/max of const values
*/
-void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
+void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type,
+ QuantizationGranularity granularity)
{
auto opcode = node->opcode();
auto arity = node->arity();
@@ -660,20 +795,26 @@ void quantize_const_inputs(luci::CircleNode *node, loco::DataType output_type)
quant_const(const_node, output_type);
break;
+ case luci::CircleOpcode::INSTANCE_NORM:
+ quant_instnorm(loco::must_cast<luci::CircleInstanceNorm *>(node), output_type, granularity);
+ break;
+
+ case luci::CircleOpcode::PRELU:
+ quant_prelu(loco::must_cast<luci::CirclePRelu *>(node), output_type, granularity);
+ break;
+
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::ADD_N:
case luci::CircleOpcode::DIV:
case luci::CircleOpcode::EQUAL:
case luci::CircleOpcode::GREATER:
case luci::CircleOpcode::GREATER_EQUAL:
- case luci::CircleOpcode::INSTANCE_NORM:
case luci::CircleOpcode::LESS:
case luci::CircleOpcode::LESS_EQUAL:
case luci::CircleOpcode::MAXIMUM:
case luci::CircleOpcode::MINIMUM:
case luci::CircleOpcode::MUL:
case luci::CircleOpcode::NOT_EQUAL:
- case luci::CircleOpcode::PRELU:
case luci::CircleOpcode::SUB:
// Quantize all const inputs using their values
for (uint32_t i = 0; i < arity; i++)
@@ -817,7 +958,7 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
- quantize_const_inputs(circle_node, _output_dtype);
+ quantize_const_inputs(circle_node, _output_dtype, _granularity);
}
// Propagate quantization parameters of concat Op
diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.cpp b/compiler/luci/pass/src/RemoveRedundantTranspose.cpp
new file mode 100644
index 000000000..33cb76520
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantTranspose.cpp
@@ -0,0 +1,127 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantTransposePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+/// @brief Return true if first_perm[second_perm[i]] == i
+bool check_perm(const luci::CircleConst *first_perm, const luci::CircleConst *second_perm)
+{
+ assert(first_perm->rank() == 1);
+ assert(second_perm->rank() == 1);
+ assert(second_perm->size<loco::DataType::S32>() == first_perm->size<loco::DataType::S32>());
+ for (int32_t i = 0; i < static_cast<int32_t>(first_perm->size<loco::DataType::S32>()); i++)
+ {
+ if (first_perm->at<loco::DataType::S32>(second_perm->at<loco::DataType::S32>(i)) != i)
+ return false;
+ }
+ return true;
+}
+
+bool remove_consecutive_transpose_function(luci::CircleNode *node)
+{
+ auto target_node = dynamic_cast<luci::CircleTranspose *>(node);
+ if (target_node == nullptr)
+ return false;
+ auto pred_node = dynamic_cast<luci::CircleTranspose *>(target_node->a());
+ if (pred_node == nullptr)
+ return false;
+ if (loco::succs(pred_node).size() != 1)
+ return false;
+
+ auto pred_perm = dynamic_cast<luci::CircleConst *>(target_node->perm());
+ if (pred_perm == nullptr)
+ return false;
+
+ auto main_perm = dynamic_cast<luci::CircleConst *>(pred_node->perm());
+ if (main_perm == nullptr)
+ return false;
+
+ auto main_node = loco::must_cast<luci::CircleNode *>(pred_node->a());
+ if (check_perm(pred_perm, main_perm))
+ {
+ replace(node).with(main_node);
+ }
+ else
+ {
+ auto g = main_perm->graph();
+ auto new_const_node = g->nodes()->create<luci::CircleConst>();
+
+ new_const_node->dtype(loco::DataType::S32);
+ new_const_node->rank(1);
+ new_const_node->dim(0) = main_perm->dim(0);
+ new_const_node->size<loco::DataType::S32>(main_perm->dim(0).value());
+ new_const_node->shape_status(luci::ShapeStatus::VALID);
+ for (uint32_t i = 0; i < main_perm->size<loco::DataType::S32>(); i++)
+ {
+ new_const_node->at<loco::DataType::S32>(i) =
+ pred_perm->at<loco::DataType::S32>(main_perm->at<loco::DataType::S32>(i));
+ }
+ pred_node->perm(new_const_node);
+ replace(node).with(pred_node);
+ }
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+/**
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * (main_node) (main_perm)
+ * \ /
+ * [CircleTranspose] [CircleConst]
+ * (pred_node) (pred_perm)
+ * \ /
+ * [CircleTranspose]
+ * (target_node)
+ * |
+ *
+ * AFTER
+ * <Optional Case>
+ *
+ * | | |
+ * [CircleNode] [CircleConst] |
+ * (main_node) (new_const_node) |
+ * \ / or [CircleNode]
+ * [CircleTranspose] (main_node)
+ * (pred_node) |
+ * | |
+ *
+ */
+bool RemoveRedundantTransposePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (remove_consecutive_transpose_function(circle_node))
+ {
+ changed = true;
+ break;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp b/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp
new file mode 100644
index 000000000..db608b674
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantTranspose.test.cpp
@@ -0,0 +1,156 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/RemoveRedundantTransposePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <vector>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+void setValue(luci::CircleConst *node, const std::vector<int> &v)
+{
+ node->dtype(loco::DataType::S32);
+ node->size<loco::DataType::S32>(v.size());
+ node->rank(1);
+ node->dim(0).set(v.size());
+ for (int i = 0; i < v.size(); ++i)
+ {
+ node->at<loco::DataType::S32>(i) = v[i];
+ }
+}
+
+/**
+ * Type1
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleTranspose] [CircleConst]
+ * \ /
+ * [CircleTranspose]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleNode]
+ * | Remove Both
+ *
+ * --------------------------------------------
+ *
+ * Type2
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleTranspose] [CircleConst]
+ * \ /
+ * [CircleTranspose]
+ * |
+ *
+ * AFTER
+ * | |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleTranspose]
+ * |
+ *
+ */
+void create_redundunt_transpose(loco::Graph *g, const std::vector<int32_t> &perm1,
+ const std::vector<int32_t> &perm2)
+{
+ assert(g);
+
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto graph_input = g->inputs()->create();
+ input->index(graph_input->index());
+
+ // Create perm1
+ auto perm1_node = g->nodes()->create<luci::CircleConst>();
+ setValue(perm1_node, perm1);
+
+ auto transpose1 = g->nodes()->create<luci::CircleTranspose>();
+ transpose1->dtype(loco::DataType::FLOAT32);
+ transpose1->a(input);
+ transpose1->perm(perm1_node);
+
+ // Create perm2
+ auto perm2_node = g->nodes()->create<luci::CircleConst>();
+ setValue(perm2_node, perm2);
+
+ auto transpose2 = g->nodes()->create<luci::CircleTranspose>();
+ transpose2->dtype(loco::DataType::FLOAT32);
+ transpose2->a(transpose1);
+ transpose2->perm(perm2_node);
+
+ // Output
+ auto output = g->nodes()->create<luci::CircleOutput>();
+ output->from(transpose2);
+ auto graph_output = g->outputs()->create();
+ output->index(graph_output->index());
+}
+
+} // namespace
+
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type1)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose(graph.get(), {1, 0, 2, 3}, {1, 0, 2, 3});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ // No transpose node is in graph.
+ ASSERT_EQ(nullptr, transpose_node);
+}
+
+TEST(RemoveRedundantTransposePass, remove_consecutive_transpose_function_type2)
+{
+ auto graph = loco::make_graph();
+ create_redundunt_transpose(graph.get(), {0, 1, 3, 2}, {1, 0, 2, 3});
+
+ luci::RemoveRedundantTransposePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleTranspose *transpose_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto trans = dynamic_cast<luci::CircleTranspose *>(node);
+ if (not trans)
+ continue;
+ transpose_node = trans;
+ break;
+ }
+ // Just one transpose node, with updated perm constant.
+ ASSERT_NE(nullptr, transpose_node);
+ auto perm = loco::must_cast<luci::CircleConst *>(transpose_node->perm());
+ ASSERT_EQ(1, perm->at<loco::DataType::S32>(0));
+ ASSERT_EQ(0, perm->at<loco::DataType::S32>(1));
+ ASSERT_EQ(3, perm->at<loco::DataType::S32>(2));
+ ASSERT_EQ(2, perm->at<loco::DataType::S32>(3));
+}
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
new file mode 100644
index 000000000..7096c2591
--- /dev/null
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.cpp
@@ -0,0 +1,223 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+luci::CircleConst *create_weights_from_gamma(luci::CircleConst *gamma)
+{
+ assert(gamma->rank() == 1);
+ auto channel_size = gamma->dim(0).value();
+
+ // Channel-wise MUL is the same as DEPTHWISE_CONV2D with filter shape (1,1,1,channel_size)
+ auto weights = gamma->graph()->nodes()->create<luci::CircleConst>();
+ weights->dtype(loco::DataType::FLOAT32);
+ weights->rank(4);
+ weights->dim(0).set(1);
+ weights->dim(1).set(1);
+ weights->dim(2).set(1);
+ weights->dim(3).set(channel_size);
+ weights->shape_status(luci::ShapeStatus::VALID);
+ weights->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ weights->at<loco::DataType::FLOAT32>(i) = gamma->at<loco::DataType::FLOAT32>(i);
+ }
+
+ return weights;
+}
+
+luci::CircleConst *create_bias_from_beta(luci::CircleConst *beta)
+{
+ assert(beta->rank() == 1);
+ auto channel_size = beta->dim(0).value();
+
+ // Channel-wise ADD is the same as bias (shape = (channel_size)) of DEPTHWISE_CONV2D
+ auto bias = beta->graph()->nodes()->create<luci::CircleConst>();
+ bias->dtype(loco::DataType::FLOAT32);
+ bias->rank(1);
+ bias->dim(0).set(channel_size);
+ bias->size<loco::DataType::FLOAT32>(channel_size);
+ bias->shape_status(luci::ShapeStatus::VALID);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ bias->at<loco::DataType::FLOAT32>(i) = beta->at<loco::DataType::FLOAT32>(i);
+ }
+
+ return bias;
+}
+
+bool is_batchnorm_add(const luci::CircleAdd *add, luci::CircleMul *&mul, luci::CircleConst *&beta)
+{
+ auto x = loco::must_cast<luci::CircleNode *>(add->x());
+ auto y = loco::must_cast<luci::CircleNode *>(add->y());
+
+ luci::CircleMul *pred = nullptr;
+ luci::CircleConst *constant = nullptr;
+
+ if (x->opcode() == luci::CircleOpcode::CIRCLECONST && y->opcode() == luci::CircleOpcode::MUL)
+ {
+ pred = loco::must_cast<luci::CircleMul *>(y);
+ constant = loco::must_cast<luci::CircleConst *>(x);
+ }
+ else if (x->opcode() == luci::CircleOpcode::MUL && y->opcode() == luci::CircleOpcode::CIRCLECONST)
+ {
+ pred = loco::must_cast<luci::CircleMul *>(x);
+ constant = loco::must_cast<luci::CircleConst *>(y);
+ }
+ else
+ {
+ return false;
+ }
+
+ if (constant->rank() != 1)
+ return false;
+
+ auto channel_dim = constant->dim(0);
+ // Assumption: Layout is channel-last
+ if (!(channel_dim == add->dim(add->rank() - 1)))
+ return false;
+
+ mul = pred;
+ beta = constant;
+ return true;
+}
+
+// Check if mul is batchnorm mul
+bool is_batchnorm_mul(const luci::CircleMul *mul, luci::CircleNode *&pred_node,
+ luci::CircleConst *&gamma)
+{
+ auto x = dynamic_cast<luci::CircleConst *>(mul->x());
+ auto y = dynamic_cast<luci::CircleConst *>(mul->y());
+
+ luci::CircleNode *pred = nullptr;
+ luci::CircleConst *constant = nullptr;
+
+ if (x != nullptr && y == nullptr)
+ {
+ pred = loco::must_cast<luci::CircleNode *>(mul->y());
+ constant = x;
+ }
+ else if (x == nullptr && y != nullptr)
+ {
+ pred = loco::must_cast<luci::CircleNode *>(mul->x());
+ constant = y;
+ }
+ else
+ {
+ return false;
+ }
+
+ if (constant->rank() != 1)
+ return false;
+
+ auto channel_dim = constant->dim(0);
+ if (!(channel_dim == mul->dim(mul->rank() - 1)))
+ return false;
+
+ pred_node = pred;
+ gamma = constant;
+ return true;
+}
+
+/**
+ * Replace channel-wise Mul/Add with DepthwiseConv2D
+ *
+ * BEFORE
+ *
+ * [Node] [gamma]
+ * | /
+ * [Mul] [beta]
+ * | /
+ * [Add]
+ *
+ * AFTER
+ *
+ * [Node] [weights] [bias]
+ * \ / /
+ * [DepthwiseConv2D]
+ */
+bool replace_mul_add_with_dwconv(luci::CircleAdd *add)
+{
+ luci::CircleNode *pred_node = nullptr;
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *beta = nullptr;
+ luci::CircleConst *gamma = nullptr;
+
+ if (!is_batchnorm_add(add, mul, beta))
+ return false;
+
+ if (loco::succs(mul).size() != 1)
+ return false;
+
+ if (!is_batchnorm_mul(mul, pred_node, gamma))
+ return false;
+
+ if (pred_node->rank() != 4)
+ return false;
+
+ if (pred_node->dtype() != loco::DataType::FLOAT32 || beta->dtype() != loco::DataType::FLOAT32 ||
+ gamma->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ auto weights = create_weights_from_gamma(gamma);
+ auto bias = create_bias_from_beta(beta);
+
+ auto dwconv = add->graph()->nodes()->create<luci::CircleDepthwiseConv2D>();
+ dwconv->input(pred_node);
+ dwconv->filter(weights);
+ dwconv->bias(bias);
+ dwconv->padding(luci::Padding::SAME);
+ dwconv->stride()->w(1);
+ dwconv->stride()->h(1);
+ dwconv->depthMultiplier(1);
+ dwconv->dilation()->w(1);
+ dwconv->dilation()->h(1);
+ dwconv->fusedActivationFunction(add->fusedActivationFunction());
+
+ loco::replace(add).with(dwconv);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool ReplaceMulAddWithDepthwiseConvPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto add = dynamic_cast<luci::CircleAdd *>(node);
+ if (not add)
+ continue;
+
+ if (replace_mul_add_with_dwconv(add))
+ {
+ changed = true;
+ break;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
new file mode 100644
index 000000000..a90182aaa
--- /dev/null
+++ b/compiler/luci/pass/src/ReplaceMulAddWithDepthwiseConvPass.test.cpp
@@ -0,0 +1,142 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [Node] [gamma]
+ * | /
+ * [Mul] [beta]
+ * | /
+ * [Add]
+ *
+ * AFTER
+ *
+ * [Node] [weights] [bias]
+ * \ / /
+ * [DepthwiseConv2D]
+ */
+class SimpleGraph
+{
+public:
+ SimpleGraph()
+ {
+ input = g.nodes()->create<luci::CircleInput>();
+ mul = g.nodes()->create<luci::CircleMul>();
+ gamma = g.nodes()->create<luci::CircleConst>();
+ add = g.nodes()->create<luci::CircleAdd>();
+ beta = g.nodes()->create<luci::CircleConst>();
+ output = g.nodes()->create<luci::CircleOutput>();
+
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+
+ input->dtype(loco::DataType::FLOAT32);
+ mul->dtype(loco::DataType::FLOAT32);
+ gamma->dtype(loco::DataType::FLOAT32);
+ add->dtype(loco::DataType::FLOAT32);
+ beta->dtype(loco::DataType::FLOAT32);
+ output->dtype(loco::DataType::FLOAT32);
+
+ uint32_t channel_size = 16;
+ input->shape({1, 4, 4, channel_size});
+ mul->shape({1, 4, 4, channel_size});
+ gamma->shape({channel_size});
+ add->shape({1, 4, 4, channel_size});
+ beta->shape({channel_size});
+ output->shape({1, 4, 4, channel_size});
+
+ gamma->size<loco::DataType::FLOAT32>(channel_size);
+ beta->size<loco::DataType::FLOAT32>(channel_size);
+ for (uint32_t i = 0; i < channel_size; i++)
+ {
+ gamma->at<loco::DataType::FLOAT32>(i) = i;
+ beta->at<loco::DataType::FLOAT32>(i) = i;
+ }
+
+ mul->x(input);
+ mul->y(gamma);
+ add->x(mul);
+ add->y(beta);
+ output->from(add);
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleMul *mul = nullptr;
+ luci::CircleConst *gamma = nullptr;
+ luci::CircleAdd *add = nullptr;
+ luci::CircleConst *beta = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST(ReplaceMulAddWithDepthwiseConv, simple)
+{
+ SimpleGraph g;
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ while (pass.run(&g.g))
+ ;
+
+ auto dwconv = dynamic_cast<luci::CircleDepthwiseConv2D *>(g.output->from());
+ EXPECT_NE(nullptr, dwconv);
+
+ uint32_t channel_size = 16;
+ auto weights = dynamic_cast<luci::CircleConst *>(dwconv->filter());
+ auto bias = dynamic_cast<luci::CircleConst *>(dwconv->bias());
+ EXPECT_NE(nullptr, weights);
+ EXPECT_EQ(4, weights->rank());
+ EXPECT_EQ(channel_size, weights->dim(3).value());
+ EXPECT_NE(nullptr, bias);
+ EXPECT_EQ(1, bias->rank());
+ EXPECT_EQ(channel_size, bias->dim(0).value());
+
+ for (int i = 0; i < channel_size; i++)
+ {
+ EXPECT_FLOAT_EQ(i, weights->at<loco::DataType::FLOAT32>(i));
+ EXPECT_FLOAT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
+ }
+}
+
+TEST(ReplaceMulAddWithDepthwiseConv, wrong_op_NEG)
+{
+ SimpleGraph g;
+ // swap mul/add (changed to add->mul)
+ g.add->x(g.input);
+ loco::replace(g.add).with(g.mul);
+ g.mul->x(g.add);
+
+ luci::ReplaceMulAddWithDepthwiseConvPass pass;
+ auto changed = pass.run(&g.g);
+
+ EXPECT_EQ(false, changed);
+}
diff --git a/compiler/luci/pass/src/ShapeInferencePass.cpp b/compiler/luci/pass/src/ShapeInferencePass.cpp
index f681b3d5f..4bd0aaed4 100644
--- a/compiler/luci/pass/src/ShapeInferencePass.cpp
+++ b/compiler/luci/pass/src/ShapeInferencePass.cpp
@@ -28,6 +28,19 @@
namespace luci
{
+bool ShapeInferencePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
bool ShapeInferencePass::run(loco::Graph *g)
{
loco::CanonicalShapeInferenceRule canonical_rule;
diff --git a/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp b/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp
new file mode 100644
index 000000000..115b77a96
--- /dev/null
+++ b/compiler/luci/pass/src/ShapeSignatureInferencePass.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ShapeSignatureInferencePass.h"
+
+#include <luci/IR/CircleShapeSignature.h>
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+#include <loco.h>
+
+namespace luci
+{
+
+bool ShapeSignatureInferencePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
+bool ShapeSignatureInferencePass::run(loco::Graph *g)
+{
+ luci::ssinf::Rule signature_inference_rule;
+ bool changed = false;
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ luci::ShapeSignature shape_signature;
+
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (signature_inference_rule.infer(circle_node, shape_signature))
+ {
+ if (!(circle_node->shape_signature() == shape_signature))
+ {
+ circle_node->shape_signature(shape_signature);
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp
new file mode 100644
index 000000000..6a58f18c5
--- /dev/null
+++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.cpp
@@ -0,0 +1,139 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <cassert>
+#include <vector>
+
+namespace
+{
+
+bool satisfy_precondition(luci::CircleFullyConnected *fc)
+{
+ // check if it's already been shuffled
+ if (fc->weights_format() != luci::CircleFullyConnected::WeightsFormat::DEFAULT)
+ return false;
+
+ // check if its data type is FLOAT32
+ if (fc->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ auto weights = loco::must_cast<luci::CircleConst *>(fc->weights());
+ // rank must be 2
+ if (weights->rank() != 2)
+ return false;
+
+ // check if it has sparsity parameter
+ if (weights->sparsityparam())
+ return false;
+
+ // check if the number of row of FullyConnected's weight is a multiple of 16
+ const uint32_t MULTIPLE = 16;
+ uint32_t rows = weights->dim(0).value();
+ if (rows % MULTIPLE)
+ return false;
+
+ return true;
+}
+
+// get FullyConnected op vector that has same tensor
+void get_FCs_having_same_tensor(std::vector<luci::CircleFullyConnected *> &fc_vec, loco::Graph *g,
+ luci::CircleFullyConnected *fc)
+{
+ auto the_tensor = fc->weights();
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
+ if (not fc)
+ continue;
+
+ if (fc->weights() == the_tensor)
+ fc_vec.push_back(fc);
+ }
+}
+
+luci::CircleConst *shuffle_weight(luci::CircleFullyConnected *fc)
+{
+ auto the_weights = loco::must_cast<luci::CircleConst *>(fc->weights());
+
+ // create CircleConst where shuffled data will be stored
+ luci::CircleConst *new_weights = fc->graph()->nodes()->create<luci::CircleConst>();
+ new_weights->dtype(loco::DataType::FLOAT32);
+ new_weights->size<loco::DataType::FLOAT32>(the_weights->size<loco::DataType::FLOAT32>());
+ new_weights->rank(the_weights->rank());
+ new_weights->shape_status(the_weights->shape_status());
+ for (uint32_t r = 0; r < new_weights->rank(); r++)
+ {
+ new_weights->dim(r).set(the_weights->dim(r).value());
+ }
+
+ // suffle weight
+ const uint32_t MULTIPLE = 16;
+ const uint32_t rows = the_weights->dim(0).value();
+ const uint32_t cols = the_weights->dim(1).value();
+ const uint32_t r_step = rows / MULTIPLE;
+ uint32_t index = 0;
+ for (uint32_t r = 0; r < r_step; r++)
+ {
+ for (uint32_t c = 0; c < cols; c++)
+ {
+ for (uint32_t i = 0; i < MULTIPLE; i++)
+ {
+ new_weights->at<loco::DataType::FLOAT32>(index++) =
+ the_weights->at<loco::DataType::FLOAT32>((r * MULTIPLE + i) * cols + c);
+ }
+ }
+ }
+
+ return new_weights;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool ShuffleWeightTo16x1Float32Pass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
+ if (not fc)
+ continue;
+
+ if (not satisfy_precondition(fc))
+ continue;
+
+ std::vector<luci::CircleFullyConnected *> fc_vec;
+ get_FCs_having_same_tensor(fc_vec, g, fc);
+ auto new_weights = shuffle_weight(fc);
+
+ // replace to new weights
+ for (const auto fc : fc_vec)
+ {
+ fc->weights(new_weights);
+ fc->weights_format(luci::CircleFullyConnected::WeightsFormat::SHUFFLED16x1FLOAT32);
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp
new file mode 100644
index 000000000..9745e5754
--- /dev/null
+++ b/compiler/luci/pass/src/ShuffleWeightTo16x1Float32Pass.test.cpp
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+void create_fc_net(loco::Graph *g)
+{
+ assert(g);
+
+ const uint32_t ROW = 16;
+ const uint32_t COL = 2;
+ const uint32_t elements_num = ROW * COL;
+
+ // input
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto graph_input = g->inputs()->create();
+ input->index(graph_input->index());
+
+ // fc weights
+ auto weights = g->nodes()->create<luci::CircleConst>();
+ weights->dtype(loco::DataType::FLOAT32);
+ weights->size<loco::DataType::FLOAT32>(elements_num);
+ weights->rank(2);
+ weights->dim(0).set(ROW);
+ weights->dim(1).set(COL);
+ for (uint32_t idx = 0; idx < elements_num; idx++)
+ {
+ weights->at<loco::DataType::FLOAT32>(idx) = idx;
+ }
+
+ // fc
+ auto fc = g->nodes()->create<luci::CircleFullyConnected>();
+ fc->dtype(loco::DataType::FLOAT32);
+ fc->input(input);
+ fc->weights(weights);
+
+ // output
+ auto output = g->nodes()->create<luci::CircleOutput>();
+ output->from(fc);
+ auto graph_output = g->outputs()->create();
+ output->index(graph_output->index());
+}
+
+TEST(ShuffleWeightTo16x1Float32PassTest, SimpleTest1)
+{
+ auto graph = loco::make_graph();
+ create_fc_net(graph.get());
+
+ luci::CircleFullyConnected *fc_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ auto fc = dynamic_cast<luci::CircleFullyConnected *>(node);
+ if (not fc)
+ continue;
+
+ fc_node = fc;
+ break;
+ }
+ ASSERT_NE(fc_node, nullptr);
+ auto weights = loco::must_cast<luci::CircleConst *>(fc_node->weights());
+ // before
+ ASSERT_EQ(0, weights->at<loco::DataType::FLOAT32>(0));
+ ASSERT_EQ(1, weights->at<loco::DataType::FLOAT32>(1));
+ ASSERT_EQ(2, weights->at<loco::DataType::FLOAT32>(2));
+ ASSERT_EQ(3, weights->at<loco::DataType::FLOAT32>(3));
+ ASSERT_EQ(4, weights->at<loco::DataType::FLOAT32>(4));
+ ASSERT_EQ(5, weights->at<loco::DataType::FLOAT32>(5));
+ ASSERT_EQ(6, weights->at<loco::DataType::FLOAT32>(6));
+ ASSERT_EQ(7, weights->at<loco::DataType::FLOAT32>(7));
+ ASSERT_EQ(8, weights->at<loco::DataType::FLOAT32>(8));
+ ASSERT_EQ(9, weights->at<loco::DataType::FLOAT32>(9));
+ ASSERT_EQ(10, weights->at<loco::DataType::FLOAT32>(10));
+ ASSERT_EQ(11, weights->at<loco::DataType::FLOAT32>(11));
+ ASSERT_EQ(12, weights->at<loco::DataType::FLOAT32>(12));
+ ASSERT_EQ(13, weights->at<loco::DataType::FLOAT32>(13));
+ ASSERT_EQ(14, weights->at<loco::DataType::FLOAT32>(14));
+ ASSERT_EQ(15, weights->at<loco::DataType::FLOAT32>(15));
+
+ luci::ShuffleWeightTo16x1Float32Pass pass;
+ while (pass.run(graph.get()))
+ ;
+
+ weights = loco::must_cast<luci::CircleConst *>(fc_node->weights());
+ // after
+ ASSERT_EQ(0, weights->at<loco::DataType::FLOAT32>(0));
+ ASSERT_EQ(2, weights->at<loco::DataType::FLOAT32>(1));
+ ASSERT_EQ(4, weights->at<loco::DataType::FLOAT32>(2));
+ ASSERT_EQ(6, weights->at<loco::DataType::FLOAT32>(3));
+ ASSERT_EQ(8, weights->at<loco::DataType::FLOAT32>(4));
+ ASSERT_EQ(10, weights->at<loco::DataType::FLOAT32>(5));
+ ASSERT_EQ(12, weights->at<loco::DataType::FLOAT32>(6));
+ ASSERT_EQ(14, weights->at<loco::DataType::FLOAT32>(7));
+ ASSERT_EQ(16, weights->at<loco::DataType::FLOAT32>(8));
+ ASSERT_EQ(18, weights->at<loco::DataType::FLOAT32>(9));
+ ASSERT_EQ(20, weights->at<loco::DataType::FLOAT32>(10));
+ ASSERT_EQ(22, weights->at<loco::DataType::FLOAT32>(11));
+ ASSERT_EQ(24, weights->at<loco::DataType::FLOAT32>(12));
+ ASSERT_EQ(26, weights->at<loco::DataType::FLOAT32>(13));
+ ASSERT_EQ(28, weights->at<loco::DataType::FLOAT32>(14));
+ ASSERT_EQ(30, weights->at<loco::DataType::FLOAT32>(15));
+}
diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp
new file mode 100644
index 000000000..44e974b91
--- /dev/null
+++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.cpp
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/SubstitutePackToReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+bool substitute_pack_to_reshape(luci::CircleNode *node)
+{
+ auto target_node = dynamic_cast<luci::CirclePack *>(node);
+ if (target_node == nullptr)
+ return false;
+ if (target_node->values_count() != 1)
+ return false;
+ auto value_node = loco::must_cast<luci::CircleNode *>(target_node->values(0));
+ if (value_node->shape_status() != luci::ShapeStatus::VALID)
+ return false;
+ int32_t axis = target_node->axis();
+ if (axis < 0)
+ axis = axis + static_cast<int32_t>(value_node->rank()) + 1;
+
+ auto graph = target_node->graph();
+ auto reshape_node = graph->nodes()->create<luci::CircleReshape>();
+ reshape_node->tensor(value_node);
+
+ auto const_node = graph->nodes()->create<luci::CircleConst>();
+ const_node->dtype(loco::DataType::S32);
+ const_node->size<loco::DataType::S32>(value_node->rank() + 1);
+ const_node->shape_status(luci::ShapeStatus::VALID);
+ const_node->rank(1);
+ const_node->dim(0).set(value_node->rank() + 1);
+ for (int32_t i = 0; i < static_cast<int32_t>(value_node->rank()) + 1; i++)
+ {
+ if (i == axis)
+ {
+ const_node->at<loco::DataType::S32>(i) = 1;
+ }
+ else if (i < axis)
+ {
+ const_node->at<loco::DataType::S32>(i) = value_node->dim(i).value();
+ }
+ else
+ {
+ const_node->at<loco::DataType::S32>(i) = value_node->dim(i - 1).value();
+ }
+ }
+ reshape_node->shape(const_node);
+ replace(target_node).with(reshape_node);
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ * |
+ * [CircleNode]
+ * |
+ * [CirclePack]
+ * |
+ * [CircleNode]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleReshape]
+ * |
+ * [CircleNode]
+ * |
+ *
+ */
+bool SubstitutePackToReshapePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (substitute_pack_to_reshape(circle_node))
+ {
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp b/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp
new file mode 100644
index 000000000..143b88896
--- /dev/null
+++ b/compiler/luci/pass/src/SubstitutePackToReshapePass.test.cpp
@@ -0,0 +1,124 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/SubstitutePackToReshapePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * BEFORE
+ * |
+ * [CircleNode]
+ * |
+ * [CirclePack]
+ * |
+ * [CircleNode]
+ * |
+ *
+ * AFTER
+ * |
+ * [CircleNode] [CircleConst]
+ * \ /
+ * [CircleReshape]
+ * |
+ * [CircleNode]
+ * |
+ *
+ */
+void create_substitute_pack_to_reshape(loco::Graph *g, const std::initializer_list<uint32_t> shape,
+ int32_t axis)
+{
+ assert(g);
+
+ // Input Create.
+ auto input = g->nodes()->create<luci::CircleInput>();
+ auto graph_input = g->inputs()->create();
+ input->index(graph_input->index());
+ input->shape_status(luci::ShapeStatus::VALID);
+ input->rank(shape.size());
+ input->shape(shape);
+
+ // Pack Node create.
+ auto pack = g->nodes()->create<luci::CirclePack>(1);
+ pack->values(0, input);
+ pack->axis(axis);
+
+ // Output Connect.
+ auto output = g->nodes()->create<luci::CircleOutput>();
+ output->from(pack);
+ auto graph_output = g->outputs()->create();
+ output->index(graph_output->index());
+
+ return;
+}
+
+} // namespace
+
+TEST(SubstitutePackToReshapePass, simple_case)
+{
+ auto graph = loco::make_graph();
+ create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, 0);
+ luci::SubstitutePackToReshapePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleReshape *reshape_node = nullptr;
+ luci::CirclePack *pack_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
+ reshape_node = reshape;
+ else if (auto pack = dynamic_cast<luci::CirclePack *>(node))
+ pack_node = pack;
+ }
+ ASSERT_NE(nullptr, reshape_node);
+ ASSERT_EQ(nullptr, pack_node);
+ auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
+ ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0));
+ ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(1));
+ ASSERT_EQ(2, new_shape->at<loco::DataType::S32>(2));
+ ASSERT_EQ(3, new_shape->at<loco::DataType::S32>(3));
+ ASSERT_EQ(4, new_shape->at<loco::DataType::S32>(4));
+}
+
+TEST(SubstitutePackToReshapePass, simple_case_neg_axis)
+{
+ auto graph = loco::make_graph();
+ create_substitute_pack_to_reshape(graph.get(), {1, 2, 3, 4}, -1);
+ luci::SubstitutePackToReshapePass pass;
+ while (pass.run(graph.get()))
+ ;
+ luci::CircleReshape *reshape_node = nullptr;
+ luci::CirclePack *pack_node = nullptr;
+ for (auto node : loco::active_nodes(loco::output_nodes(graph.get())))
+ {
+ if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
+ reshape_node = reshape;
+ else if (auto pack = dynamic_cast<luci::CirclePack *>(node))
+ pack_node = pack;
+ }
+ ASSERT_NE(nullptr, reshape_node);
+ ASSERT_EQ(nullptr, pack_node);
+ auto new_shape = loco::must_cast<luci::CircleConst *>(reshape_node->shape());
+ ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(0));
+ ASSERT_EQ(2, new_shape->at<loco::DataType::S32>(1));
+ ASSERT_EQ(3, new_shape->at<loco::DataType::S32>(2));
+ ASSERT_EQ(4, new_shape->at<loco::DataType::S32>(3));
+ ASSERT_EQ(1, new_shape->at<loco::DataType::S32>(4));
+}
diff --git a/compiler/luci/pass/src/TypeInferencePass.cpp b/compiler/luci/pass/src/TypeInferencePass.cpp
index 2c7b3a897..63744045c 100644
--- a/compiler/luci/pass/src/TypeInferencePass.cpp
+++ b/compiler/luci/pass/src/TypeInferencePass.cpp
@@ -26,6 +26,19 @@
namespace luci
{
+bool TypeInferencePass::run(luci::Module *m)
+{
+ bool changed = false;
+
+ for (size_t g = 0; g < m->size(); ++g)
+ {
+ if (run(m->graph(g)))
+ changed = true;
+ }
+
+ return changed;
+}
+
bool TypeInferencePass::run(loco::Graph *g)
{
loco::CanonicalTypeInferenceRule canonical_rule;
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInference.h b/compiler/luci/service/include/luci/Service/CircleShapeInference.h
index fb934c2cf..c301db5f4 100644
--- a/compiler/luci/service/include/luci/Service/CircleShapeInference.h
+++ b/compiler/luci/service/include/luci/Service/CircleShapeInference.h
@@ -21,6 +21,10 @@
#include <loco/IR/Nodes.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/CircleShapeInferenceHelper.h>
+
namespace luci
{
@@ -36,6 +40,155 @@ struct ShapeInference
static ShapeDescription get(loco::Node *node);
};
+namespace sinf // namespace for Shape Inference
+{
+
+struct Rule
+{
+ bool infer(const luci::CircleNode *, loco::TensorShape &) const;
+};
+
+class Algorithm final : public luci::CircleNodeVisitor<loco::TensorShape>
+{
+public:
+ // TODO Remove this when all of visit function is implemented
+ loco::TensorShape visit(const luci::CircleNode *node) final { return sinf::circle_shape(node); }
+
+ // loco::TensorShape visit(const luci::CircleAbs *node) final;
+ // loco::TensorShape visit(const luci::CircleAdd *node) final;
+ // loco::TensorShape visit(const luci::CircleAddN *node) final;
+ // loco::TensorShape visit(const luci::CircleArgMax *node) final;
+ // loco::TensorShape visit(const luci::CircleArgMin *node) final;
+ // loco::TensorShape visit(const luci::CircleAveragePool2D *node) final;
+ // loco::TensorShape visit(const luci::CircleBatchMatMul *node) final;
+ // loco::TensorShape visit(const luci::CircleBatchToSpaceND *node) final;
+ // loco::TensorShape visit(const luci::CircleCast *node) final;
+ // loco::TensorShape visit(const luci::CircleCeil *node) final;
+ // loco::TensorShape visit(const luci::CircleConcatenation *node) final;
+ // loco::TensorShape visit(const luci::CircleConst *node) final;
+ // loco::TensorShape visit(const luci::CircleConv2D *node) final;
+ // loco::TensorShape visit(const luci::CircleCos *node) final;
+ // loco::TensorShape visit(const luci::CircleCustom *node) final;
+ // loco::TensorShape visit(const luci::CircleDepthToSpace *node) final;
+ // loco::TensorShape visit(const luci::CircleDepthwiseConv2D *node) final;
+ // loco::TensorShape visit(const luci::CircleDequantize *node) final;
+ // loco::TensorShape visit(const luci::CircleDiv *node) final;
+ // loco::TensorShape visit(const luci::CircleElu *node) final;
+ // loco::TensorShape visit(const luci::CircleEqual *node) final;
+ // loco::TensorShape visit(const luci::CircleExp *node) final;
+ // loco::TensorShape visit(const luci::CircleExpandDims *node) final;
+ // loco::TensorShape visit(const luci::CircleFill *node) final;
+ // loco::TensorShape visit(const luci::CircleFloor *node) final;
+ // loco::TensorShape visit(const luci::CircleFloorDiv *node) final;
+ // loco::TensorShape visit(const luci::CircleFloorMod *node) final;
+ // loco::TensorShape visit(const luci::CircleFullyConnected *node) final;
+ // loco::TensorShape visit(const luci::CircleGather *node) final;
+ // loco::TensorShape visit(const luci::CircleGatherNd *node) final;
+ // loco::TensorShape visit(const luci::CircleGreater *node) final;
+ // loco::TensorShape visit(const luci::CircleGreaterEqual *node) final;
+ // loco::TensorShape visit(const luci::CircleIf *node) final;
+ // loco::TensorShape visit(const luci::CircleL2Normalize *node) final;
+ // loco::TensorShape visit(const luci::CircleL2Pool2D *node) final;
+ // loco::TensorShape visit(const luci::CircleLeakyRelu *node) final;
+ // loco::TensorShape visit(const luci::CircleLess *node) final;
+ // loco::TensorShape visit(const luci::CircleLessEqual *node) final;
+ // loco::TensorShape visit(const luci::CircleLocalResponseNormalization *node) final;
+ // loco::TensorShape visit(const luci::CircleLog *node) final;
+ // loco::TensorShape visit(const luci::CircleLogicalAnd *node) final;
+ // loco::TensorShape visit(const luci::CircleLogicalNot *node) final;
+ // loco::TensorShape visit(const luci::CircleLogicalOr *node) final;
+ // loco::TensorShape visit(const luci::CircleLogistic *node) final;
+ // loco::TensorShape visit(const luci::CircleLogSoftmax *node) final;
+ // loco::TensorShape visit(const luci::CircleMatrixDiag *node) final;
+ // loco::TensorShape visit(const luci::CircleMatrixSetDiag *node) final;
+ // loco::TensorShape visit(const luci::CircleMaximum *node) final;
+ // loco::TensorShape visit(const luci::CircleMaxPool2D *node) final;
+ // loco::TensorShape visit(const luci::CircleMean *node) final;
+ // loco::TensorShape visit(const luci::CircleMinimum *node) final;
+ // loco::TensorShape visit(const luci::CircleMirrorPad *node) final;
+ // loco::TensorShape visit(const luci::CircleNeg *node) final;
+ // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4 *node) final;
+ // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5 *node) final;
+ // loco::TensorShape visit(const luci::CircleNotEqual *node) final;
+ // loco::TensorShape visit(const luci::CirclePack *node) final;
+ // loco::TensorShape visit(const luci::CirclePad *node) final;
+ // loco::TensorShape visit(const luci::CirclePadV2 *node) final;
+ // loco::TensorShape visit(const luci::CirclePow *node) final;
+ // loco::TensorShape visit(const luci::CirclePRelu *node) final;
+ // loco::TensorShape visit(const luci::CircleRange *node) final;
+ // loco::TensorShape visit(const luci::CircleRank *node) final;
+ // loco::TensorShape visit(const luci::CircleMul *node) final;
+ // loco::TensorShape visit(const luci::CircleOneHot *node) final;
+ // loco::TensorShape visit(const luci::CircleReduceAny *node) final;
+ // loco::TensorShape visit(const luci::CircleReduceMax *node) final;
+ // loco::TensorShape visit(const luci::CircleReduceMin *node) final;
+ // loco::TensorShape visit(const luci::CircleReduceProd *node) final;
+ // loco::TensorShape visit(const luci::CircleRelu *node) final;
+ // loco::TensorShape visit(const luci::CircleRelu6 *node) final;
+ // loco::TensorShape visit(const luci::CircleReluN1To1 *node) final;
+ // loco::TensorShape visit(const luci::CircleReshape *node) final;
+ // loco::TensorShape visit(const luci::CircleResizeBilinear *node) final;
+ // loco::TensorShape visit(const luci::CircleResizeNearestNeighbor *node) final;
+ // loco::TensorShape visit(const luci::CircleReverseSequence *node) final;
+ // loco::TensorShape visit(const luci::CircleReverseV2 *node) final;
+ // loco::TensorShape visit(const luci::CircleRound *node) final;
+ // loco::TensorShape visit(const luci::CircleRsqrt *node) final;
+ // loco::TensorShape visit(const luci::CircleScatterNd *node) final;
+ // loco::TensorShape visit(const luci::CircleSegmentSum *node) final;
+ // loco::TensorShape visit(const luci::CircleSelect *node) final;
+ // loco::TensorShape visit(const luci::CircleSelectV2 *node) final;
+ // loco::TensorShape visit(const luci::CircleShape *node) final;
+ // loco::TensorShape visit(const luci::CircleSin *node) final;
+ // loco::TensorShape visit(const luci::CircleSlice *node) final;
+ // loco::TensorShape visit(const luci::CircleSoftmax *node) final;
+ // loco::TensorShape visit(const luci::CircleSpaceToBatchND *node) final;
+ // loco::TensorShape visit(const luci::CircleSpaceToDepth *node) final;
+ // loco::TensorShape visit(const luci::CircleSparseToDense *node) final;
+ // loco::TensorShape visit(const luci::CircleSplit *node) final;
+ // loco::TensorShape visit(const luci::CircleSplitV *node) final;
+ // loco::TensorShape visit(const luci::CircleSqrt *node) final;
+ // loco::TensorShape visit(const luci::CircleSquare *node) final;
+ // loco::TensorShape visit(const luci::CircleSquaredDifference *node) final;
+ // loco::TensorShape visit(const luci::CircleSqueeze *node) final;
+ // loco::TensorShape visit(const luci::CircleStridedSlice *node) final;
+ // loco::TensorShape visit(const luci::CircleSub *node) final;
+ // loco::TensorShape visit(const luci::CircleSum *node) final;
+ // loco::TensorShape visit(const luci::CircleTanh *node) final;
+ // loco::TensorShape visit(const luci::CircleTile *node) final;
+ // loco::TensorShape visit(const luci::CircleTopKV2 *node) final;
+ // loco::TensorShape visit(const luci::CircleTranspose *node) final;
+ // loco::TensorShape visit(const luci::CircleTransposeConv *node) final;
+ // loco::TensorShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final;
+ // loco::TensorShape visit(const luci::CircleUnique *node) final;
+ // loco::TensorShape visit(const luci::CircleUnpack *node) final;
+ // loco::TensorShape visit(const luci::CircleWhere *node) final;
+ // loco::TensorShape visit(const luci::CircleWhile *node) final;
+ // loco::TensorShape visit(const luci::CircleZerosLike *node) final;
+
+ // Circle Only
+ // loco::TensorShape visit(const luci::CircleBCQFullyConnected *node) final;
+ // loco::TensorShape visit(const luci::CircleBCQGather *node) final;
+ // loco::TensorShape visit(const luci::CircleInstanceNorm *node) final;
+
+ // Virtual
+ // loco::TensorShape visit(const luci::CircleInput *node) final;
+ // loco::TensorShape visit(const luci::CircleOutput *node) final;
+ // loco::TensorShape visit(const luci::CircleOutputDummy *node) final;
+ // loco::TensorShape visit(const luci::CircleOutputExclude *node) final;
+ // loco::TensorShape visit(const luci::CircleCustomOut *node) final;
+ // loco::TensorShape visit(const luci::CircleIfOut *node) final;
+ // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
+ // loco::TensorShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final;
+ // loco::TensorShape visit(const luci::CircleSplitOut *node) final;
+ // loco::TensorShape visit(const luci::CircleSplitVOut *node) final;
+ // loco::TensorShape visit(const luci::CircleTopKV2Out *node) final;
+ // loco::TensorShape visit(const luci::CircleUniqueOut *node) final;
+ // loco::TensorShape visit(const luci::CircleUnpackOut *node) final;
+ // loco::TensorShape visit(const luci::CircleWhileOut *node) final;
+};
+
+} // namespace sinf
+
} // namespace luci
#endif // __LUCI_CIRCLE_SHAPE_INFERENCE_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h
new file mode 100644
index 000000000..dd6a5a454
--- /dev/null
+++ b/compiler/luci/service/include/luci/Service/CircleShapeInferenceHelper.h
@@ -0,0 +1,36 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__
+#define __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__
+
+#include <loco/IR/TensorShape.h>
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleShapeSignature.h>
+
+namespace luci
+{
+namespace sinf // Namespace for Shape Inference
+{
+
+// Return shape of circle node as loco::TensorShape
+loco::TensorShape circle_shape(const luci::CircleNode *node);
+
+} // namespace sinf
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_SHAPE_INFERENCE_HELPER_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceRule.h b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h
index 4d1d83012..f7ea89bb8 100644
--- a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceRule.h
+++ b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInference.h
@@ -14,22 +14,26 @@
* limitations under the License.
*/
-#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_RULE_H__
-#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_RULE_H__
+#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__
+#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__
#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
#include <luci/IR/CircleShapeSignature.h>
+#include <luci/Service/CircleShapeSignatureInferenceHelper.h>
namespace luci
{
-struct CircleShapeSignatureInferenceRule
+namespace ssinf // namespace for Shape Signature Inference
+{
+
+struct Rule
{
bool infer(const luci::CircleNode *, ShapeSignature &) const;
};
-class ShapeSignatureInferenceAlgorithm final : public luci::CircleNodeVisitor<ShapeSignature>
+class Algorithm final : public luci::CircleNodeVisitor<ShapeSignature>
{
public:
// TODO Remove this when visit function is implemented for all the operations.
@@ -84,7 +88,7 @@ public:
// ShapeSignature visit(const luci::CircleMatrixSetDiag *node) final;
// ShapeSignature visit(const luci::CircleMaximum *node) final;
// ShapeSignature visit(const luci::CircleMaxPool2D *node) final;
- // ShapeSignature visit(const luci::CircleMean *node) final;
+ ShapeSignature visit(const luci::CircleMean *node) final;
// ShapeSignature visit(const luci::CircleMinimum *node) final;
// ShapeSignature visit(const luci::CircleMirrorPad *node) final;
// ShapeSignature visit(const luci::CircleNeg *node) final;
@@ -100,13 +104,13 @@ public:
// ShapeSignature visit(const luci::CircleRank *node) final;
// ShapeSignature visit(const luci::CircleMul *node) final;
// ShapeSignature visit(const luci::CircleOneHot *node) final;
- // ShapeSignature visit(const luci::CircleReduceAny *node) final;
- // ShapeSignature visit(const luci::CircleReduceMax *node) final;
- // ShapeSignature visit(const luci::CircleReduceMin *node) final;
- // ShapeSignature visit(const luci::CircleReduceProd *node) final;
- // ShapeSignature visit(const luci::CircleRelu *node) final;
- // ShapeSignature visit(const luci::CircleRelu6 *node) final;
- // ShapeSignature visit(const luci::CircleReluN1To1 *node) final;
+ ShapeSignature visit(const luci::CircleReduceAny *node) final;
+ ShapeSignature visit(const luci::CircleReduceMax *node) final;
+ ShapeSignature visit(const luci::CircleReduceMin *node) final;
+ ShapeSignature visit(const luci::CircleReduceProd *node) final;
+ ShapeSignature visit(const luci::CircleRelu *node) final;
+ ShapeSignature visit(const luci::CircleRelu6 *node) final;
+ ShapeSignature visit(const luci::CircleReluN1To1 *node) final;
// ShapeSignature visit(const luci::CircleReshape *node) final;
// ShapeSignature visit(const luci::CircleResizeBilinear *node) final;
// ShapeSignature visit(const luci::CircleResizeNearestNeighbor *node) final;
@@ -133,7 +137,7 @@ public:
// ShapeSignature visit(const luci::CircleSqueeze *node) final;
// ShapeSignature visit(const luci::CircleStridedSlice *node) final;
// ShapeSignature visit(const luci::CircleSub *node) final;
- // ShapeSignature visit(const luci::CircleSum *node) final;
+ ShapeSignature visit(const luci::CircleSum *node) final;
// ShapeSignature visit(const luci::CircleTanh *node) final;
// ShapeSignature visit(const luci::CircleTile *node) final;
// ShapeSignature visit(const luci::CircleTopKV2 *node) final;
@@ -152,10 +156,10 @@ public:
// ShapeSignature visit(const luci::CircleInstanceNorm *node) final;
// Virtual
- // ShapeSignature visit(const luci::CircleInput *node) final;
- // ShapeSignature visit(const luci::CircleOutput *node) final;
- // ShapeSignature visit(const luci::CircleOutputDummy *node) final;
- // ShapeSignature visit(const luci::CircleOutputExclude *node) final;
+ ShapeSignature visit(const luci::CircleInput *node) final;
+ ShapeSignature visit(const luci::CircleOutput *node) final;
+ ShapeSignature visit(const luci::CircleOutputDummy *node) final;
+ ShapeSignature visit(const luci::CircleOutputExclude *node) final;
// ShapeSignature visit(const luci::CircleCustomOut *node) final;
// ShapeSignature visit(const luci::CircleIfOut *node) final;
// ShapeSignature visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
@@ -168,6 +172,8 @@ public:
// ShapeSignature visit(const luci::CircleWhileOut *node) final;
};
+} // namespace ssinf
+
} // namespace luci
-#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_RULE_H__
+#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h
new file mode 100644
index 000000000..fb5b3b302
--- /dev/null
+++ b/compiler/luci/service/include/luci/Service/CircleShapeSignatureInferenceHelper.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__
+#define __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleShapeSignature.h>
+
+namespace luci
+{
+
+namespace ssinf // Namespace for Shape Signature Inference
+{
+
+// Return empty signature if all of dimensions are known.
+// If at least one of dimensions is unknown, return signature without change.
+ShapeSignature legalized_signature(const luci::ShapeSignature &signature);
+
+// Return reduced input_signature with indices and keep_dims.
+// - indices : reduction index
+// - keep_dims : If true, rank is not changed. If false, rank is reduced along indices.
+ShapeSignature reduced_signature(const loco::Node *node, const loco::Node *indices, bool keep_dims);
+
+// Return signature of index-th argument of node.
+ShapeSignature input_arg_signature(const luci::CircleNode *node, uint32_t index);
+
+} // namespace ssinf
+
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_SHAPE_SIGNATURE_INFERENCE_HELPER_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInference.h b/compiler/luci/service/include/luci/Service/CircleTypeInference.h
index ea7a3c5ed..342214887 100644
--- a/compiler/luci/service/include/luci/Service/CircleTypeInference.h
+++ b/compiler/luci/service/include/luci/Service/CircleTypeInference.h
@@ -21,6 +21,10 @@
#include <mio/circle/schema_generated.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Service/CircleTypeInferenceHelper.h>
+
namespace luci
{
@@ -37,6 +41,155 @@ struct TypeInference
static circle::TensorType get(loco::Node *node);
};
+namespace tinf // namespace for Type Inference
+{
+
+struct Rule
+{
+ bool infer(const luci::CircleNode *, loco::DataType &) const;
+};
+
+class Algorithm final : public luci::CircleNodeVisitor<loco::DataType>
+{
+public:
+ // TODO Remove this when all of visit function is implemented
+ loco::DataType visit(const luci::CircleNode *node) final { return node->dtype(); }
+
+ // loco::DataType visit(const luci::CircleAbs *node) final;
+ // loco::DataType visit(const luci::CircleAdd *node) final;
+ // loco::DataType visit(const luci::CircleAddN *node) final;
+ // loco::DataType visit(const luci::CircleArgMax *node) final;
+ // loco::DataType visit(const luci::CircleArgMin *node) final;
+ // loco::DataType visit(const luci::CircleAveragePool2D *node) final;
+ // loco::DataType visit(const luci::CircleBatchMatMul *node) final;
+ // loco::DataType visit(const luci::CircleBatchToSpaceND *node) final;
+ // loco::DataType visit(const luci::CircleCast *node) final;
+ // loco::DataType visit(const luci::CircleCeil *node) final;
+ // loco::DataType visit(const luci::CircleConcatenation *node) final;
+ // loco::DataType visit(const luci::CircleConst *node) final;
+ // loco::DataType visit(const luci::CircleConv2D *node) final;
+ // loco::DataType visit(const luci::CircleCos *node) final;
+ // loco::DataType visit(const luci::CircleCustom *node) final;
+ // loco::DataType visit(const luci::CircleDepthToSpace *node) final;
+ // loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final;
+ // loco::DataType visit(const luci::CircleDequantize *node) final;
+ // loco::DataType visit(const luci::CircleDiv *node) final;
+ // loco::DataType visit(const luci::CircleElu *node) final;
+ // loco::DataType visit(const luci::CircleEqual *node) final;
+ // loco::DataType visit(const luci::CircleExp *node) final;
+ // loco::DataType visit(const luci::CircleExpandDims *node) final;
+ // loco::DataType visit(const luci::CircleFill *node) final;
+ // loco::DataType visit(const luci::CircleFloor *node) final;
+ // loco::DataType visit(const luci::CircleFloorDiv *node) final;
+ // loco::DataType visit(const luci::CircleFloorMod *node) final;
+ // loco::DataType visit(const luci::CircleFullyConnected *node) final;
+ // loco::DataType visit(const luci::CircleGather *node) final;
+ // loco::DataType visit(const luci::CircleGatherNd *node) final;
+ // loco::DataType visit(const luci::CircleGreater *node) final;
+ // loco::DataType visit(const luci::CircleGreaterEqual *node) final;
+ // loco::DataType visit(const luci::CircleIf *node) final;
+ // loco::DataType visit(const luci::CircleL2Normalize *node) final;
+ // loco::DataType visit(const luci::CircleL2Pool2D *node) final;
+ // loco::DataType visit(const luci::CircleLeakyRelu *node) final;
+ // loco::DataType visit(const luci::CircleLess *node) final;
+ // loco::DataType visit(const luci::CircleLessEqual *node) final;
+ // loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final;
+ // loco::DataType visit(const luci::CircleLog *node) final;
+ // loco::DataType visit(const luci::CircleLogicalAnd *node) final;
+ // loco::DataType visit(const luci::CircleLogicalNot *node) final;
+ // loco::DataType visit(const luci::CircleLogicalOr *node) final;
+ // loco::DataType visit(const luci::CircleLogistic *node) final;
+ // loco::DataType visit(const luci::CircleLogSoftmax *node) final;
+ // loco::DataType visit(const luci::CircleMatrixDiag *node) final;
+ // loco::DataType visit(const luci::CircleMatrixSetDiag *node) final;
+ // loco::DataType visit(const luci::CircleMaximum *node) final;
+ // loco::DataType visit(const luci::CircleMaxPool2D *node) final;
+ // loco::DataType visit(const luci::CircleMean *node) final;
+ // loco::DataType visit(const luci::CircleMinimum *node) final;
+ // loco::DataType visit(const luci::CircleMirrorPad *node) final;
+ // loco::DataType visit(const luci::CircleNeg *node) final;
+ // loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final;
+ // loco::DataType visit(const luci::CircleNonMaxSuppressionV5 *node) final;
+ // loco::DataType visit(const luci::CircleNotEqual *node) final;
+ // loco::DataType visit(const luci::CirclePack *node) final;
+ // loco::DataType visit(const luci::CirclePad *node) final;
+ // loco::DataType visit(const luci::CirclePadV2 *node) final;
+ // loco::DataType visit(const luci::CirclePow *node) final;
+ // loco::DataType visit(const luci::CirclePRelu *node) final;
+ // loco::DataType visit(const luci::CircleRange *node) final;
+ // loco::DataType visit(const luci::CircleRank *node) final;
+ // loco::DataType visit(const luci::CircleMul *node) final;
+ // loco::DataType visit(const luci::CircleOneHot *node) final;
+ // loco::DataType visit(const luci::CircleReduceAny *node) final;
+ // loco::DataType visit(const luci::CircleReduceMax *node) final;
+ // loco::DataType visit(const luci::CircleReduceMin *node) final;
+ // loco::DataType visit(const luci::CircleReduceProd *node) final;
+ // loco::DataType visit(const luci::CircleRelu *node) final;
+ // loco::DataType visit(const luci::CircleRelu6 *node) final;
+ // loco::DataType visit(const luci::CircleReluN1To1 *node) final;
+ // loco::DataType visit(const luci::CircleReshape *node) final;
+ // loco::DataType visit(const luci::CircleResizeBilinear *node) final;
+ // loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final;
+ // loco::DataType visit(const luci::CircleReverseSequence *node) final;
+ // loco::DataType visit(const luci::CircleReverseV2 *node) final;
+ // loco::DataType visit(const luci::CircleRound *node) final;
+ // loco::DataType visit(const luci::CircleRsqrt *node) final;
+ // loco::DataType visit(const luci::CircleScatterNd *node) final;
+ // loco::DataType visit(const luci::CircleSegmentSum *node) final;
+ // loco::DataType visit(const luci::CircleSelect *node) final;
+ // loco::DataType visit(const luci::CircleSelectV2 *node) final;
+ // loco::DataType visit(const luci::CircleShape *node) final;
+ // loco::DataType visit(const luci::CircleSin *node) final;
+ // loco::DataType visit(const luci::CircleSlice *node) final;
+ // loco::DataType visit(const luci::CircleSoftmax *node) final;
+ // loco::DataType visit(const luci::CircleSpaceToBatchND *node) final;
+ // loco::DataType visit(const luci::CircleSpaceToDepth *node) final;
+ // loco::DataType visit(const luci::CircleSparseToDense *node) final;
+ // loco::DataType visit(const luci::CircleSplit *node) final;
+ // loco::DataType visit(const luci::CircleSplitV *node) final;
+ // loco::DataType visit(const luci::CircleSqrt *node) final;
+ // loco::DataType visit(const luci::CircleSquare *node) final;
+ // loco::DataType visit(const luci::CircleSquaredDifference *node) final;
+ // loco::DataType visit(const luci::CircleSqueeze *node) final;
+ // loco::DataType visit(const luci::CircleStridedSlice *node) final;
+ // loco::DataType visit(const luci::CircleSub *node) final;
+ // loco::DataType visit(const luci::CircleSum *node) final;
+ // loco::DataType visit(const luci::CircleTanh *node) final;
+ // loco::DataType visit(const luci::CircleTile *node) final;
+ // loco::DataType visit(const luci::CircleTopKV2 *node) final;
+ // loco::DataType visit(const luci::CircleTranspose *node) final;
+ // loco::DataType visit(const luci::CircleTransposeConv *node) final;
+ // loco::DataType visit(const luci::CircleUnidirectionalSequenceLSTM *node) final;
+ // loco::DataType visit(const luci::CircleUnique *node) final;
+ // loco::DataType visit(const luci::CircleUnpack *node) final;
+ // loco::DataType visit(const luci::CircleWhere *node) final;
+ // loco::DataType visit(const luci::CircleWhile *node) final;
+ // loco::DataType visit(const luci::CircleZerosLike *node) final;
+
+ // Circle Only
+ // loco::DataType visit(const luci::CircleBCQFullyConnected *node) final;
+ // loco::DataType visit(const luci::CircleBCQGather *node) final;
+ // loco::DataType visit(const luci::CircleInstanceNorm *node) final;
+
+ // Virtual
+ // loco::DataType visit(const luci::CircleInput *node) final;
+ // loco::DataType visit(const luci::CircleOutput *node) final;
+ // loco::DataType visit(const luci::CircleOutputDummy *node) final;
+ // loco::DataType visit(const luci::CircleOutputExclude *node) final;
+ // loco::DataType visit(const luci::CircleCustomOut *node) final;
+ // loco::DataType visit(const luci::CircleIfOut *node) final;
+ // loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final;
+ // loco::DataType visit(const luci::CircleNonMaxSuppressionV5Out *node) final;
+ // loco::DataType visit(const luci::CircleSplitOut *node) final;
+ // loco::DataType visit(const luci::CircleSplitVOut *node) final;
+ // loco::DataType visit(const luci::CircleTopKV2Out *node) final;
+ // loco::DataType visit(const luci::CircleUniqueOut *node) final;
+ // loco::DataType visit(const luci::CircleUnpackOut *node) final;
+ // loco::DataType visit(const luci::CircleWhileOut *node) final;
+};
+
+} // namespace tinf
+
} // namespace luci
#endif // __LUCI_CIRCLE_TYPE_INFERENCE_H__
diff --git a/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h b/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h
new file mode 100644
index 000000000..296f99355
--- /dev/null
+++ b/compiler/luci/service/include/luci/Service/CircleTypeInferenceHelper.h
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_CIRCLE_TYPE_INFERENCE_HELPER_H__
+#define __LUCI_CIRCLE_TYPE_INFERENCE_HELPER_H__
+
+#include <luci/IR/CircleNodes.h>
+
+#include <loco/IR/DataType.h>
+
+namespace luci
+{
+namespace tinf // Namespace for Type Inference
+{
+
+// Helper function will be added
+
+} // namespace tinf
+} // namespace luci
+
+#endif // __LUCI_CIRCLE_TYPE_INFERENCE_HELPER_H__
diff --git a/compiler/luci/service/include/luci/Service/ShapeDescription.h b/compiler/luci/service/include/luci/Service/ShapeDescription.h
index 949cce535..4d92be13f 100644
--- a/compiler/luci/service/include/luci/Service/ShapeDescription.h
+++ b/compiler/luci/service/include/luci/Service/ShapeDescription.h
@@ -20,6 +20,8 @@
#include <loco/IR/PermutingCodec.h>
#include <loco/IR/NodeShape.h>
+#include <luci/IR/CircleNodes.h>
+
#include <cstdint>
#include <vector>
@@ -33,6 +35,7 @@ struct ShapeDescription
};
// TODO remove these when CircleDialect is fully functioal
+ShapeDescription to_shape_description(const luci::CircleNode *node);
ShapeDescription to_shape_description(const loco::TensorShape &shape);
ShapeDescription to_shape_description(const loco::FeatureShape &shape);
ShapeDescription to_shape_description(const loco::FilterShape &shape);
diff --git a/compiler/luci/service/src/CircleShapeInference.cpp b/compiler/luci/service/src/CircleShapeInference.cpp
index 0732849db..db8ffd8ad 100644
--- a/compiler/luci/service/src/CircleShapeInference.cpp
+++ b/compiler/luci/service/src/CircleShapeInference.cpp
@@ -20,7 +20,10 @@
#include <loco.h>
#include <loco/Service/ShapeInference.h>
+#include <luci/Log.h>
+
#include <cassert>
+#include <iostream>
namespace luci
{
@@ -32,3 +35,60 @@ ShapeDescription ShapeInference::get(loco::Node *node)
}
} // namespace luci
+
+namespace
+{
+
+std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
+{
+ os << "[";
+ for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
+ {
+ if (r)
+ os << ",";
+ os << tensor_shape.dim(r).value();
+ }
+ os << "]";
+ return os;
+}
+
+bool inputs_shape_ready(const luci::CircleNode *node)
+{
+ for (uint32_t arity = 0; arity < node->arity(); ++arity)
+ {
+ auto node_input = loco::must_cast<luci::CircleNode *>(node->arg(arity));
+ if (node_input->shape_status() == luci::ShapeStatus::UNDEFINED)
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+namespace sinf
+{
+
+bool Rule::infer(const luci::CircleNode *circle_node, loco::TensorShape &shape) const
+{
+ LOGGER(l);
+ VERBOSE(l, 1) << "[CircleShapeInference] " << circle_node->name();
+ VERBOSE(l, 1) << " before: " << circle_shape(circle_node);
+
+ if (!inputs_shape_ready(circle_node))
+ {
+ VERBOSE(l, 1) << " after: Some inputs are not ready for inference";
+ return false;
+ }
+
+ Algorithm alg;
+ shape = circle_node->accept(&alg);
+ VERBOSE(l, 1) << " after: " << shape;
+
+ return true;
+}
+
+} // namespace ssinf
+} // namespace luci
diff --git a/compiler/luci/service/src/CircleShapeInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp
new file mode 100644
index 000000000..f7eb6c3ec
--- /dev/null
+++ b/compiler/luci/service/src/CircleShapeInferenceHelper.cpp
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleShapeInferenceHelper.h"
+
+namespace luci
+{
+namespace sinf
+{
+
+loco::TensorShape circle_shape(const luci::CircleNode *node)
+{
+ loco::TensorShape shape;
+ shape.rank(node->rank());
+ for (uint32_t r = 0; r < node->rank(); ++r)
+ shape.dim(r) = loco::Dimension(node->dim(r).value());
+ return shape;
+}
+
+} // namespace sinf
+} // namespace luci
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
index a55f50b19..38ff619ab 100644
--- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
@@ -102,7 +102,7 @@ private:
};
/**
- * @breif Expand shape x and y to same rank by align right and filling with 1
+ * @brief Expand shape x and y to same rank by align right and filling with 1
*/
void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
{
@@ -122,7 +122,7 @@ void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
}
/**
- * @breif Returns shape of expanded dimension of input x and y having same rank
+ * @brief Returns shape of expanded dimension of input x and y having same rank
*/
loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
{
diff --git a/compiler/luci/service/src/CircleShapeSignatureInferenceRule.cpp b/compiler/luci/service/src/CircleShapeSignatureInference.cpp
index dc7df3e39..1ccaa19d5 100644
--- a/compiler/luci/service/src/CircleShapeSignatureInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleShapeSignatureInference.cpp
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "luci/Service/CircleShapeSignatureInferenceRule.h"
+#include "luci/Service/CircleShapeSignatureInference.h"
#include <luci/Log.h>
@@ -39,14 +39,16 @@ std::ostream &operator<<(std::ostream &os, const luci::ShapeSignature &shape_sig
namespace luci
{
-bool CircleShapeSignatureInferenceRule::infer(const luci::CircleNode *circle_node,
- ShapeSignature &shape_signature) const
+namespace ssinf
+{
+
+bool Rule::infer(const luci::CircleNode *circle_node, ShapeSignature &shape_signature) const
{
LOGGER(l);
// There is nothing to check before ShapeSignatureInference.
- ShapeSignatureInferenceAlgorithm alg;
+ Algorithm alg;
shape_signature = circle_node->accept(&alg);
@@ -57,4 +59,6 @@ bool CircleShapeSignatureInferenceRule::infer(const luci::CircleNode *circle_nod
return true;
}
+} // namespace ssinf
+
} // namespace luci
diff --git a/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp b/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp
new file mode 100644
index 000000000..d7d1a24e8
--- /dev/null
+++ b/compiler/luci/service/src/CircleShapeSignatureInferenceHelper.cpp
@@ -0,0 +1,160 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleShapeSignatureInferenceHelper.h"
+
+#include <loco.h>
+
+#include <luci/Log.h>
+
+#include <oops/InternalExn.h>
+
+namespace luci
+{
+
+namespace ssinf
+{
+
+luci::ShapeSignature legalized_signature(const luci::ShapeSignature &signature)
+{
+ // If shape signature has at least one -1, it is not static.
+ for (uint32_t i = 0; i < signature.rank(); ++i)
+ if (signature.dim(i) == -1)
+ return signature;
+
+ // If all dimensions are static, return empty shape signature.
+ return luci::ShapeSignature();
+}
+
+ShapeSignature reduced_signature(const loco::Node *node, const loco::Node *indices, bool keep_dims)
+{
+ LOGGER(l);
+
+ ShapeSignature input_signature;
+ ShapeSignature output_signature;
+
+ auto circle_node = loco::must_cast<const luci::CircleNode *>(node);
+ if (circle_node->shape_signature().rank() > 0)
+ input_signature = circle_node->shape_signature();
+ else
+ {
+ input_signature.rank(circle_node->rank());
+ for (uint32_t i = 0; i < circle_node->rank(); ++i)
+ input_signature.dim(i) = circle_node->dim(i).value();
+ }
+
+ // If input rank is 0, it means that one of following case is occurred.
+ // - Input is scalar : result is always scalar
+ // - Input shape signature is not inferenced : cannot infer output shape signauture
+ // Therefore, when input signature rank is 0, always return empty signature.
+ if (input_signature.rank() == 0)
+ return output_signature;
+
+ // When reduction_indices is not constant
+ auto reduction_indices = dynamic_cast<const luci::CircleConst *>(indices);
+ if (reduction_indices == nullptr)
+ {
+ if (keep_dims)
+ {
+ // If keep_dims is true, rank is not changed.
+ output_signature.rank(input_signature.rank());
+ for (uint32_t i = 0; i < output_signature.rank(); ++i)
+ output_signature.dim(i) = -1;
+ }
+ else
+ {
+ // There is no way to inference for this case.
+ // Do nothing to return empty signature.
+ INFO(l) << "[CircleShapeSignatureInferenceHelper] " << circle_node->name() << std::endl;
+ INFO(l) << " reduced_signature : cannot infer because of non-constant node" << std::endl;
+ }
+
+ return output_signature;
+ }
+
+ std::vector<int32_t> reduction_values;
+ if (reduction_indices->dtype() == loco::DataType::S32)
+ {
+ auto reduction_size = reduction_indices->size<loco::DataType::S32>();
+ for (uint32_t i = 0; i < reduction_size; ++i)
+ {
+ int32_t axis = reduction_indices->at<loco::DataType::S32>(i);
+ if (axis < 0)
+ axis += input_signature.rank();
+
+ if (!(0 <= axis && axis < static_cast<int32_t>(input_signature.rank())))
+ INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
+
+ reduction_values.push_back(axis);
+ }
+ }
+ else if (reduction_indices->dtype() == loco::DataType::S64)
+ {
+ auto reduction_size = reduction_indices->size<loco::DataType::S64>();
+ for (uint32_t i = 0; i < reduction_size; ++i)
+ {
+ int32_t axis = static_cast<int32_t>(reduction_indices->at<loco::DataType::S64>(i));
+ if (axis < 0)
+ axis += input_signature.rank();
+
+ if (!(0 <= axis && axis < static_cast<int32_t>(input_signature.rank())))
+ INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
+
+ reduction_values.push_back(axis);
+ }
+ }
+ else
+ {
+ INTERNAL_EXN("Wrong reduction axis type, Only INT32, INT64 supported.");
+ }
+
+ if (keep_dims)
+ {
+ output_signature.rank(input_signature.rank());
+ for (uint32_t i = 0; i < input_signature.rank(); ++i)
+ output_signature.dim(i) = input_signature.dim(i);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ output_signature.dim(reduction_values.at(i)) = 1;
+ }
+ else
+ {
+ std::vector<bool> check_reduce(input_signature.rank(), false);
+ for (uint32_t i = 0; i < reduction_values.size(); ++i)
+ check_reduce.at(reduction_values.at(i)) = true;
+
+ uint32_t reduce_cnt = 0;
+ for (uint32_t i = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i))
+ ++reduce_cnt;
+
+ output_signature.rank(input_signature.rank() - reduce_cnt);
+ for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
+ if (check_reduce.at(i) == false)
+ output_signature.dim(j++) = input_signature.dim(i);
+ }
+
+ return output_signature;
+}
+
+ShapeSignature input_arg_signature(const luci::CircleNode *node, uint32_t index)
+{
+ auto circle_input = loco::must_cast<luci::CircleNode *>(node->arg(index));
+ return circle_input->shape_signature();
+}
+
+} // namespace ssinf
+
+} // namespace luci
diff --git a/compiler/luci/service/src/CircleTypeInference.cpp b/compiler/luci/service/src/CircleTypeInference.cpp
index aa8524a55..b4755b51a 100644
--- a/compiler/luci/service/src/CircleTypeInference.cpp
+++ b/compiler/luci/service/src/CircleTypeInference.cpp
@@ -16,6 +16,8 @@
#include "luci/Service/CircleTypeInference.h"
+#include <luci/Log.h>
+
#include <loco.h>
#include <loco/Service/TypeInference.h>
@@ -70,3 +72,47 @@ circle::TensorType TypeInference::get(loco::Node *node)
}
} // namespace luci
+
+namespace
+{
+
+bool inputs_dtype_ready(const luci::CircleNode *node)
+{
+ for (uint32_t arity = 0; arity < node->arity(); ++arity)
+ {
+ if (node->dtype() == loco::DataType::Unknown)
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+namespace tinf
+{
+
+bool Rule::infer(const luci::CircleNode *circle_node, loco::DataType &dtype) const
+{
+ LOGGER(l);
+ VERBOSE(l, 1) << "[CircleTypeInference] " << circle_node->name();
+ VERBOSE(l, 1) << " before: " << static_cast<int>(circle_node->dtype());
+
+ if (!inputs_dtype_ready(circle_node))
+ {
+ VERBOSE(l, 1) << " after: Some inputs are not ready for inference";
+ return false;
+ }
+
+ Algorithm alg;
+ dtype = circle_node->accept(&alg);
+
+ VERBOSE(l, 1) << " after: " << static_cast<int>(dtype);
+
+ return true;
+}
+
+} // namespace tinf
+} // namespace luci
diff --git a/compiler/luci/service/src/CircleTypeInferenceHelper.cpp b/compiler/luci/service/src/CircleTypeInferenceHelper.cpp
new file mode 100644
index 000000000..75cd9f7b2
--- /dev/null
+++ b/compiler/luci/service/src/CircleTypeInferenceHelper.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Service/CircleTypeInferenceHelper.h"
+
+namespace luci
+{
+namespace tinf
+{
+
+// Helper function will be added
+
+} // namespace tinf
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleInput.cpp b/compiler/luci/service/src/Nodes/CircleInput.cpp
new file mode 100644
index 000000000..24eab7bd6
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleInput.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleInput *node)
+{
+ return node->shape_signature();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleMean.cpp b/compiler/luci/service/src/Nodes/CircleMean.cpp
new file mode 100644
index 000000000..a78713698
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleMean.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleMean *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleOutput.cpp b/compiler/luci/service/src/Nodes/CircleOutput.cpp
new file mode 100644
index 000000000..d4c8da2d8
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleOutput.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutput *node)
+{
+ return input_arg_signature(node, 0);
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp b/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp
new file mode 100644
index 000000000..e0f13c439
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleOutputDummy.cpp
@@ -0,0 +1,24 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutputDummy *) { return ShapeSignature(); }
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp b/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp
new file mode 100644
index 000000000..75bbbb3c0
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleOutputExclude.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleOutputExclude *)
+{
+ return ShapeSignature();
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceAny.cpp b/compiler/luci/service/src/Nodes/CircleReduceAny.cpp
new file mode 100644
index 000000000..27da81466
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceAny.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceAny *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceMax.cpp b/compiler/luci/service/src/Nodes/CircleReduceMax.cpp
new file mode 100644
index 000000000..48d9cb970
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceMax.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceMax *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceMin.cpp b/compiler/luci/service/src/Nodes/CircleReduceMin.cpp
new file mode 100644
index 000000000..9a9997118
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceMin.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceMin *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReduceProd.cpp b/compiler/luci/service/src/Nodes/CircleReduceProd.cpp
new file mode 100644
index 000000000..a9d381a74
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReduceProd.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleReduceProd *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleRelu.cpp b/compiler/luci/service/src/Nodes/CircleRelu.cpp
new file mode 100644
index 000000000..a7a7f6f0a
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRelu.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleRelu *node)
+{
+ return input_arg_signature(node, 0);
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleRelu6.cpp b/compiler/luci/service/src/Nodes/CircleRelu6.cpp
new file mode 100644
index 000000000..92a596d08
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleRelu6.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleRelu6 *node)
+{
+ return input_arg_signature(node, 0);
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp b/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp
new file mode 100644
index 000000000..1e8d9971d
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleReluN1To1.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleReluN1To1 *node)
+{
+ return input_arg_signature(node, 0);
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/Nodes/CircleSum.cpp b/compiler/luci/service/src/Nodes/CircleSum.cpp
new file mode 100644
index 000000000..9ef90e8e0
--- /dev/null
+++ b/compiler/luci/service/src/Nodes/CircleSum.cpp
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/Service/CircleShapeSignatureInference.h>
+
+namespace luci
+{
+
+ShapeSignature ssinf::Algorithm::visit(const luci::CircleSum *node)
+{
+ return legalized_signature(
+ reduced_signature(node->input(), node->reduction_indices(), node->keep_dims()));
+}
+
+} // namespace luci
diff --git a/compiler/luci/service/src/ShapeDescription.cpp b/compiler/luci/service/src/ShapeDescription.cpp
index cbc302f70..01a638f8f 100644
--- a/compiler/luci/service/src/ShapeDescription.cpp
+++ b/compiler/luci/service/src/ShapeDescription.cpp
@@ -23,6 +23,19 @@
namespace luci
{
+ShapeDescription to_shape_description(const luci::CircleNode *circle_node)
+{
+ ShapeDescription res;
+
+ res._rank_known = true;
+
+ res._dims.resize(circle_node->rank());
+ for (uint32_t i = 0; i < circle_node->rank(); ++i)
+ res._dims.at(i) = circle_node->dim(i).value();
+
+ return res;
+}
+
ShapeDescription to_shape_description(const loco::TensorShape &shape)
{
ShapeDescription res;
diff --git a/compiler/luci/service/src/Validate.cpp b/compiler/luci/service/src/Validate.cpp
index d224fd172..3f732b6fe 100644
--- a/compiler/luci/service/src/Validate.cpp
+++ b/compiler/luci/service/src/Validate.cpp
@@ -42,6 +42,19 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape
return os;
}
+std::ostream &operator<<(std::ostream &os, const luci::CircleNode *circle_node)
+{
+ os << "[";
+ for (uint32_t r = 0; r < circle_node->rank(); ++r)
+ {
+ if (r)
+ os << ",";
+ os << circle_node->dim(r).value();
+ }
+ os << "]";
+ return os;
+}
+
/**
* @brief returns a node that is CircleOutput with index is out_index in nodes
*/
@@ -80,23 +93,28 @@ bool validate_shape_dtype(loco::Graph *g)
if (dynamic_cast<luci::CircleOutputExclude *>(circle_node))
continue;
- assert(loco::shape_known(circle_node));
+ assert(circle_node->shape_status() != luci::ShapeStatus::UNDEFINED);
// check if output node shape is same as graph output shape
- auto co_tensor_shape = loco::shape_get(circle_node).as<loco::TensorShape>();
auto go_tensor_shape = graph_out->shape();
assert(go_tensor_shape);
- if (!(co_tensor_shape == *go_tensor_shape))
+
+ bool is_shape_valid = (circle_node->rank() == go_tensor_shape->rank());
+ for (uint32_t i = 0; is_shape_valid && i < circle_node->rank(); ++i)
+ if (circle_node->dim(i).value() != go_tensor_shape->dim(i).value())
+ is_shape_valid = false;
+
+ if (is_shape_valid == false)
{
INFO(l) << "[luci] Shape for output #" << out_index << " not same " << std::endl;
- INFO(l) << "[luci] " << circle_node->name() << " " << co_tensor_shape << " vs "
+ INFO(l) << "[luci] " << circle_node->name() << " " << circle_node << " vs "
<< *go_tensor_shape << std::endl;
return false;
}
// check if data type match
- assert(loco::dtype_known(circle_node));
- if (graph_out->dtype() != loco::dtype_get(circle_node))
+ assert(circle_node->dtype() != loco::DataType::Unknown);
+ if (graph_out->dtype() != circle_node->dtype())
{
INFO(l) << "[luci] Type for output #" << out_index << " not same " << std::endl;
return false;
@@ -106,6 +124,55 @@ bool validate_shape_dtype(loco::Graph *g)
return true;
}
+bool validate_shape_signature(loco::Graph *g)
+{
+ LOGGER(l);
+
+ for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ const auto shape_signature = circle_node->shape_signature();
+
+ if (shape_signature.rank() == 0)
+ continue;
+
+ // Rank of shape and shape signature should be same
+ if (circle_node->rank() != shape_signature.rank())
+ {
+ INFO(l) << "[luci] Rank of shape signature for " << circle_node->name() << " do not match"
+ << std::endl;
+ return false;
+ }
+
+ bool has_unknown = false;
+
+ // If shape siganture is not -1, dimension value should be same
+ for (uint32_t d = 0; d < shape_signature.rank(); ++d)
+ {
+ if (shape_signature.dim(d) != -1 &&
+ shape_signature.dim(d) != (int32_t)(circle_node->dim(d).value()))
+ {
+ INFO(l) << "[luci] Dimension " << d << "of shape signature for " << circle_node->name()
+ << " do not match" << std::endl;
+ return false;
+ }
+
+ if (shape_signature.dim(d) == -1)
+ has_unknown = true;
+ }
+
+ // Shape signature should have at least one -1 value.
+ if (!has_unknown)
+ {
+ INFO(l) << "[luci] Shape signature in " << circle_node->name()
+ << " do not have unknown dimension" << std::endl;
+ return false;
+ }
+ }
+
+ return true;
+}
+
} // namespace
namespace luci
@@ -119,6 +186,9 @@ bool validate(loco::Graph *g)
if (!validate_shape_dtype(g))
return false;
+ if (!validate_shape_signature(g))
+ return false;
+
// TODO add more validation
return true;
diff --git a/compiler/luci/tester/src/ReadTester.cpp b/compiler/luci/tester/src/ReadTester.cpp
index a1aead1bd..f270a232c 100644
--- a/compiler/luci/tester/src/ReadTester.cpp
+++ b/compiler/luci/tester/src/ReadTester.cpp
@@ -21,6 +21,9 @@
#include <luci/Pass/ShapeInferencePass.h>
#include <luci/Pass/TypeInferencePass.h>
+// Following passes will be removed after refactoring is finished
+#include <luci/Pass/MigrateLegacyShapeDtypePass.h>
+
#include <iostream>
#include <map>
#include <string>
@@ -95,6 +98,12 @@ int entry(int argc, char **argv)
while (pass.run(graph) == true)
;
}
+ {
+ // This pass will be removed after refactoring is finished
+ luci::MigrateLegacyShapeDtypePass pass;
+ while (pass.run(graph) == true)
+ ;
+ }
if (!luci::validate(graph))
return 255;
diff --git a/compiler/luci/tester/src/WriteTester.cpp b/compiler/luci/tester/src/WriteTester.cpp
index aa7085c77..9a6e8de05 100644
--- a/compiler/luci/tester/src/WriteTester.cpp
+++ b/compiler/luci/tester/src/WriteTester.cpp
@@ -23,6 +23,9 @@
#include <luci/CircleExporter.h>
#include <oops/InternalExn.h>
+// Following passes will be removed after refactoring is finished
+#include <luci/Pass/MigrateLegacyShapeDtypePass.h>
+
#include <fstream>
#include <iostream>
#include <map>
@@ -139,6 +142,12 @@ int entry(int argc, char **argv)
while (pass.run(graph) == true)
;
}
+ {
+ // This pass will be removed after refactoring is finished
+ luci::MigrateLegacyShapeDtypePass pass;
+ while (pass.run(graph) == true)
+ ;
+ }
if (!luci::validate(graph))
return 255;