summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2019-09-16 13:32:34 +0900
committerGitHub Enterprise <noreply-CODE@samsung.com>2019-09-16 13:32:34 +0900
commit7087e0fcdd6c32582a4cbb9e2aa9091a6cc2c49d (patch)
tree1bdd81f00beca451163c82f8ef283fb1f555623e
parentb732c95fef413cc5fd4d7eab34c0d8ee66386ea5 (diff)
downloadnnfw-7087e0fcdd6c32582a4cbb9e2aa9091a6cc2c49d.tar.gz
nnfw-7087e0fcdd6c32582a4cbb9e2aa9091a6cc2c49d.tar.bz2
nnfw-7087e0fcdd6c32582a4cbb9e2aa9091a6cc2c49d.zip
[exo-tflite] shape inference for broadcast and TFLAdd (#7429)
* [exo-tflite] shape inference for broadcast This will add helper functions for broadcast and shape inference of TFLAdd using them Signed-off-by: SaeHie Park <saehie.park@samsung.com> * make compiler happy * fix message
-rw-r--r--compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp73
1 files changed, 72 insertions, 1 deletions
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
index dd8fe85fc..e8df92ad7 100644
--- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
+++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
@@ -24,7 +24,9 @@
#include "Check.h"
+#include <algorithm>
#include <cassert>
+#include <stdexcept>
namespace
{
@@ -113,6 +115,67 @@ private:
};
/**
+ * @breif Expand shape x and y to same rank by align right and filling with 1
+ */
+void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
+{
+ auto x_rank = x.rank();
+ auto y_rank = y.rank();
+
+ if (x_rank == y_rank)
+ return;
+
+ TensorShapeExpander x_exp(x);
+ TensorShapeExpander y_exp(y);
+
+ auto xy_rank = std::max(x_rank, y_rank);
+
+ x = x_rank > y_rank ? x : x_exp.to(xy_rank);
+ y = y_rank > x_rank ? y : y_exp.to(xy_rank);
+}
+
+/**
+ * @breif Returns shape of expanded dimension of input x and y having same rank
+ */
+loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
+{
+ assert(x.rank() == y.rank());
+
+ auto rank = x.rank();
+
+ loco::TensorShape output_shape;
+
+ output_shape.rank(rank);
+ for (uint32_t axis = 0; axis < rank; ++axis)
+ {
+ assert(x.dim(axis).known() && y.dim(axis).known());
+
+ auto x_dim = x.dim(axis).value();
+ auto y_dim = y.dim(axis).value();
+
+ // each dimension of x and y should be same or one must be 1 if different
+ if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
+ throw std::runtime_error("Cannot produce expand_dimension of two shapes");
+
+ output_shape.dim(axis) = std::max(x_dim, y_dim);
+ }
+
+ return output_shape;
+}
+
+loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
+{
+ auto x_match = x;
+ auto y_match = y;
+
+ expand_rank(x_match, y_match);
+
+ auto output_shape = expand_dimension(x_match, y_match);
+
+ return output_shape;
+}
+
+/**
* @brief Class to infer the shape of TFLNode
*
* @note All TFLNode's inputs and outouts are always loco::Domain::Tensor
@@ -135,7 +198,15 @@ public:
}
}
- // TFLAdd
+ loco::NodeShape visit(const locoex::TFLAdd *node) final
+ {
+ auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
+ auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
+
+ auto output_shape = broadcast_shape(x_shape, y_shape);
+
+ return loco::NodeShape{output_shape};
+ }
loco::NodeShape visit(const locoex::TFLAveragePool2D *node) final
{