summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp18
1 files changed, 13 insertions, 5 deletions
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
index 968495e53..9abe8ecc3 100644
--- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
+++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
@@ -93,7 +93,15 @@ loco::Window<2> window_of(const loco::DepthwiseFilterShape &depthwise_filter_sha
return window;
}
-class PlaneInference final
+enum class Direction
+{
+ Forward,
+ Backward,
+};
+
+template <Direction> class PlaneInference;
+
+template <> class PlaneInference<Direction::Forward> final
{
public:
PlaneShape operator()(const PlaneShape &in) const
@@ -187,7 +195,7 @@ public:
// CASE: AvgPool2D
loco::NodeShape visit(const loco::AvgPool2D *node) final
{
- PlaneInference infer_plane_shape;
+ PlaneInference<Direction::Forward> infer_plane_shape;
infer_plane_shape.pad(node->pad());
infer_plane_shape.window(node->window());
@@ -238,7 +246,7 @@ public:
auto filter_shape = node_shape(node->ker()).as<loco::FilterShape>();
auto filter_window = window_of(filter_shape);
- PlaneInference infer_plane_shape;
+ PlaneInference<Direction::Forward> infer_plane_shape;
infer_plane_shape.pad(node->pad());
infer_plane_shape.window(&filter_window);
@@ -266,7 +274,7 @@ public:
auto depthwise_filter_shape = node_shape(node->ker()).as<loco::DepthwiseFilterShape>();
auto dpethwise_filter_window = window_of(depthwise_filter_shape);
- PlaneInference infer_plane_shape;
+ PlaneInference<Direction::Forward> infer_plane_shape;
infer_plane_shape.pad(node->pad());
infer_plane_shape.window(&dpethwise_filter_window);
@@ -375,7 +383,7 @@ public:
// CASE: MaxPool2D
loco::NodeShape visit(const loco::MaxPool2D *node) final
{
- PlaneInference infer_plane_shape;
+ PlaneInference<Direction::Forward> infer_plane_shape;
infer_plane_shape.pad(node->pad());
infer_plane_shape.window(node->window());