diff options
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index 968495e53..9abe8ecc3 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -93,7 +93,15 @@ loco::Window<2> window_of(const loco::DepthwiseFilterShape &depthwise_filter_sha return window; } -class PlaneInference final +enum class Direction +{ + Forward, + Backward, +}; + +template <Direction> class PlaneInference; + +template <> class PlaneInference<Direction::Forward> final { public: PlaneShape operator()(const PlaneShape &in) const @@ -187,7 +195,7 @@ public: // CASE: AvgPool2D loco::NodeShape visit(const loco::AvgPool2D *node) final { - PlaneInference infer_plane_shape; + PlaneInference<Direction::Forward> infer_plane_shape; infer_plane_shape.pad(node->pad()); infer_plane_shape.window(node->window()); @@ -238,7 +246,7 @@ public: auto filter_shape = node_shape(node->ker()).as<loco::FilterShape>(); auto filter_window = window_of(filter_shape); - PlaneInference infer_plane_shape; + PlaneInference<Direction::Forward> infer_plane_shape; infer_plane_shape.pad(node->pad()); infer_plane_shape.window(&filter_window); @@ -266,7 +274,7 @@ public: auto depthwise_filter_shape = node_shape(node->ker()).as<loco::DepthwiseFilterShape>(); auto dpethwise_filter_window = window_of(depthwise_filter_shape); - PlaneInference infer_plane_shape; + PlaneInference<Direction::Forward> infer_plane_shape; infer_plane_shape.pad(node->pad()); infer_plane_shape.window(&dpethwise_filter_window); @@ -375,7 +383,7 @@ public: // CASE: MaxPool2D loco::NodeShape visit(const loco::MaxPool2D *node) final { - PlaneInference infer_plane_shape; + PlaneInference<Direction::Forward> infer_plane_shape; infer_plane_shape.pad(node->pad()); infer_plane_shape.window(node->window()); |