summaryrefslogtreecommitdiff
path: root/libs/tflite/src/ext/kernels/Abs.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'libs/tflite/src/ext/kernels/Abs.cpp')
-rw-r--r--libs/tflite/src/ext/kernels/Abs.cpp103
1 files changed, 103 insertions, 0 deletions
diff --git a/libs/tflite/src/ext/kernels/Abs.cpp b/libs/tflite/src/ext/kernels/Abs.cpp
new file mode 100644
index 000000000..7e9c2338d
--- /dev/null
+++ b/libs/tflite/src/ext/kernels/Abs.cpp
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "tflite/ext/kernels/Abs.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+#include <iostream>
+#include <cmath>
+
+namespace nnfw
+{
+namespace tflite
+{
+namespace custom
+{
+namespace Abs
+{
+
+void *InitAbs(TfLiteContext *context, const char *buffer, size_t length) { return nullptr; }
+
+void FreeAbs(TfLiteContext *context, void *buffer) {}
+
+TfLiteStatus PrepareAbs(TfLiteContext *context, TfLiteNode *node)
+{
+ TF_LITE_ENSURE_EQ(context, ::tflite::NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, ::tflite::NumOutputs(node), 1);
+
+ const TfLiteTensor *input = ::tflite::GetInput(context, node, 0);
+ TfLiteTensor *output = ::tflite::GetOutput(context, node, 0);
+
+ TF_LITE_ENSURE_EQ(context, input->type, output->type);
+
+ return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus EvalAbs(TfLiteContext *context, TfLiteNode *node)
+{
+ const TfLiteTensor *input = ::tflite::GetInput(context, node, 0);
+ TfLiteTensor *output = ::tflite::GetOutput(context, node, 0);
+ size_t elements = ::tflite::NumElements(input);
+ switch (input->type)
+ {
+ case kTfLiteFloat32:
+ {
+ auto *in = input->data.f;
+ auto *in_end = in + elements;
+ auto *out = output->data.f;
+ for (; in < in_end; in++, out++)
+ *out = std::abs(*in);
+ return kTfLiteOk;
+ }
+ case kTfLiteInt32:
+ {
+ auto *in = input->data.i32;
+ auto *in_end = in + elements;
+ auto *out = output->data.i32;
+ for (; in < in_end; in++, out++)
+ *out = std::abs(*in);
+ return kTfLiteOk;
+ }
+ case kTfLiteInt64:
+ {
+ auto *in = input->data.i64;
+ auto *in_end = in + elements;
+ auto *out = output->data.i64;
+ for (; in < in_end; in++, out++)
+ *out = std::abs(*in);
+ return kTfLiteOk;
+ }
+ case kTfLiteUInt8:
+ {
+ auto *in = input->data.uint8;
+ auto *in_end = in + elements;
+ auto *out = output->data.uint8;
+ for (; in < in_end; in++, out++)
+ *out = std::abs(*in);
+ return kTfLiteOk;
+ }
+ default:
+ {
+ context->ReportError(context, "Input type %d is not supported", input->type);
+ return kTfLiteError;
+ }
+ }
+}
+
+} // namespace Abs
+} // namespace custom
+} // namespace tflite
+} // namespace nnfw