summaryrefslogtreecommitdiff
path: root/compiler/luci/pass
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2022-09-07 19:04:21 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2022-09-07 19:04:21 +0900
commitc690d52bdd137ed6a17353aa7af35e8141ece77b (patch)
treedbb7dd99133132dfbffcb8c9e9af4f1ffc2f4808 /compiler/luci/pass
parent3ad689f0803519e343c36d5700646e86059df961 (diff)
downloadnnfw-accepted/tizen_7.0_unified.tar.gz
nnfw-accepted/tizen_7.0_unified.tar.bz2
nnfw-accepted/tizen_7.0_unified.zip
Diffstat (limited to 'compiler/luci/pass')
-rw-r--r--compiler/luci/pass/CMakeLists.txt8
-rw-r--r--compiler/luci/pass/include/luci/CircleOptimizer.h3
-rw-r--r--compiler/luci/pass/include/luci/Pass/FoldDensifyPass.h38
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveRedundantDequantizePass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapeNetPass.h39
-rw-r--r--compiler/luci/pass/include/luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h37
-rw-r--r--compiler/luci/pass/include/luci/Pass/ResolveCustomOpSplitVPass.h37
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp39
-rw-r--r--compiler/luci/pass/src/CircleQuantizer.cpp7
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp329
-rw-r--r--compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp525
-rw-r--r--compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp57
-rw-r--r--compiler/luci/pass/src/FoldDensifyPass.cpp180
-rw-r--r--compiler/luci/pass/src/FoldDensifyPass.test.cpp158
-rw-r--r--compiler/luci/pass/src/FoldDequantizePass.cpp96
-rw-r--r--compiler/luci/pass/src/FoldDequantizePass.test.cpp377
-rw-r--r--compiler/luci/pass/src/FoldSparseToDensePass.cpp2
-rw-r--r--compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp49
-rw-r--r--compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp86
-rw-r--r--compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp6
-rw-r--r--compiler/luci/pass/src/FuseAddWithTConvPass.cpp20
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp53
-rw-r--r--compiler/luci/pass/src/FuseInstanceNormPass.cpp186
-rw-r--r--compiler/luci/pass/src/PropagateQParamBackwardPass.cpp1
-rw-r--r--compiler/luci/pass/src/PropagateQParamForwardPass.cpp9
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.cpp126
-rw-r--r--compiler/luci/pass/src/QuantizationUtils.h16
-rw-r--r--compiler/luci/pass/src/QuantizeActivation.cpp11
-rw-r--r--compiler/luci/pass/src/QuantizeBias.cpp14
-rw-r--r--compiler/luci/pass/src/QuantizeBias.test.cpp189
-rw-r--r--compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp7
-rw-r--r--compiler/luci/pass/src/QuantizeWeights.cpp1
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp91
-rw-r--r--compiler/luci/pass/src/QuantizedModelVerifier.test.cpp53
-rw-r--r--compiler/luci/pass/src/RemoveRedundantDequantizePass.cpp80
-rw-r--r--compiler/luci/pass/src/RemoveRedundantDequantizePass.test.cpp114
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.cpp172
-rw-r--r--compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.test.cpp123
-rw-r--r--compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp196
-rw-r--r--compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp189
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp172
-rw-r--r--compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp175
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h7
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.cpp9
-rw-r--r--compiler/luci/pass/src/VerifyQuantizedNodeType.h1
-rw-r--r--compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp312
-rw-r--r--compiler/luci/pass/src/helpers/SparsityFormatConverter.h129
47 files changed, 4300 insertions, 266 deletions
diff --git a/compiler/luci/pass/CMakeLists.txt b/compiler/luci/pass/CMakeLists.txt
index 5237c6d3f..d9d004db9 100644
--- a/compiler/luci/pass/CMakeLists.txt
+++ b/compiler/luci/pass/CMakeLists.txt
@@ -1,9 +1,16 @@
nnas_find_package(FlatBuffers EXACT 2.0 QUIET)
+nnas_find_package(Fp16Source QUIET)
+
if(NOT FlatBuffers_FOUND)
message(STATUS "FlatBuffers NOT FOUND")
return()
endif(NOT FlatBuffers_FOUND)
+if(NOT Fp16Source_FOUND)
+ message(STATUS "Fp16Source NOT FOUND")
+ return()
+endif(NOT Fp16Source_FOUND)
+
file(GLOB_RECURSE SOURCES "src/*.cpp")
file(GLOB_RECURSE TESTS "src/*.test.cpp")
list(REMOVE_ITEM SOURCES ${TESTS})
@@ -14,6 +21,7 @@ endif(NOT LUCI_LIBRARY_TYPE)
add_library(luci_pass ${LUCI_LIBRARY_TYPE} ${SOURCES})
target_include_directories(luci_pass PRIVATE src)
+target_include_directories(luci_pass PRIVATE ${Fp16Source_DIR}/include)
target_include_directories(luci_pass PUBLIC include)
target_link_libraries(luci_pass PUBLIC loco)
target_link_libraries(luci_pass PUBLIC logo_core)
diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h
index c803898f6..b94822c35 100644
--- a/compiler/luci/pass/include/luci/CircleOptimizer.h
+++ b/compiler/luci/pass/include/luci/CircleOptimizer.h
@@ -47,8 +47,10 @@ public:
ResolveCustomOpBatchMatMul,
ResolveCustomOpMatMul,
ResolveCustomOpMaxPoolWithArgmax,
+ ResolveCustomOpSplitV,
FoldAddV2,
FoldCast,
+ FoldDensify,
FoldDepthwiseConv2D,
FoldDequantize,
FoldGather,
@@ -61,6 +63,7 @@ public:
ShuffleWeightTo16x1Float32,
RemoveRedundantTranspose,
ReplaceMulAddWithDepthwiseConv,
+ ReplaceNonConstFCWithBatchMatMul,
ReplaceSubWithAdd,
SubstitutePackToReshape,
SubstitutePadV2ToPad,
diff --git a/compiler/luci/pass/include/luci/Pass/FoldDensifyPass.h b/compiler/luci/pass/include/luci/Pass/FoldDensifyPass.h
new file mode 100644
index 000000000..8ec81b1d4
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/FoldDensifyPass.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_FOLD_DENSIFY_PASS_H__
+#define __LUCI_FOLD_DENSIFY_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to Fold Densify if input is Sparse Constant
+ *
+ */
+struct FoldDensifyPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::FoldDensifyPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_FOLD_DENSIFY_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveRedundantDequantizePass.h b/compiler/luci/pass/include/luci/Pass/RemoveRedundantDequantizePass.h
new file mode 100644
index 000000000..2deb75297
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveRedundantDequantizePass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_REDUNDANT_DEQUANTIZE_PASS_H__
+#define __LUCI_REMOVE_REDUNDANT_DEQUANTIZE_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to remove redundant dequantize operations
+ */
+struct RemoveRedundantDequantizePass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveRedundantDequantizePass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_REDUNDANT_DEQUANTIZE_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapeNetPass.h b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapeNetPass.h
new file mode 100644
index 000000000..19948a31c
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/RemoveUnnecessaryReshapeNetPass.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REMOVE_UNNECESSARY_RESHAPE_NET_PASS_H__
+#define __LUCI_REMOVE_UNNECESSARY_RESHAPE_NET_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to remove unnecessary Reshape nodes.
+ * @details This class will remove unnecessary pre/post-Reshape nodes.
+ * See https://github.com/Samsung/ONE/issues/9600 for more details.
+ */
+struct RemoveUnnecessaryReshapeNetPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::RemoveUnnecessaryReshapeNetPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REMOVE_UNNECESSARY_RESHAPE_NET_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h b/compiler/luci/pass/include/luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h
new file mode 100644
index 000000000..24e16ec49
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_REPLACE_NONCONST_FC_WITH_BATCH_MATMUL_PASS_H__
+#define __LUCI_REPLACE_NONCONST_FC_WITH_BATCH_MATMUL_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to replace "FC with non-const weight" with Batched MatMul
+ */
+struct ReplaceNonConstFCWithBatchMatMulPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::ReplaceNonConstFCWithBatchMatMulPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_REPLACE_NONCONST_FC_WITH_BATCH_MATMUL_PASS_H__
diff --git a/compiler/luci/pass/include/luci/Pass/ResolveCustomOpSplitVPass.h b/compiler/luci/pass/include/luci/Pass/ResolveCustomOpSplitVPass.h
new file mode 100644
index 000000000..d4f0147e8
--- /dev/null
+++ b/compiler/luci/pass/include/luci/Pass/ResolveCustomOpSplitVPass.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_RESOLVE_CUSTOM_OP_SPLIT_V_PASS_H__
+#define __LUCI_RESOLVE_CUSTOM_OP_SPLIT_V_PASS_H__
+
+#include <logo/Pass.h>
+
+namespace luci
+{
+
+/**
+ * @brief Class to resolve certain custom op of subgraph into splitv op in circle schema.
+ */
+struct ResolveCustomOpSplitVPass final : public logo::Pass
+{
+ const char *name(void) const final { return "luci::ResolveCustomOpSplitVPass"; }
+
+ bool run(loco::Graph *g) final;
+};
+
+} // namespace luci
+
+#endif // __LUCI_RESOLVE_CUSTOM_OP_SPLIT_V_PASS_H__
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index 6dbb22d7c..74c569d20 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -20,6 +20,7 @@
#include "luci/Pass/ExpandBroadcastConstPass.h"
#include "luci/Pass/FoldAddV2Pass.h"
#include "luci/Pass/FoldCastPass.h"
+#include "luci/Pass/FoldDensifyPass.h"
#include "luci/Pass/FoldDepthwiseConv2DPass.h"
#include "luci/Pass/FoldDequantizePass.h"
#include "luci/Pass/FoldGatherPass.h"
@@ -43,15 +44,18 @@
#include "luci/Pass/RemoveRedundantTransposePass.h"
#include "luci/Pass/RemoveRedundantQuantizePass.h"
#include "luci/Pass/RemoveUnnecessaryReshapePass.h"
+#include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h"
#include "luci/Pass/RemoveUnnecessarySlicePass.h"
#include "luci/Pass/RemoveUnnecessaryStridedSlicePass.h"
#include "luci/Pass/RemoveUnnecessarySplitPass.h"
+#include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
#include "luci/Pass/ReplaceMulAddWithDepthwiseConvPass.h"
#include "luci/Pass/ReplaceSubWithAddPass.h"
#include "luci/Pass/ResolveCustomOpAddPass.h"
#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMaxPoolWithArgmaxPass.h"
+#include "luci/Pass/ResolveCustomOpSplitVPass.h"
#include "luci/Pass/SparsifyTensorPass.h"
#include "luci/Pass/ShuffleWeightTo16x1Float32Pass.h"
#include "luci/Pass/SubstitutePackToReshapePass.h"
@@ -127,7 +131,8 @@ bool OptimizeOptionsImpl::query(Algorithm algo)
return true;
}
-void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output)
+// TODO Make a struct for args
+void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_output, bool fuse_fc)
{
logo::Phase phase;
@@ -135,6 +140,21 @@ void convert_nchw_to_nhwc(loco::Graph *g, bool preserve_input, bool preserve_out
phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+ // Resolve custom Ops
+ phase.emplace_back(std::make_unique<luci::ResolveCustomOpAddPass>());
+ phase.emplace_back(std::make_unique<luci::ResolveCustomOpBatchMatMulPass>());
+ phase.emplace_back(std::make_unique<luci::ResolveCustomOpMatMulPass>());
+ phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
+ phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>());
+
+ // Fuse FullyConnected with Add
+ // Why we perform FuseAddWithFullyConnectedPass before ConvertNCHWToNHWCPass?
+ // FullyConnected Op's layout is not changed in ConvertNCHWToNHWCPass, while
+ // Add Op's layer is changed from NCHW to NHWC.
+ // This disables fusion of Add and FullyConnected after ConvertNCHWToNHWC.
+ if (fuse_fc)
+ phase.emplace_back(std::make_unique<luci::FuseAddWithFullyConnectedPass>());
+
phase.emplace_back(
std::make_unique<luci::ConvertNCHWToNHWCPass>(preserve_input, preserve_output));
@@ -190,7 +210,9 @@ void CircleOptimizer::optimize(loco::Graph *g) const
bool preserve_output =
_options->param(Options::AlgorithmParameters::NCHW_to_NHWC_output_shape) != "true";
- convert_nchw_to_nhwc(g, preserve_input, preserve_output);
+ bool fuse_fc = _options->query(Options::Algorithm::FuseAddWithFullyConnected);
+
+ convert_nchw_to_nhwc(g, preserve_input, preserve_output, fuse_fc);
}
/* TRANSFORM DECLARATION BEGIN */
@@ -220,6 +242,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::ResolveCustomOpMaxPoolWithArgmaxPass>());
}
+ if (_options->query(Options::Algorithm::ResolveCustomOpSplitV))
+ {
+ phase.emplace_back(std::make_unique<luci::ResolveCustomOpSplitVPass>());
+ }
if (_options->query(Options::Algorithm::FuseInstanceNorm))
{
phase.emplace_back(std::make_unique<FuseInstanceNormPass>());
@@ -260,6 +286,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::FoldCastPass>());
}
+ if (_options->query(Options::Algorithm::FoldDensify))
+ {
+ phase.emplace_back(std::make_unique<luci::FoldDensifyPass>());
+ }
if (_options->query(Options::Algorithm::FoldDepthwiseConv2D))
{
phase.emplace_back(std::make_unique<luci::FoldDepthwiseConv2DPass>());
@@ -307,6 +337,7 @@ void CircleOptimizer::optimize(loco::Graph *g) const
if (_options->query(Options::Algorithm::RemoveUnnecessaryReshape))
{
phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapePass>());
+ phase.emplace_back(std::make_unique<luci::RemoveUnnecessaryReshapeNetPass>());
}
if (_options->query(Options::Algorithm::RemoveUnnecessarySlice))
{
@@ -332,6 +363,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
}
+ if (_options->query(Options::Algorithm::ReplaceNonConstFCWithBatchMatMul))
+ {
+ phase.emplace_back(std::make_unique<luci::ReplaceNonConstFCWithBatchMatMulPass>());
+ }
if (_options->query(Options::Algorithm::ReplaceMulAddWithDepthwiseConv))
{
phase.emplace_back(std::make_unique<luci::ReplaceMulAddWithDepthwiseConvPass>());
diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp
index ce38a90b9..9a6550b9f 100644
--- a/compiler/luci/pass/src/CircleQuantizer.cpp
+++ b/compiler/luci/pass/src/CircleQuantizer.cpp
@@ -22,6 +22,7 @@
#include "luci/Pass/RequantizePass.h"
#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/RemoveRedundantDequantizePass.h"
#include "luci/Pass/QuantizePreCheckerPass.h"
#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
@@ -252,8 +253,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const
static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
- static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
- static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
+ static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16", "float32"};
+ static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16", "float32"};
auto input_model_dtype =
_options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
@@ -434,6 +435,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const
phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+ // Remove redundant Dequantize Ops generated during fake quantization
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantDequantizePass>());
// Fold Dequantize Ops generated during fake quantization
phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
index ce4f54035..55a29d105 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.cpp
@@ -28,6 +28,69 @@
namespace
{
+// Return true if from can be broadcasted to to
+// to's shape is [N, C, H, W]
+bool broadcastable(const luci::CircleConst *from, const luci::CircleNode *to)
+{
+ assert(to->rank() == 4); // FIX_CALLER_UNLESS
+
+ const auto from_rank = from->rank();
+ if (from_rank > 4)
+ return false;
+
+ // Scalar is always broadcastable
+ if (from_rank == 0)
+ return true;
+
+ for (uint32_t i = 1; i <= from_rank; i++)
+ {
+ auto to_index = 4 - i;
+ auto from_index = from_rank - i;
+
+ if (from->dim(from_index).value() != to->dim(to_index).value() and
+ from->dim(from_index).value() != 1)
+ return false;
+ }
+
+ return true;
+}
+
+// Expand node to rank 4
+// node should have rank less than or equal to 4
+void expand_to_rank_4(luci::CircleConst *node)
+{
+ auto original_rank = node->rank();
+
+ assert(original_rank <= 4); // FIX_CALLER_UNLESS
+
+ if (original_rank == 4)
+ return;
+
+ std::vector<uint32_t> original_shape;
+ for (uint32_t i = 0; i < original_rank; i++)
+ {
+ original_shape.emplace_back(node->dim(i).value());
+ }
+
+ node->rank(4);
+ for (uint32_t i = 0; i < (4 - original_rank); i++)
+ node->dim(i) = 1;
+
+ for (uint32_t i = 0; i < original_rank; i++)
+ node->dim(i + (4 - original_rank)) = original_shape.at(i);
+}
+
+bool is_output(const loco::Node *node)
+{
+ auto cnode = loco::must_cast<const luci::CircleNode *>(node);
+ auto opcode = cnode->opcode();
+ if (opcode == luci::CircleOpcode::CIRCLEOUTPUT ||
+ opcode == luci::CircleOpcode::CIRCLEOUTPUTEXCLUDE)
+ return true;
+
+ return false;
+}
+
bool is_same_shape(const luci::CircleNode *node, const std::vector<loco::Dimension> &shape)
{
if (not node)
@@ -484,7 +547,7 @@ bool is_NCHW_with_s_const(const T *node, luci::CircleNode *&pred_node,
//
// Find MUL with an NCHW pattern described below
// - Input (non-constant) shape : [N, C, H, W]
-// - Input (constant) shape : [1, C, 1, 1], [N, C, H, W] or a scalar (1)
+// - Input (constant) shape : broadcastable to [N, C, H, W]
// - Output shape : [N, C, H, W]
bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_node,
luci::CircleConst *&multiplier)
@@ -511,32 +574,12 @@ bool is_NCHW_with_const(const luci::CircleMul *node, luci::CircleNode *&pred_nod
if (pred_node->rank() != 4)
return false;
- const auto const_rank = multiplier->rank();
- // Support Rank 4 or scalar (rank 0 or 1)
- if (const_rank != 4 && const_rank != 0 && const_rank != 1)
+ if (not broadcastable(multiplier, node))
return false;
- const auto input_cdim = pred_node->dim(1);
- const auto output_cdim = node->dim(1);
-
- if (const_rank == 4)
- {
- bool supported_shape = false;
-
- // Check multiplier is (1, C, 1, 1)
- if (is_same_shape(multiplier, {1, node->dim(1), 1, 1}))
- supported_shape = true;
-
- // Check multiplier is (N, C, H, W)
- if (is_same_shape(multiplier, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
- supported_shape = true;
+ expand_to_rank_4(multiplier);
- return supported_shape;
- }
- if (input_cdim == output_cdim)
- return true;
- else
- return false;
+ return true;
}
// We assume ADD with const input is NCHW if,
@@ -569,32 +612,12 @@ bool is_NCHW_with_const(const luci::CircleAdd *node, luci::CircleNode *&pred_nod
if (pred_node->rank() != 4)
return false;
- const auto const_rank = beta->rank();
- // Support Rank 4 or scalar (rank 0 or 1)
- if (const_rank != 4 && const_rank != 0 && const_rank != 1)
+ if (not broadcastable(beta, node))
return false;
- const auto input_cdim = pred_node->dim(1);
- const auto output_cdim = node->dim(1);
-
- if (const_rank == 4)
- {
- bool supported_shape = false;
-
- // Check beta is (1, C, 1, 1)
- if (is_same_shape(beta, {1, node->dim(1), 1, 1}))
- supported_shape = true;
-
- // Check beta is (N, C, H, W)
- if (is_same_shape(beta, {node->dim(0), node->dim(1), node->dim(2), node->dim(3)}))
- supported_shape = true;
+ expand_to_rank_4(beta);
- return supported_shape;
- }
- if (input_cdim == output_cdim)
- return true;
- else
- return false;
+ return true;
}
// We assume SUB with const input is NCHW if,
@@ -675,6 +698,24 @@ template <class T> bool convert_unary_x(T *node)
return true;
}
+template <class T> bool convert_unary_logits(T *node)
+{
+ const auto pred_node = loco::must_cast<luci::CircleNode *>(node->logits());
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+ node->logits(pre_trans);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+}
+
class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
{
// Default
@@ -742,17 +783,14 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
if (is_NCHW_with_const(node, pred_node, beta))
{
+ assert(beta->rank() == 4); // FIX is_NCHW_with_const unless
+ auto nhwc_const = create_NHWC_from_NCHW(beta);
+ if (nhwc_const == nullptr)
+ return false;
+ node->y(nhwc_const);
+
auto pre_trans = create_pre_transpose(node);
pre_trans->a(pred_node);
-
- if (beta->rank() == 4)
- {
- auto nhwc_const = create_NHWC_from_NCHW(beta);
- if (nhwc_const == nullptr)
- return false;
- node->y(nhwc_const);
- }
-
node->x(pre_trans);
}
else if (beta == nullptr)
@@ -816,6 +854,11 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
bool visit(luci::CircleLogistic *node) { return convert_unary_x<luci::CircleLogistic>(node); }
+ bool visit(luci::CircleLogSoftmax *node)
+ {
+ return convert_unary_logits<luci::CircleLogSoftmax>(node);
+ }
+
bool visit(luci::CircleMaximum *node)
{
luci::CircleNode *pred_node = nullptr;
@@ -954,15 +997,15 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
if (is_NCHW_with_const(node, pred_node, multiplier))
{
+ assert(multiplier->rank() == 4); // FIX is_NCHW_with_const unless
+ auto nhwc_const = create_NHWC_from_NCHW(multiplier);
+ if (nhwc_const == nullptr)
+ return false;
+ node->y(nhwc_const);
+
auto pre_trans = create_pre_transpose(node);
pre_trans->a(pred_node);
node->x(pre_trans);
-
- if (multiplier->rank() == 4)
- {
- auto nhwc_const = create_NHWC_from_NCHW(multiplier);
- node->y(nhwc_const);
- }
}
else if (multiplier == nullptr)
{
@@ -1049,12 +1092,127 @@ class ConvertNCHWToNHWC final : public luci::CircleNodeMutableVisitor<bool>
return true;
}
+ // TODO Reduce duplicate code with CircleMean
+ bool visit(luci::CircleReduceMax *node)
+ {
+ auto input = loco::must_cast<luci::CircleNode *>(node->input());
+ if (input->rank() != 4)
+ return false;
+
+ auto rindices = dynamic_cast<luci::CircleConst *>(node->reduction_indices());
+ if (not rindices)
+ return false;
+
+ auto nhwc_rindices = create_NHWC_rindices(rindices);
+ if (not nhwc_rindices)
+ return false;
+
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(input);
+ node->input(pre_trans);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ node->reduction_indices(nhwc_rindices);
+
+ if (node->keep_dims())
+ {
+ auto post_trans = create_post_transpose(node);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+ }
+
+ // The below codes handle the cases where node->keep_dims() == false
+ // 1D output never needs a transpose
+ if (node->rank() <= 1)
+ return true;
+
+ std::vector<bool> reduced_dims_nhwc(4, false);
+ uint32_t num_reduced_indices = nhwc_rindices->size<loco::DataType::S32>();
+
+ for (uint32_t ri = 0; ri < num_reduced_indices; ++ri)
+ {
+ reduced_dims_nhwc[nhwc_rindices->at<loco::DataType::S32>(ri)] = true;
+ }
+
+ // if channel dimension has been reduced, we don't need a transpose
+ if (reduced_dims_nhwc[3])
+ return true;
+
+ // likewise, if both space dimensions are reduced, no transpose is needed
+ if (reduced_dims_nhwc[1] && reduced_dims_nhwc[2])
+ return true;
+
+ std::vector<int32_t> post_trans_ind;
+ // case 1: only N is reduced
+ if (num_reduced_indices == 1 && reduced_dims_nhwc[0])
+ post_trans_ind = {2, 0, 1};
+
+ // case 2: only H or W is reduced
+ if (num_reduced_indices == 1 && (reduced_dims_nhwc[1] || reduced_dims_nhwc[2]))
+ post_trans_ind = {0, 2, 1};
+
+ // case 3: N and either H or W are reduced
+ if (num_reduced_indices == 2)
+ post_trans_ind = {1, 0};
+
+ auto post_trans = create_Nd_transpose(node, post_trans_ind);
+ loco::replace(node).with(post_trans);
+
+ post_trans->a(node);
+
+ return true;
+ }
+
bool visit(luci::CircleRelu *node) { return convert_unary_features<luci::CircleRelu>(node); }
bool visit(luci::CircleRelu6 *node) { return convert_unary_features<luci::CircleRelu6>(node); }
bool visit(luci::CircleRsqrt *node) { return convert_unary_x<luci::CircleRsqrt>(node); }
+ bool visit(luci::CircleSoftmax *node) { return convert_unary_logits<luci::CircleSoftmax>(node); }
+
+ bool visit(luci::CircleSplitV *node)
+ {
+ // Change split dimension
+ auto axis = dynamic_cast<luci::CircleConst *>(node->split_dim());
+ if (not axis)
+ return false;
+
+ if (axis->dtype() != loco::DataType::S32)
+ return false;
+
+ if (axis->size<loco::DataType::S32>() != 1)
+ return false;
+
+ axis->at<loco::DataType::S32>(0) = nchw_axis_to_nhwc(axis->at<loco::DataType::S32>(0));
+
+ // Insert pre-transpose
+ const auto pred_node = loco::must_cast<luci::CircleNode *>(node->input());
+ auto pre_trans = create_pre_transpose(node);
+ pre_trans->a(pred_node);
+ node->input(pre_trans);
+
+ // Do shape inference for this node again.
+ node->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ // Insert post-transposes
+ for (auto succ : loco::succs(node))
+ {
+ auto svo = loco::must_cast<luci::CircleSplitVOut *>(succ);
+
+ auto post_trans = create_post_transpose(svo);
+ loco::replace(svo).with(post_trans);
+ post_trans->a(svo);
+ }
+
+ return true;
+ }
+
bool visit(luci::CircleSquaredDifference *node)
{
// TODO support CircleConst input
@@ -1195,6 +1353,8 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
// pre-Transpose --- [intermediate Ops] --- post-Transpose
// |
// +--[intermediate Ops] --- post-Transpose
+ //
+ // NOTE Intermediate Ops SHOULD NOT contain pre-Transpose/Reshape
for (auto node : loco::postorder_traversal(loco::output_nodes(g)))
{
if (has_data_format(node))
@@ -1202,25 +1362,51 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
if (is_pre_transpose(node) || is_pre_reshape(node))
{
+ std::set<loco::Node *> intermediate;
+
+ // Variable to check intermediate Ops contain pre-Transpose/Reshape
+ bool has_pre = false;
+
+ // Variable to check the pattern is closed with post-Transpose/Reshape
+ bool is_closed = true;
+
// For recursive call of lambda
- std::function<void(loco::Node *)> set_data_format_to_succs;
- set_data_format_to_succs = [&](loco::Node *n) {
+ std::function<void(loco::Node *)> collect_intermediate;
+ collect_intermediate = [&](loco::Node *n) {
for (auto succ : loco::succs(n))
{
// Exit condition
if (is_post_transpose(succ) || is_post_reshape(succ))
continue;
- if (not has_data_format(succ))
+ if (is_pre_transpose(succ) || is_pre_reshape(succ))
+ {
+ has_pre = true;
+ break;
+ }
+
+ if (is_output(succ))
{
- set_data_format(succ, DataFormat::NHWC);
+ is_closed = false;
+ break;
}
- set_data_format_to_succs(succ);
+ intermediate.emplace(succ);
+
+ collect_intermediate(succ);
}
};
- set_data_format_to_succs(node);
+ collect_intermediate(node);
+
+ if (has_pre or not is_closed)
+ continue;
+
+ for (auto inter : intermediate)
+ {
+ if (not has_data_format(inter))
+ set_data_format(inter, DataFormat::NHWC);
+ }
}
}
@@ -1248,6 +1434,7 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
case luci::CircleOpcode::ELU:
case luci::CircleOpcode::LEAKY_RELU:
case luci::CircleOpcode::LOGISTIC:
+ case luci::CircleOpcode::LOG_SOFTMAX:
case luci::CircleOpcode::MAXIMUM:
case luci::CircleOpcode::MEAN:
case luci::CircleOpcode::MINIMUM:
@@ -1255,9 +1442,12 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
case luci::CircleOpcode::NEG:
case luci::CircleOpcode::PAD:
case luci::CircleOpcode::PADV2:
+ case luci::CircleOpcode::REDUCE_MAX:
case luci::CircleOpcode::RELU:
case luci::CircleOpcode::RELU6:
case luci::CircleOpcode::RSQRT:
+ case luci::CircleOpcode::SOFTMAX:
+ case luci::CircleOpcode::SPLIT_V:
case luci::CircleOpcode::SQUARED_DIFFERENCE:
case luci::CircleOpcode::SUB:
if (!has_data_format(node))
@@ -1296,7 +1486,8 @@ bool ConvertNCHWToNHWCPass::run(loco::Graph *g)
if (circle_node->rank() != 4)
{
// TODO replace the check above with the input rank check, and remove the condition below
- if (not dynamic_cast<luci::CircleMean *>(node))
+ if (not dynamic_cast<luci::CircleMean *>(node) and
+ not dynamic_cast<luci::CircleReduceMax *>(node))
continue;
}
diff --git a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
index dd81d1380..6bb3d3268 100644
--- a/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
+++ b/compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
@@ -16,6 +16,8 @@
#include <logo/Phase.h>
+#include <luci/test/TestIOGraph.h>
+
#include "luci/Pass/ConvertNCHWToNHWCPass.h"
#include "luci/Pass/CircleShapeInferencePass.h"
@@ -23,6 +25,8 @@
#include <gtest/gtest.h>
+using namespace luci::test;
+
namespace
{
@@ -202,6 +206,173 @@ public:
luci::CircleConst *post_shape = nullptr;
};
+/**
+ * Graph with pre-Reshape but no post-Transpose/Reshape.
+ *
+ * BEFORE
+ * [Input]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Relu]
+ * |
+ * [Output]
+ *
+ * AFTER
+ * [Input]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Pre-Transpose]
+ * |
+ * [Relu]
+ * |
+ * [Post-Transpose]
+ * |
+ * [Output]
+ */
+class NoPostReshapeGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ relu = g.nodes()->create<luci::CircleRelu>();
+ pre_reshape = g.nodes()->create<luci::CircleReshape>();
+ pre_shape = g.nodes()->create<luci::CircleConst>();
+
+ pre_shape->dtype(loco::DataType::S32);
+
+ uint32_t channel_size = 16;
+ auto in = loco::must_cast<luci::CircleNode *>(input);
+ in->shape({1, channel_size, 4, 4});
+ pre_shape->shape({4});
+
+ pre_shape->size<loco::DataType::S32>(4);
+ pre_shape->at<loco::DataType::S32>(0) = 1;
+ pre_shape->at<loco::DataType::S32>(1) = 4;
+ pre_shape->at<loco::DataType::S32>(2) = 4;
+ pre_shape->at<loco::DataType::S32>(3) = channel_size;
+
+ pre_reshape->tensor(input);
+ pre_reshape->shape(pre_shape);
+ relu->features(pre_reshape);
+
+ relu->name("Relu");
+ pre_reshape->name("pre-reshape");
+
+ return relu;
+ }
+
+public:
+ luci::CircleRelu *relu = nullptr;
+ luci::CircleReshape *pre_reshape = nullptr;
+ luci::CircleConst *pre_shape = nullptr;
+};
+
+/**
+ * Graph with two pre-Reshapes
+ *
+ * BEFORE
+ * [Input]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Relu]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Post-Reshape]
+ * |
+ * [Output]
+ *
+ * AFTER
+ * [Input]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Pre-Transpose]
+ * |
+ * [Relu]
+ * |
+ * [Post-Transpose]
+ * |
+ * [Pre-Reshape]
+ * |
+ * [Post-Reshape]
+ * |
+ * [Output]
+ */
+class ReluNotClosedGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ relu = g.nodes()->create<luci::CircleRelu>();
+ pre_reshape = g.nodes()->create<luci::CircleReshape>();
+ pre_reshape_2 = g.nodes()->create<luci::CircleReshape>();
+ post_reshape = g.nodes()->create<luci::CircleReshape>();
+ pre_shape = g.nodes()->create<luci::CircleConst>();
+ pre_shape_2 = g.nodes()->create<luci::CircleConst>();
+ post_shape = g.nodes()->create<luci::CircleConst>();
+
+ pre_shape->dtype(loco::DataType::S32);
+ pre_shape_2->dtype(loco::DataType::S32);
+ post_shape->dtype(loco::DataType::S32);
+
+ uint32_t channel_size = 16;
+ auto in = loco::must_cast<luci::CircleNode *>(input);
+ in->shape({1, channel_size, 4, 4});
+ pre_shape->shape({4});
+ pre_shape_2->shape({4});
+ post_shape->shape({4});
+
+ pre_shape->size<loco::DataType::S32>(4);
+ pre_shape->at<loco::DataType::S32>(0) = 1;
+ pre_shape->at<loco::DataType::S32>(1) = 4;
+ pre_shape->at<loco::DataType::S32>(2) = 4;
+ pre_shape->at<loco::DataType::S32>(3) = channel_size;
+
+ pre_shape_2->size<loco::DataType::S32>(4);
+ pre_shape_2->at<loco::DataType::S32>(0) = 1;
+ pre_shape_2->at<loco::DataType::S32>(1) = 4;
+ pre_shape_2->at<loco::DataType::S32>(2) = channel_size;
+ pre_shape_2->at<loco::DataType::S32>(3) = 4;
+
+ post_shape->size<loco::DataType::S32>(4);
+ post_shape->at<loco::DataType::S32>(0) = 1;
+ post_shape->at<loco::DataType::S32>(1) = 4;
+ post_shape->at<loco::DataType::S32>(2) = 4;
+ post_shape->at<loco::DataType::S32>(3) = channel_size;
+
+ pre_reshape->tensor(input);
+ pre_reshape->shape(pre_shape);
+
+ relu->features(pre_reshape);
+
+ pre_reshape_2->tensor(relu);
+ pre_reshape_2->shape(pre_shape_2);
+
+ post_reshape->tensor(pre_reshape_2);
+ post_reshape->shape(post_shape);
+
+ relu->name("Relu");
+ pre_reshape->name("pre-reshape");
+ pre_reshape->name("pre-reshape-2");
+ post_reshape->name("post-reshape");
+
+ return post_reshape;
+ }
+
+public:
+ luci::CircleRelu *relu = nullptr;
+ luci::CircleReshape *pre_reshape = nullptr;
+ luci::CircleReshape *pre_reshape_2 = nullptr;
+ luci::CircleReshape *post_reshape = nullptr;
+ luci::CircleConst *pre_shape = nullptr;
+ luci::CircleConst *pre_shape_2 = nullptr;
+ luci::CircleConst *post_shape = nullptr;
+};
+
class AddScalarGraph final : public SimpleGraph
{
protected:
@@ -312,6 +483,22 @@ public:
luci::CircleLogistic *logistic = nullptr;
};
+class LogSoftmaxGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ log_softmax = g.nodes()->create<luci::CircleLogSoftmax>();
+ log_softmax->logits(input);
+ log_softmax->name("log_softmax");
+
+ return log_softmax;
+ }
+
+public:
+ luci::CircleLogSoftmax *log_softmax = nullptr;
+};
+
class MaximumGraph final : public SimpleGraph
{
protected:
@@ -642,6 +829,51 @@ public:
luci::CircleConst *const_value = nullptr;
};
+class ReduceMaxGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ rm = g.nodes()->create<luci::CircleReduceMax>();
+ rindices = g.nodes()->create<luci::CircleConst>();
+
+ rm->dtype(loco::DataType::FLOAT32);
+ rindices->dtype(loco::DataType::S32);
+
+ rm->shape(_shape);
+ rindices->shape({static_cast<uint32_t>(_axes.size())});
+
+ rindices->size<loco::DataType::S32>(_axes.size());
+ for (uint32_t i = 0; i < _axes.size(); ++i)
+ {
+ rindices->at<loco::DataType::S32>(i) = _axes[i];
+ }
+
+ rm->input(input);
+ rm->reduction_indices(rindices);
+ rm->keep_dims(_keep_dims);
+
+ rm->name("reduce_max");
+ rindices->name("rindices");
+
+ return rm;
+ }
+
+public:
+ void keep_dims(bool val) { _keep_dims = val; }
+ void axes(std::vector<int32_t> val) { _axes = val; }
+ void shape(std::initializer_list<uint32_t> val) { _shape = val; }
+
+public:
+ luci::CircleReduceMax *rm = nullptr;
+ luci::CircleConst *rindices = nullptr;
+
+private:
+ bool _keep_dims = true;
+ std::vector<int32_t> _axes = {2, 3};
+ std::initializer_list<uint32_t> _shape = {1, 16, 1, 1};
+};
+
class ReluGraph final : public SimpleGraph
{
protected:
@@ -690,6 +922,111 @@ public:
luci::CircleRsqrt *rsqrt = nullptr;
};
+class SoftmaxGraph final : public SimpleGraph
+{
+protected:
+ loco::Node *insertGraphBody(loco::Node *input) override
+ {
+ softmax = g.nodes()->create<luci::CircleSoftmax>();
+ softmax->logits(input);
+ softmax->name("softmax");
+
+ return softmax;
+ }
+
+public:
+ luci::CircleSoftmax *softmax = nullptr;
+};
+
+class SplitVGraphlet
+{
+public:
+ SplitVGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ // CircleCustom(SplitV)
+ _splitv = g->nodes()->create<luci::CircleSplitV>();
+ _splitv->shape({1, 2, 2, 192});
+ _splitv->dtype(loco::DataType::FLOAT32);
+ _splitv->name("splitv");
+
+ // CircleConst
+ auto size_splits = g->nodes()->create<luci::CircleConst>();
+ size_splits->dtype(loco::DataType::S32);
+ size_splits->shape({3});
+ size_splits->size<loco::DataType::S32>(3);
+ size_splits->at<loco::DataType::S32>(0) = 32;
+ size_splits->at<loco::DataType::S32>(1) = 32;
+ size_splits->at<loco::DataType::S32>(2) = 128;
+
+ // CircleConst
+ auto split_dim = g->nodes()->create<luci::CircleConst>();
+ split_dim->dtype(loco::DataType::S32);
+ split_dim->rank(0);
+ split_dim->size<loco::DataType::S32>(1);
+ split_dim->scalar<loco::DataType::S32>() = 3;
+
+ _splitv->size_splits(size_splits);
+ _splitv->split_dim(split_dim);
+ _splitv->num_split(3);
+
+ // CircleSplitVOut
+ _splitv_out1 = g->nodes()->create<luci::CircleSplitVOut>();
+ _splitv_out1->shape({1, 2, 2, 32});
+ _splitv_out1->dtype(loco::DataType::FLOAT32);
+ _splitv_out1->index(0);
+ _splitv_out1->input(_splitv);
+ _splitv_out1->name("splitv_out1");
+
+ // CircleSplitVOut
+ _splitv_out2 = g->nodes()->create<luci::CircleSplitVOut>();
+ _splitv_out2->shape({1, 2, 2, 32});
+ _splitv_out2->dtype(loco::DataType::FLOAT32);
+ _splitv_out2->index(1);
+ _splitv_out2->input(_splitv);
+ _splitv_out2->name("splitv_out2");
+
+ // CircleSplitVOut
+ _splitv_out3 = g->nodes()->create<luci::CircleSplitVOut>();
+ _splitv_out3->shape({1, 2, 2, 128});
+ _splitv_out3->dtype(loco::DataType::FLOAT32);
+ _splitv_out3->index(2);
+ _splitv_out3->input(_splitv);
+ _splitv_out3->name("splitv_out3");
+ }
+
+public:
+ luci::CircleSplitV *splitv() { return _splitv; }
+
+protected:
+ luci::CircleSplitV *_splitv = nullptr;
+ luci::CircleSplitVOut *_splitv_out1 = nullptr;
+ luci::CircleSplitVOut *_splitv_out2 = nullptr;
+ luci::CircleSplitVOut *_splitv_out3 = nullptr;
+};
+
+class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet
+{
+public:
+ SplitVGraph() = default;
+
+ void init(void)
+ {
+ TestIGraphlet::init(g(), {1, 2, 2, 192});
+ TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}});
+ SplitVGraphlet::init(g());
+
+ // connect graph
+ _splitv->input(input());
+
+ output(0)->from(_splitv_out1);
+ output(1)->from(_splitv_out2);
+ output(2)->from(_splitv_out3);
+ }
+};
+
class SquaredDifferenceGraph final : public SimpleGraph
{
protected:
@@ -929,8 +1266,11 @@ TEST(ConvertNCHWToNHWC, AddScalar)
auto new_beta = dynamic_cast<luci::CircleConst *>(g.add->y());
EXPECT_NE(nullptr, new_beta);
- EXPECT_EQ(1, new_beta->rank());
+ EXPECT_EQ(4, new_beta->rank());
EXPECT_EQ(1, new_beta->dim(0).value());
+ EXPECT_EQ(1, new_beta->dim(1).value());
+ EXPECT_EQ(1, new_beta->dim(2).value());
+ EXPECT_EQ(1, new_beta->dim(3).value());
check_pre_trans(g.output->from());
}
@@ -1017,6 +1357,26 @@ TEST(ConvertNCHWToNHWC, Logistic)
EXPECT_EQ(16, g.logistic->dim(3).value());
}
+TEST(ConvertNCHWToNHWC, LogSoftmax)
+{
+ LogSoftmaxGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.log_softmax->logits());
+
+ auto log_softmax_succs = loco::succs(g.log_softmax);
+ EXPECT_EQ(1, log_softmax_succs.size());
+ check_post_trans(*log_softmax_succs.begin());
+
+ // Check log_softmax shape
+ EXPECT_EQ(1, g.log_softmax->dim(0).value());
+ EXPECT_EQ(4, g.log_softmax->dim(1).value());
+ EXPECT_EQ(4, g.log_softmax->dim(2).value());
+ EXPECT_EQ(16, g.log_softmax->dim(3).value());
+}
+
TEST(ConvertNCHWToNHWC, Maximum)
{
MaximumGraph g;
@@ -1265,8 +1625,11 @@ TEST(ConvertNCHWToNHWC, MulScalar)
auto new_multiplier = dynamic_cast<luci::CircleConst *>(g.mul->y());
EXPECT_NE(nullptr, new_multiplier);
- EXPECT_EQ(1, new_multiplier->rank());
+ EXPECT_EQ(4, new_multiplier->rank());
EXPECT_EQ(1, new_multiplier->dim(0).value());
+ EXPECT_EQ(1, new_multiplier->dim(1).value());
+ EXPECT_EQ(1, new_multiplier->dim(2).value());
+ EXPECT_EQ(1, new_multiplier->dim(3).value());
check_pre_trans(g.output->from());
}
@@ -1451,6 +1814,85 @@ TEST(ConvertNCHWToNHWC, Preserve_Input_Output)
}
}
+TEST(ConvertNCHWToNHWC, ReduceMax)
+{
+ ReduceMaxGraph g;
+ g.init();
+
+ run_phase(&g.g, false, false);
+
+ check_pre_trans(g.rm->input());
+
+ auto rm_succs = loco::succs(g.rm);
+ EXPECT_EQ(1, rm_succs.size());
+ check_post_trans(*rm_succs.begin());
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(2, new_rindices->dim(0).value());
+ EXPECT_EQ(2, new_rindices->size<loco::DataType::S32>());
+ EXPECT_EQ(1, new_rindices->at<loco::DataType::S32>(0));
+ EXPECT_EQ(2, new_rindices->at<loco::DataType::S32>(1));
+}
+
+TEST(ConvertNCHWToNHWC, ReduceMax_keep_dims_false)
+{
+ struct TC
+ {
+ std::vector<int32_t> nchw_ind;
+ std::vector<int32_t> nhwc_ind;
+ std::initializer_list<uint32_t> shape;
+ bool needs_transpose = false;
+ };
+
+ uint32_t n = 1;
+ uint32_t c = 16;
+ uint32_t h = 4;
+ uint32_t w = 4;
+
+ std::vector<TC> test_cases{{{0}, {0}, {c, h, w}, true}, {{1}, {3}, {n, h, w}, false},
+ {{2}, {1}, {n, c, w}, true}, {{3}, {2}, {n, c, h}, true},
+ {{0, 1}, {0, 3}, {h, w}, false}, {{0, 2}, {0, 1}, {c, w}, true},
+ {{0, 3}, {0, 2}, {c, h}, true}, {{1, 2}, {3, 1}, {n, w}, false},
+ {{1, 3}, {3, 2}, {n, h}, false}, {{2, 3}, {1, 2}, {n, c}, false},
+ {{0, 1, 2}, {0, 3, 1}, {w}, false}};
+
+ for (auto &tc : test_cases)
+ {
+ ReduceMaxGraph g;
+ g.keep_dims(false);
+ g.axes(tc.nchw_ind);
+ g.shape(tc.shape);
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.rm->input());
+
+ auto rm_succs = loco::succs(g.rm);
+ EXPECT_EQ(1, rm_succs.size());
+ if (tc.needs_transpose)
+ {
+ EXPECT_NE(nullptr, dynamic_cast<luci::CircleTranspose *>(*rm_succs.begin()));
+ }
+ else
+ {
+ EXPECT_NE(nullptr, dynamic_cast<luci::CircleOutput *>(*rm_succs.begin()));
+ }
+
+ auto new_rindices = dynamic_cast<luci::CircleConst *>(g.rm->reduction_indices());
+ EXPECT_NE(nullptr, new_rindices);
+ EXPECT_EQ(1, new_rindices->rank());
+ EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->dim(0).value());
+ EXPECT_EQ(tc.nhwc_ind.size(), new_rindices->size<loco::DataType::S32>());
+ for (uint32_t i = 0; i < tc.nhwc_ind.size(); ++i)
+ {
+ EXPECT_EQ(tc.nhwc_ind[i], new_rindices->at<loco::DataType::S32>(i));
+ }
+ }
+}
+
TEST(ConvertNCHWToNHWC, Relu)
{
ReluGraph g;
@@ -1511,6 +1953,57 @@ TEST(ConvertNCHWToNHWC, Rsqrt)
EXPECT_EQ(16, g.rsqrt->dim(3).value());
}
+TEST(ConvertNCHWToNHWC, Softmax)
+{
+ SoftmaxGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.softmax->logits());
+
+ auto softmax_succs = loco::succs(g.softmax);
+ EXPECT_EQ(1, softmax_succs.size());
+ check_post_trans(*softmax_succs.begin());
+
+ // Check softmax shape
+ EXPECT_EQ(1, g.softmax->dim(0).value());
+ EXPECT_EQ(4, g.softmax->dim(1).value());
+ EXPECT_EQ(4, g.softmax->dim(2).value());
+ EXPECT_EQ(16, g.softmax->dim(3).value());
+}
+
+TEST(ConvertNCHWToNHWC, SplitV)
+{
+ SplitVGraph g;
+ g.init();
+
+ run_phase(g.g(), true, true);
+
+ check_pre_trans(g.splitv()->input());
+
+ auto splitv_succs = loco::succs(g.splitv());
+ for (auto svo : loco::succs(g.splitv()))
+ {
+ for (auto succ : loco::succs(svo))
+ {
+ check_post_trans(succ);
+ }
+ }
+
+ // Check splitv() shape
+ EXPECT_EQ(1, g.splitv()->dim(0).value());
+ EXPECT_EQ(2, g.splitv()->dim(1).value());
+ EXPECT_EQ(192, g.splitv()->dim(2).value());
+ EXPECT_EQ(2, g.splitv()->dim(3).value());
+
+ // Check axis
+ auto axis = dynamic_cast<luci::CircleConst *>(g.splitv()->split_dim());
+ EXPECT_NE(nullptr, axis);
+ EXPECT_EQ(1, axis->size<loco::DataType::S32>());
+ EXPECT_EQ(2, axis->at<loco::DataType::S32>(0));
+}
+
TEST(ConvertNCHWToNHWC, SquaredDifference)
{
SquaredDifferenceGraph g;
@@ -1602,3 +2095,31 @@ TEST(ConvertNCHWToNHWC, SubScalar)
check_pre_trans(g.output->from());
}
+
+TEST(ConvertNCHWToNHWC, Not_Closed_Case1_NEG)
+{
+ NoPostReshapeGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.relu->features());
+
+ auto relu_succs = loco::succs(g.relu);
+ EXPECT_EQ(1, relu_succs.size());
+ check_post_trans(*relu_succs.begin());
+}
+
+TEST(ConvertNCHWToNHWC, Not_Closed_Case2_NEG)
+{
+ ReluNotClosedGraph g;
+ g.init();
+
+ run_phase(&g.g, true, true);
+
+ check_pre_trans(g.relu->features());
+
+ auto relu_succs = loco::succs(g.relu);
+ EXPECT_EQ(1, relu_succs.size());
+ check_post_trans(*relu_succs.begin());
+}
diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
index 11970fff5..72f590135 100644
--- a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
+++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
@@ -184,8 +184,63 @@ struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
// For non-const activation, insert Quantize-Dequantize Ops
// and dequantize the node
- void visit(luci::CircleConv2D *node) { fq_activation(node); }
void visit(luci::CircleAdd *node) { fq_activation(node); }
+ void visit(luci::CircleAveragePool2D *node) { fq_activation(node); }
+ void visit(luci::CircleBatchMatMul *node) { fq_activation(node); }
+ void visit(luci::CircleConv2D *node) { fq_activation(node); }
+ void visit(luci::CircleDepthwiseConv2D *node) { fq_activation(node); }
+ void visit(luci::CircleDiv *node) { fq_activation(node); }
+ void visit(luci::CircleFullyConnected *node) { fq_activation(node); }
+ void visit(luci::CircleInstanceNorm *node) { fq_activation(node); }
+ void visit(luci::CircleLeakyRelu *node) { fq_activation(node); }
+ void visit(luci::CircleLogistic *node) { fq_activation(node); }
+ void visit(luci::CircleLogSoftmax *node) { fq_activation(node); }
+ void visit(luci::CircleMaxPool2D *node) { fq_activation(node); }
+ void visit(luci::CircleMul *node) { fq_activation(node); }
+ void visit(luci::CircleNeg *node) { fq_activation(node); }
+ void visit(luci::CirclePad *node) { fq_activation(node); }
+ void visit(luci::CirclePRelu *node) { fq_activation(node); }
+ void visit(luci::CircleMean *node) { fq_activation(node); }
+ void visit(luci::CircleReduceMax *node) { fq_activation(node); }
+ void visit(luci::CircleRelu *node) { fq_activation(node); }
+ void visit(luci::CircleRelu6 *node) { fq_activation(node); }
+ void visit(luci::CircleResizeBilinear *node) { fq_activation(node); }
+ void visit(luci::CircleResizeNearestNeighbor *node) { fq_activation(node); }
+ void visit(luci::CircleRsqrt *node) { fq_activation(node); }
+ void visit(luci::CircleSoftmax *node) { fq_activation(node); }
+ void visit(luci::CircleSqrt *node) { fq_activation(node); }
+ void visit(luci::CircleTanh *node) { fq_activation(node); }
+ void visit(luci::CircleTransposeConv *node) { fq_activation(node); }
+
+ // For Ops that do not change the value of input, do nothing
+ // (dtype will be automatically updated by type inference)
+ void visit(luci::CircleCast *) {}
+ void visit(luci::CircleConcatenation *) {}
+ void visit(luci::CircleGather *) {}
+ void visit(luci::CircleSlice *) {}
+ void visit(luci::CircleStridedSlice *) {}
+ void visit(luci::CircleReshape *) {}
+ void visit(luci::CircleSplit *) {}
+ void visit(luci::CircleSplitOut *) {}
+ void visit(luci::CircleSplitV *) {}
+ void visit(luci::CircleSplitVOut *) {}
+ void visit(luci::CircleTranspose *) {}
+
+ // For Ops that return index, fake quantization is unnecessary
+ void visit(luci::CircleArgMax *) {}
+
+ // Virtual node
+ void visit(luci::CircleOutputExclude *) {}
+
+ void visit(luci::CircleQuantize *node)
+ {
+ RETURN_UNLESS(is_quant_act(node));
+
+ insert_dequantize(node);
+ }
+
+ // Dequantize Op does nothing in fp32 model
+ void visit(luci::CircleDequantize *) {}
};
#undef RETURN_UNLESS
diff --git a/compiler/luci/pass/src/FoldDensifyPass.cpp b/compiler/luci/pass/src/FoldDensifyPass.cpp
new file mode 100644
index 000000000..5ddc743e5
--- /dev/null
+++ b/compiler/luci/pass/src/FoldDensifyPass.cpp
@@ -0,0 +1,180 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldDensifyPass.h"
+#include "helpers/SparsityFormatConverter.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+
+#include <cassert>
+#include <vector>
+
+namespace
+{
+
+bool is_foldable_const(luci::CircleConst *node)
+{
+ if (node->sparsityparam() == nullptr)
+ return false;
+
+ if (node->dtype() == loco::DataType::FLOAT32)
+ return true;
+ if (node->dtype() == loco::DataType::FLOAT16)
+ return true;
+
+ return false;
+}
+
+luci::CircleConst *densified_const_node(luci::CircleConst *const_node)
+{
+ assert(const_node->sparsityparam());
+
+ auto name = const_node->name();
+ assert(name.length() > 0);
+ auto g = const_node->graph();
+ auto new_const_node = g->nodes()->create<luci::CircleConst>();
+
+ new_const_node->dtype(const_node->dtype());
+ new_const_node->rank(const_node->rank());
+
+ uint32_t dim_size = 1;
+ std::vector<int> dense_shape;
+ for (uint32_t i = 0; i < new_const_node->rank(); ++i)
+ {
+ assert(const_node->dim(i).known());
+ new_const_node->dim(i) = const_node->dim(i);
+
+ uint32_t value = const_node->dim(i).value();
+ dim_size *= value;
+ dense_shape.emplace_back(static_cast<int32_t>(value));
+ }
+
+ if (const_node->dtype() == loco::DataType::FLOAT32)
+ new_const_node->size<loco::DataType::FLOAT32>(dim_size);
+ else
+ {
+ assert(const_node->dtype() == loco::DataType::FLOAT16);
+ new_const_node->size<loco::DataType::FLOAT16>(dim_size);
+ }
+
+ new_const_node->shape_status(luci::ShapeStatus::VALID);
+ new_const_node->name(name + "_DS");
+
+ if (const_node->dtype() == loco::DataType::FLOAT32)
+ {
+ auto const_items = const_node->size<loco::DataType::FLOAT32>();
+ auto f_data = std::make_unique<float[]>(const_items);
+ for (size_t i = 0; i < const_items; ++i)
+ f_data[i] = const_node->at<loco::DataType::FLOAT32>(i);
+
+ sparsity::TfLiteSparsity sp = to_tflite_sparsity(const_node->sparsityparam());
+ sparsity::FormatConverter<float> converter(dense_shape, sp);
+ converter.SparseToDense(f_data.get());
+ const auto &data_dense = converter.GetData();
+ assert(data_dense.size() == dim_size);
+
+ for (uint32_t i = 0; i < dim_size; ++i)
+ new_const_node->at<loco::DataType::FLOAT32>(i) = data_dense[i];
+
+ luci::freeTfLiteSparsity(sp);
+ }
+ else
+ {
+ assert(const_node->dtype() == loco::DataType::FLOAT16);
+
+ auto const_items = const_node->size<loco::DataType::FLOAT16>();
+ auto f_data = std::make_unique<uint16_t[]>(const_items);
+ for (size_t i = 0; i < const_items; ++i)
+ f_data[i] = const_node->at<loco::DataType::FLOAT16>(i);
+
+ // Primitive type for FLOAT16 is UINT16
+ sparsity::TfLiteSparsity sp = to_tflite_sparsity(const_node->sparsityparam());
+ sparsity::FormatConverter<uint16_t> converter(dense_shape, sp);
+ converter.SparseToDense(f_data.get());
+ const auto &data_dense = converter.GetData();
+ assert(data_dense.size() == dim_size);
+ for (uint32_t i = 0; i < dim_size; ++i)
+ new_const_node->at<loco::DataType::FLOAT16>(i) = data_dense[i];
+
+ luci::freeTfLiteSparsity(sp);
+ }
+
+ return new_const_node;
+}
+
+/**
+ * @brief Fold Densify if input is Sparse Constant
+ */
+bool fold_densify(luci::CircleDensify *densify)
+{
+ auto const_input = dynamic_cast<luci::CircleConst *>(densify->input());
+ if (not const_input)
+ return false;
+
+ if (not is_foldable_const(const_input))
+ return false;
+
+ auto dense_const = densified_const_node(const_input);
+ assert(dense_const);
+
+ loco::replace(densify).with(dense_const);
+ luci::add_origin(dense_const, luci::composite_origin(
+ {luci::get_origin(densify), luci::get_origin(const_input)}));
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ *
+ * [CircleConst](sparse)
+ * |
+ * [CircleDensify]
+ * |
+ * [CircleNode]
+ * |
+ *
+ * AFTER
+ *
+ * [CircleConst](dense) [CircleConst](sparse)
+ * | |
+ * [CircleNode] [CircleDensify]
+ * |
+ */
+bool FoldDensifyPass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto densify = dynamic_cast<luci::CircleDensify *>(node))
+ {
+ if (fold_densify(densify))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/FoldDensifyPass.test.cpp b/compiler/luci/pass/src/FoldDensifyPass.test.cpp
new file mode 100644
index 000000000..2f9736f49
--- /dev/null
+++ b/compiler/luci/pass/src/FoldDensifyPass.test.cpp
@@ -0,0 +1,158 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FoldDensifyPass.h"
+#include "PassTestGraphs.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class FoldDensifyPassGraph : public luci::ConstantFoldingAddTestGraph
+{
+public:
+ FoldDensifyPassGraph(std::initializer_list<uint32_t> shape)
+ : luci::ConstantFoldingAddTestGraph(shape, loco::DataType::FLOAT32)
+ {
+ _densify = _g.nodes()->create<luci::CircleDensify>();
+ _x = _g.nodes()->create<luci::CircleConst>();
+
+ _densify->dtype(loco::DataType::FLOAT32);
+ _x->dtype(loco::DataType::FLOAT32);
+
+ _densify->shape(shape);
+ _x->shape(shape);
+
+ _densify->input(_x);
+
+ _densify->name("densify");
+ _x->name("x");
+ }
+
+ loco::Node *createFoldedPattern() override { return _densify; }
+
+public:
+ void fill_const_dense(void)
+ {
+ uint32_t num_elems = 1;
+ for (uint32_t r = 0; r < _x->rank(); ++r)
+ num_elems *= _x->dim(r).value();
+
+ _x->size<loco::DataType::FLOAT32>(num_elems);
+ for (uint32_t i = 0; i < num_elems; i++)
+ _x->at<loco::DataType::FLOAT32>(i) = static_cast<float>(i + 1);
+ }
+
+ void fill_const_sparse(void)
+ {
+ // fill 4x4 of
+ // [[1 0 0 0]
+ // [0 2 0 0]
+ // [0 0 3 0]
+ // [0 0 0 4]]
+
+ // values of 1.0, 2.0, 3.0, 4.0
+ uint32_t udata[] = {0x3f800000, 0x40000000, 0x40400000, 0x40800000};
+ float *fdata = reinterpret_cast<float *>(udata);
+
+ _x->size<loco::DataType::FLOAT32>(4);
+ for (uint32_t i = 0; i < 4; i++)
+ _x->at<loco::DataType::FLOAT32>(i) = fdata[i];
+
+ auto sparsityparam = std::make_unique<luci::SparsityParam>();
+ sparsityparam->traversal_order = std::vector<int32_t>({0, 1});
+ sparsityparam->block_map = std::vector<int32_t>({});
+
+ auto dm0 = luci::DimMetaData(luci::DimensionType::DENSE, 4);
+
+ std::vector<int32_t> as_vec = {0, 1, 2, 3, 4};
+ std::vector<int32_t> ai_vec = {0, 1, 2, 3};
+ auto as = luci::SparseIndexVector(luci::SparseIndexVectorType::I32, as_vec);
+ auto ai = luci::SparseIndexVector(luci::SparseIndexVectorType::I32, ai_vec);
+ auto dm1 = luci::DimMetaData(luci::DimensionType::SPARSE_CSR, 0, as, ai);
+ sparsityparam->dim_metadata.emplace_back(dm0);
+ sparsityparam->dim_metadata.emplace_back(dm1);
+
+ _x->sparsityparam(std::move(sparsityparam));
+ }
+
+protected:
+ luci::CircleDensify *_densify = nullptr;
+ luci::CircleConst *_x = nullptr;
+};
+
+class FoldDensifyPassGraphTest : public FoldDensifyPassGraph, public ::testing::Test
+{
+public:
+ FoldDensifyPassGraphTest() : FoldDensifyPassGraph({4, 4}) {}
+
+ virtual void SetUp() { init(); }
+};
+
+} // namespace
+
+TEST(FoldDensifyPassGraph, name)
+{
+ luci::FoldDensifyPass pass;
+ auto const name = pass.name();
+ ASSERT_NE(nullptr, name);
+}
+
+TEST_F(FoldDensifyPassGraphTest, no_sparsity_param_NEG)
+{
+ fill_const_dense();
+
+ luci::FoldDensifyPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_EQ(nullptr, folded_const);
+}
+
+TEST_F(FoldDensifyPassGraphTest, sparsity_param)
+{
+ fill_const_sparse();
+
+ luci::FoldDensifyPass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ EXPECT_EQ(2, folded_const->rank());
+ EXPECT_EQ(4, folded_const->dim(0).value());
+ EXPECT_EQ(4, folded_const->dim(1).value());
+ EXPECT_EQ(16, folded_const->size<loco::DataType::FLOAT32>());
+ for (int y = 0; y < 4; ++y)
+ {
+ for (int x = 0; x < 4; ++x)
+ {
+ float ovalue = folded_const->at<loco::DataType::FLOAT32>(y * 4 + x);
+ float fvalue = 0.0;
+ if (x == y)
+ {
+ // diagonal position
+ fvalue = static_cast<float>(y + 1);
+ }
+ EXPECT_EQ(fvalue, ovalue);
+ }
+ }
+}
diff --git a/compiler/luci/pass/src/FoldDequantizePass.cpp b/compiler/luci/pass/src/FoldDequantizePass.cpp
index 3dd4f8cea..b6526deb0 100644
--- a/compiler/luci/pass/src/FoldDequantizePass.cpp
+++ b/compiler/luci/pass/src/FoldDequantizePass.cpp
@@ -19,6 +19,8 @@
#include <luci/IR/CircleNodes.h>
#include <luci/Profile/CircleNodeOrigin.h>
+#include <fp16.h>
+
namespace
{
@@ -32,6 +34,9 @@ bool is_hybrid_kernel_supported(loco::Node *node)
bool is_foldable_const(luci::CircleConst *node)
{
+ if (node->dtype() == loco::DataType::FLOAT16)
+ return true;
+
if (node->quantparam() == nullptr)
return false;
@@ -39,17 +44,18 @@ bool is_foldable_const(luci::CircleConst *node)
return true;
if (node->dtype() == loco::DataType::U8)
return true;
+ if (node->dtype() == loco::DataType::S16)
+ return true;
+ if (node->dtype() == loco::DataType::S32)
+ return true;
+ if (node->dtype() == loco::DataType::S64)
+ return true;
return false;
}
luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node)
{
- if (const_node->quantparam() == nullptr)
- {
- throw std::runtime_error("Given constant node has no quantization parameter");
- }
-
auto name = const_node->name();
assert(name.length() > 0);
auto g = const_node->graph();
@@ -67,38 +73,70 @@ luci::CircleConst *dequantized_const_node(luci::CircleConst *const_node)
new_const_node->shape_status(luci::ShapeStatus::VALID);
new_const_node->name(name + "_DQ");
+ if (const_node->dtype() == loco::DataType::FLOAT16)
+ {
+ for (uint32_t i = 0; i < new_const_node->size<loco::DataType::FLOAT32>(); ++i)
+ {
+ auto raw = const_node->at<loco::DataType::FLOAT16>(i);
+ new_const_node->at<loco::DataType::FLOAT32>(i) = fp16_ieee_to_fp32_value(raw);
+ }
+ return new_const_node;
+ }
+
+ if (const_node->quantparam() == nullptr)
+ {
+ throw std::runtime_error("Given constant node has no quantization parameter");
+ }
+
const int32_t q_dim = const_node->quantparam()->quantized_dimension;
- const int32_t q_dim_value = const_node->dim(q_dim).value();
+ // For scalar, q_dim_value is 1
+ // For non-scalar, q_dim_value is the size of quantized dimension
+ const int32_t q_dim_value = const_node->rank() == 0 ? 1 : const_node->dim(q_dim).value();
int32_t right_count = q_dim_value;
for (uint32_t i = q_dim + 1; i < const_node->rank(); ++i)
right_count *= const_node->dim(i).value();
- if (const_node->dtype() == loco::DataType::S8)
+ for (uint32_t i = 0; i < new_const_node->size<loco::DataType::FLOAT32>(); ++i)
{
- for (uint32_t i = 0; i < const_node->size<loco::DataType::S8>(); ++i)
- {
- uint32_t qd = (i % right_count) / (right_count / q_dim_value);
- if (qd >= const_node->quantparam()->zerop.size())
- qd = 0;
+ uint32_t qd = (i % right_count) / (right_count / q_dim_value);
+ if (qd >= const_node->quantparam()->zerop.size())
+ qd = 0;
- new_const_node->at<loco::DataType::FLOAT32>(i) =
- (float)(const_node->at<loco::DataType::S8>(i) - const_node->quantparam()->zerop.at(qd)) *
- const_node->quantparam()->scale.at(qd);
- }
- }
- else
- {
- for (uint32_t i = 0; i < const_node->size<loco::DataType::U8>(); ++i)
+ switch (const_node->dtype())
{
- uint32_t qd = (i % right_count) / (right_count / q_dim_value);
- if (qd >= const_node->quantparam()->zerop.size())
- qd = 0;
-
- new_const_node->at<loco::DataType::FLOAT32>(i) =
- (float)((int)const_node->at<loco::DataType::U8>(i) -
- const_node->quantparam()->zerop.at(qd)) *
- const_node->quantparam()->scale.at(qd);
+ case loco::DataType::S8:
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ static_cast<float>(const_node->at<loco::DataType::S8>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ break;
+ case loco::DataType::S16:
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ static_cast<float>(const_node->at<loco::DataType::S16>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ break;
+ case loco::DataType::S32:
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ static_cast<float>(const_node->at<loco::DataType::S32>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ break;
+ case loco::DataType::S64:
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ static_cast<float>(const_node->at<loco::DataType::S64>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ break;
+ case loco::DataType::U8:
+ new_const_node->at<loco::DataType::FLOAT32>(i) =
+ static_cast<float>(const_node->at<loco::DataType::U8>(i) -
+ const_node->quantparam()->zerop.at(qd)) *
+ const_node->quantparam()->scale.at(qd);
+ break;
+ default:
+ throw std::runtime_error("Not supported dtype for FoldDequantizePass");
}
}
@@ -160,7 +198,7 @@ bool FoldDequantizePass::run(loco::Graph *g)
{
bool changed = false;
- for (auto node : loco::all_nodes(g))
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
if (auto circle_dequant = dynamic_cast<luci::CircleDequantize *>(node))
{
diff --git a/compiler/luci/pass/src/FoldDequantizePass.test.cpp b/compiler/luci/pass/src/FoldDequantizePass.test.cpp
index d82a7bc87..fb5b6adc0 100644
--- a/compiler/luci/pass/src/FoldDequantizePass.test.cpp
+++ b/compiler/luci/pass/src/FoldDequantizePass.test.cpp
@@ -15,12 +15,389 @@
*/
#include "luci/Pass/FoldDequantizePass.h"
+#include "PassTestGraphs.h"
#include <gtest/gtest.h>
+namespace
+{
+
+template <loco::DataType DT>
+class FoldDequantizeTest : public luci::ConstantFoldingAddTestGraph, public ::testing::Test
+{
+public:
+ FoldDequantizeTest() : luci::ConstantFoldingAddTestGraph({2, 2, 2}, DT) {}
+
+ virtual void SetUp() { init(); }
+
+ loco::Node *createFoldedPattern() override
+ {
+ _dequantize = _g.nodes()->create<luci::CircleDequantize>();
+ _input = _g.nodes()->create<luci::CircleConst>();
+
+ _dequantize->dtype(loco::DataType::FLOAT32);
+ _input->dtype(DT);
+
+ _input->shape({2, 2, 2});
+
+ _input->size<DT>(8);
+ _input->at<DT>(0) = 0;
+ _input->at<DT>(1) = 1;
+ _input->at<DT>(2) = 2;
+ _input->at<DT>(3) = 3;
+ _input->at<DT>(4) = 4;
+ _input->at<DT>(5) = 5;
+ _input->at<DT>(6) = 6;
+ _input->at<DT>(7) = 7;
+
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->quantized_dimension = 1;
+ qparam->scale.push_back(5.0);
+ qparam->scale.push_back(10.0);
+ qparam->zerop.push_back(1);
+ qparam->zerop.push_back(2);
+ _input->quantparam(std::move(qparam));
+
+ _dequantize->input(_input);
+
+ _dequantize->name("dequantize");
+ _input->name("input");
+
+ return _dequantize;
+ }
+
+ void createScalarPattern()
+ {
+ _input->rank(0);
+ _input->size<DT>(1);
+ _input->at<DT>(0) = 1;
+
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->quantized_dimension = 0;
+ qparam->scale.push_back(1.0);
+ qparam->zerop.push_back(0);
+ _input->quantparam(std::move(qparam));
+ }
+
+ void createNotFoldablePattern() { _input->quantparam(nullptr); }
+
+protected:
+ luci::CircleDequantize *_dequantize = nullptr;
+ luci::CircleConst *_input = nullptr;
+};
+
+class S8FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S8>
+{
+};
+
+class S16FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S16>
+{
+};
+
+class S32FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S32>
+{
+};
+
+class S64FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::S64>
+{
+};
+
+class U8FoldDequantizeTest : public FoldDequantizeTest<loco::DataType::U8>
+{
+};
+
+class F16FoldDequantizeTest : public luci::ConstantFoldingTestGraph, public ::testing::Test
+{
+public:
+ F16FoldDequantizeTest() : ConstantFoldingTestGraph({2, 2}, loco::DataType::FLOAT16) {}
+
+ virtual void SetUp() { init(); }
+
+ loco::Node *createFoldedPattern() override
+ {
+ const auto DT = loco::DataType::FLOAT16;
+ _dequantize = _g.nodes()->create<luci::CircleDequantize>();
+ _f16const = _g.nodes()->create<luci::CircleConst>();
+
+ _dequantize->dtype(loco::DataType::FLOAT32);
+ _f16const->dtype(DT);
+
+ _f16const->shape({2, 2});
+
+ _f16const->size<loco::DataType::FLOAT16>(4);
+ _f16const->at<DT>(0) = 49408; // -2.5f
+ _f16const->at<DT>(1) = 47104; // -0.5f
+ _f16const->at<DT>(2) = 0; // 0.0f
+ _f16const->at<DT>(3) = 15872; // 1.5f
+ // NOTE how to get uint16_t value of float16 ?
+ // Use compiler/souschef/src/Gaussian.cpp GaussianFloat16DataChef::generate()
+ // uint16_t value = fp16_ieee_from_fp32_value(-2.5);
+ // printf("-2.5 = %u\r\n", value);
+
+ _dequantize->input(_f16const);
+
+ _dequantize->name("dequantize");
+ _f16const->name("input");
+
+ _output->from(_dequantize);
+
+ return _dequantize;
+ }
+
+ void createNotFoldablePattern() { _dequantize->input(_input); }
+
+protected:
+ luci::CircleConst *getFoldedPattern() override
+ {
+ return dynamic_cast<luci::CircleConst *>(_output->from());
+ }
+
+ void init() override { createFoldedPattern(); }
+
+protected:
+ luci::CircleDequantize *_dequantize = nullptr;
+ luci::CircleConst *_f16const = nullptr;
+};
+
+} // namespace
+
TEST(FoldDequantizePassTest, name)
{
luci::FoldDequantizePass pass;
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
+
+TEST_F(U8FoldDequantizeTest, fold_dequant_basic)
+{
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
+ EXPECT_EQ(3, folded_const->rank());
+ EXPECT_EQ(2, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->dim(1).value());
+ EXPECT_EQ(2, folded_const->dim(2).value());
+ EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
+ EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
+ EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
+ EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
+ EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
+ EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
+}
+
+TEST_F(U8FoldDequantizeTest, fold_dequant_basic_NEG)
+{
+ createNotFoldablePattern();
+
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_EQ(nullptr, folded_const);
+}
+
+TEST_F(S8FoldDequantizeTest, fold_dequant_basic)
+{
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
+ EXPECT_EQ(3, folded_const->rank());
+ EXPECT_EQ(2, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->dim(1).value());
+ EXPECT_EQ(2, folded_const->dim(2).value());
+ EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
+ EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
+ EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
+ EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
+ EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
+ EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
+}
+
+TEST_F(S8FoldDequantizeTest, fold_dequant_basic_NEG)
+{
+ createNotFoldablePattern();
+
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_EQ(nullptr, folded_const);
+}
+
+TEST_F(S16FoldDequantizeTest, fold_dequant_basic)
+{
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
+ EXPECT_EQ(3, folded_const->rank());
+ EXPECT_EQ(2, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->dim(1).value());
+ EXPECT_EQ(2, folded_const->dim(2).value());
+ EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
+ EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
+ EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
+ EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
+ EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
+ EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
+}
+
+TEST_F(S16FoldDequantizeTest, fold_dequant_basic_NEG)
+{
+ createNotFoldablePattern();
+
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_EQ(nullptr, folded_const);
+}
+
+TEST_F(S32FoldDequantizeTest, fold_dequant_basic)
+{
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
+ EXPECT_EQ(3, folded_const->rank());
+ EXPECT_EQ(2, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->dim(1).value());
+ EXPECT_EQ(2, folded_const->dim(2).value());
+ EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
+ EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
+ EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
+ EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
+ EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
+ EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
+}
+
+TEST_F(S32FoldDequantizeTest, fold_dequant_basic_NEG)
+{
+ createNotFoldablePattern();
+
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_EQ(nullptr, folded_const);
+}
+
+TEST_F(S64FoldDequantizeTest, fold_dequant_basic)
+{
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
+ EXPECT_EQ(3, folded_const->rank());
+ EXPECT_EQ(2, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->dim(1).value());
+ EXPECT_EQ(2, folded_const->dim(2).value());
+ EXPECT_EQ(-5.0, folded_const->at<loco::DataType::FLOAT32>(0));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(1));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
+ EXPECT_EQ(10.0, folded_const->at<loco::DataType::FLOAT32>(3));
+ EXPECT_EQ(15.0, folded_const->at<loco::DataType::FLOAT32>(4));
+ EXPECT_EQ(20.0, folded_const->at<loco::DataType::FLOAT32>(5));
+ EXPECT_EQ(40.0, folded_const->at<loco::DataType::FLOAT32>(6));
+ EXPECT_EQ(50.0, folded_const->at<loco::DataType::FLOAT32>(7));
+}
+
+TEST_F(S64FoldDequantizeTest, fold_dequant_basic_NEG)
+{
+ createNotFoldablePattern();
+
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_EQ(nullptr, folded_const);
+}
+
+TEST_F(U8FoldDequantizeTest, fold_dequant_scalar)
+{
+ createScalarPattern();
+
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Check type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
+ EXPECT_EQ(0, folded_const->rank());
+ EXPECT_EQ(1.0, folded_const->at<loco::DataType::FLOAT32>(0));
+}
+
+TEST_F(F16FoldDequantizeTest, fold_dequant_basic)
+{
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_NE(nullptr, folded_const);
+
+ // Chec type, shape, values of folded const
+ EXPECT_EQ(loco::DataType::FLOAT32, folded_const->dtype());
+ EXPECT_EQ(2, folded_const->rank());
+ EXPECT_EQ(2, folded_const->dim(0).value());
+ EXPECT_EQ(2, folded_const->dim(1).value());
+ EXPECT_EQ(-2.5, folded_const->at<loco::DataType::FLOAT32>(0));
+ EXPECT_EQ(-0.5, folded_const->at<loco::DataType::FLOAT32>(1));
+ EXPECT_EQ(0.0, folded_const->at<loco::DataType::FLOAT32>(2));
+ EXPECT_EQ(1.5, folded_const->at<loco::DataType::FLOAT32>(3));
+}
+
+TEST_F(F16FoldDequantizeTest, fold_dequant_basic_NEG)
+{
+ createNotFoldablePattern();
+
+ luci::FoldDequantizePass pass;
+ while (pass.run(graph()))
+ ;
+
+ auto folded_const = getFoldedPattern();
+ EXPECT_EQ(nullptr, folded_const);
+}
diff --git a/compiler/luci/pass/src/FoldSparseToDensePass.cpp b/compiler/luci/pass/src/FoldSparseToDensePass.cpp
index 0c6fc43ed..ed60d8899 100644
--- a/compiler/luci/pass/src/FoldSparseToDensePass.cpp
+++ b/compiler/luci/pass/src/FoldSparseToDensePass.cpp
@@ -19,6 +19,8 @@
#include <luci/IR/CircleNodes.h>
+#include <limits>
+
namespace
{
diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
index 2c990f0a5..bc09abee2 100644
--- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
+++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
@@ -22,6 +22,7 @@
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Service/CircleShapeInference.h>
#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Service/CircleNodeClone.h>
namespace
{
@@ -55,6 +56,26 @@ void copy_shape(luci::CircleReshape *reshape, luci::CircleReshape *new_reshape)
new_reshape->newShape()->dim(r) = reshape->newShape()->dim(r);
}
+luci::CircleReshape *create_cloned_reshape(luci::CircleReshape *reshape)
+{
+ assert(reshape != nullptr); // FIX_CALLER_UNLESS
+
+ luci::CircleConst *cloned_shape = clone_shape(reshape);
+ if (cloned_shape == nullptr)
+ return nullptr;
+
+ auto cloned_node = luci::clone_node(reshape, reshape->graph());
+ if (cloned_node == nullptr)
+ return nullptr;
+
+ auto new_reshape = loco::must_cast<luci::CircleReshape *>(cloned_node);
+ new_reshape->shape(cloned_shape);
+ new_reshape->name(reshape->name() + "_C");
+ luci::add_origin(new_reshape, luci::get_origin(reshape));
+
+ return new_reshape;
+}
+
bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg)
{
assert(reshape != nullptr);
@@ -85,6 +106,26 @@ bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg)
return true;
}
+bool forward_reshape(luci::CircleReshape *reshape, luci::CircleLogistic *logit)
+{
+ assert(reshape != nullptr); // FIX_CALLER_UNLESS
+ assert(logit != nullptr); // FIX_CALLER_UNLESS
+
+ auto new_reshape = create_cloned_reshape(reshape);
+ if (not new_reshape)
+ return false;
+
+ // reconnect network
+ loco::replace(logit).with(new_reshape);
+ logit->x(reshape->tensor());
+ new_reshape->tensor(logit);
+
+ // Do shape inference for this node again.
+ logit->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ return true;
+}
+
class ForwardReshape final : public luci::CircleNodeMutableVisitor<bool>
{
protected:
@@ -103,6 +144,14 @@ protected:
return forward_reshape(reshape, node);
}
+ bool visit(luci::CircleLogistic *node)
+ {
+ auto reshape = as_reshape(node->x());
+ if (reshape == nullptr)
+ return false;
+
+ return forward_reshape(reshape, node);
+ }
// TODO add more unary operators
};
diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp
index 2593a014c..373513270 100644
--- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp
+++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.test.cpp
@@ -65,6 +65,42 @@ protected:
luci::CircleConst *_reshape_shape = nullptr;
};
+// TODO Reduce duplicate code with ReshapeNegGraphlet
+class ReshapeLogisticGraphlet
+{
+public:
+ ReshapeLogisticGraphlet() = default;
+
+public:
+ void init(loco::Graph *g, const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ std::vector<uint32_t> shape_out_v = shape_out;
+
+ _reshape_shape = g->nodes()->create<luci::CircleConst>();
+ _reshape = g->nodes()->create<luci::CircleReshape>();
+ _logistic = g->nodes()->create<luci::CircleLogistic>();
+
+ _reshape_shape->dtype(loco::DataType::S32);
+ _reshape_shape->rank(1);
+ _reshape_shape->dim(0).set(shape_out_v.size());
+ _reshape_shape->shape_status(luci::ShapeStatus::VALID);
+ // values
+ const auto size = shape_out_v.size();
+ _reshape_shape->size<loco::DataType::S32>(size);
+ for (uint32_t i = 0; i < size; i++)
+ _reshape_shape->at<loco::DataType::S32>(i) = shape_out_v[i];
+
+ _reshape_shape->name("reshape_shape");
+ _reshape->name("reshape");
+ _logistic->name("logistic");
+ }
+
+protected:
+ luci::CircleReshape *_reshape = nullptr;
+ luci::CircleLogistic *_logistic = nullptr;
+ luci::CircleConst *_reshape_shape = nullptr;
+};
+
class ForwardReshapeToNegGraph : public TestIOGraph, public ReshapeNegGraphlet
{
public:
@@ -85,6 +121,26 @@ public:
}
};
+class ForwardReshapeToLogisticGraph : public TestIOGraph, public ReshapeLogisticGraphlet
+{
+public:
+ ForwardReshapeToLogisticGraph() = default;
+
+public:
+ void init(const ShapeU32 shape_in, const ShapeU32 shape_out)
+ {
+ TestIOGraph::init(shape_in, shape_out);
+ ReshapeLogisticGraphlet::init(g(), shape_in, shape_out);
+
+ // connect network
+ _reshape->tensor(input());
+ _reshape->shape(_reshape_shape);
+ _logistic->x(_reshape);
+
+ output()->from(_logistic);
+ }
+};
+
class ForwardReshapeToNegGraphTest : public ::testing::Test
{
public:
@@ -101,6 +157,22 @@ protected:
luci::ForwardReshapeToUnaryOpPass _pass;
};
+class ForwardReshapeToLogisticGraphTest : public ::testing::Test
+{
+public:
+ ForwardReshapeToLogisticGraphTest() = default;
+
+ void run_pass(void)
+ {
+ while (_pass.run(_graph.g()))
+ ;
+ }
+
+protected:
+ ForwardReshapeToLogisticGraph _graph;
+ luci::ForwardReshapeToUnaryOpPass _pass;
+};
+
} // namespace
TEST(ForwardReshapeToUnaryOpPassTest, name)
@@ -123,3 +195,17 @@ TEST_F(ForwardReshapeToNegGraphTest, simple_forward)
neg = dynamic_cast<luci::CircleNeg *>(reshape->tensor());
ASSERT_NE(nullptr, neg);
}
+
+TEST_F(ForwardReshapeToLogisticGraphTest, forward)
+{
+ _graph.init({2, 2, 2}, {2, 4});
+
+ run_pass();
+
+ auto reshape = dynamic_cast<luci::CircleReshape *>(_graph.output()->from());
+ auto log = dynamic_cast<luci::CircleLogistic *>(_graph.output()->from());
+ ASSERT_NE(nullptr, reshape);
+ ASSERT_EQ(nullptr, log);
+ log = dynamic_cast<luci::CircleLogistic *>(reshape->tensor());
+ ASSERT_NE(nullptr, log);
+}
diff --git a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
index 97a962cb6..3cf31ed10 100644
--- a/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
+++ b/compiler/luci/pass/src/FuseAddWithFullyConnectedPass.cpp
@@ -99,6 +99,12 @@ bool fuse_add_with_fc(luci::CircleFullyConnected *fc)
fused_bias->at<loco::DataType::FLOAT32>(i) += const_bias->at<loco::DataType::FLOAT32>(i);
}
+ // At this point, it is guarateed that fused_bias's shape is [1, 1, ..., N] or [N]
+ // where N is weights->dim(0).
+ // The shape is normalized to [N] to become the bias of FC
+ fused_bias->rank(1);
+ fused_bias->dim(0) = weights->dim(0);
+
fc->bias(fused_bias);
fc->fusedActivationFunction(add->fusedActivationFunction());
diff --git a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
index 2bca57014..852bc8b63 100644
--- a/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
+++ b/compiler/luci/pass/src/FuseAddWithTConvPass.cpp
@@ -37,10 +37,10 @@ namespace
* \ |
* [CircleTransposeConv] [CircleAdd]
* |
- * ([CircleRelu6])
+ * ([CircleRelu/Relu6])
* |
*
- * Note: CircleRelu6 is inserted if Add activation is ReLU6
+ * Note: CircleRelu/Relu6 is inserted if Add activation is ReLU6
*/
bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
{
@@ -65,7 +65,8 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
if (add->dtype() != loco::DataType::FLOAT32)
return false;
if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
- add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
+ add->fusedActivationFunction() != luci::FusedActFunc::RELU6 &&
+ add->fusedActivationFunction() != luci::FusedActFunc::RELU)
return false;
// get addition
@@ -102,6 +103,19 @@ bool fuse_add_with_tconv(luci::CircleTransposeConv *tconv)
// remove add node
replace(add).with(relu);
}
+ else if (add->fusedActivationFunction() == luci::FusedActFunc::RELU)
+ {
+ auto name = addition->name();
+ assert(name.length() > 0);
+ // separate relu op from add op
+ auto relu = add->graph()->nodes()->create<luci::CircleRelu>();
+ relu->features(tconv);
+ relu->name(name + "/Relu");
+ luci::add_origin(relu, luci::get_origin(add));
+
+ // remove add node
+ replace(add).with(relu);
+ }
else
{
replace(add).with(tconv);
diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
index 337954960..e6b54df36 100644
--- a/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
+++ b/compiler/luci/pass/src/FuseBatchNormWithTConvPass.cpp
@@ -29,7 +29,7 @@ namespace
* NOTE TF's BatchNormalization is converted to Mul and Add.
*
* BEFORE
- * | [CircleOutputExclude]
+ * | [CircleConst]/[CircleOutputExclude]
* | / [CircleConst]
* | / /
* [CircleTransposeConv] [CircleConst]
@@ -40,7 +40,7 @@ namespace
* |
*
* AFTER
- * | [CircleOutputExclude]
+ * | [CircleConst]/[CircleOutputExclude]
* +-------------------------------------+ / [CircleConst]
* | | / /
* | [CircleTransposeConv] [CircleConst]
@@ -69,9 +69,10 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
return false;
// check scale and shift constant attributes
- if (scale->rank() != 1)
+ // TODO maybe rank check is not needed
+ if (scale->rank() != 1 && scale->rank() != 4)
return false;
- if (shift->rank() != 1)
+ if (shift->rank() != 1 && shift->rank() != 4)
return false;
// check mul, add attributes
if (mul->dtype() != loco::DataType::FLOAT32)
@@ -82,9 +83,8 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
return false;
- // tconv bias should be not set
- if (not dynamic_cast<luci::CircleOutputExclude *>(tconv->bias()))
- return false;
+ // tconv bias is optional
+ auto bias = dynamic_cast<luci::CircleConst *>(tconv->bias());
// get weight of tconv
auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
@@ -96,10 +96,36 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
return false;
auto filter_out_chn = filter->dim(0).value();
- if (filter_out_chn != scale->dim(0).value())
+ // allow scale/shift and bias shape of [N], [1,1,1,N]; BN works for "channel-wise"
+ auto srank = scale->rank() - 1;
+ if (filter_out_chn != scale->dim(srank).value())
return false;
- if (filter_out_chn != shift->dim(0).value())
+ for (uint32_t d = 0; d < srank; ++d)
+ {
+ if (1 != scale->dim(d).value())
+ return false;
+ }
+ srank = shift->rank() - 1;
+ if (filter_out_chn != shift->dim(srank).value())
return false;
+ for (uint32_t d = 0; d < srank; ++d)
+ {
+ if (1 != shift->dim(d).value())
+ return false;
+ }
+ if (bias)
+ {
+ if (bias->dtype() != loco::DataType::FLOAT32)
+ return false;
+ srank = bias->rank() - 1;
+ if (filter_out_chn != bias->dim(srank).value())
+ return false;
+ for (uint32_t d = 0; d < srank; ++d)
+ {
+ if (1 != bias->dim(d).value())
+ return false;
+ }
+ }
auto name = add->name();
assert(name.length() > 0);
@@ -151,6 +177,11 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
for (uint32_t c = 0; c < filter_out_chn; ++c)
{
fused_bias->at<loco::DataType::FLOAT32>(c) = shift->at<loco::DataType::FLOAT32>(c);
+ if (bias != nullptr)
+ {
+ fused_bias->at<loco::DataType::FLOAT32>(c) +=
+ bias->at<loco::DataType::FLOAT32>(c) * scale->at<loco::DataType::FLOAT32>(c);
+ }
}
fused_bias->name(name + "/TransposeConv/bias");
@@ -166,6 +197,10 @@ bool fused_batch_norm_with_tconv(luci::CircleAdd *add)
luci::add_origin(fused_tconv,
luci::composite_origin(
{luci::get_origin(add), luci::get_origin(mul), luci::get_origin(tconv)}));
+ if (bias != nullptr)
+ {
+ luci::add_origin(fused_tconv, luci::get_origin(bias));
+ }
if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
{
diff --git a/compiler/luci/pass/src/FuseInstanceNormPass.cpp b/compiler/luci/pass/src/FuseInstanceNormPass.cpp
index f3ec6cd9e..10a651e35 100644
--- a/compiler/luci/pass/src/FuseInstanceNormPass.cpp
+++ b/compiler/luci/pass/src/FuseInstanceNormPass.cpp
@@ -325,6 +325,10 @@ public:
}
private:
+ bool condition_common_1_5(uint32_t ifm_channel_depth);
+ bool condition_common_3_4();
+
+private:
template <enum PatternVersion> bool match();
public:
@@ -368,21 +372,8 @@ private:
if (not(condition)) \
return false;
-template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
+bool InstanceNormPattern::condition_common_1_5(uint32_t ifm_channel_depth)
{
- CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
- CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
-
- auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
- CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
- CHECK_OR_FALSE(ifm_circle->rank() == 4);
- CHECK_OR_FALSE(ifm_circle->dim(3).known());
- uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
-
- CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
-
- CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
-
add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
CHECK_OR_FALSE(add_as_variance);
@@ -408,6 +399,70 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
CHECK_OR_FALSE(const_as_beta);
CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
+ return true;
+}
+
+bool InstanceNormPattern::condition_common_3_4()
+{
+ // check left sub
+ ifm = sub->x();
+ CHECK_OR_FALSE(ifm);
+
+ luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
+ CHECK_OR_FALSE(ifm_node->rank() == 4);
+ CHECK_OR_FALSE(ifm_node->dim(3).known());
+
+ mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
+ CHECK_OR_FALSE(mean_of_ifm);
+ CHECK_OR_FALSE(ifm == mean_of_ifm->input());
+
+ // continue search from add_as_variance
+ CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance));
+ CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
+ // TODO Support regarding broadcast
+ CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
+
+ mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x());
+ CHECK_OR_FALSE(mean_as_variance);
+
+ square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
+ CHECK_OR_FALSE(square);
+
+ sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
+ CHECK_OR_FALSE(sub_2);
+ CHECK_OR_FALSE(ifm == sub_2->x());
+
+ mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y());
+ CHECK_OR_FALSE(mean_of_ifm_2);
+ CHECK_OR_FALSE(ifm == mean_of_ifm_2->input());
+
+ loco::Node *ifm_should_be = nullptr;
+ luci::CircleMean *mean_of_ifm_2_should_be = nullptr;
+ CHECK_OR_FALSE(
+ luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
+ CHECK_OR_FALSE(ifm == ifm_should_be);
+ CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be);
+
+ return true;
+}
+
+template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_1>()
+{
+ CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
+ CHECK_OR_FALSE(luci::fill(&ifm, &mul_gamma).with_commutative_args_of(mul_as_scaled_ifm));
+
+ auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
+ CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
+ CHECK_OR_FALSE(ifm_circle->rank() == 4);
+ CHECK_OR_FALSE(ifm_circle->dim(3).known());
+ uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
+
+ CHECK_OR_FALSE(luci::fill(&rsqrt, &const_as_gamma).with_commutative_args_of(mul_gamma));
+
+ CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_gamma, ifm_channel_depth));
+
+ CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
+
luci::CircleMul *mul_gamma_should_be = nullptr;
luci::CircleMean *mean_of_ifm_should_be = nullptr;
@@ -488,44 +543,7 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
CHECK_OR_FALSE(luci::fill(&div, &const_as_gamma).with_commutative_args_of(mul_gamma));
CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
- // check left sub
- ifm = sub->x();
- CHECK_OR_FALSE(ifm);
-
- luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
- CHECK_OR_FALSE(ifm_node->rank() == 4);
- CHECK_OR_FALSE(ifm_node->dim(3).known());
-
- mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
- CHECK_OR_FALSE(mean_of_ifm);
- CHECK_OR_FALSE(ifm == mean_of_ifm->input());
-
- // continue search from add_as_variance
- CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance));
- CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
- // TODO Support regarding broadcast
- CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
-
- mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x());
- CHECK_OR_FALSE(mean_as_variance);
-
- square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
- CHECK_OR_FALSE(square);
-
- sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
- CHECK_OR_FALSE(sub_2);
- CHECK_OR_FALSE(ifm == sub_2->x());
-
- mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y());
- CHECK_OR_FALSE(mean_of_ifm_2);
- CHECK_OR_FALSE(ifm == mean_of_ifm_2->input());
-
- loco::Node *ifm_should_be = nullptr;
- luci::CircleMean *mean_of_ifm_2_should_be = nullptr;
- CHECK_OR_FALSE(
- luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
- CHECK_OR_FALSE(ifm == ifm_should_be);
- CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be);
+ CHECK_OR_FALSE(condition_common_3_4());
_matched = true;
return true;
@@ -546,44 +564,7 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
CHECK_OR_FALSE(div);
CHECK_OR_FALSE(luci::fill(&sub, &add_as_variance).with_commutative_args_of(div));
- // check left sub
- ifm = sub->x();
- CHECK_OR_FALSE(ifm);
-
- luci::CircleNode *ifm_node = loco::must_cast<luci::CircleNode *>(ifm);
- CHECK_OR_FALSE(ifm_node->rank() == 4);
- CHECK_OR_FALSE(ifm_node->dim(3).known());
-
- mean_of_ifm = dynamic_cast<luci::CircleMean *>(sub->y());
- CHECK_OR_FALSE(mean_of_ifm);
- CHECK_OR_FALSE(ifm == mean_of_ifm->input());
-
- // continue search from add_as_variance
- CHECK_OR_FALSE(luci::fill(&sqrt, &const_as_epsilon).with_commutative_args_of(add_as_variance));
- CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
- // TODO Support regarding broadcast
- CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
-
- mean_as_variance = dynamic_cast<luci::CircleMean *>(sqrt->x());
- CHECK_OR_FALSE(mean_as_variance);
-
- square = dynamic_cast<luci::CircleSquare *>(mean_as_variance->input());
- CHECK_OR_FALSE(square);
-
- sub_2 = dynamic_cast<luci::CircleSub *>(square->x());
- CHECK_OR_FALSE(sub_2);
- CHECK_OR_FALSE(ifm == sub_2->x());
-
- mean_of_ifm_2 = dynamic_cast<luci::CircleMean *>(sub_2->y());
- CHECK_OR_FALSE(mean_of_ifm_2);
- CHECK_OR_FALSE(ifm == mean_of_ifm_2->input());
-
- loco::Node *ifm_should_be = nullptr;
- luci::CircleMean *mean_of_ifm_2_should_be = nullptr;
- CHECK_OR_FALSE(
- luci::fill(&ifm_should_be, &mean_of_ifm_2_should_be).with_commutative_args_of(sub_2));
- CHECK_OR_FALSE(ifm == ifm_should_be);
- CHECK_OR_FALSE(mean_of_ifm_2 == mean_of_ifm_2_should_be);
+ CHECK_OR_FALSE(condition_common_3_4());
assert(const_as_gamma == nullptr);
assert(const_as_beta == nullptr);
@@ -612,30 +593,7 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
CHECK_OR_FALSE(ifm_circle->dim(3).known());
uint32_t ifm_channel_depth = ifm_circle->dim(3).value();
- add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
- CHECK_OR_FALSE(add_as_variance);
-
- CHECK_OR_FALSE(
- luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));
-
- CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
- // TODO Support regarding broadcast
- CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);
-
- CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));
-
- sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
- CHECK_OR_FALSE(sqdiff);
-
- loco::Node *ifm_should_be = nullptr;
- CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
- CHECK_OR_FALSE(ifm == ifm_should_be);
- CHECK_OR_FALSE(is_instance_mean_v1(mean_of_ifm));
- CHECK_OR_FALSE(ifm == mean_of_ifm->input());
-
- const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
- CHECK_OR_FALSE(const_as_beta);
- CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));
+ CHECK_OR_FALSE(condition_common_1_5(ifm_channel_depth));
luci::CircleRsqrt *rsqrt_should_be = nullptr;
luci::CircleMean *mean_of_ifm_should_be = nullptr;
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
index b4975486d..e8fa2a478 100644
--- a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
+++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
@@ -23,6 +23,7 @@
#include <luci/Log.h>
#include <cmath>
+#include <limits>
namespace
{
diff --git a/compiler/luci/pass/src/PropagateQParamForwardPass.cpp b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp
index 003e4c293..aaadb2864 100644
--- a/compiler/luci/pass/src/PropagateQParamForwardPass.cpp
+++ b/compiler/luci/pass/src/PropagateQParamForwardPass.cpp
@@ -138,13 +138,18 @@ struct PropagateQParamForward final : public luci::CircleNodeMutableVisitor<bool
auto qtype = luci::activation_qtype(input_node);
switch (qtype)
{
- case luci::ActivationQType::PreDefinedValue:
- node->quantparam(luci::make_predefined_qparam(input_node->opcode(), node->dtype()));
+ case luci::ActivationQType::PreDefinedLogistic:
+ case luci::ActivationQType::PreDefinedTanh:
+ case luci::ActivationQType::PreDefinedSoftmax:
+ node->quantparam(luci::make_predefined_qparam(qtype, node->dtype()));
break;
case luci::ActivationQType::IntScale:
luci::set_int_scale(node);
break;
default:
+ // This assert ensures this switch-satement handles all ActivationQTypes
+ // TODO Find a better design to remove coupling with ActivationQType
+ assert(qtype == luci::ActivationQType::MinMax);
break;
}
diff --git a/compiler/luci/pass/src/QuantizationUtils.cpp b/compiler/luci/pass/src/QuantizationUtils.cpp
index ad86cedf4..06a4ae9f6 100644
--- a/compiler/luci/pass/src/QuantizationUtils.cpp
+++ b/compiler/luci/pass/src/QuantizationUtils.cpp
@@ -20,6 +20,7 @@
#include <iostream>
#include <cmath>
+#include <limits>
namespace luci
{
@@ -276,31 +277,70 @@ uint32_t cal_offset(loco::TensorShape &dimension, uint32_t *indices)
indices[2] * dimension.dim(3).value() + indices[3];
}
+// Activation (ofm) qtype is determined in different ways.
+// 1. Pre-defined values: Some Ops have pre-defined qparams (ex: LOGISTIC, TANH)
+// 2. Integer scale: Output of some Ops should be integers (ex: FLOOR, CEIL)
+// 3. Activation qtype of input: Some Ops propagate qparam from input to output (ex: QUANTIZE,
+// TRANSPOSE, etc. See PropagateQParamForwardPass.cpp for more details).
ActivationQType activation_qtype(const CircleNode *node)
{
auto fused_act_node = dynamic_cast<const CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node);
if (fused_act_node && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
- return ActivationQType::PreDefinedValue;
+ return ActivationQType::PreDefinedTanh;
+
+#define RETURN_INPUT_ACTIVATION_QTYPE(CLASS, INPUT) \
+ { \
+ auto n = loco::must_cast<const CLASS *>(node); \
+ auto input = loco::must_cast<CircleNode *>(n->INPUT()); \
+ return activation_qtype(input); \
+ }
switch (node->opcode())
{
case CircleOpcode::LOGISTIC:
+ return ActivationQType::PreDefinedLogistic;
case CircleOpcode::TANH:
+ return ActivationQType::PreDefinedTanh;
case CircleOpcode::SOFTMAX:
- return ActivationQType::PreDefinedValue;
+ return ActivationQType::PreDefinedSoftmax;
case CircleOpcode::FLOOR:
case CircleOpcode::FLOOR_DIV:
case CircleOpcode::FLOOR_MOD:
case CircleOpcode::CEIL:
return ActivationQType::IntScale;
+ case CircleOpcode::GATHER:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleGather, params);
+ case CircleOpcode::RESHAPE:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleReshape, tensor);
+ case CircleOpcode::TRANSPOSE:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleTranspose, a);
+ case CircleOpcode::STRIDED_SLICE:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleStridedSlice, input);
+ case CircleOpcode::SPLIT:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleSplit, input);
+ case CircleOpcode::CIRCLESPLITOUT:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitOut, input);
+ case CircleOpcode::SPLIT_V:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitV, input);
+ case CircleOpcode::CIRCLESPLITVOUT:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleSplitVOut, input);
+ case CircleOpcode::UNPACK:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleUnpack, value);
+ case CircleOpcode::CIRCLEUNPACKOUT:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleUnpackOut, input);
+ case CircleOpcode::QUANTIZE:
+ RETURN_INPUT_ACTIVATION_QTYPE(CircleQuantize, input);
default:
break;
}
+#undef RETURN_INPUT_ACTIVATION_QTYPE
+
return ActivationQType::MinMax;
}
-std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype)
+std::unique_ptr<CircleQuantParam> make_predefined_qparam(ActivationQType qtype,
+ loco::DataType dtype)
{
auto qparam = std::make_unique<CircleQuantParam>();
@@ -309,9 +349,9 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo
qparam->zerop.emplace_back(zp);
};
- switch (opcode)
+ switch (qtype)
{
- case CircleOpcode::LOGISTIC:
+ case ActivationQType::PreDefinedLogistic:
if (dtype == loco::DataType::U8)
set_qparam(1.0f / 256.0f, 0);
else
@@ -320,7 +360,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo
set_qparam(1.0f / 32768.0f, 0);
}
break;
- case CircleOpcode::TANH:
+ case ActivationQType::PreDefinedTanh:
if (dtype == loco::DataType::U8)
set_qparam(2.0f / 256.0f, 128);
else
@@ -329,7 +369,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo
set_qparam(1.0f / 32768.0f, 0);
}
break;
- case CircleOpcode::SOFTMAX:
+ case ActivationQType::PreDefinedSoftmax:
if (dtype == loco::DataType::U8)
set_qparam(1.0f / 255.0f, 0);
else
@@ -341,7 +381,7 @@ std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, lo
default:
throw std::runtime_error("Unsupported opcode with pre-defined qparam");
}
- return std::move(qparam);
+ return qparam;
}
// For nodes with integer output, we use integer scale
@@ -395,4 +435,74 @@ void quant_const(luci::CircleConst *node, loco::DataType quant_type)
node->quantparam(std::move(quantparam));
}
+namespace
+{
+
+// TODO move this to a more global helper file
+int nbits(loco::DataType dt) noexcept
+{
+ switch (dt)
+ {
+ case loco::DataType::S8:
+ case loco::DataType::U8:
+ return 8;
+ case loco::DataType::S16:
+ case loco::DataType::U16:
+ case loco::DataType::FLOAT16:
+ return 16;
+ case loco::DataType::S32:
+ case loco::DataType::U32:
+ case loco::DataType::FLOAT32:
+ return 32;
+ case loco::DataType::S64:
+ return 64;
+ default:
+ return 64; // a safe large default
+ }
+}
+
+// TODO Check if the metric is valid
+// Returns true if [min,max] is poorly representable
+bool range_check(float min, float max, loco::DataType dtype)
+{
+ float thresh = 1.5f;
+ return log2f(max) - log2f(min) > nbits(dtype) * thresh;
+}
+
+bool warn_scale_zp(float scale, int64_t zp, luci::CircleNode *n)
+{
+ float min, max;
+ // estimate min/max
+ switch (n->dtype())
+ {
+ case loco::DataType::U8:
+ min = scale * (0 - zp);
+ max = scale * (255 - zp);
+ break;
+ case loco::DataType::S16:
+ min = scale * (-32767);
+ max = scale * (32767);
+ break;
+ default:
+ return false;
+ }
+ return range_check(min, max, n->dtype());
+}
+
+} // namespace
+
+void warn_accuracy_with_range(luci::CircleNode *n)
+{
+ LOGGER(l);
+ auto qp = n->quantparam();
+ auto k = qp->zerop.size();
+ for (uint32_t i = 0; i < k; i++)
+ {
+ if (warn_scale_zp(qp->scale[i], qp->zerop[i], n))
+ WARN(l) << "Quantization of " << i << "-th channel of " << n->name()
+ << "'s quantization may cause accuracy issues" << std::endl;
+ ;
+ }
+}
+
} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizationUtils.h b/compiler/luci/pass/src/QuantizationUtils.h
index cd8cec95a..4d5316ccb 100644
--- a/compiler/luci/pass/src/QuantizationUtils.h
+++ b/compiler/luci/pass/src/QuantizationUtils.h
@@ -62,15 +62,19 @@ bool is_quantized(const CircleNode *node);
enum ActivationQType
{
- MinMax, // Quantize using recorded min/max
- PreDefinedValue, // Quantize using pre-defined values
- IntScale, // Round scale to a positive integer
+ MinMax, // Quantize using recorded min/max
+ PreDefinedLogistic, // Quantize using pre-defined values
+ PreDefinedTanh, // Quantize using pre-defined values
+ PreDefinedSoftmax, // Quantize using pre-defined values
+ IntScale, // Round scale to a positive integer
};
ActivationQType activation_qtype(const CircleNode *node);
// Create qparam with pre-defined values for speical operators
-std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleOpcode opcode, loco::DataType dtype);
+std::unique_ptr<CircleQuantParam> make_predefined_qparam(CircleNode *node, loco::DataType dtype);
+std::unique_ptr<CircleQuantParam> make_predefined_qparam(ActivationQType qtype,
+ loco::DataType dtype);
// Update node's scale to a positive integer (for special Ops e.g., Floor, Ceil)
void set_int_scale(luci::CircleNode *node);
@@ -78,6 +82,10 @@ void set_int_scale(luci::CircleNode *node);
// Quantize const tensor using its min/max values
void quant_const(luci::CircleConst *node, loco::DataType quant_type);
+// Check that a node is quantized without significant loss of precision;
+// Emits warnings to log with WARN
+void warn_accuracy_with_range(luci::CircleNode *n);
+
} // namespace luci
#endif // __LUCI_QUANTIZATION_UTILS_H__
diff --git a/compiler/luci/pass/src/QuantizeActivation.cpp b/compiler/luci/pass/src/QuantizeActivation.cpp
index 149331824..95251a82c 100644
--- a/compiler/luci/pass/src/QuantizeActivation.cpp
+++ b/compiler/luci/pass/src/QuantizeActivation.cpp
@@ -114,29 +114,26 @@ void QuantizeSpecialActivation::visit(luci::CircleNode *node)
auto fused_act_node = dynamic_cast<CircleNodeMixin<CircleNodeTrait::FusedActFunc> *>(node);
if (fused_act_node != nullptr && fused_act_node->fusedActivationFunction() == FusedActFunc::TANH)
{
- auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type);
+ auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedTanh, output_type);
node->quantparam(std::move(qparam));
}
}
void QuantizeSpecialActivation::visit(luci::CircleLogistic *node)
{
- assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
- auto qparam = make_predefined_qparam(luci::CircleOpcode::LOGISTIC, output_type);
+ auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedLogistic, output_type);
node->quantparam(std::move(qparam));
}
void QuantizeSpecialActivation::visit(luci::CircleTanh *node)
{
- assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
- auto qparam = make_predefined_qparam(luci::CircleOpcode::TANH, output_type);
+ auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedTanh, output_type);
node->quantparam(std::move(qparam));
}
void QuantizeSpecialActivation::visit(luci::CircleSoftmax *node)
{
- assert(activation_qtype(node) == luci::ActivationQType::PreDefinedValue);
- auto qparam = make_predefined_qparam(luci::CircleOpcode::SOFTMAX, output_type);
+ auto qparam = make_predefined_qparam(luci::ActivationQType::PreDefinedSoftmax, output_type);
node->quantparam(std::move(qparam));
}
diff --git a/compiler/luci/pass/src/QuantizeBias.cpp b/compiler/luci/pass/src/QuantizeBias.cpp
index aa496232a..de97a14dd 100644
--- a/compiler/luci/pass/src/QuantizeBias.cpp
+++ b/compiler/luci/pass/src/QuantizeBias.cpp
@@ -22,6 +22,7 @@
#include <algorithm>
#include <cmath>
+#include <limits>
using namespace luci;
@@ -201,6 +202,18 @@ CircleConst *QuantizeBias::quantized_bias(CircleNode *input, const CircleNode *w
std::vector<float> scaling_factor(size);
std::vector<int64_t> zp(size);
+ if (const_bias->rank() == 0)
+ {
+ // TODO Support quantization of scalar bias
+ throw std::runtime_error("Quantization of scalar bias is not yet supported (" +
+ const_bias->name() + ")");
+ }
+ if (size != const_bias->dim(const_bias->rank() - 1).value())
+ {
+ throw std::runtime_error(const_bias->name() +
+ " (bias) should have the shape of [1, 1, .. 1, channel]");
+ }
+
if (output_type == loco::DataType::U8)
{
new_bias = quant_bias_per_channel(const_bias, input_scale, weight_scale, scaling_factor, zp);
@@ -218,6 +231,7 @@ CircleConst *QuantizeBias::quantized_bias(CircleNode *input, const CircleNode *w
auto quantparam = std::make_unique<CircleQuantParam>();
quantparam->scale = scaling_factor;
quantparam->zerop = zp;
+ quantparam->quantized_dimension = const_bias->rank() - 1;
assert(new_bias->quantparam() == nullptr); // bias should not be quantized before
new_bias->quantparam(std::move(quantparam));
diff --git a/compiler/luci/pass/src/QuantizeBias.test.cpp b/compiler/luci/pass/src/QuantizeBias.test.cpp
new file mode 100644
index 000000000..0104a191b
--- /dev/null
+++ b/compiler/luci/pass/src/QuantizeBias.test.cpp
@@ -0,0 +1,189 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "QuantizeBias.h"
+
+#include <luci/test/TestIOGraph.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleQuantParam.h>
+
+#include <gtest/gtest.h>
+
+using namespace luci;
+
+namespace
+{
+
+using namespace luci::test;
+
+// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
+template <typename T>
+luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
+ const std::vector<uint32_t> &shape, T value)
+{
+ auto node = g->nodes()->create<luci::CircleConst>();
+ node->dtype(dtype);
+ node->rank(shape.size());
+
+ uint32_t size = 1;
+ for (uint32_t i = 0; i < shape.size(); ++i)
+ {
+ node->dim(i) = shape.at(i);
+ size *= shape.at(i);
+ }
+ node->shape_status(luci::ShapeStatus::VALID);
+
+#define INIT_VALUES(DT) \
+ { \
+ node->size<DT>(size); \
+ for (uint32_t i = 0; i < size; ++i) \
+ node->at<DT>(i) = value; \
+ }
+
+ switch (dtype)
+ {
+ case loco::DataType::U8:
+ INIT_VALUES(loco::DataType::U8);
+ break;
+ case loco::DataType::S16:
+ INIT_VALUES(loco::DataType::S16);
+ break;
+ case loco::DataType::S32:
+ INIT_VALUES(loco::DataType::S32);
+ break;
+ case loco::DataType::FLOAT32:
+ INIT_VALUES(loco::DataType::FLOAT32)
+ break;
+ default:
+ INTERNAL_EXN("create_const_node called with unsupported type");
+ break;
+ }
+ return node;
+}
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [IFM] [WEIGHTS] [BIAS(FP32)]
+ * \ | /
+ * [FC]
+ * |
+ * [OFM]
+ *
+ * AFTER
+ *
+ * [IFM] [WEIGHTS] [BIAS(Quantized)]
+ * \ | /
+ * [FC]
+ * |
+ * [OFM]
+ */
+struct Q8FCGraphlet
+{
+public:
+ Q8FCGraphlet() = default;
+ virtual ~Q8FCGraphlet() = default;
+
+ void init(loco::Graph *g, const ShapeU32 out_shape, const ShapeU32 w_shape,
+ const ShapeU32 bias_shape, const float bv)
+ {
+ _fc = g->nodes()->create<luci::CircleFullyConnected>();
+ _fc->input(_x);
+ _x->dtype(loco::DataType::U8);
+ {
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale.push_back(1.0);
+ quantparam->zerop.push_back(0);
+ quantparam->quantized_dimension = 0;
+ _x->quantparam(std::move(quantparam));
+ }
+
+ auto weights = create_const_node<uint8_t>(g, loco::DataType::U8, w_shape, 1.0);
+ auto w_qparam = std::make_unique<CircleQuantParam>();
+ std::vector<float> w_scale(weights->dim(0).value(), 1.0);
+ std::vector<int64_t> w_zp(weights->dim(0).value(), 0);
+ w_qparam->scale = w_scale;
+ w_qparam->zerop = w_zp;
+ w_qparam->quantized_dimension = 0;
+ weights->quantparam(std::move(w_qparam));
+ _fc->weights(weights);
+ _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _fc->dtype(loco::DataType::U8);
+ _fc->shape(out_shape);
+ auto l = _fc->dim(_fc->rank() - 1).value();
+ _fc->bias(create_const_node(g, loco::DataType::FLOAT32, bias_shape, bv));
+ _fc->name("fc");
+ {
+ auto quantparam = std::make_unique<CircleQuantParam>();
+ quantparam->scale.push_back(1.0);
+ quantparam->zerop.push_back(0);
+ quantparam->quantized_dimension = 0;
+ _fc->quantparam(std::move(quantparam));
+ }
+ }
+
+public:
+ luci::CircleFullyConnected *fc() { return _fc; }
+
+protected:
+ luci::CircleFullyConnected *_fc = nullptr;
+ luci::CircleInput *_x = nullptr;
+};
+
+struct Q8FCGraph final : public TestIGraphlet, public TestOGraphlet, public Q8FCGraphlet
+{
+ void init(const ShapeU32 in_shape, const ShapeU32 w_shape, const ShapeU32 out_shape,
+ const ShapeU32 bias_shape, const float bv)
+ {
+ TestIGraphlet::init(g(), in_shape);
+ TestOGraphlet::init(g(), out_shape);
+ _x = input();
+ Q8FCGraphlet::init(g(), out_shape, w_shape, bias_shape, bv);
+ output()->from(_fc);
+ }
+};
+
+class CQ8QuantizeBiasFCTest : public ::testing::Test
+{
+public:
+ Q8FCGraph g;
+ luci::QuantizeBias qb{loco::DataType::FLOAT32, loco::DataType::U8,
+ luci::QuantizationGranularity::ChannelWise};
+};
+
+} // namespace
+
+TEST_F(CQ8QuantizeBiasFCTest, fully_connected)
+{
+ g.init({1, 18, 80}, {256, 80}, {18, 256}, {1, 256}, 1);
+ g.fc()->accept(&qb);
+
+ auto bias = loco::must_cast<CircleConst *>(g.fc()->bias());
+ auto qparam = bias->quantparam();
+
+ EXPECT_NE(nullptr, qparam);
+ EXPECT_EQ(256, qparam->scale.size());
+ EXPECT_EQ(256, qparam->zerop.size());
+ EXPECT_EQ(1, qparam->quantized_dimension);
+}
+
+TEST_F(CQ8QuantizeBiasFCTest, wrong_bias_shape_NEG)
+{
+ g.init({1, 18, 80}, {256, 80}, {18, 256}, {1, 2, 128}, 1);
+ EXPECT_ANY_THROW(g.fc()->accept(&qb)); // Wrong bias shape
+}
diff --git a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
index c9b35e0be..ef047d35d 100644
--- a/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
+++ b/compiler/luci/pass/src/QuantizeDequantizeWeightsPass.cpp
@@ -27,6 +27,7 @@
#include <iostream>
#include <cmath>
#include <functional>
+#include <limits>
namespace
{
@@ -352,15 +353,15 @@ private:
private:
// Check if
// 1. node is const
- // 2. node was not quantized
+ // 2. node's dtype is float32
bool is_quantizable(loco::Node *node)
{
auto const_node = dynamic_cast<luci::CircleConst *>(node);
if (not const_node)
return false;
- // Skip if this is already quantized
- if (is_quantized(const_node))
+ // Skip if this is not float32
+ if (const_node->dtype() != loco::DataType::FLOAT32)
return false;
return true;
diff --git a/compiler/luci/pass/src/QuantizeWeights.cpp b/compiler/luci/pass/src/QuantizeWeights.cpp
index 11322ab44..500ae12ed 100644
--- a/compiler/luci/pass/src/QuantizeWeights.cpp
+++ b/compiler/luci/pass/src/QuantizeWeights.cpp
@@ -23,6 +23,7 @@
#include <cmath>
#include <vector>
#include <functional>
+#include <limits>
using namespace luci;
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index d9a9d4db7..005144516 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -41,10 +41,28 @@ namespace
{
using namespace luci;
+
+bool use_predefined_values(ActivationQType qtype)
+{
+ switch (qtype)
+ {
+ case ActivationQType::PreDefinedLogistic:
+ case ActivationQType::PreDefinedTanh:
+ case ActivationQType::PreDefinedSoftmax:
+ return true;
+ default:
+ // This ensures this switch-statement handles all ActivationQTypes
+ assert(qtype == ActivationQType::IntScale or qtype == ActivationQType::MinMax);
+ break;
+ }
+
+ return false;
+}
+
// Create a Quantize Op whose
// dtype is out_type
// shape is the same with node
-// qparam is computed using node's min/max
+// qparam is computed according to node's qtype
luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType out_type)
{
auto quantize = node->graph()->nodes()->create<CircleQuantize>();
@@ -60,9 +78,9 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
assert(qparam); // FIX_CALLER_UNLESS
auto qtype = luci::activation_qtype(node);
- if (qtype == ActivationQType::PreDefinedValue)
+ if (use_predefined_values(qtype))
{
- quantize->quantparam(luci::make_predefined_qparam(node->opcode(), out_type));
+ quantize->quantparam(luci::make_predefined_qparam(qtype, out_type));
return quantize;
}
@@ -105,6 +123,23 @@ luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType
return quantize;
}
+// Create Dequantize Op whose shape is the same with node
+luci::CircleDequantize *create_dequantize(luci::CircleNode *node)
+{
+ auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>();
+ dequantize->name(node->name() + "_Dequantize");
+ dequantize->dtype(loco::DataType::FLOAT32);
+ dequantize->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ dequantize->dim(i).set(node->dim(i).value());
+
+ dequantize->shape_status(luci::ShapeStatus::VALID);
+
+ luci::add_origin(dequantize, luci::get_origin(node));
+
+ return dequantize;
+}
+
} // namespace
namespace luci
@@ -229,11 +264,13 @@ private:
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLeakyRelu, features)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLogistic, x)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMaxPool2D, value)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMean, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMirrorPad, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleNeg, x)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePad, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePadV2, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePRelu, input)
@@ -241,6 +278,7 @@ private:
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMax, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMin, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu, features)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu6, features)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReshape, tensor)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeBilinear, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeNearestNeighbor, input)
@@ -250,6 +288,7 @@ private:
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSoftmax, logits)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToBatchND, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToDepth, input)
+ INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqueeze, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqrt, x)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleStridedSlice, input)
INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSum, input)
@@ -353,7 +392,9 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
luci::add_origin(quant_op, luci::get_origin(succ));
}
- // Requantize input
+ // Update qparam of input
+ // This step is skipped if input_type is float32
+ if (_ctx->input_type != loco::DataType::FLOAT32)
{
auto quantparam = input->quantparam();
assert(quantparam);
@@ -376,11 +417,13 @@ void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
assert(_ctx->input_type == loco::DataType::S16);
compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
}
- input->dtype(_ctx->input_type);
input->quantparam()->scale[0] = scaling_factor;
input->quantparam()->zerop[0] = zp;
}
+ // Update dtype of input
+ input->dtype(_ctx->input_type);
+
auto graph_input = inputs->at(input->index());
graph_input->dtype(_ctx->input_type);
}
@@ -405,13 +448,26 @@ void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
if (not from->quantparam())
continue;
- // Insert Quantize Op
- auto quant_op = create_quantize_op(from, _ctx->output_type);
- loco::replace(from).with(quant_op);
- quant_op->input(from);
+ // Insert Dequantize Op for float32 output_type
+ if (_ctx->output_type == loco::DataType::FLOAT32)
+ {
+ auto dequant_op = create_dequantize(from);
+ loco::replace(from).with(dequant_op);
+ dequant_op->input(from);
+ }
+ else
+ {
+ // Insert Quantize Op for non-float32 output_type
+ auto quant_op = create_quantize_op(from, _ctx->output_type);
+ loco::replace(from).with(quant_op);
+ quant_op->input(from);
- // TODO Set a proper origin (Quantize should have its own Origin)
- luci::add_origin(quant_op, luci::get_origin(from));
+ // TODO Set a proper origin (Quantize should have its own Origin)
+ luci::add_origin(quant_op, luci::get_origin(from));
+ }
+
+ // Update dtype of output
+ output->dtype(_ctx->output_type);
auto graph_output = outputs->at(output->index());
graph_output->dtype(_ctx->output_type);
@@ -594,12 +650,25 @@ bool QuantizeWithMinMaxPass::run(loco::Graph *g)
// Set output type
set_output_type(g);
+ // Remove redundant Quantize Op
+ {
+ logo::Phase phase;
+
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());
+
+ ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
+ logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
+ phase_runner.attach(&prog);
+ phase_runner.run(phase);
+ }
+
// Remove min/max values
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
if (auto qparam = circle_node->quantparam())
{
+ warn_accuracy_with_range(circle_node);
qparam->min.clear();
qparam->max.clear();
}
diff --git a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
index cebafd32b..21b4fe1c6 100644
--- a/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
+++ b/compiler/luci/pass/src/QuantizedModelVerifier.test.cpp
@@ -1088,6 +1088,31 @@ private:
luci::CircleConst *_const = nullptr;
};
+class ReduceMaxTestGraph final : public SimpleTestGraph
+{
+public:
+ void init(void) override
+ {
+ TestIOGraph::init({4, 3, 2}, {2});
+
+ _axis = create_const<Type::S32, int32_t>(g(), {4}, {1, 0, -3, -3});
+ _reduce_max = g()->nodes()->create<luci::CircleReduceMax>();
+ {
+ _reduce_max->input(input());
+ _reduce_max->reduction_indices(_axis);
+ _reduce_max->name("test");
+ _reduce_max->keep_dims(false);
+ }
+ output()->from(_reduce_max);
+
+ set_minmax_to_non_const(g(), -1, 1);
+ }
+
+private:
+ luci::CircleReduceMax *_reduce_max = nullptr;
+ luci::CircleConst *_axis = nullptr;
+};
+
class ResizeBilinearTestGraph final : public SimpleTestGraph
{
public:
@@ -2345,6 +2370,34 @@ TEST(QuantizedModelVerifierTest, Pow_wrong_granularity_NEG)
SUCCEED();
}
+TEST(QuantizedModelVerifierTest, ReduceMax)
+{
+ TEST_WITH_GRAPH(ReduceMaxTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_GRAPH(ReduceMaxTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_GRAPH(ReduceMaxTestGraph, Type::S16, Granularity::ChannelWise);
+
+ TEST_WITH_LAYER_INFO(ReduceMaxTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_LAYER_INFO(ReduceMaxTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_LAYER_INFO(ReduceMaxTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, ReduceMax_wrong_type_NEG)
+{
+ TEST_WITH_WRONG_TYPE(ReduceMaxTestGraph, Type::U8, Granularity::LayerWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ReduceMaxTestGraph, Type::U8, Granularity::ChannelWise, Type::S16);
+ TEST_WITH_WRONG_TYPE(ReduceMaxTestGraph, Type::S16, Granularity::ChannelWise, Type::U8);
+ SUCCEED();
+}
+
+TEST(QuantizedModelVerifierTest, ReduceMax_wrong_granularity_NEG)
+{
+ TEST_WITH_WRONG_GRANULARITY(ReduceMaxTestGraph, Type::U8, Granularity::LayerWise);
+ TEST_WITH_WRONG_GRANULARITY(ReduceMaxTestGraph, Type::U8, Granularity::ChannelWise);
+ TEST_WITH_WRONG_GRANULARITY(ReduceMaxTestGraph, Type::S16, Granularity::ChannelWise);
+ SUCCEED();
+}
+
TEST(QuantizedModelVerifierTest, ResizeBilinear)
{
TEST_WITH_GRAPH(ResizeBilinearTestGraph, Type::U8, Granularity::LayerWise);
diff --git a/compiler/luci/pass/src/RemoveRedundantDequantizePass.cpp b/compiler/luci/pass/src/RemoveRedundantDequantizePass.cpp
new file mode 100644
index 000000000..66cd9d791
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantDequantizePass.cpp
@@ -0,0 +1,80 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantDequantizePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+bool remove_redundant_dequant(luci::CircleDequantize *dequant)
+{
+ assert(dequant != nullptr);
+
+ auto prev = loco::must_cast<luci::CircleNode *>(dequant->input());
+ if (prev->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ replace(dequant).with(prev);
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+/**
+ * Dequantize Op does the below things on the ifm.
+ * 1. Element-wise update of quantized values (u8/s16) to fp32 values
+ * 2. Update dtype to fp32
+ * If the previous node is not quantized, dequantize Op is redundant.
+ *
+ * BEFORE
+ *
+ * [CircleNode (A)]
+ * |
+ * [CircleNode (B)] (fp32)
+ * |
+ * [CircleDequantize]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode (A)]
+ * |
+ * [CircleNode (B)] (fp32)
+ * |
+ * [CircleNode]
+ */
+bool RemoveRedundantDequantizePass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto target_node = dynamic_cast<luci::CircleDequantize *>(node);
+ if (target_node != nullptr)
+ {
+ if (remove_redundant_dequant(target_node))
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveRedundantDequantizePass.test.cpp b/compiler/luci/pass/src/RemoveRedundantDequantizePass.test.cpp
new file mode 100644
index 000000000..adb2f14a4
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveRedundantDequantizePass.test.cpp
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveRedundantDequantizePass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <luci/test/TestIOGraph.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+class DequantizeGraphlet
+{
+public:
+ DequantizeGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ _dequantize = g->nodes()->create<luci::CircleDequantize>();
+ _dequantize->dtype(loco::DataType::FLOAT32);
+ _dequantize->name("dequantize");
+ }
+
+protected:
+ luci::CircleDequantize *_dequantize = nullptr;
+};
+
+class RedundantDequantizeGraph : public TestIOGraph, public DequantizeGraphlet
+{
+public:
+ RedundantDequantizeGraph() = default;
+
+public:
+ void init(void)
+ {
+ TestIOGraph::init({1}, {1});
+ DequantizeGraphlet::init(g());
+
+ _dequantize->input(input());
+
+ output()->from(_dequantize);
+ }
+
+ void init_u8_input(void)
+ {
+ TestIOGraph::init({1}, {1});
+ DequantizeGraphlet::init(g());
+
+ // Use u8 input (dequantize is not redundant anymore)
+ input()->dtype(loco::DataType::U8);
+ {
+ auto qparam = std::make_unique<luci::CircleQuantParam>();
+ qparam->scale = {1};
+ qparam->zerop = {1};
+ input()->quantparam(std::move(qparam));
+ }
+
+ _dequantize->input(input());
+
+ output()->from(_dequantize);
+ }
+};
+
+} // namespace
+
+TEST(RemoveRedundantDequantizePass, single_redundant_dequantize)
+{
+ RedundantDequantizeGraph g;
+ luci::RemoveRedundantDequantizePass pass;
+
+ g.init();
+
+ EXPECT_TRUE(pass.run(g.g()));
+
+ int count = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(g.g())))
+ {
+ if (dynamic_cast<luci::CircleDequantize *>(node))
+ {
+ count++;
+ }
+ }
+
+ ASSERT_EQ(0, count);
+}
+
+TEST(RemoveRedundantDequantizePass, wrong_dtype_NEG)
+{
+ RedundantDequantizeGraph g;
+ luci::RemoveRedundantDequantizePass pass;
+
+ g.init_u8_input();
+
+ EXPECT_FALSE(pass.run(g.g()));
+}
diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.cpp
new file mode 100644
index 000000000..476ec68bf
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.cpp
@@ -0,0 +1,172 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+
+bool acceptable_intermediate_op(const loco::Node *node)
+{
+ if (not node)
+ return false;
+
+ const auto opcode = loco::must_cast<const luci::CircleNode *>(node)->opcode();
+
+ switch (opcode)
+ {
+ case luci::CircleOpcode::ADD:
+ case luci::CircleOpcode::MUL:
+ case luci::CircleOpcode::TANH:
+ case luci::CircleOpcode::LOGISTIC:
+ break;
+
+ default:
+ return false;
+ }
+
+ return true;
+}
+
+bool same_shape(const loco::Node *a, const loco::Node *b)
+{
+ auto a_cnode = loco::must_cast<const luci::CircleNode *>(a);
+ auto b_cnode = loco::must_cast<const luci::CircleNode *>(b);
+
+ if (a_cnode->rank() != b_cnode->rank())
+ return false;
+
+ for (uint32_t i = 0; i < a_cnode->rank(); i++)
+ {
+ if (not(a_cnode->dim(i) == b_cnode->dim(i)))
+ return false;
+ }
+ return true;
+}
+
+class PreReshapeFinder
+{
+public:
+ PreReshapeFinder(const luci::CircleReshape *post_reshape) : _post_reshape(post_reshape)
+ {
+ assert(post_reshape != nullptr); // FIX_CALLER_UNLESS
+ }
+
+public:
+ // Return true if pre_reshapes are found
+ bool collect_pre_reshapes(loco::Node *node)
+ {
+ // TODO Support diamond case
+ if (loco::succs(node).size() != 1)
+ return false;
+
+ if (auto pre_reshape = dynamic_cast<luci::CircleReshape *>(node))
+ {
+ // Check ifm of pre-reshape and ofm of post_reshape
+ if (not same_shape(pre_reshape->tensor(), _post_reshape))
+ return false;
+
+ // Check ofm of pre-reshape and ifm of post_reshape
+ if (not same_shape(pre_reshape, _post_reshape->tensor()))
+ return false;
+
+ _pre_reshapes.emplace_back(pre_reshape);
+ return true;
+ }
+
+ if (not acceptable_intermediate_op(node))
+ return false;
+
+ for (uint32_t i = 0; i < node->arity(); i++)
+ {
+ if (not collect_pre_reshapes(node->arg(i)))
+ return false;
+ }
+
+ return true;
+ }
+
+public:
+ std::vector<luci::CircleReshape *> pre_reshapes(void) const { return _pre_reshapes; }
+
+private:
+ const luci::CircleReshape *_post_reshape = nullptr;
+ std::vector<luci::CircleReshape *> _pre_reshapes;
+};
+
+bool remove_unnecessary_reshape_net(luci::CircleReshape *reshape)
+{
+ PreReshapeFinder finder(reshape);
+ if (not finder.collect_pre_reshapes(reshape->tensor()))
+ return false;
+
+ // Remove pre_reshapes
+ for (auto pre_reshape : finder.pre_reshapes())
+ {
+ loco::replace(pre_reshape).with(pre_reshape->tensor());
+ }
+
+ // Remove post_reshape
+ loco::replace(reshape).with(reshape->tensor());
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+/**
+ * BEFORE
+ *
+ * [CircleNode]
+ * |
+ * [CircleReshape_1] (shape: A -> B)
+ * |
+ * [CircleNode] (ex: Add/Mul/Tanh/Logistic ..)
+ * |
+ * [CircleReshape_2] (shape: B -> A)
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode]
+ * | \
+ * | [CircleReshape_1]
+ * [CircleNode]
+ * | \
+ * | [CircleReshape_2]
+ * [CircleNode]
+ **/
+bool RemoveUnnecessaryReshapeNetPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto reshape_node = dynamic_cast<luci::CircleReshape *>(node))
+ {
+ if (remove_unnecessary_reshape_net(reshape_node))
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.test.cpp b/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.test.cpp
new file mode 100644
index 000000000..4ad707ba3
--- /dev/null
+++ b/compiler/luci/pass/src/RemoveUnnecessaryReshapeNetPass.test.cpp
@@ -0,0 +1,123 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "luci/Pass/RemoveUnnecessaryReshapeNetPass.h"
+
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+class RemoveUnnecessaryReshapeNet : public ::testing::Test
+{
+public:
+ RemoveUnnecessaryReshapeNet() {}
+
+ void createReshapeConst(luci::CircleReshape *target, const std::vector<uint32_t> shape)
+ {
+ auto shape_const = g.nodes()->create<luci::CircleConst>();
+ shape_const->dtype(loco::DataType::S32);
+ shape_const->size<loco::DataType::S32>(shape.size());
+ shape_const->shape_status(luci::ShapeStatus::VALID);
+ shape_const->rank(1);
+ shape_const->dim(0).set(shape.size());
+ for (int32_t i = 0; i < shape.size(); i++)
+ {
+ shape_const->at<loco::DataType::S32>(i) = static_cast<int32_t>(shape.at(i));
+ }
+ shape_const->name("shape_const");
+ target->shape(shape_const);
+ target->rank(shape.size());
+ for (uint32_t i = 0; i < shape.size(); i++)
+ {
+ target->dim(i) = shape[i];
+ }
+ target->shape_status(luci::ShapeStatus::VALID);
+ }
+
+ void buildGraph(const std::initializer_list<uint32_t> base_shape,
+ const std::initializer_list<uint32_t> first_shape,
+ const std::initializer_list<uint32_t> second_shape)
+ {
+ // Input Create.
+ input = g.nodes()->create<luci::CircleInput>();
+ auto graph_input = g.inputs()->create();
+ input->index(graph_input->index());
+ input->shape_status(luci::ShapeStatus::VALID);
+ input->shape(base_shape);
+ input->name("input");
+
+ // Create first reshape.
+ first_reshape = g.nodes()->create<luci::CircleReshape>();
+ first_reshape->tensor(input);
+ first_reshape->name("Reshape");
+ createReshapeConst(first_reshape, first_shape);
+
+ // Create logistic.
+ logistic = g.nodes()->create<luci::CircleLogistic>();
+ logistic->x(first_reshape);
+ logistic->name("logistic");
+ logistic->shape(first_shape);
+ logistic->shape_status(luci::ShapeStatus::VALID);
+
+ // Create second reshape.
+ second_reshape = g.nodes()->create<luci::CircleReshape>();
+ second_reshape->tensor(logistic);
+ second_reshape->name("second_reshape");
+ createReshapeConst(second_reshape, second_shape);
+
+ // Output Connect.
+ output = g.nodes()->create<luci::CircleOutput>();
+ output->from(second_reshape);
+ output->name("output");
+ auto graph_output = g.outputs()->create();
+ output->index(graph_output->index());
+ }
+
+public:
+ loco::Graph g;
+ luci::CircleInput *input = nullptr;
+ luci::CircleReshape *first_reshape = nullptr;
+ luci::CircleLogistic *logistic = nullptr;
+ luci::CircleReshape *second_reshape = nullptr;
+ luci::CircleOutput *output = nullptr;
+};
+
+} // namespace
+
+TEST_F(RemoveUnnecessaryReshapeNet, simple_case)
+{
+ buildGraph({1, 1, 1, 32}, {1, 1, 32, 1}, {1, 1, 1, 32});
+ luci::RemoveUnnecessaryReshapeNetPass pass;
+
+ ASSERT_TRUE(pass.run(&g));
+
+ int count = 0;
+ for (auto node : loco::active_nodes(loco::output_nodes(&g)))
+ {
+ if (auto reshape = dynamic_cast<luci::CircleReshape *>(node))
+ count++;
+ }
+ ASSERT_EQ(0, count);
+}
+
+TEST_F(RemoveUnnecessaryReshapeNet, shape_mismatch_NEG)
+{
+ buildGraph({1, 1, 1, 32}, {1, 1, 32, 1}, {1, 1, 2, 16});
+ luci::RemoveUnnecessaryReshapeNetPass pass;
+ ASSERT_FALSE(pass.run(&g));
+}
diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp
new file mode 100644
index 000000000..741b70956
--- /dev/null
+++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.cpp
@@ -0,0 +1,196 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h>
+
+namespace
+{
+
+// TODO move to global helper list if needed
+/**
+ * @brief Create a node with `inp` as input from fused activation fucntion `act`
+ */
+luci::CircleNode *fromActivation(luci::CircleNode *inp, luci::FusedActFunc act)
+{
+ switch (act)
+ {
+ case luci::FusedActFunc::NONE:
+ return inp;
+ case luci::FusedActFunc::RELU:
+ {
+ auto n = inp->graph()->nodes()->create<luci::CircleRelu>();
+ n->features(inp);
+ return n;
+ }
+ case luci::FusedActFunc::RELU6:
+ {
+ auto n = inp->graph()->nodes()->create<luci::CircleRelu6>();
+ n->features(inp);
+ return n;
+ }
+ case luci::FusedActFunc::RELU_N1_TO_1:
+ {
+ auto n = inp->graph()->nodes()->create<luci::CircleReluN1To1>();
+ n->features(inp);
+ return n;
+ }
+ case luci::FusedActFunc::TANH:
+ {
+ auto n = inp->graph()->nodes()->create<luci::CircleTanh>();
+ n->x(inp);
+ return n;
+ }
+ case luci::FusedActFunc::SIGN_BIT:
+ {
+ throw std::invalid_argument("no matching node to create from fused activation");
+ }
+ default:
+ throw std::invalid_argument("invalid fused activation");
+ }
+}
+
+/**
+ * Replace Fully Connected with Batched MatMul
+ *
+ * BEFORE
+ *
+ * [Node1] [Node2]
+ * | |
+ * [transpose]? [transpose]?
+ * \ /
+ * [FullyConnected]
+ *
+ * AFTER
+ *
+ * [Node1] [Node2]
+ * \ /
+ * [BatchMatMul] [BiasValue]?
+ * \ /
+ * [Add]?
+ * |
+ * [Activation]?
+ *
+ * Nodes with "?" denote optional elements
+ */
+bool replace_fc_with_matmul(luci::CircleFullyConnected *fc)
+{
+ luci::CircleNode *x = nullptr;
+ luci::CircleNode *y = nullptr;
+ luci::CircleNode *b = nullptr;
+ luci::CircleTranspose *ty = nullptr;
+ luci::CircleTranspose *tx = nullptr;
+ bool adj_x = false;
+ bool adj_y = true;
+
+ if (dynamic_cast<luci::CircleConst *>(fc->weights()))
+ return false; // NonConst
+
+ if ((ty = dynamic_cast<luci::CircleTranspose *>(fc->weights()))) // is y a transpose?
+ {
+ adj_y = false;
+ if (dynamic_cast<luci::CircleConst *>(ty->a()))
+ return false;
+ else
+ y = loco::must_cast<luci::CircleNode *>(ty->a());
+ }
+ else
+ { // y is not transpose and not const
+ y = loco::must_cast<luci::CircleNode *>(fc->weights());
+ }
+ if ((tx = dynamic_cast<luci::CircleTranspose *>(fc->input())))
+ {
+ adj_x = true;
+ x = loco::must_cast<luci::CircleNode *>(tx->a());
+ }
+ else
+ {
+ x = loco::must_cast<luci::CircleNode *>(fc->input());
+ }
+
+ b = loco::must_cast<luci::CircleNode *>(fc->bias());
+
+ if (x->dtype() != loco::DataType::FLOAT32 || y->dtype() != loco::DataType::FLOAT32 ||
+ b->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ auto name = fc->name();
+ assert(name.length() > 0);
+
+ auto matmul = fc->graph()->nodes()->create<luci::CircleBatchMatMul>();
+ matmul->x(x);
+ matmul->y(y);
+ matmul->adj_x(adj_x);
+ matmul->adj_y(adj_y);
+ matmul->name(name);
+ matmul->dtype(fc->dtype());
+
+ luci::add_origin(matmul, luci::get_origin(fc));
+
+ auto all_zero = [](const luci::CircleConst *c) {
+ bool ac = true;
+ for (uint32_t i = 0; i < c->size<loco::DataType::FLOAT32>() && ac; i++)
+ {
+ ac &= c->at<loco::DataType::FLOAT32>(i) == 0.0f;
+ }
+ return ac;
+ };
+
+ auto bc = dynamic_cast<luci::CircleConst *>(b);
+ if ((nullptr != bc) && !all_zero(bc))
+ {
+ auto bias_add = fc->graph()->nodes()->create<luci::CircleAdd>();
+ bias_add->x(matmul);
+ bias_add->y(b);
+ bias_add->name(fc->name() + "/bias_add");
+ bias_add->dtype(fc->dtype());
+ add_origin(bias_add, get_origin(fc));
+ bias_add->fusedActivationFunction(fc->fusedActivationFunction());
+ loco::replace(fc).with(bias_add);
+ }
+ else
+ {
+ auto n = fromActivation(matmul, fc->fusedActivationFunction());
+ add_origin(n, luci::get_origin(fc));
+ n->name(fc->name() + "fusedActivation");
+ n->dtype(fc->dtype());
+ loco::replace(fc).with(n);
+ }
+
+ return true;
+}
+} // namespace
+
+namespace luci
+{
+
+bool ReplaceNonConstFCWithBatchMatMulPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ if (auto fc = dynamic_cast<luci::CircleFullyConnected *>(node))
+ {
+ if (replace_fc_with_matmul(fc))
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp
new file mode 100644
index 000000000..7606a6125
--- /dev/null
+++ b/compiler/luci/pass/src/ReplaceNonConstFCWithBatchMatMulPass.test.cpp
@@ -0,0 +1,189 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ReplaceNonConstFCWithBatchMatMulPass.h"
+
+#include <luci/test/TestIOGraph.h>
+#include <luci/IR/CircleNodes.h>
+
+#include <gtest/gtest.h>
+
+namespace
+{
+
+using namespace luci::test;
+
+// TODO Reduce duplicate codes in ResolveCustomOpMatMulPass.cpp
+template <typename T>
+luci::CircleConst *create_const_node(loco::Graph *g, const loco::DataType dtype,
+ const std::vector<uint32_t> &shape,
+ const std::vector<T> &values)
+{
+ auto node = g->nodes()->create<luci::CircleConst>();
+ node->dtype(dtype);
+ node->rank(shape.size());
+
+ uint32_t size = 1;
+ for (uint32_t i = 0; i < shape.size(); ++i)
+ {
+ node->dim(i) = shape.at(i);
+ size *= shape.at(i);
+ }
+ node->shape_status(luci::ShapeStatus::VALID);
+
+#define INIT_VALUES(DT) \
+ { \
+ node->size<DT>(size); \
+ for (uint32_t i = 0; i < values.size(); ++i) \
+ node->at<DT>(i) = values[i]; \
+ }
+
+ switch (dtype)
+ {
+ case loco::DataType::U8:
+ INIT_VALUES(loco::DataType::U8);
+ break;
+ case loco::DataType::S16:
+ INIT_VALUES(loco::DataType::S16);
+ break;
+ case loco::DataType::S32:
+ INIT_VALUES(loco::DataType::S32);
+ break;
+ case loco::DataType::FLOAT32:
+ INIT_VALUES(loco::DataType::FLOAT32)
+ break;
+ default:
+ INTERNAL_EXN("create_const_node called with unsupported type");
+ break;
+ }
+ return node;
+}
+
+/**
+ * Simple graph for test
+ *
+ * BEFORE
+ *
+ * [IFM1] [IFM2] [BIAS]
+ * \ | /
+ * [FC]
+ * |
+ * [Res]
+ *
+ * AFTER
+ * [IFM1] [IFM2]
+ * \ |
+ * [BatchMatMul] [BIAS]
+ * \ /
+ * [Add]
+ * |
+ * [Res]
+ *
+ */
+struct FCGraphlet
+{
+public:
+ FCGraphlet() = default;
+ virtual ~FCGraphlet() = default;
+
+ void init(loco::Graph *g, const ShapeU32 r_shape, const float bv)
+ {
+ _tr_y = g->nodes()->create<luci::CircleTranspose>();
+ _tr_y->a(_y);
+ std::vector<int32_t> tr_val = {1, 0};
+ _tr_y->perm(create_const_node(g, loco::DataType::S32, {2}, tr_val));
+
+ _fc = g->nodes()->create<luci::CircleFullyConnected>();
+ _fc->input(_x);
+ _fc->weights(_tr_y);
+ _fc->fusedActivationFunction(luci::FusedActFunc::NONE);
+ _fc->dtype(loco::DataType::FLOAT32);
+ _fc->shape(r_shape);
+ auto l = _fc->dim(_fc->rank() - 1).value();
+ std::vector<float> bias_val(l, bv);
+ _fc->bias(create_const_node(g, loco::DataType::FLOAT32, {l}, bias_val));
+ _fc->name("fc");
+ }
+
+public:
+ luci::CircleFullyConnected *fc() { return _fc; }
+
+protected:
+ luci::CircleFullyConnected *_fc = nullptr;
+ luci::CircleTranspose *_tr_y = nullptr;
+ luci::CircleInput *_x = nullptr;
+ luci::CircleInput *_y = nullptr;
+};
+
+struct FCGraph : public TestIsGraphlet<2>, public TestOGraphlet, public FCGraphlet
+{
+ FCGraph() = default;
+ virtual ~FCGraph() = default;
+ void init(const ShapeU32 x_shape, const ShapeU32 y_shape, const ShapeU32 r_shape, const float bv)
+ {
+ TestIsGraphlet<2>::init(g(), {x_shape, y_shape});
+ TestOGraphlet::init(g(), r_shape);
+ _x = input(0);
+ _y = input(1);
+ FCGraphlet::init(g(), r_shape, bv);
+ output()->from(_fc);
+ }
+};
+
+class ReplaceNonConstFCWithBatchMatMulPassTest : public ::testing::Test
+{
+public:
+ FCGraph g;
+ luci::ReplaceNonConstFCWithBatchMatMulPass pass;
+};
+
+} // namespace
+
+TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, simple_test)
+{
+ g.init({2, 3}, {2, 3}, {2, 2}, 0.0f);
+
+ auto ret = pass.run(g.g());
+ EXPECT_EQ(true, ret);
+
+ auto mm = dynamic_cast<luci::CircleBatchMatMul *>(g.output()->from());
+ EXPECT_NE(nullptr, mm);
+}
+
+TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, nonzero_bias_test)
+{
+ g.init({2, 3}, {2, 3}, {2, 2}, 1.0f);
+
+ auto ret = pass.run(g.g());
+ EXPECT_EQ(true, ret);
+
+ auto mm = dynamic_cast<luci::CircleAdd *>(g.output()->from());
+ EXPECT_NE(nullptr, mm);
+}
+
+TEST_F(ReplaceNonConstFCWithBatchMatMulPassTest, wrong_op_NEG)
+{
+ loco::Graph g;
+
+ auto inp = g.nodes()->create<luci::CircleInput>();
+ auto relu = g.nodes()->create<luci::CircleRelu>();
+ relu->features(inp);
+
+ luci::ReplaceNonConstFCWithBatchMatMulPass pass;
+ auto changed = pass.run(&g);
+
+ EXPECT_EQ(false, changed);
+}
diff --git a/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp
new file mode 100644
index 000000000..a65065800
--- /dev/null
+++ b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.cpp
@@ -0,0 +1,172 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ResolveCustomOpSplitVPass.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/Service/Nodes/CircleConst.h>
+
+namespace
+{
+
+// Input node is const S64
+// Return s32 version of node
+// Return nullptr if s64 value is out of range of s32
+luci::CircleConst *s64_to_s32(luci::CircleConst *node)
+{
+ assert(node);
+ assert(node->dtype() == loco::DataType::S64);
+
+ auto cloned = luci::clone(node);
+ luci::add_origin(cloned, luci::get_origin(node));
+
+ const auto num_elems = node->size<loco::DataType::S64>();
+
+ cloned->dtype(loco::DataType::S32);
+ cloned->size<loco::DataType::S32>(num_elems);
+
+ for (uint32_t i = 0; i < num_elems; i++)
+ {
+ int64_t val = node->at<loco::DataType::S64>(i);
+ if (val < std::numeric_limits<int32_t>::min() or val > std::numeric_limits<int32_t>::max())
+ return nullptr;
+
+ cloned->at<loco::DataType::S32>(i) = static_cast<int32_t>(val);
+ }
+
+ return cloned;
+}
+
+/** BEFORE
+ *
+ * [CircleNode]
+ * \
+ * \ [size_splits] [split_dim]
+ * \ | /
+ * [CircleCustom(SplitV))]
+ * |
+ * [CircleCustomOut]
+ * |
+ * [CircleNode]
+ *
+ * AFTER
+ *
+ * [CircleNode]
+ * | \
+ * | \ [size_splits] [split_dim]
+ * | \ | /
+ * | \ | /
+ * | \ | /
+ * [CircleCustom(SplitV)] [CircleSplitV]
+ * | |
+ * [CircleCustomOut] [CircleSplitVOut]
+ * |
+ * [CircleNode]
+ */
+bool resolve_splitv(luci::CircleCustom *node)
+{
+ const std::string custom_code = node->custom_code();
+ const std::vector<uint8_t> custom_options = node->custom_options();
+
+ if (custom_code != "SplitV")
+ return false;
+
+ if (node->numInputs() != 3)
+ return false;
+
+ auto size_splits = dynamic_cast<luci::CircleConst *>(node->inputs(1));
+ if (not size_splits)
+ return false;
+
+ // Convert size_splits to S32, because luci-interpeter does not support
+ // S64 size_splits yet
+ // TODO Support S64 size_splits
+ if (size_splits->dtype() == loco::DataType::S64)
+ {
+ size_splits = s64_to_s32(size_splits);
+ if (not size_splits)
+ return false;
+ }
+ if (size_splits->dtype() != loco::DataType::S32)
+ return false;
+
+ auto split_dim = dynamic_cast<luci::CircleConst *>(node->inputs(2));
+ if (not split_dim)
+ return false;
+
+ if (split_dim->dtype() == loco::DataType::S64)
+ {
+ split_dim = s64_to_s32(split_dim);
+ if (not split_dim)
+ return false;
+ }
+ if (split_dim->dtype() != loco::DataType::S32)
+ return false;
+
+ if (size_splits->rank() != 1)
+ return false;
+
+ const auto num_split = size_splits->dim(0).value();
+
+ auto split_v = node->graph()->nodes()->create<luci::CircleSplitV>();
+ split_v->input(node->inputs(0));
+ split_v->size_splits(size_splits);
+ split_v->split_dim(split_dim);
+ split_v->num_split(num_split);
+ split_v->name(node->name());
+ luci::add_origin(split_v, luci::get_origin(node));
+
+ int32_t i = 0;
+ const auto succs = loco::succs(node);
+ for (auto succ : succs)
+ {
+ auto custom_out = loco::must_cast<luci::CircleCustomOut *>(succ); // FIX_CALLER_UNLESS
+
+ auto split_v_out = node->graph()->nodes()->create<luci::CircleSplitVOut>();
+ split_v_out->input(split_v);
+ split_v_out->name(node->name() + "_out_" + std::to_string(i));
+ split_v_out->index(i++);
+ luci::add_origin(split_v_out, luci::get_origin(node));
+ loco::replace(custom_out).with(split_v_out);
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool ResolveCustomOpSplitVPass::run(loco::Graph *g)
+{
+ bool changed = false;
+
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto cop = dynamic_cast<luci::CircleCustom *>(node);
+ if (not cop)
+ continue;
+
+ if (resolve_splitv(cop))
+ changed = true;
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp
new file mode 100644
index 000000000..e7738aadb
--- /dev/null
+++ b/compiler/luci/pass/src/ResolveCustomOpSplitVPass.test.cpp
@@ -0,0 +1,175 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ResolveCustomOpSplitVPass.h"
+
+#include <luci/test/TestIOGraph.h>
+
+#include <luci/IR/CircleNodes.h>
+#include <gtest/gtest.h>
+
+using namespace luci::test;
+
+namespace
+{
+
+/**
+ * graph having Custom operator SplitV
+ *
+ * [Input] [Const] [Const]
+ * \ | /
+ * [Custom(SplitV)]
+ * / | \
+ * [CustomOut] [CustomOut] [CustomOut]
+ * | | |
+ * [Output] [Output] [Output]
+ */
+class SplitVGraphlet
+{
+public:
+ SplitVGraphlet() = default;
+
+public:
+ void init(loco::Graph *g)
+ {
+ // CircleCustom(SplitV)
+ _splitv = g->nodes()->create<luci::CircleCustom>(3, 3);
+ _splitv->custom_code("SplitV");
+ _splitv->shape({1, 2, 2, 192});
+ _splitv->dtype(loco::DataType::FLOAT32);
+ _splitv->name("splitv");
+
+ // CircleConst
+ auto size_splits = g->nodes()->create<luci::CircleConst>();
+ size_splits->dtype(loco::DataType::S64);
+ size_splits->shape({3});
+ size_splits->size<loco::DataType::S64>(3);
+ size_splits->at<loco::DataType::S64>(0) = 32;
+ size_splits->at<loco::DataType::S64>(1) = 32;
+ size_splits->at<loco::DataType::S64>(2) = 128;
+
+ // CircleConst
+ auto split_dim = g->nodes()->create<luci::CircleConst>();
+ split_dim->dtype(loco::DataType::S32);
+ split_dim->rank(0);
+ split_dim->size<loco::DataType::S32>(1);
+ split_dim->scalar<loco::DataType::S32>() = 3;
+
+ _splitv->inputs(1, size_splits);
+ _splitv->inputs(2, split_dim);
+
+ // CircleCustomOut
+ _splitv_out1 = g->nodes()->create<luci::CircleCustomOut>();
+ _splitv_out1->shape({1, 2, 2, 32});
+ _splitv_out1->dtype(loco::DataType::FLOAT32);
+ _splitv_out1->index(0);
+ _splitv_out1->input(_splitv);
+
+ // CircleCustomOut
+ _splitv_out2 = g->nodes()->create<luci::CircleCustomOut>();
+ _splitv_out2->shape({1, 2, 2, 32});
+ _splitv_out2->dtype(loco::DataType::FLOAT32);
+ _splitv_out2->index(1);
+ _splitv_out2->input(_splitv);
+
+ // CircleCustomOut
+ _splitv_out3 = g->nodes()->create<luci::CircleCustomOut>();
+ _splitv_out3->shape({1, 2, 2, 128});
+ _splitv_out3->dtype(loco::DataType::FLOAT32);
+ _splitv_out3->index(2);
+ _splitv_out3->input(_splitv);
+ }
+
+public:
+ luci::CircleCustom *splitv() { return _splitv; }
+
+protected:
+ luci::CircleCustom *_splitv = nullptr;
+ luci::CircleCustomOut *_splitv_out1 = nullptr;
+ luci::CircleCustomOut *_splitv_out2 = nullptr;
+ luci::CircleCustomOut *_splitv_out3 = nullptr;
+};
+
+class SplitVGraph : public TestIGraphlet, public TestOsGraphlet<3>, public SplitVGraphlet
+{
+public:
+ SplitVGraph() = default;
+
+ void init(void)
+ {
+ TestIGraphlet::init(g(), {1, 2, 2, 192});
+ TestOsGraphlet<3>::init(g(), {{1, 2, 2, 32}, {1, 2, 2, 32}, {1, 2, 2, 128}});
+ SplitVGraphlet::init(g());
+
+ // connect graph
+ _splitv->inputs(0, input());
+
+ output(0)->from(_splitv_out1);
+ output(1)->from(_splitv_out2);
+ output(2)->from(_splitv_out3);
+ }
+};
+
+class SplitVGraphTest : public ::testing::Test
+{
+public:
+ SplitVGraph g;
+ luci::ResolveCustomOpSplitVPass pass;
+};
+
+} // namespace
+
+TEST_F(SplitVGraphTest, simple_test)
+{
+ g.init();
+
+ auto ret = pass.run(g.g());
+ EXPECT_EQ(true, ret);
+
+ auto svo_1 = dynamic_cast<luci::CircleSplitVOut *>(g.output(0)->from());
+ EXPECT_NE(nullptr, svo_1);
+ auto svo_2 = dynamic_cast<luci::CircleSplitVOut *>(g.output(1)->from());
+ EXPECT_NE(nullptr, svo_2);
+ auto svo_3 = dynamic_cast<luci::CircleSplitVOut *>(g.output(2)->from());
+ EXPECT_NE(nullptr, svo_3);
+
+ auto sv = dynamic_cast<luci::CircleSplitV *>(svo_1->input());
+ EXPECT_NE(nullptr, sv);
+ sv = dynamic_cast<luci::CircleSplitV *>(svo_2->input());
+ EXPECT_NE(nullptr, sv);
+ sv = dynamic_cast<luci::CircleSplitV *>(svo_3->input());
+ EXPECT_NE(nullptr, sv);
+
+ auto size_splits = loco::must_cast<luci::CircleConst *>(sv->size_splits());
+ EXPECT_EQ(loco::DataType::S32, size_splits->dtype());
+ EXPECT_EQ(32, size_splits->at<loco::DataType::S32>(0));
+ EXPECT_EQ(32, size_splits->at<loco::DataType::S32>(1));
+ EXPECT_EQ(128, size_splits->at<loco::DataType::S32>(2));
+
+ auto split_dim = loco::must_cast<luci::CircleConst *>(sv->split_dim());
+ EXPECT_EQ(loco::DataType::S32, split_dim->dtype());
+ EXPECT_EQ(3, split_dim->scalar<loco::DataType::S32>());
+}
+
+TEST_F(SplitVGraphTest, wrong_op_NEG)
+{
+ g.init();
+
+ g.splitv()->custom_code("AddV2");
+
+ auto ret = pass.run(g.g());
+ EXPECT_EQ(false, ret);
+}
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
index 442183c18..408e6b8d9 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeGranularity.h
@@ -197,6 +197,13 @@ private:
return true;
}
+ bool visit(const luci::CircleReduceMax *node)
+ {
+ RETURN_FALSE_UNLESS(is_lwq(node));
+ RETURN_FALSE_UNLESS(is_lwq(node->input()));
+ return true;
+ }
+
bool visit(const luci::CircleRelu *node)
{
RETURN_FALSE_UNLESS(is_lwq(node));
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
index 4e1c062c0..cf86acabe 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.cpp
@@ -302,6 +302,15 @@ bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CirclePow *nod
}
template <loco::DataType Qtype, loco::DataType Btype>
+bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleReduceMax *node)
+{
+ RETURN_FALSE_UNLESS(has_type(node, Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->input(), Qtype))
+ RETURN_FALSE_UNLESS(has_type(node->reduction_indices(), loco::DataType::S32))
+ return true;
+}
+
+template <loco::DataType Qtype, loco::DataType Btype>
bool VerifyQuantizedNodeTypeBase<Qtype, Btype>::visit(const luci::CircleRelu *node)
{
return group_has_type(node, Qtype);
diff --git a/compiler/luci/pass/src/VerifyQuantizedNodeType.h b/compiler/luci/pass/src/VerifyQuantizedNodeType.h
index ff1acbd6f..789d3c7cd 100644
--- a/compiler/luci/pass/src/VerifyQuantizedNodeType.h
+++ b/compiler/luci/pass/src/VerifyQuantizedNodeType.h
@@ -104,6 +104,7 @@ private:
bool visit(const luci::CirclePadV2 *node);
bool visit(const luci::CirclePRelu *node);
bool visit(const luci::CirclePow *node);
+ bool visit(const luci::CircleReduceMax *node);
bool visit(const luci::CircleRelu *node);
bool visit(const luci::CircleReshape *node);
bool visit(const luci::CircleResizeBilinear *node);
diff --git a/compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp b/compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp
new file mode 100644
index 000000000..72b7d60ff
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/SparsityFormatConverter.cpp
@@ -0,0 +1,312 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// codes under namespace sparsity referenced from
+// https://github.com/tensorflow/tensorflow/blob/3f878cff5b698b82eea85db2b60d65a2e320850e/
+// tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h
+// tensorflow/lite/kernels/internal/utils/sparsity_format_converter.cc
+
+#include "SparsityFormatConverter.h"
+
+#include <oops/InternalExn.h>
+
+#include <cassert>
+
+namespace sparsity
+{
+
+namespace
+{
+
+uint64_t GetFlattenedIndex(const std::vector<int> &indices, const std::vector<int> &shape)
+{
+ uint64_t index = 0;
+ int sub_elements = 1;
+ for (int i = shape.size() - 1; i >= 0; i--)
+ {
+ index += indices[i] * sub_elements;
+ sub_elements *= shape[i];
+ }
+ return index;
+}
+
+std::vector<int> TfLiteIntArrayToVector(const TfLiteIntArray *int_array)
+{
+ std::vector<int> values;
+ if (!int_array)
+ {
+ return values;
+ }
+
+ values.resize(int_array->size);
+ for (int i = 0; i < int_array->size; i++)
+ {
+ values[i] = int_array->data[i];
+ }
+
+ return values;
+}
+
+} // namespace
+
+template <typename T>
+FormatConverter<T>::FormatConverter(const std::vector<int> &shape, const TfLiteSparsity &sparsity)
+{
+ auto traversal_order = TfLiteIntArrayToVector(sparsity.traversal_order);
+ auto block_map = TfLiteIntArrayToVector(sparsity.block_map);
+
+ std::vector<TfLiteDimensionType> format(sparsity.dim_metadata_size);
+ std::vector<int> dense_size(sparsity.dim_metadata_size);
+ std::vector<std::vector<int>> segments(sparsity.dim_metadata_size);
+ std::vector<std::vector<int>> indices(sparsity.dim_metadata_size);
+ for (int i = 0; i < sparsity.dim_metadata_size; i++)
+ {
+ format[i] = sparsity.dim_metadata[i].format;
+ dense_size[i] = sparsity.dim_metadata[i].dense_size;
+ segments[i] = TfLiteIntArrayToVector(sparsity.dim_metadata[i].array_segments);
+ indices[i] = TfLiteIntArrayToVector(sparsity.dim_metadata[i].array_indices);
+ }
+
+ InitSparseToDenseConverter(shape, std::move(traversal_order), std::move(format),
+ std::move(dense_size), std::move(segments), std::move(indices),
+ std::move(block_map));
+}
+
+template <typename T>
+void FormatConverter<T>::InitSparseToDenseConverter(
+ std::vector<int> shape, std::vector<int> traversal_order, std::vector<TfLiteDimensionType> format,
+ std::vector<int> dense_size, std::vector<std::vector<int>> segments,
+ std::vector<std::vector<int>> indices, std::vector<int> block_map)
+{
+ dense_shape_ = std::move(shape);
+ traversal_order_ = std::move(traversal_order);
+ block_map_ = std::move(block_map);
+ format_ = std::move(format);
+
+ dense_size_ = 1;
+ for (size_t i = 0; i < dense_shape_.size(); i++)
+ {
+ dense_size_ *= dense_shape_[i];
+ }
+
+ dim_metadata_.resize(2 * format_.size());
+ for (size_t i = 0; i < format_.size(); i++)
+ {
+ if (format_[i] == kTfLiteDimDense)
+ {
+ dim_metadata_[2 * i] = {dense_size[i]};
+ }
+ else
+ {
+ dim_metadata_[2 * i] = std::move(segments[i]);
+ dim_metadata_[2 * i + 1] = std::move(indices[i]);
+ }
+ }
+
+ int original_rank = dense_shape_.size();
+ int block_dim = 0;
+
+ blocked_shape_.resize(original_rank);
+ block_size_.resize(block_map_.size());
+ for (int i = 0; i < original_rank; i++)
+ {
+ if (block_dim < (int)block_map_.size() && block_map_[block_dim] == i)
+ {
+ if (original_rank + block_dim < (int)traversal_order_.size())
+ {
+ int orig_dim = traversal_order_[original_rank + block_dim];
+ block_size_[block_dim] = dense_size[orig_dim];
+ blocked_shape_[i] = dense_shape_[i] / dense_size[orig_dim];
+ block_dim++;
+ }
+ }
+ else
+ {
+ blocked_shape_[i] = dense_shape_[i];
+ }
+ }
+}
+
+template <typename T>
+void FormatConverter<T>::Populate(const T *src_data, std::vector<int> indices, int level,
+ int prev_idx, int *src_data_ptr, T *dest_data)
+{
+ if (static_cast<size_t>(level) == indices.size())
+ {
+ int orig_rank = dense_shape_.size();
+ std::vector<int> orig_idx;
+ orig_idx.resize(orig_rank);
+ int i = 0;
+ for (; static_cast<size_t>(i) < orig_idx.size(); i++)
+ {
+ int orig_dim = traversal_order_[i];
+ orig_idx[orig_dim] = indices[i];
+ }
+
+ for (; static_cast<size_t>(i) < indices.size(); i++)
+ {
+ const int block_idx = traversal_order_[i] - orig_rank;
+ const int orig_dim = block_map_[block_idx];
+ orig_idx[orig_dim] = orig_idx[orig_dim] * block_size_[block_idx] + indices[i];
+ }
+
+ dest_data[GetFlattenedIndex(orig_idx, dense_shape_)] = src_data[*src_data_ptr];
+
+ *src_data_ptr = *src_data_ptr + 1;
+ return;
+ }
+
+ const int metadata_idx = 2 * level;
+ const int shape_of_level = dim_metadata_[metadata_idx][0];
+ if (format_[level] == kTfLiteDimDense)
+ {
+ for (int i = 0; i < shape_of_level; i++)
+ {
+ indices[level] = i;
+ Populate(src_data, indices, level + 1, prev_idx * shape_of_level + i, src_data_ptr,
+ dest_data);
+ }
+ }
+ else if (static_cast<size_t>(prev_idx + 1) < dim_metadata_[metadata_idx].size())
+ {
+ const auto &array_segments = dim_metadata_[metadata_idx];
+ const auto &array_indices = dim_metadata_[metadata_idx + 1];
+ for (int i = array_segments[prev_idx]; i < array_segments[prev_idx + 1]; i++)
+ {
+ if (static_cast<size_t>(i) < array_indices.size() &&
+ static_cast<size_t>(level) < indices.size())
+ {
+ indices[level] = array_indices[i];
+ Populate(src_data, indices, level + 1, i, src_data_ptr, dest_data);
+ }
+ }
+ }
+}
+
+template <typename T> bool FormatConverter<T>::SparseToDense(const T *src_data)
+{
+ data_.resize(dense_size_);
+ std::fill(data_.begin(), data_.end(), T(0));
+
+ int total_rank = traversal_order_.size();
+ int src_data_ptr = 0;
+ std::vector<int> indices(total_rank);
+ Populate(src_data, indices, 0, 0, &src_data_ptr, data_.data());
+
+ return true;
+}
+
+template class FormatConverter<float>;
+template class FormatConverter<uint16_t>;
+
+} // namespace sparsity
+
+#include <luci/IR/SparsityParam.h>
+
+namespace luci
+{
+
+sparsity::TfLiteDimensionType to_tflite_sparsity(luci::DimensionType dt)
+{
+ switch (dt)
+ {
+ case luci::DimensionType::DENSE:
+ return sparsity::TfLiteDimensionType::kTfLiteDimDense;
+ case luci::DimensionType::SPARSE_CSR:
+ return sparsity::TfLiteDimensionType::kTfLiteDimSparseCSR;
+ }
+ return sparsity::TfLiteDimensionType::kTfLiteDimDense;
+}
+
+sparsity::TfLiteIntArray *to_tflite_sparsity(const luci::SparseIndexVector &data)
+{
+ auto type = data.type();
+ switch (type)
+ {
+ case luci::SparseIndexVectorType::NONE:
+ {
+ std::vector<int32_t> empty;
+ return makeTfLiteArray(empty);
+ }
+ case luci::SparseIndexVectorType::I32:
+ return makeTfLiteArray<int32_t>(*data.as_int32_vector());
+ case luci::SparseIndexVectorType::U16:
+ return makeTfLiteArray<uint16_t>(*data.as_uint16_vector());
+ case luci::SparseIndexVectorType::U8:
+ return makeTfLiteArray<uint8_t>(*data.as_uint8_vector());
+ default:
+ INTERNAL_EXN_V("unsupported SparseIndexVectorType", oops::to_uint32(type));
+ }
+}
+
+sparsity::TfLiteSparsity to_tflite_sparsity(const luci::SparsityParam *sp)
+{
+ sparsity::TfLiteSparsity tflsp;
+ tflsp.traversal_order = makeTfLiteArray(sp->traversal_order);
+ tflsp.block_map = makeTfLiteArray(sp->block_map);
+ tflsp.dim_metadata = makeTfLiteDimensionMetadata(sp->dim_metadata);
+ tflsp.dim_metadata_size = sp->dim_metadata.size();
+ return tflsp;
+}
+
+template <typename T> sparsity::TfLiteIntArray *makeTfLiteArray(const std::vector<T> &data)
+{
+ size_t cn = data.size();
+ size_t sz = 1 + data.size();
+ sparsity::TfLiteIntArray *sp = (sparsity::TfLiteIntArray *)(new int[sz]);
+ sp->size = cn;
+ for (size_t i = 0; i < cn; ++i)
+ {
+ sp->data[i] = data[i];
+ }
+ return sp;
+}
+
+sparsity::TfLiteDimensionMetadata *
+makeTfLiteDimensionMetadata(const std::vector<luci::DimMetaData> &data)
+{
+ size_t cn = data.size();
+ sparsity::TfLiteDimensionMetadata *tfldm = new sparsity::TfLiteDimensionMetadata[cn];
+
+ for (size_t i = 0; i < cn; ++i)
+ {
+ tfldm[i].format = to_tflite_sparsity(data[i].format());
+ tfldm[i].dense_size = data[i].dense_size();
+ tfldm[i].array_segments = to_tflite_sparsity(data[i].array_segments());
+ tfldm[i].array_indices = to_tflite_sparsity(data[i].array_indices());
+ }
+
+ return tfldm;
+}
+
+void freeTfLiteSparsity(sparsity::TfLiteSparsity &tflsp)
+{
+ assert(tflsp.traversal_order);
+ assert(tflsp.block_map);
+ delete[] tflsp.traversal_order;
+ delete[] tflsp.block_map;
+
+ for (int i = 0; i < tflsp.dim_metadata_size; ++i)
+ {
+ assert(tflsp.dim_metadata[i].array_segments);
+ assert(tflsp.dim_metadata[i].array_indices);
+ delete[] tflsp.dim_metadata[i].array_segments;
+ delete[] tflsp.dim_metadata[i].array_indices;
+ }
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/helpers/SparsityFormatConverter.h b/compiler/luci/pass/src/helpers/SparsityFormatConverter.h
new file mode 100644
index 000000000..fcd9bbcd0
--- /dev/null
+++ b/compiler/luci/pass/src/helpers/SparsityFormatConverter.h
@@ -0,0 +1,129 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LUCI_PASS_HELPERS_SPARSITY_FORMAT_CONVERTER_H__
+#define __LUCI_PASS_HELPERS_SPARSITY_FORMAT_CONVERTER_H__
+
+#include <cstdint>
+#include <vector>
+
+// codes under namespace sparsity referenced from
+// https://github.com/tensorflow/tensorflow/blob/3f878cff5b698b82eea85db2b60d65a2e320850e/
+// tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h
+// tensorflow/lite/kernels/internal/utils/sparsity_format_converter.cc
+
+namespace sparsity
+{
+
+// Storage format of each dimension in a sparse tensor.
+typedef enum TfLiteDimensionType
+{
+ kTfLiteDimDense = 0,
+ kTfLiteDimSparseCSR,
+} TfLiteDimensionType;
+
+// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
+// indices
+typedef struct TfLiteIntArray
+{
+ int size;
+ int data[];
+} TfLiteIntArray;
+
+// Metadata to encode each dimension in a sparse tensor.
+typedef struct TfLiteDimensionMetadata
+{
+ TfLiteDimensionType format;
+ int dense_size;
+ TfLiteIntArray *array_segments;
+ TfLiteIntArray *array_indices;
+} TfLiteDimensionMetadata;
+
+// Parameters used to encode a sparse tensor. For detailed explanation of each
+// field please refer to lite/schema/schema.fbs.
+typedef struct TfLiteSparsity
+{
+ TfLiteIntArray *traversal_order;
+ TfLiteIntArray *block_map;
+ TfLiteDimensionMetadata *dim_metadata;
+ int dim_metadata_size;
+} TfLiteSparsity;
+
+// A converter that keeps an internal representation of sparse tensor parameters
+// and converts tensors between dense and sparse formats.
+template <typename T> class FormatConverter
+{
+public:
+ /* Creates a sparse to dense converter.
+ * @param shape Shape of the target dense tensor.
+ * @param sparsity Sparsity parameter of the sparse TfLiteTensor.
+ */
+ FormatConverter(const std::vector<int> &shape, const TfLiteSparsity &sparsity);
+
+ const std::vector<T> &GetData() { return data_; }
+ const std::vector<std::vector<int>> &GetDimMetadata() { return dim_metadata_; }
+
+ bool SparseToDense(const T *src_data);
+
+private:
+ // Helper function for initializing this converter for sparse to dense
+ // conversion.
+ void InitSparseToDenseConverter(std::vector<int> shape, std::vector<int> traversal_order,
+ std::vector<TfLiteDimensionType> format,
+ std::vector<int> dense_size,
+ std::vector<std::vector<int>> segments,
+ std::vector<std::vector<int>> indices,
+ std::vector<int> block_map);
+
+ void Populate(const T *src_data, std::vector<int> indices, int level, int prev_idx,
+ int *src_data_ptr, T *dest_data);
+
+private:
+ std::vector<int> dense_shape_;
+ std::vector<int> blocked_shape_;
+ size_t dense_size_;
+ std::vector<int> traversal_order_;
+ std::vector<TfLiteDimensionType> format_;
+ std::vector<int> block_size_;
+ std::vector<int> block_map_;
+ std::vector<std::vector<int>> dim_metadata_;
+ std::vector<T> data_;
+};
+
+extern template class FormatConverter<float>;
+extern template class FormatConverter<uint16_t>;
+
+} // namespace sparsity
+
+#include <luci/IR/SparsityParam.h>
+
+namespace luci
+{
+
+sparsity::TfLiteDimensionType to_tflite_sparsity(luci::DimensionType dt);
+sparsity::TfLiteIntArray *to_tflite_sparsity(const luci::SparseIndexVector &data);
+sparsity::TfLiteSparsity to_tflite_sparsity(const luci::SparsityParam *sp);
+
+template <typename T> sparsity::TfLiteIntArray *makeTfLiteArray(const std::vector<T> &data);
+sparsity::TfLiteDimensionMetadata *
+makeTfLiteDimensionMetadata(const std::vector<luci::DimMetaData> &data);
+
+void freeTfLiteSparsity(sparsity::TfLiteSparsity &tflsp);
+
+} // namespace luci
+
+#endif // __LUCI_PASS_HELPERS_SPARSITY_FORMAT_CONVERTER_H__