summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>2019-09-17 23:33:09 (GMT)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2019-09-17 23:33:09 (GMT)
commit3e621afddd5eb99e7d1fcfb3f3c253698f833461 (patch)
tree4b5ea87d7dade079032d5bc654b38febe339fa71
parent6601e7963fa4a356a7982e55eed79b751598921c (diff)
downloadnnfw-3e621afddd5eb99e7d1fcfb3f3c253698f833461.zip
nnfw-3e621afddd5eb99e7d1fcfb3f3c253698f833461.tar.gz
nnfw-3e621afddd5eb99e7d1fcfb3f3c253698f833461.tar.bz2
[loco] Backward plane inference (#7458)
This commit introduces backward plane inference, which would be used for canonical shape inference rule, especially for TransposedConv2D. Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
-rw-r--r--compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp44
1 files changed, 44 insertions, 0 deletions
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
index 591b024..5934007 100644
--- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
+++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
@@ -151,6 +151,50 @@ private:
const loco::Stride<2> *_stride = nullptr;
};
+template <> class PlaneInference<Direction::Backward> final
+{
+public:
+ PlaneShape operator()(const PlaneShape &in) const
+ {
+ assert(_pad != nullptr);
+ assert(_window != nullptr);
+ assert(_stride != nullptr);
+
+ uint32_t const input_height = in.height.value();
+ uint32_t const input_width = in.width.value();
+
+ uint32_t const vertical_padding = _pad->top() + _pad->bottom();
+ uint32_t const horizontal_padding = _pad->left() + _pad->right();
+
+ uint32_t const raw_window_height = _window->vertical();
+ uint32_t const raw_window_width = _window->horizontal();
+
+ // TODO Support "dilation"
+ uint32_t const effective_window_height = raw_window_height;
+ uint32_t const effective_window_width = raw_window_width;
+
+ uint32_t const vertical_stride = _stride->vertical();
+ uint32_t const horizontal_stride = _stride->horizontal();
+
+ PlaneShape res;
+
+ res.height = vertical_stride * (input_height - 1) + effective_window_height - vertical_padding;
+ res.width = horizontal_stride * (input_width - 1) + effective_window_width - horizontal_padding;
+
+ return res;
+ }
+
+public:
+ void pad(const loco::Padding2D *value) { _pad = value; }
+ void window(const loco::Window<2> *value) { _window = value; }
+ void stride(const loco::Stride<2> *value) { _stride = value; }
+
+private:
+ const loco::Padding2D *_pad = nullptr;
+ const loco::Window<2> *_window = nullptr;
+ const loco::Stride<2> *_stride = nullptr;
+};
+
/**
* There are two possible maintenance policies.
* - Introduce a new canonical node first, and then extend this algorithm later