summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/PropagateQParamBackwardPass.cpp')
-rw-r--r--compiler/luci/pass/src/PropagateQParamBackwardPass.cpp63
1 files changed, 63 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
index e8fa2a478..18617e3b7 100644
--- a/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
+++ b/compiler/luci/pass/src/PropagateQParamBackwardPass.cpp
@@ -28,6 +28,25 @@
namespace
{
+// Return true if node is a virtual node
+bool virtual_op(const luci::CircleOpcode opcode)
+{
+ switch (opcode)
+ {
+#define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \
+ case luci::CircleOpcode::OPCODE: \
+ return false;
+#define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) \
+ case luci::CircleOpcode::OPCODE: \
+ return true;
+#include <luci/IR/CircleNodes.lst>
+#undef CIRCLE_NODE
+#undef CIRCLE_VNODE
+ default:
+ throw std::runtime_error("Unknown opcode detected");
+ }
+}
+
void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
loco::DataType quant_type)
{
@@ -448,6 +467,50 @@ struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<voi
void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); }
void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); }
+
+ // Propagate qparam for non-value changing Ops
+ // (ex: Reshape, Transpose, etc.)
+ // TODO Add more Ops
+
+ void visit(luci::CircleReshape *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor());
+
+ // Do not propagate qparam if input node has multiple users
+ if (loco::succs(input_node).size() > 1)
+ return;
+
+ const auto input_opcode = input_node->opcode();
+
+ // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT)
+ // Why? It is not safe to propagate qparam to some virtual nodes. For example,
+ // const node, multi-out nodes. Let's block them for now.
+ // TODO Revisit this condition
+ if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT)
+ return;
+
+ overwrite_quantparam(node, input_node);
+ }
+
+ void visit(luci::CircleTranspose *node)
+ {
+ auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
+
+ // Do not propagate qparam if input node has multiple users
+ if (loco::succs(input_node).size() > 1)
+ return;
+
+ const auto input_opcode = input_node->opcode();
+
+ // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT)
+ // Why? It is not safe to propagate qparam to some virtual nodes. For example,
+ // const node, multi-out nodes. Let's block them for now.
+ // TODO Revisit this condition
+ if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT)
+ return;
+
+ overwrite_quantparam(node, input_node);
+ }
};
} // namespace