diff options
Diffstat (limited to 'onert-micro/luci-interpreter/src/kernels/Reshape.cpp')
-rw-r--r-- | onert-micro/luci-interpreter/src/kernels/Reshape.cpp | 53 |
1 files changed, 47 insertions, 6 deletions
diff --git a/onert-micro/luci-interpreter/src/kernels/Reshape.cpp b/onert-micro/luci-interpreter/src/kernels/Reshape.cpp index ba47df4ab..7fe3e5636 100644 --- a/onert-micro/luci-interpreter/src/kernels/Reshape.cpp +++ b/onert-micro/luci-interpreter/src/kernels/Reshape.cpp @@ -16,6 +16,7 @@ */ #include "Builders.h" +#include "Utils.h" #include <cassert> #include <cstring> @@ -23,35 +24,75 @@ namespace luci_interpreter { -void configure_kernel_CircleReshape(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph) +void configure_kernel_CircleReshape(const circle::Operator *, BaseRuntimeGraph *) { // Do nothing } -void execute_kernel_CircleReshape(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph, - bool is_inplace) +// TODO: reduce code duplication with ExpandDims +void execute_kernel_CircleReshape(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph) { const auto input_index = cur_op->inputs()->operator[](0); + const auto shape_index = cur_op->inputs()->operator[](1); const auto output_index = cur_op->outputs()->operator[](0); assert(input_index != -1); + assert(shape_index != -1); assert(output_index != -1); const auto input = runtime_graph->getCircleTensorByIndex(input_index); + const auto shape = runtime_graph->getCircleTensorByIndex(shape_index); const auto output = runtime_graph->getCircleTensorByIndex(output_index); - + bool is_inplace = runtime_graph->is_inplace_op(cur_op); if (is_inplace) { runtime_graph->makeInplaceOperation(input, output); return; } - const auto input_data = (runtime_graph->getDataByTensor(input)); - auto output_data = (runtime_graph->getDataByTensor(output)); + const auto input_data = runtime_graph->getDataByTensor(input); + auto shape_data = runtime_graph->getConstDataByTensor(shape); + auto output_data = runtime_graph->getDataByTensor(output); assert(input_data != nullptr); assert(output_data != nullptr); +#ifndef DIS_DYN_SHAPES + if (shape_data == nullptr) + { + shape_data = runtime_graph->getDataByTensor(shape); + assert(shape_data != nullptr); + + assert(Tensor::element_type(shape) == DataType::S32); + + const int32_t *shape_data_int = kernels::getTensorData<int32_t>(shape_data); + const auto num_elements = Tensor::num_elements(shape); + + luci_interpreter::RuntimeShape dynamic_shape(num_elements); + int32_t data_size = 1; + for (int i = 0; i < num_elements; ++i) + { + dynamic_shape.setDim(i, shape_data_int[i]); + data_size *= shape_data_int[i]; + } + data_size *= size(Tensor::element_type(output)); + + runtime_graph->addDynamicShapeTensor(output, std::move(dynamic_shape)); + + if (data_size == 0) + { + runtime_graph->resetTensorData(nullptr, output); + return; + } + + auto new_output_data = new uint8_t[data_size]; + output_data = new_output_data; + runtime_graph->resetTensorData(new_output_data, output); + } +#else + assert(shape_data != nullptr); +#endif // DIS_DYN_SHAPES + const size_t element_size = getDataTypeSize(Tensor::element_type(input)); const int32_t num_elements = Tensor::num_elements(input); std::memcpy(output_data, input_data, num_elements * element_size); |