summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2019-09-17 16:59:34 +0900
committerGitHub Enterprise <noreply-CODE@samsung.com>2019-09-17 16:59:34 +0900
commit41e9f2ac30b469f79e1f6e01f651b768950651ff (patch)
treec13a1d758df8935c07d6172b4edeffe58a0bacdd
parent421490ad05ba56798ce92b2e10dd60fbae168b2b (diff)
downloadnnfw-41e9f2ac30b469f79e1f6e01f651b768950651ff.tar.gz
nnfw-41e9f2ac30b469f79e1f6e01f651b768950651ff.tar.bz2
nnfw-41e9f2ac30b469f79e1f6e01f651b768950651ff.zip
[moco-tf] copy_shapedata with two inputs (#7499)
* [moco-tf] copy_shapedata with two inputs This will introduce copy_shapedata that can receive two inputs and return shape from broadcast algorithm for binary operation nodes Signed-off-by: SaeHie Park <saehie.park@samsung.com> * add note
-rw-r--r--compiler/moco-tf/src/Transforms/FixShapeTransform.cpp58
1 files changed, 58 insertions, 0 deletions
diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp
index 17de7c262..818a7f7da 100644
--- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp
+++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp
@@ -285,6 +285,64 @@ bool copy_shapedata(const loco::Node *src, loco::Node *dst)
}
/**
+ * @note This will find broadcast shape from two inputs lhs and rhs using
+ * broadcast_shape() and return that shape to dst
+ */
+bool copy_shapedata(const loco::Node *lhs, const loco::Node *rhs, loco::Node *dst)
+{
+ // if dst already has ShapeInferenceData, skip
+ if (shape_inference_done(dst))
+ return false;
+
+ loco::NodeShape lhs_shape;
+ loco::NodeShape rhs_shape;
+
+ if (loco::shape_known(lhs))
+ {
+ lhs_shape = loco::shape_get(lhs);
+ }
+ else
+ {
+ if (!shape_inference_done(lhs))
+ return false;
+
+ lhs_shape = as_node_shape(lhs->annot<ShapeInferenceData>());
+ }
+
+ if (loco::shape_known(rhs))
+ {
+ rhs_shape = loco::shape_get(rhs);
+ }
+ else
+ {
+ if (!shape_inference_done(rhs))
+ return false;
+
+ rhs_shape = as_node_shape(rhs->annot<ShapeInferenceData>());
+ }
+
+ if (lhs_shape.domain() != loco::Domain::Tensor || rhs_shape.domain() != loco::Domain::Tensor)
+ {
+ throw std::runtime_error("copy_shapedata supports only for Tensor");
+ }
+
+ loco::TensorShape lhs_tensorshape = lhs_shape.as<loco::TensorShape>();
+ loco::TensorShape rhs_tensorshape = rhs_shape.as<loco::TensorShape>();
+ loco::TensorShape sum_tensorshape = broadcast_shape(lhs_tensorshape, rhs_tensorshape);
+
+ loco::NodeShape sum_shape({sum_tensorshape});
+ auto shape_data = make_shape_inference_data(sum_shape);
+ dst->annot(std::move(shape_data));
+
+ LOGGER(l);
+
+ INFO(l) << "copy_shapedata " << lhs_tensorshape << " or " << rhs_tensorshape << " -> "
+ << sum_tensorshape << std::endl;
+
+ return true;
+}
+
+/**
* @note While in shape inference, Node maybe Canonical, TF dialect or other dialects
* This will provide common loco::NodeShape as shape information
*/