diff options
Diffstat (limited to 'onert-micro/luci-interpreter/src/kernels/Shape.cpp')
-rw-r--r-- | onert-micro/luci-interpreter/src/kernels/Shape.cpp | 55 |
1 files changed, 15 insertions, 40 deletions
diff --git a/onert-micro/luci-interpreter/src/kernels/Shape.cpp b/onert-micro/luci-interpreter/src/kernels/Shape.cpp index 2f16ac884..31a0b62bc 100644 --- a/onert-micro/luci-interpreter/src/kernels/Shape.cpp +++ b/onert-micro/luci-interpreter/src/kernels/Shape.cpp @@ -1,6 +1,5 @@ /* - * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved - * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,57 +14,33 @@ * limitations under the License. */ -#include "kernels/Shape.h" +#include "Builders.h" +#include "SISOKernel.h" #include "kernels/Utils.h" namespace luci_interpreter { -namespace kernels +void configure_kernel_CircleShape(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph) { - -ShapeKernel::ShapeKernel(const Tensor *input, Tensor *output, const ShapeParams ¶ms) - : KernelWithParams<ShapeParams>({input}, {output}, params) -{ -} - -void ShapeKernel::configure() -{ - LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S32 or - output()->element_type() == DataType::S64); - const auto input_shape = input()->shape(); - - Shape output_shape(1); - output_shape.dim(0) = input_shape.num_dims(); - // TODO: enable it only if kernel with dynamic shapes - output()->resize(output_shape); + kernels::SISOKernel kernel(cur_op, runtime_graph); + LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.output()) == DataType::S32); } -void ShapeKernel::execute() const +void execute_kernel_CircleShape(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph) { - switch (params().out_type) - { - case DataType::S32: - evalInt<int32_t>(); - break; - case DataType::S64: - evalInt<int64_t>(); - break; - default: - assert(false && "Unsupported type."); - } -} + kernels::SISOKernel kernel(cur_op, runtime_graph); -template <typename T> void ShapeKernel::evalInt() const -{ - const auto input_shape = input()->shape(); + const circle::Tensor *input = kernel.input(); + const circle::Tensor *output = kernel.output(); - auto output_data = getTensorData<T>(output()); + assert(Tensor::element_type(output) == DataType::S32); + int32_t *output_data = kernels::getTensorData<int32_t>(runtime_graph->getDataByTensor(output)); - for (int i = 0; i < input_shape.num_dims(); ++i) + const int rank = Tensor::num_dims(input); + for (int i = 0; i < rank; ++i) { - output_data[i] = input_shape.dim(i); + output_data[i] = Tensor::dim(input, i); } } -} // namespace kernels } // namespace luci_interpreter |