diff options
Diffstat (limited to 'compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp')
-rw-r--r-- | compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp index 2c990f0a5..bc09abee2 100644 --- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp +++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp @@ -22,6 +22,7 @@ #include <luci/Profile/CircleNodeOrigin.h> #include <luci/Service/CircleShapeInference.h> #include <luci/Service/Nodes/CircleConst.h> +#include <luci/Service/CircleNodeClone.h> namespace { @@ -55,6 +56,26 @@ void copy_shape(luci::CircleReshape *reshape, luci::CircleReshape *new_reshape) new_reshape->newShape()->dim(r) = reshape->newShape()->dim(r); } +luci::CircleReshape *create_cloned_reshape(luci::CircleReshape *reshape) +{ + assert(reshape != nullptr); // FIX_CALLER_UNLESS + + luci::CircleConst *cloned_shape = clone_shape(reshape); + if (cloned_shape == nullptr) + return nullptr; + + auto cloned_node = luci::clone_node(reshape, reshape->graph()); + if (cloned_node == nullptr) + return nullptr; + + auto new_reshape = loco::must_cast<luci::CircleReshape *>(cloned_node); + new_reshape->shape(cloned_shape); + new_reshape->name(reshape->name() + "_C"); + luci::add_origin(new_reshape, luci::get_origin(reshape)); + + return new_reshape; +} + bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg) { assert(reshape != nullptr); @@ -85,6 +106,26 @@ bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg) return true; } +bool forward_reshape(luci::CircleReshape *reshape, luci::CircleLogistic *logit) +{ + assert(reshape != nullptr); // FIX_CALLER_UNLESS + assert(logit != nullptr); // FIX_CALLER_UNLESS + + auto new_reshape = create_cloned_reshape(reshape); + if (not new_reshape) + return false; + + // reconnect network + loco::replace(logit).with(new_reshape); + logit->x(reshape->tensor()); + new_reshape->tensor(logit); + + // Do shape inference for this node again. + logit->shape_status(luci::ShapeStatus::UNDEFINED); + + return true; +} + class ForwardReshape final : public luci::CircleNodeMutableVisitor<bool> { protected: @@ -103,6 +144,14 @@ protected: return forward_reshape(reshape, node); } + bool visit(luci::CircleLogistic *node) + { + auto reshape = as_reshape(node->x()); + if (reshape == nullptr) + return false; + + return forward_reshape(reshape, node); + } // TODO add more unary operators }; |