summaryrefslogtreecommitdiff
path: root/compiler/luci/service/src/CircleShapeInferenceRule.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/luci/service/src/CircleShapeInferenceRule.cpp')
-rw-r--r--compiler/luci/service/src/CircleShapeInferenceRule.cpp622
1 files changed, 363 insertions, 259 deletions
diff --git a/compiler/luci/service/src/CircleShapeInferenceRule.cpp b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
index db25186b1..e8febc58f 100644
--- a/compiler/luci/service/src/CircleShapeInferenceRule.cpp
+++ b/compiler/luci/service/src/CircleShapeInferenceRule.cpp
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,6 +18,7 @@
#include "luci/Service/CircleShapeInferenceRule.h"
#include "Check.h"
+#include "CircleShapeInferenceHelper.h"
#include "ShapeInfer_StridedSlice.h"
#include <luci/IR/CircleNodes.h>
@@ -41,7 +43,11 @@ std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape
{
if (r)
os << ",";
- os << tensor_shape.dim(r).value();
+
+ if (tensor_shape.dim(r).known())
+ os << tensor_shape.dim(r).value();
+ else
+ os << "?";
}
os << "]";
return os;
@@ -52,7 +58,15 @@ loco::TensorShape own_shape(const luci::CircleNode *node)
loco::TensorShape shape;
shape.rank(node->rank());
for (uint32_t r = 0; r < node->rank(); ++r)
- shape.dim(r) = loco::Dimension(node->dim(r).value());
+ {
+ // Shape inference rules in this file did not consider unknown dimension.
+ // If some node has unknown dimension, 0 is inserted and wrong shape
+ // inference was done as a result.
+ // To fix this, new shape inference algorithm is being implemented.
+ // Until new inference algorithm is fully implemented, unknown dimension
+ // would be represented as 1 along with TFLite expression.
+ shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
+ }
return shape;
}
@@ -102,7 +116,7 @@ private:
};
/**
- * @breif Expand shape x and y to same rank by align right and filling with 1
+ * @brief Expand shape x and y to same rank by align right and filling with 1
*/
void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
{
@@ -122,7 +136,7 @@ void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
}
/**
- * @breif Returns shape of expanded dimension of input x and y having same rank
+ * @brief Returns shape of expanded dimension of input x and y having same rank
*/
loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
{
@@ -135,10 +149,8 @@ loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::Tenso
output_shape.rank(rank);
for (uint32_t axis = 0; axis < rank; ++axis)
{
- assert(x.dim(axis).known() && y.dim(axis).known());
-
- auto x_dim = x.dim(axis).value();
- auto y_dim = y.dim(axis).value();
+ auto x_dim = x.dim(axis).known() ? x.dim(axis).value() : 1;
+ auto y_dim = y.dim(axis).known() ? y.dim(axis).value() : 1;
// each dimension of x and y should be same or one must be 1 if different
if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
@@ -177,32 +189,34 @@ template <loco::DataType T> std::vector<int64_t> vector_from_constant(luci::Circ
template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node)
{
- auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>();
- auto y_shape = loco::shape_get(node->y()).template as<loco::TensorShape>();
+ auto x_shape = luci::shape_get(node->x()).template as<loco::TensorShape>();
+ auto y_shape = luci::shape_get(node->y()).template as<loco::TensorShape>();
auto output_shape = broadcast_shape(x_shape, y_shape);
return loco::NodeShape{output_shape};
}
-template <class CIRCLENODE> loco::NodeShape use_x(const CIRCLENODE *node)
-{
- auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>();
- return loco::NodeShape{x_shape};
-}
+#define DECLARE_USE_SINGLE(NAME) \
+ template <class CIRCLENODE> loco::NodeShape use_##NAME(const CIRCLENODE *node) \
+ { \
+ auto inputs_shape = luci::shape_get(node->NAME()).template as<loco::TensorShape>(); \
+ return loco::NodeShape{inputs_shape}; \
+ }
-template <class CIRCLENODE> loco::NodeShape use_logits(const CIRCLENODE *node)
-{
- auto shape = loco::shape_get(node->logits()).template as<loco::TensorShape>();
- return loco::NodeShape{shape};
-}
+DECLARE_USE_SINGLE(input);
+DECLARE_USE_SINGLE(inputs);
+DECLARE_USE_SINGLE(x);
+DECLARE_USE_SINGLE(logits);
+
+#undef DECLARE_USE_SINGLE
template <class CIRCLENODE>
loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *paddings)
{
const loco::DataType S32 = loco::DataType::S32;
- auto input_shape = loco::shape_get(node->input()).template as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
// TODO support other data type
LUCI_ASSERT(paddings->dtype() == S32, "Only support int 32 for now");
@@ -232,11 +246,11 @@ loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *pa
loco::NodeShape infer_add_n(const luci::CircleAddN *node)
{
- auto shape = loco::shape_get(node->inputs(0)).as<loco::TensorShape>();
+ auto shape = luci::shape_get(node->inputs(0)).as<loco::TensorShape>();
for (uint32_t idx = 1; idx < node->arity(); ++idx)
{
- auto shape_idx = loco::shape_get(node->inputs(idx)).as<loco::TensorShape>();
+ auto shape_idx = luci::shape_get(node->inputs(idx)).as<loco::TensorShape>();
if (!(shape == shape_idx))
{
INTERNAL_EXN_V("ADD_N shape not same as the first input: ", idx);
@@ -245,10 +259,10 @@ loco::NodeShape infer_add_n(const luci::CircleAddN *node)
return loco::NodeShape{shape};
}
-loco::NodeShape infer_arg_max(const luci::CircleArgMax *node)
+template <class CIRCLENODE> loco::NodeShape infer_arg_maxmin(const CIRCLENODE *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
+ auto dimension_shape = luci::shape_get(node->dimension()).template as<loco::TensorShape>();
int64_t select_axis = 0;
{
@@ -258,55 +272,19 @@ loco::NodeShape infer_arg_max(const luci::CircleArgMax *node)
// Support S32 for now.
auto const_shape_node = loco::must_cast<luci::CircleConst *>(node->dimension());
LUCI_ASSERT(const_shape_node->dtype() == loco::DataType::S32,
- "Only support int32 CircleConst for CircleArgMax");
+ "Only support int32 CircleConst for CircleArgMax/CircleArgMin");
if (const_shape_node->rank() > 1)
INTERNAL_EXN_V("Only support rank 0/1 CircleConst",
oops::to_uint32(const_shape_node->rank()));
- select_axis = const_shape_node->scalar<loco::DataType::S32>();
- }
- assert(select_axis < input_shape.rank());
- assert(select_axis >= 0); // TODO support minus of this breaks
-
- // NOTE select_axis is removed
- loco::TensorShape shape_output;
- uint32_t rank = input_shape.rank();
- uint32_t shrink = static_cast<uint32_t>(select_axis);
- assert(rank > 0);
- shape_output.rank(rank - 1);
- for (uint32_t r = 0, d = 0; r < rank; ++r)
- {
- if (r == shrink)
- continue;
- shape_output.dim(d++) = input_shape.dim(r);
+ select_axis = const_shape_node->template scalar<loco::DataType::S32>();
}
- return loco::NodeShape{shape_output};
-}
-
-loco::NodeShape infer_arg_min(const luci::CircleArgMin *node)
-{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>();
-
- int64_t select_axis = 0;
- {
- LUCI_ASSERT(node->dimension(), "2nd input dimension() should not be nullptr");
-
- // Only support node's shape() is CircleConst with S32/S64
- // Support S32 for now.
- auto const_shape_node = loco::must_cast<luci::CircleConst *>(node->dimension());
- LUCI_ASSERT(const_shape_node->dtype() == loco::DataType::S32,
- "Only support int32 CircleConst for CircleArgMin");
- if (const_shape_node->rank() > 1)
- INTERNAL_EXN_V("Only support rank 0/1 CircleConst",
- oops::to_uint32(const_shape_node->rank()));
-
- select_axis = const_shape_node->scalar<loco::DataType::S32>();
- }
assert(select_axis < input_shape.rank());
- assert(select_axis >= 0); // TODO support minus of this breaks
+
+ if (select_axis < 0)
+ select_axis += static_cast<int64_t>(input_shape.rank());
// NOTE select_axis is removed
loco::TensorShape shape_output;
@@ -326,10 +304,10 @@ loco::NodeShape infer_arg_min(const luci::CircleArgMin *node)
// Call this for CircleAvgPool2D and CircleMaxPool2D only
template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
{
- LUCI_ASSERT(loco::shape_known(node->value()), "Shape must be known");
-
- auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>();
+ auto ifm_shape = luci::shape_get(node->value()).template as<loco::TensorShape>();
assert(ifm_shape.rank() == 4);
+ assert(ifm_shape.dim(1).known());
+ assert(ifm_shape.dim(2).known());
uint32_t input_height = ifm_shape.dim(1).value();
uint32_t input_width = ifm_shape.dim(2).value();
@@ -347,6 +325,8 @@ template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType
if (node->padding() == luci::Padding::VALID)
{
+ LUCI_ASSERT(input_height + stride_height > effective_window_height, "Invalid shape");
+ LUCI_ASSERT(input_width + stride_width > effective_window_width, "Invalid shape");
output_height = (input_height + stride_height - effective_window_height) / stride_height;
output_width = (input_width + stride_width - effective_window_width) / stride_width;
}
@@ -372,7 +352,7 @@ loco::NodeShape infer_batch_to_space_nd(const luci::CircleBatchToSpaceND *node)
{
const loco::DataType S32 = loco::DataType::S32;
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
// Support only input rank is 3 and 4
assert(input_shape.rank() == 3 || input_shape.rank() == 4);
@@ -384,8 +364,8 @@ loco::NodeShape infer_batch_to_space_nd(const luci::CircleBatchToSpaceND *node)
auto const_crops = loco::must_cast<luci::CircleConst *>(node->crops());
LUCI_ASSERT(const_crops->dtype() == loco::DataType::S32, "Only support int32 crops");
- auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>();
- auto const_crops_shape = loco::shape_get(const_crops).as<loco::TensorShape>();
+ auto const_block_shape_shape = luci::shape_get(const_block_shape).as<loco::TensorShape>();
+ auto const_crops_shape = luci::shape_get(const_crops).as<loco::TensorShape>();
assert(const_block_shape_shape.rank() == 1);
assert(const_crops_shape.rank() == 2);
@@ -423,10 +403,14 @@ struct OutputSize
template <class Conv2DType> OutputSize infer_conv2d_type(const Conv2DType *node)
{
- auto ifm_shape = loco::shape_get(node->input()).template as<loco::TensorShape>();
- auto ker_shape = loco::shape_get(node->filter()).template as<loco::TensorShape>();
+ auto ifm_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
+ auto ker_shape = luci::shape_get(node->filter()).template as<loco::TensorShape>();
assert(ifm_shape.rank() == 4);
assert(ker_shape.rank() == 4);
+ assert(ifm_shape.dim(1).known());
+ assert(ifm_shape.dim(2).known());
+ assert(ker_shape.dim(1).known());
+ assert(ker_shape.dim(2).known());
uint32_t input_height = ifm_shape.dim(1).value();
uint32_t input_width = ifm_shape.dim(2).value();
@@ -444,6 +428,8 @@ template <class Conv2DType> OutputSize infer_conv2d_type(const Conv2DType *node)
if (node->padding() == luci::Padding::VALID)
{
+ LUCI_ASSERT(input_height + stride_height > effective_ker_height, "Invalid shape");
+ LUCI_ASSERT(input_width + stride_width > effective_ker_width, "Invalid shape");
output_height = (input_height + stride_height - effective_ker_height) / stride_height;
output_width = (input_width + stride_width - effective_ker_width) / stride_width;
}
@@ -496,7 +482,7 @@ loco::NodeShape infer_batchmatmul_shape(const loco::TensorShape &x_shape,
loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2);
loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1);
- if (not(x_rhs == y_lhs))
+ if (x_rhs.known() && y_lhs.known() && not(x_rhs == y_lhs))
INTERNAL_EXN("x_rhs and y_lhs should be same");
uint32_t out_rank = output_shape.rank();
@@ -506,12 +492,42 @@ loco::NodeShape infer_batchmatmul_shape(const loco::TensorShape &x_shape,
return loco::NodeShape{output_shape};
}
+loco::NodeShape infer_broadcast_to(const luci::CircleBroadcastTo *node)
+{
+ const loco::DataType S32 = loco::DataType::S32;
+
+ loco::TensorShape shape_by_input;
+ {
+ LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
+
+ // Only support node's shape() is CircleConst with S32
+ auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
+ if (const_shape_node != nullptr)
+ {
+ LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");
+
+ shape_by_input.rank(const_shape_node->size<S32>());
+ for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
+ {
+ shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
+ }
+ }
+ else
+ {
+ // We use shape from the node itself
+ shape_by_input = own_shape(node);
+ }
+ }
+
+ return loco::NodeShape{shape_by_input};
+}
+
loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node)
{
// TODO Support when CircleConcatenation has 0 input
assert(node->numValues() > 0);
- auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
+ auto first_shape = luci::shape_get(node->values(0)).as<loco::TensorShape>();
auto axis = node->axis();
if (axis < 0)
axis += first_shape.rank();
@@ -527,14 +543,22 @@ loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node)
for (uint32_t i = 1; i < node->numValues(); ++i)
{
- auto input_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->values(i)).as<loco::TensorShape>();
+ if (input_shape.rank() != output_shape.rank())
+ INTERNAL_EXN_V("Input has incompatible shape", node->name());
for (uint32_t j = 0; j < output_shape.rank(); ++j)
{
if (j == static_cast<uint32_t>(axis))
+ {
+ // If dimension is unknown, value() will return 0.
+ // This is wrong but until new inference algorithm is implemented,
+ // this code will not be modified to keep compatibility.
output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
+ }
else
- assert(output_shape.dim(j) == input_shape.dim(j));
+ assert(!output_shape.dim(j).known() || !input_shape.dim(j).known() ||
+ output_shape.dim(j) == input_shape.dim(j));
}
}
@@ -545,11 +569,8 @@ loco::NodeShape infer_conv2d(const luci::CircleConv2D *node)
{
LOGGER(l);
- auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
- auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
-
- INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker(" << ker_shape.rank()
- << ")" << std::endl;
+ auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
+ auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
assert(ifm_shape.rank() == 4);
assert(ker_shape.rank() == 4);
@@ -564,12 +585,17 @@ loco::NodeShape infer_conv2d(const luci::CircleConv2D *node)
ofm_shape.dim(2) = os.width;
ofm_shape.dim(3) = ker_shape.dim(0);
+ INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker(" << ker_shape.rank()
+ << ") output(" << ofm_shape.dim(0).value() << "," << ofm_shape.dim(1).value() << ","
+ << ofm_shape.dim(2).value() << "," << ofm_shape.dim(3).value() << ") " << node->name()
+ << std::endl;
+
return loco::NodeShape{ofm_shape};
}
loco::NodeShape infer_depth_to_space(const luci::CircleDepthToSpace *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
// Only data format NHWC is supported
@@ -601,12 +627,13 @@ loco::NodeShape infer_depth_to_space(const luci::CircleDepthToSpace *node)
loco::NodeShape infer_depthwise_conv2d(const luci::CircleDepthwiseConv2D *node)
{
- auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
- auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
+ auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>(); // in NHWC
+ auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
assert(ifm_shape.rank() == 4);
assert(ker_shape.rank() == 4);
assert(ker_shape.dim(0).value() == 1);
+ assert(ifm_shape.dim(3).value() * node->depthMultiplier() == ker_shape.dim(3).value());
auto os = infer_conv2d_type(node);
@@ -623,7 +650,7 @@ loco::NodeShape infer_depthwise_conv2d(const luci::CircleDepthwiseConv2D *node)
loco::NodeShape infer_expand_dims(const luci::CircleExpandDims *node)
{
const loco::DataType S32 = loco::DataType::S32;
- auto x_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto x_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
if (x_shape.rank() == 0)
{
// This maybe for unknown shape. We use shape from the node itself.
@@ -637,7 +664,7 @@ loco::NodeShape infer_expand_dims(const luci::CircleExpandDims *node)
}
int32_t axis = const_axis->at<S32>(0);
LUCI_ASSERT((axis <= static_cast<int32_t>(x_shape.rank())) &&
- (axis >= -1 - static_cast<int32_t>(x_shape.rank())),
+ (axis >= -1 - static_cast<int32_t>(x_shape.rank())),
"Axis has to be between [-(D+1), D], where D is rank of input.");
size_t positive_axis = axis < 0 ? x_shape.rank() + axis + 1 : axis;
loco::TensorShape output_shape;
@@ -684,29 +711,41 @@ loco::NodeShape infer_fill(const luci::CircleFill *node)
loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto weights_shape = loco::shape_get(node->weights()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto weights_shape = luci::shape_get(node->weights()).as<loco::TensorShape>();
+
+ loco::TensorShape out_shape;
- // Checking shape capability for fully connected layer
- // Input: a tensor of at least rank 2 [D1, D2, ... Dn]
- // Weight: [# of units, K]
- // Output: [D1 * D2 * ... * Dn / K, # of units]
- if (input_shape.rank() < 2 || weights_shape.rank() != 2)
+ // NOTE Some recipes in some repositories are using rank 4 input for FullyConnected.
+ // Until they are all fixed, disable following assert.
+ // TODO Enable following assert after related fixes are applied
+ // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L194
+ // LUCI_ASSERT(input_shape.rank() == 2 || input_shape.rank() == 3,
+ // "Input rank of FullyConnected should be 2 or 3");
+
+ // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L225
+ LUCI_ASSERT(weights_shape.rank() == 2, "Weights of FullyConnected should be 2");
+
+ // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L353-L367
+ if (node->keep_num_dims())
{
- // Return node own shape if shape inference is not possible
- return use_own(node);
+ out_shape.rank(input_shape.rank());
+ for (uint32_t i = 0; i < input_shape.rank(); ++i)
+ out_shape.dim(i) = input_shape.dim(i);
+ out_shape.dim(out_shape.rank() - 1) = weights_shape.dim(0);
}
-
- uint32_t input_size = 1;
- for (uint32_t i = 0; i < input_shape.rank(); i++)
+ else
{
- input_size = input_size * input_shape.dim(i).value();
+ uint32_t input_size = 1;
+ for (uint32_t i = 0; i < input_shape.rank(); i++)
+ {
+ input_size = input_size * input_shape.dim(i).value();
+ }
+ const uint32_t batch_size = input_size / weights_shape.dim(1).value();
+ out_shape.rank(2);
+ out_shape.dim(0) = batch_size;
+ out_shape.dim(1) = weights_shape.dim(0);
}
- const uint32_t batch_size = input_size / weights_shape.dim(1).value();
- loco::TensorShape out_shape;
- out_shape.rank(2);
- out_shape.dim(0) = batch_size;
- out_shape.dim(1) = weights_shape.dim(0);
return loco::NodeShape{out_shape};
}
@@ -715,8 +754,8 @@ loco::NodeShape infer_gather(const luci::CircleGather *node)
{
loco::TensorShape output_shape;
- const auto input_shape = loco::shape_get(node->params()).as<loco::TensorShape>();
- const auto positions_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->params()).as<loco::TensorShape>();
+ const auto positions_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
int32_t axis = node->axis();
// If CircleGather input has a dynamic shape, it can't inference this shape. So, it returns the
@@ -743,8 +782,8 @@ loco::NodeShape infer_gather_nd(const luci::CircleGatherNd *node)
{
loco::TensorShape output_shape;
- const auto params_shape = loco::shape_get(node->params()).as<loco::TensorShape>();
- const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
+ const auto params_shape = luci::shape_get(node->params()).as<loco::TensorShape>();
+ const auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
const auto params_rank = params_shape.rank();
const auto indices_rank = indices_shape.rank();
@@ -791,7 +830,7 @@ loco::NodeShape infer_matrix_diag(const luci::CircleMatrixDiag *node)
{
loco::TensorShape output_shape;
- auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>();
+ auto diagonal_shape = luci::shape_get(node->diagonal()).as<loco::TensorShape>();
auto rank = diagonal_shape.rank();
output_shape.rank(rank + 1);
@@ -808,8 +847,8 @@ loco::NodeShape infer_matrix_diag(const luci::CircleMatrixDiag *node)
loco::NodeShape infer_matrix_set_diag(const luci::CircleMatrixSetDiag *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto diagonal_shape = luci::shape_get(node->diagonal()).as<loco::TensorShape>();
auto rank = diagonal_shape.rank();
@@ -831,7 +870,7 @@ loco::TensorShape infer_reducer(const loco::Node *input, const loco::Node *indic
{
const loco::DataType S32 = loco::DataType::S32;
- auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(input).as<loco::TensorShape>();
auto reduction_indices = loco::must_cast<const luci::CircleConst *>(indices);
{ // Exceptions
@@ -892,7 +931,7 @@ loco::NodeShape infer_mirror_pad(const luci::CircleMirrorPad *node)
loco::NodeShape infer_one_hot(const luci::CircleOneHot *node)
{
const loco::DataType S32 = loco::DataType::S32;
- auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
+ auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
// Only support OneHot node's depth() is CircleConst with type S32
// TODO support depth with other types
auto depth = loco::must_cast<luci::CircleConst *>(node->depth());
@@ -925,11 +964,11 @@ loco::NodeShape infer_pack(const luci::CirclePack *node)
{
LUCI_ASSERT(node->values_count() > 0, "Only support one or more inputs");
- auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
+ auto first_shape = luci::shape_get(node->values(0)).as<loco::TensorShape>();
// Make sure all inputs have the same shape.
for (uint32_t i = 1; i < node->values_count(); ++i)
{
- auto in_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
+ auto in_shape = luci::shape_get(node->values(i)).as<loco::TensorShape>();
LUCI_ASSERT(loco::NodeShape{first_shape} == loco::NodeShape{in_shape},
"All inputs must have the same shape");
}
@@ -985,8 +1024,8 @@ loco::NodeShape infer_pad_v2(const luci::CirclePadV2 *node)
loco::NodeShape infer_p_relu(const luci::CirclePRelu *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto alpha_shape = loco::shape_get(node->alpha()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto alpha_shape = luci::shape_get(node->alpha()).as<loco::TensorShape>();
auto output_shape = broadcast_shape(input_shape, alpha_shape);
@@ -1087,10 +1126,12 @@ loco::NodeShape infer_reshape(const luci::CircleReshape *node)
loco::TensorShape output_shape = shape_by_input;
// One of the dimensions can have special value -1, meaning its actual value should be inferred.
- const auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>();
- const uint32_t input_element_count = loco::element_count(&input_shape);
+ const auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>();
+ uint32_t input_element_count = 1;
uint32_t output_element_count = 1;
uint32_t unknown_dim_index = UINT32_MAX;
+ for (uint32_t i = 0; i < input_shape.rank(); ++i)
+ input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
{
const uint32_t dim_value = output_shape.dim(dim_index).value();
@@ -1112,45 +1153,17 @@ loco::NodeShape infer_reshape(const luci::CircleReshape *node)
return loco::NodeShape{output_shape};
}
-loco::NodeShape infer_resize_bilinear(const luci::CircleResizeBilinear *node)
-{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
-
- if (input_shape.rank() != 4)
- INTERNAL_EXN("Expected ResizeBilinear input to have rank 4");
-
- auto *const_node = loco::must_cast<luci::CircleConst *>(node->size());
-
- if (const_node->dtype() != loco::DataType::S32)
- INTERNAL_EXN("Only S32 datatype is supported for ResizeBilinear size");
-
- if (const_node->rank() != 1)
- INTERNAL_EXN("Expected size tensor of rank 1");
-
- if (const_node->dim(0).value() != 2)
- INTERNAL_EXN("Expected size tensor with shape [2]");
-
- loco::TensorShape output_shape;
- output_shape.rank(4);
- output_shape.dim(0) = input_shape.dim(0);
- output_shape.dim(1) = const_node->at<loco::DataType::S32>(0);
- output_shape.dim(2) = const_node->at<loco::DataType::S32>(1);
- output_shape.dim(3) = input_shape.dim(3);
-
- return loco::NodeShape{output_shape};
-}
-
-loco::NodeShape infer_resize_nearest_neighbor(const luci::CircleResizeNearestNeighbor *node)
+template <class CIRCLENODE> loco::NodeShape infer_resize_type(const CIRCLENODE *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
if (input_shape.rank() != 4)
- INTERNAL_EXN("Expected ResizeNearesNeighbor input to have rank 4");
+ INTERNAL_EXN("Expected input to have rank 4");
auto *const_node = loco::must_cast<luci::CircleConst *>(node->size());
if (const_node->dtype() != loco::DataType::S32)
- INTERNAL_EXN("Only S32 datatype is supported for ResizeNearesNeighbor size");
+ INTERNAL_EXN("Only S32 datatype is supported for size");
if (const_node->rank() != 1)
INTERNAL_EXN("Expected size tensor of rank 1");
@@ -1161,8 +1174,8 @@ loco::NodeShape infer_resize_nearest_neighbor(const luci::CircleResizeNearestNei
loco::TensorShape output_shape;
output_shape.rank(4);
output_shape.dim(0) = input_shape.dim(0);
- output_shape.dim(1) = const_node->at<loco::DataType::S32>(0);
- output_shape.dim(2) = const_node->at<loco::DataType::S32>(1);
+ output_shape.dim(1) = const_node->template at<loco::DataType::S32>(0);
+ output_shape.dim(2) = const_node->template at<loco::DataType::S32>(1);
output_shape.dim(3) = input_shape.dim(3);
return loco::NodeShape{output_shape};
@@ -1195,8 +1208,8 @@ loco::NodeShape infer_scatter_nd(const luci::CircleScatterNd *node)
loco::NodeShape infer_segment_sum(const luci::CircleSegmentSum *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
- auto segment_shape = loco::shape_get(node->segment_ids()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto segment_shape = luci::shape_get(node->segment_ids()).as<loco::TensorShape>();
LUCI_ASSERT(segment_shape.rank() == 1, "segment_ids must be 1-D tensor");
LUCI_ASSERT(segment_shape.dim(0).value() == input_shape.dim(0).value(),
@@ -1226,11 +1239,11 @@ loco::NodeShape infer_segment_sum(const luci::CircleSegmentSum *node)
loco::NodeShape infer_select(const luci::CircleSelect *node)
{
- auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>();
- assert(t_shape == loco::shape_get(node->e()).as<loco::TensorShape>());
+ auto t_shape = luci::shape_get(node->t()).as<loco::TensorShape>();
+ assert(t_shape == luci::shape_get(node->e()).as<loco::TensorShape>());
// condition shape validation
- auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>();
+ auto c_shape = luci::shape_get(node->condition()).as<loco::TensorShape>();
if (c_shape.rank() != t_shape.rank())
{
if (c_shape.rank() != 0 && c_shape.rank() != 1)
@@ -1248,9 +1261,9 @@ loco::NodeShape infer_select(const luci::CircleSelect *node)
loco::NodeShape infer_select_v2(const luci::CircleSelectV2 *node)
{
- auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>();
- auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>();
- auto e_shape = loco::shape_get(node->e()).as<loco::TensorShape>();
+ auto c_shape = luci::shape_get(node->condition()).as<loco::TensorShape>();
+ auto t_shape = luci::shape_get(node->t()).as<loco::TensorShape>();
+ auto e_shape = luci::shape_get(node->e()).as<loco::TensorShape>();
// validate ability to broadcast shapes to each other
auto b_shape = broadcast_shape(broadcast_shape(c_shape, t_shape), e_shape);
@@ -1259,7 +1272,7 @@ loco::NodeShape infer_select_v2(const luci::CircleSelectV2 *node)
loco::NodeShape infer_shape(const luci::CircleShape *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
loco::TensorShape output_shape;
@@ -1274,7 +1287,7 @@ loco::NodeShape infer_slice(const luci::CircleSlice *node)
const loco::DataType S32 = loco::DataType::S32;
const loco::DataType S64 = loco::DataType::S64;
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
auto const_begin = loco::must_cast<luci::CircleConst *>(node->begin());
auto const_size = loco::must_cast<luci::CircleConst *>(node->size());
@@ -1306,7 +1319,7 @@ loco::NodeShape infer_slice(const luci::CircleSlice *node)
auto size = vect_size.at(idx);
if (size == -1)
{
- size = input_shape.dim(idx).value() - vect_begin.at(idx);
+ size = static_cast<int64_t>(input_shape.dim(idx).value()) - vect_begin.at(idx);
}
output_shape.dim(idx) = size;
}
@@ -1318,7 +1331,7 @@ loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node)
{
const loco::DataType S32 = loco::DataType::S32;
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
// Support only input rank is 3 and 4
assert(input_shape.rank() == 3 || input_shape.rank() == 4);
@@ -1330,8 +1343,8 @@ loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node)
auto const_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
LUCI_ASSERT(const_paddings->dtype() == S32, "Only support int32 paddings");
- auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>();
- auto const_paddings_shape = loco::shape_get(const_paddings).as<loco::TensorShape>();
+ auto const_block_shape_shape = luci::shape_get(const_block_shape).as<loco::TensorShape>();
+ auto const_paddings_shape = luci::shape_get(const_paddings).as<loco::TensorShape>();
assert(const_block_shape_shape.rank() == 1);
assert(const_paddings_shape.rank() == 2);
@@ -1374,7 +1387,7 @@ loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node)
loco::NodeShape infer_space_to_depth(const luci::CircleSpaceToDepth *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
// Only data format NHWC is supported
@@ -1412,19 +1425,33 @@ loco::NodeShape infer_sparse_to_dense(const luci::CircleSparseToDense *node)
auto output_shape_node = dynamic_cast<luci::CircleConst *>(node->output_shape());
if (output_shape_node != nullptr)
{
- // Only support node with S32
- LUCI_ASSERT(output_shape_node->dtype() == loco::DataType::S32,
- "Only support int32 CircleConst");
+ const auto output_shape_type = output_shape_node->dtype();
if (output_shape_node->rank() != 1)
INTERNAL_EXN_V("Only support rank 1 CircleConst",
oops::to_uint32(output_shape_node->rank()));
- shape.rank(output_shape_node->size<loco::DataType::S32>());
+ if (output_shape_type == loco::DataType::S32)
+ {
+ shape.rank(output_shape_node->size<loco::DataType::S32>());
- for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ {
+ shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis);
+ }
+ }
+ else if (output_shape_type == loco::DataType::S64)
{
- shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis);
+ shape.rank(output_shape_node->size<loco::DataType::S64>());
+
+ for (uint32_t axis = 0; axis < shape.rank(); ++axis)
+ {
+ shape.dim(axis) = output_shape_node->at<loco::DataType::S64>(axis);
+ }
+ }
+ else
+ {
+ INTERNAL_EXN("Output shape of SparseToDense must be either int32 or int64");
}
}
else
@@ -1453,7 +1480,7 @@ loco::NodeShape infer_strided_slice(const luci::CircleStridedSlice *node)
loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
// TODO input shape may be unknown before runtime
std::vector<bool> do_squeeze(input_shape.rank(), false);
@@ -1504,11 +1531,35 @@ loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
return loco::NodeShape{output_shape};
}
+loco::NodeShape infer_svdf(const luci::CircleSVDF *node)
+{
+ const auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ const auto weight_feature_shape = luci::shape_get(node->weight_feature()).as<loco::TensorShape>();
+
+ assert(ifm_shape.rank() == 2);
+ assert(weight_feature_shape.rank() == 2);
+
+ assert(ifm_shape.dim(1) == weight_feature_shape.dim(1));
+ assert(weight_feature_shape.dim(0).known());
+
+ const auto rank = node->svdf_rank();
+ const auto num_filters = weight_feature_shape.dim(0).value();
+ assert(num_filters % rank == 0);
+ const auto num_units = num_filters / rank;
+
+ loco::TensorShape ofm_shape;
+ ofm_shape.rank(2);
+ ofm_shape.dim(0) = ifm_shape.dim(0);
+ ofm_shape.dim(1) = num_units;
+
+ return loco::NodeShape{ofm_shape};
+}
+
loco::NodeShape infer_tile(const luci::CircleTile *node)
{
const loco::DataType S32 = loco::DataType::S32;
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
auto multiples = loco::must_cast<luci::CircleConst *>(node->multiples());
// TODO support non-const case
@@ -1534,7 +1585,7 @@ loco::NodeShape infer_tile(const luci::CircleTile *node)
loco::NodeShape infer_transpose(const luci::CircleTranspose *node)
{
- auto input_shape = loco::shape_get(node->a()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->a()).as<loco::TensorShape>();
auto perm_node = loco::must_cast<luci::CircleConst *>(node->perm());
@@ -1556,7 +1607,9 @@ loco::NodeShape infer_transpose(const luci::CircleTranspose *node)
loco::NodeShape infer_transpose_conv(const luci::CircleTransposeConv *node)
{
// TransposeConv's output shape is written in its 'inputSizes' argument
- auto input_sizes_const = loco::must_cast<luci::CircleConst *>(node->inputSizes());
+ auto input_sizes_const = dynamic_cast<luci::CircleConst *>(node->inputSizes());
+ if (not input_sizes_const)
+ return use_own(node);
// TODO support non-const type
LUCI_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype")
LUCI_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4,
@@ -1576,7 +1629,7 @@ loco::NodeShape infer_unpack(const luci::CircleUnpack *node)
// CircleUnpack provides list(array) of Tensors which has one less dimension of the input
// We'll set shape of CircleUnpack to shape of actual outputs
// TODO fix this if any problem rises
- auto value_shape = loco::shape_get(node->value()).as<loco::TensorShape>();
+ auto value_shape = luci::shape_get(node->value()).as<loco::TensorShape>();
auto axis = node->axis();
auto num = node->num();
@@ -1608,9 +1661,25 @@ loco::NodeShape infer_unpack(const luci::CircleUnpack *node)
return loco::NodeShape{output_shape};
}
+loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectionalSequenceLSTM *node)
+{
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ auto recurrent_to_output_weights =
+ luci::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>();
+ auto rank = input_shape.rank();
+ loco::TensorShape output_shape;
+ output_shape.rank(rank);
+ for (uint32_t i = 0; i < rank - 1; i++)
+ {
+ output_shape.dim(i) = input_shape.dim(i);
+ }
+ output_shape.dim(rank - 1) = recurrent_to_output_weights.dim(1);
+ return loco::NodeShape{output_shape};
+}
+
loco::NodeShape infer_unique(const luci::CircleUnique *node)
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
assert(input_shape.rank() == 1);
@@ -1625,7 +1694,7 @@ loco::NodeShape infer_bcq_fully_connected(const luci::CircleBCQFullyConnected *n
{
loco::TensorShape out_shape;
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
auto weights_clusters = loco::must_cast<luci::CircleConst *>(node->weights_clusters());
LUCI_ASSERT(input_shape.rank() == 2, "Input rank of BCQFullyConnected should be 2");
@@ -1648,8 +1717,8 @@ loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node)
loco::TensorShape input_shape;
loco::TensorShape output_shape;
- const auto input_binary_shape = loco::shape_get(node->input_binary()).as<loco::TensorShape>();
- const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
+ const auto input_binary_shape = luci::shape_get(node->input_binary()).as<loco::TensorShape>();
+ const auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
auto axis = node->axis();
auto input_clusters = loco::must_cast<luci::CircleConst *>(node->input_clusters());
@@ -1675,6 +1744,28 @@ loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node)
return loco::NodeShape{output_shape};
}
+loco::NodeShape infer_circle_gru(const luci::CircleGRU *node)
+{
+ loco::TensorShape output_shape;
+
+ const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ const auto state_shape = luci::shape_get(node->state()).as<loco::TensorShape>();
+
+ auto rank = input_shape.rank();
+ assert(rank > 1);
+ output_shape.rank(rank);
+ for (uint32_t i = 0; i < rank - 1; i++)
+ {
+ output_shape.dim(i) = input_shape.dim(i);
+ }
+ output_shape.dim(rank - 1) = state_shape.dim(1);
+
+ if (not node->returnSequences())
+ output_shape.dim(0) = 1;
+
+ return loco::NodeShape{output_shape};
+}
+
// Virtual
loco::NodeShape infer_input(const luci::CircleInput *node)
{
@@ -1696,46 +1787,6 @@ loco::NodeShape infer_output(const luci::CircleOutput *node)
return loco::NodeShape{*output_shape};
}
-loco::NodeShape infer_if_out(const luci::CircleIfOut *node)
-{
- /**
- * @note IF operator type and shape are that of the "then" and "else"
- * Graph Outputs.
- */
- auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
- if (circle_if == nullptr)
- {
- INTERNAL_EXN("CircleIf IR is not configured correctly");
- }
-
- auto index = node->index();
- auto then_graph = circle_if->then_graph();
- auto else_graph = circle_if->else_graph();
- assert(then_graph != nullptr);
- assert(else_graph != nullptr);
-
- // shape and type are assumed to be same
- // these are checked at post_import_graph() in Import
- auto then_outputs = loco::output_nodes(then_graph);
- auto else_outputs = loco::output_nodes(else_graph);
- assert(then_outputs.size() == else_outputs.size());
- assert(index < static_cast<int32_t>(then_outputs.size()));
-
- auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
- auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
-
- auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
- auto else_graph_outputs = else_graph->outputs();
- assert(then_graph_outputs->size() == else_graph_outputs->size());
-
- auto then_graph_output = then_graph_outputs->at(then_out->index());
- auto else_graph_output = else_graph_outputs->at(else_out->index());
- (void)else_graph_output; // make compiler happy for unused variable warnings
- assert(*then_graph_output->shape() == *else_graph_output->shape());
-
- return loco::NodeShape{*then_graph_output->shape()};
-}
-
loco::NodeShape infer_non_max_suppression_v4_out(const luci::CircleNonMaxSuppressionV4Out *node)
{
const loco::DataType S32 = loco::DataType::S32;
@@ -1802,7 +1853,7 @@ loco::NodeShape infer_split_out(const luci::CircleSplitOut *node)
loco::NodeShape unknown;
- auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
+ auto split_shape = luci::shape_get(split).as<loco::TensorShape>();
auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
if (split_dim == nullptr)
@@ -1836,7 +1887,7 @@ loco::NodeShape infer_split_v_out(const luci::CircleSplitVOut *node)
loco::NodeShape unknown;
- auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
+ auto split_shape = luci::shape_get(split).as<loco::TensorShape>();
auto size_splits = dynamic_cast<const luci::CircleConst *>(split->size_splits());
if (size_splits == nullptr)
@@ -1879,7 +1930,7 @@ loco::NodeShape infer_split_v_out(const luci::CircleSplitVOut *node)
assert(0 <= index_this && index_this < split->num_split());
auto split_depth = size_splits->at<S32>(index_this);
if (split_depth == -1)
- split_depth = input_size - size_splits_sum;
+ split_depth = static_cast<int32_t>(input_size) - static_cast<int32_t>(size_splits_sum);
loco::TensorShape output_shape = split_shape;
@@ -1897,7 +1948,7 @@ loco::NodeShape infer_top_k_v2_out(const luci::CircleTopKV2Out *node)
INTERNAL_EXN("CircleSplit IR is not configured correctly");
// shape of topkv2 is same as topkv2->input()
- auto input_shape = loco::shape_get(topkv2).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(topkv2).as<loco::TensorShape>();
auto node_k = loco::must_cast<const luci::CircleConst *>(topkv2->k());
LUCI_ASSERT(node_k->dtype() == S32, "Only support Int32");
@@ -1924,7 +1975,7 @@ loco::NodeShape infer_unique_out(const luci::CircleUniqueOut *node)
}
assert(node->index() == 1);
auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
- auto unique_shape = loco::shape_get(unique->input()).as<loco::TensorShape>();
+ auto unique_shape = luci::shape_get(unique->input()).as<loco::TensorShape>();
assert(unique_shape.rank() == 1);
@@ -1942,7 +1993,7 @@ loco::NodeShape infer_unpack_out(const luci::CircleUnpackOut *node)
INTERNAL_EXN("CircleUnpack IR is not configured correctly");
}
- auto unpack_shape = loco::shape_get(unpack).as<loco::TensorShape>();
+ auto unpack_shape = luci::shape_get(unpack).as<loco::TensorShape>();
return loco::NodeShape{unpack_shape};
}
@@ -1971,7 +2022,7 @@ loco::NodeShape infer_while_out(const luci::CircleWhileOut *node)
auto cond_graph_inputs = cond_graph->inputs();
auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
- auto cond_graph_input_shape = *cond_graph_input->shape();
+ const auto &cond_graph_input_shape = *cond_graph_input->shape();
auto this_shape = own_shape(node);
if (!(this_shape == cond_graph_input_shape))
@@ -1998,9 +2049,9 @@ public:
loco::NodeShape visit(const luci::CircleAddN *node) final { return infer_add_n(node); }
- loco::NodeShape visit(const luci::CircleArgMax *node) final { return infer_arg_max(node); }
+ loco::NodeShape visit(const luci::CircleArgMax *node) final { return infer_arg_maxmin(node); }
- loco::NodeShape visit(const luci::CircleArgMin *node) final { return infer_arg_min(node); }
+ loco::NodeShape visit(const luci::CircleArgMin *node) final { return infer_arg_maxmin(node); }
loco::NodeShape visit(const luci::CircleAveragePool2D *node) final
{
@@ -2009,8 +2060,8 @@ public:
loco::NodeShape visit(const luci::CircleBatchMatMul *node) final
{
- auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
- auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
+ auto x_shape = luci::shape_get(node->x()).as<loco::TensorShape>();
+ auto y_shape = luci::shape_get(node->y()).as<loco::TensorShape>();
return infer_batchmatmul_shape(x_shape, y_shape, node->adj_x(), node->adj_y());
}
@@ -2020,6 +2071,11 @@ public:
return infer_batch_to_space_nd(node);
}
+ loco::NodeShape visit(const luci::CircleBroadcastTo *node) final
+ {
+ return infer_broadcast_to(node);
+ }
+
loco::NodeShape visit(const luci::CircleCast *node) final { return use_x(node); }
loco::NodeShape visit(const luci::CircleCeil *node) final { return use_x(node); }
@@ -2035,8 +2091,12 @@ public:
loco::NodeShape visit(const luci::CircleCos *node) final { return use_x(node); }
+ loco::NodeShape visit(const luci::CircleCumSum *node) final { return use_input(node); }
+
loco::NodeShape visit(const luci::CircleCustom *node) final { return use_own(node); }
+ loco::NodeShape visit(const luci::CircleDensify *node) final { return use_input(node); }
+
loco::NodeShape visit(const luci::CircleDepthToSpace *node) final
{
return infer_depth_to_space(node);
@@ -2047,11 +2107,17 @@ public:
return infer_depthwise_conv2d(node);
}
+ loco::NodeShape visit(const luci::CircleDequantize *node) final
+ {
+ const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ return loco::NodeShape{input_shape};
+ }
+
loco::NodeShape visit(const luci::CircleDiv *node) final { return broadcast_xy(node); }
loco::NodeShape visit(const luci::CircleElu *node) final
{
- auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2065,6 +2131,8 @@ public:
return infer_expand_dims(node);
}
+ loco::NodeShape visit(const luci::CircleFakeQuant *node) final { return use_inputs(node); }
+
loco::NodeShape visit(const luci::CircleFill *node) final { return infer_fill(node); }
loco::NodeShape visit(const luci::CircleFloor *node) final { return use_x(node); }
@@ -2082,15 +2150,29 @@ public:
loco::NodeShape visit(const luci::CircleGatherNd *node) final { return infer_gather_nd(node); }
+ loco::NodeShape visit(const luci::CircleGelu *node) final
+ {
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
+
+ return loco::NodeShape{input_shape};
+ }
+
loco::NodeShape visit(const luci::CircleGreater *node) final { return broadcast_xy(node); }
loco::NodeShape visit(const luci::CircleGreaterEqual *node) final { return broadcast_xy(node); }
+ loco::NodeShape visit(const luci::CircleHardSwish *node) final
+ {
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
+
+ return loco::NodeShape{input_shape};
+ }
+
loco::NodeShape visit(const luci::CircleIf *node) final
{
// Shape of CircleIf is not used. Just use input 0
assert(node->input_count() > 0);
- const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->input(0)).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2103,7 +2185,7 @@ public:
loco::NodeShape visit(const luci::CircleLeakyRelu *node) final
{
- const auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2113,7 +2195,7 @@ public:
loco::NodeShape visit(const luci::CircleLocalResponseNormalization *node) final
{
- const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2162,13 +2244,13 @@ public:
loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4 *node) final
{
- const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>();
+ const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>();
return loco::NodeShape{boxes_shape};
}
loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5 *node) final
{
- const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>();
+ const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>();
return loco::NodeShape{boxes_shape};
}
@@ -2186,6 +2268,12 @@ public:
loco::NodeShape visit(const luci::CirclePRelu *node) final { return infer_p_relu(node); }
+ loco::NodeShape visit(const luci::CircleQuantize *node) final
+ {
+ const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
+ return loco::NodeShape{input_shape};
+ }
+
loco::NodeShape visit(const luci::CircleRange *node) final { return infer_range(node); }
loco::NodeShape visit(const luci::CircleRank *) final
@@ -2222,21 +2310,28 @@ public:
loco::NodeShape visit(const luci::CircleRelu *node) final
{
- auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
+
+ return loco::NodeShape{input_shape};
+ }
+
+ loco::NodeShape visit(const luci::CircleRelu0To1 *node) final
+ {
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
loco::NodeShape visit(const luci::CircleRelu6 *node) final
{
- auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
loco::NodeShape visit(const luci::CircleReluN1To1 *node) final
{
- auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2252,17 +2347,17 @@ public:
loco::NodeShape visit(const luci::CircleResizeBilinear *node) final
{
- return infer_resize_bilinear(node);
+ return infer_resize_type(node);
}
loco::NodeShape visit(const luci::CircleResizeNearestNeighbor *node) final
{
- return infer_resize_nearest_neighbor(node);
+ return infer_resize_type(node);
}
loco::NodeShape visit(const luci::CircleReverseSequence *node) final
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2271,9 +2366,9 @@ public:
loco::NodeShape visit(const luci::CircleReverseV2 *node) final
{
- auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>();
- LUCI_ASSERT(loco::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1,
+ LUCI_ASSERT(luci::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1,
"Tensor must be 1-D");
return loco::NodeShape{input_shape};
@@ -2318,14 +2413,14 @@ public:
loco::NodeShape visit(const luci::CircleSplit *node) final
{
// We'll set Split output as same as input so that SplitOut can handle it's own shape
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
loco::NodeShape visit(const luci::CircleSplitV *node) final
{
// We'll set SplitV output as same as input so that SplitOut can handle it's own shape
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2353,6 +2448,8 @@ public:
return loco::NodeShape{output_shape};
}
+ loco::NodeShape visit(const luci::CircleSVDF *node) final { return infer_svdf(node); }
+
loco::NodeShape visit(const luci::CircleTanh *node) final { return use_x(node); }
loco::NodeShape visit(const luci::CircleTile *node) final { return infer_tile(node); }
@@ -2360,7 +2457,7 @@ public:
loco::NodeShape visit(const luci::CircleTopKV2 *node) final
{
// set shape of this node as same as input
- const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2373,6 +2470,11 @@ public:
loco::NodeShape visit(const luci::CircleUnpack *node) final { return infer_unpack(node); }
+ loco::NodeShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final
+ {
+ return infer_unidirectionalsequencelstm(node);
+ }
+
loco::NodeShape visit(const luci::CircleUnique *node) final { return infer_unique(node); }
loco::NodeShape visit(const luci::CircleWhere *node) final { return use_own(node); }
@@ -2381,13 +2483,13 @@ public:
{
// Shape of CircleWhile is not used. Just use input 0
assert(node->arity() > 0);
- const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>();
+ const auto input_shape = luci::shape_get(node->input(0)).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
loco::NodeShape visit(const luci::CircleZerosLike *node) final
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
@@ -2402,11 +2504,13 @@ public:
loco::NodeShape visit(const luci::CircleInstanceNorm *node) final
{
- auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
+ auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
return loco::NodeShape{input_shape};
}
+ loco::NodeShape visit(const luci::CircleGRU *node) final { return infer_circle_gru(node); }
+
// Virtual
loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); }
@@ -2418,8 +2522,6 @@ public:
loco::NodeShape visit(const luci::CircleCustomOut *node) final { return use_own(node); }
- loco::NodeShape visit(const luci::CircleIfOut *node) final { return infer_if_out(node); }
-
loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final
{
return infer_non_max_suppression_v4_out(node);
@@ -2443,6 +2545,8 @@ public:
loco::NodeShape visit(const luci::CircleUnpackOut *node) final { return infer_unpack_out(node); }
+ loco::NodeShape visit(const luci::CircleVariable *node) final { return use_own(node); }
+
loco::NodeShape visit(const luci::CircleWhileOut *node) final { return infer_while_out(node); }
};