diff options
author | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2019-09-17 07:01:54 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2019-09-17 07:01:54 +0900 |
commit | 93d85e1d258772b6fab86946bc1c4cfe60f18458 (patch) | |
tree | 1bb5a17c9dea8cecda8016a6f45391a961b7c248 | |
parent | d96fbc0347d2f3455209e72e4478ee6840b58b4f (diff) | |
download | nnfw-93d85e1d258772b6fab86946bc1c4cfe60f18458.tar.gz nnfw-93d85e1d258772b6fab86946bc1c4cfe60f18458.tar.bz2 nnfw-93d85e1d258772b6fab86946bc1c4cfe60f18458.zip |
[exo-tflite] Shape and Type inference for TFLMul (#7470)
This will add Shape and Type inference for TFLMul node
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
-rw-r--r-- | compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp | 10 | ||||
-rw-r--r-- | compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp | 2 |
2 files changed, 10 insertions, 2 deletions
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp index e8df92ad7..7b65705d4 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp @@ -226,7 +226,15 @@ public: return infer_pool_2d_shape(node); } - // TODO TFLMul + loco::NodeShape visit(const locoex::TFLMul *node) final + { + auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>(); + + auto output_shape = broadcast_shape(x_shape, y_shape); + + return loco::NodeShape{output_shape}; + } // TODO TFLNop diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp index ee9918eb3..5be25b041 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp @@ -60,7 +60,7 @@ struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::DataTy return loco::dtype_get(node->value()); } - // TODO TFLMul + loco::DataType visit(const locoex::TFLMul *node) final { return loco::dtype_get(node->x()); } // TODO TFLNop |