summaryrefslogtreecommitdiff
path: root/compiler/nnc/backends/soft_backend/code_snippets/cpp_conv_transpose.def
blob: 016ff15e192cf42d38ca7c0a4d088f747503c8f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <iostream>

template <typename T>
void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
                     const RuntimeShape& input_shape, const T* input_data,
                     const RuntimeShape& filter_shape,
                     const RuntimeShape& output_shape, T* im2col_data) {
  const int stride_width = params.stride_width;
  const int stride_height = params.stride_height;
  const int pad_width = params.padding_values.width;
  const int pad_height = params.padding_values.height;
  TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
  TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
  TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
  TFLITE_DCHECK(im2col_data);

  const int batches = MatchingDim(input_shape, 0, output_shape, 0);
  const int input_height = input_shape.Dims(1);
  const int input_width = input_shape.Dims(2);
  const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
  const int filter_height = filter_shape.Dims(1);
  const int filter_width = filter_shape.Dims(2);
  const int output_height = output_shape.Dims(1);
  const int output_width = output_shape.Dims(2);
  MatchingDim(output_shape, 3, filter_shape, 0);  // output_depth

  // Construct the MxN sized im2col matrix.
  // The rows M, are sub-ordered B x H x W
  const RuntimeShape row_shape({1, batches, output_height, output_width});
  // The columns, N, are sub-ordered Kh x Kw x Din
  const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
  // Use dimensions M and N to construct dims for indexing directly into im2col
  const RuntimeShape im2col_shape(
    {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});

  // Build the im2col matrix by looping through all the input pixels,
  // computing their influence on the output, rather than looping through all
  // the output pixels. We therefore must initialize the im2col array to zero.
  // This is potentially inefficient because we subsequently overwrite bytes
  // set here. However, in practice memset is very fast and costs negligible.
  memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));

  // Loop through the output batches
  for (int batch = 0; batch < batches; ++batch) {
    // Loop through input pixels one at a time.
    for (int in_y = 0; in_y < input_height; ++in_y) {
      for (int in_x = 0; in_x < input_width; ++in_x) {
        // Loop through the output pixels it will influence
        const int out_x_origin = (in_x * stride_width) - pad_width;
        const int out_y_origin = (in_y * stride_height) - pad_height;
        for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
          const int out_y = out_y_origin + filter_y;
          // Is output pixel within height bounds?
          if ((out_y >= 0) && (out_y < output_height)) {
            for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
              const int out_x = out_x_origin + filter_x;
              // Is output pixel within width bounds?
              if ((out_x >= 0) && (out_x < output_width)) {
                // Copy the input elements of this pixel
                T const* src =
                  input_data + Offset(input_shape, batch, in_y, in_x, 0);
                int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
                int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
                T* dst = im2col_data +
                         Offset(im2col_shape, 0, 0, row_offset, col_offset);
                memcpy(dst, src, input_depth * sizeof(T));
              }
            }
          }
        }
      }
    }
  }
}

inline void TransposeConv(
  const ConvParams& params, const RuntimeShape& input_shape,
  const float* input_data, const RuntimeShape& filter_shape,
  const float* filter_data, const RuntimeShape& output_shape,
  float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {

  // Note we could use transposed weights with forward conv for unstrided
  // cases. But we are already getting good performance with this code as-is.
  TFLITE_DCHECK(im2col_data);
  TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
                  output_shape, im2col_data);

  const auto im2col_matrix_map =
    MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
  const auto filter_matrix_map =
    MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
  auto output_matrix_map =
    MapAsMatrixWithLastDimAsRows(output_data, output_shape);

  Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
}