summaryrefslogtreecommitdiff
path: root/compiler/luci/service
diff options
context:
space:
mode:
authorChunseok Lee <chunseok.lee@samsung.com>2020-10-28 12:16:55 +0900
committerChunseok Lee <chunseok.lee@samsung.com>2020-10-28 12:16:55 +0900
commitc55f8a6db48cda9d3a78048338b7f18c4cca62b8 (patch)
tree761ee8e171e5203f5c598ad93b2e7e0bc2e31aa2 /compiler/luci/service
parent74476a2d0296bdad70a2f7f90bc7419a8b05bffd (diff)
downloadnnfw-c55f8a6db48cda9d3a78048338b7f18c4cca62b8.tar.gz
nnfw-c55f8a6db48cda9d3a78048338b7f18c4cca62b8.tar.bz2
nnfw-c55f8a6db48cda9d3a78048338b7f18c4cca62b8.zip
Diffstat (limited to 'compiler/luci/service')
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceRule.cpp27
-rw-r--r--compiler/luci/service/src/CircleTypeInferenceRule.cpp7
-rw-r--r--compiler/luci/service/src/Validate.cpp5
3 files changed, 39 insertions, 0 deletions
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
index db25186b1..a55f50b19 100644
--- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
@@ -1608,6 +1608,22 @@ loco::NodeShape infer_unpack(const luci::CircleUnpack *node)
return loco::NodeShape{output_shape};
}
+loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectionalSequenceLSTM *node)
+{
+ auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto recurrent_to_output_weights =
+ loco::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>();
+ auto rank = input_shape.rank();
+ loco::TensorShape output_shape;
+ output_shape.rank(rank);
+ for (uint32_t i = 0; i < rank - 1; i++)
+ {
+ output_shape.dim(i) = input_shape.dim(i);
+ }
+ output_shape.dim(rank - 1) = recurrent_to_output_weights.dim(1);
+ return loco::NodeShape{output_shape};
+}
+
loco::NodeShape infer_unique(const luci::CircleUnique *node)
{
auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
@@ -2047,6 +2063,12 @@ public:
return infer_depthwise_conv2d(node);
}
+ loco::NodeShape visit(const luci::CircleDequantize *node) final
+ {
+ const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ return loco::NodeShape{input_shape};
+ }
+
loco::NodeShape visit(const luci::CircleDiv *node) final { return broadcast_xy(node); }
loco::NodeShape visit(const luci::CircleElu *node) final
@@ -2373,6 +2395,11 @@ public:
loco::NodeShape visit(const luci::CircleUnpack *node) final { return infer_unpack(node); }
+ loco::NodeShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final
+ {
+ return infer_unidirectionalsequencelstm(node);
+ }
+
loco::NodeShape visit(const luci::CircleUnique *node) final { return infer_unique(node); }
loco::NodeShape visit(const luci::CircleWhere *node) final { return use_own(node); }
diff --git a/compiler/luci/service/src/CircleTypeInferenceRule.cpp b/compiler/luci/service/src/CircleTypeInferenceRule.cpp
index d28d8ac99..f738ab5a8 100644
--- a/compiler/luci/service/src/CircleTypeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleTypeInferenceRule.cpp
@@ -111,6 +111,8 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
return loco::dtype_get(node->input());
}
+ loco::DataType visit(const luci::CircleDequantize *) final { return loco::DataType::FLOAT32; }
+
loco::DataType visit(const luci::CircleDiv *node) final { return loco::dtype_get(node->x()); }
loco::DataType visit(const luci::CircleElu *node) final
@@ -490,6 +492,11 @@ struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataT
return loco::dtype_get(node->outBackprop());
}
+ loco::DataType visit(const luci::CircleUnidirectionalSequenceLSTM *node) final
+ {
+ return loco::dtype_get(node->input());
+ }
+
loco::DataType visit(const luci::CircleUnique *node) final
{
return loco::dtype_get(node->input());
diff --git a/compiler/luci/service/src/Validate.cpp b/compiler/luci/service/src/Validate.cpp
index 282a068e0..d224fd172 100644
--- a/compiler/luci/service/src/Validate.cpp
+++ b/compiler/luci/service/src/Validate.cpp
@@ -75,6 +75,11 @@ bool validate_shape_dtype(loco::Graph *g)
assert(circle_output != nullptr);
assert(circle_output->from() != nullptr);
auto circle_node = loco::must_cast<luci::CircleNode *>(circle_output->from());
+
+ // Shape and dtype validation for CiecleOutputExclude is not needed
+ if (dynamic_cast<luci::CircleOutputExclude *>(circle_node))
+ continue;
+
assert(loco::shape_known(circle_node));
// check if output node shape is same as graph output shape