summaryrefslogtreecommitdiff
path: root/compiler/luci-interpreter/src/kernels/MirrorPad.cpp
blob: 2fbeefce4cdaa5c69001acba8cd0e5738e357e48 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
/*
 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "kernels/MirrorPad.h"

#include "kernels/Utils.h"

namespace luci_interpreter
{
namespace kernels
{

MirrorPad::MirrorPad(const Tensor *input, const Tensor *paddings, Tensor *output,
                     const MirrorPadParams &params)
  : KernelWithParams<MirrorPadParams>({input, paddings}, {output}, params)
{
}

void MirrorPad::configure()
{
  const Shape &input_shape = input()->shape();
  const int num_dims = input_shape.num_dims();

  if (num_dims > 4)
    throw std::runtime_error("Unsupported number of dimensions.");

  assert(output()->element_type() == input()->element_type());
  assert(paddings()->element_type() == DataType::S32);
  // Paddings shape should be [N, 2].
  assert(paddings()->shape().num_dims() == 2);
  assert(paddings()->shape().dim(0) == num_dims);
  assert(paddings()->shape().dim(1) == 2);

  Shape output_shape(num_dims);
  const auto *paddings_data = getTensorData<int32_t>(paddings());
  for (int i = 0; i < num_dims; ++i)
  {
    const int32_t padding_before = paddings_data[i * 2];
    const int32_t padding_after = paddings_data[i * 2 + 1];
    assert(padding_before >= 0 && padding_after >= 0);
    output_shape.dim(i) = input_shape.dim(i) + padding_before + padding_after;
  }

  output()->resize(output_shape);
}

template <typename T>
inline void MirrorPadImpl(const Tensor &input, const Tensor &paddings, MirrorPadMode mode,
                          Tensor &output);

void MirrorPad::execute() const
{
  switch (input()->element_type())
  {
    case DataType::FLOAT32:
    {
      MirrorPadImpl<float>(*input(), *paddings(), params().mode, *output());
      break;
    }
    case DataType::U8:
    {
      assert(output()->zero_point() >= std::numeric_limits<uint8_t>::min());
      assert(output()->zero_point() <= std::numeric_limits<uint8_t>::max());

      MirrorPadImpl<uint8_t>(*input(), *paddings(), params().mode, *output());
      break;
    }
    default:
      throw std::runtime_error("Unsupported type.");
  }
}

template <typename T>
inline void MirrorPadImpl(const Tensor &input, const Tensor &paddings, MirrorPadMode mode,
                          Tensor &output)
{
  auto const input_dims = input.shape().num_dims();
  auto const input_data = input.data<T>();
  auto const paddings_data = paddings.data<int32_t>();
  auto const output_data = output.data<T>();

  auto const input_b = input_dims > 3 ? input.shape().dim(input_dims - 4) : 1;
  auto const input_h = input_dims > 2 ? input.shape().dim(input_dims - 3) : 1;
  auto const input_w = input_dims > 1 ? input.shape().dim(input_dims - 2) : 1;
  auto const input_d = input.shape().dim(input_dims - 1);

  auto const input_h_offset = input_d * input_w;
  auto const input_b_offset = input_h_offset * input_h;

  auto const output_b = input_dims > 3 ? output.shape().dim(input_dims - 4) : 1;
  auto const output_h = input_dims > 2 ? output.shape().dim(input_dims - 3) : 1;
  auto const output_w = input_dims > 1 ? output.shape().dim(input_dims - 2) : 1;
  auto const output_d = output.shape().dim(input_dims - 1);

  auto const left_b_pad = paddings_data[2 * (input_dims - 4)];
  auto const left_h_pad = paddings_data[2 * (input_dims - 3)];
  auto const left_w_pad = paddings_data[2 * (input_dims - 2)];
  auto const left_d_pad = paddings_data[2 * (input_dims - 1)];

  auto const right_b_pad = paddings_data[2 * (input_dims - 4) + 1];
  auto const right_h_pad = paddings_data[2 * (input_dims - 3) + 1];
  auto const right_w_pad = paddings_data[2 * (input_dims - 2) + 1];
  auto const right_d_pad = paddings_data[2 * (input_dims - 1) + 1];

  const auto positive_mod = [](auto a, auto b) { return (a % b + b) % b; };
  const auto offset_index = [input_d, input_h_offset, input_b_offset](auto d, auto w, auto h,
                                                                      auto b) {
    return d + w * input_d + h * input_h_offset + b * input_b_offset;
  };

  const auto symmetric_dim = [&positive_mod](auto i, auto left_pad, auto input) {
    bool reflected = (((i < left_pad ? i + 1 - input : i) - left_pad) / input & 1) == 1;
    return positive_mod(reflected ? input + left_pad - i - 1 : i - left_pad, input);
  };

  const T *in_ptr = input_data;
  T *out_ptr = output_data;

  for (int32_t b = 0; b < output_b; ++b)
  {
    for (int32_t h = 0; h < output_h; ++h)
    {
      for (int32_t w = 0; w < output_w; ++w)
      {
        for (int32_t d = 0; d < output_d; ++d)
        {
          if (b < left_b_pad || b >= output_b - right_b_pad || //
              h < left_h_pad || h >= output_h - right_h_pad || //
              w < left_w_pad || w >= output_w - right_w_pad || //
              d < left_d_pad || d >= output_d - right_d_pad)
          {
            if (mode == MirrorPadMode::REFLECT)
            {
              *out_ptr++ = input_data[offset_index(
                positive_mod(d - left_d_pad, input_d), positive_mod(w - left_w_pad, input_w),
                positive_mod(h - left_h_pad, input_h), positive_mod(b - left_b_pad, input_b))];
            }
            else
            {
              *out_ptr++ = input_data[offset_index(
                symmetric_dim(d, left_d_pad, input_d), symmetric_dim(w, left_w_pad, input_w),
                symmetric_dim(h, left_h_pad, input_h), symmetric_dim(b, left_b_pad, input_b))];
            }
          }
          else
          {
            *out_ptr++ = *in_ptr++;
          }
        }
      }
    }
  }
}

} // namespace kernels
} // namespace luci_interpreter