summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2019-09-17 07:01:54 +0900
committerGitHub Enterprise <noreply-CODE@samsung.com>2019-09-17 07:01:54 +0900
commit93d85e1d258772b6fab86946bc1c4cfe60f18458 (patch)
tree1bb5a17c9dea8cecda8016a6f45391a961b7c248
parentd96fbc0347d2f3455209e72e4478ee6840b58b4f (diff)
downloadnnfw-93d85e1d258772b6fab86946bc1c4cfe60f18458.tar.gz
nnfw-93d85e1d258772b6fab86946bc1c4cfe60f18458.tar.bz2
nnfw-93d85e1d258772b6fab86946bc1c4cfe60f18458.zip
[exo-tflite] Shape and Type inference for TFLMul (#7470)
This will add Shape and Type inference for TFLMul node Signed-off-by: SaeHie Park <saehie.park@samsung.com>
-rw-r--r--compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp10
-rw-r--r--compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp2
2 files changed, 10 insertions, 2 deletions
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
index e8df92ad7..7b65705d4 100644
--- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
+++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp
@@ -226,7 +226,15 @@ public:
return infer_pool_2d_shape(node);
}
- // TODO TFLMul
+ loco::NodeShape visit(const locoex::TFLMul *node) final
+ {
+ auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
+ auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
+
+ auto output_shape = broadcast_shape(x_shape, y_shape);
+
+ return loco::NodeShape{output_shape};
+ }
// TODO TFLNop
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp
index ee9918eb3..5be25b041 100644
--- a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp
+++ b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp
@@ -60,7 +60,7 @@ struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::DataTy
return loco::dtype_get(node->value());
}
- // TODO TFLMul
+ loco::DataType visit(const locoex::TFLMul *node) final { return loco::dtype_get(node->x()); }
// TODO TFLNop