diff options
author | 박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com> | 2019-09-18 08:33:09 +0900 |
---|---|---|
committer | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-18 08:33:09 +0900 |
commit | 3e621afddd5eb99e7d1fcfb3f3c253698f833461 (patch) | |
tree | 4b5ea87d7dade079032d5bc654b38febe339fa71 | |
parent | 6601e7963fa4a356a7982e55eed79b751598921c (diff) | |
download | nnfw-3e621afddd5eb99e7d1fcfb3f3c253698f833461.tar.gz nnfw-3e621afddd5eb99e7d1fcfb3f3c253698f833461.tar.bz2 nnfw-3e621afddd5eb99e7d1fcfb3f3c253698f833461.zip |
[loco] Backward plane inference (#7458)
This commit introduces backward plane inference, which would be used for
canonical shape inference rule, especially for TransposedConv2D.
Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
-rw-r--r-- | compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index 591b02450..59340076d 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -151,6 +151,50 @@ private: const loco::Stride<2> *_stride = nullptr; }; +template <> class PlaneInference<Direction::Backward> final +{ +public: + PlaneShape operator()(const PlaneShape &in) const + { + assert(_pad != nullptr); + assert(_window != nullptr); + assert(_stride != nullptr); + + uint32_t const input_height = in.height.value(); + uint32_t const input_width = in.width.value(); + + uint32_t const vertical_padding = _pad->top() + _pad->bottom(); + uint32_t const horizontal_padding = _pad->left() + _pad->right(); + + uint32_t const raw_window_height = _window->vertical(); + uint32_t const raw_window_width = _window->horizontal(); + + // TODO Support "dilation" + uint32_t const effective_window_height = raw_window_height; + uint32_t const effective_window_width = raw_window_width; + + uint32_t const vertical_stride = _stride->vertical(); + uint32_t const horizontal_stride = _stride->horizontal(); + + PlaneShape res; + + res.height = vertical_stride * (input_height - 1) + effective_window_height - vertical_padding; + res.width = horizontal_stride * (input_width - 1) + effective_window_width - horizontal_padding; + + return res; + } + +public: + void pad(const loco::Padding2D *value) { _pad = value; } + void window(const loco::Window<2> *value) { _window = value; } + void stride(const loco::Stride<2> *value) { _stride = value; } + +private: + const loco::Padding2D *_pad = nullptr; + const loco::Window<2> *_window = nullptr; + const loco::Stride<2> *_stride = nullptr; +}; + /** * There are two possible maintenance policies. * - Introduce a new canonical node first, and then extend this algorithm later |