From bdf28f49e816ac22193073be5da1acae0cc655e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princi?= =?UTF-8?q?pal=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Tue, 17 Sep 2019 18:07:17 +0900 Subject: [moco-tf] copy shape for binary input nodes (#7525) This will update FixShape to use correct copy_shapedata() for binary input nodes Signed-off-by: SaeHie Park --- compiler/moco-tf/src/Transforms/FixShapeTransform.cpp | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp index 818a7f7da..934f5793a 100644 --- a/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp +++ b/compiler/moco-tf/src/Transforms/FixShapeTransform.cpp @@ -546,10 +546,9 @@ bool fix_shape(moco::tf::TFAdd *node) return false; if (!node_shape(y, y_shape)) return false; - // TODO check shape difference // Output shape is same as the input - return copy_shapedata(x, node); + return copy_shapedata(x, y, node); } bool fix_shape(moco::tf::TFAvgPool *node) @@ -1102,10 +1101,9 @@ bool fix_shape(moco::tf::TFMul *node) return false; if (!node_shape(y, y_shape)) return false; - // TODO check shape difference // Output shape is same as the input - return copy_shapedata(x, node); + return copy_shapedata(x, y, node); } bool fix_shape(moco::tf::TFMean *node) @@ -1194,10 +1192,8 @@ bool fix_shape(moco::tf::TFRealDiv *node) if (!node_shape(y, y_shape)) return false; - // TODO check shape difference - // Output shape is same as the input - return copy_shapedata(x, node); + return copy_shapedata(x, y, node); } bool fix_shape(moco::tf::TFRelu *node) @@ -1324,9 +1320,9 @@ bool fix_shape(moco::tf::TFSoftmax *node) bool fix_shape(moco::tf::TFSquaredDifference *node) { - // Output shape is same as the input x auto x = node->x(); - return copy_shapedata(x, node); + auto y = node->y(); + return copy_shapedata(x, y, node); } bool fix_shape(moco::tf::TFSqueeze *node) @@ -1451,10 +1447,8 @@ bool fix_shape(moco::tf::TFSub *node) if (!node_shape(y, y_shape)) return false; - // TODO check shape difference - // Output shape is same as the input - return copy_shapedata(x, node); + return copy_shapedata(x, y, node); } bool fix_shape(moco::tf::TFTanh *node) -- cgit v1.2.3