diff options
-rw-r--r-- | compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp | 10 | ||||
-rw-r--r-- | compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp | 2 |
2 files changed, 10 insertions, 2 deletions
diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp index e8df92ad7..7b65705d4 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLShapeInferenceRule.cpp @@ -226,7 +226,15 @@ public: return infer_pool_2d_shape(node); } - // TODO TFLMul + loco::NodeShape visit(const locoex::TFLMul *node) final + { + auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>(); + auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>(); + + auto output_shape = broadcast_shape(x_shape, y_shape); + + return loco::NodeShape{output_shape}; + } // TODO TFLNop diff --git a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp index ee9918eb3..5be25b041 100644 --- a/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp +++ b/compiler/exo-tflite/src/Dialect/Service/TFLTypeInferenceRule.cpp @@ -60,7 +60,7 @@ struct TypeInferenceAlgorithm final : public locoex::TFLNodeVisitor<loco::DataTy return loco::dtype_get(node->value()); } - // TODO TFLMul + loco::DataType visit(const locoex::TFLMul *node) final { return loco::dtype_get(node->x()); } // TODO TFLNop |