diff options
author | 박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com> | 2020-03-31 15:06:32 +0900 |
---|---|---|
committer | GitHub Enterprise <noreply-CODE@samsung.com> | 2020-03-31 15:06:32 +0900 |
commit | 2a7bcc886e941466179d5d2383a32cf6ed68b318 (patch) | |
tree | 6e4bfbb79e76cfd920e3af65014803d9bb0e1bc6 | |
parent | a2bdfac4d51a30e47fb0917a669befca7193d478 (diff) | |
download | nnfw-2a7bcc886e941466179d5d2383a32cf6ed68b318.tar.gz nnfw-2a7bcc886e941466179d5d2383a32cf6ed68b318.tar.bz2 nnfw-2a7bcc886e941466179d5d2383a32cf6ed68b318.zip |
[luci] Shape/Type inference for Softmax (#10781)
This will enable Shape and Type inference for Softmax Op
Signed-off-by: SaeHie Park <saehie.park@samsung.com>
-rw-r--r-- | compiler/luci/service/src/CircleShapeInferenceRule.cpp | 7 | ||||
-rw-r--r-- | compiler/luci/service/src/CircleTypeInferenceRule.cpp | 5 |
2 files changed, 10 insertions, 2 deletions
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index b58e6704d..bd8fb860b 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -618,7 +618,12 @@ public: return loco::NodeShape{input_shape}; } - // TODO CircleSoftmax + loco::NodeShape visit(const luci::CircleSoftmax *node) final + { + auto input_shape = loco::shape_get(node->logits()).as<loco::TensorShape>(); + + return loco::NodeShape{input_shape}; + } loco::NodeShape visit(const luci::CircleSqrt *node) final { diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp index 37ab9b525..5f6df6a44 100644 --- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp @@ -99,7 +99,10 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT loco::DataType visit(const luci::CircleRsqrt *node) final { return loco::dtype_get(node->x()); } - // TODO CircleSoftmax + loco::DataType visit(const luci::CircleSoftmax *node) final + { + return loco::dtype_get(node->logits()); + } loco::DataType visit(const luci::CircleSqrt *node) final { return loco::dtype_get(node->x()); } |