diff options
Diffstat (limited to 'libs/support/tflite/src/kernels/SquaredDifference.cpp')
-rw-r--r-- | libs/support/tflite/src/kernels/SquaredDifference.cpp | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/libs/support/tflite/src/kernels/SquaredDifference.cpp b/libs/support/tflite/src/kernels/SquaredDifference.cpp new file mode 100644 index 000000000..25e10a8ed --- /dev/null +++ b/libs/support/tflite/src/kernels/SquaredDifference.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "support/tflite/kernels/SquaredDifference.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +#include <iostream> + +namespace tflite +{ +namespace ops +{ +namespace custom +{ +namespace nnfw +{ +namespace SquaredDifference +{ + +void *InitSquaredDifference(TfLiteContext *context, const char *buffer, size_t length) +{ + return nullptr; +} + +void FreeSquaredDifference(TfLiteContext *context, void *buffer) {} + +TfLiteStatus PrepareSquaredDifference(TfLiteContext *context, TfLiteNode *node) +{ + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor *input1 = GetInput(context, node, 0); + const TfLiteTensor *input2 = GetInput(context, node, 1); + TfLiteTensor *output = GetOutput(context, node, 0); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + + return context->ResizeTensor(context, output, TfLiteIntArrayCopy(input1->dims)); +} + +TfLiteStatus EvalSquaredDifference(TfLiteContext *context, TfLiteNode *node) +{ + + const TfLiteTensor *input1 = GetInput(context, node, 0); + const TfLiteTensor *input2 = GetInput(context, node, 1); + + TfLiteTensor *output = GetOutput(context, node, 0); + + size_t elements = NumElements(input1); + + switch (input1->type) + { + case kTfLiteFloat32: + { + const float *in1 = input1->data.f; + const float *in2 = input2->data.f; + const float *in_end1 = in1 + elements; + float *out = output->data.f; + + for (; in1 < in_end1; in1++, in2++, out++) + *out = ((*in1 - *in2) * (*in1 - *in2)); + + return kTfLiteOk; + } + case kTfLiteInt32: + { + const int *in1 = input1->data.i32; + const int *in2 = input2->data.i32; + const int *in_end1 = in1 + elements; + int *out = output->data.i32; + + for (; in1 < in_end1; in1++, in2++, out++) + *out = ((*in1 - *in2) * (*in1 - *in2)); + + return kTfLiteOk; + } + case kTfLiteInt64: + { + const int64_t *in1 = input1->data.i64; + const int64_t *in2 = input1->data.i64; + const int64_t *in_end1 = in1 + elements; + int64_t *out = output->data.i64; + + for (; in1 < in_end1; in1++, in2++, out++) + *out = ((*in1 - *in2) * (*in1 - *in2)); + + return kTfLiteOk; + } + default: + { + context->ReportError(context, "InputType is %d Unsupported", input1->type); + return kTfLiteError; + } + } +} + +} // namespace SquaredDifference +} // nnfw +} // namespace custom +} // namespace ops +} // namespace tflite |