diff options
author | Chunseok Lee <chunseok.lee@samsung.com> | 2020-10-28 12:16:55 +0900 |
---|---|---|
committer | Chunseok Lee <chunseok.lee@samsung.com> | 2020-10-28 12:16:55 +0900 |
commit | c55f8a6db48cda9d3a78048338b7f18c4cca62b8 (patch) | |
tree | 761ee8e171e5203f5c598ad93b2e7e0bc2e31aa2 /compiler/luci/service | |
parent | 74476a2d0296bdad70a2f7f90bc7419a8b05bffd (diff) | |
download | nnfw-c55f8a6db48cda9d3a78048338b7f18c4cca62b8.tar.gz nnfw-c55f8a6db48cda9d3a78048338b7f18c4cca62b8.tar.bz2 nnfw-c55f8a6db48cda9d3a78048338b7f18c4cca62b8.zip |
Imported Upstream version 1.10.0upstream/1.10.0submit/tizen/20201028.104702submit/tizen/20201028.031836accepted/tizen/unified/20201029.124827
Diffstat (limited to 'compiler/luci/service')
-rw-r--r-- | compiler/luci/service/src/CircleShapeInferenceRule.cpp | 27 | ||||
-rw-r--r-- | compiler/luci/service/src/CircleTypeInferenceRule.cpp | 7 | ||||
-rw-r--r-- | compiler/luci/service/src/Validate.cpp | 5 |
3 files changed, 39 insertions, 0 deletions
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp index db25186b1..a55f50b19 100644 --- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp @@ -1608,6 +1608,22 @@ loco::NodeShape infer_unpack(const luci::CircleUnpack *node) return loco::NodeShape{output_shape}; } +loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectionalSequenceLSTM *node) +{ + auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + auto recurrent_to_output_weights = + loco::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>(); + auto rank = input_shape.rank(); + loco::TensorShape output_shape; + output_shape.rank(rank); + for (uint32_t i = 0; i < rank - 1; i++) + { + output_shape.dim(i) = input_shape.dim(i); + } + output_shape.dim(rank - 1) = recurrent_to_output_weights.dim(1); + return loco::NodeShape{output_shape}; +} + loco::NodeShape infer_unique(const luci::CircleUnique *node) { auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); @@ -2047,6 +2063,12 @@ public: return infer_depthwise_conv2d(node); } + loco::NodeShape visit(const luci::CircleDequantize *node) final + { + const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); + return loco::NodeShape{input_shape}; + } + loco::NodeShape visit(const luci::CircleDiv *node) final { return broadcast_xy(node); } loco::NodeShape visit(const luci::CircleElu *node) final @@ -2373,6 +2395,11 @@ public: loco::NodeShape visit(const luci::CircleUnpack *node) final { return infer_unpack(node); } + loco::NodeShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final + { + return infer_unidirectionalsequencelstm(node); + } + loco::NodeShape visit(const luci::CircleUnique *node) final { return infer_unique(node); } loco::NodeShape visit(const luci::CircleWhere *node) final { return use_own(node); } diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp index d28d8ac99..f738ab5a8 100644 --- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp +++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp @@ -111,6 +111,8 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT return loco::dtype_get(node->input()); } + loco::DataType visit(const luci::CircleDequantize *) final { return loco::DataType::FLOAT32; } + loco::DataType visit(const luci::CircleDiv *node) final { return loco::dtype_get(node->x()); } loco::DataType visit(const luci::CircleElu *node) final @@ -490,6 +492,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT return loco::dtype_get(node->outBackprop()); } + loco::DataType visit(const luci::CircleUnidirectionalSequenceLSTM *node) final + { + return loco::dtype_get(node->input()); + } + loco::DataType visit(const luci::CircleUnique *node) final { return loco::dtype_get(node->input()); diff --git a/compiler/luci/service/src/Validate.cpp b/compiler/luci/service/src/Validate.cpp index 282a068e0..d224fd172 100644 --- a/compiler/luci/service/src/Validate.cpp +++ b/compiler/luci/service/src/Validate.cpp @@ -75,6 +75,11 @@ bool validate_shape_dtype(loco::Graph *g) assert(circle_output != nullptr); assert(circle_output->from() != nullptr); auto circle_node = loco::must_cast<luci::CircleNode *>(circle_output->from()); + + // Shape and dtype validation for CiecleOutputExclude is not needed + if (dynamic_cast<luci::CircleOutputExclude *>(circle_node)) + continue; + assert(loco::shape_known(circle_node)); // check if output node shape is same as graph output shape |