summaryrefslogtreecommitdiff
path: root/compiler/luci/pass
diff options
context:
space:
mode:
authorHyeongseok Oh <hseok82.oh@samsung.com>2023-04-12 15:42:02 +0900
committerHyeongseok Oh <hseok82.oh@samsung.com>2023-04-12 15:42:02 +0900
commit323663bb115ef625642391a5a8e9b35fee8b2ae3 (patch)
tree17e2a6b91535e6f53f4cacda5e4db6aa0303dd22 /compiler/luci/pass
parentc690d52bdd137ed6a17353aa7af35e8141ece77b (diff)
downloadnnfw-323663bb115ef625642391a5a8e9b35fee8b2ae3.tar.gz
nnfw-323663bb115ef625642391a5a8e9b35fee8b2ae3.tar.bz2
nnfw-323663bb115ef625642391a5a8e9b35fee8b2ae3.zip
Imported Upstream version 1.22.0upstream/1.22.0
Diffstat (limited to 'compiler/luci/pass')
-rw-r--r--compiler/luci/pass/CMakeLists.txt2
-rw-r--r--compiler/luci/pass/include/luci/CircleOptimizer.h5
-rw-r--r--compiler/luci/pass/include/luci/Pass/FoldFullyConnectedPass.h38
-rw-r--r--compiler/luci/pass/include/luci/Pass/ForwardTransposeOpPass.h (renamed from compiler/luci/pass/src/test/TestIOGraph.test.cpp)24
-rw-r--r--compiler/luci/pass/include/luci/Pass/FusePReluPass.h40
-rw-r--r--compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h21
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveDuplicateConstPass.h45
-rw-r--r--compiler/luci/pass/include/luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h37
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp25
-rw-r--r--compiler/luci/pass/src/CircleQuantizer.cpp191
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp180
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp252
-rw-r--r--compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp17
-rw-r--r--compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp2
-rw-r--r--compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp2
-rw-r--r--compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp2
-rw-r--r--compiler/luci/pass/src/FoldFullyConnectedPass.cpp198
-rw-r--r--compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp160
-rw-r--r--compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp28
-rw-r--r--compiler/luci/pass/src/ForwardTransposeOpPass.cpp366
-rw-r--r--compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp524
-rw-r--r--compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp8
-rw-r--r--compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp20
-rw-r--r--compiler/luci/pass/src/FuseBCQPass.cpp1
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp53
-rw-r--r--compiler/luci/pass/src/FusePReluPass.cpp202
-rw-r--r--compiler/luci/pass/src/FusePReluPass.test.cpp187
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.cpp6
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.h3
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.cpp17
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.h1
-rw-r--r--compiler/luci/pass/src/QuantizeWeights.cpp173
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp76
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp39
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.h17
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.test.cpp279
-rw-r--r--compiler/luci/pass/src/RemoveDuplicateConstPass.cpp225
-rw-r--r--compiler/luci/pass/src/RemoveDuplicateConstPass.test.cpp159
-rw-r--r--compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp64
-rw-r--r--compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp4
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp1
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpMaxPoolWithArgmaxPass.cpp1
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp2
-rw-r--r--compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp672
-rw-r--r--compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.test.cpp211
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h8
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.cpp8
-rw-r--r--compiler/luci/pass/src/helpers/NodeFiller.h26
-rw-r--r--compiler/luci/pass/src/helpers/SparsityFormatConverter.h1
-rw-r--r--compiler/luci/pass/src/helpers/Strings.cpp9
-rw-r--r--compiler/luci/pass/src/helpers/Strings.h2
-rw-r--r--compiler/luci/pass/src/helpers/Strings.test.cpp23
-rw-r--r--compiler/luci/pass/src/test/TestIOGraph.h161
53 files changed, 4375 insertions, 443 deletions
diff --git a/compiler/luci/pass/CMakeLists.txt b/compiler/luci/pass/CMakeLists.txt
index d9d004db9..ac18a5f8d 100644
--- a/compiler/luci/pass/CMakeLists.txt
+++ b/compiler/luci/pass/CMakeLists.txt
@@ -31,7 +31,7 @@ target_link_libraries(luci_pass PRIVATE luci_log)
target_link_libraries(luci_pass PRIVATE luci_service)
target_link_libraries(luci_pass PRIVATE luci_logex)
target_link_libraries(luci_pass PRIVATE luci_profile)
-target_link_libraries(luci_pass PRIVATE mio_tflite280_inc)
+target_link_libraries(luci_pass PRIVATE luci_compute)
target_link_libraries(luci_pass PRIVATE nncc_common)
target_link_libraries(luci_pass PRIVATE pepper_csv2vec)
target_link_libraries(luci_pass PRIVATE oops)
diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h
index b94822c35..d77e89db1 100644
--- a/compiler/luci/pass/include/luci/CircleOptimizer.h
+++ b/compiler/luci/pass/include/luci/CircleOptimizer.h
@@ -52,14 +52,17 @@ public:
FoldCast,
FoldDensify,
FoldDepthwiseConv2D,
+ FoldFullyConnected,
FoldDequantize,
FoldGather,
FoldSparseToDense,
ForwardReshapeToUnaryOp,
+ ForwardTransposeOp,
SparsifyTensorPass,
FusePreActivationBatchNorm,
MakeBatchNormGammaPositive,
FuseActivationFunction,
+ FusePRelu,
ShuffleWeightTo16x1Float32,
RemoveRedundantTranspose,
ReplaceMulAddWithDepthwiseConv,
@@ -83,6 +86,8 @@ public:
RemoveRedundantReshape,
RemoveFakeQuant,
RemoveQuantDequantSeq,
+ RemoveDuplicateConst,
+ UnrollUnidirSeqLSTM,
};
enum AlgorithmParameters
diff --git a/compiler/luci/pass/include/luci/Pass/FoldFullyConnectedPass.h b/compiler/luci/pass/include/luci/Pass/FoldFullyConnectedPass.h
new file mode 100644
index 000000000..bd36ff149
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FoldFullyConnectedPass.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FOLD_FULLY_CONNECTED_PASS_H__
+#define __LUCI_FOLD_FULLY_CONNECTED_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fold FullyConnected with constant input and filter into a
+ * constant tensor
+ */
+struct FoldFullyConnectedPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FoldFullyConnectedPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FOLD_FULLY_CONNECTED_PASS_H__
diff --git a/compiler/luci/pass/src/test/TestIOGraph.test.cpp b/compiler/luci/pass/include/luci/Pass/ForwardTransposeOpPass.h
index e58a13f2b..b44b1bde1 100644
--- a/compiler/luci/pass/src/test/TestIOGraph.test.cpp
+++ b/compiler/luci/pass/include/luci/Pass/ForwardTransposeOpPass.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -14,6 +14,24 @@
* limitations under the License.
*/
-#include "TestIOGraph.h"
+#ifndef __LUCI_FORWARD_TRANSPOSE_OP_PASS_H__
+#define __LUCI_FORWARD_TRANSPOSE_OP_PASS_H__
-// This file validates "TestIOGraph.h". Pleaes DO NOT remove this file.
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Forward Transpose Ops for further optimization.
+ */
+struct ForwardTransposeOpPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::ForwardTransposeOpPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FORWARD_TRANSPOSE_OP_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/FusePReluPass.h b/compiler/luci/pass/include/luci/Pass/FusePReluPass.h
new file mode 100644
index 000000000..a21acf49d
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FusePReluPass.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FUSE_PRELU_PASS_H__
+#define __LUCI_FUSE_PRELU_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to fuse certain pattern of subgraph into CirclePRelu
+ * with auxiliary nodes
+ *
+ * For detailed subgraph pattern to be fused, please check its implementation.
+ */
+struct FusePReluPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FusePReluPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FUSE_PRELU_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
index ea6db85d1..6874046f0 100644
--- a/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
+++ b/compiler/luci/pass/include/luci/Pass/QuantizeWithMinMaxPass.h
@@ -39,29 +39,12 @@ public:
loco::DataType input_model_dtype = loco::DataType::Unknown;
loco::DataType output_model_dtype = loco::DataType::Unknown;
QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
- loco::DataType input_type = loco::DataType::Unknown;
- loco::DataType output_type = loco::DataType::Unknown;
+ std::vector<loco::DataType> input_types;
+ std::vector<loco::DataType> output_types;
bool TF_style_maxpool = false;
std::vector<LayerInfo> layers_info;
};
- // For backward-compatibility
- // TODO Remove this constructor
-public:
- QuantizeWithMinMaxPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype,
- QuantizationGranularity granularity)
- {
- _ctx = std::make_unique<Context>();
- {
- _ctx->input_model_dtype = input_model_dtype;
- _ctx->output_model_dtype = output_model_dtype;
- _ctx->granularity = granularity;
- _ctx->input_type = output_model_dtype;
- _ctx->output_type = output_model_dtype;
- _ctx->TF_style_maxpool = false;
- }
- }
-
public:
QuantizeWithMinMaxPass(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)}
{
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveDuplicateConstPass.h b/compiler/luci/pass/include/luci/Pass/RemoveDuplicateConstPass.h
new file mode 100644
index 000000000..000cdcc43
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveDuplicateConstPass.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_DUPLICATE_CONST_PASS_H__
+#define __LUCI_REMOVE_DUPLICATE_CONST_PASS_H__
+
+#include <luci/IR/CircleNodes.h>
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to remove duplicate Const nodes.
+ */
+struct RemoveDuplicateConstPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveDuplicateConstPass"; }
+
+ bool run(loco::Graph *g) final;
+
+private:
+ bool remove_duplicate_const();
+
+ template <loco::DataType DT> void add_to_map(luci::CircleConst *const_node);
+
+ std::map<float, std::vector<CircleConst *>> _sum_to_const;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_DUPLICATE_CONST_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h b/compiler/luci/pass/include/luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h
new file mode 100644
index 000000000..fd5a708e8
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_UNROLL_UNIDIRECTIONALSEQUENCELSTM_PASS_H__
+#define __LUCI_UNROLL_UNIDIRECTIONALSEQUENCELSTM_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Unroll UnidirectionalSequenceLSTM
+ */
+struct UnrollUnidirectionalSequenceLSTMPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::UnrollUnidirectionalSequenceLSTMPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_UNROLL_UNIDIRECTIONALSEQUENCELSTM_PASS_H__
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index 74c569d20..5e1613ad9 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -23,9 +23,11 @@
#include "luci/Pass/FoldDensifyPass.h"
#include "luci/Pass/FoldDepthwiseConv2DPass.h"
#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/FoldFullyConnectedPass.h"
#include "luci/Pass/FoldGatherPass.h"
#include "luci/Pass/FoldSparseToDensePass.h"
#include "luci/Pass/ForwardReshapeToUnaryOpPass.h"
+#include "luci/Pass/ForwardTransposeOpPass.h"
#include "luci/Pass/FuseActivationFunctionPass.h"
#include "luci/Pass/FuseAddWithFullyConnectedPass.h"
#include "luci/Pass/FuseAddWithTConvPass.h"
@@ -36,8 +38,10 @@
#include "luci/Pass/FuseInstanceNormPass.h"
#include "luci/Pass/FuseMeanWithMeanPass.h"
#include "luci/Pass/FusePreActivationBatchNormPass.h"
+#include "luci/Pass/FusePReluPass.h"
#include "luci/Pass/FuseTransposeWithMeanPass.h"
#include "luci/Pass/MakeBatchNormGammaPositivePass.h"
+#include "luci/Pass/RemoveDuplicateConstPass.h"
#include "luci/Pass/RemoveFakeQuantPass.h"
#include "luci/Pass/RemoveQuantDequantSeqPass.h"
#include "luci/Pass/RemoveRedundantReshapePass.h"
@@ -66,6 +70,7 @@
#include "luci/Pass/SubstituteTransposeToReshapePass.h"
#include "luci/Pass/TransformMinMaxToRelu6Pass.h"
#include "luci/Pass/TransformMinReluToRelu6Pass.h"
+#include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h"
// TODO add more passes
#include "luci/Pass/CircleShapeInferencePass.h"
@@ -274,6 +279,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FuseActivationFunctionPass>());
}
+ if (_options->query(Options::Algorithm::FusePRelu))
+ {
+ phase.emplace_back(std::make_unique<FusePReluPass>());
+ }
if (_options->query(Options::Algorithm::FuseTransposeWithMean))
{
phase.emplace_back(std::make_unique<FuseTransposeWithMeanPass>());
@@ -298,6 +307,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
}
+ if (_options->query(Options::Algorithm::FoldFullyConnected))
+ {
+ phase.emplace_back(std::make_unique<luci::FoldFullyConnectedPass>());
+ }
if (_options->query(Options::Algorithm::FoldGather))
{
phase.emplace_back(std::make_unique<luci::FoldGatherPass>());
@@ -310,6 +323,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::ForwardReshapeToUnaryOpPass>());
}
+ if (_options->query(Options::Algorithm::ForwardTransposeOp))
+ {
+ phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>());
+ }
if (_options->query(Options::Algorithm::FusePreActivationBatchNorm))
{
phase.emplace_back(std::make_unique<luci::FusePreActivationBatchNormPass>());
@@ -326,6 +343,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::ExpandBroadcastConstPass>());
}
+ if (_options->query(Options::Algorithm::RemoveDuplicateConst))
+ {
+ phase.emplace_back(std::make_unique<luci::RemoveDuplicateConstPass>());
+ }
if (_options->query(Options::Algorithm::RemoveFakeQuant))
{
phase.emplace_back(std::make_unique<luci::RemoveFakeQuantPass>());
@@ -407,6 +428,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::TransformMinReluToRelu6Pass>());
}
+ if (_options->query(Options::Algorithm::UnrollUnidirSeqLSTM))
+ {
+ phase.emplace_back(std::make_unique<luci::UnrollUnidirectionalSequenceLSTMPass>());
+ }
/* TRANSFORM DECLARATION END */
diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp
index 9a6550b9f..3ffa1180c 100644
--- a/compiler/luci/pass/src/CircleQuantizer.cpp
+++ b/compiler/luci/pass/src/CircleQuantizer.cpp
@@ -40,6 +40,7 @@
#include <luci/IR/CircleNode.h>
#include <logo/Phase.h>
+#include <pepper/csv2vec.h>
#include <memory>
@@ -49,6 +50,154 @@ namespace
using namespace luci;
using LayerParam = luci::CircleQuantizer::Options::LayerParam;
+// This function updates user-given input_type to match with the input signature of graph
+// If user gives only one input_type, it will be expanded to the number of graph inputs
+void canonicalize_input_type(loco::Graph *g, std::vector<loco::DataType> &input_type)
+{
+ if (g == nullptr)
+ return;
+
+ const auto inputs = g->inputs();
+
+ assert(inputs); // FIX_CALLER_UNLESS
+
+ // Check validity of the number of input dtype given by a user
+ if (input_type.size() != 1 and input_type.size() != inputs->size())
+ {
+ throw std::runtime_error(
+ "Invalid number of input dtype. The number of input dtype should be 1 or "
+ "the same as the number of graph inputs.");
+ }
+
+ // Handle the case when a user gives only one input dtype
+ if (input_type.size() == 1)
+ {
+ const auto user_given_dtype = input_type[0];
+ input_type.clear();
+
+ // Expand input dtype to the number of graph inputs
+ // Since quantizer can only quantize float32, user_given_dtype is set only for float32 inputs
+ auto input_nodes = loco::input_nodes(g);
+ for (uint32_t i = 0; i < input_nodes.size(); i++)
+ {
+ auto input = loco::must_cast<luci::CircleInput *>(input_nodes[i]);
+
+ if (input->dtype() == loco::DataType::FLOAT32)
+ input_type.push_back(user_given_dtype);
+ else
+ input_type.push_back(input->dtype());
+ }
+ }
+
+ // Finally, check validity of input_type
+ // input_type is valid if
+ // C1. for non-float32 model input, input_type == model's input dtype
+ // or
+ // C2. for float32 model input, input_type == uint8, int16, or float32
+ auto input_nodes = loco::input_nodes(g);
+ for (uint32_t i = 0; i < input_nodes.size(); i++)
+ {
+ auto input = loco::must_cast<luci::CircleInput *>(input_nodes[i]);
+ assert(i == input->index()); // FIX_ME_UNLESS
+
+ if (input->dtype() != loco::DataType::FLOAT32)
+ {
+ // C1
+ if (input->dtype() != input_type[i])
+ throw std::runtime_error(
+ "Input dtype of " + input->name() +
+ " is invalid. It has to be the same with the model's input dtype.");
+ }
+ else
+ {
+ // C2
+ if (input_type[i] != loco::DataType::FLOAT32 and input_type[i] != loco::DataType::U8 and
+ input_type[i] != loco::DataType::S16)
+ {
+ throw std::runtime_error("Input dtype of " + input->name() +
+ " is invalid. For float32 input, the input dtype after "
+ "quantization must be one of uint8, int16, or float32.");
+ }
+ }
+ }
+}
+
+// This function updates user-given output_type to match with the output signature of graph
+// If user gives only one output_type, it will be expanded to the number of graph outputs
+// NOTE This function is almost same with canonicalize_input_type, but it is written as a
+// separate function for more precise error messaging.
+// TODO Find a way to reduce duplicate codes
+void canonicalize_output_type(loco::Graph *g, std::vector<loco::DataType> &output_type)
+{
+ if (g == nullptr)
+ return;
+
+ const auto outputs = g->outputs();
+
+ assert(outputs); // FIX_CALLER_UNLESS
+
+ // Check validity of the number of output dtype given by a user
+ if (output_type.size() != 1 and output_type.size() != outputs->size())
+ {
+ throw std::runtime_error(
+ "Invalid number of output dtype. The number of output dtype should be 1 or "
+ "the same as the number of graph outputs.");
+ }
+
+ // Handle the case when a user gives only one output dtype
+ if (output_type.size() == 1)
+ {
+ const auto user_given_dtype = output_type[0];
+ output_type.clear();
+
+ // Expand output dtype to the number of graph outputs
+ // If dtype of graph output is float32, it will be replaced with user_given_dtype
+ // Otherwise, it will not change
+ auto output_nodes = loco::output_nodes(g);
+ for (uint32_t i = 0; i < output_nodes.size(); i++)
+ {
+ auto output = loco::must_cast<luci::CircleOutput *>(output_nodes[i]);
+
+ if (output->dtype() == loco::DataType::FLOAT32)
+ output_type.push_back(user_given_dtype);
+ else
+ output_type.push_back(output->dtype());
+ }
+ }
+
+ // Finally, check validity of output_type
+ // output_type is valid if
+ // C1. for non-float32 model output, output_type == model's output dtype
+ // or
+ // C2. for float32 model output, output_type == uint8, int16, or float32
+ auto output_nodes = loco::output_nodes(g);
+ for (uint32_t i = 0; i < output_nodes.size(); i++)
+ {
+ auto output = loco::must_cast<luci::CircleOutput *>(output_nodes[i]);
+ assert(i == output->index()); // FIX_ME_UNLESS
+
+ if (output->dtype() != loco::DataType::FLOAT32)
+ {
+ // C1
+ if (output->dtype() != output_type[i])
+ throw std::runtime_error(
+ "Output dtype of " + output->name() +
+ " is invalid. It has to be the same with the model's output dtype.");
+ }
+ else
+ {
+ // C2
+ if (output_type[i] != loco::DataType::FLOAT32 and output_type[i] != loco::DataType::U8 and
+ output_type[i] != loco::DataType::S16)
+ {
+ throw std::runtime_error("Output dtype of " + output->name() +
+ " is invalid. For float32 output, the output dtype after "
+ "quantization must be one of uint8, int16, or float32.");
+ }
+ }
+ }
+}
+
template <typename T> T lexical_cast(const std::string &str)
{
std::istringstream ss;
@@ -253,8 +402,10 @@ void CircleQuantizer::quantize(loco::Graph *g) const
static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
- static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16", "float32"};
- static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16", "float32"};
+ static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16", "int32",
+ "int64", "float32", "bool"};
+ static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16", "int32",
+ "int64", "float32", "bool"};
auto input_model_dtype =
_options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
@@ -268,6 +419,9 @@ void CircleQuantizer::quantize(loco::Graph *g) const
if (output_type.empty())
output_type = output_model_dtype;
+ auto input_type_vec = pepper::csv_to_vector<std::string>(input_type);
+ auto output_type_vec = pepper::csv_to_vector<std::string>(output_type);
+
bool TF_style_maxpool =
_options->param(Options::AlgorithmParameters::Quantize_TF_style_maxpool) == "True";
@@ -285,13 +439,19 @@ void CircleQuantizer::quantize(loco::Graph *g) const
throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
to_string(qwmm_supported_granularity));
- if (!in_array(to_lower_case(input_type), qwmm_supported_input_type))
- throw std::runtime_error("Unsupported input type. List of supported input types: " +
- to_string(qwmm_supported_input_type));
+ for (auto dtype : input_type_vec)
+ {
+ if (!in_array(to_lower_case(dtype), qwmm_supported_input_type))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(qwmm_supported_input_type));
+ }
- if (!in_array(to_lower_case(output_type), qwmm_supported_output_type))
- throw std::runtime_error("Unsupported output type. List of supported output types: " +
- to_string(qwmm_supported_output_type));
+ for (auto dtype : output_type_vec)
+ {
+ if (!in_array(to_lower_case(dtype), qwmm_supported_output_type))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(qwmm_supported_output_type));
+ }
if (str_to_granularity(granularity) == QuantizationGranularity::LayerWise &&
str_to_dtype(output_model_dtype) != loco::DataType::U8)
@@ -314,6 +474,13 @@ void CircleQuantizer::quantize(loco::Graph *g) const
}
}
+ auto input_types = str_vec_to_dtype_vec(input_type_vec);
+ auto output_types = str_vec_to_dtype_vec(output_type_vec);
+
+ // Canonicalize user-given input/output_type (match with # of inputs/outputs)
+ canonicalize_input_type(g, input_types);
+ canonicalize_output_type(g, output_types);
+
// Input model checker for quantization
luci::QuantizePreCheckerPass input_model_checker{};
input_model_checker.run(g);
@@ -323,8 +490,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const
ctx->input_model_dtype = str_to_dtype(input_model_dtype);
ctx->output_model_dtype = str_to_dtype(output_model_dtype);
ctx->granularity = str_to_granularity(granularity);
- ctx->input_type = str_to_dtype(input_type);
- ctx->output_type = str_to_dtype(output_type);
+ ctx->input_types = input_types;
+ ctx->output_types = output_types;
ctx->TF_style_maxpool = TF_style_maxpool;
for (auto layer_param : layer_params)
@@ -347,8 +514,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const
{
verify_ctx->output_model_dtype = str_to_dtype(output_model_dtype);
verify_ctx->granularity = str_to_granularity(granularity);
- verify_ctx->input_type = str_to_dtype(input_type);
- verify_ctx->output_type = str_to_dtype(output_type);
+ verify_ctx->input_types = input_types;
+ verify_ctx->output_types = output_types;
verify_ctx->TF_style_maxpool = TF_style_maxpool;
for (auto layer_param : layer_params)
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
index 55a29d105..99e1e2939 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
@@ -503,43 +503,30 @@ bool is_NCHW(const luci::CirclePadV2 *node)
return true;
}
-// NOTE Following conditions can be extended later
-// NOTE Used for Maximum, Miminum as ReLU/ReLU6
-//
-// Find T with an NCHW pattern described below
-// - Input (non-constant) shape : [N, C, H, W]
-// - Input (constant) shape : [1] or []
-// - Output shape : [N, C, H, W]
-template <class T>
-bool is_NCHW_with_s_const(const T *node, luci::CircleNode *&pred_node,
- luci::CircleConst *&comp_const)
+bool is_const(const loco::Node *node)
{
- auto x = dynamic_cast<luci::CircleConst *>(node->x());
- auto y = dynamic_cast<luci::CircleConst *>(node->y());
-
- if (x != nullptr && y == nullptr)
- {
- pred_node = loco::must_cast<luci::CircleNode *>(node->y());
- comp_const = x;
- }
- else if (x == nullptr && y != nullptr)
- {
- pred_node = loco::must_cast<luci::CircleNode *>(node->x());
- comp_const = y;
- }
- else
- {
- // Ignore if T does not have a comp_const input.
+ if (not dynamic_cast<const luci::CircleConst *>(node))
return false;
- }
- if (pred_node->rank() != 4)
+ return true;
+}
+
+bool is_scalar_const(const loco::Node *node)
+{
+ auto const_node = dynamic_cast<const luci::CircleConst *>(node);
+ if (not const_node)
return false;
- // Check if scalar
- const auto const_rank = comp_const->rank();
- if (const_rank == 0 || (const_rank == 1 && comp_const->dim(0).value() == 1))
+ const auto const_rank = const_node->rank();
+ // shape of scalar
+ // 1. rank = 0
+ // 2. rank = 1, dimension = 1
+ if (const_rank == 0)
+ return true;
+
+ if (const_rank == 1 && const_node->dim(0).value() == 1)
return true;
+
return false;
}
@@ -854,22 +841,30 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
bool visit(luci::CircleLogistic *node) { return convert_unary_x<luci::CircleLogistic>(node); }
- bool visit(luci::CircleLogSoftmax *node)
- {
- return convert_unary_logits<luci::CircleLogSoftmax>(node);
- }
-
bool visit(luci::CircleMaximum *node)
{
- luci::CircleNode *pred_node = nullptr;
- luci::CircleConst *comp_constant = nullptr;
-
- if (is_NCHW_with_s_const<luci::CircleMaximum>(node, pred_node, comp_constant))
+ if ((not is_const(node->x())) and is_scalar_const(node->y()))
{
auto pre_trans = create_pre_transpose(node);
- pre_trans->a(pred_node);
+ pre_trans->a(node->x());
node->x(pre_trans);
}
+ else if (is_scalar_const(node->x()) and (not is_const(node->y())))
+ {
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(node->y());
+ node->y(pre_trans);
+ }
+ else if ((not is_const(node->x())) and (not is_const(node->y())))
+ {
+ auto pre_trans_x = create_pre_transpose(node);
+ pre_trans_x->a(node->x());
+ node->x(pre_trans_x);
+
+ auto pre_trans_y = create_pre_transpose(node);
+ pre_trans_y->a(node->y());
+ node->y(pre_trans_y);
+ }
else
{
// TODO support other cases
@@ -963,15 +958,18 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
bool visit(luci::CircleMinimum *node)
{
- luci::CircleNode *pred_node = nullptr;
- luci::CircleConst *comp_constant = nullptr;
-
- if (is_NCHW_with_s_const<luci::CircleMinimum>(node, pred_node, comp_constant))
+ if ((not is_const(node->x())) and is_scalar_const(node->y()))
{
auto pre_trans = create_pre_transpose(node);
- pre_trans->a(pred_node);
+ pre_trans->a(node->x());
node->x(pre_trans);
}
+ else if (is_scalar_const(node->x()) and (not is_const(node->y())))
+ {
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(node->y());
+ node->y(pre_trans);
+ }
else
{
// TODO support other cases
@@ -1168,14 +1166,88 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
return true;
}
+ // TODO Reduce duplicate codes with CircleReduceMax
+ bool visit(luci::CircleReduceMin *node)
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(node->input());
+ if (input->rank() != 4)
+ return false;
+
+ auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
+ if (not rindices)
+ return false;
+
+ auto nhwc_rindices = create_NHWC_rindices(rindices);
+ if (not nhwc_rindices)
+ return false;
+
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(input);
+ node->input(pre_trans);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ node->reduction_indices(nhwc_rindices);
+
+ if (node->keep_dims())
+ {
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+ }
+
+ // The below codes handle the cases where node->keep_dims() == false
+ // 1D output never needs a transpose
+ if (node->rank() <= 1)
+ return true;
+
+ std::vector<bool> reduced_dims_nhwc(4, false);
+ uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
+
+ for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
+ {
+ reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
+ }
+
+ // if channel dimension has been reduced, we don't need a transpose
+ if (reduced_dims_nhwc[3])
+ return true;
+
+ // likewise, if both space dimensions are reduced, no transpose is needed
+ if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
+ return true;
+
+ std::vector<int32_t> post_trans_ind;
+ // case 1: only N is reduced
+ if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
+ post_trans_ind = {2, 0, 1};
+
+ // case 2: only H or W is reduced
+ if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
+ post_trans_ind = {0, 2, 1};
+
+ // case 3: N and either H or W are reduced
+ if (num_reduced_indices == 2)
+ post_trans_ind = {1, 0};
+
+ auto post_trans = create_Nd_transpose(node, post_trans_ind);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+ }
+
bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); }
bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); }
bool visit(luci::CircleRsqrt *node) { return convert_unary_x<luci::CircleRsqrt>(node); }
- bool visit(luci::CircleSoftmax *node) { return convert_unary_logits<luci::CircleSoftmax>(node); }
-
bool visit(luci::CircleSplitV *node)
{
// Change split dimension
@@ -1375,6 +1447,10 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
collect_intermediate = [&](loco::Node *n) {
for (auto succ : loco::succs(n))
{
+ // Skip unnecessary traversal
+ if (intermediate.find(succ) != intermediate.end())
+ continue;
+
// Exit condition
if (is_post_transpose(succ) || is_post_reshape(succ))
continue;
@@ -1429,12 +1505,13 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
set_data_format(node, DataFormat::NCHW);
}
break;
+ // SOFTMAX, LOG_SOFTMAX are not converted, because
+ // tflite/circle assumes the last channel is always axis
case luci::CircleOpcode::ADD:
case luci::CircleOpcode::CONCATENATION:
case luci::CircleOpcode::ELU:
case luci::CircleOpcode::LEAKY_RELU:
case luci::CircleOpcode::LOGISTIC:
- case luci::CircleOpcode::LOG_SOFTMAX:
case luci::CircleOpcode::MAXIMUM:
case luci::CircleOpcode::MEAN:
case luci::CircleOpcode::MINIMUM:
@@ -1443,10 +1520,10 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
case luci::CircleOpcode::PAD:
case luci::CircleOpcode::PADV2:
case luci::CircleOpcode::REDUCE_MAX:
+ case luci::CircleOpcode::REDUCE_MIN:
case luci::CircleOpcode::RELU:
case luci::CircleOpcode::RELU6:
case luci::CircleOpcode::RSQRT:
- case luci::CircleOpcode::SOFTMAX:
case luci::CircleOpcode::SPLIT_V:
case luci::CircleOpcode::SQUARED_DIFFERENCE:
case luci::CircleOpcode::SUB:
@@ -1487,7 +1564,8 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
{
// TODO replace the check above with the input rank check, and remove the condition below
if (not dynamic_cast<luci::CircleMean *>(node) and
- not dynamic_cast<luci::CircleReduceMax *>(node))
+ not dynamic_cast<luci::CircleReduceMax *>(node) and
+ not dynamic_cast<luci::CircleReduceMin *>(node))
continue;
}
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
index 6bb3d3268..fd326518e 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -483,22 +483,6 @@ public:
luci::CircleLogistic *logistic = nullptr;
};
-class LogSoftmaxGraph final : public SimpleGraph
-{
-protected:
- loco::Node *insertGraphBody(loco::Node *input) override
- {
- log_softmax = g.nodes()->create<luci::CircleLogSoftmax>();
- log_softmax->logits(input);
- log_softmax->name("log_softmax");
-
- return log_softmax;
- }
-
-public:
- luci::CircleLogSoftmax *log_softmax = nullptr;
-};
-
class MaximumGraph final : public SimpleGraph
{
protected:
@@ -530,6 +514,27 @@ public:
luci::CircleConst *limit = nullptr;
};
+class MaximumNonConstGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ max = g.nodes()->create<luci::CircleMaximum>();
+ max->dtype(loco::DataType::FLOAT32);
+ max->shape({1, 16, 4, 4});
+
+ max->x(input);
+ max->y(input);
+
+ max->name("max");
+
+ return max;
+ }
+
+public:
+ luci::CircleMaximum *max = nullptr;
+};
+
class MeanGraph final : public SimpleGraph
{
protected:
@@ -874,6 +879,51 @@ private:
std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
};
+class ReduceMinGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ rm = g.nodes()->create<luci::CircleReduceMin>();
+ rindices = g.nodes()->create<luci::CircleConst>();
+
+ rm->dtype(loco::DataType::FLOAT32);
+ rindices->dtype(loco::DataType::S32);
+
+ rm->shape(_shape);
+ rindices->shape({static_cast<uint32_t>(_axes.size())});
+
+ rindices->size<loco::DataType::S32>(_axes.size());
+ for (uint32_t i = 0; i < _axes.size(); ++i)
+ {
+ rindices->at<loco::DataType::S32>(i) = _axes[i];
+ }
+
+ rm->input(input);
+ rm->reduction_indices(rindices);
+ rm->keep_dims(_keep_dims);
+
+ rm->name("reduce_max");
+ rindices->name("rindices");
+
+ return rm;
+ }
+
+public:
+ void keep_dims(bool val) { _keep_dims = val; }
+ void axes(std::vector<int32_t> val) { _axes = val; }
+ void shape(std::initializer_list<uint32_t> val) { _shape = val; }
+
+public:
+ luci::CircleReduceMin *rm = nullptr;
+ luci::CircleConst *rindices = nullptr;
+
+private:
+ bool _keep_dims = true;
+ std::vector<int32_t> _axes = {2, 3};
+ std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
+};
+
class ReluGraph final : public SimpleGraph
{
protected:
@@ -922,22 +972,6 @@ public:
luci::CircleRsqrt *rsqrt = nullptr;
};
-class SoftmaxGraph final : public SimpleGraph
-{
-protected:
- loco::Node *insertGraphBody(loco::Node *input) override
- {
- softmax = g.nodes()->create<luci::CircleSoftmax>();
- softmax->logits(input);
- softmax->name("softmax");
-
- return softmax;
- }
-
-public:
- luci::CircleSoftmax *softmax = nullptr;
-};
-
class SplitVGraphlet
{
public:
@@ -1357,44 +1391,50 @@ TEST(ConvertNCHWToNHWC, Logistic)
EXPECT_EQ(16, g.logistic->dim(3).value());
}
-TEST(ConvertNCHWToNHWC, LogSoftmax)
+TEST(ConvertNCHWToNHWC, Maximum)
{
- LogSoftmaxGraph g;
+ MaximumGraph g;
g.init();
- run_phase(&g.g, true, true);
+ run_phase(&g.g, false, false);
+
+ auto input_succs = loco::succs(g.input);
+ EXPECT_EQ(1, input_succs.size());
+ check_post_trans(*input_succs.begin());
- check_pre_trans(g.log_softmax->logits());
+ check_pre_trans(g.max->x());
- auto log_softmax_succs = loco::succs(g.log_softmax);
- EXPECT_EQ(1, log_softmax_succs.size());
- check_post_trans(*log_softmax_succs.begin());
+ auto max_succs = loco::succs(g.max);
+ EXPECT_EQ(1, max_succs.size());
+ check_post_trans(*max_succs.begin());
- // Check log_softmax shape
- EXPECT_EQ(1, g.log_softmax->dim(0).value());
- EXPECT_EQ(4, g.log_softmax->dim(1).value());
- EXPECT_EQ(4, g.log_softmax->dim(2).value());
- EXPECT_EQ(16, g.log_softmax->dim(3).value());
+ check_pre_trans(g.output->from());
}
-TEST(ConvertNCHWToNHWC, Maximum)
+TEST(ConvertNCHWToNHWC, Maximum_non_scalar_NEG)
{
MaximumGraph g;
g.init();
- run_phase(&g.g, false, false);
+ g.limit->shape({3});
- auto input_succs = loco::succs(g.input);
- EXPECT_EQ(1, input_succs.size());
- check_post_trans(*input_succs.begin());
+ luci::ConvertNCHWToNHWCPass pass(true, true);
+ EXPECT_FALSE(pass.run(&g.g));
+}
+
+TEST(ConvertNCHWToNHWC, MaximumNonConst)
+{
+ MaximumNonConstGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
check_pre_trans(g.max->x());
+ check_pre_trans(g.max->y());
auto max_succs = loco::succs(g.max);
EXPECT_EQ(1, max_succs.size());
check_post_trans(*max_succs.begin());
-
- check_pre_trans(g.output->from());
}
TEST(ConvertNCHWToNHWC, Mean)
@@ -1553,6 +1593,17 @@ TEST(ConvertNCHWToNHWC, Minimum)
check_pre_trans(g.output->from());
}
+TEST(ConvertNCHWToNHWC, Minimum_non_scalar_NEG)
+{
+ MinimumGraph g;
+ g.init();
+
+ g.limit->shape({3});
+
+ luci::ConvertNCHWToNHWCPass pass(true, true);
+ EXPECT_FALSE(pass.run(&g.g));
+}
+
TEST(ConvertNCHWToNHWC, Mul)
{
MulGraph g;
@@ -1893,6 +1944,85 @@ TEST(ConvertNCHWToNHWC, ReduceMax_keep_dims_false)
}
}
+TEST(ConvertNCHWToNHWC, ReduceMin)
+{
+ ReduceMinGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.rm->input());
+
+ auto rm_succs = loco::succs(g.rm);
+ EXPECT_EQ(1, rm_succs.size());
+ check_post_trans(*rm_succs.begin());
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(2, new_rindices->dim(0).value());
+ EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
+ EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
+ EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
+}
+
+TEST(ConvertNCHWToNHWC, ReduceMin_keep_dims_false)
+{
+ struct TC
+ {
+ std::vector<int32_t> nchw_ind;
+ std::vector<int32_t> nhwc_ind;
+ std::initializer_list<uint32_t> shape;
+ bool needs_transpose = false;
+ };
+
+ uint32_t n = 1;
+ uint32_t c = 16;
+ uint32_t h = 4;
+ uint32_t w = 4;
+
+ std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
+ {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
+ {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
+ {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
+ {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
+ {{0, 1, 2}, {0, 3, 1}, {w}, false}};
+
+ for (auto &tc : test_cases)
+ {
+ ReduceMinGraph g;
+ g.keep_dims(false);
+ g.axes(tc.nchw_ind);
+ g.shape(tc.shape);
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.rm->input());
+
+ auto rm_succs = loco::succs(g.rm);
+ EXPECT_EQ(1, rm_succs.size());
+ if (tc.needs_transpose)
+ {
+ EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
+ }
+ else
+ {
+ EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
+ }
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
+ EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
+ for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
+ {
+ EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
+ }
+ }
+}
+
TEST(ConvertNCHWToNHWC, Relu)
{
ReluGraph g;
@@ -1953,26 +2083,6 @@ TEST(ConvertNCHWToNHWC, Rsqrt)
EXPECT_EQ(16, g.rsqrt->dim(3).value());
}
-TEST(ConvertNCHWToNHWC, Softmax)
-{
- SoftmaxGraph g;
- g.init();
-
- run_phase(&g.g, true, true);
-
- check_pre_trans(g.softmax->logits());
-
- auto softmax_succs = loco::succs(g.softmax);
- EXPECT_EQ(1, softmax_succs.size());
- check_post_trans(*softmax_succs.begin());
-
- // Check softmax shape
- EXPECT_EQ(1, g.softmax->dim(0).value());
- EXPECT_EQ(4, g.softmax->dim(1).value());
- EXPECT_EQ(4, g.softmax->dim(2).value());
- EXPECT_EQ(16, g.softmax->dim(3).value());
-}
-
TEST(ConvertNCHWToNHWC, SplitV)
{
SplitVGraph g;
diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
index 72f590135..aacfce3d0 100644
--- a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
+++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
@@ -31,7 +31,10 @@ namespace
luci::CircleQuantize *create_quantize(luci::CircleNode *node)
{
auto quantize = node->graph()->nodes()->create<luci::CircleQuantize>();
- quantize->name(node->name() + "_Quantize");
+ // DESIGN NOTE: Why use '_FQ_Quantize' instead of '_Quantize'?
+ // '_Quantize' is used in mixed-precision quantization
+ // We add '_FQ' to distinguish Op from mixed-precision quantization
+ quantize->name(node->name() + "_FQ_Quantize");
quantize->dtype(node->dtype());
quantize->rank(node->rank());
for (uint32_t i = 0; i < node->rank(); i++)
@@ -50,7 +53,10 @@ luci::CircleQuantize *create_quantize(luci::CircleNode *node)
luci::CircleDequantize *create_dequantize(luci::CircleNode *node)
{
auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>();
- dequantize->name(node->name() + "_Dequantize");
+ // DESIGN NOTE: Why use '_FQ_Dequantize' instead of '_Dequantize'?
+ // '_Dequantize' is used in mixed-precision quantization
+ // We add '_FQ' to distinguish Op from mixed-precision quantization
+ dequantize->name(node->name() + "_FQ_Dequantize");
dequantize->dtype(loco::DataType::FLOAT32);
dequantize->rank(node->rank());
for (uint32_t i = 0; i < node->rank(); i++)
@@ -184,6 +190,7 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
// For non-const activation, insert Quantize-Dequantize Ops
// and dequantize the node
+ void visit(luci::CircleAbs *node) { fq_activation(node); }
void visit(luci::CircleAdd *node) { fq_activation(node); }
void visit(luci::CircleAveragePool2D *node) { fq_activation(node); }
void visit(luci::CircleBatchMatMul *node) { fq_activation(node); }
@@ -201,6 +208,7 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
void visit(luci::CirclePad *node) { fq_activation(node); }
void visit(luci::CirclePRelu *node) { fq_activation(node); }
void visit(luci::CircleMean *node) { fq_activation(node); }
+ void visit(luci::CircleReduceProd *node) { fq_activation(node); }
void visit(luci::CircleReduceMax *node) { fq_activation(node); }
void visit(luci::CircleRelu *node) { fq_activation(node); }
void visit(luci::CircleRelu6 *node) { fq_activation(node); }
@@ -216,15 +224,20 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
// (dtype will be automatically updated by type inference)
void visit(luci::CircleCast *) {}
void visit(luci::CircleConcatenation *) {}
+ void visit(luci::CircleDepthToSpace *) {}
void visit(luci::CircleGather *) {}
void visit(luci::CircleSlice *) {}
void visit(luci::CircleStridedSlice *) {}
void visit(luci::CircleReshape *) {}
+ void visit(luci::CircleSpaceToDepth *) {}
void visit(luci::CircleSplit *) {}
void visit(luci::CircleSplitOut *) {}
void visit(luci::CircleSplitV *) {}
void visit(luci::CircleSplitVOut *) {}
void visit(luci::CircleTranspose *) {}
+ void visit(luci::CirclePack *) {}
+ void visit(luci::CircleUnpack *) {}
+ void visit(luci::CircleUnpackOut *) {}
// For Ops that return index, fake quantization is unnecessary
void visit(luci::CircleArgMax *) {}
diff --git a/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp b/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp
index 0734e0778..5df1b72dc 100644
--- a/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp
+++ b/compiler/luci/pass/src/ExpandBroadcastConstPass.test.cpp
@@ -19,6 +19,8 @@
#include <luci/IR/CircleNodes.h>
+#include <limits> // std::numeric_limits
+
#include <gtest/gtest.h>
namespace
diff --git a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp
index 6e423e3d9..33f9f1d77 100644
--- a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp
+++ b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.cpp
@@ -23,6 +23,8 @@
#include <luci/Log.h>
+#include <limits> // std::numeric_limits
+
namespace
{
diff --git a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp
index b1ef56833..36cae0437 100644
--- a/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp
+++ b/compiler/luci/pass/src/FoldDepthwiseConv2DPass.test.cpp
@@ -19,6 +19,8 @@
#include <luci/IR/CircleNodes.h>
+#include <limits> // std::numeric_limits
+
#include <gtest/gtest.h>
namespace
diff --git a/compiler/luci/pass/src/FoldFullyConnectedPass.cpp b/compiler/luci/pass/src/FoldFullyConnectedPass.cpp
new file mode 100644
index 000000000..a3bca7eda
--- /dev/null
+++ b/compiler/luci/pass/src/FoldFullyConnectedPass.cpp
@@ -0,0 +1,198 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldFullyConnectedPass.h"
+
+#include <tensorflow/lite/kernels/internal/reference/fully_connected.h>
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/AttrFusedActFunc.h>
+
+#include <luci/Log.h>
+
+#include <limits> // std::numeric_limits
+
+namespace
+{
+
+bool set_kernel_parameters(tflite::FullyConnectedParams *params, luci::CircleFullyConnected *node)
+{
+ switch (node->fusedActivationFunction())
+ {
+ case luci::FusedActFunc::NONE:
+ case luci::FusedActFunc::TANH:
+ params->float_activation_min = std::numeric_limits<float>::lowest();
+ params->float_activation_max = std::numeric_limits<float>::max();
+ break;
+ case luci::FusedActFunc::RELU:
+ params->float_activation_min = 0;
+ params->float_activation_max = std::numeric_limits<float>::max();
+ break;
+ case luci::FusedActFunc::RELU_N1_TO_1:
+ params->float_activation_min = -1;
+ params->float_activation_max = 1;
+ break;
+ case luci::FusedActFunc::RELU6:
+ params->float_activation_min = 0;
+ params->float_activation_max = 6;
+ break;
+ default:
+ {
+ LOGGER(l);
+ WARN(l) << "Unsupported activation: " << uint32_t(node->fusedActivationFunction());
+ return false;
+ }
+ }
+
+ assert(node->weights_format() ==
+ luci::CircleFullyConnected::WeightsFormat::DEFAULT); // FIX_CALLER_UNLESS
+ params->weights_format = tflite::FullyConnectedWeightsFormat::kDefault;
+
+ return true;
+}
+
+#define RETURN_FALSE_UNLESS(cond) \
+ if (not(cond)) \
+ return false;
+
+/**
+ * Fold FullyConnected with constant input and filter into a constant tensor
+ *
+ * BEFORE
+ *
+ * [CircleConst] [CircleConst]
+ * | |
+ * [CircleFullyConnected]
+ *
+ * AFTER
+ *
+ * [CircleConst]
+ */
+bool fold_fully_connected(luci::CircleFullyConnected *node)
+{
+ RETURN_FALSE_UNLESS(node != nullptr);
+
+ LOGGER(l);
+
+ auto const input = dynamic_cast<luci::CircleConst *>(node->input());
+ auto const weights = dynamic_cast<luci::CircleConst *>(node->weights());
+ auto const bias = dynamic_cast<luci::CircleConst *>(node->bias());
+ auto const no_bias = dynamic_cast<luci::CircleOutputExclude *>(node->bias());
+
+ RETURN_FALSE_UNLESS(input != nullptr);
+ RETURN_FALSE_UNLESS(weights != nullptr);
+ RETURN_FALSE_UNLESS(node->weights_format() == luci::CircleFullyConnected::WeightsFormat::DEFAULT);
+ RETURN_FALSE_UNLESS(bias != nullptr or no_bias != nullptr);
+
+ RETURN_FALSE_UNLESS(input->dtype() == loco::DataType::FLOAT32);
+ RETURN_FALSE_UNLESS(weights->dtype() == loco::DataType::FLOAT32);
+ if (bias)
+ RETURN_FALSE_UNLESS(bias->dtype() == loco::DataType::FLOAT32);
+
+ auto const input_elems = input->size<loco::DataType::FLOAT32>();
+
+ RETURN_FALSE_UNLESS(weights->rank() == 2);
+ RETURN_FALSE_UNLESS(input_elems % weights->dim(1).value() == 0);
+ auto const batch_size = input_elems / weights->dim(1).value();
+ auto const num_units = weights->dim(0).value();
+
+ if (bias)
+ RETURN_FALSE_UNLESS(bias->size<loco::DataType::FLOAT32>() == num_units);
+
+ tflite::FullyConnectedParams params{};
+ if (!set_kernel_parameters(&params, node))
+ return false; // Unsupported kernel parameter values
+
+ std::vector<uint32_t> output_shape;
+ if (node->keep_num_dims() == false)
+ {
+ output_shape.push_back(batch_size);
+ output_shape.push_back(num_units);
+ }
+ else
+ {
+ output_shape.resize(input->rank());
+ for (uint32_t i = 0; i < input->rank(); i++)
+ output_shape[i] = input->dim(i).value();
+ output_shape[input->rank() - 1] = num_units;
+ }
+
+ auto constant = node->graph()->nodes()->create<luci::CircleConst>();
+ {
+ constant->name(node->name());
+ constant->dtype(node->dtype());
+ constant->rank(node->rank());
+ constant->shape_status(luci::ShapeStatus::VALID);
+ uint32_t num_elem = 1;
+ for (uint32_t i = 0; i < node->rank(); ++i)
+ {
+ constant->dim(i).set(node->dim(i).value());
+ num_elem *= node->dim(i).value();
+ }
+ constant->size<loco::DataType::FLOAT32>(num_elem);
+ }
+
+ auto tensor_shape = [](luci::CircleNode *node) {
+ if (node == nullptr)
+ return tflite::RuntimeShape();
+
+ tflite::RuntimeShape runtime_shape(node->rank());
+ for (uint32_t i = 0; i < node->rank(); ++i)
+ runtime_shape.SetDim(i, node->dim(i).value());
+ return runtime_shape;
+ };
+
+ auto tensor_data = [](luci::CircleConst *node) -> float * {
+ if (node == nullptr)
+ return nullptr;
+
+ return &node->at<loco::DataType::FLOAT32>(0);
+ };
+
+ tflite::reference_ops::FullyConnected(
+ params, tensor_shape(input), tensor_data(input), tensor_shape(weights), tensor_data(weights),
+ tensor_shape(bias), tensor_data(bias), tensor_shape(constant), tensor_data(constant));
+
+ loco::replace(node).with(constant);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * Constant Folding for FullyConnected Op
+ **/
+bool FoldFullyConnectedPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto fc = dynamic_cast<CircleFullyConnected *>(node);
+
+ if (fold_fully_connected(fc))
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace luci
+
+#undef RETURN_FALSE_UNLESS
diff --git a/compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp
new file mode 100644
index 000000000..a8e64a24b
--- /dev/null
+++ b/compiler/luci/pass/src/FoldFullyConnectedPass.test.cpp
@@ -0,0 +1,160 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldFullyConnectedPass.h"
+#include "PassTestGraphs.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <limits> // std::numeric_limits
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+/**
+ * Graph has an FullyConnected Op with constant inputs
+ *
+ * BEFORE
+ *
+ * [CircleConst] [CircleConst]
+ * | |
+ * [CircleFullyConnected]
+ *
+ * AFTER
+ *
+ * [CircleConst]
+ */
+class FoldFullyConnectedTest : public luci::ConstantFoldingTestGraph, public ::testing::Test
+{
+#define INPUT_DIM 80
+#define NUM_UNITS 32
+
+public:
+ FoldFullyConnectedTest() : luci::ConstantFoldingTestGraph({INPUT_DIM}, loco::DataType::FLOAT32)
+ {
+ _fc = _g.nodes()->create<luci::CircleFullyConnected>();
+ _fc_input = _g.nodes()->create<luci::CircleConst>();
+ _fc_weights = _g.nodes()->create<luci::CircleConst>();
+ _fc_bias = _g.nodes()->create<luci::CircleConst>();
+
+ _fc->dtype(loco::DataType::FLOAT32);
+ _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _fc->input(_fc_input);
+ _fc->weights(_fc_weights);
+ _fc->bias(_fc_bias);
+ _fc->shape({NUM_UNITS});
+ _fc->weights_format(luci::CircleFullyConnected::WeightsFormat::DEFAULT);
+ _fc->keep_num_dims(true);
+
+ _fc_input->dtype(loco::DataType::FLOAT32);
+ _fc_input->shape({INPUT_DIM});
+ _fc_input->size<loco::DataType::FLOAT32>(INPUT_DIM);
+
+ _fc_weights->dtype(loco::DataType::FLOAT32);
+ _fc_weights->shape({NUM_UNITS, INPUT_DIM});
+ _fc_weights->size<loco::DataType::FLOAT32>(NUM_UNITS * INPUT_DIM);
+
+ _fc_bias->dtype(loco::DataType::FLOAT32);
+ _fc_bias->shape({1, NUM_UNITS});
+ _fc_bias->size<loco::DataType::FLOAT32>(NUM_UNITS);
+
+ for (uint32_t i = 0; i < INPUT_DIM; ++i)
+ _fc_input->at<loco::DataType::FLOAT32>(i) = 1.0;
+
+ for (uint32_t i = 0; i < INPUT_DIM * NUM_UNITS; ++i)
+ _fc_weights->at<loco::DataType::FLOAT32>(i) = 1.0;
+
+ for (uint32_t i = 0; i < NUM_UNITS; ++i)
+ _fc_bias->at<loco::DataType::FLOAT32>(i) = 0.0;
+
+ _output->from(_fc);
+ }
+
+protected:
+ void init() final {}
+
+protected:
+ loco::Node *createFoldedPattern() final { return nullptr; }
+
+protected:
+ luci::CircleConst *getFoldedPattern() final
+ {
+ return loco::must_cast<luci::CircleConst *>(_output->from());
+ }
+
+protected:
+ luci::CircleFullyConnected *_fc = nullptr;
+ luci::CircleConst *_fc_input = nullptr;
+ luci::CircleConst *_fc_weights = nullptr;
+ luci::CircleConst *_fc_bias = nullptr;
+#undef INPUT_DIM
+#undef NUM_UNITS
+};
+
+} // namespace
+
+TEST_F(FoldFullyConnectedTest, fold_fc)
+{
+ luci::FoldFullyConnectedPass pass;
+ ASSERT_TRUE(pass.run(&_g));
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_EQ(folded_const->dtype(), loco::DataType::FLOAT32);
+ EXPECT_EQ(1, folded_const->rank());
+ EXPECT_EQ(32, folded_const->dim(0));
+ EXPECT_EQ(32, folded_const->size<loco::DataType::FLOAT32>());
+ for (uint32_t i = 0; i < 32; ++i)
+ EXPECT_NEAR(folded_const->at<loco::DataType::FLOAT32>(i), 80,
+ std::numeric_limits<float>::min());
+}
+
+TEST_F(FoldFullyConnectedTest, fold_fc_no_bias)
+{
+ auto no_bias = _g.nodes()->create<luci::CircleOutputExclude>();
+ _fc->bias(no_bias);
+
+ luci::FoldFullyConnectedPass pass;
+ ASSERT_TRUE(pass.run(&_g));
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
+ EXPECT_EQ(1, folded_const->rank());
+ EXPECT_EQ(32, folded_const->dim(0));
+ EXPECT_EQ(32, folded_const->size<loco::DataType::FLOAT32>());
+ for (uint32_t i = 0; i < 32; ++i)
+ EXPECT_NEAR(folded_const->at<loco::DataType::FLOAT32>(i), 80,
+ std::numeric_limits<float>::min());
+}
+
+TEST_F(FoldFullyConnectedTest, fold_fc_NEG)
+{
+ auto new_fc = _g.nodes()->create<luci::CircleFullyConnected>();
+ _fc->input(new_fc);
+
+ luci::FoldFullyConnectedPass pass;
+ ASSERT_FALSE(pass.run(&_g));
+}
+
+TEST_F(FoldFullyConnectedTest, fold_fc_weight_format_NEG)
+{
+ auto new_fc = _g.nodes()->create<luci::CircleFullyConnected>();
+ _fc->weights_format(luci::CircleFullyConnected::WeightsFormat::SHUFFLED4x16INT8);
+
+ luci::FoldFullyConnectedPass pass;
+ ASSERT_FALSE(pass.run(&_g));
+}
diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
index bc09abee2..3494a6e60 100644
--- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
+++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
@@ -76,6 +76,26 @@ luci::CircleReshape *create_cloned_reshape(luci::CircleReshape *reshape)
return new_reshape;
}
+bool forward_reshape(luci::CircleReshape *reshape, luci::CircleAbs *abs)
+{
+ assert(reshape != nullptr); // FIX_CALLER_UNLESS
+ assert(abs != nullptr); // FIX_CALLER_UNLESS
+
+ auto new_reshape = create_cloned_reshape(reshape);
+ if (not new_reshape)
+ return false;
+
+ // reconnect network
+ loco::replace(abs).with(new_reshape);
+ abs->x(reshape->tensor());
+ new_reshape->tensor(abs);
+
+ // Do shape inference for this node again.
+ abs->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ return true;
+}
+
bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg)
{
assert(reshape != nullptr);
@@ -136,6 +156,14 @@ protected:
return false;
}
+ bool visit(luci::CircleAbs *node)
+ {
+ auto reshape = as_reshape(node->x());
+ if (reshape == nullptr)
+ return false;
+ return forward_reshape(reshape, node);
+ }
+
bool visit(luci::CircleNeg *node)
{
auto reshape = as_reshape(node->x());
diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp
new file mode 100644
index 000000000..c76d73344
--- /dev/null
+++ b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp
@@ -0,0 +1,366 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ForwardTransposeOpPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Service/CircleNodeClone.h>
+
+using namespace luci;
+
+namespace
+{
+
+// Create new Transpose Op including perm
+// Return nullptr if failed
+CircleTranspose *create_cloned_transpose(CircleTranspose *transpose)
+{
+ assert(transpose != nullptr); // FIX_CALLER_UNLESS
+
+ auto perm = dynamic_cast<CircleConst *>(transpose->perm());
+ if (not perm)
+ return nullptr;
+
+ CircleConst *cloned_perm = clone(perm);
+ if (cloned_perm == nullptr)
+ return nullptr;
+
+ cloned_perm->name(perm->name() + "_C");
+ luci::add_origin(cloned_perm, luci::get_origin(perm));
+
+ auto cloned_node = clone_node(transpose, transpose->graph());
+ if (cloned_node == nullptr)
+ return nullptr;
+
+ auto new_transpose = loco::must_cast<luci::CircleTranspose *>(cloned_node);
+ new_transpose->perm(cloned_perm);
+ new_transpose->name(transpose->name() + "_C");
+ luci::add_origin(new_transpose, luci::get_origin(transpose));
+
+ return new_transpose;
+}
+
+uint32_t cal_offset(const std::vector<uint32_t> &shape, const std::vector<uint32_t> &indices)
+{
+ assert(shape.size() == indices.size()); // FIX_CALLER_UNLESS
+
+ uint32_t offset = 0;
+ for (uint32_t i = 0; i < indices.size(); i++)
+ {
+ uint32_t index = indices[i];
+ for (uint32_t j = shape.size() - 1; j > i; j--)
+ {
+ index *= shape[j];
+ }
+ offset += index;
+ }
+ return offset;
+}
+
+// Return reverse-transpose of 'node'
+// i.e., Transpose(return value) = node
+CircleConst *reverse_transposed(CircleConst *node, std::vector<uint32_t> &t)
+{
+ assert(node->rank() == t.size()); // FIX_CALLER_UNLESS
+ assert(node->rank() == 4); // FIX_CALLER_UNLESS
+
+ std::vector<uint32_t> orig_shape(node->rank());
+ std::vector<uint32_t> new_shape(node->rank());
+
+ for (uint32_t i = 0; i < node->rank(); i++)
+ {
+ assert(t[i] < node->rank()); // FIX_CALLER_UNLESS
+
+ orig_shape[i] = node->dim(i).value();
+ new_shape[t[i]] = node->dim(i).value();
+ }
+
+ auto clone_const = clone(node);
+ for (uint32_t i = 0; i < node->rank(); i++)
+ clone_const->dim(i).set(new_shape[i]);
+
+ clone_const->name(clone_const->name() + "_r_transposed");
+ add_origin(clone_const, luci::get_origin(node));
+
+ for (uint32_t n = 0; n < clone_const->dim(0).value(); n++)
+ {
+ for (uint32_t h = 0; h < clone_const->dim(1).value(); h++)
+ {
+ for (uint32_t w = 0; w < clone_const->dim(2).value(); w++)
+ {
+ for (uint32_t c = 0; c < clone_const->dim(3).value(); c++)
+ {
+ std::vector<uint32_t> new_indices{n, h, w, c};
+ std::vector<uint32_t> orig_indices{new_indices[t[0]], new_indices[t[1]],
+ new_indices[t[2]], new_indices[t[3]]};
+
+ const auto data = node->at<loco::DataType::FLOAT32>(cal_offset(orig_shape, orig_indices));
+ clone_const->at<loco::DataType::FLOAT32>(cal_offset(new_shape, new_indices)) = data;
+ }
+ }
+ }
+ }
+
+ return clone_const;
+}
+
+bool check_rank_four(const CircleConst *c) { return c->rank() == 4; }
+
+// Return true if below conditions are met
+// 1. t->perm() is CircleConst
+// 2. t->perm() is S32
+bool check_perm(const CircleTranspose *t)
+{
+ auto perm = dynamic_cast<CircleConst *>(t->perm());
+ if (not perm)
+ return false;
+
+ switch (perm->dtype())
+ {
+ case loco::DataType::S32:
+ for (uint32_t i = 0; i < perm->size<loco::DataType::S32>(); i++)
+ {
+ auto data = perm->at<loco::DataType::S32>(i);
+ // TODO Support not normalized index
+ if (data < 0 or data >= static_cast<int32_t>(t->rank()))
+ return false;
+ }
+ break;
+ // TODO Support S64 data type
+ default:
+ return false;
+ }
+
+ return true;
+}
+
+#define RETURN_FALSE_UNLESS(COND) \
+ if (not(COND)) \
+ return false;
+
+// Elementwise Binary Operator with const
+class EBOWithConstPattern final : public CircleNodeMutableVisitor<bool>
+{
+private:
+ template <typename CIRCLE_OP_PTR> bool has_pattern(CIRCLE_OP_PTR node)
+ {
+ if (auto x = dynamic_cast<luci::CircleConst *>(node->x()))
+ {
+ if (auto y = dynamic_cast<luci::CircleTranspose *>(node->y()))
+ {
+ RETURN_FALSE_UNLESS(check_rank_four(x));
+ RETURN_FALSE_UNLESS(check_perm(y));
+
+ auto new_const = gen_new_const(y, x);
+ assert(new_const); // FIX_ME_UNLESS
+
+ auto new_transpose = create_cloned_transpose(y);
+ assert(new_transpose); // FIX_ME_UNLESS
+
+ // Reconnect network
+ node->x(new_const);
+ node->y(y->a());
+ loco::replace(node).with(new_transpose);
+ new_transpose->a(node);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ return true;
+ }
+ }
+
+ if (auto y = dynamic_cast<luci::CircleConst *>(node->y()))
+ {
+ if (auto x = dynamic_cast<luci::CircleTranspose *>(node->x()))
+ {
+ RETURN_FALSE_UNLESS(check_rank_four(y));
+ RETURN_FALSE_UNLESS(check_perm(x));
+
+ auto new_const = gen_new_const(x, y);
+ assert(new_const); // FIX_ME_UNLESS
+
+ auto new_transpose = create_cloned_transpose(x);
+ assert(new_transpose); // FIX_ME_UNLESS
+
+ // Reconnect network
+ node->y(new_const);
+ node->x(x->a());
+ loco::replace(node).with(new_transpose);
+ new_transpose->a(node);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+public:
+ // Default
+ bool visit(luci::CircleNode *) { return false; }
+
+ bool visit(luci::CircleAdd *node) { return has_pattern(node); }
+
+ bool visit(luci::CircleMul *node) { return has_pattern(node); }
+
+private:
+ // Return a new const node after Tranpose Op is forwarded
+ // Return nullptr if unsupported cases
+ CircleConst *gen_new_const(CircleTranspose *t, CircleConst *c)
+ {
+ const auto perm = dynamic_cast<CircleConst *>(t->perm());
+
+ // Only support constant perm
+ if (not perm)
+ return nullptr;
+
+ std::vector<uint32_t> perm_data;
+ switch (perm->dtype())
+ {
+ case loco::DataType::S32:
+ for (uint32_t i = 0; i < perm->size<loco::DataType::S32>(); i++)
+ {
+ auto data = perm->at<loco::DataType::S32>(i);
+ assert(data >= 0 and data < static_cast<int32_t>(t->rank()));
+ perm_data.emplace_back(static_cast<uint32_t>(data));
+ }
+ break;
+ // TODO Support S64 data type
+ default:
+ return nullptr;
+ }
+
+ assert(perm_data.size() == t->rank()); // FIX_CALLER_UNLESS
+
+ return reverse_transposed(c, perm_data);
+ }
+};
+
+// Elementwise Unary Operator
+class EwUnaryPattern final : public CircleNodeMutableVisitor<bool>
+{
+private:
+ // input is 'x'
+ template <typename CIRCLE_OP_PTR> bool has_pattern_x(CIRCLE_OP_PTR node)
+ {
+ if (auto x = dynamic_cast<luci::CircleTranspose *>(node->x()))
+ {
+ RETURN_FALSE_UNLESS(check_perm(x));
+
+ auto new_transpose = create_cloned_transpose(x);
+ assert(new_transpose); // FIX_ME_UNLESS
+
+ // Reconnect network
+ node->x(x->a());
+ loco::replace(node).with(new_transpose);
+ new_transpose->a(node);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ return true;
+ }
+
+ return false;
+ }
+
+public:
+ // Default
+ bool visit(luci::CircleNode *) { return false; }
+
+ bool visit(luci::CircleAbs *node) { return has_pattern_x(node); }
+};
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ * |
+ * [CircleNode] [CircleConst]
+ * | /
+ * [CircleTranspose] [CircleConst]
+ * / | /
+ * [CircleNode] [(BinaryOp)]
+ * | | \
+ * | | [CircleNode]
+ * | | |
+ *
+ * BinaryOp: CircleAdd, CircleMul, ...
+ *
+ * |
+ * [CircleNode] [CircleConst]
+ * | /
+ * [CircleTranspose]
+ * / |
+ * [CircleNode] [(UnaryOp)]
+ * | | \
+ * | | [CircleNode]
+ * | | |
+ *
+ * UnaryOp: CircleAbs, ...
+ *
+ * AFTER
+ * |
+ * [CircleConst] [CircleNode] [CircleConst(updated)]
+ * | / | /
+ * [CircleTranspose] [(BinaryOp)] [CircleConst]
+ * | | /
+ * [CircleNode] [CircleTranspose]
+ * | | \
+ * | | [CircleNode]
+ * | | |
+ *
+ * |
+ * [CircleConst] [CircleNode]
+ * | / |
+ * [CircleTranspose] [(UnaryOp)] [CircleConst]
+ * | | /
+ * [CircleNode] [CircleTranspose]
+ * | | \
+ * | | [CircleNode]
+ * | | |
+ *
+ * Note: new [CircleTranspose] is added after [(BinaryOp)]
+ */
+bool ForwardTransposeOpPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ EBOWithConstPattern eboc;
+ EwUnaryPattern ewu;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (circle_node->accept(&eboc))
+ changed = true;
+ else if (circle_node->accept(&ewu))
+ changed = true;
+ }
+ return changed;
+}
+
+#undef RETURN_FALSE_UNLESS
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp
new file mode 100644
index 000000000..2d061c2a3
--- /dev/null
+++ b/compiler/luci/pass/src/ForwardTransposeOpPass.test.cpp
@@ -0,0 +1,524 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ForwardTransposeOpPass.h"
+#include "luci/Pass/CircleShapeInferencePass.h"
+
+#include <logo/Phase.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+#include <vector>
+
+namespace
+{
+
+using namespace luci::test;
+
+template <typename T> class TransposeBinaryOpGraphlet
+{
+public:
+ TransposeBinaryOpGraphlet() = default;
+
+public:
+ virtual ~TransposeBinaryOpGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 perm)
+ {
+ std::vector<uint32_t> shape_in_v = shape_in;
+ std::vector<uint32_t> perm_v = perm;
+
+ assert(shape_in_v.size() == perm_v.size()); // FIX_CALLER_UNLESS
+
+ _perm = g->nodes()->create<luci::CircleConst>();
+ _const = g->nodes()->create<luci::CircleConst>();
+ _transpose = g->nodes()->create<luci::CircleTranspose>();
+ _binary = g->nodes()->create<T>();
+
+ _perm->dtype(loco::DataType::S32);
+ _perm->rank(1);
+ _perm->dim(0).set(perm_v.size());
+ _perm->shape_status(luci::ShapeStatus::VALID);
+
+ _const->dtype(loco::DataType::FLOAT32);
+ _const->rank(shape_in_v.size());
+ for (uint32_t i = 0; i < shape_in_v.size(); i++)
+ _const->dim(i).set(shape_in_v[perm_v[i]]);
+ _const->shape_status(luci::ShapeStatus::VALID);
+
+ // values
+ const auto size = perm_v.size();
+ _perm->size<loco::DataType::S32>(size);
+ for (uint32_t i = 0; i < size; i++)
+ _perm->at<loco::DataType::S32>(i) = perm_v[i];
+
+ uint32_t elems = 1;
+ for (uint32_t i = 0; i < size; i++)
+ elems *= shape_in_v[i];
+
+ _const->size<loco::DataType::FLOAT32>(elems);
+ for (uint32_t i = 0; i < elems; i++)
+ _const->at<loco::DataType::FLOAT32>(i) = i;
+
+ _perm->name("transpose_perm");
+ _transpose->name("transpose");
+ _binary->name("binary");
+ }
+
+ luci::CircleTranspose *transpose(void) { return _transpose; }
+
+ void switch_xy(void)
+ {
+ assert(_binary); // FIX_CALLER_UNLESS
+ auto temp = _binary->x();
+ _binary->x(_binary->y());
+ _binary->y(temp);
+ }
+
+protected:
+ luci::CircleTranspose *_transpose = nullptr;
+ T *_binary = nullptr;
+ luci::CircleConst *_perm = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
+using TransposeAddGraphlet = TransposeBinaryOpGraphlet<luci::CircleAdd>;
+using TransposeMulGraphlet = TransposeBinaryOpGraphlet<luci::CircleMul>;
+
+class ForwardTransposeToAddGraph : public TestIOGraph, public TransposeAddGraphlet
+{
+public:
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ TransposeAddGraphlet::init(g(), shape_in, shape_out);
+
+ // connect network
+ _transpose->a(input());
+ _transpose->perm(_perm);
+ _binary->x(_transpose);
+ _binary->y(_const);
+
+ output()->from(_binary);
+ }
+};
+
+class ForwardTransposeToAddInvalidGraph : public TestIOGraph, public TransposeAddGraphlet
+{
+public:
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ TransposeAddGraphlet::init(g(), shape_in, shape_out);
+
+ // connect network
+ _transpose->a(input());
+ _transpose->perm(_perm);
+ _binary->x(_transpose);
+ _binary->y(_transpose);
+
+ output()->from(_binary);
+ }
+};
+
+class ForwardTransposeToMulGraph : public TestIOGraph, public TransposeMulGraphlet
+{
+public:
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ TransposeMulGraphlet::init(g(), shape_in, shape_out);
+
+ // connect network
+ _transpose->a(input());
+ _transpose->perm(_perm);
+ _binary->x(_transpose);
+ _binary->y(_const);
+
+ output()->from(_binary);
+ }
+};
+
+void run_phase(loco::Graph *g)
+{
+ logo::Phase phase;
+
+ // Default passes.
+ phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
+
+ // Pass to test
+ phase.emplace_back(std::make_unique<luci::ForwardTransposeOpPass>());
+
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
+ phase_runner.run(phase);
+}
+
+class ForwardTransposeToAddGraphTest : public ::testing::Test
+{
+public:
+ void run_pass(void) { run_phase(_graph.g()); }
+
+protected:
+ ForwardTransposeToAddGraph _graph;
+};
+
+class ForwardTransposeToAddGraphNegTest : public ::testing::Test
+{
+public:
+ void run_pass(void) { run_phase(_graph.g()); }
+
+protected:
+ ForwardTransposeToAddInvalidGraph _graph;
+};
+
+class ForwardTransposeToMulGraphTest : public ::testing::Test
+{
+public:
+ void run_pass(void) { run_phase(_graph.g()); }
+
+protected:
+ ForwardTransposeToMulGraph _graph;
+};
+
+} // namespace
+
+TEST_F(ForwardTransposeToAddGraphTest, forward_add_xy)
+{
+ _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
+
+ run_pass();
+
+ auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from());
+ EXPECT_NE(nullptr, transpose);
+ EXPECT_EQ(4, transpose->rank());
+ EXPECT_EQ(1, transpose->dim(0).value());
+ EXPECT_EQ(1, transpose->dim(1).value());
+ EXPECT_EQ(51, transpose->dim(2).value());
+ EXPECT_EQ(64, transpose->dim(3).value());
+
+ auto add = dynamic_cast<luci::CircleAdd *>(transpose->a());
+ EXPECT_NE(nullptr, add);
+ EXPECT_EQ(4, add->rank());
+ EXPECT_EQ(1, add->dim(0).value());
+ EXPECT_EQ(64, add->dim(1).value());
+ EXPECT_EQ(51, add->dim(2).value());
+ EXPECT_EQ(1, add->dim(3).value());
+
+ auto add_const = dynamic_cast<luci::CircleConst *>(add->y());
+ EXPECT_NE(nullptr, add_const);
+ EXPECT_EQ(4, add_const->rank());
+ EXPECT_EQ(1, add_const->dim(0).value());
+ EXPECT_EQ(64, add_const->dim(1).value());
+ EXPECT_EQ(51, add_const->dim(2).value());
+ EXPECT_EQ(1, add_const->dim(3).value());
+}
+
+TEST_F(ForwardTransposeToAddGraphTest, forward_add_yx)
+{
+ _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
+ _graph.switch_xy();
+
+ run_pass();
+
+ auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from());
+ EXPECT_NE(nullptr, transpose);
+ EXPECT_EQ(4, transpose->rank());
+ EXPECT_EQ(1, transpose->dim(0).value());
+ EXPECT_EQ(1, transpose->dim(1).value());
+ EXPECT_EQ(51, transpose->dim(2).value());
+ EXPECT_EQ(64, transpose->dim(3).value());
+
+ auto mul = dynamic_cast<luci::CircleAdd *>(transpose->a());
+ EXPECT_NE(nullptr, mul);
+ EXPECT_EQ(4, mul->rank());
+ EXPECT_EQ(1, mul->dim(0).value());
+ EXPECT_EQ(64, mul->dim(1).value());
+ EXPECT_EQ(51, mul->dim(2).value());
+ EXPECT_EQ(1, mul->dim(3).value());
+
+ auto mul_const = dynamic_cast<luci::CircleConst *>(mul->x());
+ EXPECT_NE(nullptr, mul_const);
+ EXPECT_EQ(4, mul_const->rank());
+ EXPECT_EQ(1, mul_const->dim(0).value());
+ EXPECT_EQ(64, mul_const->dim(1).value());
+ EXPECT_EQ(51, mul_const->dim(2).value());
+ EXPECT_EQ(1, mul_const->dim(3).value());
+}
+
+TEST_F(ForwardTransposeToMulGraphTest, forward_mul_xy)
+{
+ _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
+
+ run_pass();
+
+ auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from());
+ EXPECT_NE(nullptr, transpose);
+ EXPECT_EQ(4, transpose->rank());
+ EXPECT_EQ(1, transpose->dim(0).value());
+ EXPECT_EQ(1, transpose->dim(1).value());
+ EXPECT_EQ(51, transpose->dim(2).value());
+ EXPECT_EQ(64, transpose->dim(3).value());
+
+ auto mul = dynamic_cast<luci::CircleMul *>(transpose->a());
+ EXPECT_NE(nullptr, mul);
+ EXPECT_EQ(4, mul->rank());
+ EXPECT_EQ(1, mul->dim(0).value());
+ EXPECT_EQ(64, mul->dim(1).value());
+ EXPECT_EQ(51, mul->dim(2).value());
+ EXPECT_EQ(1, mul->dim(3).value());
+
+ auto mul_const = dynamic_cast<luci::CircleConst *>(mul->y());
+ EXPECT_NE(nullptr, mul_const);
+ EXPECT_EQ(4, mul_const->rank());
+ EXPECT_EQ(1, mul_const->dim(0).value());
+ EXPECT_EQ(64, mul_const->dim(1).value());
+ EXPECT_EQ(51, mul_const->dim(2).value());
+ EXPECT_EQ(1, mul_const->dim(3).value());
+}
+
+TEST_F(ForwardTransposeToMulGraphTest, forward_mul_yx)
+{
+ _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
+ _graph.switch_xy();
+
+ run_pass();
+
+ auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from());
+ EXPECT_NE(nullptr, transpose);
+ EXPECT_EQ(4, transpose->rank());
+ EXPECT_EQ(1, transpose->dim(0).value());
+ EXPECT_EQ(1, transpose->dim(1).value());
+ EXPECT_EQ(51, transpose->dim(2).value());
+ EXPECT_EQ(64, transpose->dim(3).value());
+
+ auto mul = dynamic_cast<luci::CircleMul *>(transpose->a());
+ EXPECT_NE(nullptr, mul);
+ EXPECT_EQ(4, mul->rank());
+ EXPECT_EQ(1, mul->dim(0).value());
+ EXPECT_EQ(64, mul->dim(1).value());
+ EXPECT_EQ(51, mul->dim(2).value());
+ EXPECT_EQ(1, mul->dim(3).value());
+
+ auto mul_const = dynamic_cast<luci::CircleConst *>(mul->x());
+ EXPECT_NE(nullptr, mul_const);
+ EXPECT_EQ(4, mul_const->rank());
+ EXPECT_EQ(1, mul_const->dim(0).value());
+ EXPECT_EQ(64, mul_const->dim(1).value());
+ EXPECT_EQ(51, mul_const->dim(2).value());
+ EXPECT_EQ(1, mul_const->dim(3).value());
+}
+
+TEST_F(ForwardTransposeToAddGraphTest, forward_transpose_add_NEG)
+{
+ _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
+
+ // Remove add
+ _graph.output()->from(_graph.transpose());
+
+ luci::ForwardTransposeOpPass pass;
+ EXPECT_FALSE(pass.run(_graph.g()));
+}
+
+TEST_F(ForwardTransposeToAddGraphNegTest, forward_transpose_add_non_const_NEG)
+{
+ _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
+
+ luci::ForwardTransposeOpPass pass;
+ EXPECT_FALSE(pass.run(_graph.g()));
+}
+
+TEST_F(ForwardTransposeToMulGraphTest, forward_transpose_mul_NEG)
+{
+ _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
+
+ // Remove mul
+ _graph.output()->from(_graph.transpose());
+
+ luci::ForwardTransposeOpPass pass;
+ EXPECT_FALSE(pass.run(_graph.g()));
+}
+
+// Unary
+
+namespace
+{
+
+template <typename T> class TransposeUnaryOpGraphlet
+{
+public:
+ TransposeUnaryOpGraphlet() = default;
+
+public:
+ virtual ~TransposeUnaryOpGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 perm)
+ {
+ std::vector<uint32_t> shape_in_v = shape_in;
+ std::vector<uint32_t> perm_v = perm;
+
+ assert(shape_in_v.size() == perm_v.size()); // FIX_CALLER_UNLESS
+
+ _perm = g->nodes()->create<luci::CircleConst>();
+ _const = g->nodes()->create<luci::CircleConst>();
+ _transpose = g->nodes()->create<luci::CircleTranspose>();
+ _unary = g->nodes()->create<T>();
+
+ _perm->dtype(loco::DataType::S32);
+ _perm->rank(1);
+ _perm->dim(0).set(perm_v.size());
+ _perm->shape_status(luci::ShapeStatus::VALID);
+
+ _const->dtype(loco::DataType::FLOAT32);
+ _const->rank(shape_in_v.size());
+ for (uint32_t i = 0; i < shape_in_v.size(); i++)
+ _const->dim(i).set(shape_in_v[perm_v[i]]);
+ _const->shape_status(luci::ShapeStatus::VALID);
+
+ // values
+ const auto size = perm_v.size();
+ _perm->size<loco::DataType::S32>(size);
+ for (uint32_t i = 0; i < size; i++)
+ _perm->at<loco::DataType::S32>(i) = perm_v[i];
+
+ uint32_t elems = 1;
+ for (uint32_t i = 0; i < size; i++)
+ elems *= shape_in_v[i];
+
+ _const->size<loco::DataType::FLOAT32>(elems);
+ for (uint32_t i = 0; i < elems; i++)
+ _const->at<loco::DataType::FLOAT32>(i) = i;
+
+ _perm->name("transpose_perm");
+ _transpose->name("transpose");
+ _unary->name("_unary");
+ }
+
+ luci::CircleTranspose *transpose(void) { return _transpose; }
+
+protected:
+ luci::CircleTranspose *_transpose = nullptr;
+ T *_unary = nullptr;
+ luci::CircleConst *_perm = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
+using TransposeAbsGraphlet = TransposeUnaryOpGraphlet<luci::CircleAbs>;
+
+class ForwardTransposeToAbsGraph : public TestIOGraph, public TransposeAbsGraphlet
+{
+public:
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ TransposeAbsGraphlet::init(g(), shape_in, shape_out);
+
+ // connect network
+ _transpose->a(input());
+ _transpose->perm(_perm);
+ _unary->x(_transpose);
+
+ output()->from(_unary);
+ }
+};
+
+class ForwardTransposeToAbsInvalidGraph : public TestIOGraph, public TransposeAbsGraphlet
+{
+public:
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ TransposeAbsGraphlet::init(g(), shape_in, shape_out);
+
+ _relu = g()->nodes()->create<luci::CircleRelu>();
+ _relu->dtype(loco::DataType::FLOAT32);
+ _relu->name("relu");
+
+ // connect network
+ _relu->features(input());
+ _unary->x(_relu);
+
+ output()->from(_unary);
+ }
+
+protected:
+ luci::CircleRelu *_relu = nullptr;
+};
+
+class ForwardTransposeToAbsGraphTest : public ::testing::Test
+{
+public:
+ void run_pass(void) { run_phase(_graph.g()); }
+
+protected:
+ ForwardTransposeToAbsGraph _graph;
+};
+
+class ForwardTransposeToAbsGraphNegTest : public ::testing::Test
+{
+public:
+ void run_pass(void) { run_phase(_graph.g()); }
+
+protected:
+ ForwardTransposeToAbsInvalidGraph _graph;
+};
+
+} // namespace
+
+TEST_F(ForwardTransposeToAbsGraphTest, forward_abs_x)
+{
+ _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
+
+ run_pass();
+
+ auto transpose = dynamic_cast<luci::CircleTranspose *>(_graph.output()->from());
+ EXPECT_NE(nullptr, transpose);
+ EXPECT_EQ(4, transpose->rank());
+ EXPECT_EQ(1, transpose->dim(0).value());
+ EXPECT_EQ(1, transpose->dim(1).value());
+ EXPECT_EQ(51, transpose->dim(2).value());
+ EXPECT_EQ(64, transpose->dim(3).value());
+
+ auto abs = dynamic_cast<luci::CircleAbs *>(transpose->a());
+ EXPECT_NE(nullptr, abs);
+ EXPECT_EQ(4, abs->rank());
+ EXPECT_EQ(1, abs->dim(0).value());
+ EXPECT_EQ(64, abs->dim(1).value());
+ EXPECT_EQ(51, abs->dim(2).value());
+ EXPECT_EQ(1, abs->dim(3).value());
+}
+
+TEST_F(ForwardTransposeToAbsGraphTest, forward_transpose_abs_NEG)
+{
+ _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
+
+ // Remove abs
+ _graph.output()->from(_graph.transpose());
+
+ luci::ForwardTransposeOpPass pass;
+ EXPECT_FALSE(pass.run(_graph.g()));
+}
+
+TEST_F(ForwardTransposeToAbsGraphNegTest, forward_transpose_abs_non_transpose_NEG)
+{
+ _graph.init({1, 64, 51, 1}, {0, 3, 2, 1});
+
+ luci::ForwardTransposeOpPass pass;
+ EXPECT_FALSE(pass.run(_graph.g()));
+}
diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
index 3cf31ed10..1d4a2e3bf 100644
--- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
+++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
@@ -86,6 +86,14 @@ bool fuse_add_with_fc(luci::CircleFullyConnected *fc)
if (not(addition->dim(rank - 1) == weights->dim(0)))
return false;
+ auto bias = loco::must_cast<luci::CircleNode *>(fc->bias());
+
+ // We only support (1) constant bias (2) no bias
+ // If bias is neither (1) nor (2), it would be a feature map
+ if (bias->opcode() != luci::CircleOpcode::CIRCLECONST and
+ bias->opcode() != luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
+ return false;
+
auto fused_bias = luci::clone(addition);
// Add existing bias values
diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp
index 4cc2eb599..300796594 100644
--- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp
+++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.test.cpp
@@ -125,6 +125,15 @@ public:
public:
luci::CircleFullyConnected *fc() { return _fc; }
+public:
+ void to_fm_bias(void)
+ {
+ assert(_fc != nullptr); // FIX_ME_UNLESS
+
+ auto new_fc = _fc->graph()->nodes()->create<luci::CircleFullyConnected>();
+ _fc->bias(new_fc);
+ }
+
protected:
luci::CircleFullyConnected *_fc = nullptr;
luci::CircleAdd *_add = nullptr;
@@ -174,3 +183,14 @@ TEST_F(FuseAddWithFullyConnectedPassTest, simple_test)
EXPECT_EQ(i, bias->at<loco::DataType::FLOAT32>(i));
}
}
+
+TEST_F(FuseAddWithFullyConnectedPassTest, fm_bias_NEG)
+{
+ g.init();
+
+ // Bias is a feature map. Add is not fused.
+ g.to_fm_bias();
+
+ auto ret = pass.run(g.g());
+ EXPECT_EQ(false, ret);
+}
diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp
index 09180d8c1..3f8f700a9 100644
--- a/compiler/luci/pass/src/FuseBCQPass.cpp
+++ b/compiler/luci/pass/src/FuseBCQPass.cpp
@@ -679,7 +679,6 @@ bool FuseBCQPass::run(luci::Module *m)
if (output_node->index() == 0 || (int)output_node->index() > original_output_cnt)
{
auto noOp = main_graph->nodes()->create<luci::CircleOutputExclude>();
- noOp->dtype(loco::DataType::FLOAT32); // TODO Remove this setting
output_node->from(noOp);
changed = true;
}
diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
index e6b54df36..265a8398b 100644
--- a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
+++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
@@ -23,6 +23,26 @@
namespace
{
+
+template <class CIRCLENODE>
+void replace_with_relu(luci::CircleNode *target, luci::CircleNode *feature,
+ const std::string &relu_name)
+{
+ assert(target != nullptr);
+ assert(feature != nullptr);
+
+ auto relu = target->graph()->nodes()->create<CIRCLENODE>();
+ relu->features(feature);
+ relu->name(relu_name);
+ luci::add_origin(relu, luci::get_origin(target));
+
+ replace(target).with(relu);
+}
+
+} // namespace
+
+namespace
+{
/**
* Fuse Mul-Add to TransposeConv if possible.
*
@@ -49,10 +69,10 @@ namespace
* | / / | /
* [CircleTransposeConv] [CircleAdd]
* |
- * ([CircleRelu6])
+ * ([CircleRelu]/[CircleRelu6])
* |
*
- * Note: CircleRelu6 is inserted if Add activation is ReLU6
+ * Note: CircleRelu or CircleRelu6 is inserted if Add activation is ReLU/ReLU6
*/
bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
{
@@ -80,7 +100,8 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
if (add->dtype() != loco::DataType::FLOAT32)
return false;
if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
- add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
+ add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
+ add->fusedActivationFunction() != luci::FusedActFunc::RELU)
return false;
// tconv bias is optional
@@ -202,19 +223,23 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
luci::add_origin(fused_tconv, luci::get_origin(bias));
}
- if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
+ switch (add->fusedActivationFunction())
{
- // separate relu op from add op
- auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
- relu->features(fused_tconv);
- relu->name(name + "/Relu6");
- luci::add_origin(relu, luci::get_origin(add));
+ case luci::FusedActFunc::RELU6:
+ replace_with_relu<luci::CircleRelu6>(add, fused_tconv, name + "/Relu6");
+ break;
- replace(add).with(relu);
- }
- else
- {
- replace(add).with(fused_tconv);
+ case luci::FusedActFunc::RELU:
+ replace_with_relu<luci::CircleRelu>(add, fused_tconv, name + "/Relu");
+ break;
+
+ case luci::FusedActFunc::NONE:
+ replace(add).with(fused_tconv);
+ break;
+
+ default:
+ assert(false);
+ break;
}
return true;
diff --git a/compiler/luci/pass/src/FusePReluPass.cpp b/compiler/luci/pass/src/FusePReluPass.cpp
new file mode 100644
index 000000000..a5ce60ebf
--- /dev/null
+++ b/compiler/luci/pass/src/FusePReluPass.cpp
@@ -0,0 +1,202 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FusePReluPass.h"
+#include "helpers/NodeFiller.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/Service/CircleNodeClone.h>
+
+#include <cassert>
+
+// Helper to fuse PRelu
+namespace
+{
+
+/**
+ * Below diagram shows PRelu pattern to fuse.
+ * - this pattern will be replaced with one PRelu
+ *
+ * [In]
+ * |
+ * V
+ * +---- ifm ----+
+ * | | |
+ * | | V
+ * | | abs
+ * | V |
+ * | sub <---+
+ * | |
+ * | V
+ * | mul_alpha (alpha of PRelu)
+ * | |
+ * V V
+ * relu mul_half (0.5)
+ * | |
+ * | V
+ * +---> add
+ * |
+ * V
+ * [Out]
+ *
+ */
+class PReluPattern final
+{
+public:
+ PReluPattern(luci::CircleAdd *candidate)
+ {
+ assert(candidate);
+ _add_ofm = candidate;
+ }
+
+public:
+ bool matched();
+
+public:
+ luci::CircleNode *_ifm = nullptr;
+ luci::CircleRelu *_relu = nullptr;
+ luci::CircleAbs *_abs = nullptr;
+ luci::CircleSub *_sub = nullptr;
+ luci::CircleMul *_mul_alpha = nullptr;
+ luci::CircleMul *_mul_half = nullptr;
+ luci::CircleAdd *_add_ofm = nullptr;
+ luci::CircleConst *_const_alpha = nullptr;
+ luci::CircleConst *_const_half = nullptr;
+};
+
+#define CHECK_OR_FALSE(condition) \
+ if (not(condition)) \
+ return false;
+
+bool PReluPattern::matched()
+{
+ // check pattern
+ CHECK_OR_FALSE(luci::fill(&_relu, &_mul_half).with_commutative_args_of(_add_ofm));
+ CHECK_OR_FALSE(luci::fill(&_mul_alpha, &_const_half).with_commutative_args_of(_mul_half));
+ CHECK_OR_FALSE(luci::fill(&_sub, &_const_alpha).with_commutative_args_of(_mul_alpha));
+
+ CHECK_OR_FALSE(luci::fill(&_ifm, &_abs).with_args_of(_sub));
+
+ CHECK_OR_FALSE(_relu->features() == _ifm);
+ CHECK_OR_FALSE(_abs->x() == _ifm);
+
+ // Check Activation to be NONE
+ CHECK_OR_FALSE(_sub->fusedActivationFunction() == luci::FusedActFunc::NONE);
+ CHECK_OR_FALSE(_mul_alpha->fusedActivationFunction() == luci::FusedActFunc::NONE);
+ CHECK_OR_FALSE(_mul_half->fusedActivationFunction() == luci::FusedActFunc::NONE);
+ CHECK_OR_FALSE(_add_ofm->fusedActivationFunction() == luci::FusedActFunc::NONE);
+
+ // TODO support other types?
+ // check if _const_half is really FLOAT32 & 0.5
+ CHECK_OR_FALSE(_const_half->dtype() == loco::DataType::FLOAT32);
+ CHECK_OR_FALSE(_const_half->size<loco::DataType::FLOAT32>() == 1);
+ CHECK_OR_FALSE(_const_half->at<loco::DataType::FLOAT32>(0) == 0.5);
+
+ // check _const_alpha condition
+ CHECK_OR_FALSE(_const_alpha->dtype() == loco::DataType::FLOAT32);
+ // TODO add more if needed
+
+ return true;
+}
+
+#undef CHECK_OR_FALSE
+
+class FusePRelu final
+{
+public:
+ FusePRelu(const PReluPattern &p) : _p(p) {}
+
+public:
+ void apply(void);
+
+private:
+ luci::CirclePRelu *create_prelu(loco::Graph *graph);
+
+private:
+ const PReluPattern &_p;
+};
+
+luci::CirclePRelu *FusePRelu::create_prelu(loco::Graph *graph)
+{
+ assert(graph);
+
+ auto prelu = graph->nodes()->create<luci::CirclePRelu>();
+ prelu->input(_p._ifm);
+ prelu->alpha(_p._const_alpha);
+ prelu->name(_p._add_ofm->name() + "_prelu");
+ return prelu;
+}
+
+void FusePRelu::apply()
+{
+ auto graph = _p._add_ofm->graph();
+
+ auto prelu = create_prelu(graph);
+
+ // set origin
+ std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
+ luci::get_origin(_p._relu), luci::get_origin(_p._abs), luci::get_origin(_p._sub),
+ luci::get_origin(_p._mul_alpha), luci::get_origin(_p._mul_half), luci::get_origin(_p._add_ofm)};
+
+ luci::add_origin(prelu, luci::composite_origin(origin_vec));
+
+ replace(_p._add_ofm).with(prelu);
+}
+
+} // namespace
+
+namespace
+{
+
+bool fuse_prelu(luci::CircleAdd *add)
+{
+ assert(add);
+
+ PReluPattern pattern(add);
+ if (pattern.matched())
+ {
+ FusePRelu fuse(pattern);
+ fuse.apply();
+ return true;
+ }
+ return false;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool FusePReluPass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto add = dynamic_cast<luci::CircleAdd *>(node);
+ if (not add)
+ continue;
+
+ if (fuse_prelu(add))
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FusePReluPass.test.cpp b/compiler/luci/pass/src/FusePReluPass.test.cpp
new file mode 100644
index 000000000..209fe3911
--- /dev/null
+++ b/compiler/luci/pass/src/FusePReluPass.test.cpp
@@ -0,0 +1,187 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FusePReluPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class PReluGraphlet
+{
+public:
+ PReluGraphlet() = default;
+
+ void init(loco::Graph *g)
+ {
+ _abs = g->nodes()->create<luci::CircleAbs>();
+ _sub = g->nodes()->create<luci::CircleSub>();
+ _mul_alpha = g->nodes()->create<luci::CircleMul>();
+ _mul_half = g->nodes()->create<luci::CircleMul>();
+ _relu = g->nodes()->create<luci::CircleRelu>();
+ _add = g->nodes()->create<luci::CircleAdd>();
+ _const_alpha = g->nodes()->create<luci::CircleConst>();
+ _const_half = g->nodes()->create<luci::CircleConst>();
+
+ _sub->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _mul_alpha->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _mul_half->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _add->fusedActivationFunction(luci::FusedActFunc::NONE);
+
+ _abs->name("abs");
+ _sub->name("sub");
+ _mul_alpha->name("mul_alpha");
+ _mul_half->name("mul_half");
+ _relu->name("relu");
+ _add->name("add");
+ _const_alpha->name("const_alpha");
+ _const_half->name("const_half");
+
+ _const_alpha->dtype(loco::DataType::FLOAT32);
+ _const_alpha->size<loco::DataType::FLOAT32>(1);
+ _const_alpha->shape({1});
+ _const_alpha->at<loco::DataType::FLOAT32>(0) = 0.1;
+ _const_alpha->shape_status(luci::ShapeStatus::VALID);
+
+ _const_half->dtype(loco::DataType::FLOAT32);
+ _const_half->size<loco::DataType::FLOAT32>(1);
+ _const_half->shape({1});
+ _const_half->at<loco::DataType::FLOAT32>(0) = 0.5;
+ _const_half->shape_status(luci::ShapeStatus::VALID);
+ }
+
+ void invalid_half() { _const_half->at<loco::DataType::FLOAT32>(0) = 0.1; }
+ void invalid_act() { _add->fusedActivationFunction(luci::FusedActFunc::RELU); }
+
+protected:
+ luci::CircleAbs *_abs = nullptr;
+ luci::CircleSub *_sub = nullptr;
+ luci::CircleMul *_mul_alpha = nullptr;
+ luci::CircleMul *_mul_half = nullptr;
+ luci::CircleRelu *_relu = nullptr;
+ luci::CircleAdd *_add = nullptr;
+ luci::CircleConst *_const_alpha = nullptr;
+ luci::CircleConst *_const_half = nullptr;
+};
+
+class FusePReluTestGraph : public TestIOGraph, public PReluGraphlet
+{
+public:
+ FusePReluTestGraph() = default;
+
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ PReluGraphlet::init(g());
+
+ _relu->features(input());
+ _abs->x(input());
+ _sub->x(input());
+ _sub->y(_abs);
+ _mul_alpha->x(_sub);
+ _mul_alpha->y(_const_alpha);
+ _mul_half->x(_mul_alpha);
+ _mul_half->y(_const_half);
+ _add->x(_relu);
+ _add->y(_mul_half);
+
+ output()->from(_add);
+ }
+};
+
+class FusePReluTestNegGraph : public TestIOGraph, public PReluGraphlet
+{
+public:
+ FusePReluTestNegGraph() = default;
+
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ PReluGraphlet::init(g());
+
+ _relu->features(input());
+ _abs->x(input());
+ // NOTE x and y are incorrect
+ _sub->x(_abs);
+ _sub->y(input());
+ _mul_alpha->x(_sub);
+ _mul_alpha->y(_const_alpha);
+ _mul_half->x(_mul_alpha);
+ _mul_half->y(_const_half);
+ _add->x(_relu);
+ _add->y(_mul_half);
+
+ output()->from(_add);
+ }
+};
+
+} // namespace
+
+TEST(FusePReluPassTest, name)
+{
+ luci::FusePReluPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(FusePReluPassTest, fuse)
+{
+ FusePReluTestGraph g;
+ luci::FusePReluPass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+}
+
+TEST(FusePReluPassTest, fuse_invalid_half_NEG)
+{
+ FusePReluTestNegGraph g;
+ luci::FusePReluPass pass;
+
+ g.init();
+ g.invalid_half();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
+
+TEST(FusePReluPassTest, fuse_invalid_act_NEG)
+{
+ FusePReluTestNegGraph g;
+ luci::FusePReluPass pass;
+
+ g.init();
+ g.invalid_act();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
+
+TEST(FusePReluPassTest, fuse_NEG)
+{
+ FusePReluTestNegGraph g;
+ luci::FusePReluPass pass;
+
+ g.init();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp
index 06a4ae9f6..45d229a0b 100644
--- a/compiler/luci/pass/src/QuantizationUtils.cpp
+++ b/compiler/luci/pass/src/QuantizationUtils.cpp
@@ -34,6 +34,8 @@ bool is_quantized(const CircleNode *node)
node->dtype() == loco::DataType::S64); // bias (int16 quant)
}
+bool is_fp32(const CircleNode *node) { return node->dtype() == loco::DataType::FLOAT32; }
+
uint8_t fp32_to_uint8_cast(float f)
{
assert(std::numeric_limits<uint8_t>::min() <= f);
@@ -124,8 +126,8 @@ void compute_sym_scale_zp(float min, float max, float &scaling_factor, int64_t &
: scale_factor_from_max_side;
// protect scale from being very low to avoid overflow/underflow
- if (scaling_factor < 1e-9)
- scaling_factor = 1e-9;
+ if (scaling_factor < 1e-8)
+ scaling_factor = 1e-8;
zp = 0;
nudged_min = static_cast<float>(qmin_double * scaling_factor);
diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h
index 4d5316ccb..0720c9839 100644
--- a/compiler/luci/pass/src/QuantizationUtils.h
+++ b/compiler/luci/pass/src/QuantizationUtils.h
@@ -60,6 +60,9 @@ void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2);
// Return true if the node is quantized
bool is_quantized(const CircleNode *node);
+// Return true if the node is fp32
+bool is_fp32(const CircleNode *node);
+
enum ActivationQType
{
MinMax, // Quantize using recorded min/max
diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp
index 95251a82c..214e61c1e 100644
--- a/compiler/luci/pass/src/QuantizeActivation.cpp
+++ b/compiler/luci/pass/src/QuantizeActivation.cpp
@@ -44,12 +44,8 @@ void QuantizeActivation::visit(luci::CircleNode *node)
LOGGER(l);
INFO(l) << "QuantizeActivation visit node: " << node->name() << std::endl;
- // Check if this is already quantized
- if (is_quantized(node))
- return;
-
- // Check if this is bool type (bool type is not quantized)
- if (node->dtype() == loco::DataType::BOOL)
+ // Check if node is fp32
+ if (not is_fp32(node))
return;
// Check if this is const (const activation is handled by QuantizeConstInputActivation)
@@ -185,7 +181,7 @@ void QuantizeConstInputActivation::visit(luci::CircleNode *node)
{ \
auto input = node->INPUT_NAME(); \
auto const_node = dynamic_cast<luci::CircleConst *>(input); \
- if (const_node && !is_quantized(const_node)) \
+ if (const_node && is_fp32(const_node)) \
{ \
auto new_const = luci::clone(const_node); \
quant_const(new_const, _output_type); \
@@ -199,7 +195,7 @@ void QuantizeConstInputActivation::visit(luci::CircleNode *node)
{ \
auto input1 = node->INPUT_NAME1(); \
auto const_node1 = dynamic_cast<luci::CircleConst *>(input1); \
- if (const_node1 && !is_quantized(const_node1)) \
+ if (const_node1 && is_fp32(const_node1)) \
{ \
auto new_const1 = luci::clone(const_node1); \
quant_const(new_const1, _output_type); \
@@ -207,7 +203,7 @@ void QuantizeConstInputActivation::visit(luci::CircleNode *node)
} \
auto input2 = node->INPUT_NAME2(); \
auto const_node2 = dynamic_cast<luci::CircleConst *>(input2); \
- if (const_node2 && !is_quantized(const_node2)) \
+ if (const_node2 && is_fp32(const_node2)) \
{ \
auto new_const2 = luci::clone(const_node2); \
quant_const(new_const2, _output_type); \
@@ -216,6 +212,7 @@ void QuantizeConstInputActivation::visit(luci::CircleNode *node)
}
// Ops that receive a single activation as an input
+QUANTIZE_SINGLE_CONST_INPUT(luci::CircleAbs, x)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMax, input)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleArgMin, input)
QUANTIZE_SINGLE_CONST_INPUT(luci::CircleBatchToSpaceND, input)
@@ -278,7 +275,7 @@ void QuantizeConstInputActivation::visit(luci::CircleAddN *node)
{
auto input_node = node->inputs(i);
auto const_node = dynamic_cast<luci::CircleConst *>(input_node);
- if (const_node && !is_quantized(const_node))
+ if (const_node && is_fp32(const_node))
{
auto new_const = luci::clone(const_node);
quant_const(new_const, _output_type);
diff --git a/compiler/luci/pass/src/QuantizeActivation.h b/compiler/luci/pass/src/QuantizeActivation.h
index fc32d1cde..c6c991a76 100644
--- a/compiler/luci/pass/src/QuantizeActivation.h
+++ b/compiler/luci/pass/src/QuantizeActivation.h
@@ -102,6 +102,7 @@ private:
void visit(luci::CircleNode *node);
// Ops that receive a single activation as an input
+ void visit(luci::CircleAbs *node);
void visit(luci::CircleArgMax *node);
void visit(luci::CircleArgMin *node);
void visit(luci::CircleBatchToSpaceND *node);
diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp
index 500ae12ed..29cdaffff 100644
--- a/compiler/luci/pass/src/QuantizeWeights.cpp
+++ b/compiler/luci/pass/src/QuantizeWeights.cpp
@@ -90,6 +90,118 @@ void asym_wquant_per_channel(CircleConst *node, std::vector<float> &min,
}
}
+// TODO Reduce duplicate code with QuantizeDequantizeWeights
+void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
+ std::vector<float> &scaling_factor, std::vector<int64_t> &zp,
+ std::vector<float> &nudged_min, std::vector<float> &nudged_max,
+ int32_t &channel_dim_index)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+ const int32_t kMaxScale = std::numeric_limits<int16_t>::max();
+ const int32_t kMinScale = -kMaxScale;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ for (size_t i = 0; i < min.size(); ++i)
+ {
+ compute_sym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]);
+ }
+
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
+ data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round(data * scaling_factor_inv));
+ };
+
+ iterate_per_channel(node, channel_dim_index, quantize);
+
+ node->dtype(loco::DataType::S16); // change the type of tensor
+ node->size<loco::DataType::S16>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::S16>(i) =
+ std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
+void cal_minmax_per_channel(CircleConst *node, std::vector<float> &min, std::vector<float> &max,
+ int32_t &channel_dim_index)
+{
+ loco::TensorShape dimension;
+ dimension.rank(4);
+
+ if (!get_channel_dim_index(node, dimension, channel_dim_index))
+ {
+ throw std::runtime_error("Failed to find channel index in " + node->name());
+ }
+ auto size = dimension.dim(channel_dim_index).value();
+
+ std::vector<bool> has_min_max_value(size, false);
+ min.resize(size);
+ max.resize(size);
+
+ auto cal_minmax = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ if (has_min_max_value[channel_idx])
+ {
+ min[channel_idx] = data < min[channel_idx] ? data : min[channel_idx];
+ max[channel_idx] = data > max[channel_idx] ? data : max[channel_idx];
+ }
+ else
+ {
+ min[channel_idx] = data;
+ max[channel_idx] = data;
+ has_min_max_value[channel_idx] = true;
+ }
+ };
+
+ iterate_per_channel(node, channel_dim_index, cal_minmax);
+}
+
+void asymmetric_wquant_per_channel(CircleConst *node, std::vector<float> &min,
+ std::vector<float> &max, std::vector<float> &scaling_factor,
+ std::vector<int64_t> &zp, std::vector<float> &nudged_min,
+ std::vector<float> &nudged_max, int32_t &channel_dim_index)
+{
+ assert(node->dtype() == loco::DataType::FLOAT32);
+
+ const int32_t kMinScale = 0;
+ const int32_t kMaxScale = 255;
+
+ uint32_t size = node->size<loco::DataType::FLOAT32>();
+ std::vector<int32_t> quantized_values(size);
+
+ for (size_t i = 0; i < min.size(); ++i)
+ {
+ compute_asym_scale_zp(min[i], max[i], scaling_factor[i], zp[i], nudged_min[i], nudged_max[i]);
+ }
+
+ auto quantize = [&](uint32_t *indices, loco::TensorShape &dimension, int channel_dim_index) {
+ int channel_idx = indices[channel_dim_index];
+ const float scaling_factor_inv = 1.0 / scaling_factor[channel_idx];
+ auto data = node->at<loco::DataType::FLOAT32>(cal_offset(dimension, indices));
+ data = data < nudged_min[channel_idx] ? nudged_min[channel_idx] : data;
+ data = data > nudged_max[channel_idx] ? nudged_max[channel_idx] : data;
+ quantized_values[cal_offset(dimension, indices)] =
+ static_cast<int32_t>(std::round((data - nudged_min[channel_idx]) * scaling_factor_inv));
+ };
+
+ iterate_per_channel(node, channel_dim_index, quantize);
+
+ node->dtype(loco::DataType::U8); // change the type of tensor
+ node->size<loco::DataType::U8>(size); // resize tensor
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ node->at<loco::DataType::U8>(i) = std::min(kMaxScale, std::max(kMinScale, quantized_values[i]));
+ }
+}
+
void sym_wquant_per_channel(CircleConst *node, std::vector<float> &scaling_factor,
int32_t &channel_dim_index)
{
@@ -250,7 +362,37 @@ void QuantizeWeights::quantize_weights(luci::CircleConst *weights)
auto quantparam = weights->quantparam();
if (quantparam == nullptr)
{
- assert(false && "quantparam is nullptr");
+ // Find min/max on the fly
+ // NOTE This is for the case when QuantizeDequantizeWeights is skipped
+ // TODO Reduce duplicate codes
+ std::vector<float> min;
+ std::vector<float> max;
+ int32_t channel_dim_index = 0;
+
+ cal_minmax_per_channel(weights, min, max, channel_dim_index);
+
+ std::vector<float> nudged_min(min.size());
+ std::vector<float> nudged_max(min.size());
+ std::vector<float> scaling_factor(min.size());
+ std::vector<int64_t> zp(min.size());
+
+ if (output_type == loco::DataType::U8)
+ {
+ asymmetric_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max,
+ channel_dim_index);
+ }
+ else
+ {
+ sym_wquant_per_channel(weights, min, max, scaling_factor, zp, nudged_min, nudged_max,
+ channel_dim_index);
+ }
+
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale = scaling_factor;
+ quantparam->zerop = zp;
+ quantparam->quantized_dimension = channel_dim_index;
+ weights->quantparam(std::move(quantparam));
+
return;
}
@@ -273,8 +415,35 @@ void QuantizeWeights::quantize_weights(luci::CircleConst *weights)
// Find min/max per layer-wise
else
{
- // Quantize using recorded quantparam
auto quantparam = weights->quantparam();
+ if (quantparam == nullptr)
+ {
+ // Find min/max on the fly
+ // NOTE This is for the case when QuantizeDequantizeWeights is skipped
+ // TODO Reduce duplicate codes
+ float min = std::numeric_limits<float>::max();
+ float max = std::numeric_limits<float>::lowest();
+ for (uint32_t i = 0; i < weights->size<loco::DataType::FLOAT32>(); i++)
+ {
+ auto data = weights->at<loco::DataType::FLOAT32>(i);
+ min = data < min ? data : min;
+ max = data > max ? data : max;
+ }
+ float scaling_factor{0};
+ int64_t zp{0};
+ float nudged_min{0};
+ float nudged_max{0};
+
+ asymmetric_wquant_with_minmax_per_layer(weights, min, max, scaling_factor, zp, nudged_min,
+ nudged_max);
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale.push_back(scaling_factor);
+ quantparam->zerop.push_back(zp);
+ weights->quantparam(std::move(quantparam));
+ return;
+ }
+
+ // Quantize using recorded quantparam
assert(quantparam != nullptr);
assert(quantparam->min.size() == 1); // only support layer-wise quant
assert(quantparam->scale.size() == 1); // only support layer-wise quant
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index 005144516..c68e06712 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -32,8 +32,6 @@
#include <luci/Log.h>
#include <logo/Phase.h>
-#include <oops/UserExn.h>
-
#include <iostream>
#include <cmath>
@@ -154,8 +152,8 @@ namespace
* 2. After output feature map
*
* For example, if default_dtype = U8 and op_dtype = S16,
- * 1. Quantize Op for U8->S16 is inserted before ifm
- * 2. Quantize Op for S16->U8 is inserted after ofm
+ * 1. Quantize (U8->S16) is inserted before ifm
+ * 2. Quantize (S16->U8) is inserted after ofm
*
* Why not insert Quantize Op for const ifm?
* We quantize const tensor at once to preserve precision.
@@ -181,6 +179,10 @@ private:
if (input->opcode() == luci::CircleOpcode::CIRCLECONST)
return nullptr;
+ // input is not quantizable (ex: index)
+ if (input->quantparam() == nullptr)
+ return nullptr;
+
auto input_quant = create_quantize_op(input, _op_dtype);
input_quant->input(input);
auto origin_node = loco::must_cast<luci::CircleNode *>(origin);
@@ -192,6 +194,11 @@ private:
{
auto output = loco::must_cast<luci::CircleNode *>(node);
assert(output->opcode() != luci::CircleOpcode::CIRCLECONST); // FIX_CALLER_UNLESS
+
+ // output is not quantizable (ex: index)
+ if (output->quantparam() == nullptr)
+ return;
+
auto output_quant = create_quantize_op(output, _default_dtype);
luci::add_origin(output_quant, luci::get_origin(output));
@@ -253,6 +260,7 @@ private:
void visit(luci::CircleUnpackOut *) {}
// Ops that receive a single activation as an input
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleAbs, x)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleAveragePool2D, value)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleBatchToSpaceND, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleConv2D, input)
@@ -365,10 +373,20 @@ private:
void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
{
auto inputs = g->inputs();
- for (auto node : loco::input_nodes(g))
+
+ assert(inputs); // FIX_CALLER_UNLESS
+ assert(inputs->size() == _ctx->input_types.size()); // FIX_CALLER_UNLESS
+
+ // NOTE loco::input_nodes returns input nodes following the order of InputIndex
+ auto input_nodes = loco::input_nodes(g);
+ for (uint32_t i = 0; i < input_nodes.size(); i++)
{
- auto input = loco::must_cast<luci::CircleInput *>(node);
- if (input->dtype() == _ctx->input_type)
+ auto input = loco::must_cast<luci::CircleInput *>(input_nodes[i]);
+ assert(i == input->index()); // Fix input_type logic
+
+ const auto user_given_dtype = _ctx->input_types[i];
+
+ if (input->dtype() == user_given_dtype)
continue;
// Bool type is not quantizable
@@ -394,7 +412,7 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
// Update qparam of input
// This step is skipped if input_type is float32
- if (_ctx->input_type != loco::DataType::FLOAT32)
+ if (user_given_dtype != loco::DataType::FLOAT32)
{
auto quantparam = input->quantparam();
assert(quantparam);
@@ -408,13 +426,13 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
float nudged_min{0};
float nudged_max{0};
- if (_ctx->input_type == loco::DataType::U8)
+ if (user_given_dtype == loco::DataType::U8)
{
compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
}
else
{
- assert(_ctx->input_type == loco::DataType::S16);
+ assert(user_given_dtype == loco::DataType::S16);
compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
}
input->quantparam()->scale[0] = scaling_factor;
@@ -422,20 +440,29 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
}
// Update dtype of input
- input->dtype(_ctx->input_type);
+ input->dtype(user_given_dtype);
auto graph_input = inputs->at(input->index());
- graph_input->dtype(_ctx->input_type);
+ graph_input->dtype(user_given_dtype);
}
}
void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
{
auto outputs = g->outputs();
- for (auto node : loco::output_nodes(g))
+ assert(outputs); // FIX_CALLER_UNLESS
+ assert(outputs->size() == _ctx->output_types.size()); // Fix CircleQuantizer unless
+
+ // NOTE loco::output_nodes returns output nodes following the order of OutputIndex
+ auto output_nodes = loco::output_nodes(g);
+ for (uint32_t i = 0; i < output_nodes.size(); i++)
{
- auto output = loco::must_cast<luci::CircleOutput *>(node);
- if (output->dtype() == _ctx->output_type)
+ auto output = loco::must_cast<luci::CircleOutput *>(output_nodes[i]);
+ assert(i == output->index()); // Fix output_type logic
+
+ const auto user_given_dtype = _ctx->output_types[i];
+
+ if (output->dtype() == user_given_dtype)
continue;
// Bool type is not quantizable
@@ -444,12 +471,12 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
auto from = loco::must_cast<luci::CircleNode *>(output->from());
- // The last Op is not quantizable Op (ex: ArgMax)
+ // The last Op is not quantizable (ex: ArgMax)
if (not from->quantparam())
continue;
// Insert Dequantize Op for float32 output_type
- if (_ctx->output_type == loco::DataType::FLOAT32)
+ if (user_given_dtype == loco::DataType::FLOAT32)
{
auto dequant_op = create_dequantize(from);
loco::replace(from).with(dequant_op);
@@ -458,7 +485,7 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
else
{
// Insert Quantize Op for non-float32 output_type
- auto quant_op = create_quantize_op(from, _ctx->output_type);
+ auto quant_op = create_quantize_op(from, user_given_dtype);
loco::replace(from).with(quant_op);
quant_op->input(from);
@@ -467,10 +494,10 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
}
// Update dtype of output
- output->dtype(_ctx->output_type);
+ output->dtype(user_given_dtype);
auto graph_output = outputs->at(output->index());
- graph_output->dtype(_ctx->output_type);
+ graph_output->dtype(user_given_dtype);
}
}
@@ -493,9 +520,9 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
* Weights is quantized using min/max of its value
*
* Bias is quantized using input scale (s_i) and weights scale (s_w)
- * - Activation and weights should be quantized earlier than bias
+ * - Therefore, activation and weights should be quantized earlier than bias
*
- * Quantization Steps
+ * Overall Quantization Steps
* 1. Quantize Activation
* - Quantize using recorded min/max (QuantizeActivation)
* - Insert Quantize Ops for mixed-precision quantization (InsertQuantizeOp)
@@ -550,7 +577,10 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
};
// Quantize activation
- for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ // Why all_nodes?
+ // Models can have inactive (unused) inputs.
+ // We do not reject such models, but quantize them too
+ for (auto node : loco::all_nodes(g))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
QuantizeActivation qa(_ctx->input_model_dtype, quantize_dtype(circle_node));
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
index d5fa21ffd..49c2d4652 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.test.cpp
@@ -53,8 +53,14 @@ public:
TEST(QuantizeWithMinMaxPassTest, name)
{
- luci::QuantizeWithMinMaxPass pass(loco::DataType::FLOAT32, loco::DataType::U8,
- luci::QuantizationGranularity::LayerWise);
+ auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
+ {
+ ctx->input_model_dtype = loco::DataType::FLOAT32;
+ ctx->output_model_dtype = loco::DataType::U8;
+ ctx->granularity = luci::QuantizationGranularity::LayerWise;
+ }
+
+ luci::QuantizeWithMinMaxPass pass(std::move(ctx));
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
@@ -65,8 +71,14 @@ TEST(QuantizeWithMinMaxPassTest, int_concat)
{
SimpleConcatGraph g(loco::DataType::S32);
- luci::QuantizeWithMinMaxPass qwmm(loco::DataType::FLOAT32, loco::DataType::U8,
- luci::QuantizationGranularity::LayerWise);
+ auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
+ {
+ ctx->input_model_dtype = loco::DataType::FLOAT32;
+ ctx->output_model_dtype = loco::DataType::U8;
+ ctx->granularity = luci::QuantizationGranularity::LayerWise;
+ }
+
+ luci::QuantizeWithMinMaxPass qwmm(std::move(ctx));
qwmm.run(&g.g);
@@ -74,3 +86,22 @@ TEST(QuantizeWithMinMaxPassTest, int_concat)
EXPECT_EQ(nullptr, g.input_1->quantparam());
EXPECT_EQ(nullptr, g.input_2->quantparam());
}
+
+TEST(QuantizeWithMinMaxPassTest, inactive_input)
+{
+ SimpleConcatGraph g(loco::DataType::FLOAT32);
+
+ // Unused input
+ g.g.nodes()->create<luci::CircleInput>();
+
+ auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
+ {
+ ctx->input_model_dtype = loco::DataType::FLOAT32;
+ ctx->output_model_dtype = loco::DataType::U8;
+ ctx->granularity = luci::QuantizationGranularity::LayerWise;
+ }
+
+ luci::QuantizeWithMinMaxPass qwmm(std::move(ctx));
+
+ EXPECT_NO_THROW(qwmm.run(&g.g));
+}
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.h b/compiler/luci/pass/src/QuantizedModelVerifier.h
index 7409a51d7..d9bea434d 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.h
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.h
@@ -38,26 +38,13 @@ public:
{
loco::DataType output_model_dtype = loco::DataType::Unknown;
QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
- loco::DataType input_type = loco::DataType::Unknown;
- loco::DataType output_type = loco::DataType::Unknown;
+ std::vector<loco::DataType> input_types;
+ std::vector<loco::DataType> output_types;
bool TF_style_maxpool = false;
std::vector<LayerInfo> layers_info;
};
public:
- QuantizedModelVerifier(loco::DataType quantized_dtype, QuantizationGranularity granularity)
- {
- _ctx = std::make_unique<Context>();
- {
- _ctx->output_model_dtype = quantized_dtype;
- _ctx->granularity = granularity;
- _ctx->input_type = quantized_dtype;
- _ctx->output_type = quantized_dtype;
- _ctx->TF_style_maxpool = false;
- }
- }
-
-public:
QuantizedModelVerifier(std::unique_ptr<Context> &&ctx) : _ctx{std::move(ctx)}
{
// DO NOTHING
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
index 21b4fe1c6..05ec31727 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
@@ -18,7 +18,9 @@
#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/QuantizationParameters.h"
+#include "luci/Pass/CircleTypeInferencePass.h"
+#include <logo/Phase.h>
#include <luci/test/TestIOGraph.h>
#include <gtest/gtest.h>
@@ -104,12 +106,56 @@ void insert_scale_zp(luci::CircleNode *node, float scale, int64_t zp)
qparam->zerop.push_back(zp);
}
+void run_phase(loco::Graph *g, Type quantized_dtype, Granularity granularity)
+{
+ logo::Phase phase;
+
+ // Default passes.
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+
+ auto ctx = std::make_unique<luci::QuantizeWithMinMaxPass::Context>();
+ {
+ ctx->input_model_dtype = loco::DataType::FLOAT32;
+ ctx->output_model_dtype = quantized_dtype;
+ ctx->granularity = granularity;
+ // Test graph has only one input/output
+ ctx->input_types = {quantized_dtype};
+ ctx->output_types = {quantized_dtype};
+ }
+
+ phase.emplace_back(std::make_unique<luci::QuantizeWithMinMaxPass>(std::move(ctx)));
+
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
+ phase_runner.run(phase);
+}
+
+void run_phase(loco::Graph *g, std::unique_ptr<luci::QuantizeWithMinMaxPass::Context> &&ctx)
+{
+ logo::Phase phase;
+
+ // Default passes.
+ phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+
+ phase.emplace_back(std::make_unique<luci::QuantizeWithMinMaxPass>(std::move(ctx)));
+
+ logo::PhaseRunner<logo::PhaseStrategy::Restart> phase_runner{g};
+ phase_runner.run(phase);
+}
+
void quantize_and_verify(loco::Graph *g, Type quantized_dtype, Granularity granularity)
{
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
- pass.run(g);
+ run_phase(g, quantized_dtype, granularity);
- luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
+ auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>();
+ {
+ ctx->output_model_dtype = quantized_dtype;
+ ctx->granularity = granularity;
+ // Test graph has only one input/output
+ ctx->input_types = {quantized_dtype};
+ ctx->output_types = {quantized_dtype};
+ }
+
+ luci::QuantizedModelVerifier verifier(std::move(ctx));
verifier.verify(g);
}
@@ -132,14 +178,14 @@ void quantize_and_verify_with_layer_info(loco::Graph *g, Type quantized_dtype,
ctx->input_model_dtype = Type::FLOAT32;
ctx->output_model_dtype = quantized_dtype;
ctx->granularity = granularity;
- ctx->input_type = quantized_dtype;
- ctx->output_type = quantized_dtype;
+ // Test graph has only one input/output
+ ctx->input_types = {quantized_dtype};
+ ctx->output_types = {quantized_dtype};
ctx->TF_style_maxpool = false;
ctx->layers_info.push_back(info);
}
- luci::QuantizeWithMinMaxPass pass(std::move(ctx));
- pass.run(g);
+ run_phase(g, std::move(ctx));
}
// Do verification
@@ -148,8 +194,8 @@ void quantize_and_verify_with_layer_info(loco::Graph *g, Type quantized_dtype,
{
ctx->output_model_dtype = quantized_dtype;
ctx->granularity = granularity;
- ctx->input_type = quantized_dtype;
- ctx->output_type = quantized_dtype;
+ ctx->input_types = {quantized_dtype};
+ ctx->output_types = {quantized_dtype};
ctx->TF_style_maxpool = false;
ctx->layers_info.push_back(info);
}
@@ -164,13 +210,21 @@ void quantize_and_verify_with_layer_info(loco::Graph *g, Type quantized_dtype,
void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quantized_dtype,
Granularity granularity, Type wrong_dtype)
{
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
- pass.run(g->g());
+ run_phase(g->g(), quantized_dtype, granularity);
auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
node->dtype(wrong_dtype);
- luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
+ auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>();
+ {
+ ctx->output_model_dtype = quantized_dtype;
+ ctx->granularity = granularity;
+ // Test graph has only one input/output
+ ctx->input_types = {quantized_dtype};
+ ctx->output_types = {quantized_dtype};
+ }
+
+ luci::QuantizedModelVerifier verifier(std::move(ctx));
verifier.verify(g->g());
}
@@ -179,13 +233,21 @@ void quantize_and_verify_with_wrong_type(luci::test::TestIOGraph *g, Type quanti
void quantize_and_verify_with_wrong_granularity(luci::test::TestIOGraph *g, Type quantized_dtype,
Granularity granularity)
{
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, quantized_dtype, granularity);
- pass.run(g->g());
+ run_phase(g->g(), quantized_dtype, granularity);
auto node = loco::must_cast<luci::CircleNode *>(g->output()->from());
insert_scale_zp(node, 1.0, 1);
- luci::QuantizedModelVerifier verifier(quantized_dtype, granularity);
+ auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>();
+ {
+ ctx->output_model_dtype = quantized_dtype;
+ ctx->granularity = granularity;
+ // Test graph has only one input/output
+ ctx->input_types = {quantized_dtype};
+ ctx->output_types = {quantized_dtype};
+ }
+
+ luci::QuantizedModelVerifier verifier(std::move(ctx));
verifier.verify(g->g());
}
@@ -238,6 +300,24 @@ public:
virtual void init(void) = 0;
};
+class TypedTestGraph : public luci::test::TestIOGraph
+{
+protected:
+ void init(Type T, const luci::test::ShapeU32 shape_in, const luci::test::ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+
+ input()->dtype(T);
+ output()->dtype(T);
+
+ g()->inputs()->at(0)->dtype(T);
+ g()->outputs()->at(0)->dtype(T);
+ }
+
+public:
+ virtual void init(void) = 0;
+};
+
class InstanceNormTestGraph final : public SimpleTestGraph
{
public:
@@ -603,6 +683,9 @@ public:
output()->from(_argmax);
set_minmax_to_non_const(g(), -1, 1);
+
+ // Sync output dtype with graph's output dtype
+ g()->outputs()->at(0)->dtype(output()->dtype());
}
public:
@@ -904,6 +987,9 @@ public:
output()->from(_op);
set_minmax_to_non_const(g(), -1, 1);
+
+ // Sync output dtype with graph's output dtype
+ g()->outputs()->at(0)->dtype(output()->dtype());
}
loco::Node *x(void) const { return _op->x(); }
@@ -934,6 +1020,9 @@ public:
output()->from(_op);
set_minmax_to_non_const(g(), -1, 1);
+
+ // Sync output dtype with graph's output dtype
+ g()->outputs()->at(0)->dtype(output()->dtype());
}
loco::Node *x(void) const { return _op->x(); }
@@ -1218,6 +1307,33 @@ private:
luci::CircleConst *_const = nullptr;
};
+template <Type T> class IntMulTestGraph final : public TypedTestGraph
+{
+public:
+ void init(void) override
+ {
+ TypedTestGraph::init(T, {32}, {32});
+
+ _const = create_dummy_const<T>(g(), {32});
+ _mul = g()->nodes()->create<luci::CircleMul>();
+ {
+ _mul->x(input());
+ _mul->y(_const);
+ _mul->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _mul->name("test");
+ _mul->dtype(T);
+ }
+ output()->from(_mul);
+ }
+
+ loco::Node *x() { return _mul->x(); }
+ loco::Node *y() { return _mul->y(); }
+
+private:
+ luci::CircleMul *_mul = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
class AddTestGraph final : public SimpleTestGraph
{
public:
@@ -1246,6 +1362,33 @@ private:
luci::CircleConst *_const = nullptr;
};
+template <Type T> class IntAddTestGraph final : public TypedTestGraph
+{
+public:
+ void init(void) override
+ {
+ TypedTestGraph::init(T, {32}, {32});
+
+ _const = create_dummy_const<T>(g(), {32});
+ _add = g()->nodes()->create<luci::CircleAdd>();
+ {
+ _add->x(input());
+ _add->y(_const);
+ _add->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _add->name("test");
+ _add->dtype(T);
+ }
+ output()->from(_add);
+ }
+
+ loco::Node *x() { return _add->x(); }
+ loco::Node *y() { return _add->y(); }
+
+private:
+ luci::CircleAdd *_add = nullptr;
+ luci::CircleConst *_const = nullptr;
+};
+
} // namespace
// Quantize and verify with given configurations
@@ -1286,34 +1429,46 @@ private:
// Quantize and verify with wrong type
// Users can specify the test target
-#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity, wrong_dtype, target) \
- do \
- { \
- graph g; \
- g.init(); \
- auto node = loco::must_cast<luci::CircleNode *>(target); \
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \
- pass.run(g.g()); \
- auto after_node = loco::must_cast<luci::CircleNode *>(target); \
- after_node->dtype(wrong_dtype); \
- luci::QuantizedModelVerifier verifier(type, granularity); \
- EXPECT_ANY_THROW(verifier.verify(g.g())); \
+#define TEST_WITH_WRONG_TYPE_TARGET(graph, type, granularity_, wrong_dtype, target) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ auto node = loco::must_cast<luci::CircleNode *>(target); \
+ run_phase(g.g(), type, granularity_); \
+ auto after_node = loco::must_cast<luci::CircleNode *>(target); \
+ after_node->dtype(wrong_dtype); \
+ auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>(); \
+ { \
+ ctx->output_model_dtype = type; \
+ ctx->granularity = granularity_; \
+ ctx->input_types = {type}; \
+ ctx->output_types = {type}; \
+ } \
+ luci::QuantizedModelVerifier verifier(std::move(ctx)); \
+ EXPECT_ANY_THROW(verifier.verify(g.g())); \
} while (0)
// Quantize and verify with wrong granularity
// Users can specify the test target
-#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity, target) \
- do \
- { \
- graph g; \
- g.init(); \
- auto node = loco::must_cast<luci::CircleNode *>(target); \
- luci::QuantizeWithMinMaxPass pass(Type::FLOAT32, type, granularity); \
- pass.run(g.g()); \
- auto after_node = loco::must_cast<luci::CircleNode *>(target); \
- insert_scale_zp(after_node, 1.0, 1); \
- luci::QuantizedModelVerifier verifier(type, granularity); \
- EXPECT_ANY_THROW(verifier.verify(g.g())); \
+#define TEST_WITH_WRONG_GRANULARITY_TARGET(graph, type, granularity_, target) \
+ do \
+ { \
+ graph g; \
+ g.init(); \
+ auto node = loco::must_cast<luci::CircleNode *>(target); \
+ run_phase(g.g(), type, granularity_); \
+ auto after_node = loco::must_cast<luci::CircleNode *>(target); \
+ insert_scale_zp(after_node, 1.0, 1); \
+ auto ctx = std::make_unique<luci::QuantizedModelVerifier::Context>(); \
+ { \
+ ctx->output_model_dtype = type; \
+ ctx->granularity = granularity_; \
+ ctx->input_types = {type}; \
+ ctx->output_types = {type}; \
+ } \
+ luci::QuantizedModelVerifier verifier(std::move(ctx)); \
+ EXPECT_ANY_THROW(verifier.verify(g.g())); \
} while (0)
// Test a local helper function
@@ -2512,6 +2667,29 @@ TEST(QuantizedModelVerifierTest, Add_wrong_granularity_NEG)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, Add_inttype)
+{
+ // Tests for S32
+ TEST_WITH_GRAPH(IntAddTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(IntAddTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(IntAddTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ // Tests for S64
+ TEST_WITH_GRAPH(IntAddTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(IntAddTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(IntAddTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(IntAddTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+
+ SUCCEED();
+}
+
TEST(QuantizedModelVerifierTest, Mul)
{
TEST_WITH_GRAPH(MulTestGraph, Type::U8, Granularity::LayerWise);
@@ -2544,6 +2722,29 @@ TEST(QuantizedModelVerifierTest, Mul_wrong_granularity_NEG)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, Mul_inttype)
+{
+ // Tests for S32
+ TEST_WITH_GRAPH(IntMulTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(IntMulTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(IntMulTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S32>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S32>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S32>, Type::S16, Granularity::ChannelWise);
+
+ // Tests for S64
+ TEST_WITH_GRAPH(IntMulTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(IntMulTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(IntMulTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S64>, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S64>, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(IntMulTestGraph<Type::S64>, Type::S16, Granularity::ChannelWise);
+
+ SUCCEED();
+}
+
// TODO Add following testcases
//
// CircleConv2D
diff --git a/compiler/luci/pass/src/RemoveDuplicateConstPass.cpp b/compiler/luci/pass/src/RemoveDuplicateConstPass.cpp
new file mode 100644
index 000000000..e50dda9e0
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveDuplicateConstPass.cpp
@@ -0,0 +1,225 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveDuplicateConstPass.h"
+
+#include <luci/Log.h>
+
+namespace
+{
+
+bool compare_quant_params(luci::CircleConst *left, luci::CircleConst *right)
+{
+ const auto left_quant_param = left->quantparam();
+ const auto right_quant_param = right->quantparam();
+
+ if (left_quant_param == right_quant_param)
+ return true;
+
+ if (left_quant_param != nullptr and right_quant_param != nullptr)
+ {
+ if (left_quant_param->scale == right_quant_param->scale and
+ left_quant_param->quantized_dimension == right_quant_param->quantized_dimension and
+ left_quant_param->zerop == right_quant_param->zerop and
+ left_quant_param->min == right_quant_param->min and
+ left_quant_param->max == right_quant_param->max)
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool compare_dim_values(luci::CircleConst *left, luci::CircleConst *right)
+{
+ const auto left_rank = left->rank();
+ const auto right_rank = right->rank();
+
+ if (left_rank != right_rank)
+ return false;
+
+ for (uint32_t i = 0; i < left_rank; ++i)
+ {
+ if (left->dim(i).value() != right->dim(i).value())
+ return false;
+ }
+
+ return true;
+}
+
+template <loco::DataType DT> bool is_equal_consts(luci::CircleConst *left, luci::CircleConst *right)
+{
+ if (not compare_quant_params(left, right))
+ return false;
+
+ if (not compare_dim_values(left, right))
+ return false;
+
+ for (uint32_t i = 0; i < left->size<DT>(); ++i)
+ {
+ if (left->at<DT>(i) != right->at<DT>(i))
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool RemoveDuplicateConstPass::remove_duplicate_const()
+{
+ bool changed = false;
+
+ for (auto &cur_pair : _sum_to_const)
+ {
+ // if single const - continue
+ if (cur_pair.second.size() == 1)
+ continue;
+
+ for (auto reference_const : cur_pair.second)
+ {
+ if (reference_const == nullptr)
+ continue;
+
+ for (uint32_t i = 0; i < cur_pair.second.size(); ++i)
+ {
+ auto cur_const = cur_pair.second.at(i);
+ if (cur_const == nullptr or cur_const == reference_const)
+ continue;
+
+ if (cur_const->dtype() != reference_const->dtype())
+ continue;
+
+ bool is_equal = false;
+
+ switch (cur_const->dtype())
+ {
+ case loco::DataType::FLOAT32:
+ is_equal = is_equal_consts<loco::DataType::FLOAT32>(reference_const, cur_const);
+ break;
+ case loco::DataType::S32:
+ is_equal = is_equal_consts<loco::DataType::S32>(reference_const, cur_const);
+ break;
+ case loco::DataType::S16:
+ is_equal = is_equal_consts<loco::DataType::S16>(reference_const, cur_const);
+ break;
+ case loco::DataType::S8:
+ is_equal = is_equal_consts<loco::DataType::S8>(reference_const, cur_const);
+ break;
+ case loco::DataType::U8:
+ is_equal = is_equal_consts<loco::DataType::U8>(reference_const, cur_const);
+ break;
+ default:
+ continue;
+ }
+
+ if (not is_equal)
+ continue;
+
+ loco::replace(cur_const).with(reference_const);
+
+ // Remove from next checking
+ cur_pair.second[i] = nullptr;
+
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
+}
+
+template <loco::DataType DT>
+void RemoveDuplicateConstPass::add_to_map(luci::CircleConst *const_node)
+{
+ const auto const_size = const_node->size<DT>();
+ float sum = 0.0;
+
+ for (uint32_t i = 0; i < const_size; ++i)
+ {
+ sum += const_node->at<DT>(i);
+ }
+
+ if (_sum_to_const.find(sum) == _sum_to_const.end())
+ {
+ _sum_to_const[sum] = {const_node};
+ }
+ else
+ {
+ _sum_to_const.at(sum).push_back(const_node);
+ }
+}
+
+/**
+ * Remove duplicate Const nodes.
+ *
+ * BEFORE
+ * [CircleNode] [CircleConst]
+ * | /
+ * | /
+ * [CircleNode] [CircleConst]
+ * | /
+ * | /
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode] [CircleConst]
+ * | / /
+ * | / /
+ * [CircleNode] /
+ * | /
+ * | /
+ * [CircleNode]
+ *
+ */
+bool RemoveDuplicateConstPass::run(loco::Graph *g)
+{
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto const_node = dynamic_cast<luci::CircleConst *>(node);
+ if (const_node == nullptr)
+ continue;
+
+ switch (const_node->dtype())
+ {
+ case loco::DataType::FLOAT32:
+ add_to_map<loco::DataType::FLOAT32>(const_node);
+ break;
+ case loco::DataType::S32:
+ add_to_map<loco::DataType::S32>(const_node);
+ break;
+ case loco::DataType::S16:
+ add_to_map<loco::DataType::S16>(const_node);
+ break;
+ case loco::DataType::S8:
+ add_to_map<loco::DataType::S8>(const_node);
+ break;
+ case loco::DataType::U8:
+ add_to_map<loco::DataType::U8>(const_node);
+ break;
+ default:
+ continue;
+ }
+ }
+
+ return remove_duplicate_const();
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveDuplicateConstPass.test.cpp b/compiler/luci/pass/src/RemoveDuplicateConstPass.test.cpp
new file mode 100644
index 000000000..5052a3e01
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveDuplicateConstPass.test.cpp
@@ -0,0 +1,159 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveDuplicateConstPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/test/TestIOGraph.h>
+#include <gtest/gtest.h>
+
+namespace
+{
+using namespace luci::test;
+
+class DuplicateConstsGraphlet
+{
+public:
+ DuplicateConstsGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, bool is_duplicate)
+ {
+ _reshape_shape = g->nodes()->create<luci::CircleConst>();
+ _reshape_shape->rank(1);
+ _reshape_shape->dim(0).set(1);
+ _reshape_shape->shape_status(luci::ShapeStatus::VALID);
+ _reshape_shape->dtype(loco::DataType::S32);
+
+ _reshape_shape->size<loco::DataType::S32>(1);
+ _reshape_shape->at<loco::DataType::S32>(0) = 5;
+ _reshape_shape->name("reshape_shape_1");
+
+ _reshape_shape_duplicate = g->nodes()->create<luci::CircleConst>();
+ _reshape_shape_duplicate->rank(1);
+ _reshape_shape_duplicate->dim(0).set(1);
+ _reshape_shape_duplicate->shape_status(luci::ShapeStatus::VALID);
+ _reshape_shape_duplicate->dtype(loco::DataType::S32);
+ if (is_duplicate)
+ {
+ _reshape_shape_duplicate->size<loco::DataType::S32>(1);
+ _reshape_shape_duplicate->at<loco::DataType::S32>(0) = 5;
+ }
+ else
+ {
+ _reshape_shape_duplicate->size<loco::DataType::S32>(2);
+ _reshape_shape_duplicate->at<loco::DataType::S32>(0) = 1;
+ _reshape_shape_duplicate->at<loco::DataType::S32>(1) = 5;
+ }
+ _reshape_shape_duplicate->name("reshape_shape_2");
+
+ _reshape_f = g->nodes()->create<luci::CircleReshape>();
+ _reshape_f->newShape()->rank(1);
+ _reshape_f->newShape()->dim(0) = 5;
+ _reshape_f->name("reshape_f");
+
+ _reshape_s = g->nodes()->create<luci::CircleReshape>();
+ if (is_duplicate)
+ {
+ _reshape_s->newShape()->rank(1);
+ _reshape_s->newShape()->dim(0) = 5;
+ }
+ else
+ {
+ _reshape_s->newShape()->rank(2);
+ _reshape_s->newShape()->dim(0) = 1;
+ _reshape_s->newShape()->dim(1) = 5;
+ }
+ _reshape_s->name("reshape_s");
+ }
+
+protected:
+ luci::CircleReshape *_reshape_f = nullptr;
+ luci::CircleReshape *_reshape_s = nullptr;
+ luci::CircleConst *_reshape_shape = nullptr;
+ luci::CircleConst *_reshape_shape_duplicate = nullptr;
+};
+
+class DuplicateConstsGraph : public TestIOGraph, public DuplicateConstsGraphlet
+{
+public:
+ DuplicateConstsGraph() = default;
+
+public:
+ void init(const ShapeU32 in_shape, const ShapeU32 out_shape, bool is_duplicate)
+ {
+ TestIOGraph::init(in_shape, out_shape);
+
+ DuplicateConstsGraphlet::init(g(), is_duplicate);
+
+ // connect graph
+ _reshape_f->tensor(input());
+ _reshape_f->shape(_reshape_shape);
+
+ _reshape_s->tensor(_reshape_f);
+ _reshape_s->shape(_reshape_shape_duplicate);
+
+ output()->from(_reshape_s);
+ }
+};
+} // namespace
+
+TEST(RemoveDuplicateConstPass, name)
+{
+ luci::RemoveDuplicateConstPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST(RemoveDuplicateConstPass, remove_duplicate)
+{
+ DuplicateConstsGraph g;
+ g.init({1, 5}, {5}, true);
+
+ luci::RemoveDuplicateConstPass pass;
+ while (pass.run(g.g()))
+ ;
+
+ uint32_t const_num = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(g.g())))
+ {
+ auto target_node = dynamic_cast<luci::CircleConst *>(node);
+ if (target_node != nullptr)
+ const_num++;
+ }
+
+ ASSERT_EQ(const_num, 1);
+}
+
+TEST(RemoveDuplicateConstPass, remove_duplicate_NEG)
+{
+ DuplicateConstsGraph g;
+ g.init({1, 5}, {1, 5}, false);
+
+ luci::RemoveDuplicateConstPass pass;
+ while (pass.run(g.g()))
+ ;
+
+ uint32_t const_num = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(g.g())))
+ {
+ auto target_node = dynamic_cast<luci::CircleConst *>(node);
+ if (target_node != nullptr)
+ const_num++;
+ }
+
+ ASSERT_EQ(const_num, 2);
+}
diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp
index 741b70956..07457c1e8 100644
--- a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp
+++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp
@@ -64,6 +64,40 @@ luci::CircleNode *fromActivation(luci::CircleNode *inp, luci::FusedActFunc act)
}
}
+// Create CircleReshape where
+// - dtype is same with node
+// - shape is same with node
+// NOTE: User should set input(tensor) of the returned Op.
+luci::CircleReshape *create_reshape(luci::CircleFullyConnected *node)
+{
+ assert(node); // FIX_CALLER_UNLESS
+
+ auto g = node->graph();
+
+ auto reshape = g->nodes()->create<luci::CircleReshape>();
+ reshape->name(node->name() + "/reshape");
+ reshape->dtype(node->dtype());
+ luci::add_origin(reshape, luci::get_origin(node));
+
+ auto shape_const = g->nodes()->create<luci::CircleConst>();
+ shape_const->dtype(loco::DataType::S32);
+ shape_const->rank(1);
+ shape_const->dim(0).set(node->rank());
+ shape_const->size<loco::DataType::S32>(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ {
+ assert(node->dim(i).known()); // FIX_CALLER_UNLESS
+ shape_const->at<loco::DataType::S32>(i) = node->dim(i).value();
+ }
+ shape_const->shape_status(luci::ShapeStatus::VALID);
+ shape_const->name(node->name() + "/shape");
+ luci::add_origin(shape_const, luci::get_origin(node));
+
+ reshape->shape(shape_const);
+
+ return reshape;
+}
+
/**
* Replace Fully Connected with Batched MatMul
*
@@ -79,19 +113,23 @@ luci::CircleNode *fromActivation(luci::CircleNode *inp, luci::FusedActFunc act)
*
* [Node1] [Node2]
* \ /
- * [BatchMatMul] [BiasValue]?
+ * [BatchMatMul]
+ * |
+ * [Reshape] [BiasValue]?
* \ /
* [Add]?
* |
* [Activation]?
*
* Nodes with "?" denote optional elements
+ * NOTE Reshape Op is inserted to keep the original shape of FullyConnected Op
+ * Reshape Op can be redundant (input shape == output shape). This can be removed
+ * by RemoveUnnecessaryReshapePass.
*/
bool replace_fc_with_matmul(luci::CircleFullyConnected *fc)
{
luci::CircleNode *x = nullptr;
luci::CircleNode *y = nullptr;
- luci::CircleNode *b = nullptr;
luci::CircleTranspose *ty = nullptr;
luci::CircleTranspose *tx = nullptr;
bool adj_x = false;
@@ -122,10 +160,13 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc)
x = loco::must_cast<luci::CircleNode *>(fc->input());
}
- b = loco::must_cast<luci::CircleNode *>(fc->bias());
+ if (x->dtype() != loco::DataType::FLOAT32 || y->dtype() != loco::DataType::FLOAT32)
+ return false;
- if (x->dtype() != loco::DataType::FLOAT32 || y->dtype() != loco::DataType::FLOAT32 ||
- b->dtype() != loco::DataType::FLOAT32)
+ auto bc = dynamic_cast<luci::CircleConst *>(fc->bias());
+ // NOTE bias can be empty as CircleOutputExclude type
+ // NOTE we can only handle bias as FLOAT32 type as of now
+ if (nullptr != bc && bc->dtype() != loco::DataType::FLOAT32)
return false;
auto name = fc->name();
@@ -141,6 +182,9 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc)
luci::add_origin(matmul, luci::get_origin(fc));
+ auto reshape = create_reshape(fc);
+ reshape->tensor(matmul);
+
auto all_zero = [](const luci::CircleConst *c) {
bool ac = true;
for (uint32_t i = 0; i < c->size<loco::DataType::FLOAT32>() && ac; i++)
@@ -150,12 +194,11 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc)
return ac;
};
- auto bc = dynamic_cast<luci::CircleConst *>(b);
- if ((nullptr != bc) && !all_zero(bc))
+ if (nullptr != bc && !all_zero(bc))
{
auto bias_add = fc->graph()->nodes()->create<luci::CircleAdd>();
- bias_add->x(matmul);
- bias_add->y(b);
+ bias_add->x(reshape);
+ bias_add->y(bc);
bias_add->name(fc->name() + "/bias_add");
bias_add->dtype(fc->dtype());
add_origin(bias_add, get_origin(fc));
@@ -164,7 +207,8 @@ bool replace_fc_with_matmul(luci::CircleFullyConnected *fc)
}
else
{
- auto n = fromActivation(matmul, fc->fusedActivationFunction());
+ // NOTE bias doesn't exist or bias is all zero
+ auto n = fromActivation(reshape, fc->fusedActivationFunction());
add_origin(n, luci::get_origin(fc));
n->name(fc->name() + "fusedActivation");
n->dtype(fc->dtype());
diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp
index 7606a6125..93024f3f7 100644
--- a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp
+++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp
@@ -159,8 +159,8 @@ TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, simple_test)
auto ret = pass.run(g.g());
EXPECT_EQ(true, ret);
- auto mm = dynamic_cast<luci::CircleBatchMatMul *>(g.output()->from());
- EXPECT_NE(nullptr, mm);
+ auto res = dynamic_cast<luci::CircleReshape *>(g.output()->from());
+ EXPECT_NE(nullptr, res);
}
TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, nonzero_bias_test)
diff --git a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp
index 1e8f681c8..f61882796 100644
--- a/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp
+++ b/compiler/luci/pass/src/ResolveCustomOpMatMulPass.cpp
@@ -153,7 +153,6 @@ bool resolve_matmul(luci::CircleCustom *cop)
}
auto empty_bias = graph->nodes()->create<luci::CircleOutputExclude>();
- empty_bias->dtype(loco::DataType::FLOAT32); // Needed for type inference
auto fc_node = graph->nodes()->create<luci::CircleFullyConnected>();
fc_node->input(lhs);
diff --git a/compiler/luci/pass/src/ResolveCustomOpMaxPoolWithArgmaxPass.cpp b/compiler/luci/pass/src/ResolveCustomOpMaxPoolWithArgmaxPass.cpp
index f37f27742..7c038d56d 100644
--- a/compiler/luci/pass/src/ResolveCustomOpMaxPoolWithArgmaxPass.cpp
+++ b/compiler/luci/pass/src/ResolveCustomOpMaxPoolWithArgmaxPass.cpp
@@ -23,6 +23,7 @@
#include <loco.h>
#include <oops/InternalExn.h>
+#include <limits> // std::numeric_limits
#include <flatbuffers/flexbuffers.h>
diff --git a/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp
index a65065800..5a09e3930 100644
--- a/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp
+++ b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp
@@ -20,6 +20,8 @@
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Service/Nodes/CircleConst.h>
+#include <limits> // std::numeric_limits
+
namespace
{
diff --git a/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp b/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp
new file mode 100644
index 000000000..b73efafa5
--- /dev/null
+++ b/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.cpp
@@ -0,0 +1,672 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h"
+
+#include "helpers/NodeFiller.h"
+#include "helpers/TypeMapper.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+#include <string>
+#include <vector>
+
+/**
+ * BEFORE
+ * [CircleNode]
+ * |
+ * [UnidirectionalSequenceLSTM]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode]
+ * |
+ * [CircleTranspose]
+ * |
+ * [CircleUnpack]
+ * |
+ * [CircleUnpackOut]
+ * |
+ * (Unrolled sub network)
+ * |
+ * [CirclePack]
+ * | |
+ * [CircleTranspose] [UnidirectionalSequenceLSTM]
+ * | |
+ * [CircleNode]
+ *
+ * NOTE for timesteps = 1,
+ * first [CircleTranspose] is not added and
+ * last [CirclePack] + [CircleTranspose] is replaced with [CircleReshape]
+ *
+ * First unrolled sub network is as follows
+ * - [] and 'Circle' are omitted
+ * - all FC has one or two Const for Weight/Bias
+ *
+ * (input)
+ * |
+ * FC
+ * |
+ * Split
+ * +---------+----------+----------+
+ * | | | |
+ * | Logistic Logistic Tanh
+ * | Const | | |
+ * | | | | |
+ * | +-- Mul +-- Mul ---+
+ * | | |
+ * | +---- Add ------+
+ * | |
+ * | +----+----+
+ * | | |
+ * Logistic Tanh |
+ * | | |
+ * +-- Mul ----+ |
+ * | |
+ * (output) (A)
+ *
+ * and following unrolled sub networks are;
+ *
+ * (prev-output) (input)
+ * | |
+ * FC FC
+ * | |
+ * +--- Add --+
+ * Const |
+ * | |
+ * +------ Add
+ * |
+ * Split
+ * |
+ * +---------+----------+----------+
+ * SplitOut SplitOut SplitOut SplitOut
+ * | | | |
+ * | Logistic Logistic Tanh
+ * | (A') | | |
+ * | | | | |
+ * | +--- Mul +-- Mul ---+
+ * | | |
+ * | +---- Add ------+
+ * | |
+ * | +----+----+
+ * | | |
+ * Logistic Tanh |
+ * | | |
+ * +-- Mul ----+ |
+ * | |
+ * (output) (next)
+ *
+ * where (A) and (A') are connected
+ *
+ */
+
+namespace
+{
+
+struct UnrollLSTM
+{
+ luci::CircleConst *transpose_perm(void);
+ luci::CircleTranspose *first_transpose(luci::CircleNode *input);
+ std::vector<luci::CircleUnpackOut *> input_unpacks(luci::CircleNode *input);
+ luci::CircleConst *merged_weights(luci::CircleConst *iw, luci::CircleConst *fw,
+ luci::CircleConst *cw, luci::CircleConst *ow);
+ luci::CircleFullyConnected *create_input_matmul(luci::CircleNode *input);
+ luci::CircleAdd *create_input_matmul(luci::CircleNode *input, luci::CircleMul *mul,
+ uint32_t step);
+ std::vector<luci::CircleSplitOut *> matmul_splits(luci::CircleNode *input, uint32_t step);
+ luci::CircleConst *forget_zero(void);
+ luci::CircleMul *forget_gate_cell(std::vector<luci::CircleSplitOut *> &splits,
+ luci::CircleNode *prev, uint32_t step,
+ luci::CircleNode **retadd);
+ luci::CircleReshape *last_reshape(luci::CircleNode *input);
+ luci::CircleTranspose *last_transpose(std::vector<luci::CircleMul *> &output_muls);
+
+ luci::CircleUnidirectionalSequenceLSTM *_lstm{nullptr};
+ loco::Graph::NodeContext *_nctx{nullptr};
+ std::string _name;
+ uint32_t _batch{0};
+ uint32_t _timesteps{0};
+ uint32_t _units{0}; // output space dim
+};
+
+luci::CircleConst *UnrollLSTM::transpose_perm(void)
+{
+ auto perm = _nctx->create<luci::CircleConst>();
+ perm->dtype(loco::DataType::S32);
+ perm->rank(1);
+ perm->dim(0) = 3;
+ perm->size<loco::DataType::S32>(3);
+ perm->at<loco::DataType::S32>(0) = 1;
+ perm->at<loco::DataType::S32>(1) = 0;
+ perm->at<loco::DataType::S32>(2) = 2;
+ perm->shape_status(luci::ShapeStatus::VALID);
+
+ return perm;
+}
+
+luci::CircleTranspose *UnrollLSTM::first_transpose(luci::CircleNode *input)
+{
+ assert(input != nullptr);
+
+ auto perm = transpose_perm();
+ perm->name(_name + "_perm1");
+ luci::add_origin(perm, luci::get_origin(_lstm));
+
+ auto transpose = _nctx->create<luci::CircleTranspose>();
+ transpose->a(input);
+ transpose->perm(perm);
+ transpose->name(_name + "_trans1");
+ luci::add_origin(transpose, luci::get_origin(_lstm));
+
+ return transpose;
+}
+
+std::vector<luci::CircleUnpackOut *> UnrollLSTM::input_unpacks(luci::CircleNode *input)
+{
+ assert(input != nullptr);
+
+ // NOTE unpack input can be LSTM or Transpose
+ auto unpack = _nctx->create<luci::CircleUnpack>();
+ unpack->num(_timesteps);
+ unpack->axis(0);
+ unpack->value(input);
+ unpack->name(_name + "_unpack");
+ luci::add_origin(unpack, luci::get_origin(_lstm));
+
+ std::vector<luci::CircleUnpackOut *> outs;
+ for (uint32_t idx = 0; idx < _timesteps; ++idx)
+ {
+ auto unpackout = _nctx->create<luci::CircleUnpackOut>();
+ unpackout->input(unpack);
+ unpackout->index(idx);
+ unpackout->name(_name + "_unpackout_" + std::to_string(idx));
+ luci::add_origin(unpackout, luci::get_origin(_lstm));
+ outs.push_back(unpackout);
+ }
+
+ return outs;
+}
+
+luci::CircleConst *UnrollLSTM::merged_weights(luci::CircleConst *iw, luci::CircleConst *fw,
+ luci::CircleConst *cw, luci::CircleConst *ow)
+{
+ assert(iw != nullptr);
+ assert(fw != nullptr);
+ assert(cw != nullptr);
+ assert(ow != nullptr);
+
+ auto iw_rank = iw->rank();
+ assert(iw_rank == fw->rank());
+ assert(iw_rank == cw->rank());
+ assert(iw_rank == ow->rank());
+
+ uint32_t ne_w = 1;
+ for (uint32_t i = 0; i < iw_rank; i++)
+ ne_w *= iw->dim(i).value();
+
+ assert(iw->dtype() == loco::DataType::FLOAT32);
+ assert(fw->dtype() == loco::DataType::FLOAT32);
+ assert(cw->dtype() == loco::DataType::FLOAT32);
+ assert(ow->dtype() == loco::DataType::FLOAT32);
+
+ // merged weights
+ auto mw = _nctx->create<luci::CircleConst>();
+ mw->dtype(iw->dtype());
+ mw->rank(iw_rank);
+ mw->dim(0) = 4u * iw->dim(0).value();
+ for (uint32_t i = 1; i < iw_rank; i++)
+ mw->dim(i) = iw->dim(i);
+ mw->size<loco::DataType::FLOAT32>(4 * ne_w);
+ mw->shape_status(luci::ShapeStatus::VALID);
+ for (uint32_t i = 0; i < ne_w; ++i)
+ {
+ mw->at<loco::DataType::FLOAT32>(i + ne_w * 0) = iw->at<loco::DataType::FLOAT32>(i);
+ mw->at<loco::DataType::FLOAT32>(i + ne_w * 1) = fw->at<loco::DataType::FLOAT32>(i);
+ mw->at<loco::DataType::FLOAT32>(i + ne_w * 2) = cw->at<loco::DataType::FLOAT32>(i);
+ mw->at<loco::DataType::FLOAT32>(i + ne_w * 3) = ow->at<loco::DataType::FLOAT32>(i);
+ }
+ return mw;
+}
+
+luci::CircleFullyConnected *UnrollLSTM::create_input_matmul(luci::CircleNode *input)
+{
+ assert(input != nullptr);
+
+ // weights
+ auto iw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_input_weights());
+ auto fw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_forget_weights());
+ auto cw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_cell_weights());
+ auto ow = loco::must_cast<luci::CircleConst *>(_lstm->input_to_output_weights());
+
+ auto fcw = merged_weights(iw, fw, cw, ow);
+ fcw->name(_name + "_fc_w");
+ luci::add_origin(fcw, luci::get_origin(_lstm));
+
+ // bias
+ auto ib = loco::must_cast<luci::CircleConst *>(_lstm->input_gate_bias());
+ auto fb = loco::must_cast<luci::CircleConst *>(_lstm->forget_gate_bias());
+ auto cb = loco::must_cast<luci::CircleConst *>(_lstm->cell_gate_bias());
+ auto ob = loco::must_cast<luci::CircleConst *>(_lstm->output_gate_bias());
+
+ auto fcb = merged_weights(ib, fb, cb, ob);
+ fcb->name(_name + "_fc_b");
+ luci::add_origin(fcb, luci::get_origin(_lstm));
+
+ auto fc = _nctx->create<luci::CircleFullyConnected>();
+ fc->input(input);
+ fc->weights(fcw);
+ fc->bias(fcb);
+ fc->fusedActivationFunction(luci::FusedActFunc::NONE);
+ fc->name(_name + "_fc");
+ luci::add_origin(fc, luci::get_origin(_lstm));
+
+ return fc;
+}
+
+luci::CircleAdd *UnrollLSTM::create_input_matmul(luci::CircleNode *input, luci::CircleMul *mul,
+ uint32_t step)
+{
+ assert(input != nullptr);
+ assert(mul != nullptr);
+ assert(step < _timesteps);
+
+ auto base_name = _name + "_matmul" + std::to_string(step);
+
+ // input weights
+ auto iw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_input_weights());
+ auto fw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_forget_weights());
+ auto cw = loco::must_cast<luci::CircleConst *>(_lstm->input_to_cell_weights());
+ auto ow = loco::must_cast<luci::CircleConst *>(_lstm->input_to_output_weights());
+
+ auto fcw = merged_weights(iw, fw, cw, ow);
+ fcw->name(base_name + "_fc_w");
+ luci::add_origin(fcw, luci::get_origin(_lstm));
+
+ auto fcb = _nctx->create<luci::CircleOutputExclude>();
+
+ auto fc = _nctx->create<luci::CircleFullyConnected>();
+ fc->input(input);
+ fc->weights(fcw);
+ fc->bias(fcb);
+ fc->fusedActivationFunction(luci::FusedActFunc::NONE);
+ fc->name(base_name + "_fc");
+ luci::add_origin(fc, luci::get_origin(_lstm));
+
+ // recurrent weights
+ auto ri = loco::must_cast<luci::CircleConst *>(_lstm->recurrent_to_input_weights());
+ auto rf = loco::must_cast<luci::CircleConst *>(_lstm->recurrent_to_forget_weights());
+ auto rc = loco::must_cast<luci::CircleConst *>(_lstm->recurrent_to_cell_weights());
+ auto ro = loco::must_cast<luci::CircleConst *>(_lstm->recurrent_to_output_weights());
+
+ auto fcrw = merged_weights(ri, rf, rc, ro);
+ fcrw->name(base_name + "_fcr_w");
+ luci::add_origin(fcrw, luci::get_origin(_lstm));
+
+ auto fcrb = _nctx->create<luci::CircleOutputExclude>();
+
+ auto fcr = _nctx->create<luci::CircleFullyConnected>();
+ fcr->input(mul);
+ fcr->weights(fcrw);
+ fcr->bias(fcrb);
+ fcr->fusedActivationFunction(luci::FusedActFunc::NONE);
+ fcr->name(base_name + "_fcr");
+ luci::add_origin(fcr, luci::get_origin(_lstm));
+
+ auto add_fc = _nctx->create<luci::CircleAdd>();
+ add_fc->x(fcr);
+ add_fc->y(fc);
+ add_fc->fusedActivationFunction(luci::FusedActFunc::NONE);
+ add_fc->name(base_name + "_addfc");
+ luci::add_origin(add_fc, luci::get_origin(_lstm));
+
+ // bias
+ auto ib = loco::must_cast<luci::CircleConst *>(_lstm->input_gate_bias());
+ auto fb = loco::must_cast<luci::CircleConst *>(_lstm->forget_gate_bias());
+ auto cb = loco::must_cast<luci::CircleConst *>(_lstm->cell_gate_bias());
+ auto ob = loco::must_cast<luci::CircleConst *>(_lstm->output_gate_bias());
+
+ auto bias = merged_weights(ib, fb, cb, ob);
+ bias->name(base_name + "_bias");
+
+ auto add_bias = _nctx->create<luci::CircleAdd>();
+ add_bias->x(add_fc);
+ add_bias->y(bias);
+ add_bias->fusedActivationFunction(luci::FusedActFunc::NONE);
+ add_bias->name(base_name + "_addbias");
+ luci::add_origin(add_bias, luci::get_origin(_lstm));
+
+ return add_bias;
+}
+
+std::vector<luci::CircleSplitOut *> UnrollLSTM::matmul_splits(luci::CircleNode *input,
+ uint32_t step)
+{
+ assert(input != nullptr);
+ assert(step < _timesteps);
+
+ std::string split_name = _name + "_sp" + std::to_string(step);
+
+ auto split_dim = _nctx->create<luci::CircleConst>();
+ split_dim->dtype(loco::DataType::S32);
+ split_dim->rank(1);
+ split_dim->dim(0) = 1;
+ split_dim->size<loco::DataType::S32>(1);
+ split_dim->at<loco::DataType::S32>(0) = 1;
+ split_dim->shape_status(luci::ShapeStatus::VALID);
+ split_dim->name(split_name + "_dim");
+ luci::add_origin(split_dim, luci::get_origin(_lstm));
+
+ auto split = _nctx->create<luci::CircleSplit>();
+ split->num_split(4);
+ split->split_dim(split_dim);
+ split->input(input);
+ split->name(split_name);
+ luci::add_origin(split, luci::get_origin(_lstm));
+
+ auto split_o0 = _nctx->create<luci::CircleSplitOut>();
+ split_o0->input(split);
+ split_o0->index(0);
+ split_o0->name(split_name + "_spo0");
+ luci::add_origin(split_o0, luci::get_origin(_lstm));
+
+ auto split_o1 = _nctx->create<luci::CircleSplitOut>();
+ split_o1->input(split);
+ split_o1->index(1);
+ split_o1->name(split_name + "_spo1");
+ luci::add_origin(split_o1, luci::get_origin(_lstm));
+
+ auto split_o2 = _nctx->create<luci::CircleSplitOut>();
+ split_o2->input(split);
+ split_o2->index(2);
+ split_o2->name(split_name + "_spo2");
+ luci::add_origin(split_o2, luci::get_origin(_lstm));
+
+ auto split_o3 = _nctx->create<luci::CircleSplitOut>();
+ split_o3->input(split);
+ split_o3->index(3);
+ split_o3->name(split_name + "_spo3");
+ luci::add_origin(split_o3, luci::get_origin(_lstm));
+
+ std::vector<luci::CircleSplitOut *> outs;
+ outs.push_back(split_o0);
+ outs.push_back(split_o1);
+ outs.push_back(split_o2);
+ outs.push_back(split_o3);
+ return outs;
+}
+
+luci::CircleConst *UnrollLSTM::forget_zero(void)
+{
+ uint32_t amount = _batch * _units;
+
+ auto zero = _nctx->create<luci::CircleConst>();
+ zero->dtype(loco::DataType::FLOAT32);
+ zero->rank(2);
+ zero->dim(0) = _batch;
+ zero->dim(1) = _units;
+ zero->size<loco::DataType::FLOAT32>(amount);
+ for (uint32_t idx = 0; idx < amount; ++idx)
+ zero->at<loco::DataType::FLOAT32>(idx) = 0.0;
+ zero->shape_status(luci::ShapeStatus::VALID);
+ zero->name(_name + "_zero");
+ luci::add_origin(zero, luci::get_origin(_lstm));
+ return zero;
+}
+
+luci::CircleMul *UnrollLSTM::forget_gate_cell(std::vector<luci::CircleSplitOut *> &splits,
+ luci::CircleNode *prev, uint32_t step,
+ luci::CircleNode **retadd)
+{
+ assert(splits.size() > 0);
+ assert(prev != nullptr);
+ assert(step < _timesteps);
+
+ std::string net_name = _name + "_net" + std::to_string(step);
+
+ auto split_0 = splits[0]; // input-input : Logistic - Mul(c) - Add - Tanh - Mul
+ auto split_1 = splits[1]; // input-forget : Logistic - Mul(p) - Add - Tanh - Mul
+ auto split_2 = splits[2]; // input-cell : Tanh - Mul(c) - Add - Tanh - Mul
+ auto split_3 = splits[3]; // input-output : Logistic - Mul
+
+ auto logis_0 = _nctx->create<luci::CircleLogistic>();
+ logis_0->x(split_0);
+ logis_0->name(net_name + "_log0");
+ luci::add_origin(logis_0, luci::get_origin(_lstm));
+
+ auto logis_1 = _nctx->create<luci::CircleLogistic>();
+ logis_1->x(split_1);
+ logis_1->name(net_name + "_log1");
+ luci::add_origin(logis_1, luci::get_origin(_lstm));
+
+ auto tanh_2 = _nctx->create<luci::CircleTanh>();
+ tanh_2->x(split_2);
+ tanh_2->name(net_name + "_tanh2");
+ luci::add_origin(tanh_2, luci::get_origin(_lstm));
+
+ auto logis_3 = _nctx->create<luci::CircleLogistic>();
+ logis_3->x(split_3);
+ logis_3->name(net_name + "_log3");
+ luci::add_origin(logis_3, luci::get_origin(_lstm));
+
+ auto mul_c = _nctx->create<luci::CircleMul>();
+ mul_c->x(logis_0);
+ mul_c->y(tanh_2);
+ mul_c->fusedActivationFunction(luci::FusedActFunc::NONE);
+ mul_c->name(net_name + "_mul1");
+ luci::add_origin(mul_c, luci::get_origin(_lstm));
+
+ auto mul_p = _nctx->create<luci::CircleMul>();
+ mul_p->x(logis_1);
+ mul_p->y(prev);
+ mul_p->fusedActivationFunction(luci::FusedActFunc::NONE);
+ mul_p->name(net_name + "_mul2");
+ luci::add_origin(mul_p, luci::get_origin(_lstm));
+
+ auto add_cp = _nctx->create<luci::CircleAdd>();
+ add_cp->x(mul_c);
+ add_cp->y(mul_p);
+ add_cp->fusedActivationFunction(luci::FusedActFunc::NONE);
+ add_cp->name(net_name + "_add1");
+ luci::add_origin(add_cp, luci::get_origin(_lstm));
+
+ if (retadd != nullptr)
+ *retadd = add_cp;
+
+ auto tanh_cp = _nctx->create<luci::CircleTanh>();
+ tanh_cp->x(add_cp);
+ tanh_cp->name(net_name + "_tanh3");
+ luci::add_origin(tanh_cp, luci::get_origin(_lstm));
+
+ auto mul_out = _nctx->create<luci::CircleMul>();
+ mul_out->x(logis_3);
+ mul_out->y(tanh_cp);
+ mul_out->fusedActivationFunction(luci::FusedActFunc::NONE);
+ mul_out->name(net_name + "_mul3");
+ luci::add_origin(mul_out, luci::get_origin(_lstm));
+
+ return mul_out;
+}
+
+luci::CircleReshape *UnrollLSTM::last_reshape(luci::CircleNode *input)
+{
+ assert(input != nullptr);
+
+ auto reshape_s = _nctx->create<luci::CircleConst>();
+ reshape_s->dtype(loco::DataType::S32);
+ reshape_s->rank(1);
+ reshape_s->dim(0) = 3;
+ reshape_s->size<loco::DataType::S32>(3);
+ reshape_s->at<loco::DataType::S32>(0) = _batch;
+ reshape_s->at<loco::DataType::S32>(1) = _timesteps;
+ reshape_s->at<loco::DataType::S32>(2) = _units;
+ reshape_s->shape_status(luci::ShapeStatus::VALID);
+ reshape_s->name(_name + "_reshape_s");
+ luci::add_origin(reshape_s, luci::get_origin(_lstm));
+
+ auto reshape = _nctx->create<luci::CircleReshape>();
+ reshape->tensor(input);
+ reshape->shape(reshape_s);
+ reshape->newShape()->rank(3);
+ reshape->newShape()->dim(0) = _batch;
+ reshape->newShape()->dim(1) = _timesteps;
+ reshape->newShape()->dim(2) = _units;
+ reshape->name(_name + "_reshape");
+ luci::add_origin(reshape, luci::get_origin(_lstm));
+
+ return reshape;
+}
+
+luci::CircleTranspose *UnrollLSTM::last_transpose(std::vector<luci::CircleMul *> &output_muls)
+{
+ assert(output_muls.size() == _timesteps);
+
+ auto pack = _nctx->create<luci::CirclePack>(_timesteps);
+ pack->axis(0);
+ for (uint32_t idx = 0; idx < _timesteps; ++idx)
+ pack->values(idx, output_muls[idx]);
+ pack->name(_name + "_pack");
+ luci::add_origin(pack, luci::get_origin(_lstm));
+
+ auto perm = transpose_perm();
+ perm->name(_name + "_perm2");
+ luci::add_origin(perm, luci::get_origin(_lstm));
+
+ auto transpose = _nctx->create<luci::CircleTranspose>();
+ transpose->a(pack);
+ transpose->perm(perm);
+ transpose->name(_name + "_trans2");
+ luci::add_origin(transpose, luci::get_origin(_lstm));
+
+ return transpose;
+}
+
+bool unroll_lstm(luci::CircleUnidirectionalSequenceLSTM *lstm)
+{
+ // NOTE shape of input of lstm is interpreted as [batch, timesteps, feature]
+ // shape of output of lstm is interpreted as [batch, timesteps, units]
+ // TODO add more conditions to check LSTM
+ assert(lstm != nullptr);
+ assert(lstm->rank() == 3); // use assert to findout when this happens
+ if (lstm->rank() != 3)
+ return false;
+ if (!(lstm->dim(0).known() and lstm->dim(1).known() and lstm->dim(2).known()))
+ return false;
+
+ UnrollLSTM ulstm;
+ ulstm._lstm = lstm;
+ ulstm._nctx = lstm->graph()->nodes();
+ ulstm._name = lstm->name();
+ ulstm._batch = lstm->dim(0).value();
+ ulstm._timesteps = lstm->dim(1).value();
+ ulstm._units = lstm->dim(2).value(); // output space dim
+
+ luci::CircleNode *input = loco::must_cast<luci::CircleNode *>(lstm->input());
+ assert(input->rank() == 3); // use assert to findout when this happens
+ if (input->rank() != 3)
+ return false;
+ assert(input->dim(0).value() == ulstm._batch);
+ assert(input->dim(1).value() == ulstm._timesteps);
+
+ if (ulstm._timesteps > 1)
+ {
+ // Transpose to switch batch <-> timesteps
+ // NOTE TF uses Reshape when batch is 1 but as there is Transpose->Reshape
+ // Pass, we can just use Transpose for both cases
+ auto transpose = ulstm.first_transpose(input);
+ input = transpose;
+ }
+
+ auto unpacks = ulstm.input_unpacks(input);
+ assert(unpacks.size() == ulstm._timesteps);
+ uint32_t step = 0;
+ auto unpackout = unpacks[step];
+
+ // First FC
+ auto fc_1 = ulstm.create_input_matmul(unpackout);
+ assert(fc_1 != nullptr);
+ auto splits = ulstm.matmul_splits(fc_1, step);
+ assert(splits.size() == 4);
+
+ luci::CircleNode *prev = nullptr; // prev step CircleAdd
+ luci::CircleNode *this_add = nullptr;
+
+ prev = ulstm.forget_zero(); // provide all zero constant for first step
+
+ std::vector<luci::CircleMul *> output_muls;
+ auto mul_gc = ulstm.forget_gate_cell(splits, prev, step, &this_add);
+ assert(mul_gc != nullptr);
+ assert(this_add != nullptr);
+ // gather all Muls for last Pack
+ output_muls.push_back(mul_gc);
+
+ for (step = 1; step < ulstm._timesteps; ++step)
+ {
+ auto unpackout = unpacks[step];
+ auto add_n = ulstm.create_input_matmul(unpackout, mul_gc, step);
+
+ auto splits = ulstm.matmul_splits(add_n, step);
+ assert(splits.size() == 4);
+
+ prev = this_add;
+ mul_gc = ulstm.forget_gate_cell(splits, prev, step, &this_add);
+ assert(mul_gc != nullptr);
+ assert(this_add != nullptr);
+
+ output_muls.push_back(mul_gc);
+ }
+ assert(output_muls.size() == ulstm._timesteps);
+
+ if (ulstm._timesteps == 1)
+ {
+ // Reshape for single step
+ auto reshape = ulstm.last_reshape(mul_gc);
+ loco::replace(lstm).with(reshape);
+ }
+ else
+ {
+ // Pack + Transpose for two or more steps
+ auto transpose = ulstm.last_transpose(output_muls);
+ loco::replace(lstm).with(transpose);
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool UnrollUnidirectionalSequenceLSTMPass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto lstm = dynamic_cast<luci::CircleUnidirectionalSequenceLSTM *>(node))
+ {
+ if (unroll_lstm(lstm))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.test.cpp b/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.test.cpp
new file mode 100644
index 000000000..3f273cbd3
--- /dev/null
+++ b/compiler/luci/pass/src/UnrollUnidirectionalSequenceLSTMPass.test.cpp
@@ -0,0 +1,211 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/UnrollUnidirectionalSequenceLSTMPass.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <luci/IR/Nodes/CircleUnidirectionalSequenceLSTM.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class UniSeqLSTMGraphlet
+{
+public:
+ UniSeqLSTMGraphlet() = default;
+
+ void init(loco::Graph *g, const ShapeU32 oshape)
+ {
+ _uslstm = g->nodes()->create<luci::CircleUnidirectionalSequenceLSTM>();
+ _uslstm->dtype(loco::DataType::FLOAT32);
+ _uslstm->shape(oshape);
+ _uslstm->name("uslstm");
+
+ _uslstm->fusedActivationFunction(luci::FusedActFunc::TANH);
+ _uslstm->cell_clip(0.0);
+ _uslstm->proj_clip(0.0);
+ _uslstm->time_major(false);
+ _uslstm->asymmetric_quantize_inputs(false);
+
+ _iw = weight_1x1(g);
+ _rw = weight_1x1(g);
+ _gb = weight_1(g);
+ _ex = g->nodes()->create<luci::CircleOutputExclude>();
+ }
+
+protected:
+ luci::CircleConst *weight_1x1(loco::Graph *g)
+ {
+ auto w = g->nodes()->create<luci::CircleConst>();
+ w->dtype(loco::DataType::FLOAT32);
+ w->rank(2);
+ w->dim(0) = 1;
+ w->dim(1) = 1;
+ w->size<loco::DataType::FLOAT32>(1);
+ w->at<loco::DataType::FLOAT32>(0) = 1.0;
+ w->shape_status(luci::ShapeStatus::VALID);
+ return w;
+ }
+
+ luci::CircleConst *weight_1(loco::Graph *g)
+ {
+ auto w = g->nodes()->create<luci::CircleConst>();
+ w->dtype(loco::DataType::FLOAT32);
+ w->rank(1);
+ w->dim(0) = 1;
+ w->size<loco::DataType::FLOAT32>(1);
+ w->at<loco::DataType::FLOAT32>(0) = 1.0;
+ w->shape_status(luci::ShapeStatus::VALID);
+ return w;
+ }
+
+protected:
+ luci::CircleUnidirectionalSequenceLSTM *_uslstm = nullptr;
+ luci::CircleConst *_iw = nullptr;
+ luci::CircleConst *_rw = nullptr;
+ luci::CircleConst *_gb = nullptr;
+ luci::CircleOutputExclude *_ex = nullptr;
+};
+
+class UnrollUniSeqLSTMPassTestGraph : public TestIOGraph, public UniSeqLSTMGraphlet
+{
+public:
+ UnrollUniSeqLSTMPassTestGraph() = default;
+
+ void init(const ShapeU32 ishape, const ShapeU32 oshape)
+ {
+ TestIOGraph::init(ishape, oshape);
+ UniSeqLSTMGraphlet::init(g(), oshape);
+
+ auto inode = input();
+ _uslstm->input(inode);
+
+ _uslstm->input_to_input_weights(_iw);
+ _uslstm->input_to_forget_weights(_iw);
+ _uslstm->input_to_cell_weights(_iw);
+ _uslstm->input_to_output_weights(_iw);
+
+ _uslstm->recurrent_to_input_weights(_rw);
+ _uslstm->recurrent_to_forget_weights(_rw);
+ _uslstm->recurrent_to_cell_weights(_rw);
+ _uslstm->recurrent_to_output_weights(_rw);
+
+ _uslstm->cell_to_input_weights(_ex);
+ _uslstm->cell_to_forget_weights(_ex);
+ _uslstm->cell_to_output_weights(_ex);
+
+ _uslstm->input_gate_bias(_gb);
+ _uslstm->forget_gate_bias(_gb);
+ _uslstm->cell_gate_bias(_gb);
+ _uslstm->output_gate_bias(_gb);
+
+ _uslstm->projection_weights(_ex);
+ _uslstm->projection_bias(_ex);
+
+ _uslstm->output_state(_ex);
+ _uslstm->cell_state(_ex);
+
+ _uslstm->input_layer_norm_coefficients(_ex);
+ _uslstm->forget_layer_norm_coefficients(_ex);
+ _uslstm->cell_layer_norm_coefficients(_ex);
+ _uslstm->output_layer_norm_coefficients(_ex);
+
+ output()->from(_uslstm);
+ }
+};
+
+} // namespace
+
+namespace
+{
+
+using namespace luci::test;
+
+// FakeQuantGraphlet is for simple negative test
+class FakeQuantGraphlet
+{
+public:
+ FakeQuantGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ _fq = g->nodes()->create<luci::CircleFakeQuant>();
+ _fq->name("fq");
+ }
+
+protected:
+ luci::CircleFakeQuant *_fq = nullptr;
+};
+
+class FakeQuantGraph : public TestIOGraph, public FakeQuantGraphlet
+{
+public:
+ FakeQuantGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1, 1, 1}, {1, 1, 1});
+ FakeQuantGraphlet::init(g());
+
+ _fq->inputs(input());
+
+ output()->from(_fq);
+ }
+};
+
+} // namespace
+
+TEST(UnrollUnidirectionalSequenceLSTMPassTestName, name)
+{
+ luci::UnrollUnidirectionalSequenceLSTMPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+class UnrollUnidirectionalSequenceLSTMPassTest : public ::testing::Test
+{
+public:
+ UnrollUniSeqLSTMPassTestGraph g;
+ luci::UnrollUnidirectionalSequenceLSTMPass pass;
+};
+
+TEST_F(UnrollUnidirectionalSequenceLSTMPassTest, simple_run)
+{
+ g.init({1, 1, 1}, {1, 1, 1});
+
+ EXPECT_TRUE(pass.run(g.g()));
+}
+
+class UnrollUnidirectionalSequenceLSTMPassTestN : public ::testing::Test
+{
+public:
+ FakeQuantGraph g;
+ luci::UnrollUnidirectionalSequenceLSTMPass pass;
+};
+
+TEST_F(UnrollUnidirectionalSequenceLSTMPassTestN, simple_run_NEG)
+{
+ g.init();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
index 408e6b8d9..6bf7ff698 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
@@ -133,6 +133,10 @@ private:
bool visit(const luci::CircleAdd *node)
{
+ // Skip granularity check for indices
+ if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
+ return true;
+
RETURN_FALSE_UNLESS(is_lwq(node));
RETURN_FALSE_UNLESS(is_lwq(node->x()));
RETURN_FALSE_UNLESS(is_lwq(node->y()));
@@ -176,6 +180,10 @@ private:
bool visit(const luci::CircleMul *node)
{
+ // Skip granularity check for indices
+ if (node->dtype() == loco::DataType::S32 or node->dtype() == loco::DataType::S64)
+ return true;
+
RETURN_FALSE_UNLESS(is_lwq(node));
RETURN_FALSE_UNLESS(is_lwq(node->x()));
RETURN_FALSE_UNLESS(is_lwq(node->y()));
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
index cf86acabe..3ce32555b 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
@@ -47,6 +47,10 @@ namespace luci
template <loco::DataType Qtype, loco::DataType Btype>
bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleAdd *node)
{
+ // Allow add of indices
+ if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64))
+ return true;
+
return group_has_type(node, Qtype);
}
@@ -240,6 +244,10 @@ bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMirrorPa
template <loco::DataType Qtype, loco::DataType Btype>
bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleMul *node)
{
+ // Allow mul of indices
+ if (group_has_type(node, loco::DataType::S32) or group_has_type(node, loco::DataType::S64))
+ return true;
+
return group_has_type(node, Qtype);
}
diff --git a/compiler/luci/pass/src/helpers/NodeFiller.h b/compiler/luci/pass/src/helpers/NodeFiller.h
index b80f085b0..10113e8dd 100644
--- a/compiler/luci/pass/src/helpers/NodeFiller.h
+++ b/compiler/luci/pass/src/helpers/NodeFiller.h
@@ -57,6 +57,12 @@ public:
*/
template <class COMM_NODE> bool with_commutative_args_of(const COMM_NODE *node);
+ /**
+ * @note Similar as with_commutative_args_of but not commutative.
+ * _arg_1 and _arg_2 must match that of ARG_TYPE_1 and ARG_TYPE_2.
+ */
+ template <class COMM_NODE> bool with_args_of(const COMM_NODE *node);
+
private:
ARG_TYPE_1 **_arg_1;
ARG_TYPE_2 **_arg_2;
@@ -101,4 +107,24 @@ bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_commutative_args_of(const COMM_NOD
return false;
}
+template <class ARG_TYPE_1, class ARG_TYPE_2>
+template <class COMM_NODE>
+bool NodeFiller<ARG_TYPE_1, ARG_TYPE_2>::with_args_of(const COMM_NODE *node)
+{
+ // X == ARG_TYPE_1 / Y == ARG_TYPE_2
+ {
+ auto x = dynamic_cast<ARG_TYPE_1 *>(node->x());
+ auto y = dynamic_cast<ARG_TYPE_2 *>(node->y());
+
+ if (x && y)
+ {
+ *_arg_1 = x;
+ *_arg_2 = y;
+ return true;
+ }
+ }
+
+ return false;
+}
+
} // namespace luci
diff --git a/compiler/luci/pass/src/helpers/SparsityFormatConverter.h b/compiler/luci/pass/src/helpers/SparsityFormatConverter.h
index fcd9bbcd0..e01430489 100644
--- a/compiler/luci/pass/src/helpers/SparsityFormatConverter.h
+++ b/compiler/luci/pass/src/helpers/SparsityFormatConverter.h
@@ -18,6 +18,7 @@
#ifndef __LUCI_PASS_HELPERS_SPARSITY_FORMAT_CONVERTER_H__
#define __LUCI_PASS_HELPERS_SPARSITY_FORMAT_CONVERTER_H__
+#include <cstddef>
#include <cstdint>
#include <vector>
diff --git a/compiler/luci/pass/src/helpers/Strings.cpp b/compiler/luci/pass/src/helpers/Strings.cpp
index d020f6ddc..2628726c1 100644
--- a/compiler/luci/pass/src/helpers/Strings.cpp
+++ b/compiler/luci/pass/src/helpers/Strings.cpp
@@ -77,6 +77,15 @@ loco::DataType str_to_dtype(const std::string &str)
return loco::DataType::Unknown;
}
+// Convert string to a vector of loco::DataType
+std::vector<loco::DataType> str_vec_to_dtype_vec(std::vector<std::string> &vec)
+{
+ std::vector<loco::DataType> res;
+ std::transform(vec.begin(), vec.end(), std::back_inserter(res),
+ [](std::string s) -> loco::DataType { return str_to_dtype(to_lower_case(s)); });
+ return res;
+}
+
QuantizationGranularity str_to_granularity(const std::string &str)
{
if (to_lower_case(str).compare("layer") == 0)
diff --git a/compiler/luci/pass/src/helpers/Strings.h b/compiler/luci/pass/src/helpers/Strings.h
index 0e7818517..485f37948 100644
--- a/compiler/luci/pass/src/helpers/Strings.h
+++ b/compiler/luci/pass/src/helpers/Strings.h
@@ -36,6 +36,8 @@ std::string to_lower_case(std::string);
loco::DataType str_to_dtype(const std::string &);
+std::vector<loco::DataType> str_vec_to_dtype_vec(std::vector<std::string> &);
+
QuantizationGranularity str_to_granularity(const std::string &);
} // namespace luci
diff --git a/compiler/luci/pass/src/helpers/Strings.test.cpp b/compiler/luci/pass/src/helpers/Strings.test.cpp
index d77b65038..6d854ad4f 100644
--- a/compiler/luci/pass/src/helpers/Strings.test.cpp
+++ b/compiler/luci/pass/src/helpers/Strings.test.cpp
@@ -48,3 +48,26 @@ TEST(StringsTest, str_to_granularity)
EXPECT_THROW(luci::str_to_granularity("foo"), std::runtime_error);
}
+
+TEST(StringsTest, str_vec_to_dtype_vec)
+{
+ std::vector<std::string> input1 = {"uint8", "int16", "float32"};
+ auto result1 = luci::str_vec_to_dtype_vec(input1);
+ ASSERT_EQ(3, result1.size());
+ ASSERT_EQ(loco::DataType::U8, result1[0]);
+ ASSERT_EQ(loco::DataType::S16, result1[1]);
+ ASSERT_EQ(loco::DataType::FLOAT32, result1[2]);
+
+ std::vector<std::string> input2 = {"uint8", "int16", "float32", ""};
+ auto result2 = luci::str_vec_to_dtype_vec(input2);
+ ASSERT_EQ(4, result2.size());
+ ASSERT_EQ(loco::DataType::U8, result2[0]);
+ ASSERT_EQ(loco::DataType::S16, result2[1]);
+ ASSERT_EQ(loco::DataType::FLOAT32, result2[2]);
+ ASSERT_EQ(loco::DataType::Unknown, result2[3]);
+
+ std::vector<std::string> input3 = {"uint8"};
+ auto result3 = luci::str_vec_to_dtype_vec(input3);
+ ASSERT_EQ(1, result3.size());
+ ASSERT_EQ(loco::DataType::U8, result3[0]);
+}
diff --git a/compiler/luci/pass/src/test/TestIOGraph.h b/compiler/luci/pass/src/test/TestIOGraph.h
deleted file mode 100644
index b1fc41f90..000000000
--- a/compiler/luci/pass/src/test/TestIOGraph.h
+++ /dev/null
@@ -1,161 +0,0 @@
-/*
- * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef __LUCI_PASS_TEST_IO_GRAPH_H__
-#define __LUCI_PASS_TEST_IO_GRAPH_H__
-
-#include "TestShape.h"
-
-#include <luci/IR/CircleNodes.h>
-
-namespace luci
-{
-namespace test
-{
-
-/**
- * @brief Graphlet with Inputs and loco::Graph for multiple inputs
- * @note Every Graph will have Input(s) and Output(s)
- * We put loco::Graph only in IsGraphlet not to declare separate
- * class for loco::Graph
- */
-template <unsigned N> class TestIsGraphlet
-{
-public:
- TestIsGraphlet()
- {
- for (uint32_t n = 0; n < N; ++n)
- {
- _graph_inputs[n] = nullptr;
- _inputs[n] = nullptr;
- }
- }
-
-public:
- virtual void init(loco::Graph *g, const ShapeU32 shape_in)
- {
- for (uint32_t n = 0; n < N; ++n)
- {
- _graph_inputs[n] = g->inputs()->create();
-
- _inputs[n] = g->nodes()->create<luci::CircleInput>();
- _inputs[n]->shape(shape_in);
- _inputs[n]->shape_status(luci::ShapeStatus::VALID);
- _inputs[n]->dtype(loco::DataType::FLOAT32);
- _inputs[n]->name("input_" + std::to_string(n));
-
- _inputs[n]->index(_graph_inputs[n]->index());
-
- auto input_shape = std::make_unique<loco::TensorShape>();
- set_shape_vector(input_shape.get(), shape_in);
- _graph_inputs[n]->shape(std::move(input_shape));
- _graph_inputs[n]->dtype(loco::DataType::FLOAT32);
- }
- }
-
-public:
- loco::Graph *g(void) { return &_g; }
- luci::CircleInput *input(int idx) { return _inputs[idx]; }
-
-protected:
- loco::Graph _g;
- std::array<loco::GraphInput *, N> _graph_inputs;
- std::array<luci::CircleInput *, N> _inputs;
-};
-
-/**
- * @brief Graphlet with one Input
- */
-class TestIGraphlet : public TestIsGraphlet<1>
-{
-public:
- luci::CircleInput *input() { return _inputs[0]; }
-};
-
-/**
- * @brief Graphlet with Outputs for multiple outputs
- */
-template <unsigned N> class TestOsGraphlet
-{
-public:
- TestOsGraphlet()
- {
- for (uint32_t n = 0; n < N; ++n)
- {
- _graph_outputs[n] = nullptr;
- _outputs[n] = nullptr;
- }
- }
-
-public:
- virtual void init(loco::Graph *g, const ShapeU32 shape_out)
- {
- for (uint32_t n = 0; n < N; ++n)
- {
- _graph_outputs[n] = g->outputs()->create();
-
- _outputs[n] = g->nodes()->create<luci::CircleOutput>();
- _outputs[n]->shape(shape_out);
- _outputs[n]->shape_status(luci::ShapeStatus::VALID);
- _outputs[n]->dtype(loco::DataType::FLOAT32);
- _outputs[n]->name("output_" + std::to_string(n));
-
- _outputs[n]->index(_graph_outputs[n]->index());
-
- auto output_shape = std::make_unique<loco::TensorShape>();
- set_shape_vector(output_shape.get(), shape_out);
- _graph_outputs[n]->shape(std::move(output_shape));
- _graph_outputs[n]->dtype(loco::DataType::FLOAT32);
- }
- }
-
-public:
- luci::CircleOutput *output(int idx) { return _outputs[idx]; }
-
-protected:
- std::array<loco::GraphOutput *, N> _graph_outputs;
- std::array<luci::CircleOutput *, N> _outputs;
-};
-
-/**
- * @brief Graphlet with one Output
- */
-class TestOGraphlet : public TestOsGraphlet<1>
-{
-public:
- luci::CircleOutput *output() { return _outputs[0]; }
-};
-
-/**
- * @brief Graph with Input and Output
- */
-class TestIOGraph : public TestIGraphlet, public TestOGraphlet
-{
-public:
- TestIOGraph() = default;
-
-public:
- virtual void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
- {
- TestIsGraphlet<1>::init(g(), shape_in);
- TestOsGraphlet<1>::init(g(), shape_out);
- }
-};
-
-} // namespace test
-} // namespace luci
-
-#endif // __LUCI_PASS_TEST_IO_GRAPH_H__