diff options
-rw-r--r-- | compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp index 591b02450..59340076d 100644 --- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp +++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp @@ -151,6 +151,50 @@ private: const loco::Stride<2> *_stride = nullptr; }; +template <> class PlaneInference<Direction::Backward> final +{ +public: + PlaneShape operator()(const PlaneShape &in) const + { + assert(_pad != nullptr); + assert(_window != nullptr); + assert(_stride != nullptr); + + uint32_t const input_height = in.height.value(); + uint32_t const input_width = in.width.value(); + + uint32_t const vertical_padding = _pad->top() + _pad->bottom(); + uint32_t const horizontal_padding = _pad->left() + _pad->right(); + + uint32_t const raw_window_height = _window->vertical(); + uint32_t const raw_window_width = _window->horizontal(); + + // TODO Support "dilation" + uint32_t const effective_window_height = raw_window_height; + uint32_t const effective_window_width = raw_window_width; + + uint32_t const vertical_stride = _stride->vertical(); + uint32_t const horizontal_stride = _stride->horizontal(); + + PlaneShape res; + + res.height = vertical_stride * (input_height - 1) + effective_window_height - vertical_padding; + res.width = horizontal_stride * (input_width - 1) + effective_window_width - horizontal_padding; + + return res; + } + +public: + void pad(const loco::Padding2D *value) { _pad = value; } + void window(const loco::Window<2> *value) { _window = value; } + void stride(const loco::Stride<2> *value) { _stride = value; } + +private: + const loco::Padding2D *_pad = nullptr; + const loco::Window<2> *_window = nullptr; + const loco::Stride<2> *_stride = nullptr; +}; + /** * There are two possible maintenance policies. * - Introduce a new canonical node first, and then extend this algorithm later |