diff options
author | 박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com> | 2019-09-16 15:50:46 +0900 |
---|---|---|
committer | 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com> | 2019-09-16 15:50:46 +0900 |
commit | 32150fd2b33c75abd339b8e2f7ff9bfbee31ad60 (patch) | |
tree | 5d9b963c9d20a10e7c752f53c8df564ace9e70c1 /compiler | |
parent | c6ca30584d2621546ef05d7d2e3326364bef07e9 (diff) | |
download | nnfw-32150fd2b33c75abd339b8e2f7ff9bfbee31ad60.tar.gz nnfw-32150fd2b33c75abd339b8e2f7ff9bfbee31ad60.tar.bz2 nnfw-32150fd2b33c75abd339b8e2f7ff9bfbee31ad60.zip |
[loco] PlaneInference by direction (#7452)
This commit fixed PlaneInference into template with argument 'Direction'
Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index 968495e53..9abe8ecc3 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -93,7 +93,15 @@ loco::Window<2> window_of(const loco::DepthwiseFilterShape &depthwise_filter_sha return window; } -class PlaneInference final +enum class Direction +{ + Forward, + Backward, +}; + +template <Direction> class PlaneInference; + +template <> class PlaneInference<Direction::Forward> final { public: PlaneShape operator()(const PlaneShape &in) const @@ -187,7 +195,7 @@ public: // CASE: AvgPool2D loco::NodeShape visit(const loco::AvgPool2D *node) final { - PlaneInference infer_plane_shape; + PlaneInference<Direction::Forward> infer_plane_shape; infer_plane_shape.pad(node->pad()); infer_plane_shape.window(node->window()); @@ -238,7 +246,7 @@ public: auto filter_shape = node_shape(node->ker()).as<loco::FilterShape>(); auto filter_window = window_of(filter_shape); - PlaneInference infer_plane_shape; + PlaneInference<Direction::Forward> infer_plane_shape; infer_plane_shape.pad(node->pad()); infer_plane_shape.window(&filter_window); @@ -266,7 +274,7 @@ public: auto depthwise_filter_shape = node_shape(node->ker()).as<loco::DepthwiseFilterShape>(); auto dpethwise_filter_window = window_of(depthwise_filter_shape); - PlaneInference infer_plane_shape; + PlaneInference<Direction::Forward> infer_plane_shape; infer_plane_shape.pad(node->pad()); infer_plane_shape.window(&dpethwise_filter_window); @@ -375,7 +383,7 @@ public: // CASE: MaxPool2D loco::NodeShape visit(const loco::MaxPool2D *node) final { - PlaneInference infer_plane_shape; + PlaneInference<Direction::Forward> infer_plane_shape; infer_plane_shape.pad(node->pad()); infer_plane_shape.window(node->window()); |