summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp')
-rw-r--r--compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp214
1 files changed, 214 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
new file mode 100644
index 000000000..11970fff5
--- /dev/null
+++ b/compiler/luci/pass/src/ConvertToFakeQuantizedModelPass.cpp
@@ -0,0 +1,214 @@
+/*
+ * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
+#include "luci/Pass/QuantizationParameters.h"
+
+#include "QuantizationUtils.h"
+
+#include <luci/Profile/CircleNodeOrigin.h>
+#include <luci/IR/CircleNodes.h>
+#include <luci/IR/CircleNodeVisitor.h>
+#include <luci/Log.h>
+
+namespace
+{
+
+// Create Quantize Op whose dtype/shape/qparam are the same with node
+luci::CircleQuantize *create_quantize(luci::CircleNode *node)
+{
+ auto quantize = node->graph()->nodes()->create<luci::CircleQuantize>();
+ quantize->name(node->name() + "_Quantize");
+ quantize->dtype(node->dtype());
+ quantize->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ quantize->dim(i).set(node->dim(i).value());
+
+ quantize->shape_status(luci::ShapeStatus::VALID);
+
+ copy_quantparam(node, quantize);
+
+ luci::add_origin(quantize, luci::get_origin(node));
+
+ return quantize;
+}
+
+// Create Dequantize Op whose shape is the same with node
+luci::CircleDequantize *create_dequantize(luci::CircleNode *node)
+{
+ auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>();
+ dequantize->name(node->name() + "_Dequantize");
+ dequantize->dtype(loco::DataType::FLOAT32);
+ dequantize->rank(node->rank());
+ for (uint32_t i = 0; i < node->rank(); i++)
+ dequantize->dim(i).set(node->dim(i).value());
+
+ dequantize->shape_status(luci::ShapeStatus::VALID);
+
+ luci::add_origin(dequantize, luci::get_origin(node));
+
+ return dequantize;
+}
+
+// Return true if node is quantized activation
+// 1. dtype is u8 or s16
+// 2. node has qparam
+bool is_quant_act(const luci::CircleNode *node)
+{
+ if (node->dtype() != loco::DataType::U8 and node->dtype() != loco::DataType::S16)
+ return false;
+
+ if (not node->quantparam())
+ return false;
+
+ return true;
+}
+
+// Return true if node is quantized const
+// 1. dtype is not fp32
+// 2. node has qparam
+// NOTE Quantized const can have the following types
+// u8 (weights, activation), s16 (weights, activation), s32 (bias), s64 (bias)
+bool is_quant_const(const luci::CircleConst *node)
+{
+ if (node->dtype() == loco::DataType::FLOAT32)
+ return false;
+
+ if (not node->quantparam())
+ return false;
+
+ return true;
+}
+
+// Insert dequantize Op after node
+void insert_dequantize(loco::Node *lnode)
+{
+ auto node = loco::must_cast<luci::CircleNode *>(lnode);
+ auto dequant = create_dequantize(node);
+ loco::replace(node).with(dequant);
+ dequant->input(node);
+}
+
+// Insert quantize Op after node and return the quantize Op
+luci::CircleQuantize *insert_quantize(loco::Node *lnode)
+{
+ auto node = loco::must_cast<luci::CircleNode *>(lnode);
+ auto quant = create_quantize(node);
+ loco::replace(node).with(quant);
+ quant->input(node);
+ return quant;
+}
+
+// Dequantize node
+void dequantize(luci::CircleNode *node)
+{
+ node->dtype(loco::DataType::FLOAT32);
+ node->quantparam(nullptr);
+}
+
+// Do fake quantization on quantized activation
+// 1. Insert Quantize-Dequantize Ops
+// 2. Update dtype/quantparam of node
+void fq_activation(luci::CircleNode *node)
+{
+ if (not is_quant_act(node))
+ return;
+
+ auto quant = insert_quantize(node);
+ insert_dequantize(quant);
+
+ dequantize(node);
+}
+
+#define RETURN_UNLESS(COND) \
+ if (not(COND)) \
+ return;
+
+// Visitor to do fake quantization for each Op
+// For non-const activation, insert Quantize-Dequantize after the ofm
+// For quantized const, insert Dequantize after the const
+struct FakeQuantize final : public luci::CircleNodeMutableVisitor<void>
+{
+ void visit(luci::CircleNode *node)
+ {
+ throw std::runtime_error("Unsupported op for fake quantization in " + node->name());
+ }
+
+ void visit(luci::CircleInput *node)
+ {
+ RETURN_UNLESS(is_quant_act(node));
+
+ auto quant = insert_quantize(node);
+ insert_dequantize(quant);
+
+ dequantize(node);
+
+ // Update graph input
+ const auto inputs = node->graph()->inputs();
+ auto graph_input = inputs->at(node->index());
+ graph_input->dtype(loco::DataType::FLOAT32);
+ }
+
+ void visit(luci::CircleOutput *node)
+ {
+ RETURN_UNLESS(is_quant_act(node));
+
+ dequantize(node);
+
+ // Update graph output
+ const auto outputs = node->graph()->outputs();
+ auto graph_output = outputs->at(node->index());
+ graph_output->dtype(loco::DataType::FLOAT32);
+ }
+
+ // For quantized const, insert Dequantize Op
+ void visit(luci::CircleConst *node)
+ {
+ RETURN_UNLESS(is_quant_const(node));
+
+ insert_dequantize(node);
+ }
+
+ // For non-const activation, insert Quantize-Dequantize Ops
+ // and dequantize the node
+ void visit(luci::CircleConv2D *node) { fq_activation(node); }
+ void visit(luci::CircleAdd *node) { fq_activation(node); }
+};
+
+#undef RETURN_UNLESS
+
+} // namespace
+
+namespace luci
+{
+
+bool ConvertToFakeQuantizedModelPass::run(loco::Graph *g)
+{
+ LOGGER(l);
+ for (auto node : loco::active_nodes(loco::output_nodes(g)))
+ {
+ auto circle_node = loco::must_cast<luci::CircleNode *>(node);
+ INFO(l) << "ConvertToFakeQuantizedModelPass visit node: " << circle_node->name() << std::endl;
+
+ FakeQuantize fq;
+ circle_node->accept(&fq);
+ }
+
+ // One time run
+ return false;
+}
+
+} // namespace luci