diff options
author | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-16 13:32:34 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2019-09-16 13:32:34 +0900 |
commit | 7087e0fcdd6c32582a4cbb9e2aa9091a6cc2c49d (patch) | |
tree | 1bdd81f00beca451163c82f8ef283fb1f555623e | |
parent | b732c95fef413cc5fd4d7eab34c0d8ee66386ea5 (diff) | |
download | nnfw-7087e0fcdd6c32582a4cbb9e2aa9091a6cc2c49d.tar.gz nnfw-7087e0fcdd6c32582a4cbb9e2aa9091a6cc2c49d.tar.bz2 nnfw-7087e0fcdd6c32582a4cbb9e2aa9091a6cc2c49d.zip |
[exo-tflite] shape inference for broadcast and TFLAdd (#7429)
* [exo-tflite] shape inference for broadcast
This will add helper functions for broadcast and shape inference of TFLAdd using them
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* make compiler happy
* fix message
-rw-r--r-- | compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp | 73 |
1 files changed, 72 insertions, 1 deletions
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp index dd8fe85fc..e8df92ad7 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp @@ -24,7 +24,9 @@ #include "Check.h" +#include <algorithm> #include <cassert> +#include <stdexcept> namespace { @@ -113,6 +115,67 @@ private: }; /** + * @breif Expand shape x and y to same rank by align right and filling with 1 + */ +void expand_rank(loco::TensorShape &x, loco::TensorShape &y) +{ + auto x_rank = x.rank(); + auto y_rank = y.rank(); + + if (x_rank == y_rank) + return; + + TensorShapeExpander x_exp(x); + TensorShapeExpander y_exp(y); + + auto xy_rank = std::max(x_rank, y_rank); + + x = x_rank > y_rank ? x : x_exp.to(xy_rank); + y = y_rank > x_rank ? y : y_exp.to(xy_rank); +} + +/** + * @breif Returns shape of expanded dimension of input x and y having same rank + */ +loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y) +{ + assert(x.rank() == y.rank()); + + auto rank = x.rank(); + + loco::TensorShape output_shape; + + output_shape.rank(rank); + for (uint32_t axis = 0; axis < rank; ++axis) + { + assert(x.dim(axis).known() && y.dim(axis).known()); + + auto x_dim = x.dim(axis).value(); + auto y_dim = y.dim(axis).value(); + + // each dimension of x and y should be same or one must be 1 if different + if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1))) + throw std::runtime_error("Cannot produce expand_dimension of two shapes"); + + output_shape.dim(axis) = std::max(x_dim, y_dim); + } + + return output_shape; +} + +loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y) +{ + auto x_match = x; + auto y_match = y; + + expand_rank(x_match, y_match); + + auto output_shape = expand_dimension(x_match, y_match); + + return output_shape; +} + +/** * @brief Class to infer the shape of TFLNode * * @note All TFLNode's inputs and outouts are always loco::Domain::Tensor @@ -135,7 +198,15 @@ public: } } - // TFLAdd + loco::NodeShape visit(const locoex::TFLAdd *node) final + { + auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>(); + + auto output_shape = broadcast_shape(x_shape, y_shape); + + return loco::NodeShape{output_shape}; + } loco::NodeShape visit(const locoex::TFLAveragePool2D *node) final { |