diff options
Diffstat (limited to 'compiler/luci/pass/src/PropagateQParamBackwardPass.cpp')
-rw-r--r-- | compiler/luci/pass/src/PropagateQParamBackwardPass.cpp | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp index e8fa2a478..18617e3b7 100644 --- a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp +++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp @@ -28,6 +28,25 @@ namespace { +// Return true if node is a virtual node +bool virtual_op(const luci::CircleOpcode opcode) +{ + switch (opcode) + { +#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \ + case luci::CircleOpcode::OPCODE: \ + return false; +#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) \ + case luci::CircleOpcode::OPCODE: \ + return true; +#include <luci/IR/CircleNodes.lst> +#undef CIRCLE_NODE +#undef CIRCLE_VNODE + default: + throw std::runtime_error("Unknown opcode detected"); + } +} + void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop, loco::DataType quant_type) { @@ -448,6 +467,50 @@ struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<voi void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); } void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); } + + // Propagate qparam for non-value changing Ops + // (ex: Reshape, Transpose, etc.) + // TODO Add more Ops + + void visit(luci::CircleReshape *node) + { + auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor()); + + // Do not propagate qparam if input node has multiple users + if (loco::succs(input_node).size() > 1) + return; + + const auto input_opcode = input_node->opcode(); + + // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT) + // Why? It is not safe to propagate qparam to some virtual nodes. For example, + // const node, multi-out nodes. Let's block them for now. + // TODO Revisit this condition + if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT) + return; + + overwrite_quantparam(node, input_node); + } + + void visit(luci::CircleTranspose *node) + { + auto input_node = loco::must_cast<luci::CircleNode *>(node->a()); + + // Do not propagate qparam if input node has multiple users + if (loco::succs(input_node).size() > 1) + return; + + const auto input_opcode = input_node->opcode(); + + // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT) + // Why? It is not safe to propagate qparam to some virtual nodes. For example, + // const node, multi-out nodes. Let's block them for now. + // TODO Revisit this condition + if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT) + return; + + overwrite_quantparam(node, input_node); + } }; } // namespace |