summaryrefslogtreecommitdiff
path: root/runtimes/neurun/src/kernel/cpu/SoftMaxLayer.cc
blob: c998c65f6b562aa87847d0dbd027e7a23cd284c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
/*
 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "SoftMaxLayer.h"

#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "kernel/cpu/OperationUtils.h"

namespace neurun
{
namespace kernel
{
namespace cpu
{

SoftMaxLayer::SoftMaxLayer()
    : _inputData(nullptr), _outputData(nullptr), _beta(0.0), _inputShape(), _outputShape(),
      _inputType(OperandType::SCALAR_FLOAT32)
{
  // DO NOTHING
}

// Performs softmax along the input of size (input_size * batch_size).
void Softmax(const float *in, const int input_size, const int batch_size, const float beta,
             float *out)
{
  TF_LITE_ASSERT(input_size > 0);

  // For each batch
  for (int b = 0; b < batch_size; b++)
  {
    // Find the max coeff.
    float max_coeff = in[0];
    for (int i = 1; i < input_size; i++)
    {
      if (in[i] > max_coeff)
        max_coeff = in[i];
    }

    // Compute the normalized sum of exps.
    float exp_sum = 0.0;
    for (int i = 0; i < input_size; i++)
    {
      out[i] = std::exp((in[i] - max_coeff) * beta);
      exp_sum += out[i];
    }

    // Divide by the sum of exps.
    float reciprocal_sum_exp = 1.f / exp_sum;
    for (int i = 0; i < input_size; i++)
    {
      out[i] *= reciprocal_sum_exp;
    }

    // Advance in and out pointers for the next batch.
    in += input_size;
    out += input_size;
  }
}

bool SoftMaxLayer::softmaxFloat32()
{
  Shape shapeIn4D;

  if (getNumberOfDimensions(_inputShape) == 2)
  {
    uint32_t batch_size = getSizeOfDimension(_inputShape, 0);
    uint32_t input_size = getNumberOfElements(_inputShape) / batch_size;
    Softmax(reinterpret_cast<const float *>(_inputData), input_size, batch_size, _beta,
            reinterpret_cast<float *>(_outputData));
  }
  else if (getNumberOfDimensions(_inputShape) == 4)
  {
    ::tflite::SoftmaxParams op_params;
    op_params.beta = _beta;
    ::tflite::optimized_ops::Softmax(op_params, convertShapeToTFLiteShape(_inputShape),
                                     reinterpret_cast<const float *>(_inputData),
                                     convertShapeToTFLiteShape(_outputShape),
                                     reinterpret_cast<float *>(_outputData));
  }
  else
  {
    std::cout << "only 2D and 4D tensors supported" << std::endl;
    return false;
  }

  return true;
}

bool SoftMaxLayer::softmaxQuant8()
{
  Shape shapeIn4D = _inputShape;

  if (getNumberOfDimensions(_inputShape) == 2)
  {
    uint32_t batch_size = getSizeOfDimension(_inputShape, 0);
    uint32_t input_size = getNumberOfElements(_inputShape) / batch_size;
    shapeIn4D.dimensions = {batch_size, 1, 1, input_size};
  }
  else if (getNumberOfDimensions(_inputShape) == 4)
  {
    shapeIn4D = _inputShape;
  }
  else
  {
    std::cout << "only 2D and 4D tensors supported" << std::endl;
    return false;
  }
  if (_outputShape.offset != 0 || _outputShape.scale != 1.f / 256)
  {
    std::cout << "incorrect scale / offset for output" << std::endl;
    return false;
  }
  static const int32_t kScaledDiffIntegerBits = 5;
  const double input_beta_real_multiplier = std::min(
      1.0 * _beta * _inputShape.scale * (1 << (31 - kScaledDiffIntegerBits)), (1ll << 31) - 1.0);
  int32_t input_multiplier = 0;
  int32_t input_left_shift = 0;
  if (!QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier, &input_multiplier,
                                        &input_left_shift))
  {
    return false;
  }
  float diff_min = -1.0f * CalculateInputRadius(kScaledDiffIntegerBits, input_left_shift);

  ::tflite::SoftmaxParams op_params;
  op_params.input_multiplier = input_multiplier;
  op_params.input_left_shift = input_left_shift;
  op_params.diff_min = diff_min;
  ::tflite::optimized_ops::Softmax(op_params, convertShapeToTFLiteShape(shapeIn4D), _inputData,
                                   convertShapeToTFLiteShape(shapeIn4D), _outputData);
  return true;
}

void SoftMaxLayer::configure(uint8_t *inputData, const Shape &inputShape, const float beta,
                             uint8_t *outputData, const Shape &outputShape)
{
  _inputData = inputData;
  _inputShape = inputShape;
  _inputType = inputShape.type;
  _outputData = outputData;
  _outputShape = outputShape;
  _beta = beta;
}

void SoftMaxLayer::run()
{
  if (_inputType == OperandType::TENSOR_FLOAT32)
  {
    softmaxFloat32();
  }
  else if (_inputType == OperandType::TENSOR_QUANT8_ASYMM)
  {
    throw std::runtime_error{"SoftMaxLayer : Not tested for TENSOR_QUANT8_ASYMM"};
    // softmaxQuant8();
  }
}

} // namespace cpu
} // namespace kernel
} // namespace neurun