summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2019-09-17 18:07:17 +0900
committerGitHub Enterprise <noreply-CODE@samsung.com>2019-09-17 18:07:17 +0900
commitbdf28f49e816ac22193073be5da1acae0cc655e6 (patch)
tree9e1c4fe3a733cb29a7e3bf4fbbd69b5a7567ef70
parentde4722b761d55b30cd2f0e51e8a5fa090f3c349e (diff)
downloadnnfw-bdf28f49e816ac22193073be5da1acae0cc655e6.tar.gz
nnfw-bdf28f49e816ac22193073be5da1acae0cc655e6.tar.bz2
nnfw-bdf28f49e816ac22193073be5da1acae0cc655e6.zip
[moco-tf] copy shape for binary input nodes (#7525)
This will update FixShape to use correct copy_shapedata() for binary input nodes Signed-off-by: SaeHie Park <saehie.park@samsung.com>
-rw-r--r--compiler/moco-tf/src/Transforms/FixShapeTransform.cpp18
1 files changed, 6 insertions, 12 deletions
diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp
index 818a7f7da..934f5793a 100644
--- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp
+++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp
@@ -546,10 +546,9 @@ bool fix_shape(moco::tf::TFAdd *node)
return false;
if (!node_shape(y, y_shape))
return false;
- // TODO check shape difference
// Output shape is same as the input
- return copy_shapedata(x, node);
+ return copy_shapedata(x, y, node);
}
bool fix_shape(moco::tf::TFAvgPool *node)
@@ -1102,10 +1101,9 @@ bool fix_shape(moco::tf::TFMul *node)
return false;
if (!node_shape(y, y_shape))
return false;
- // TODO check shape difference
// Output shape is same as the input
- return copy_shapedata(x, node);
+ return copy_shapedata(x, y, node);
}
bool fix_shape(moco::tf::TFMean *node)
@@ -1194,10 +1192,8 @@ bool fix_shape(moco::tf::TFRealDiv *node)
if (!node_shape(y, y_shape))
return false;
- // TODO check shape difference
-
// Output shape is same as the input
- return copy_shapedata(x, node);
+ return copy_shapedata(x, y, node);
}
bool fix_shape(moco::tf::TFRelu *node)
@@ -1324,9 +1320,9 @@ bool fix_shape(moco::tf::TFSoftmax *node)
bool fix_shape(moco::tf::TFSquaredDifference *node)
{
- // Output shape is same as the input x
auto x = node->x();
- return copy_shapedata(x, node);
+ auto y = node->y();
+ return copy_shapedata(x, y, node);
}
bool fix_shape(moco::tf::TFSqueeze *node)
@@ -1451,10 +1447,8 @@ bool fix_shape(moco::tf::TFSub *node)
if (!node_shape(y, y_shape))
return false;
- // TODO check shape difference
-
// Output shape is same as the input
- return copy_shapedata(x, node);
+ return copy_shapedata(x, y, node);
}
bool fix_shape(moco::tf::TFTanh *node)