diff options
Diffstat (limited to 'compiler/luci/pass/src/CircleQuantizer.cpp')
-rw-r--r-- | compiler/luci/pass/src/CircleQuantizer.cpp | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/compiler/luci/pass/src/CircleQuantizer.cpp b/compiler/luci/pass/src/CircleQuantizer.cpp index ce38a90b9..9a6550b9f 100644 --- a/compiler/luci/pass/src/CircleQuantizer.cpp +++ b/compiler/luci/pass/src/CircleQuantizer.cpp @@ -22,6 +22,7 @@ #include "luci/Pass/RequantizePass.h" #include "luci/Pass/ConvertToFakeQuantizedModelPass.h" #include "luci/Pass/FoldDequantizePass.h" +#include "luci/Pass/RemoveRedundantDequantizePass.h" #include "luci/Pass/QuantizePreCheckerPass.h" #include "luci/Pass/QuantizeWithMinMaxPass.h" #include "luci/Pass/QuantizeDequantizeWeightsPass.h" @@ -252,8 +253,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const static const std::vector<std::string> qwmm_supported_input_model_dtype{"float32"}; static const std::vector<std::string> qwmm_supported_output_model_dtype{"uint8", "int16"}; static const std::vector<std::string> qwmm_supported_granularity{"layer", "channel"}; - static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16"}; - static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16"}; + static const std::vector<std::string> qwmm_supported_input_type{"uint8", "int16", "float32"}; + static const std::vector<std::string> qwmm_supported_output_type{"uint8", "int16", "float32"}; auto input_model_dtype = _options->param(Options::AlgorithmParameters::Quantize_input_model_dtype); @@ -434,6 +435,8 @@ void CircleQuantizer::quantize(loco::Graph *g) const phase.emplace_back(std::make_unique<luci::CircleShapeInferencePass>()); phase.emplace_back(std::make_unique<luci::CircleTypeInferencePass>()); + // Remove redundant Dequantize Ops generated during fake quantization + phase.emplace_back(std::make_unique<luci::RemoveRedundantDequantizePass>()); // Fold Dequantize Ops generated during fake quantization phase.emplace_back(std::make_unique<luci::FoldDequantizePass>()); |