summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/QuantizeWithMinMaxPass.cpp
blob: 005144516ec308f635514e416b27e4755214cdb7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
/*
 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "luci/Pass/QuantizeWithMinMaxPass.h"
#include "luci/Pass/PropagateQParamForwardPass.h"
#include "luci/Pass/PropagateQParamBackwardPass.h"
#include "luci/Pass/RemoveRedundantQuantizePass.h"
#include "QuantizeActivation.h"
#include "QuantizeWeights.h"
#include "QuantizeBias.h"
#include "QuantizationUtils.h"
#include "ProgressReporter.h"
#include "helpers/LayerInfoMap.h"

#include <luci/IR/CircleNodes.h>
#include <luci/IR/CircleNodeVisitor.h>
#include <luci/Service/Nodes/CircleConst.h>
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Log.h>
#include <logo/Phase.h>

#include <oops/UserExn.h>

#include <iostream>
#include <cmath>

namespace
{

using namespace luci;

bool use_predefined_values(ActivationQType qtype)
{
  switch (qtype)
  {
    case ActivationQType::PreDefinedLogistic:
    case ActivationQType::PreDefinedTanh:
    case ActivationQType::PreDefinedSoftmax:
      return true;
    default:
      // This ensures this switch-statement handles all ActivationQTypes
      assert(qtype == ActivationQType::IntScale or qtype == ActivationQType::MinMax);
      break;
  }

  return false;
}

// Create a Quantize Op whose
// dtype is out_type
// shape is the same with node
// qparam is computed according to node's qtype
luci::CircleQuantize *create_quantize_op(luci::CircleNode *node, loco::DataType out_type)
{
  auto quantize = node->graph()->nodes()->create<CircleQuantize>();
  quantize->name(node->name() + "_Quantize");
  quantize->dtype(out_type);
  quantize->rank(node->rank());
  for (uint32_t i = 0; i < node->rank(); i++)
    quantize->dim(i).set(node->dim(i).value());

  quantize->shape_status(luci::ShapeStatus::VALID);

  auto qparam = node->quantparam();
  assert(qparam); // FIX_CALLER_UNLESS

  auto qtype = luci::activation_qtype(node);
  if (use_predefined_values(qtype))
  {
    quantize->quantparam(luci::make_predefined_qparam(qtype, out_type));
    return quantize;
  }

  assert(qtype == ActivationQType::MinMax or qtype == ActivationQType::IntScale);

  assert(qparam->min.size() == 1); // FIX_CALLER_UNLESS
  assert(qparam->max.size() == 1); // FIX_CALLER_UNLESS
  auto min = qparam->min[0];
  auto max = qparam->max[0];

  float scaling_factor{0};
  int64_t zp{0};
  float nudged_min{0};
  float nudged_max{0};

  if (out_type == loco::DataType::U8)
  {
    compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
  }
  else
  {
    assert(out_type == loco::DataType::S16);
    compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
  }

  auto quantparam = std::make_unique<CircleQuantParam>();
  quantparam->scale.push_back(scaling_factor);
  quantparam->zerop.push_back(zp);
  // Save original min/max (not nudged_min/max). Nudged min/max
  // is different from the real min/max values, causing wrong
  // qparam when quantization dtype is changed.
  quantparam->min.push_back(min);
  quantparam->max.push_back(max);

  quantize->quantparam(std::move(quantparam));

  if (qtype == ActivationQType::IntScale)
    set_int_scale(quantize);

  return quantize;
}

// Create Dequantize Op whose shape is the same with node
luci::CircleDequantize *create_dequantize(luci::CircleNode *node)
{
  auto dequantize = node->graph()->nodes()->create<luci::CircleDequantize>();
  dequantize->name(node->name() + "_Dequantize");
  dequantize->dtype(loco::DataType::FLOAT32);
  dequantize->rank(node->rank());
  for (uint32_t i = 0; i < node->rank(); i++)
    dequantize->dim(i).set(node->dim(i).value());

  dequantize->shape_status(luci::ShapeStatus::VALID);

  luci::add_origin(dequantize, luci::get_origin(node));

  return dequantize;
}

} // namespace

namespace luci
{

namespace
{

/**
 * Insert Quantize operator for mixed-precision quantization
 * 1. Before input feature map (only for non-const)
 * 2. After output feature map
 *
 * For example, if default_dtype = U8 and op_dtype = S16,
 * 1. Quantize Op for U8->S16 is inserted before ifm
 * 2. Quantize Op for S16->U8 is inserted after ofm
 *
 * Why not insert Quantize Op for const ifm?
 * We quantize const tensor at once to preserve precision.
 * For example, if default dtype = U8, op_dtype = S16, and op is CONV2D,
 * We directly quantize weights to 16 bits, not 8->16 bits.
 */
