diff options
Diffstat (limited to 'runtimes/neurun/core/src/backend/CustomKernel.cc')
-rw-r--r-- | runtimes/neurun/core/src/backend/CustomKernel.cc | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/runtimes/neurun/core/src/backend/CustomKernel.cc b/runtimes/neurun/core/src/backend/CustomKernel.cc new file mode 100644 index 000000000..198e223cf --- /dev/null +++ b/runtimes/neurun/core/src/backend/CustomKernel.cc @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/CustomKernel.h" + +namespace neurun +{ +namespace backend +{ +namespace custom +{ + +// TODO move this elsewhere +class APIConverter +{ +public: + static nnfw_operand convertOperand(void *alloc, const TypeInfo &type) + { + nnfw_operand api_operand; + api_operand.allocation = alloc; + api_operand.type = convertType(type); + return api_operand; + } + + static nnfw_tensorinfo convertType(const TypeInfo &type) + { + nnfw_tensorinfo api_type; + api_type.rank = type.shape.rank(); + assert(type.shape.rank() <= 6); + std::copy(type.shape.dims().begin(), type.shape.dims().end(), std::begin(api_type.dims)); + + switch (type.dtype) + { + case model::DataType::FLOAT32: + api_type.dtype = NNFW_TYPE_TENSOR_FLOAT32; + break; + case model::DataType::INT32: + api_type.dtype = NNFW_TYPE_TENSOR_INT32; + break; + case model::DataType::QUANT8_ASYMM: + api_type.dtype = NNFW_TYPE_TENSOR_QUANT8_ASYMM; + break; + case model::DataType::BOOL8: + api_type.dtype = NNFW_TYPE_TENSOR_BOOL; + break; + default: + throw std::runtime_error("Unsupported tensor datatype"); + } + return api_type; + } +}; + +Kernel::Kernel(const nnfw_custom_eval evalFunction) + : _params(), _userdata(nullptr), _userdata_size(0), _evalFunction(evalFunction) +{ +} + +void Kernel::configure(Kernel::CustomKernelConfigParams &&inParams) +{ + _userdata = inParams.userdata; + _userdata_size = inParams.userdata_size; + + _params.ninputs = inParams.input_allocations.size(); + _params.inputs = new nnfw_operand[_params.ninputs]; + for (size_t i = 0; i < _params.ninputs; ++i) + { + _params.inputs[i] = + APIConverter::convertOperand(inParams.input_allocations[i], inParams.input_types[i]); + } + + _params.noutputs = inParams.output_allocations.size(); + _params.outputs = new nnfw_operand[_params.noutputs]; + for (size_t i = 0; i < _params.noutputs; ++i) + { + _params.outputs[i] = + APIConverter::convertOperand(inParams.output_allocations[i], inParams.output_types[i]); + } +} + +void Kernel::run() { _evalFunction(&_params, _userdata, _userdata_size); } + +} // namespace custom +} // namespace backend +} // namespace neurun |