summaryrefslogtreecommitdiff
path: root/compiler/circle-mpqsolver/src/bisection/Quantizer.cpp
blob: 6fe1d560b2a811ce967c7688a4e7c1ef2ef112da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*
 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "Quantizer.h"
#include <luci/Service/Validate.h>

#include <iostream>

using namespace mpqsolver::bisection;
using AlgorithmParameters = luci::CircleQuantizer::Options::AlgorithmParameters;
using Algorithms = luci::CircleQuantizer::Options::Algorithm;

namespace
{

bool make_model_fake_quantized(luci::Module *module)
{
  luci::CircleQuantizer quantizer;

  auto options = quantizer.options();
  options->enable(Algorithms::ConvertToFakeQuantizedModel);

  for (size_t idx = 0; idx < module->size(); ++idx)
  {
    auto graph = module->graph(idx);
    // quantize the graph
    quantizer.quantize(graph);
    if (!luci::validate(graph))
    {
      return false;
    }
  }

  return true;
}

} // namespace

Quantizer::Quantizer(const std::string &input_dtype, const std::string &output_dtype)
  : _input_dtype(input_dtype), _output_dtype(output_dtype)
{
}

/**
 * @brief quantize recorded module (min/max initialized) with specified parameters
 * returns true on success
 */
bool Quantizer::quantize(luci::Module *module, const std::string &quant_dtype,
                         LayerParams &layer_params)
{
  if (!module)
    return false;

  static const std::string default_dtype = "float32";
  static const std::string granularity_type = "channel";

  luci::CircleQuantizer quantizer;

  auto options = quantizer.options();
  options->enable(Algorithms::QuantizeWithMinMax);

  options->param(AlgorithmParameters::Quantize_input_model_dtype, default_dtype);
  options->param(AlgorithmParameters::Quantize_output_model_dtype, quant_dtype);
  options->param(AlgorithmParameters::Quantize_granularity, granularity_type);
  options->param(AlgorithmParameters::Quantize_input_type, _input_dtype);
  options->param(AlgorithmParameters::Quantize_output_type, _output_dtype);
  options->param(AlgorithmParameters::Quantize_TF_style_maxpool, "False");

  if (!layer_params.empty())
  {
    try
    {
      options->layer_params(AlgorithmParameters::Quantize_layer_params, layer_params);
    }
    catch (const std::runtime_error &e)
    {
      std::cerr << e.what() << '\n';
      return false;
    }
  }

  for (size_t idx = 0; idx < module->size(); ++idx)
  {
    auto graph = module->graph(idx);
    // quantize the graph
    quantizer.quantize(graph);
    if (!luci::validate(graph))
    {
      std::cerr << "ERROR: Quantized graph is invalid" << std::endl;
      return false;
    }
  }

  return true;
}

/**
 * @brief fake_quantize recorded module (min/max initialized) with specified parameters
 * returns true on success
 */
bool Quantizer::fake_quantize(luci::Module *module, const std::string &quant_dtype,
                              LayerParams &layer_params)
{
  if (!quantize(module, quant_dtype, layer_params))
    return false;

  if (!make_model_fake_quantized(module))
    return false;

  return true;
}