/* * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved * Copyright 2017 The TensorFlow Authors. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "kernels/Cast.h" #include "kernels/Utils.h" namespace { using namespace luci_interpreter; using namespace luci_interpreter::kernels; template void cast_data(const InT *in_data, OutT *out_data, uint32_t elements_count) { std::transform(in_data, in_data + elements_count, out_data, [](InT a) { return static_cast(a); }); } template void cast_from_pointer_to_tensor(const InT *in_data, Tensor *out_tensor) { auto const out_type = out_tensor->element_type(); auto const elements_count = out_tensor->shape().num_elements(); switch (out_type) { case DataType::U8: cast_data(in_data, getTensorData(out_tensor), elements_count); break; case DataType::U16: cast_data(in_data, getTensorData(out_tensor), elements_count); break; case DataType::U32: cast_data(in_data, getTensorData(out_tensor), elements_count); break; case DataType::U64: cast_data(in_data, getTensorData(out_tensor), elements_count); break; case DataType::S8: cast_data(in_data, getTensorData(out_tensor), elements_count); break; case DataType::S16: cast_data(in_data, getTensorData(out_tensor), elements_count); break; case DataType::S32: cast_data(in_data, getTensorData(out_tensor), elements_count); break; case DataType::S64: cast_data(in_data, getTensorData(out_tensor), elements_count); break; case DataType::FLOAT32: cast_data(in_data, getTensorData(out_tensor), elements_count); break; case DataType::BOOL: cast_data(in_data, getTensorData(out_tensor), elements_count); break; default: assert(false && "Unsupported output type."); } } void cast_from_tensor_to_tensor(const Tensor *in_tensor, Tensor *out_tensor) { auto in_type = in_tensor->element_type(); switch (in_type) { case DataType::U8: cast_from_pointer_to_tensor(getTensorData(in_tensor), out_tensor); break; case DataType::U16: cast_from_pointer_to_tensor(getTensorData(in_tensor), out_tensor); break; case DataType::U32: cast_from_pointer_to_tensor(getTensorData(in_tensor), out_tensor); break; case DataType::U64: cast_from_pointer_to_tensor(getTensorData(in_tensor), out_tensor); break; case DataType::S8: cast_from_pointer_to_tensor(getTensorData(in_tensor), out_tensor); break; case DataType::S16: cast_from_pointer_to_tensor(getTensorData(in_tensor), out_tensor); break; case DataType::S32: cast_from_pointer_to_tensor(getTensorData(in_tensor), out_tensor); break; case DataType::S64: cast_from_pointer_to_tensor(getTensorData(in_tensor), out_tensor); break; case DataType::FLOAT32: cast_from_pointer_to_tensor(getTensorData(in_tensor), out_tensor); break; case DataType::BOOL: cast_from_pointer_to_tensor(getTensorData(in_tensor), out_tensor); break; default: assert(false && "Unsupported input type."); } } } // namespace namespace luci_interpreter { namespace kernels { Cast::Cast(const Tensor *input, Tensor *output) : Kernel({input}, {output}) {} void Cast::configure() { LUCI_INTERPRETER_CHECK(input()->element_type() != DataType::Unknown); LUCI_INTERPRETER_CHECK(output()->element_type() != DataType::Unknown); const Shape &shape = input()->shape(); // TODO: enable it only if kernel with dynamic shapes output()->resize(shape); } void Cast::execute() const { assert(input()->shape().num_elements() == output()->shape().num_elements()); cast_from_tensor_to_tensor(input(), output()); } } // namespace kernels } // namespace luci_interpreter