diff options
Diffstat (limited to 'libs/tflite/src/ext/kernels/Abs.cpp')
-rw-r--r-- | libs/tflite/src/ext/kernels/Abs.cpp | 103 |
1 files changed, 103 insertions, 0 deletions
diff --git a/libs/tflite/src/ext/kernels/Abs.cpp b/libs/tflite/src/ext/kernels/Abs.cpp new file mode 100644 index 000000000..7e9c2338d --- /dev/null +++ b/libs/tflite/src/ext/kernels/Abs.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tflite/ext/kernels/Abs.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +#include <iostream> +#include <cmath> + +namespace nnfw +{ +namespace tflite +{ +namespace custom +{ +namespace Abs +{ + +void *InitAbs(TfLiteContext *context, const char *buffer, size_t length) { return nullptr; } + +void FreeAbs(TfLiteContext *context, void *buffer) {} + +TfLiteStatus PrepareAbs(TfLiteContext *context, TfLiteNode *node) +{ + TF_LITE_ENSURE_EQ(context, ::tflite::NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, ::tflite::NumOutputs(node), 1); + + const TfLiteTensor *input = ::tflite::GetInput(context, node, 0); + TfLiteTensor *output = ::tflite::GetOutput(context, node, 0); + + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus EvalAbs(TfLiteContext *context, TfLiteNode *node) +{ + const TfLiteTensor *input = ::tflite::GetInput(context, node, 0); + TfLiteTensor *output = ::tflite::GetOutput(context, node, 0); + size_t elements = ::tflite::NumElements(input); + switch (input->type) + { + case kTfLiteFloat32: + { + auto *in = input->data.f; + auto *in_end = in + elements; + auto *out = output->data.f; + for (; in < in_end; in++, out++) + *out = std::abs(*in); + return kTfLiteOk; + } + case kTfLiteInt32: + { + auto *in = input->data.i32; + auto *in_end = in + elements; + auto *out = output->data.i32; + for (; in < in_end; in++, out++) + *out = std::abs(*in); + return kTfLiteOk; + } + case kTfLiteInt64: + { + auto *in = input->data.i64; + auto *in_end = in + elements; + auto *out = output->data.i64; + for (; in < in_end; in++, out++) + *out = std::abs(*in); + return kTfLiteOk; + } + case kTfLiteUInt8: + { + auto *in = input->data.uint8; + auto *in_end = in + elements; + auto *out = output->data.uint8; + for (; in < in_end; in++, out++) + *out = std::abs(*in); + return kTfLiteOk; + } + default: + { + context->ReportError(context, "Input type %d is not supported", input->type); + return kTfLiteError; + } + } +} + +} // namespace Abs +} // namespace custom +} // namespace tflite +} // namespace nnfw |