/* * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "GatherLayer.h" #include #include "OperationUtils.h" namespace neurun { namespace backend { namespace cpu { namespace kernel { void GatherLayer::configure(uint8_t *inputData, const TensorDescriptor &inputDescr, uint8_t *indicesData, const TensorDescriptor &indicesDescr, uint8_t *outputData, const TensorDescriptor &outputDescr, int32_t axis) { _inputData.u8 = inputData; _inputDescr = inputDescr; _indicesData.u8 = indicesData; _indicesDescr = indicesDescr; _axis = axis; _inputType = inputDescr.type; _outputData.u8 = outputData; _outputDescr = outputDescr; } void GatherLayer::run() { nnfw::cker::GatherParams op_params; op_params.axis = _axis; switch (_inputType) { case OperandType::FLOAT32: nnfw::cker::Gather(op_params, convertTensorDescriptorToCkerShape(_inputDescr), _inputData.f, convertTensorDescriptorToCkerShape(_indicesDescr), _indicesData.i32, convertTensorDescriptorToCkerShape(_outputDescr), _outputData.f); break; case OperandType::QUANT8_ASYMM: nnfw::cker::Gather(op_params, convertTensorDescriptorToCkerShape(_inputDescr), _inputData.u8, convertTensorDescriptorToCkerShape(_indicesDescr), _indicesData.i32, convertTensorDescriptorToCkerShape(_outputDescr), _outputData.u8); break; case OperandType::INT32: nnfw::cker::Gather( op_params, convertTensorDescriptorToCkerShape(_inputDescr), _inputData.i32, convertTensorDescriptorToCkerShape(_indicesDescr), _indicesData.i32, convertTensorDescriptorToCkerShape(_outputDescr), _outputData.i32); break; default: throw std::runtime_error("Gather NYI for this operand type!"); } } } // namespace kernel } // namespace cpu } // namespace backend } // namespace neurun