struct InsertQuantizeOp final : public luci::CircleNodeMutableVisitor<void>
{
  InsertQuantizeOp(loco::DataType default_dtype, loco::DataType op_dtype)
    : _default_dtype(default_dtype), _op_dtype(op_dtype)
  {
    assert(default_dtype != op_dtype); // FIX_CALLER_UNLESS
  }

private:
  loco::DataType _default_dtype;
  loco::DataType _op_dtype;

private:
  luci::CircleQuantize *create_in_quantize(loco::Node *in, loco::Node *origin)
  {
    auto input = loco::must_cast<luci::CircleNode *>(in);
    if (input->opcode() == luci::CircleOpcode::CIRCLECONST)
      return nullptr;

    auto input_quant = create_quantize_op(input, _op_dtype);
    input_quant->input(input);
    auto origin_node = loco::must_cast<luci::CircleNode *>(origin);
    luci::add_origin(input_quant, luci::get_origin(origin_node));
    return input_quant;
  }

  void insert_out_quantize(loco::Node *node)
  {
    auto output = loco::must_cast<luci::CircleNode *>(node);
    assert(output->opcode() != luci::CircleOpcode::CIRCLECONST); // FIX_CALLER_UNLESS
    auto output_quant = create_quantize_op(output, _default_dtype);

    luci::add_origin(output_quant, luci::get_origin(output));
    loco::replace(node).with(output_quant);
    output_quant->input(node);
  }

// INPUT_NAME is the only activation of NODE
#define INSERT_QUANTIZE_TO_UNARY_OP(NODE, INPUT_NAME)                    \
  void visit(NODE *node)                                                 \
  {                                                                      \
    if (auto input_quant = create_in_quantize(node->INPUT_NAME(), node)) \
      node->INPUT_NAME(input_quant);                                     \
                                                                         \
    insert_out_quantize(node);                                           \
  }

// INPUT_NAME is the only activation of NODE
#define INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(NODE, INPUT_NAME, OUT_NAME) \
  void visit(NODE *node)                                                     \
  {                                                                          \
    if (auto input_quant = create_in_quantize(node->INPUT_NAME(), node))     \
      node->INPUT_NAME(input_quant);                                         \
                                                                             \
    auto out_nodes = loco::succs(node);                                      \
    for (auto out_node : out_nodes)                                          \
    {                                                                        \
      auto out_circle = loco::must_cast<OUT_NAME *>(out_node);               \
      insert_out_quantize(out_circle);                                       \
    }                                                                        \
  }

// INPUT_NAME1 and INPUT_NAME2 are the only activations of NODE
#define INSERT_QUANTIZE_TO_BINARY_OP(NODE, INPUT_NAME1, INPUT_NAME2)       \
  void visit(NODE *node)                                                   \
  {                                                                        \
    if (auto input1_quant = create_in_quantize(node->INPUT_NAME1(), node)) \
      node->INPUT_NAME1(input1_quant);                                     \
                                                                           \
    if (auto input2_quant = create_in_quantize(node->INPUT_NAME2(), node)) \
      node->INPUT_NAME2(input2_quant);                                     \
                                                                           \
    insert_out_quantize(node);                                             \
  }

  // Default behavior (NYI)
  void visit(luci::CircleNode *node)
  {
    throw std::runtime_error("Unsupported Op for mixed-precision quantization. Layer name: " +
                             node->name());
  }

  // Skip output layer
  void visit(luci::CircleOutput *) {}
  void visit(luci::CircleSplitVOut *) {}
  void visit(luci::CircleSplitOut *) {}
  void visit(luci::CircleTopKV2Out *) {}
  void visit(luci::CircleUniqueOut *) {}
  void visit(luci::CircleUnpackOut *) {}

