summaryrefslogtreecommitdiff
path: root/compiler/enco/frontend/tflite/src/Op/Padding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/enco/frontend/tflite/src/Op/Padding.cpp')
-rw-r--r--compiler/enco/frontend/tflite/src/Op/Padding.cpp105
1 files changed, 105 insertions, 0 deletions
diff --git a/compiler/enco/frontend/tflite/src/Op/Padding.cpp b/compiler/enco/frontend/tflite/src/Op/Padding.cpp
new file mode 100644
index 000000000..9a0e4ef41
--- /dev/null
+++ b/compiler/enco/frontend/tflite/src/Op/Padding.cpp
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Padding.h"
+
+#include "Convert.h"
+#include "TensorBags.h"
+
+#include <coco/IR/Data.h>
+#include <coco/IR/Module.h>
+
+#include <nncc/core/ADT/tensor/Shape.h>
+#include <schema_generated.h>
+
+#include <map>
+#include <sstream>
+#include <algorithm>
+#include <cassert>
+
+using namespace nncc::core::ADT;
+
+namespace tflimport
+{
+
+coco::Padding2D get_padding(const tensor::Shape &ifm_shape, const int kernel_w, const int kernel_h,
+ tflite::Padding padding, int stride_w, int stride_h,
+ int dilation_w_factor, int dilation_h_factor)
+{
+ assert(stride_w != 0);
+ assert(stride_h != 0);
+ assert(ifm_shape.rank() == 4);
+
+ /**
+ * Compute [top padding + bottom padding] (or [left padding + right padding]).
+ * If this returns an even number, top = return value / 2 and bottom = return value - top
+ * If this returns an odd number, top = return value / 2 and bottom = return value - top (so,
+ * bottom = top + 1)
+ *
+ * Code based on https://www.tensorflow.org/api_guides/python/nn#Convolution
+ */
+ auto compute_padding = [](tflite::Padding padding, int stride, int dilation_rate, int in_size,
+ int filter_size) {
+ int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
+ if (padding == tflite::Padding_SAME)
+ {
+ if (in_size % stride == 0)
+ return std::max(effective_filter_size - stride, 0);
+ else
+ return std::max(effective_filter_size - (in_size % stride), 0);
+ }
+ else // padding == VALID
+ {
+ return 0;
+ }
+ };
+
+ // ifm shape is from order of NHWC. ifm W = dim(2), ifm H = dim(1)
+ int padding_w = compute_padding(padding, stride_w, dilation_w_factor, ifm_shape.dim(2), kernel_w);
+ int padding_h = compute_padding(padding, stride_h, dilation_h_factor, ifm_shape.dim(1), kernel_h);
+
+ coco::Padding2D coco_padding;
+ coco_padding.top(padding_h / 2).bottom(padding_h - padding_h / 2);
+ coco_padding.left(padding_w / 2).right(padding_w - padding_w / 2);
+
+ return coco_padding;
+}
+
+coco::Padding2D pool2D_padding(const tflite::Pool2DOptions *options, const tensor::Shape &ifm_shape,
+ const int filter_w, const int filter_h)
+{
+ return get_padding(ifm_shape, filter_w, filter_h, options->padding(), options->stride_w(),
+ options->stride_h(), 1, 1);
+}
+
+coco::Padding2D conv2D_padding(const tflite::Conv2DOptions *options, const tensor::Shape &ifm_shape,
+ const tensor::Shape &kernel_shape)
+{
+ return get_padding(ifm_shape, kernel_shape.dim(2), kernel_shape.dim(1), /* kernel layout: NHWC */
+ options->padding(), options->stride_w(), options->stride_h(),
+ options->dilation_w_factor(), options->dilation_h_factor());
+}
+
+coco::Padding2D depthwiseConv2D_padding(const tflite::DepthwiseConv2DOptions *options,
+ const tensor::Shape &ifm_shape,
+ const tensor::Shape &kernel_shape)
+{
+ return get_padding(ifm_shape, kernel_shape.dim(2), kernel_shape.dim(1), /* kernel layout: NHWC */
+ options->padding(), options->stride_w(), options->stride_h(),
+ options->dilation_w_factor(), options->dilation_h_factor());
+}
+
+} // namespace tflimport