diff options
author | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-17 16:59:34 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2019-09-17 16:59:34 +0900 |
commit | 41e9f2ac30b469f79e1f6e01f651b768950651ff (patch) | |
tree | c13a1d758df8935c07d6172b4edeffe58a0bacdd | |
parent | 421490ad05ba56798ce92b2e10dd60fbae168b2b (diff) | |
download | nnfw-41e9f2ac30b469f79e1f6e01f651b768950651ff.tar.gz nnfw-41e9f2ac30b469f79e1f6e01f651b768950651ff.tar.bz2 nnfw-41e9f2ac30b469f79e1f6e01f651b768950651ff.zip |
[moco-tf] copy_shapedata with two inputs (#7499)
* [moco-tf] copy_shapedata with two inputs
This will introduce copy_shapedata that can receive two inputs and return shape from broadcast algorithm for binary operation nodes
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
* add note
-rw-r--r-- | compiler/moco-tf/src/Transforms/FixShapeTransform.cpp | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index 17de7c262..818a7f7da 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -285,6 +285,64 @@ bool copy_shapedata(const loco::Node *src, loco::Node *dst) } /** + * @note This will find broadcast shape from two inputs lhs and rhs using + * broadcast_shape() and return that shape to dst + */ +bool copy_shapedata(const loco::Node *lhs, const loco::Node *rhs, loco::Node *dst) +{ + // if dst already has ShapeInferenceData, skip + if (shape_inference_done(dst)) + return false; + + loco::NodeShape lhs_shape; + loco::NodeShape rhs_shape; + + if (loco::shape_known(lhs)) + { + lhs_shape = loco::shape_get(lhs); + } + else + { + if (!shape_inference_done(lhs)) + return false; + + lhs_shape = as_node_shape(lhs->annot<ShapeInferenceData>()); + } + + if (loco::shape_known(rhs)) + { + rhs_shape = loco::shape_get(rhs); + } + else + { + if (!shape_inference_done(rhs)) + return false; + + rhs_shape = as_node_shape(rhs->annot<ShapeInferenceData>()); + } + + if (lhs_shape.domain() != loco::Domain::Tensor || rhs_shape.domain() != loco::Domain::Tensor) + { + throw std::runtime_error("copy_shapedata supports only for Tensor"); + } + + loco::TensorShape lhs_tensorshape = lhs_shape.as<loco::TensorShape>(); + loco::TensorShape rhs_tensorshape = rhs_shape.as<loco::TensorShape>(); + loco::TensorShape sum_tensorshape = broadcast_shape(lhs_tensorshape, rhs_tensorshape); + + loco::NodeShape sum_shape({sum_tensorshape}); + auto shape_data = make_shape_inference_data(sum_shape); + dst->annot(std::move(shape_data)); + + LOGGER(l); + + INFO(l) << "copy_shapedata " << lhs_tensorshape << " or " << rhs_tensorshape << " -> " + << sum_tensorshape << std::endl; + + return true; +} + +/** * @note While in shape inference, Node maybe Canonical, TF dialect or other dialects * This will provide common loco::NodeShape as shape information */ |