From 7087e0fcdd6c32582a4cbb9e2aa9091a6cc2c49d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princi?= =?UTF-8?q?pal=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 16 Sep 2019 13:32:34 +0900 Subject: [exo-tflite] shape inference for broadcast and TFLAdd (#7429) * [exo-tflite] shape inference for broadcast This will add helper functions for broadcast and shape inference of TFLAdd using them Signed-off-by: SaeHie Park * make compiler happy * fix message --- .../src/Dialect/Service/TFLShapeInferenceRule.cpp | 73 +++++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp index dd8fe85fc..e8df92ad7 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp @@ -24,7 +24,9 @@ #include "Check.h" +#include #include +#include namespace { @@ -112,6 +114,67 @@ private: const loco::TensorShape _shape; }; +/** + * @breif Expand shape x and y to same rank by align right and filling with 1 + */ +void expand_rank(loco::TensorShape &x, loco::TensorShape &y) +{ + auto x_rank = x.rank(); + auto y_rank = y.rank(); + + if (x_rank == y_rank) + return; + + TensorShapeExpander x_exp(x); + TensorShapeExpander y_exp(y); + + auto xy_rank = std::max(x_rank, y_rank); + + x = x_rank > y_rank ? x : x_exp.to(xy_rank); + y = y_rank > x_rank ? y : y_exp.to(xy_rank); +} + +/** + * @breif Returns shape of expanded dimension of input x and y having same rank + */ +loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y) +{ + assert(x.rank() == y.rank()); + + auto rank = x.rank(); + + loco::TensorShape output_shape; + + output_shape.rank(rank); + for (uint32_t axis = 0; axis < rank; ++axis) + { + assert(x.dim(axis).known() && y.dim(axis).known()); + + auto x_dim = x.dim(axis).value(); + auto y_dim = y.dim(axis).value(); + + // each dimension of x and y should be same or one must be 1 if different + if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1))) + throw std::runtime_error("Cannot produce expand_dimension of two shapes"); + + output_shape.dim(axis) = std::max(x_dim, y_dim); + } + + return output_shape; +} + +loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y) +{ + auto x_match = x; + auto y_match = y; + + expand_rank(x_match, y_match); + + auto output_shape = expand_dimension(x_match, y_match); + + return output_shape; +} + /** * @brief Class to infer the shape of TFLNode * @@ -135,7 +198,15 @@ public: } } - // TFLAdd + loco::NodeShape visit(const locoex::TFLAdd *node) final + { + auto x_shape = loco::shape_get(node->x()).as(); + auto y_shape = loco::shape_get(node->y()).as(); + + auto output_shape = broadcast_shape(x_shape, y_shape); + + return loco::NodeShape{output_shape}; + } loco::NodeShape visit(const locoex::TFLAveragePool2D *node) final { -- cgit v1.2.3