summaryrefslogtreecommitdiff
path: root/onert-micro/luci-interpreter/src/kernels/Squeeze.cpp
blob: 9736dce3aaa91d63a82347ddcbe889c7fca1cf4d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
/*
 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "kernels/Squeeze.h"

#include "kernels/Utils.h"

namespace luci_interpreter
{
namespace kernels
{

Squeeze::Squeeze(const Tensor *input, Tensor *output, const SqueezeParams &params)
  : KernelWithParams<SqueezeParams>({input}, {output}, params)
{
}

void Squeeze::configure()
{
  int input_num_dims = input()->shape().num_dims();
  int num_squeeze_dims = params().squeeze_dims.size();
  assert(input_num_dims <= 8);
  bool should_squeeze[8] = {false};
  int num_squeezed_dims = 0;
  if (num_squeeze_dims == 0)
  {
    for (int idx = 0; idx < input_num_dims; ++idx)
    {
      if (input()->shape().dim(idx) == 1)
      {
        should_squeeze[idx] = true;
        ++num_squeezed_dims;
      }
    }
  }
  else
  {
    for (int idx = 0; idx < num_squeeze_dims; ++idx)
    {
      int current = params().squeeze_dims[idx] < 0 ? params().squeeze_dims[idx] + input_num_dims
                                                   : params().squeeze_dims[idx];
      assert(current >= 0 && current < input_num_dims && input()->shape().dim(current) == 1);
      if (!should_squeeze[current])
        ++num_squeezed_dims;
      should_squeeze[current] = true;
    }
  }
  // TODO: enable it only if kernel with dynamic shapes
  Shape output_shape(input_num_dims - num_squeezed_dims);
  for (int in_idx = 0, out_idx = 0; in_idx < input_num_dims; ++in_idx)
  {
    if (!should_squeeze[in_idx])
    {
      output_shape.dim(out_idx++) = input()->shape().dim(in_idx);
    }
  }
  output()->resize(output_shape);
}

void Squeeze::execute() const
{
  assert(input()->shape().num_elements() == output()->shape().num_elements());

  const auto *input_data = input()->data<void>();
  auto *output_data = output()->data<void>();
  std::memcpy(output_data, input_data,
              getDataTypeSize(input()->element_type()) * input()->shape().num_elements());
}

} // namespace kernels
} // namespace luci_interpreter