summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>2019-09-16 15:50:46 +0900
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>2019-09-16 15:50:46 +0900
commit32150fd2b33c75abd339b8e2f7ff9bfbee31ad60 (patch)
tree5d9b963c9d20a10e7c752f53c8df564ace9e70c1 /compiler
parentc6ca30584d2621546ef05d7d2e3326364bef07e9 (diff)
downloadnnfw-32150fd2b33c75abd339b8e2f7ff9bfbee31ad60.tar.gz
nnfw-32150fd2b33c75abd339b8e2f7ff9bfbee31ad60.tar.bz2
nnfw-32150fd2b33c75abd339b8e2f7ff9bfbee31ad60.zip
[loco] PlaneInference by direction (#7452)
This commit fixed PlaneInference into template with argument 'Direction' Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
Diffstat (limited to 'compiler')
-rw-r--r--compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp18
1 files changed, 13 insertions, 5 deletions
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
index 968495e53..9abe8ecc3 100644
--- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
+++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
@@ -93,7 +93,15 @@ loco::Window<2> window_of(const loco::DepthwiseFilterShape &depthwise_filter_sha
return window;
}
-class PlaneInference final
+enum class Direction
+{
+ Forward,
+ Backward,
+};
+
+template <Direction> class PlaneInference;
+
+template <> class PlaneInference<Direction::Forward> final
{
public:
PlaneShape operator()(const PlaneShape &in) const
@@ -187,7 +195,7 @@ public:
// CASE: AvgPool2D
loco::NodeShape visit(const loco::AvgPool2D *node) final
{
- PlaneInference infer_plane_shape;
+ PlaneInference<Direction::Forward> infer_plane_shape;
infer_plane_shape.pad(node->pad());
infer_plane_shape.window(node->window());
@@ -238,7 +246,7 @@ public:
auto filter_shape = node_shape(node->ker()).as<loco::FilterShape>();
auto filter_window = window_of(filter_shape);
- PlaneInference infer_plane_shape;
+ PlaneInference<Direction::Forward> infer_plane_shape;
infer_plane_shape.pad(node->pad());
infer_plane_shape.window(&filter_window);
@@ -266,7 +274,7 @@ public:
auto depthwise_filter_shape = node_shape(node->ker()).as<loco::DepthwiseFilterShape>();
auto dpethwise_filter_window = window_of(depthwise_filter_shape);
- PlaneInference infer_plane_shape;
+ PlaneInference<Direction::Forward> infer_plane_shape;
infer_plane_shape.pad(node->pad());
infer_plane_shape.window(&dpethwise_filter_window);
@@ -375,7 +383,7 @@ public:
// CASE: MaxPool2D
loco::NodeShape visit(const loco::MaxPool2D *node) final
{
- PlaneInference infer_plane_shape;
+ PlaneInference<Direction::Forward> infer_plane_shape;
infer_plane_shape.pad(node->pad());
infer_plane_shape.window(node->window());