summaryrefslogtreecommitdiff
path: root/runtimes/libs/cker/include/cker/operation/SoftMax.h
blob: 322f5d5a279aa3a13eb20731bdecad8e5d9094f1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
/*
 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef __NNFW_CKER_SOFTMAX_H__
#define __NNFW_CKER_SOFTMAX_H__

#include "cker/Shape.h"
#include "cker/Utils.h"
#include "cker/gemmlowp/FixedPoint.h"

#include <cmath>

namespace nnfw
{
namespace cker
{

struct SoftmaxParams
{
  // beta is not really used (not a Tensorflow parameter) and not implemented
  // for LogSoftmax.
  double beta;
  // uint8 inference params.  Used even when beta defaults to 1.0.
  int32_t input_multiplier;
  int32_t input_left_shift;
  // Reverse scaling is only used by LogSoftmax.
  int32_t reverse_scaling_divisor;
  int32_t reverse_scaling_right_shift;
  int diff_min;
};

inline void Softmax(const SoftmaxParams &params, const Shape &input_shape, const float *input_data,
                    const Shape &output_shape, float *output_data)
{
  const int trailing_dim = input_shape.DimensionsCount() - 1;
  const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
  const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);

  for (int i = 0; i < outer_size; ++i)
  {
    // Find max element value which we'll use to ensure numerical stability
    // taking advantage of the following equality:
    // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
    float max = std::numeric_limits<float>::lowest();
    for (int c = 0; c < depth; ++c)
    {
      max = std::max(max, input_data[i * depth + c]);
    }

    // Compute sum.
    float sum = 0.f;
    for (int c = 0; c < depth; ++c)
    {
      sum += std::exp((input_data[i * depth + c] - max) * params.beta);
    }

    // Compute result.
    for (int c = 0; c < depth; ++c)
    {
      output_data[i * depth + c] = std::exp((input_data[i * depth + c] - max) * params.beta) / sum;
    }
  }
}

inline void Softmax(const SoftmaxParams &params, const Shape &input_shape,
                    const uint8_t *input_data, const Shape &output_shape, uint8_t *output_data)
{
  const int32_t input_beta_multiplier = params.input_multiplier;
  const int32_t input_beta_left_shift = params.input_left_shift;
  const int diff_min = params.diff_min;
  // The representation chosen for the input to the exp() function is Q5.26.
  // We need to leave extra space since values that we skip might be as large as
  // -32 before multiplying by input_beta_multiplier, and therefore as large as
  // -16 afterwards.  Note that exp(-8) is definitely not insignificant to
  // accumulation, but exp(-16) definitely is.
  static const int kScaledDiffIntegerBits = 5;
  static const int kAccumulationIntegerBits = 12;
  using FixedPointScaledDiff = gemmlowp::FixedPoint<kScaledDiffIntegerBits>;
  using FixedPointAccum = gemmlowp::FixedPoint<kAccumulationIntegerBits>;
  using FixedPoint0 = gemmlowp::FixedPoint<0>;

  const int trailing_dim = input_shape.DimensionsCount() - 1;
  const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
  const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);

  for (int i = 0; i < outer_size; ++i)
  {
    uint8_t max_in_row = 0;
    for (int c = 0; c < depth; ++c)
    {
      max_in_row = std::max(max_in_row, input_data[i * depth + c]);
    }

    FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
    for (int c = 0; c < depth; ++c)
    {
      int32_t input_diff = static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
      if (input_diff >= diff_min)
      {
        const int32_t input_diff_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne(
            input_diff, input_beta_multiplier, input_beta_left_shift);
        const FixedPointScaledDiff scaled_diff_f8 =
            FixedPointScaledDiff::FromRaw(input_diff_rescaled);
        sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
                                        exp_on_negative_values(scaled_diff_f8));
      }
    }

    int32_t fixed_sum_of_exps = sum_of_exps.raw();
    int headroom_plus_one = CountLeadingZeros(static_cast<uint32_t>(fixed_sum_of_exps));
    // This is the number of bits to the left of the binary point above 1.0.
    // Consider fixed_sum_of_exps=1.25.  In that case shifted_scale=0.8 and
    // no later adjustment will be needed.
    int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
    int32_t shifted_sum_minus_one =
        static_cast<int32_t>((static_cast<uint32_t>(fixed_sum_of_exps) << headroom_plus_one) -
                             (static_cast<uint32_t>(1) << 31));

    FixedPoint0 shifted_scale =
        one_over_one_plus_x_for_x_in_0_1(FixedPoint0::FromRaw(shifted_sum_minus_one));

    for (int c = 0; c < depth; ++c)
    {
      int32_t input_diff = static_cast<int32_t>(input_data[i * depth + c]) - max_in_row;
      if (input_diff >= diff_min)
      {
        const int32_t input_diff_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne(
            input_diff, input_beta_multiplier, input_beta_left_shift);
        const FixedPointScaledDiff scaled_diff_f8 =
            FixedPointScaledDiff::FromRaw(input_diff_rescaled);

        FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
        int32_t unsat_output = gemmlowp::RoundingDivideByPOT((shifted_scale * exp_in_0).raw(),
                                                             num_bits_over_unit + 31 - 8);

        output_data[i * depth + c] = static_cast<uint8_t>(
            std::max(std::min(unsat_output, static_cast<int32_t>(255)), static_cast<int32_t>(0)));
      }
      else
      {
        output_data[i * depth + c] = 0;
      }
    }
  }
}

} // namespace cker
} // namespace nnfw

#endif // __NNFW_CKER_SOFTMAX_H__