summaryrefslogtreecommitdiff
path: root/runtime/onert/core/src/interp/operations/BinaryArithmeticOps.cc
blob: 8ffc3cd33814bd55229411856109fb17651ac338 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
/*
 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <cker/operation/BinaryArithmeticOps.h>

#include "OperationUtil.h"

#include "interp/Registration.h"
#include "ir/operation/Add.h"
#include "ir/operation/Sub.h"
#include "ir/operation/Mul.h"
#include "misc/polymorphic_downcast.h"

namespace onert
{
namespace interp
{
namespace
{

enum class OpType
{
  ADD,
  SUB,
  MUL
};

template <typename node_type> void prepareAdd(ExecEnv *env, const ir::Operation &node)
{
  const auto &add_node = nnfw::misc::polymorphic_downcast<const node_type &>(node);

  const auto lhs_index = node.getInputs().at(add_node.LHS);
  const auto rhs_index = node.getInputs().at(add_node.RHS);
  const auto out_index = node.getOutputs().at(0);

  const auto lhs_tensor = env->tensorAt(lhs_index);
  const auto rhs_tensor = env->tensorAt(rhs_index);

  // Check shape and type lhs is same with rhs
  // TODO Util function to compare TensorInfo
  if (lhs_tensor->data_type() != rhs_tensor->data_type())
  {
    throw std::runtime_error{"Interp(Add): Different input types"};
  }

  bool try_broadcast = (lhs_tensor->tensorInfo().shape() != rhs_tensor->tensorInfo().shape());
  if (try_broadcast)
  {
    bool success = true;
    auto out_shape = calcBroadcastShape(lhs_tensor->tensorInfo().shape(),
                                        rhs_tensor->tensorInfo().shape(), success);
    if (!success)
    {
      throw std::runtime_error{"Interp(Add): Fail to brodcasting"};
    }

    auto output_info = ir::OperandInfo(out_shape, lhs_tensor->tensorInfo().typeInfo());
    // We can handle already allocated (ex. model output)
    env->allocateIfNeeded(out_index, output_info);
  }
  else
  {
    // Output's shape and type is same with input
    auto output_info = lhs_tensor->tensorInfo();
    // We can handle already allocated (ex. model output)
    env->allocateIfNeeded(out_index, output_info);
  }

  auto out_tensor = env->tensorAt(out_index);
  // Check shape and type lhs is same with output
  // TODO Util function to compare TensorInfo
  if (lhs_tensor->data_type() != out_tensor->data_type())
  {
    throw std::runtime_error{"Interp(Add): Invalid output type"};
  }
}

inline void setActivationParams(float min, float max, nnfw::cker::BinaryArithmeticOpParam *params)
{
  params->float_activation_min = min;
  params->float_activation_max = max;
}

inline void setActivationParams(int32_t min, int32_t max,
                                nnfw::cker::BinaryArithmeticOpParam *params)
{
  params->quantized_activation_min = min;
  params->quantized_activation_max = max;
}

template <typename raw_type, typename param_type, OpType op_type>
void invoke(const ITensor *lhs_tensor, const ITensor *rhs_tensor, const ITensor *out_tensor,
            const param_type &param)
{
  const auto lhs_buffer = lhs_tensor->bufferRO();
  const auto rhs_buffer = rhs_tensor->bufferRO();
  auto out_buffer = out_tensor->buffer();

  nnfw::cker::BinaryArithmeticOpParam cker_param;
  raw_type activation_min, activation_max;
  calculateActivationRange(param.activation, &activation_min, &activation_max);
  setActivationParams(activation_min, activation_max, &cker_param);
  const raw_type *lhs_ptr = reinterpret_cast<const raw_type *>(lhs_buffer);
  const raw_type *rhs_ptr = reinterpret_cast<const raw_type *>(rhs_buffer);
  raw_type *out_ptr = reinterpret_cast<raw_type *>(out_buffer);

  cker_param.type = (op_type == OpType::ADD)
                        ? nnfw::cker::BinaryArithmeticOpType::ADD
                        : ((op_type == OpType::SUB) ? nnfw::cker::BinaryArithmeticOpType::SUB
                                                    : nnfw::cker::BinaryArithmeticOpType::MUL);

  if (lhs_tensor->tensorInfo().shape() != rhs_tensor->tensorInfo().shape())
  {
    const auto lhs_shape = convertExtendShape(lhs_tensor->tensorInfo().shape());
    const auto rhs_shape = convertExtendShape(rhs_tensor->tensorInfo().shape());
    const auto out_shape = convertExtendShape(out_tensor->tensorInfo().shape());
    nnfw::cker::BroadcastBinaryArithmeticOpSlow(cker_param, lhs_shape, lhs_ptr, rhs_shape, rhs_ptr,
                                                out_shape, out_ptr);
    return;
  }

  const auto lhs_shape = convertShape(lhs_tensor->tensorInfo().shape());
  const auto rhs_shape = convertShape(rhs_tensor->tensorInfo().shape());
  const auto out_shape = convertShape(out_tensor->tensorInfo().shape());
  nnfw::cker::BinaryArithmeticOp(cker_param, lhs_shape, lhs_ptr, rhs_shape, rhs_ptr, out_shape,
                                 out_ptr);
}

template <typename node_type, typename param_type, OpType op_type>
void invokeAdd(const ExecEnv *env, const ir::Operation &node)
{
  const auto &arithmetic_node = nnfw::misc::polymorphic_downcast<const node_type &>(node);

  const auto lhs_index = node.getInputs().at(arithmetic_node.LHS);
  const auto rhs_index = node.getInputs().at(arithmetic_node.RHS);
  const auto out_index = node.getOutputs().at(0);
  const auto lhs_tensor = env->tensorAt(lhs_index);
  const auto rhs_tensor = env->tensorAt(rhs_index);
  const auto out_tensor = env->tensorAt(out_index);
  const auto data_type = lhs_tensor->data_type();

  if (data_type == ir::DataType::INT32)
  {
    invoke<int32_t, param_type, op_type>(lhs_tensor, rhs_tensor, out_tensor,
                                         arithmetic_node.param());
  }
  else if (data_type == ir::DataType::FLOAT32)
  {
    invoke<float, param_type, op_type>(lhs_tensor, rhs_tensor, out_tensor, arithmetic_node.param());
  }
  else
  {
    throw std::runtime_error{"NYI: Unsupported data type"};
  }
}
} // namespace

OpKernel *getAdd()
{
  static OpKernel kernel = {prepareAdd<ir::operation::Add>,
                            invokeAdd<ir::operation::Add, ir::operation::Add::Param, OpType::ADD>};
  return &kernel;
}

OpKernel *getSub()
{
  static OpKernel kernel = {prepareAdd<ir::operation::Sub>,
                            invokeAdd<ir::operation::Sub, ir::operation::Sub::Param, OpType::SUB>};
  return &kernel;
}

OpKernel *getMul()
{
  static OpKernel kernel = {prepareAdd<ir::operation::Mul>,
                            invokeAdd<ir::operation::Mul, ir::operation::Mul::Param, OpType::MUL>};
  return &kernel;
}

} // namespace interp
} // namespace onert