summaryrefslogtreecommitdiff
path: root/onert-micro/luci-interpreter/src/kernels/Reshape.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'onert-micro/luci-interpreter/src/kernels/Reshape.cpp')
-rw-r--r--onert-micro/luci-interpreter/src/kernels/Reshape.cpp53
1 files changed, 47 insertions, 6 deletions
diff --git a/onert-micro/luci-interpreter/src/kernels/Reshape.cpp b/onert-micro/luci-interpreter/src/kernels/Reshape.cpp
index ba47df4ab..7fe3e5636 100644
--- a/onert-micro/luci-interpreter/src/kernels/Reshape.cpp
+++ b/onert-micro/luci-interpreter/src/kernels/Reshape.cpp
@@ -16,6 +16,7 @@
*/
#include "Builders.h"
+#include "Utils.h"
#include <cassert>
#include <cstring>
@@ -23,35 +24,75 @@
namespace luci_interpreter
{
-void configure_kernel_CircleReshape(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
+void configure_kernel_CircleReshape(const circle::Operator *, BaseRuntimeGraph *)
{
// Do nothing
}
-void execute_kernel_CircleReshape(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
- bool is_inplace)
+// TODO: reduce code duplication with ExpandDims
+void execute_kernel_CircleReshape(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
{
const auto input_index = cur_op->inputs()->operator[](0);
+ const auto shape_index = cur_op->inputs()->operator[](1);
const auto output_index = cur_op->outputs()->operator[](0);
assert(input_index != -1);
+ assert(shape_index != -1);
assert(output_index != -1);
const auto input = runtime_graph->getCircleTensorByIndex(input_index);
+ const auto shape = runtime_graph->getCircleTensorByIndex(shape_index);
const auto output = runtime_graph->getCircleTensorByIndex(output_index);
-
+ bool is_inplace = runtime_graph->is_inplace_op(cur_op);
if (is_inplace)
{
runtime_graph->makeInplaceOperation(input, output);
return;
}
- const auto input_data = (runtime_graph->getDataByTensor(input));
- auto output_data = (runtime_graph->getDataByTensor(output));
+ const auto input_data = runtime_graph->getDataByTensor(input);
+ auto shape_data = runtime_graph->getConstDataByTensor(shape);
+ auto output_data = runtime_graph->getDataByTensor(output);
assert(input_data != nullptr);
assert(output_data != nullptr);
+#ifndef DIS_DYN_SHAPES
+ if (shape_data == nullptr)
+ {
+ shape_data = runtime_graph->getDataByTensor(shape);
+ assert(shape_data != nullptr);
+
+ assert(Tensor::element_type(shape) == DataType::S32);
+
+ const int32_t *shape_data_int = kernels::getTensorData<int32_t>(shape_data);
+ const auto num_elements = Tensor::num_elements(shape);
+
+ luci_interpreter::RuntimeShape dynamic_shape(num_elements);
+ int32_t data_size = 1;
+ for (int i = 0; i < num_elements; ++i)
+ {
+ dynamic_shape.setDim(i, shape_data_int[i]);
+ data_size *= shape_data_int[i];
+ }
+ data_size *= size(Tensor::element_type(output));
+
+ runtime_graph->addDynamicShapeTensor(output, std::move(dynamic_shape));
+
+ if (data_size == 0)
+ {
+ runtime_graph->resetTensorData(nullptr, output);
+ return;
+ }
+
+ auto new_output_data = new uint8_t[data_size];
+ output_data = new_output_data;
+ runtime_graph->resetTensorData(new_output_data, output);
+ }
+#else
+ assert(shape_data != nullptr);
+#endif // DIS_DYN_SHAPES
+
const size_t element_size = getDataTypeSize(Tensor::element_type(input));
const int32_t num_elements = Tensor::num_elements(input);
std::memcpy(output_data, input_data, num_elements * element_size);