summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp44
1 files changed, 44 insertions, 0 deletions
diff --git a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
index 591b02450..59340076d 100644
--- a/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
+++ b/compiler/loco/src/Service/CanonicalShapeInferenceRule.cpp
@@ -151,6 +151,50 @@ private:
const loco::Stride<2> *_stride = nullptr;
};
+template <> class PlaneInference<Direction::Backward> final
+{
+public:
+ PlaneShape operator()(const PlaneShape &in) const
+ {
+ assert(_pad != nullptr);
+ assert(_window != nullptr);
+ assert(_stride != nullptr);
+
+ uint32_t const input_height = in.height.value();
+ uint32_t const input_width = in.width.value();
+
+ uint32_t const vertical_padding = _pad->top() + _pad->bottom();
+ uint32_t const horizontal_padding = _pad->left() + _pad->right();
+
+ uint32_t const raw_window_height = _window->vertical();
+ uint32_t const raw_window_width = _window->horizontal();
+
+ // TODO Support "dilation"
+ uint32_t const effective_window_height = raw_window_height;
+ uint32_t const effective_window_width = raw_window_width;
+
+ uint32_t const vertical_stride = _stride->vertical();
+ uint32_t const horizontal_stride = _stride->horizontal();
+
+ PlaneShape res;
+
+ res.height = vertical_stride * (input_height - 1) + effective_window_height - vertical_padding;
+ res.width = horizontal_stride * (input_width - 1) + effective_window_width - horizontal_padding;
+
+ return res;
+ }
+
+public:
+ void pad(const loco::Padding2D *value) { _pad = value; }
+ void window(const loco::Window<2> *value) { _window = value; }
+ void stride(const loco::Stride<2> *value) { _stride = value; }
+
+private:
+ const loco::Padding2D *_pad = nullptr;
+ const loco::Window<2> *_window = nullptr;
+ const loco::Stride<2> *_stride = nullptr;
+};
+
/**
* There are two possible maintenance policies.
* - Introduce a new canonical node first, and then extend this algorithm later