  // Ops that receive a single activation as an input
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleAveragePool2D, value)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleBatchToSpaceND, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleConv2D, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleDepthToSpace, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleDepthwiseConv2D, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleElu, features)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleExp, x)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFloor, x)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleFullyConnected, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleGather, params)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleInstanceNorm, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLeakyRelu, features)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLocalResponseNormalization, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleLogistic, x)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMaxPool2D, value)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMean, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleMirrorPad, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleNeg, x)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePad, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePadV2, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CirclePRelu, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceProd, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMax, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReduceMin, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu, features)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRelu6, features)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReshape, tensor)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeBilinear, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleResizeNearestNeighbor, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleReverseSequence, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleRsqrt, x)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSlice, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSoftmax, logits)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToBatchND, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSpaceToDepth, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqueeze, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSqrt, x)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleStridedSlice, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleSum, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTanh, x)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTile, input)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTranspose, a)
  INSERT_QUANTIZE_TO_UNARY_OP(luci::CircleTransposeConv, outBackprop)

  // Ops that receive two activations as inputs
  INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleAdd, x, y)
  INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleBatchMatMul, x, y)
  INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleDiv, x, y)
  INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleFloorDiv, x, y)
  INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMaximum, x, y)
  INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMinimum, x, y)
  INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleMul, x, y)
  INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleOneHot, on_value, off_value)
  INSERT_QUANTIZE_TO_BINARY_OP(luci::CirclePow, x, y)
  INSERT_QUANTIZE_TO_BINARY_OP(luci::CircleSub, x, y)

  // Multiple-output ops that receive one activation as inputs
  INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleSplit, input, luci::CircleSplitOut)
  INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleSplitV, input, luci::CircleSplitVOut)
  INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleTopKV2, input, luci::CircleTopKV2Out)
  INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleUnique, input, luci::CircleUniqueOut)
  INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP(luci::CircleUnpack, value, luci::CircleUnpackOut)

  // AddN has arbitrary number of inputs
  void visit(luci::CircleAddN *node)
  {
    auto arity = node->arity();
    for (uint32_t i = 0; i < arity; i++)
    {
      if (auto input_quant = create_in_quantize(node->inputs(i), node))
        node->inputs(i, input_quant);
    }

    insert_out_quantize(node);
  }

  // Concat has arbitrary number of inputs
  void visit(luci::CircleConcatenation *node)
  {
    auto arity = node->arity();
    for (uint32_t i = 0; i < arity; i++)
    {
      if (auto input_quant = create_in_quantize(node->values(i), node))
        node->values(i, input_quant);
    }

    insert_out_quantize(node);
  }

  // Pack has arbitrary number of inputs
  void visit(luci::CirclePack *node)
  {
    auto arity = node->arity();
    for (uint32_t i = 0; i < arity; i++)
    {
      if (auto input_quant = create_in_quantize(node->values(i), node))
        node->values(i, input_quant);
    }

    insert_out_quantize(node);
  }

#undef INSERT_QUANTIZE_TO_UNARY_OP
#undef INSERT_QUANTIZE_TO_BINARY_OP
#undef INSERT_QUANTIZE_TO_UNARY_MULTI_OUTPUT_OP
};

} // namespace

void QuantizeWithMinMaxPass::set_input_type(loco::Graph *g) const
{
  auto inputs = g->inputs();
  for (auto node : loco::input_nodes(g))
  {
    auto input = loco::must_cast<luci::CircleInput *>(node);
    if (input->dtype() == _ctx->input_type)
      continue;

    // Bool type is not quantizable
    if (input->dtype() == loco::DataType::BOOL)
      continue;
    if (input->dtype() == loco::DataType::S32)
      continue;
    if (input->dtype() == loco::DataType::S64)
      continue;

    // Insert Quantize Op
    auto quant_op = create_quantize_op(input, input->dtype());
    loco::replace(input).with(quant_op);
    quant_op->input(input);

    // TODO Set a proper origin (Quantize should have its own Origin)
    {
      auto succs = loco::succs(quant_op);
      assert(succs.size() > 0);
      auto succ = loco::must_cast<luci::CircleNode *>(*succs.begin());
      luci::add_origin(quant_op, luci::get_origin(succ));
    }

    // Update qparam of input
    // This step is skipped if input_type is float32
    if (_ctx->input_type != loco::DataType::FLOAT32)
    {
      auto quantparam = input->quantparam();
      assert(quantparam);
      assert(quantparam->min.size() == 1); // only support layer-wise quant
      assert(quantparam->max.size() == 1); // only support layer-wise quant
      auto min = quantparam->min[0];
      auto max = quantparam->max[0];

      float scaling_factor{0};
      int64_t zp{0};
      float nudged_min{0};
      float nudged_max{0};

      if (_ctx->input_type == loco::DataType::U8)
      {
        compute_asym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
      }
      else
      {
        assert(_ctx->input_type == loco::DataType::S16);
        compute_sym_scale_zp(min, max, scaling_factor, zp, nudged_min, nudged_max);
      }
      input->quantparam()->scale[0] = scaling_factor;
      input->quantparam()->zerop[0] = zp;
    }

    // Update dtype of input
    input->dtype(_ctx->input_type);

    auto graph_input = inputs->at(input->index());
    graph_input->dtype(_ctx->input_type);
  }
}

