summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>2020-03-31 15:06:32 +0900
committerGitHub Enterprise <noreply-CODE@samsung.com>2020-03-31 15:06:32 +0900
commit2a7bcc886e941466179d5d2383a32cf6ed68b318 (patch)
tree6e4bfbb79e76cfd920e3af65014803d9bb0e1bc6
parenta2bdfac4d51a30e47fb0917a669befca7193d478 (diff)
downloadnnfw-2a7bcc886e941466179d5d2383a32cf6ed68b318.tar.gz
nnfw-2a7bcc886e941466179d5d2383a32cf6ed68b318.tar.bz2
nnfw-2a7bcc886e941466179d5d2383a32cf6ed68b318.zip
[luci] Shape/Type inference for Softmax (#10781)
This will enable Shape and Type inference for Softmax Op Signed-off-by: SaeHie Park <saehie.park@samsung.com>
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceRule.cpp7
-rw-r--r--compiler/luci/service/src/CircleTypeInferenceRule.cpp5
2 files changed, 10 insertions, 2 deletions
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
index b58e6704d..bd8fb860b 100644
--- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
@@ -618,7 +618,12 @@ public:
return loco::NodeShape{input_shape};
}
- // TODO CircleSoftmax
+ loco::NodeShape visit(const luci::CircleSoftmax *node) final
+ {
+ auto input_shape = loco::shape_get(node->logits()).as<loco::TensorShape>();
+
+ return loco::NodeShape{input_shape};
+ }
loco::NodeShape visit(const luci::CircleSqrt *node) final
{
diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp
index 37ab9b525..5f6df6a44 100644
--- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp
@@ -99,7 +99,10 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
loco::DataType visit(const luci::CircleRsqrt *node) final { return loco::dtype_get(node->x()); }
- // TODO CircleSoftmax
+ loco::DataType visit(const luci::CircleSoftmax *node) final
+ {
+ return loco::dtype_get(node->logits());
+ }
loco::DataType visit(const luci::CircleSqrt *node) final { return loco::dtype_get(node->x()); }