summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2020-09-05 21:49:46 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2020-09-05 21:49:46 +0900
commit74476a2d0296bdad70a2f7f90bc7419a8b05bffd (patch)
tree3f991636c1e9423d38eb16a384c20b569b0d678e /compiler/luci/pass/src
parent042b262b3633b6c0f577aed6cb4b980ad0c1dcf3 (diff)
downloadnnfw-74476a2d0296bdad70a2f7f90bc7419a8b05bffd.tar.gz
nnfw-74476a2d0296bdad70a2f7f90bc7419a8b05bffd.tar.bz2
nnfw-74476a2d0296bdad70a2f7f90bc7419a8b05bffd.zip
Diffstat (limited to 'compiler/luci/pass/src')
-rw-r--r--compiler/luci/pass/src/CircleOptimizer.cpp36
-rw-r--r--compiler/luci/pass/src/FuseBCQPass.cpp4
-rw-r--r--compiler/luci/pass/src/FuseBatchNormWithTConv.cpp159
-rw-r--r--compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp7
-rw-r--r--compiler/luci/pass/src/RequantizePass.cpp241
5 files changed, 444 insertions, 3 deletions
diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp
index 2edf7a9c6..2ee759b4e 100644
--- a/compiler/luci/pass/src/CircleOptimizer.cpp
+++ b/compiler/luci/pass/src/CircleOptimizer.cpp
@@ -16,11 +16,13 @@
#include "luci/CircleOptimizer.h"
+#include "luci/Pass/FuseBatchNormWithTConv.h"
#include "luci/Pass/FuseBCQPass.h"
#include "luci/Pass/FuseInstanceNormPass.h"
#include "luci/Pass/ResolveCustomOpAddPass.h"
#include "luci/Pass/ResolveCustomOpBatchMatMulPass.h"
#include "luci/Pass/ResolveCustomOpMatMulPass.h"
+#include "luci/Pass/RequantizePass.h"
#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
// TODO add more passes
@@ -34,6 +36,7 @@
#include "ProgressReporter.h"
#include "CircleOptimizerUtils.h"
+#include <luci/IR/CircleNodes.h>
#include <logo/Phase.h>
#include <memory>
@@ -125,6 +128,10 @@ void CircleOptimizer::optimize(loco::Graph *g) const
{
phase.emplace_back(std::make_unique<FuseBCQPass>());
}
+ if (_options->query(Options::Algorithm::FuseBatchNormWithTConv))
+ {
+ phase.emplace_back(std::make_unique<FuseBatchNormWithTConvPass>());
+ }
// Shape inference is needed for added nodes doing above transformations
phase.emplace_back(std::make_unique<luci::ShapeInferencePass>());
@@ -163,6 +170,14 @@ void CircleOptimizer::quantize(loco::Graph *g) const
throw std::runtime_error("Unsupported granularity. List of supported granularity: " +
to_string(fakeq_supported_granularity));
+ // Clear existing quantparams before doing fake quantization
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ if (circle_node->quantparam() != nullptr)
+ circle_node->quantparam(nullptr);
+ }
+
luci::QuantizeDequantizeWeightsPass fake_quantizer(
str_to_dtype(input_dtype), str_to_dtype(output_dtype), str_to_granularity(granularity));
fake_quantizer.run(g);
@@ -196,6 +211,27 @@ void CircleOptimizer::quantize(loco::Graph *g) const
quantizer.run(g);
}
+ // Requantize
+ if (_options->query(Options::Algorithm::Requantize))
+ {
+ static const std::vector<std::string> rq_supported_input_dtype{"int8"};
+ static const std::vector<std::string> rq_supported_output_dtype{"uint8"};
+
+ auto input_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_dtype);
+ auto output_dtype = _options->param(Options::AlgorithmParameters::Quantize_output_dtype);
+
+ if (!in_array(to_lower_case(input_dtype), rq_supported_input_dtype))
+ throw std::runtime_error("Unsupported input type. List of supported input types: " +
+ to_string(rq_supported_input_dtype));
+
+ if (!in_array(to_lower_case(output_dtype), rq_supported_output_dtype))
+ throw std::runtime_error("Unsupported output type. List of supported output types: " +
+ to_string(rq_supported_output_dtype));
+
+ luci::RequantizePass requantizer(str_to_dtype(input_dtype), str_to_dtype(output_dtype));
+ requantizer.run(g);
+ }
+
logo::Phase phase;
// Do Shape/Type inference
diff --git a/compiler/luci/pass/src/FuseBCQPass.cpp b/compiler/luci/pass/src/FuseBCQPass.cpp
index 260de5b30..7aa2e3e80 100644
--- a/compiler/luci/pass/src/FuseBCQPass.cpp
+++ b/compiler/luci/pass/src/FuseBCQPass.cpp
@@ -38,9 +38,9 @@ const std::string node_name_prefix(luci::NodeName node_name)
{
std::string prefix = node_name;
- if (prefix.find("ReadVariableOp/resource/") != std::string::npos)
+ if (prefix.find("/ReadVariableOp/resource") != std::string::npos)
{
- const auto start_index = prefix.find("ReadVariableOp/resource/");
+ const auto start_index = prefix.find("/ReadVariableOp/resource");
const auto left_prefix = prefix.substr(0, start_index);
const auto right_prefix = prefix.substr(start_index + 24);
diff --git a/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp b/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp
new file mode 100644
index 000000000..e39455b1a
--- /dev/null
+++ b/compiler/luci/pass/src/FuseBatchNormWithTConv.cpp
@@ -0,0 +1,159 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/FuseBatchNormWithTConv.h"
+
+#include <luci/IR/CircleNodes.h>
+
+namespace
+{
+/**
+ * NOTE TF's fusedBatchNorm is converted to mul and add of Circle.
+ *
+ * BEFORE
+ *
+ * [CircleTransposeConv]
+ * |
+ * [mul]
+ * |
+ * [add]
+ * AFTER
+ *
+ * [CircleTransposeConv]
+ */
+bool fused_batch_norm_with_tconv(luci::CircleTransposeConv *tconv)
+{
+ // check whether it has bias or not. This optimization works only if it doesn't.
+ auto bias = dynamic_cast<luci::CircleOutputExclude *>(tconv->bias());
+ if (not bias)
+ return false;
+
+ // get weight of tconv
+ auto filter = dynamic_cast<luci::CircleConst *>(tconv->filter());
+ if (not filter)
+ return false;
+ if (filter->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ // get mul node
+ auto tconv_output = loco::succs(tconv);
+ assert(tconv_output.size() == 1);
+ auto mul = dynamic_cast<luci::CircleMul *>(*tconv_output.begin());
+ if (not mul)
+ return false;
+ if (mul->dtype() != loco::DataType::FLOAT32)
+ return false;
+
+ // get add node
+ auto mul_output = loco::succs(mul);
+ assert(mul_output.size() == 1);
+ auto add = dynamic_cast<luci::CircleAdd *>(*mul_output.begin());
+ if (not add)
+ return false;
+ if (add->dtype() != loco::DataType::FLOAT32)
+ return false;
+ if (add->fusedActivationFunction() != luci::FusedActFunc::NONE &&
+ add->fusedActivationFunction() != luci::FusedActFunc::RELU6)
+ return false;
+
+ // get scale of batchnorm
+ auto scale = dynamic_cast<luci::CircleConst *>(mul->y());
+ if (not scale)
+ return false;
+
+ // scale dim(0) == tconv filter channel dim
+ if (filter->rank() != 4)
+ return false;
+ auto filter_channel_dim = filter->dim(3).value();
+ if (scale->rank() != 1)
+ return false;
+ auto scale_dim = scale->dim(0).value();
+ if (filter_channel_dim != scale_dim)
+ return false;
+
+ // get shift of batchnorm
+ auto shift = dynamic_cast<luci::CircleConst *>(add->y());
+ if (not shift)
+ return false;
+
+ // shift dim(0) == tconv filter channel dim
+ if (shift->rank() != 1)
+ return false;
+ auto shift_dim = shift->dim(0).value();
+ if (filter_channel_dim != shift_dim)
+ return false;
+
+ // filter weight = filter weight * mul(scale) + add(shift)
+ uint32_t filter_batch_dim = filter->dim(0).value();
+ uint32_t filter_height_dim = filter->dim(1).value();
+ uint32_t filter_width_dim = filter->dim(2).value();
+ for (uint32_t c = 0; c < filter_channel_dim; c++)
+ {
+ for (uint32_t n = 0; n < filter_batch_dim; n++)
+ {
+ for (uint32_t h = 0; h < filter_height_dim; h++)
+ {
+ for (uint32_t w = 0; w < filter_width_dim; w++)
+ {
+ uint32_t offset = n * filter_height_dim * filter_width_dim * filter_channel_dim +
+ h * filter_width_dim * filter_channel_dim + w * filter_channel_dim + c;
+ filter->at<loco::DataType::FLOAT32>(offset) *= scale->at<loco::DataType::FLOAT32>(c);
+ }
+ }
+ }
+ }
+
+ // fuse shift with transposed conv
+ tconv->bias(shift);
+
+ if (add->fusedActivationFunction() == luci::FusedActFunc::RELU6)
+ {
+ // separate relu op from add op
+ auto relu = add->graph()->nodes()->create<luci::CircleRelu6>();
+ relu->features(tconv);
+
+ // remove mul node
+ replace(add).with(relu);
+ }
+ else
+ {
+ replace(add).with(tconv);
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace luci
+{
+
+bool FuseBatchNormWithTConvPass::run(loco::Graph *g)
+{
+ bool changed = false;
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto tconv = dynamic_cast<luci::CircleTransposeConv *>(node);
+ if (not tconv)
+ continue;
+
+ changed |= fused_batch_norm_with_tconv(tconv);
+ }
+
+ return changed;
+}
+
+} // namespace luci
diff --git a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
index b335a53b4..60c1cdd72 100644
--- a/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
+++ b/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
@@ -472,7 +472,12 @@ struct QuantizeWeights final : public luci::CircleNodeMutableVisitor<bool>
if (granularity == QuantizationGranularity::ChannelWise)
{
auto quantparam = circle_node->quantparam();
- assert(quantparam != nullptr);
+ if (quantparam == nullptr)
+ {
+ assert(false && "quantparam is nullptr");
+ return false;
+ }
+
auto min = quantparam->min;
auto scaling_factor = quantparam->scale;
int32_t channel_dim_index = 0;
diff --git a/compiler/luci/pass/src/RequantizePass.cpp b/compiler/luci/pass/src/RequantizePass.cpp
new file mode 100644
index 000000000..49fbf76ec
--- /dev/null
+++ b/compiler/luci/pass/src/RequantizePass.cpp
@@ -0,0 +1,241 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/RequantizePass.h"
+#include "QuantizationUtils.h"
+
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+
+#include <oops/UserExn.h>
+
+#include <iostream>
+#include <cmath>
+
+namespace luci
+{
+
+namespace
+{
+
+// Check if the node is the bias of Conv2D, DepthwiseConv2D, or FullyConnected layer
+bool is_bias(CircleConst *node)
+{
+ if (node == nullptr)
+ return false;
+
+ auto succs = loco::succs(node);
+ if (succs.size() != 1) // assume bias is used by only one node
+ return false;
+
+ for (auto out : succs)
+ {
+ auto conv = dynamic_cast<CircleConv2D *>(out);
+ if (conv != nullptr && conv->bias() == node)
+ return true;
+
+ auto dw_conv = dynamic_cast<CircleDepthwiseConv2D *>(out);
+ if (dw_conv != nullptr && dw_conv->bias() == node)
+ return true;
+
+ auto fc = dynamic_cast<CircleFullyConnected *>(out);
+ if (fc != nullptr && fc->bias() == node)
+ return true;
+
+ // TODO: add TransposeConv when bias is supported in CircleTransposeConv
+ }
+ return false;
+}
+
+void requant_nonconst_int8_to_uint8(CircleNode *circle_node)
+{
+ assert(circle_node->dtype() == loco::DataType::S8);
+
+ auto quantparam = circle_node->quantparam();
+ assert(quantparam != nullptr);
+ for (size_t i = 0; i < quantparam->zerop.size(); ++i)
+ {
+ quantparam->zerop[i] += 128;
+ }
+ circle_node->dtype(loco::DataType::U8);
+}
+
+// Requantize CircleConst from symmetric int8 to asymmetric uint8
+// Original values: -127 ~ 127
+// After requantization: 1 ~ 255 (zp <- zp + 128)
+void requant_const_int8_to_uint8(CircleConst *node)
+{
+ assert(node->dtype() == loco::DataType::S8);
+
+ uint32_t size = node->size<loco::DataType::S8>();
+ std::vector<int32_t> requantized_values(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ int32_t data = node->at<loco::DataType::S8>(i);
+ requantized_values[i] = data + 128;
+ }
+
+ node->dtype(loco::DataType::U8); // change the type of tensor
+ node->size<loco::DataType::U8>(size);
+ for (uint32_t i = 0; i < size; ++i)
+ {
+ assert(1 <= requantized_values[i] && requantized_values[i] <= 255);
+ node->at<loco::DataType::U8>(i) = requantized_values[i];
+ }
+
+ auto quantparam = node->quantparam();
+ assert(quantparam != nullptr);
+ for (size_t i = 0; i < quantparam->zerop.size(); ++i)
+ {
+ quantparam->zerop[i] += 128;
+ }
+}
+
+/**
+ * @brief RequantizeNonConst requantizes tensors for activations
+ */
+struct RequantizeNonConst final : public luci::CircleNodeMutableVisitor<bool>
+{
+ RequantizeNonConst(loco::DataType input, loco::DataType output)
+ : _input_type(input), _output_type(output)
+ {
+ }
+
+ loco::DataType _input_type;
+ loco::DataType _output_type;
+
+ // Requantize input tensors of each node
+ bool visit(luci::CircleNode *node)
+ {
+ LOGGER(l);
+ INFO(l) << "RequantizeNonConst visit node: " << node->name() << std::endl;
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
+ {
+ auto input_node = node->arg(i);
+ auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+
+ // Check if this was quantized (only quantized tensors are requantized)
+ if (circle_node->quantparam() == nullptr)
+ continue;
+
+ // Check if this is already requantized
+ if (circle_node->dtype() == _output_type)
+ continue;
+
+ // Check if this is not const (only non-const is requantized in this function)
+ auto circle_const = dynamic_cast<CircleConst *>(circle_node);
+ if (circle_const != nullptr)
+ continue;
+
+ if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
+ requant_nonconst_int8_to_uint8(circle_node);
+ }
+ return false;
+ }
+};
+
+/**
+ * @brief RequantizeConst requantizes tensors for weights
+ */
+struct RequantizeConst final : public luci::CircleNodeMutableVisitor<bool>
+{
+ RequantizeConst(loco::DataType input, loco::DataType output)
+ : _input_type(input), _output_type(output)
+ {
+ }
+
+ loco::DataType _input_type;
+ loco::DataType _output_type;
+
+ // Requantize input tensors of each node
+ bool visit(luci::CircleNode *node)
+ {
+ LOGGER(l);
+ INFO(l) << "RequantizeConst visit node: " << node->name() << std::endl;
+ auto arity = node->arity();
+ for (uint32_t i = 0; i < arity; i++)
+ {
+ auto input_node = node->arg(i);
+ auto circle_node = loco::must_cast<luci::CircleNode *>(input_node);
+
+ // Check if this was quantized (only quantized tensors are requantized)
+ if (circle_node->quantparam() == nullptr)
+ continue;
+
+ // Check if this is already requantized
+ if (circle_node->dtype() == _output_type)
+ continue;
+
+ // Check if this is const (only const is requantized in this function)
+ auto circle_const = dynamic_cast<CircleConst *>(circle_node);
+ if (circle_const == nullptr)
+ continue;
+
+ // Check if this is not bias
+ // bias is not requantized when int8 -> uint8
+ if (is_bias(circle_const))
+ continue;
+
+ if (_input_type == loco::DataType::S8 && _output_type == loco::DataType::U8)
+ requant_const_int8_to_uint8(circle_const);
+ }
+ return false;
+ }
+};
+
+} // namespace
+
+bool RequantizePass::run(loco::Graph *g)
+{
+ LOGGER(l);
+ INFO(l) << "RequantizePass Start" << std::endl;
+
+ // Requantize non-const (activations)
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ RequantizeNonConst rqnc(_input_dtype, _output_dtype);
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ circle_node->accept(&rqnc);
+ }
+
+ // Requantize const (including weights, constants)
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ RequantizeConst rqc(_input_dtype, _output_dtype);
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ circle_node->accept(&rqc);
+ }
+
+ // Update output dtype
+ auto graph_outputs = g->outputs();
+ for (auto node : loco::output_nodes(g))
+ {
+ auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
+ if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _output_dtype)
+ {
+ circle_node->dtype(_output_dtype);
+ auto graph_output = graph_outputs->at(circle_node->index());
+ graph_output->dtype(_output_dtype);
+ }
+ }
+
+ INFO(l) << "RequantizePass End" << std::endl;
+ return false; // one time run
+}
+
+} // namespace luci