void QuantizeWithMinMaxPass::set_output_type(loco::Graph *g) const
{
  auto outputs = g->outputs();
  for (auto node : loco::output_nodes(g))
  {
    auto output = loco::must_cast<luci::CircleOutput *>(node);
    if (output->dtype() == _ctx->output_type)
      continue;

    // Bool type is not quantizable
    if (output->dtype() == loco::DataType::BOOL)
      continue;

    auto from = loco::must_cast<luci::CircleNode *>(output->from());

    // The last Op is not quantizable Op (ex: ArgMax)
    if (not from->quantparam())
      continue;

    // Insert Dequantize Op for float32 output_type
    if (_ctx->output_type == loco::DataType::FLOAT32)
    {
      auto dequant_op = create_dequantize(from);
      loco::replace(from).with(dequant_op);
      dequant_op->input(from);
    }
    else
    {
      // Insert Quantize Op for non-float32 output_type
      auto quant_op = create_quantize_op(from, _ctx->output_type);
      loco::replace(from).with(quant_op);
      quant_op->input(from);

      // TODO Set a proper origin (Quantize should have its own Origin)
      luci::add_origin(quant_op, luci::get_origin(from));
    }

    // Update dtype of output
    output->dtype(_ctx->output_type);

    auto graph_output = outputs->at(output->index());
    graph_output->dtype(_ctx->output_type);
  }
}

/**
 * How QuantizeWithMinMax works?
 *
 * We categorized tensors into four groups
 * - Activation: Feature maps (both Const/Non-const)
 * - Weights: Const tensors of specific Ops (Conv, FC, ...)
 * - Bias: Const tensors of specific Ops (Conv, FC, ...)
 * - Others: padding value, one_hot value, axis, ..
 *
 * Activation is quantized in different ways
 * 1. For non-constant activation, quantize using recorded min/max
 * 2. For constant activation, quantize using min/max of its value
 * 3. For some Ops (ex: pad_v2), output qparam is used as input qparam (backward propagation)
 * 4. For some Ops (ex: reshape), input qparam is used as output qparam (forward propagation)
 * 5. For some Ops (ex: tanh), output qparam has pre-defined values
 *
 * Weights is quantized using min/max of its value
 *
 * Bias is quantized using input scale (s_i) and weights scale (s_w)
 * - Activation and weights should be quantized earlier than bias
 *
 * Quantization Steps
 * 1. Quantize Activation
 *   - Quantize using recorded min/max (QuantizeActivation)
 *   - Insert Quantize Ops for mixed-precision quantization (InsertQuantizeOp)
 *   - Remove redundant Quantize Ops (RemoveRedundantQuantizePass)
 *   - Propagate qparam backward (PropagateQParamBackwardPass)
 *   - Quantize const inputs (QuantizeConstInputActivation)
 *   - Quantize using pre-defined values (QuantizeSpecialActivation)
 *   - Propagate qparam forward (PropagateQParamForwardPass)
 * 2. Quantize Weights
 * 3. Quantize Bias
 * 4. Set input dtype
 * 5. Set output dtype
 *
 * Why quantization sequence was determined as above?
 * - Activation and weights should be quantized before bias (1->2->3). Input/Output
 *   dtype can be updated at the end (4->5).
 * - During activation quantization,
 *   - Backward propagation is performed earlier than forward propagation. This allows
 *     backward-propagated qpram to be overwritten during forward propagation.
 *     We made this decision as Ops for forward propagation (reshape, transpose, ..)
 *     are more common than backward propagation. TODO Check this decision is safe.
 *   - QuantizeSpecialActivation is called before forward propagation to make sure that
 *     the pre-defined qparam values are propagated.
 */
