diff options
author | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-17 18:07:17 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2019-09-17 18:07:17 +0900 |
commit | bdf28f49e816ac22193073be5da1acae0cc655e6 (patch) | |
tree | 9e1c4fe3a733cb29a7e3bf4fbbd69b5a7567ef70 | |
parent | de4722b761d55b30cd2f0e51e8a5fa090f3c349e (diff) | |
download | nnfw-bdf28f49e816ac22193073be5da1acae0cc655e6.tar.gz nnfw-bdf28f49e816ac22193073be5da1acae0cc655e6.tar.bz2 nnfw-bdf28f49e816ac22193073be5da1acae0cc655e6.zip |
[moco-tf] copy shape for binary input nodes (#7525)
This will update FixShape to use correct copy_shapedata() for binary input nodes
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
-rw-r--r-- | compiler/moco-tf/src/Transforms/FixShapeTransform.cpp | 18 |
1 files changed, 6 insertions, 12 deletions
diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index 818a7f7da..934f5793a 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -546,10 +546,9 @@ bool fix_shape(moco::tf::TFAdd *node) return false; if (!node_shape(y, y_shape)) return false; - // TODO check shape difference // Output shape is same as the input - return copy_shapedata(x, node); + return copy_shapedata(x, y, node); } bool fix_shape(moco::tf::TFAvgPool *node) @@ -1102,10 +1101,9 @@ bool fix_shape(moco::tf::TFMul *node) return false; if (!node_shape(y, y_shape)) return false; - // TODO check shape difference // Output shape is same as the input - return copy_shapedata(x, node); + return copy_shapedata(x, y, node); } bool fix_shape(moco::tf::TFMean *node) @@ -1194,10 +1192,8 @@ bool fix_shape(moco::tf::TFRealDiv *node) if (!node_shape(y, y_shape)) return false; - // TODO check shape difference - // Output shape is same as the input - return copy_shapedata(x, node); + return copy_shapedata(x, y, node); } bool fix_shape(moco::tf::TFRelu *node) @@ -1324,9 +1320,9 @@ bool fix_shape(moco::tf::TFSoftmax *node) bool fix_shape(moco::tf::TFSquaredDifference *node) { - // Output shape is same as the input x auto x = node->x(); - return copy_shapedata(x, node); + auto y = node->y(); + return copy_shapedata(x, y, node); } bool fix_shape(moco::tf::TFSqueeze *node) @@ -1451,10 +1447,8 @@ bool fix_shape(moco::tf::TFSub *node) if (!node_shape(y, y_shape)) return false; - // TODO check shape difference - // Output shape is same as the input - return copy_shapedata(x, node); + return copy_shapedata(x, y, node); } bool fix_shape(moco::tf::TFTanh *node) |