summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/CircleQuantizer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/CircleQuantizer.cpp')
-rw-r--r--compiler/luci/pass/src/CircleQuantizer.cpp7
1 files changed, 5 insertions, 2 deletions
diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp
index ce38a90b9..9a6550b9f 100644
--- a/compiler/luci/pass/src/CircleQuantizer.cpp
+++ b/compiler/luci/pass/src/CircleQuantizer.cpp
@@ -22,6 +22,7 @@
#include "luci/Pass/RequantizePass.h"
#include "luci/Pass/ConvertToFakeQuantizedModelPass.h"
#include "luci/Pass/FoldDequantizePass.h"
+#include "luci/Pass/RemoveRedundantDequantizePass.h"
#include "luci/Pass/QuantizePreCheckerPass.h"
#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/QuantizeDequantizeWeightsPass.h"
@@ -252,8 +253,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const
static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"};
static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"};
static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"};
- static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"};
- static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"};
+ static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16", "float32"};
+ static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16", "float32"};
auto input_model_dtype =
_options->param(Options::AlgorithmParameters::Quantize_input_model_dtype);
@@ -434,6 +435,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const
phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>());
phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>());
+ // Remove redundant Dequantize Ops generated during fake quantization
+ phase.emplace_back(std::make_unique<luci::RemoveRedundantDequantizePass>());
// Fold Dequantize Ops generated during fake quantization
phase.emplace_back(std::make_unique<luci::FoldDequantizePass>());