bool QuantizeWithMinMaxPass::run(loco::Graph *g)
{
  LOGGER(l);
  INFO(l) << "QuantizeWithMinMaxPass Start" << std::endl;

  auto info_by_name = layer_info_map(g, _ctx->layers_info);

  auto quantize_dtype = [&](const luci::CircleNode *node) {
    auto iter = info_by_name.find(node->name());

    // Return designated quantization dtype
    if (iter != info_by_name.end())
      return iter->second.dtype;

    // Return default quantization dtype
    return _ctx->output_model_dtype;
  };

  auto quantize_granularity = [&](const luci::CircleNode *node) {
    auto iter = info_by_name.find(node->name());

    // Return designated quantization granularity
    if (iter != info_by_name.end())
      return iter->second.granularity;

    // Return default quantization granularity
    return _ctx->granularity;
  };

  // Quantize activation
  for (auto node : loco::active_nodes(loco::output_nodes(g)))
  {
    auto circle_node = loco::must_cast<luci::CircleNode *>(node);
    QuantizeActivation qa(_ctx->input_model_dtype, quantize_dtype(circle_node));
    circle_node->accept(&qa);
  }

  // Insert Quantize Op
  for (auto node : loco::active_nodes(loco::output_nodes(g)))
  {
    auto circle_node = loco::must_cast<luci::CircleNode *>(node);
    auto op_dtype = quantize_dtype(circle_node);
    if (op_dtype != _ctx->output_model_dtype)
    {
      InsertQuantizeOp iqo(_ctx->output_model_dtype, op_dtype);
      circle_node->accept(&iqo);
    }
  }

  // Remove redundant Quantize Op
  {
    logo::Phase phase;

    phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());

    ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
    logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
    phase_runner.attach(&prog);
    phase_runner.run(phase);
  }

  // Backward propagation of activation qparam
  {
    PropagateQParamBackwardPass pqbp(_ctx->output_model_dtype);
    pqbp.run(g);
  }

  // Quantize const input activation
  for (auto node : loco::active_nodes(loco::output_nodes(g)))
  {
    auto circle_node = loco::must_cast<luci::CircleNode *>(node);
    QuantizeConstInputActivation qcia(quantize_dtype(circle_node));
    circle_node->accept(&qcia);
  }

  // Update qparam of output of special Ops
  for (auto node : loco::active_nodes(loco::output_nodes(g)))
  {
    auto circle_node = loco::must_cast<luci::CircleNode *>(node);
    QuantizeSpecialActivation qsa(_ctx->input_model_dtype, quantize_dtype(circle_node));
    circle_node->accept(&qsa);
  }

  // Forward propagation of activation qparam
  logo::Phase phase;

  phase.emplace_back(std::make_unique<luci::PropagateQParamForwardPass>(_ctx->TF_style_maxpool));

  ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
  logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
  phase_runner.attach(&prog);
  phase_runner.run(phase);

  // Quantize weights
  for (auto node : loco::active_nodes(loco::output_nodes(g)))
  {
    auto circle_node = loco::must_cast<luci::CircleNode *>(node);
    QuantizeWeights qw(_ctx->input_model_dtype, quantize_dtype(circle_node),
                       quantize_granularity(circle_node));
    circle_node->accept(&qw);
  }

  // Quantize bias
  for (auto node : loco::active_nodes(loco::output_nodes(g)))
  {
    auto circle_node = loco::must_cast<luci::CircleNode *>(node);
    QuantizeBias qb(_ctx->input_model_dtype, quantize_dtype(circle_node),
                    quantize_granularity(circle_node));
    circle_node->accept(&qb);
  }

  // Update output dtype
  auto graph_outputs = g->outputs();
  for (auto node : loco::output_nodes(g))
  {
    auto circle_node = loco::must_cast<luci::CircleOutput *>(node);
    if (static_cast<luci::CircleNode *>(circle_node->from())->dtype() == _ctx->output_model_dtype)
    {
      circle_node->dtype(_ctx->output_model_dtype);
      auto graph_output = graph_outputs->at(circle_node->index());
      graph_output->dtype(_ctx->output_model_dtype);
    }
  }

  // Set input type
  set_input_type(g);

  // Set output type
  set_output_type(g);

  // Remove redundant Quantize Op
  {
    logo::Phase phase;

    phase.emplace_back(std::make_unique<luci::RemoveRedundantQuantizePass>());

    ProgressReporter prog(g, logo::PhaseStrategy::Saturate);
    logo::PhaseRunner<logo::PhaseStrategy::Saturate> phase_runner{g};
    phase_runner.attach(&prog);
    phase_runner.run(phase);
  }

  // Remove min/max values
  for (auto node : loco::active_nodes(loco::output_nodes(g)))
  {
    auto circle_node = loco::must_cast<luci::CircleNode *>(node);
    if (auto qparam = circle_node->quantparam())
    {
      warn_accuracy_with_range(circle_node);
      qparam->min.clear();
      qparam->max.clear();
    }
  }

  INFO(l) << "QuantizeWithMinMaxPass End" << std::endl;
  return false; // one time run
}

} // namespace luci