summaryrefslogtreecommitdiff
path: root/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp')
-rw-r--r--compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp49
1 files changed, 49 insertions, 0 deletions
diff --git a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
index 2c990f0a5..bc09abee2 100644
--- a/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
+++ b/compiler/luci/pass/src/ForwardReshapeToUnaryOpPass.cpp
@@ -22,6 +22,7 @@
#include <luci/Profile/CircleNodeOrigin.h>
#include <luci/Service/CircleShapeInference.h>
#include <luci/Service/Nodes/CircleConst.h>
+#include <luci/Service/CircleNodeClone.h>
namespace
{
@@ -55,6 +56,26 @@ void copy_shape(luci::CircleReshape *reshape, luci::CircleReshape *new_reshape)
new_reshape->newShape()->dim(r) = reshape->newShape()->dim(r);
}
+luci::CircleReshape *create_cloned_reshape(luci::CircleReshape *reshape)
+{
+ assert(reshape != nullptr); // FIX_CALLER_UNLESS
+
+ luci::CircleConst *cloned_shape = clone_shape(reshape);
+ if (cloned_shape == nullptr)
+ return nullptr;
+
+ auto cloned_node = luci::clone_node(reshape, reshape->graph());
+ if (cloned_node == nullptr)
+ return nullptr;
+
+ auto new_reshape = loco::must_cast<luci::CircleReshape *>(cloned_node);
+ new_reshape->shape(cloned_shape);
+ new_reshape->name(reshape->name() + "_C");
+ luci::add_origin(new_reshape, luci::get_origin(reshape));
+
+ return new_reshape;
+}
+
bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg)
{
assert(reshape != nullptr);
@@ -85,6 +106,26 @@ bool forward_reshape(luci::CircleReshape *reshape, luci::CircleNeg *neg)
return true;
}
+bool forward_reshape(luci::CircleReshape *reshape, luci::CircleLogistic *logit)
+{
+ assert(reshape != nullptr); // FIX_CALLER_UNLESS
+ assert(logit != nullptr); // FIX_CALLER_UNLESS
+
+ auto new_reshape = create_cloned_reshape(reshape);
+ if (not new_reshape)
+ return false;
+
+ // reconnect network
+ loco::replace(logit).with(new_reshape);
+ logit->x(reshape->tensor());
+ new_reshape->tensor(logit);
+
+ // Do shape inference for this node again.
+ logit->shape_status(luci::ShapeStatus::UNDEFINED);
+
+ return true;
+}
+
class ForwardReshape final : public luci::CircleNodeMutableVisitor<bool>
{
protected:
@@ -103,6 +144,14 @@ protected:
return forward_reshape(reshape, node);
}
+ bool visit(luci::CircleLogistic *node)
+ {
+ auto reshape = as_reshape(node->x());
+ if (reshape == nullptr)
+ return false;
+
+ return forward_reshape(reshape, node);
+ }
// TODO add more unary operators
};