summaryrefslogtreecommitdiff
path: root/compute/cker/include/cker
diff options
context:
space:
mode:
Diffstat (limited to 'compute/cker/include/cker')
-rw-r--r--compute/cker/include/cker/NeonTensorUtils.h8
-rw-r--r--compute/cker/include/cker/PortableTensorUtils.h3
-rw-r--r--compute/cker/include/cker/TensorUtils.h4
-rw-r--r--compute/cker/include/cker/Types.h25
-rw-r--r--compute/cker/include/cker/Utils.h62
-rw-r--r--compute/cker/include/cker/operation/BatchToSpaceND.h133
-rw-r--r--compute/cker/include/cker/operation/FullyConnected.h67
-rw-r--r--compute/cker/include/cker/operation/Helper/PhiloxRandom.h276
-rw-r--r--compute/cker/include/cker/operation/Helper/RandomDistributions.h778
-rw-r--r--compute/cker/include/cker/operation/Helper/RandomOp.h52
-rw-r--r--compute/cker/include/cker/operation/Helper/RandomOpCpu.h163
-rw-r--r--compute/cker/include/cker/operation/L2Normalize.h94
-rw-r--r--compute/cker/include/cker/operation/Logistic.h9
-rw-r--r--compute/cker/include/cker/operation/MatrixBandPart.h6
-rw-r--r--compute/cker/include/cker/operation/Pad.h15
-rw-r--r--compute/cker/include/cker/operation/Quantize.h47
-rw-r--r--compute/cker/include/cker/operation/ReLU6.h56
-rw-r--r--compute/cker/include/cker/operation/Reduce.h86
-rw-r--r--compute/cker/include/cker/operation/ResizeBilinear.h270
-rw-r--r--compute/cker/include/cker/operation/SpaceToDepth.h71
-rw-r--r--compute/cker/include/cker/operation/SplitV.h81
-rw-r--r--compute/cker/include/cker/operation/StatelessRandomUniform.h103
-rw-r--r--compute/cker/include/cker/ruy/RuySupport.h41
23 files changed, 2377 insertions, 73 deletions
diff --git a/compute/cker/include/cker/NeonTensorUtils.h b/compute/cker/include/cker/NeonTensorUtils.h
index 5c38bc6f3..246fd9a46 100644
--- a/compute/cker/include/cker/NeonTensorUtils.h
+++ b/compute/cker/include/cker/NeonTensorUtils.h
@@ -546,7 +546,7 @@ bool NeonIsZeroVector(const float *vector, int v_size)
void NeonCpuBackendGemm(const int8_t *input, const int32_t *bias,
const int8_t *input_to_gate_weights, int32_t n_batch, int32_t n_input,
- int32_t n_output, int32_t, int32_t *scratch)
+ int32_t n_output, int32_t, int32_t *scratch, ruy::Context *ruy_context)
{
MatrixParams<int8_t> lhs_params;
lhs_params.order = Order::kRowMajor;
@@ -571,8 +571,6 @@ void NeonCpuBackendGemm(const int8_t *input, const int32_t *bias,
}
// Below code is from tflite::cpu_backend_gemm::detail::GemmImplUsingRuy
- ruy::Context *ruy_context = ruy_support::GetRuyContext();
-
ruy::Matrix<int8_t> ruy_lhs;
ruy::Matrix<int8_t> ruy_rhs;
ruy::Matrix<int32_t> ruy_dst;
@@ -851,13 +849,13 @@ void NeonMatrixBatchVectorMultiplyAccumulate(const int8_t *__restrict__ matrix,
const int m_cols, const int8_t *__restrict__ vectors,
const float *scaling_factors, int n_batch,
int32_t *scratch, float *__restrict__ result,
- int result_stride)
+ int result_stride, ruy::Context *ruy_context)
{
if (m_rows % 4 == 0 && result_stride == 1)
{
const int32_t *bias = static_cast<const int32_t *>(nullptr);
NeonCpuBackendGemm(vectors, bias, matrix, n_batch, m_cols, m_rows,
- /*output_zp =*/0, scratch);
+ /*output_zp =*/0, scratch, ruy_context);
// Multiply by float scaling factors and write to result
const int total_size = n_batch * m_rows;
diff --git a/compute/cker/include/cker/PortableTensorUtils.h b/compute/cker/include/cker/PortableTensorUtils.h
index 9769d4ba6..54714e214 100644
--- a/compute/cker/include/cker/PortableTensorUtils.h
+++ b/compute/cker/include/cker/PortableTensorUtils.h
@@ -20,6 +20,7 @@
#include "cker/Types.h"
#include "cker/neon/neon_check.h"
+#include <ruy/context.h>
#include <cstring>
#include <cmath>
@@ -142,7 +143,7 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const int8_t *__restrict__ matr
const int8_t *__restrict__ vector,
const float *scaling_factors, int n_batch,
int32_t *, float *__restrict__ result,
- int result_stride)
+ int result_stride, ruy::Context *)
{
PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector, scaling_factors,
n_batch, result, result_stride);
diff --git a/compute/cker/include/cker/TensorUtils.h b/compute/cker/include/cker/TensorUtils.h
index 6b23c0b30..e07c91239 100644
--- a/compute/cker/include/cker/TensorUtils.h
+++ b/compute/cker/include/cker/TensorUtils.h
@@ -73,10 +73,10 @@ void MatrixBatchVectorMultiplyAccumulate(const float *matrix, int m_rows, int m_
void MatrixBatchVectorMultiplyAccumulate(const int8_t *matrix, const int m_rows, const int m_cols,
const int8_t *vectors, const float *scaling_factors,
int n_batch, int32_t *scratch, float *result,
- int result_stride)
+ int result_stride, ruy::Context *ruy_context)
{
NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, vectors,
- scaling_factors, n_batch, scratch, result, result_stride);
+ scaling_factors, n_batch, scratch, result, result_stride, ruy_context);
}
void ZeroVector(float *vector, int v_size) { PortableZeroVector(vector, v_size); }
diff --git a/compute/cker/include/cker/Types.h b/compute/cker/include/cker/Types.h
index 41b1916cf..886ce5e5e 100644
--- a/compute/cker/include/cker/Types.h
+++ b/compute/cker/include/cker/Types.h
@@ -259,6 +259,12 @@ struct FullyConnectedParams
// FullyConnectedWeightsFormat weights_format;
};
+struct L2NormParams
+{
+ // uint8 inference params.
+ int32_t input_zero_point;
+};
+
struct GatherParams
{
int32_t axis;
@@ -271,6 +277,14 @@ struct InstanceNormParams
float float_activation_max;
};
+struct ResizeBilinearParams
+{
+ int32_t output_height;
+ int32_t output_width;
+ bool align_corners;
+ bool half_pixel_centers;
+};
+
struct TransposeConvParams
{
PaddingType padding_type;
@@ -325,6 +339,12 @@ struct SplitParams
int16_t axis;
};
+struct SplitVParams
+{
+ uint16_t num_split;
+ int16_t axis;
+};
+
struct FusedBatchNormParams
{
bool is_training;
@@ -338,6 +358,11 @@ struct SpaceToBatchParams
int32_t output_offset;
};
+struct SpaceToDepthParams
+{
+ int32_t block_size;
+};
+
enum class Order
{
kColMajor,
diff --git a/compute/cker/include/cker/Utils.h b/compute/cker/include/cker/Utils.h
index b69d55c26..2abb998d0 100644
--- a/compute/cker/include/cker/Utils.h
+++ b/compute/cker/include/cker/Utils.h
@@ -123,6 +123,68 @@ inline int CountLeadingZeros(uint32_t integer_input)
return leading_zeros;
}
+inline void GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
+ int32_t *output_inv_sqrt, int *output_shift)
+{
+ assert(input >= 0);
+ if (input <= 1)
+ {
+ // Handle the input value 1 separately to avoid overflow in that case
+ // in the general computation below (b/143972021). Also handle 0 as if it
+ // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
+ // but rare/unrealistic input value. We can expect both to occur in some
+ // incompletely trained models, but probably not in fully trained models.
+ *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
+ *output_shift = 0;
+ return;
+ }
+ assert(input > 1);
+ *output_shift = 11;
+ while (input >= (1 << 29))
+ {
+ input /= 4;
+ ++*output_shift;
+ }
+ const unsigned max_left_shift_bits = CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
+ const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
+ const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
+ *output_shift -= left_shift_bit_pairs;
+ input <<= 2 * left_shift_bit_pairs;
+ assert(input >= (1 << 27));
+ assert(input < (1 << 29));
+ using gemmlowp::FixedPoint;
+ using gemmlowp::Rescale;
+ using gemmlowp::SaturatingRoundingMultiplyByPOT;
+ // Using 3 integer bits gives us enough room for the internal arithmetic in
+ // this Newton-Raphson iteration.
+ using F3 = FixedPoint<int32_t, 3>;
+ using F0 = FixedPoint<int32_t, 0>;
+ const F3 fixedpoint_input = F3::FromRaw(input >> 1);
+ const F3 fixedpoint_half_input = SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
+ const F3 fixedpoint_half_three =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
+ // Newton-Raphson iteration
+ // Naive unoptimized starting guess: x = 1
+ F3 x = F3::One();
+ // Naive unoptimized number of iterations: 5
+ for (int i = 0; i < 5; i++)
+ {
+ const F3 x3 = Rescale<3>(x * x * x);
+ x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
+ }
+ const F0 fixedpoint_half_sqrt_2 =
+ GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
+ x = x * fixedpoint_half_sqrt_2;
+ *output_inv_sqrt = x.raw();
+ if (*output_shift < 0)
+ {
+ *output_inv_sqrt <<= -*output_shift;
+ *output_shift = 0;
+ }
+ // Convert right shift (right is positive) to left shift.
+ *output_shift *= reverse_shift;
+}
+
// Comment from tensorflow lite:
//
// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING
diff --git a/compute/cker/include/cker/operation/BatchToSpaceND.h b/compute/cker/include/cker/operation/BatchToSpaceND.h
new file mode 100644
index 000000000..e33b2fba5
--- /dev/null
+++ b/compute/cker/include/cker/operation/BatchToSpaceND.h
@@ -0,0 +1,133 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_BATCH_TO_SPACE_ND_H__
+#define __NNFW_CKER_BATCH_TO_SPACE_ND_H__
+
+#include "cker/Shape.h"
+
+#define UNUSED(x) ((void)(x))
+
+namespace nnfw
+{
+namespace cker
+{
+
+// Helper methods for BatchToSpaceND.
+// `spatial_index_dim` specifies post-crop offset index in this spatial
+// dimension, i.e. spatial offset introduced by flattening batch to spatial
+// dimension minus the crop size at beginning. `block_shape_dim` is the block
+// size in current dimension. `input_dim` and `output_dim` are input and output
+// size of BatchToSpaceND operation in current dimension.
+// Output start index is inclusive and end index is exclusive.
+inline void GetIndexRange(int spatial_index_dim, int block_shape_dim, int input_dim, int output_dim,
+ int *start_index, int *end_index)
+{
+ // (*start_index) * block_shape_dim is effectively rounded up to the next
+ // multiple of block_shape_dim by the integer division.
+ *start_index = std::max(0, (-spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
+ // Similarly, (*end_index) * block_shape_dim is rounded up too (note that
+ // end_index is exclusive).
+ *end_index =
+ std::min(input_dim, (output_dim - spatial_index_dim + block_shape_dim - 1) / block_shape_dim);
+}
+
+template <typename T>
+inline void BatchToSpaceND(const Shape &unextended_input1_shape, const T *input1_data,
+ const int32_t *block_shape_data, const int32_t *crops_data,
+ const Shape &unextended_output_shape, T *output_data)
+{
+ auto input_dim = unextended_input1_shape.DimensionsCount();
+ auto output_dim = unextended_output_shape.DimensionsCount();
+
+ assert(input_dim == 3 || input_dim == 4);
+ assert(input_dim == output_dim);
+
+ UNUSED(input_dim);
+ UNUSED(output_dim);
+
+ // Extends the input/output shape from 3D to 4D if needed, NHC -> NH1C.
+ auto extend_shape = [](const Shape &shape) {
+ if (shape.DimensionsCount() == 4)
+ {
+ return shape;
+ }
+ Shape new_shape(4, 1);
+ new_shape.SetDim(0, shape.Dims(0));
+ new_shape.SetDim(1, shape.Dims(1));
+ new_shape.SetDim(3, shape.Dims(2));
+ return new_shape;
+ };
+ const Shape input1_shape = extend_shape(unextended_input1_shape);
+ const Shape output_shape = extend_shape(unextended_output_shape);
+
+ const int32_t output_width = output_shape.Dims(2);
+ const int32_t output_height = output_shape.Dims(1);
+ const int32_t output_batch_size = output_shape.Dims(0);
+
+ const int32_t depth = input1_shape.Dims(3);
+ const int32_t input_width = input1_shape.Dims(2);
+ const int32_t input_height = input1_shape.Dims(1);
+ const int32_t input_batch_size = input1_shape.Dims(0);
+
+ const int32_t block_shape_height = block_shape_data[0];
+ const int32_t block_shape_width = block_shape_data[1];
+
+ const int32_t crops_top = crops_data[0];
+ const int32_t crops_left = crops_data[2];
+
+ for (int in_batch = 0; in_batch < input_batch_size; ++in_batch)
+ {
+ const int out_batch = in_batch % output_batch_size;
+ const int spatial_offset = in_batch / output_batch_size;
+
+ int in_h_start = 0;
+ int in_h_end = 0;
+ // GetIndexRange ensures start and end indices are in [0, output_height).
+ GetIndexRange(spatial_offset / block_shape_width - crops_top, block_shape_height, input_height,
+ output_height, &in_h_start, &in_h_end);
+
+ for (int in_h = in_h_start; in_h < in_h_end; ++in_h)
+ {
+ const int out_h = in_h * block_shape_height + spatial_offset / block_shape_width - crops_top;
+ assert(out_h >= 0);
+ assert(out_h < output_height);
+
+ int in_w_start = 0;
+ int in_w_end = 0;
+ // GetIndexRange ensures start and end indices are in [0, output_width).
+ GetIndexRange(spatial_offset % block_shape_width - crops_left, block_shape_width, input_width,
+ output_width, &in_w_start, &in_w_end);
+
+ for (int in_w = in_w_start; in_w < in_w_end; ++in_w)
+ {
+ const int out_w =
+ in_w * block_shape_width + spatial_offset % block_shape_width - crops_left;
+ assert(out_w >= 0);
+ assert(out_w < output_width);
+ T *out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
+ const T *in = input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
+ memcpy(out, in, depth * sizeof(T));
+ }
+ }
+ }
+}
+
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_BATCH_TO_SPACE_ND_H__
diff --git a/compute/cker/include/cker/operation/FullyConnected.h b/compute/cker/include/cker/operation/FullyConnected.h
index 9bcf3fd82..4280c9ae2 100644
--- a/compute/cker/include/cker/operation/FullyConnected.h
+++ b/compute/cker/include/cker/operation/FullyConnected.h
@@ -18,6 +18,7 @@
#ifndef __NNFW_CKER_FULLY_CONNECTED_H__
#define __NNFW_CKER_FULLY_CONNECTED_H__
+#include <ruy/context.h>
#include "cker/Shape.h"
#include "cker/Types.h"
#include "cker/Utils.h"
@@ -78,8 +79,11 @@ inline void FullyConnected(const FullyConnectedParams &params, const Shape &inpu
MatrixBatchVectorMultiplyAccumulate(weights_data, num_units, input_size, input_data, batch_size,
output_data, /*result_stride=*/1);
- // Apply activation function
- ApplyActivationToVector(output_data, batch_size * num_units, params.activation, output_data);
+ if (params.activation != FusedActivationFunctionType::kNone)
+ {
+ // Apply activation function
+ ApplyActivationToVector(output_data, batch_size * num_units, params.activation, output_data);
+ }
}
inline void FullyConnected(const FullyConnectedParams &params, const Shape &input_shape,
@@ -140,7 +144,7 @@ inline void FullyConnectedHybrid(const FullyConnectedParams &params, const Shape
const float *input_data, const Shape &filter_shape,
const int8_t *filter_data, const Shape &, const float *bias_data,
const Shape &output_shape, float *output_data,
- FCTempArena &temp_arena)
+ FCTempArena &temp_arena, ruy::Context *ruy_context)
{
int total_input_size = input_shape.FlatSize();
const int input_size = filter_shape.Dims(1);
@@ -186,19 +190,72 @@ inline void FullyConnectedHybrid(const FullyConnectedParams &params, const Shape
int32_t *scratch = temp_arena.accum_scratch.data();
MatrixBatchVectorMultiplyAccumulate(filter_data, num_units, input_size, quant_data,
scaling_factors_ptr, batch_size, scratch, output_data,
- /*result_stride=*/1);
+ /*result_stride=*/1, ruy_context);
#else
MatrixBatchVectorMultiplyAccumulate(filter_data, num_units, input_size, quant_data,
scaling_factors_ptr, batch_size, output_data,
/*result_stride=*/1);
+ UNUSED_RELEASE(ruy_context);
UNUSED_RELEASE(output_shape);
#endif
// Apply activation function to floats.
- ApplyActivationToVector(output_data, batch_size * num_units, params.activation, output_data);
+ if (params.activation != FusedActivationFunctionType::kNone)
+ {
+ // Apply activation function
+ ApplyActivationToVector(output_data, batch_size * num_units, params.activation, output_data);
+ }
return;
}
+inline void FullyConnectedSparseWeight(const FullyConnectedParams &params, const Shape &input_shape,
+ const float *input_data, const Shape &weights_shape,
+ const float *weights_data, const Shape &bias_shape,
+ const float *bias_data, const Shape &output_shape,
+ float *output_data, int w0_size, const uint16_t *w1_segments,
+ const uint16_t *w1_indices)
+{
+ UNUSED_RELEASE(params);
+ UNUSED_RELEASE(input_shape);
+
+ assert(weights_shape.DimensionsCount() == 2);
+ assert(output_shape.DimensionsCount() == 2);
+
+ const int output_dims_count = output_shape.DimensionsCount();
+ const int weights_dims_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
+ const int output_depth =
+ MatchingDim(weights_shape, weights_dims_count - 2, output_shape, output_dims_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
+
+ UNUSED_RELEASE(bias_shape);
+ if (bias_data)
+ {
+ VectorBatchVectorAssign(bias_data, output_depth, batches, output_data);
+ }
+ else
+ {
+ ZeroVector(output_data, batches * output_depth);
+ }
+ for (int b = 0; b < batches; ++b)
+ {
+ for (int idx_0 = 0; idx_0 < w0_size; ++idx_0)
+ {
+ for (int pw1 = w1_segments[idx_0]; pw1 < w1_segments[idx_0 + 1]; ++pw1)
+ {
+ int idx_1 = w1_indices[pw1];
+ output_data[b * output_depth + idx_0] +=
+ weights_data[pw1] * input_data[b * accum_depth + idx_1];
+ }
+ }
+ }
+ if (params.activation != FusedActivationFunctionType::kNone)
+ {
+ // Apply activation function
+ ApplyActivationToVector(output_data, batches * output_depth, params.activation, output_data);
+ }
+}
+
} // namespace cker
} // namespace nnfw
diff --git a/compute/cker/include/cker/operation/Helper/PhiloxRandom.h b/compute/cker/include/cker/operation/Helper/PhiloxRandom.h
new file mode 100644
index 000000000..8e8879ce9
--- /dev/null
+++ b/compute/cker/include/cker/operation/Helper/PhiloxRandom.h
@@ -0,0 +1,276 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_
+#define TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_
+
+#include <stdlib.h>
+
+#include "cker/Types.h"
+#include "cker/Shape.h"
+#include "cker/Utils.h"
+
+// Function qualifiers that need to work on both CPU and GPU.
+#if defined(__CUDACC__) || defined(__HIPCC__)
+// For nvcc.
+#define PHILOX_DEVICE_FUNC __host__ __device__
+#define PHILOX_INLINE __inline__
+#else
+// For non-nvcc.
+#define PHILOX_DEVICE_FUNC
+#define PHILOX_INLINE inline
+#endif
+#define PHILOX_DEVICE_INLINE PHILOX_DEVICE_FUNC PHILOX_INLINE
+
+#include <math.h>
+
+namespace nnfw
+{
+namespace cker
+{
+namespace random
+{
+
+// A class that represents an inline array. It can be used on both CPU and GPU,
+// and also trivially copyable between CPU and GPU.
+// Arguments:
+// T: the array element type;
+// ElementCount: the fixed size of the array;
+template <typename T, int ElementCount> class Array
+{
+public:
+ static constexpr int kElementCount = ElementCount;
+ PHILOX_DEVICE_INLINE Array()
+ {
+ for (int i = 0; i < ElementCount; ++i)
+ {
+ data_[i] = T(0);
+ }
+ }
+
+ PHILOX_DEVICE_INLINE const T &operator[](int index) const { return data_[index]; }
+
+ PHILOX_DEVICE_INLINE T &operator[](int index) { return data_[index]; }
+
+ size_t size() const { return ElementCount; }
+
+private:
+ T data_[ElementCount];
+};
+
+// A class that encapsulates all the states for a random number generator using
+// the philox_4x32_10 algorithm. Each invocation returns a 128-bit random bits
+// in the form of four uint32.
+// There are multiple variants of this algorithm, we picked the 4x32_10 version
+// that is most suited for our applications.
+// Since this class is meant to be copied between CPU to GPU, it maintains a
+// value semantics.
+//
+// For example: To use this class and populate an array of 1024 randoms on CPU
+// with two threads,
+//
+// void Fill(PhiloxRandom rnd, uint32* output, int start, int limit) {
+// assert(start % 4 == 0);
+// assert(limit % 4 == 0);
+// rnd.Skip(start / 4);
+// for (int i = start; i < limit; i += 4) {
+// auto sample = rnd();
+// ... copy sample[0..3] to output[i..i+3]
+// }
+// }
+//
+// PhiloxRandom rng(seed);
+// PhiloxRandom rng_copy = rng;
+// rng.Skip(1000/4);
+//
+// ... schedule Fill(rng_copy, output, 0, 512) in thread 1;
+// ... schedule Fill(rng_copy, output, 512, 1024) in thread 2;
+// ... wait for thread 1 & 2 to finish executing Fill().
+//
+// NOTE:
+// 1. PhiloxRandom is trivially copyable.
+// 2. PhiloxRandom is compilable by gcc and nvcc.
+class PhiloxRandom
+{
+public:
+ using ResultType = Array<uint32_t, 4>;
+ using ResultElementType = uint32_t;
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = 4;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 10;
+ // The type for the 64-bit key stored in the form of two 32-bit uint
+ // that are used in the diffusion process.
+ using Key = Array<uint32_t, 2>;
+
+ PHILOX_DEVICE_INLINE
+ PhiloxRandom() {}
+
+ PHILOX_DEVICE_INLINE
+ explicit PhiloxRandom(uint64_t seed)
+ {
+ key_[0] = static_cast<uint32_t>(seed);
+ key_[1] = static_cast<uint32_t>(seed >> 32);
+ }
+
+ PHILOX_DEVICE_INLINE
+ explicit PhiloxRandom(uint64_t seed_lo, uint64_t seed_hi)
+ {
+ key_[0] = static_cast<uint32_t>(seed_lo);
+ key_[1] = static_cast<uint32_t>(seed_lo >> 32);
+ counter_[2] = static_cast<uint32_t>(seed_hi);
+ counter_[3] = static_cast<uint32_t>(seed_hi >> 32);
+ }
+
+ PHILOX_DEVICE_INLINE
+ PhiloxRandom(ResultType counter, Key key) : counter_(counter), key_(key) {}
+
+ PHILOX_DEVICE_INLINE
+ ResultType const &counter() const { return counter_; }
+
+ PHILOX_DEVICE_INLINE
+ Key const &key() const { return key_; }
+
+ // Skip the specified number of samples of 128-bits in the current stream.
+ PHILOX_DEVICE_INLINE
+ void Skip(uint64_t count)
+ {
+ const uint32_t count_lo = static_cast<uint32_t>(count);
+ uint32_t count_hi = static_cast<uint32_t>(count >> 32);
+
+ counter_[0] += count_lo;
+ if (counter_[0] < count_lo)
+ {
+ ++count_hi;
+ }
+
+ counter_[1] += count_hi;
+ if (counter_[1] < count_hi)
+ {
+ if (++counter_[2] == 0)
+ {
+ ++counter_[3];
+ }
+ }
+ }
+
+ // Returns a group of four random numbers using the underlying Philox
+ // algorithm.
+ PHILOX_DEVICE_INLINE ResultType operator()()
+ {
+ ResultType counter = counter_;
+ Key key = key_;
+
+ // Run the single rounds for ten times. Manually unrolling the loop
+ // for better performance.
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+ RaiseKey(&key);
+ counter = ComputeSingleRound(counter, key);
+
+ SkipOne();
+
+ return counter;
+ }
+
+private:
+ // We use the same constants as recommended by the original paper.
+ static constexpr uint32_t kPhiloxW32A = 0x9E3779B9;
+ static constexpr uint32_t kPhiloxW32B = 0xBB67AE85;
+ static constexpr uint32_t kPhiloxM4x32A = 0xD2511F53;
+ static constexpr uint32_t kPhiloxM4x32B = 0xCD9E8D57;
+
+ // Helper function to skip the next sample of 128-bits in the current stream.
+ PHILOX_DEVICE_INLINE void SkipOne()
+ {
+ if (++counter_[0] == 0)
+ {
+ if (++counter_[1] == 0)
+ {
+ if (++counter_[2] == 0)
+ {
+ ++counter_[3];
+ }
+ }
+ }
+ }
+
+ // Helper function to return the lower and higher 32-bits from two 32-bit
+ // integer multiplications.
+ PHILOX_DEVICE_INLINE
+ static void MultiplyHighLow(uint32_t a, uint32_t b, uint32_t *result_low, uint32_t *result_high)
+ {
+#ifndef __CUDA_ARCH__
+ const uint64_t product = static_cast<uint64_t>(a) * b;
+ *result_low = static_cast<uint32_t>(product);
+ *result_high = static_cast<uint32_t>(product >> 32);
+#else
+ *result_low = a * b;
+ *result_high = __umulhi(a, b);
+#endif
+ }
+
+ // Helper function for a single round of the underlying Philox algorithm.
+ PHILOX_DEVICE_INLINE static ResultType ComputeSingleRound(const ResultType &counter,
+ const Key &key)
+ {
+ uint32_t lo0;
+ uint32_t hi0;
+ MultiplyHighLow(kPhiloxM4x32A, counter[0], &lo0, &hi0);
+
+ uint32_t lo1;
+ uint32_t hi1;
+ MultiplyHighLow(kPhiloxM4x32B, counter[2], &lo1, &hi1);
+
+ ResultType result;
+ result[0] = hi1 ^ counter[1] ^ key[0];
+ result[1] = lo1;
+ result[2] = hi0 ^ counter[3] ^ key[1];
+ result[3] = lo0;
+ return result;
+ }
+
+ PHILOX_DEVICE_INLINE void RaiseKey(Key *key)
+ {
+ (*key)[0] += kPhiloxW32A;
+ (*key)[1] += kPhiloxW32B;
+ }
+
+private:
+ ResultType counter_;
+ Key key_;
+};
+
+} // namespace random
+} // namespace cker
+} // namespace nnfw
+#endif // TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_
diff --git a/compute/cker/include/cker/operation/Helper/RandomDistributions.h b/compute/cker/include/cker/operation/Helper/RandomDistributions.h
new file mode 100644
index 000000000..baeafd7c9
--- /dev/null
+++ b/compute/cker/include/cker/operation/Helper/RandomDistributions.h
@@ -0,0 +1,778 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_HELPER_RANDOM_DISTRIBUTIONS_H__
+#define __NNFW_CKER_HELPER_RANDOM_DISTRIBUTIONS_H__
+
+#include <string.h>
+
+#include <cmath>
+
+#include <algorithm>
+#include <type_traits>
+
+#include "cker/Types.h"
+#include "cker/Shape.h"
+#include "cker/Utils.h"
+
+#include "cker/eigen/EigenSupport.h"
+#include "cker/operation/Helper/PhiloxRandom.h"
+
+namespace nnfw
+{
+namespace cker
+{
+namespace random
+{
+
+// Helper function to convert a 16-bit integer to a half between [0..1).
+PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16_t x);
+// Helper function to convert a 16-bit integer to a bfloat16 between [0..1).
+// PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x);
+// Helper function to convert a 32-bit integer to a float between [0..1).
+PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32_t x);
+// Helper function to convert two 32-bit integers to a double between [0..1).
+PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32_t x0, uint32_t x1);
+
+// Computes a + b. Requires that the result is representable in the destination
+// type and that b is not maximal (i.e. b + 1 is not 0). Notably, the addend b
+// need *not* be representable in that type. (The condition on b excludes the
+// extremal case INT_MIN + UINT_MAX = INT_MAX, which this function cannot
+// compute.)
+template <typename Int>
+PHILOX_DEVICE_INLINE Int SignedAdd(Int a, typename std::make_unsigned<Int>::type b)
+{
+ // Implementation note: both b_div_2 and b - b_div_2 are positive and
+ // representable as Int.
+ auto b_div_2 = b >> 1;
+ return a + static_cast<Int>(b_div_2) + static_cast<Int>(b - b_div_2);
+}
+
+// A class that generates uniform distribution random numbers from the
+// underlying random integer generator.
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for the
+// actual returned sample type.
+// RealType: the data type of the real numbers that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class Generator, typename RealType> class UniformDistribution;
+
+template <class Generator> class UniformDistribution<Generator, Eigen::half>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = Generator::kResultElementCount;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 3;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = false;
+ typedef Array<Eigen::half, kResultElementCount> ResultType;
+ typedef Eigen::half ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator *gen)
+ {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i)
+ {
+ result[i] = Uint16ToHalf(sample[i]); // Truncate the upper 16 bits.
+ }
+ return result;
+ }
+};
+
+template <class Generator> class UniformDistribution<Generator, float>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = Generator::kResultElementCount;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 3;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = false;
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator *gen)
+ {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i)
+ {
+ result[i] = Uint32ToFloat(sample[i]);
+ }
+ return result;
+ }
+};
+
+template <class Generator> class UniformDistribution<Generator, double>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = Generator::kResultElementCount / 2;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 3;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = false;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator *gen)
+ {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i)
+ {
+ result[i] = Uint64ToDouble(sample[2 * i], sample[2 * i + 1]);
+ }
+ return result;
+ }
+};
+
+template <class Generator> class UniformDistribution<Generator, int32_t>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = Generator::kResultElementCount;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 3;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = false;
+ typedef Array<int32_t, kResultElementCount> ResultType;
+ typedef int32_t ResultElementType;
+
+ // Must have lo < hi
+ UniformDistribution(int32_t lo, int32_t hi)
+ : lo_(lo), range_(static_cast<uint32_t>(hi) - static_cast<uint32_t>(lo))
+ {
+ }
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator *gen)
+ {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i)
+ {
+ result[i] = SignedAdd(lo_, sample[i] % range_);
+ }
+ return result;
+ }
+
+private:
+ // Note that lo_ is intentionally signed while range_ is intentionally
+ // unsigned. This is because hi - lo can overflow signed integers if
+ // lo < 0 < hi, but always fits in unsigned.
+ int32_t lo_;
+ int32_t range_;
+};
+
+template <class Generator> class UniformDistribution<Generator, int64_t>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = Generator::kResultElementCount / 2;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 3;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = false;
+ typedef Array<int64_t, kResultElementCount> ResultType;
+ typedef int64_t ResultElementType;
+
+ // Must have lo < hi
+ UniformDistribution(int64_t lo, int64_t hi)
+ : lo_(lo), range_(static_cast<uint64_t>(hi) - static_cast<uint64_t>(lo))
+ {
+ }
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator *gen)
+ {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i)
+ {
+ auto bits = sample[2 * i] | static_cast<uint64_t>(sample[2 * i + 1]) << 32;
+ result[i] = SignedAdd(lo_, bits % range_);
+ }
+ return result;
+ }
+
+private:
+ // Note that lo_ is intentionally signed while range_ is intentionally
+ // unsigned. This is because hi - lo can overflow signed integers if
+ // lo < 0 < hi, but always fits in unsigned.
+ int64_t lo_;
+ uint64_t range_;
+};
+
+// Similar to `UniformDistribution`, except that instead of generating numbers
+// in the range [low, high), it generates numbers covering the whole range of
+// the integer type.
+template <typename Generator, typename IntType> class UniformFullIntDistribution;
+
+template <typename Generator, typename IntType> class UniformFullIntDistribution32
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = Generator::kResultElementCount;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 3;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = false;
+ typedef Array<IntType, kResultElementCount> ResultType;
+ typedef IntType ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator *gen)
+ {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i)
+ {
+ result[i] = sample[i];
+ }
+ return result;
+ }
+};
+
+template <typename Generator, typename IntType> class UniformFullIntDistribution64
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = Generator::kResultElementCount / 2;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 3;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = false;
+ typedef Array<IntType, kResultElementCount> ResultType;
+ typedef IntType ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator *gen)
+ {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; ++i)
+ {
+ result[i] = sample[2 * i] | static_cast<uint64_t>(sample[2 * i + 1]) << 32;
+ }
+ return result;
+ }
+};
+
+template <typename Generator>
+class UniformFullIntDistribution<Generator, int32_t>
+ : public UniformFullIntDistribution32<Generator, int32_t>
+{
+};
+template <typename Generator>
+class UniformFullIntDistribution<Generator, uint32_t>
+ : public UniformFullIntDistribution32<Generator, uint32_t>
+{
+};
+template <typename Generator>
+class UniformFullIntDistribution<Generator, int64_t>
+ : public UniformFullIntDistribution64<Generator, int64_t>
+{
+};
+template <typename Generator>
+class UniformFullIntDistribution<Generator, uint64_t>
+ : public UniformFullIntDistribution64<Generator, uint64_t>
+{
+};
+
+// A class that adapts the underlying native multiple samples to return a single
+// sample at a time.
+template <class Generator> class SingleSampleAdapter
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = 1;
+ // The number of elements that will be returned by the underlying generator.
+ static constexpr int kNativeElementCount = Generator::kResultElementCount;
+ typedef typename Generator::ResultElementType ResultType;
+ typedef typename Generator::ResultElementType ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ explicit SingleSampleAdapter(Generator *gen)
+ : generator_(gen), used_result_index_(Generator::kResultElementCount)
+ {
+ }
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()()
+ {
+ if (used_result_index_ == Generator::kResultElementCount)
+ {
+ unused_results_ = (*generator_)();
+ used_result_index_ = 0;
+ }
+
+ return unused_results_[used_result_index_++];
+ }
+
+ PHILOX_DEVICE_INLINE
+ void Skip(uint64_t num_skips)
+ {
+ if (!num_skips)
+ {
+ return;
+ }
+ int num_unused_results = kNativeElementCount - used_result_index_;
+ if (num_skips <= num_unused_results)
+ {
+ used_result_index_ += num_skips;
+ return;
+ }
+ num_skips -= num_unused_results;
+ used_result_index_ = kNativeElementCount;
+ SkipFromGenerator(num_skips / kNativeElementCount);
+ num_skips = num_skips % kNativeElementCount;
+ if (num_skips)
+ {
+ unused_results_ = (*generator_)();
+ used_result_index_ = num_skips;
+ }
+ }
+
+private:
+ // This implementation iteratively skips over `num_skips` samples
+ // from `generator_`. There is an O(1) implementation for PhiloxRandom
+ // in random_distributions.cc.
+ PHILOX_DEVICE_INLINE
+ void SkipFromGenerator(uint64_t num_skips)
+ {
+ while (num_skips--)
+ {
+ (*generator_)();
+ }
+ }
+
+ Generator *generator_;
+ typename Generator::ResultType unused_results_;
+ int used_result_index_;
+};
+
+// A class that generates unit normal distribution random numbers from the
+// underlying random integer generator.
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numbers that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class Generator, typename RealType> class NormalDistribution;
+
+PHILOX_DEVICE_INLINE
+void BoxMullerFloat(uint32_t x0, uint32_t x1, float *f0, float *f1);
+
+PHILOX_DEVICE_INLINE
+void BoxMullerDouble(uint32_t x0, uint32_t x1, uint32_t x2, uint32_t x3, double *d0, double *d1);
+
+// Exactly like the float version, except that we convert to half afterwards;
+// since we don't have half-precision sin/cos even on GPUs, there's nothing to
+// gain from working in half internally.
+template <class Generator> class NormalDistribution<Generator, Eigen::half>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = Generator::kResultElementCount;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 70;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = false;
+ typedef Array<Eigen::half, kResultElementCount> ResultType;
+ typedef Eigen::half ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator *gen)
+ {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; i += 2)
+ {
+ float f[2];
+ BoxMullerFloat(sample[i], sample[i + 1], &f[0], &f[1]);
+ result[i] = Eigen::half(f[0]);
+ result[i + 1] = Eigen::half(f[1]);
+ }
+ return result;
+ }
+};
+
+template <class Generator> class NormalDistribution<Generator, float>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = Generator::kResultElementCount;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 70;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = false;
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator *gen)
+ {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; i += 2)
+ {
+ BoxMullerFloat(sample[i], sample[i + 1], &result[i], &result[i + 1]);
+ }
+ return result;
+ }
+};
+
+template <class Generator> class NormalDistribution<Generator, double>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = Generator::kResultElementCount / 2;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 70;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = false;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(Generator *gen)
+ {
+ typename Generator::ResultType sample = (*gen)();
+ ResultType result;
+ for (int i = 0; i < kResultElementCount; i += 2)
+ {
+ const int i2 = 2 * i;
+ BoxMullerDouble(sample[i2], sample[i2 + 1], sample[i2 + 2], sample[i2 + 3], &result[i],
+ &result[i + 1]);
+ }
+ return result;
+ }
+};
+
+// A class that returns standard normal distribution between
+// [-kTruncateValue, kTruncateValue].
+// Arguments:
+// Generator: a generator type that returns a number of uint32 upon each
+// each invocation. It needs to define kResultElementCount for the
+// sample count for each invocation, and ResultType for actual
+// returned sample type.
+// RealType: the data type of the real numbers that will be returned by the
+// distribution. This could be either float or double for now.
+// This class is meant to be implemented through specialization. The default
+// is not defined by design.
+template <class SingleSampleGenerator, typename RealType> class TruncatedNormalDistribution;
+
+// Exactly like the float version, except that we convert to half afterwards;
+// since we don't have half-precision sin/cos even on GPUs, there's nothing to
+// gain from working in half internally.
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, Eigen::half>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = SingleSampleGenerator::kNativeElementCount;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 90;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = true;
+ // The threshold where the normal distribution is truncated.
+ const float kTruncateValue = 2.0f;
+
+ typedef Array<Eigen::half, kResultElementCount> ResultType;
+ typedef Eigen::half ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(SingleSampleGenerator *gen)
+ {
+ ResultType results;
+ int index = 0;
+ while (true)
+ {
+ // Repeatedly take samples from the normal distribution, until we have
+ // the desired number of elements that fall within the pre-defined cutoff
+ // threshold.
+ const uint32_t x0 = (*gen)();
+ const uint32_t x1 = (*gen)();
+ float f[2];
+ BoxMullerFloat(x0, x1, &f[0], &f[1]);
+
+ if (Eigen::numext::abs(f[0]) < kTruncateValue)
+ {
+ results[index++] = Eigen::half(f[0]);
+ if (index >= kResultElementCount)
+ {
+ return results;
+ }
+ }
+ if (Eigen::numext::abs(f[1]) < kTruncateValue)
+ {
+ results[index++] = Eigen::half(f[1]);
+ if (index >= kResultElementCount)
+ {
+ return results;
+ }
+ }
+ }
+ }
+};
+
+// Partial specialization for float.
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, float>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = SingleSampleGenerator::kNativeElementCount;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 90;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = true;
+ // The threshold where the normal distribution is truncated.
+ const float kTruncateValue = 2.0f;
+
+ typedef Array<float, kResultElementCount> ResultType;
+ typedef float ResultElementType;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(SingleSampleGenerator *gen)
+ {
+ ResultType results;
+ int index = 0;
+ while (true)
+ {
+ // Repeatedly take samples from the normal distribution, until we have
+ // the desired number of elements that fall within the pre-defined cutoff
+ // threshold.
+ const uint32_t x0 = (*gen)();
+ const uint32_t x1 = (*gen)();
+ float f[2];
+ BoxMullerFloat(x0, x1, &f[0], &f[1]);
+
+ if (Eigen::numext::abs(f[0]) < kTruncateValue)
+ {
+ results[index++] = f[0];
+ if (index >= kResultElementCount)
+ {
+ return results;
+ }
+ }
+ if (Eigen::numext::abs(f[1]) < kTruncateValue)
+ {
+ results[index++] = f[1];
+ if (index >= kResultElementCount)
+ {
+ return results;
+ }
+ }
+ }
+ }
+};
+
+// Partial specialization for double.
+template <class SingleSampleGenerator>
+class TruncatedNormalDistribution<SingleSampleGenerator, double>
+{
+public:
+ // The number of elements that will be returned.
+ static constexpr int kResultElementCount = (SingleSampleGenerator::kNativeElementCount > 1)
+ ? SingleSampleGenerator::kNativeElementCount / 2
+ : 1;
+ // Cost of generation of a single element (in cycles).
+ static constexpr int kElementCost = 90;
+ // Indicate that this distribution may take variable number of samples
+ // during the runtime.
+ static constexpr bool kVariableSamplesPerOutput = true;
+ typedef Array<double, kResultElementCount> ResultType;
+ typedef double ResultElementType;
+ const double kTruncateValue = 2.0;
+
+ PHILOX_DEVICE_INLINE
+ ResultType operator()(SingleSampleGenerator *gen)
+ {
+ ResultType results;
+ int index = 0;
+ while (1)
+ {
+ const uint32_t x0 = (*gen)();
+ const uint32_t x1 = (*gen)();
+ const uint32_t x2 = (*gen)();
+ const uint32_t x3 = (*gen)();
+ double d[2];
+ BoxMullerDouble(x0, x1, x2, x3, &d[0], &d[1]);
+
+ if (Eigen::numext::abs(d[0]) < kTruncateValue)
+ {
+ results[index++] = d[0];
+ if (index >= kResultElementCount)
+ {
+ return results;
+ }
+ }
+ if (Eigen::numext::abs(d[1]) < kTruncateValue)
+ {
+ results[index++] = d[1];
+ if (index >= kResultElementCount)
+ {
+ return results;
+ }
+ }
+ }
+ }
+};
+
+// Helper function to convert two 32-bit uniform integers to two floats
+// under the unit normal distribution.
+PHILOX_DEVICE_INLINE
+void BoxMullerFloat(uint32_t x0, uint32_t x1, float *f0, float *f1)
+{
+ // This function implements the Box-Muller transform:
+ // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
+ // Do not send a really small number to log().
+ // We cannot mark "epsilon" as "static const" because NVCC would complain
+ const float epsilon = 1.0e-7f;
+ float u1 = Uint32ToFloat(x0);
+ if (u1 < epsilon)
+ {
+ u1 = epsilon;
+ }
+ const float v1 = 2.0f * M_PI * Uint32ToFloat(x1);
+ const float u2 = Eigen::numext::sqrt(-2.0f * Eigen::numext::log(u1));
+#if defined(TENSORFLOW_USE_SYCL) || !defined(__linux__)
+ *f0 = Eigen::numext::sin(v1);
+ *f1 = Eigen::numext::cos(v1);
+#else
+ sincosf(v1, f0, f1);
+#endif
+ *f0 *= u2;
+ *f1 *= u2;
+}
+
+// Helper function to convert four 32-bit uniform integers to two doubles
+// under the unit normal distribution.
+PHILOX_DEVICE_INLINE
+void BoxMullerDouble(uint32_t x0, uint32_t x1, uint32_t x2, uint32_t x3, double *d0, double *d1)
+{
+ // This function implements the Box-Muller transform:
+ // http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
+ // Do not send a really small number to log().
+ // We cannot mark "epsilon" as "static const" because NVCC would complain
+ const double epsilon = 1.0e-7;
+ double u1 = Uint64ToDouble(x0, x1);
+ if (u1 < epsilon)
+ {
+ u1 = epsilon;
+ }
+ const double v1 = 2 * M_PI * Uint64ToDouble(x2, x3);
+ const double u2 = Eigen::numext::sqrt(-2.0 * Eigen::numext::log(u1));
+#if defined(TENSORFLOW_USE_SYCL) || !defined(__linux__)
+ *d0 = Eigen::numext::sin(v1);
+ *d1 = Eigen::numext::cos(v1);
+#else
+ sincos(v1, d0, d1);
+#endif
+ *d0 *= u2;
+ *d1 *= u2;
+}
+
+// Helper function to convert an 16-bit integer to a half between [0..1).
+PHILOX_DEVICE_INLINE Eigen::half Uint16ToHalf(uint16_t x)
+{
+ // IEEE754 halfs are formatted as follows (MSB first):
+ // sign(1) exponent(5) mantissa(10)
+ // Conceptually construct the following:
+ // sign == 0
+ // exponent == 15 -- an excess 15 representation of a zero exponent
+ // mantissa == 10 random bits
+ const uint16_t man = x & 0x3ffu; // 10 bit mantissa
+ const uint16_t exp = static_cast<uint16_t>(15);
+ const uint16_t val = (exp << 10) | man;
+
+ Eigen::half result;
+ result.x = val;
+ return result - Eigen::half(1.0);
+}
+
+// Helper function to convert an 32-bit integer to a float between [0..1).
+PHILOX_DEVICE_INLINE float Uint32ToFloat(uint32_t x)
+{
+ // IEEE754 floats are formatted as follows (MSB first):
+ // sign(1) exponent(8) mantissa(23)
+ // Conceptually construct the following:
+ // sign == 0
+ // exponent == 127 -- an excess 127 representation of a zero exponent
+ // mantissa == 23 random bits
+ const uint32_t man = x & 0x7fffffu; // 23 bit mantissa
+ const uint32_t exp = static_cast<uint32_t>(127);
+ const uint32_t val = (exp << 23) | man;
+
+ // Assumes that endian-ness is same for float and uint32.
+ float result;
+ memcpy(&result, &val, sizeof(val));
+ return result - 1.0f;
+}
+
+// Helper function to convert two 32-bit integers to a double between [0..1).
+PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32_t x0, uint32_t x1)
+{
+ // IEEE754 doubles are formatted as follows (MSB first):
+ // sign(1) exponent(11) mantissa(52)
+ // Conceptually construct the following:
+ // sign == 0
+ // exponent == 1023 -- an excess 1023 representation of a zero exponent
+ // mantissa == 52 random bits
+ const uint32_t mhi = x0 & 0xfffffu; // upper 20 bits of mantissa
+ const uint32_t mlo = x1; // lower 32 bits of mantissa
+ const uint64_t man = (static_cast<uint64_t>(mhi) << 32) | mlo; // mantissa
+ const uint64_t exp = static_cast<uint64_t>(1023);
+ const uint64_t val = (exp << 52) | man;
+ // Assumes that endian-ness is same for double and uint64.
+ double result;
+ memcpy(&result, &val, sizeof(val));
+ return result - 1.0;
+}
+
+} // namespace random
+} // namespace tensorflow
+}
+
+#endif // __NNFW_CKER_HELPER_RANDOM_DISTRIBUTIONS_H__
diff --git a/compute/cker/include/cker/operation/Helper/RandomOp.h b/compute/cker/include/cker/operation/Helper/RandomOp.h
new file mode 100644
index 000000000..7dc51fe94
--- /dev/null
+++ b/compute/cker/include/cker/operation/Helper/RandomOp.h
@@ -0,0 +1,52 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_HELPER_RANDOM_OP_H__
+#define __NNFW_CKER_HELPER_RANDOM_OP_H__
+
+#include "cker/Types.h"
+#include "cker/Shape.h"
+#include "cker/Utils.h"
+
+#include "cker/operation/Helper/RandomDistributions.h"
+
+namespace nnfw
+{
+namespace cker
+{
+
+namespace functor
+{
+
+template <typename Device, class Distribution> struct FillPhiloxRandom;
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+// Declares the partially CPU-specialized functor struct.
+//
+// NOTE: Due to inlining done by the compiler, you may need to add
+// explicit instantiation of the functor in random_op.cc. See example
+// functor::FillPhiloxRandom<CPUDevice, random::UniformDistribution>.
+template <class Distribution> struct FillPhiloxRandom<CPUDevice, Distribution>
+{
+ void operator()(random::PhiloxRandom gen, typename Distribution::ResultElementType *data,
+ int64_t size, Distribution dist);
+};
+
+} // namespace functor
+} // namespace tensorflow
+}
+#endif // __NNFW_CKER_HELPER_RANDOM_OP_H__
diff --git a/compute/cker/include/cker/operation/Helper/RandomOpCpu.h b/compute/cker/include/cker/operation/Helper/RandomOpCpu.h
new file mode 100644
index 000000000..85d267723
--- /dev/null
+++ b/compute/cker/include/cker/operation/Helper/RandomOpCpu.h
@@ -0,0 +1,163 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_HELPER_RANDOM_OP_CPU_H__
+#define __NNFW_CKER_HELPER_RANDOM_OP_CPU_H__
+
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+
+#include "cker/Types.h"
+#include "cker/Shape.h"
+#include "cker/Utils.h"
+
+#include "cker/eigen/EigenSupport.h"
+
+#include "cker/operation/Helper/PhiloxRandom.h"
+#include "cker/operation/Helper/RandomOp.h"
+#include "cker/operation/Helper/RandomDistributions.h"
+
+#if EIGEN_COMP_GNUC && __cplusplus > 199711L
+#define DISABLE_FLOAT_EQUALITY_WARNING \
+ _Pragma("GCC diagnostic push") _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
+#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
+#else
+#define DISABLE_FLOAT_EQUALITY_WARNING
+#define ENABLE_FLOAT_EQUALITY_WARNING
+#endif
+
+namespace nnfw
+{
+namespace cker
+{
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+namespace functor
+{
+using random::PhiloxRandom;
+using random::SingleSampleAdapter;
+
+// The default implementation of the functor, which should never be invoked
+// But we still need to provide implementation for now for the linker to work,
+// since we do not support all the distributions yet.
+template <typename Device, class Distribution> struct FillPhiloxRandom
+{
+ typedef typename Distribution::ResultElementType T;
+ void operator()() {}
+};
+
+// A class to fill a specified range of random groups
+template <class Distribution, bool VariableSamplesPerOutput> struct FillPhiloxRandomTask;
+
+// Specialization for distribution that takes a fixed number of samples for
+// each output.
+template <class Distribution> struct FillPhiloxRandomTask<Distribution, false>
+{
+ typedef typename Distribution::ResultElementType T;
+ static void Run(random::PhiloxRandom gen, T *data, int64_t size, Distribution dist)
+ {
+ const int kGroupSize = Distribution::kResultElementCount;
+ gen.Skip(0);
+ int64_t offset = 0;
+
+ // First fill all the full-size groups
+ int64_t limit_group_full = size / kGroupSize;
+ for (int64_t index = 0; index < limit_group_full; ++index)
+ {
+ auto samples = dist(&gen);
+ std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
+ offset += kGroupSize;
+ }
+
+ int64_t remaining_size = size - limit_group_full * kGroupSize;
+
+ // If there are any remaining elements that need to be filled, process them
+ if (remaining_size > 0)
+ {
+ auto samples = dist(&gen);
+ std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
+ }
+ }
+};
+
+// Specialization for distribution that takes a variable number of samples for
+// each output. This will be slower due to the generality.
+template <class Distribution> struct FillPhiloxRandomTask<Distribution, true>
+{
+ typedef typename Distribution::ResultElementType T;
+ static constexpr int64_t kReservedSamplesPerOutput = 256;
+
+ static void Run(random::PhiloxRandom base_gen, T *data, int64_t size, Distribution dist)
+ {
+ const int kGroupSize = Distribution::kResultElementCount;
+ static const int kGeneratorSkipPerOutputGroup =
+ kGroupSize * kReservedSamplesPerOutput / PhiloxRandom::kResultElementCount;
+
+ int64_t offset = 0;
+
+ // First fill all the full-size groups
+ int64_t limit_group_full = size / kGroupSize;
+ int64_t group_index;
+ for (group_index = 0; group_index < limit_group_full; ++group_index)
+ {
+ // Reset the generator to the beginning of the output group region
+ // This is necessary if we want the results to be independent of order
+ // of work
+ PhiloxRandom gen = base_gen;
+ gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
+ SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
+
+ auto samples = dist(&single_samples);
+ std::copy(&samples[0], &samples[0] + kGroupSize, data + offset);
+ offset += kGroupSize;
+ }
+
+ int64_t remaining_size = size - limit_group_full * kGroupSize;
+ // If there are any remaining elements that need to be filled, process them
+ if (remaining_size > 0)
+ {
+ PhiloxRandom gen = base_gen;
+ gen.Skip(group_index * kGeneratorSkipPerOutputGroup);
+ SingleSampleAdapter<PhiloxRandom> single_samples(&gen);
+
+ auto samples = dist(&single_samples);
+ std::copy(&samples[0], &samples[0] + remaining_size, data + offset);
+ }
+ }
+};
+
+// Partial specialization for CPU to fill the entire region with randoms
+// It splits the work into several tasks and run them in parallel
+template <class Distribution>
+void FillPhiloxRandom<CPUDevice, Distribution>::
+operator()(random::PhiloxRandom gen, typename Distribution::ResultElementType *data, int64_t size,
+ Distribution dist)
+{
+ FillPhiloxRandomTask<Distribution, Distribution::kVariableSamplesPerOutput>::Run(gen, data, size,
+ dist);
+}
+
+} // namespace functor
+
+} // end namespace tensorflow
+}
+
+#endif // __NNFW_CKER_HELPER_RANDOM_OP_CPU_H__
diff --git a/compute/cker/include/cker/operation/L2Normalize.h b/compute/cker/include/cker/operation/L2Normalize.h
new file mode 100644
index 000000000..a0075c3d0
--- /dev/null
+++ b/compute/cker/include/cker/operation/L2Normalize.h
@@ -0,0 +1,94 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_L2NORMALIZE_H__
+#define __NNFW_CKER_L2NORMALIZE_H__
+
+#include "cker/Shape.h"
+#include "cker/Utils.h"
+#include "cker/Types.h"
+
+namespace nnfw
+{
+namespace cker
+{
+
+void L2NormalizeFloat32(const Shape &input_shape, const float *input_data,
+ const Shape &output_shape, float *output_data)
+{
+ float epsilon = 1e-6;
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+ for (int i = 0; i < outer_size; ++i)
+ {
+ float squared_l2_norm = 0;
+ for (int c = 0; c < depth; ++c)
+ {
+ const float val = input_data[c];
+ squared_l2_norm += val * val;
+ }
+ float l2_norm = std::sqrt(squared_l2_norm);
+ l2_norm = std::max(l2_norm, epsilon);
+ for (int c = 0; c < depth; ++c)
+ {
+ *output_data = *input_data / l2_norm;
+ ++output_data;
+ ++input_data;
+ }
+ }
+}
+
+void L2NormalizeQuant8(L2NormParams &params, const Shape &input_shape, const uint8_t *input_data,
+ const Shape &output_shape, uint8_t *output_data)
+{
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+ const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int32_t input_zero_point = params.input_zero_point;
+
+ for (int i = 0; i < outer_size; ++i)
+ {
+ int32_t square_l2_norm = 0;
+ for (int c = 0; c < depth; c++)
+ {
+ // Note that input_data advances by depth in the second pass below.
+ int32_t diff = input_data[c] - input_zero_point;
+ square_l2_norm += diff * diff;
+ }
+ int32_t inv_l2norm_multiplier;
+ int inv_l2norm_shift;
+ GetInvSqrtQuantizedMultiplierExp(square_l2_norm, -1, &inv_l2norm_multiplier, &inv_l2norm_shift);
+ for (int c = 0; c < depth; c++)
+ {
+ int32_t diff = *input_data - input_zero_point;
+ int32_t rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp(
+ 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift);
+ int32_t unclamped_output_val = 128 + rescaled_diff;
+ int32_t output_val = std::min(static_cast<int32_t>(255),
+ std::max(static_cast<int32_t>(0), unclamped_output_val));
+ *output_data = static_cast<uint8_t>(output_val);
+ ++input_data;
+ ++output_data;
+ }
+ }
+}
+
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_L2NORMALIZE_H__
diff --git a/compute/cker/include/cker/operation/Logistic.h b/compute/cker/include/cker/operation/Logistic.h
index 7477858fc..3d3e59e55 100644
--- a/compute/cker/include/cker/operation/Logistic.h
+++ b/compute/cker/include/cker/operation/Logistic.h
@@ -32,18 +32,9 @@ namespace cker
inline void Logistic(const Shape &input_shape, const float *input_data, const Shape &output_shape,
float *output_data)
{
-#ifdef __aarch64__
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().unaryExpr(Eigen::internal::scalar_logistic_op<float>());
-#else
- // Note, this can be done using TANH: (1/2) + (1/2) * TANH(x/2)
- const int size = MatchingFlatSize(input_shape, output_shape);
- for (int i = 0; i < size; i++)
- {
- output_data[i] = 1.f / (1.f + std::exp(-input_data[i]));
- }
-#endif
}
} // namespace cker
diff --git a/compute/cker/include/cker/operation/MatrixBandPart.h b/compute/cker/include/cker/operation/MatrixBandPart.h
index 9f49c8fdd..5674ff3ef 100644
--- a/compute/cker/include/cker/operation/MatrixBandPart.h
+++ b/compute/cker/include/cker/operation/MatrixBandPart.h
@@ -32,10 +32,10 @@ void MatrixBandPart(const T num_lower_diags, const T num_upper_diags, const Shap
{
auto last_dim = input_shape.DimensionsCount() - 1;
- T batch_num = 0;
- for (int dim = 0; dim < last_dim - 2; dim++)
+ T batch_num = 1;
+ for (int dim = 0; dim < input_shape.DimensionsCount() - 2; dim++)
{
- batch_num += input_shape.Dims(dim);
+ batch_num *= input_shape.Dims(dim);
}
const T row_num = input_shape.Dims(last_dim - 1);
diff --git a/compute/cker/include/cker/operation/Pad.h b/compute/cker/include/cker/operation/Pad.h
index af432f3a8..4a2732d82 100644
--- a/compute/cker/include/cker/operation/Pad.h
+++ b/compute/cker/include/cker/operation/Pad.h
@@ -26,9 +26,10 @@ namespace nnfw
{
namespace cker
{
+template <typename T>
inline void Pad(const int32_t *padding_data, int32_t pad_rank, const Shape &input_shape,
- const float *input_data, const Shape &output_shape, float *output_data,
- const float *constant_value_data)
+ const T *input_data, const Shape &output_shape, T *output_data,
+ const T *constant_value_data)
{
// Note, this is pad with mode=`CONSTANT`: it doesn't support `REFLECT` and `SYMMETRIC`
// TODO: come up with more subtle solution that uses subtensors like arm compute
@@ -38,7 +39,7 @@ inline void Pad(const int32_t *padding_data, int32_t pad_rank, const Shape &inpu
/** List of padding information */
using PaddingList = std::vector<PaddingInfo>;
- auto constant_value = constant_value_data ? *constant_value_data : 0;
+ const T constant_value = constant_value_data ? *constant_value_data : 0;
assert(output_shape.DimensionsCount() == input_shape.DimensionsCount());
PaddingList padding_list(pad_rank);
@@ -64,7 +65,7 @@ inline void Pad(const int32_t *padding_data, int32_t pad_rank, const Shape &inpu
{
const int32_t in_row_len = input_shape.Dims(0);
std::fill_n(output_data, padding_list[0].first, constant_value);
- std::memcpy(output_data + padding_list[0].first, input_data, in_row_len * sizeof(float));
+ std::memcpy(output_data + padding_list[0].first, input_data, in_row_len * sizeof(T));
std::fill_n(output_data + padding_list[0].first + in_row_len, padding_list[0].second,
constant_value);
break;
@@ -89,7 +90,7 @@ inline void Pad(const int32_t *padding_data, int32_t pad_rank, const Shape &inpu
out_offset += padding_list[1].first;
// copy a row of input data
- memcpy(output_data + out_offset, input_data + in_offset, in_row_len * sizeof(float));
+ memcpy(output_data + out_offset, input_data + in_offset, in_row_len * sizeof(T));
out_offset += in_row_len;
@@ -132,7 +133,7 @@ inline void Pad(const int32_t *padding_data, int32_t pad_rank, const Shape &inpu
out_offset += padding_list[2].first;
// copy a row of input data
- memcpy(output_data + out_offset, input_data + in_offset, in_row_len * sizeof(float));
+ memcpy(output_data + out_offset, input_data + in_offset, in_row_len * sizeof(T));
out_offset += in_row_len;
@@ -191,7 +192,7 @@ inline void Pad(const int32_t *padding_data, int32_t pad_rank, const Shape &inpu
out_c_offset += padding_list[3].first;
// copy a row of input data
- memcpy(output_data + out_c_offset, input_data + in_offset, in_row_len * sizeof(float));
+ memcpy(output_data + out_c_offset, input_data + in_offset, in_row_len * sizeof(T));
out_c_offset += in_row_len;
diff --git a/compute/cker/include/cker/operation/Quantize.h b/compute/cker/include/cker/operation/Quantize.h
new file mode 100644
index 000000000..5c82d111f
--- /dev/null
+++ b/compute/cker/include/cker/operation/Quantize.h
@@ -0,0 +1,47 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_QUANTIZE_H__
+#define __NNFW_CKER_QUANTIZE_H__
+
+#include "cker/Shape.h"
+#include "cker/Types.h"
+#include "cker/Utils.h"
+#include <stdexcept>
+#include <iostream>
+namespace nnfw
+{
+namespace cker
+{
+template <typename InputT, typename OutputT>
+inline void Quantize(const Shape &input_shape, const InputT *input_data, const Shape &output_shape,
+ OutputT *output_data, const float output_scale, const int32_t output_offset)
+{
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
+ int min_val = std::numeric_limits<OutputT>::min();
+ int max_val = std::numeric_limits<OutputT>::max();
+
+ for (int i = 0; i < flat_size; i++)
+ {
+ int32_t unclamped = static_cast<int32_t>(round(input_data[i] / output_scale)) + output_offset;
+ int32_t clamped = std::min(std::max(unclamped, min_val), max_val);
+ output_data[i] = clamped;
+ }
+}
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_QUANTIZE_H__
diff --git a/compute/cker/include/cker/operation/ReLU6.h b/compute/cker/include/cker/operation/ReLU6.h
new file mode 100644
index 000000000..20df561dc
--- /dev/null
+++ b/compute/cker/include/cker/operation/ReLU6.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_RELU6_H__
+#define __NNFW_CKER_RELU6_H__
+
+#include "cker/Shape.h"
+#include "cker/eigen/Utils.h"
+
+#include <cmath>
+#include <Eigen/Core>
+
+namespace nnfw
+{
+namespace cker
+{
+
+inline void ReLU6(const Shape &input_shape, const float *input_data, float *output_data)
+{
+ int size = input_shape.FlatSize();
+
+ for (int i = 0; i < size; ++i)
+ {
+ if (input_data[i] <= 0)
+ {
+ output_data[i] = 0;
+ }
+ else if (input_data[i] > 6.0)
+ {
+ output_data[i] = 6.0;
+ }
+ else
+ {
+ output_data[i] = input_data[i];
+ }
+ }
+}
+
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_RELU6_H__
diff --git a/compute/cker/include/cker/operation/Reduce.h b/compute/cker/include/cker/operation/Reduce.h
index 4ba3652d3..cf9634a67 100644
--- a/compute/cker/include/cker/operation/Reduce.h
+++ b/compute/cker/include/cker/operation/Reduce.h
@@ -159,6 +159,92 @@ public:
num_resolved_axis, temp_index_data(), reducer, output_data);
}
+ // Computes the mean of elements across dimensions given in axis.
+ // It does so in two stages, first calculates the sum of elements along the axis
+ // then divides it by the number of element in axis for quantized values.
+ template <typename T, typename U>
+ inline bool QuantizedMeanOrSum(const T *input_data, int32_t input_zero_point, float input_scale,
+ const Shape &input_shape, T *output_data,
+ int32_t output_zero_point, float output_scale,
+ const Shape &output_shape, const std::vector<int> &axes,
+ bool /*keep_dims*/, U *temp_sum, bool compute_sum,
+ U reducer(const U current, const T in))
+ {
+ // Reset output data.
+ size_t num_outputs = 1;
+ for (int idx = 0; idx < output_shape.DimensionsCount(); ++idx)
+ {
+ size_t current = static_cast<size_t>(output_shape.Dims(idx));
+ // Overflow prevention.
+ if (num_outputs > std::numeric_limits<size_t>::max() / current)
+ {
+ return false;
+ }
+ num_outputs *= current;
+ }
+ for (size_t idx = 0; idx < num_outputs; ++idx)
+ {
+ output_data[idx] = T();
+ temp_sum[idx] = U();
+ }
+
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ if (!ResolveAxis(input_shape.DimensionsCount(), axes, resolved_axis_data(), &num_resolved_axis))
+ {
+ return false;
+ }
+
+ if (!ReduceImpl<T, U>(input_data, input_shape, output_shape, resolved_axis_data(),
+ num_resolved_axis, temp_index_data(), reducer, temp_sum))
+ {
+ return false;
+ }
+
+ // Calculate mean by dividing output_data by num of aggregated element.
+ U num_elements_in_axis = 1;
+ for (int idx = 0; idx < num_resolved_axis; ++idx)
+ {
+ size_t current = static_cast<size_t>(input_shape.Dims(resolved_axis_data()[idx]));
+ // Overflow prevention.
+ if (current > static_cast<size_t>(std::numeric_limits<U>::max() / num_elements_in_axis))
+ {
+ return false;
+ }
+ num_elements_in_axis *= current;
+ }
+
+ if (num_elements_in_axis > 0)
+ {
+ const float scale = input_scale / output_scale;
+ if (compute_sum)
+ {
+ // TODO(b/116341117): Eliminate float and do this completely in 8bit.
+ const float bias = -input_zero_point * scale * num_elements_in_axis + 0.5f;
+ for (size_t idx = 0; idx < num_outputs; ++idx)
+ {
+ const U value =
+ static_cast<U>(std::round(temp_sum[idx] * scale + bias)) + output_zero_point;
+ output_data[idx] = static_cast<T>(value);
+ }
+ }
+ else
+ {
+ const float bias = -input_zero_point * scale + 0.5f;
+ for (size_t idx = 0; idx < num_outputs; ++idx)
+ {
+ float float_mean =
+ static_cast<float>(temp_sum[idx]) / static_cast<float>(num_elements_in_axis);
+ float result = std::min(std::round(float_mean * scale + bias) + output_zero_point,
+ static_cast<float>(std::numeric_limits<T>::max()));
+ result = std::max(result, static_cast<float>(std::numeric_limits<T>::min()));
+ output_data[idx] = static_cast<T>(result);
+ }
+ }
+ }
+ return true;
+ }
+
inline int32_t *resolved_axis_data(void)
{
return _resolved_axis.size() ? _resolved_axis.data() : _resolved_axis_small;
diff --git a/compute/cker/include/cker/operation/ResizeBilinear.h b/compute/cker/include/cker/operation/ResizeBilinear.h
new file mode 100644
index 000000000..7fc1e9123
--- /dev/null
+++ b/compute/cker/include/cker/operation/ResizeBilinear.h
@@ -0,0 +1,270 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_RESIZEBILINEAR_H__
+#define __NNFW_CKER_RESIZEBILINEAR_H__
+
+#include "cker/Shape.h"
+#include "cker/Types.h"
+#include <cmath>
+
+namespace nnfw
+{
+namespace cker
+{
+
+inline void ResizeBilinearKernel2x2(int32_t x0, int32_t x1, int32_t y0, int32_t y1, int32_t x,
+ int32_t y, int32_t depth, int32_t batch,
+ const Shape &input_shape, const float *input_data,
+ const Shape &output_shape, float *output_data)
+{
+ const int32_t input_width = input_shape.Dims(2);
+ const int32_t output_width = output_shape.Dims(2);
+
+ const int32_t input_x_offset = (x1 - x0) * depth;
+ const int32_t input_y_offset = (y1 - y0) * depth * input_width;
+ const int32_t output_x_offset = depth;
+ const int32_t output_y_offset = depth * output_width;
+
+ for (int ch = 0; ch < depth; ch++)
+ {
+ const int32_t input_offset = Offset(input_shape, batch, y0, x0, ch);
+
+ float x0y0 = input_data[input_offset];
+ float x1y0 = input_data[input_offset + input_x_offset];
+ float x0y1 = input_data[input_offset + input_y_offset];
+ float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
+
+ // Top left corner.
+ const int32_t output_offset = Offset(output_shape, batch, y, x, ch);
+ output_data[output_offset] = x0y0;
+
+ // Top right corner.
+ output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2;
+
+ // Bottom left corner.
+ float output = (x0y0 + x0y1) / 2;
+ output_data[output_offset + output_y_offset] = output;
+
+ // Bottom right corner.
+ output_data[output_offset + output_x_offset + output_y_offset] =
+ (output + ((x1y0 + x1y1) / 2)) / 2;
+ }
+}
+
+inline void ResizeBilinear2x2(int32_t batches, int32_t input_height, int32_t input_width,
+ int32_t depth, int32_t output_height, int32_t output_width,
+ const Shape &input_shape, const float *input_data,
+ const Shape &output_shape, float *output_data)
+{
+ for (int b = 0; b < batches; b++)
+ {
+ for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++)
+ {
+ for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++)
+ {
+ int32_t x1 = std::min(x0 + 1, input_width - 1);
+ int32_t y1 = std::min(y0 + 1, input_height - 1);
+ ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape, input_data,
+ output_shape, output_data);
+ }
+ }
+ }
+}
+
+inline void ResizeBilinearKernel(const float *input_ptr, int32_t depth, float scale,
+ float *output_ptr)
+{
+ for (int32_t i = 0; i < depth; i++)
+ {
+ *output_ptr += *input_ptr * scale;
+ output_ptr++;
+ input_ptr++;
+ }
+}
+
+inline void ComputeInterpolationValues(const float value, const float scale,
+ const bool half_pixel_centers, int32_t input_size,
+ float *scaled_value, int32_t *lower_bound,
+ int32_t *upper_bound)
+{
+ if (half_pixel_centers)
+ {
+ *scaled_value = (value + 0.5f) * scale - 0.5f;
+ }
+ else
+ {
+ *scaled_value = value * scale;
+ }
+ float scaled_value_floor = std::floor(*scaled_value);
+ *lower_bound = std::max(static_cast<int32_t>(scaled_value_floor), static_cast<int32_t>(0));
+ *upper_bound = std::min(static_cast<int32_t>(std::ceil(*scaled_value)), input_size - 1);
+}
+
+inline void ResizeBilinearGeneric(int32_t batches, int32_t input_height, int32_t input_width,
+ int32_t depth, int32_t output_height, int32_t output_width,
+ float height_scale, float width_scale, const Shape &input_shape,
+ const float *input_data, float *output_data,
+ const bool half_pixel_centers)
+{
+ memset(output_data, 0, batches * output_height * output_width * depth * sizeof(float));
+
+ int32_t output_offset = 0;
+ for (int b = 0; b < batches; ++b)
+ {
+ for (int y = 0; y < output_height; ++y)
+ {
+ float input_y;
+ int32_t y0, y1;
+ ComputeInterpolationValues(y, height_scale, half_pixel_centers, input_height, &input_y, &y0,
+ &y1);
+ for (int x = 0; x < output_width; ++x)
+ {
+ float input_x;
+ int32_t x0, x1;
+ ComputeInterpolationValues(x, width_scale, half_pixel_centers, input_width, &input_x, &x0,
+ &x1);
+ float *output_ptr = &output_data[output_offset];
+
+ // Run kernel on the 4 corners of the bilinear resize algorithm.
+ int32_t input_offset = Offset(input_shape, b, y0, x0, 0);
+ float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
+ const float *input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ input_offset = Offset(input_shape, b, y0, x1, 0);
+ scale = (1 - (input_y - y0)) * (input_x - x0);
+ input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ input_offset = Offset(input_shape, b, y1, x0, 0);
+ scale = (input_y - y0) * (1 - (input_x - x0));
+ input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ input_offset = Offset(input_shape, b, y1, x1, 0);
+ scale = (input_y - y0) * (input_x - x0);
+ input_ptr = &input_data[input_offset];
+ ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
+
+ output_offset += depth;
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void ResizeBilinearGenericSmallChannel(int32_t batches, int32_t input_height,
+ int32_t input_width, int32_t depth,
+ int32_t output_height, int32_t output_width,
+ float height_scale, float width_scale,
+ const Shape &input_shape, const T *input_data,
+ T *output_data, const bool half_pixel_centers)
+{
+ T *output_ptr = &output_data[0];
+ for (int b = 0; b < batches; ++b)
+ {
+ for (int y = 0; y < output_height; ++y)
+ {
+ float input_y;
+ int32_t y0, y1;
+ ComputeInterpolationValues(y, height_scale, half_pixel_centers, input_height, &input_y, &y0,
+ &y1);
+ for (int x = 0; x < output_width; ++x)
+ {
+ float input_x;
+ int32_t x0, x1;
+ ComputeInterpolationValues(x, width_scale, half_pixel_centers, input_width, &input_x, &x0,
+ &x1);
+
+ int32_t input_offset[4] = {
+ Offset(input_shape, b, y0, x0, 0), Offset(input_shape, b, y0, x1, 0),
+ Offset(input_shape, b, y1, x0, 0), Offset(input_shape, b, y1, x1, 0)};
+ float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
+ (1 - (input_y - y0)) * (input_x - x0),
+ (input_y - y0) * (1 - (input_x - x0)), (input_y - y0) * (input_x - x0)};
+
+ for (int d = 0; d < depth; d++)
+ {
+ const T *input_ptr = &input_data[d];
+ *output_ptr++ = static_cast<T>(
+ input_ptr[input_offset[0]] * scale[0] + input_ptr[input_offset[1]] * scale[1] +
+ input_ptr[input_offset[2]] * scale[2] + input_ptr[input_offset[3]] * scale[3]);
+ }
+ }
+ }
+ }
+}
+
+void ResizeBilinear(ResizeBilinearParams &params, const Shape &input_shape, const float *input_data,
+ const Shape &output_shape, float *output_data)
+{
+ int32_t batches = static_cast<int32_t>(MatchingDim(input_shape, 0, output_shape, 0));
+ int32_t input_height = input_shape.Dims(1);
+ int32_t input_width = input_shape.Dims(2);
+ int32_t depth = static_cast<int32_t>(MatchingDim(input_shape, 3, output_shape, 3));
+
+ // Specialize for 2x2 upsample.
+ if (!params.align_corners && !params.half_pixel_centers &&
+ params.output_height == 2 * input_height && params.output_width == 2 * input_width)
+ {
+ ResizeBilinear2x2(batches, input_height, input_width, depth, params.output_height,
+ params.output_width, input_shape, input_data, output_shape, output_data);
+ }
+ else
+ {
+ float height_scale = static_cast<float>(input_height) / params.output_height;
+ float width_scale = static_cast<float>(input_width) / params.output_width;
+ if (params.align_corners && params.output_height > 1)
+ {
+ height_scale = static_cast<float>(input_height - 1) / (params.output_height - 1);
+ }
+ if (params.align_corners && params.output_width > 1)
+ {
+ width_scale = static_cast<float>(input_width - 1) / (params.output_width - 1);
+ }
+
+ ResizeBilinearGeneric(batches, input_height, input_width, depth, params.output_height,
+ params.output_width, height_scale, width_scale, input_shape, input_data,
+ output_data, params.half_pixel_centers);
+ }
+}
+
+void ResizeBilinear(ResizeBilinearParams &params, const Shape &input_shape,
+ const uint8_t *input_data, const Shape &output_shape, uint8_t *output_data)
+{
+ int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
+ int32_t input_height = input_shape.Dims(1);
+ int32_t input_width = input_shape.Dims(2);
+ int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
+
+ float height_scale = (params.align_corners && params.output_height > 1)
+ ? (static_cast<float>(input_height - 1) / (params.output_height - 1))
+ : (static_cast<float>(input_height) / params.output_height);
+
+ float width_scale = (params.align_corners && params.output_width > 1)
+ ? (static_cast<float>(input_width - 1) / (params.output_width - 1))
+ : (static_cast<float>(input_width) / params.output_width);
+
+ ResizeBilinearGenericSmallChannel<uint8_t>(
+ batches, input_height, input_width, depth, params.output_height, params.output_width,
+ height_scale, width_scale, input_shape, input_data, output_data, params.half_pixel_centers);
+}
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_RESIZEBILINEAR_H__
diff --git a/compute/cker/include/cker/operation/SpaceToDepth.h b/compute/cker/include/cker/operation/SpaceToDepth.h
new file mode 100644
index 000000000..ef679315e
--- /dev/null
+++ b/compute/cker/include/cker/operation/SpaceToDepth.h
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_SPACE_TO_DEPTH_H__
+#define __NNFW_CKER_SPACE_TO_DEPTH_H__
+
+#include "cker/Shape.h"
+#include "cker/Types.h"
+
+namespace nnfw
+{
+namespace cker
+{
+
+template <typename T>
+inline void SpaceToDepth(const SpaceToDepthParams &params, const Shape &unextended_input_shape,
+ const T *input_data, const Shape &unextended_output_shape, T *output_data)
+{
+ assert(unextended_input_shape.DimensionsCount() <= 4);
+ assert(unextended_output_shape.DimensionsCount() <= 4);
+ const Shape input_shape = Shape::ExtendedShape(4, unextended_input_shape);
+ const Shape output_shape = Shape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+
+ const int input_depth = input_shape.Dims(3);
+ const int batch_size = input_shape.Dims(0);
+
+ // Number of continuous values that we can copy in one interation.
+ const int stride = params.block_size * input_depth;
+
+ for (int batch = 0; batch < batch_size; ++batch)
+ {
+ for (int out_h = 0; out_h < output_height; ++out_h)
+ {
+ T *output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
+ for (int offset_h = 0; offset_h < params.block_size; ++offset_h)
+ {
+ T *dst = output_ptr;
+ for (int out_w = 0; out_w < output_width; ++out_w)
+ {
+ memcpy(dst, input_data, stride * sizeof(T));
+ input_data += stride;
+ dst += output_depth;
+ }
+ output_ptr += stride;
+ }
+ }
+ }
+}
+
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_SPACE_TO_DEPTH_H__
diff --git a/compute/cker/include/cker/operation/SplitV.h b/compute/cker/include/cker/operation/SplitV.h
new file mode 100644
index 000000000..9e46f4b04
--- /dev/null
+++ b/compute/cker/include/cker/operation/SplitV.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_SPLIT_V_H__
+#define __NNFW_CKER_SPLIT_V_H__
+
+#include "cker/Shape.h"
+#include "cker/Types.h"
+
+namespace nnfw
+{
+namespace cker
+{
+
+template <typename Scalar>
+void SplitV(const SplitVParams &params, const Shape &input_shape, const Scalar *input_data,
+ std::vector<nnfw::cker::Shape> &output_shapes, Scalar *const *output_data)
+{
+ const int split_dimensions = input_shape.DimensionsCount();
+ int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
+ int outputs_count = params.num_split;
+
+ int64_t split_size = 0;
+
+ for (int i = 0; i < outputs_count; i++)
+ {
+ // TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
+ for (int j = 0; j < split_dimensions; j++)
+ {
+ if (j != axis)
+ {
+ MatchingDim(output_shapes[i], j, input_shape, j);
+ }
+ }
+ split_size += output_shapes[i].Dims(axis);
+ }
+
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i)
+ {
+ outer_size *= input_shape.Dims(i);
+ }
+ // For all output arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < split_dimensions; ++i)
+ {
+ base_inner_size *= input_shape.Dims(i);
+ }
+
+ const Scalar *input_ptr = input_data;
+ int copy_size = 0;
+ for (int k = 0; k < outer_size; k++)
+ {
+ for (int i = 0; i < outputs_count; ++i)
+ {
+ copy_size = output_shapes[i].Dims(axis) * base_inner_size;
+ memcpy(output_data[i] + k * copy_size, input_ptr, copy_size * sizeof(Scalar));
+ input_ptr += copy_size;
+ }
+ }
+}
+
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_SPLIT_V_H__
diff --git a/compute/cker/include/cker/operation/StatelessRandomUniform.h b/compute/cker/include/cker/operation/StatelessRandomUniform.h
new file mode 100644
index 000000000..d5952ae23
--- /dev/null
+++ b/compute/cker/include/cker/operation/StatelessRandomUniform.h
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_STATELESS_RANDOM_UNIFORM_H__
+#define __NNFW_CKER_STATELESS_RANDOM_UNIFORM_H__
+
+#include "cker/Types.h"
+#include "cker/Shape.h"
+#include "cker/Utils.h"
+
+#include "cker/eigen/EigenSupport.h"
+
+#include "cker/operation/Helper/Tensor.h"
+#include "cker/operation/Helper/PhiloxRandom.h"
+#include "cker/operation/Helper/RandomOpCpu.h"
+#include "cker/operation/Helper/RandomDistributions.h"
+
+namespace nnfw
+{
+namespace cker
+{
+
+void GenerateKey(Tensor seed, random::PhiloxRandom::Key *out_key,
+ random::PhiloxRandom::ResultType *out_counter)
+{
+ // Grab the two seeds
+ uint32_t seed0;
+ uint32_t seed1;
+
+ const auto seed_vals = seed.flat<int32_t>();
+
+ seed0 = seed_vals(0);
+ seed1 = seed_vals(1);
+ // Scramble the seeds so that the user doesn't need to worry about which
+ // part of the seed needs to be strong.
+ (*out_key)[0] = 0x3ec8f720;
+ (*out_key)[1] = 0x02461e29;
+ (*out_counter)[0] = static_cast<uint32_t>(seed0);
+ (*out_counter)[1] = (*out_counter)[3] = 0;
+ (*out_counter)[2] = static_cast<uint32_t>(seed1);
+ const auto mix = random::PhiloxRandom(*out_counter, *out_key)();
+ (*out_key)[0] = mix[0];
+ (*out_key)[1] = mix[1];
+ (*out_counter)[0] = (*out_counter)[1] = 0;
+ (*out_counter)[2] = mix[2];
+ (*out_counter)[3] = mix[3];
+}
+
+template <typename Device, class Distribution>
+void Fill(random::PhiloxRandom random, Tensor *output)
+{
+ // Build distribution
+ typedef typename Distribution::ResultElementType T;
+
+ auto flat = output->flat<T>();
+ // Reuse the compute kernels from the stateful random ops
+ functor::FillPhiloxRandom<Device, Distribution>()(random, flat.data(), flat.size(),
+ Distribution());
+}
+
+inline void StatelessRandomUniform(const Shape &shape_shape, const int *shape_data,
+ const Shape &seed_shape, const int *seed_data,
+ const Shape &output_shape, float *output_data)
+{
+ Tensor shape_t;
+ Tensor seed_t;
+
+ shape_t.shape.ReplaceWith(shape_shape.DimensionsCount(), shape_shape.DimsData());
+ shape_t.buffer = (void *)shape_data;
+
+ seed_t.shape.ReplaceWith(seed_shape.DimensionsCount(), seed_shape.DimsData());
+ seed_t.buffer = (void *)seed_data;
+
+ Tensor output_t;
+ output_t.shape.ReplaceWith(output_shape.DimensionsCount(), output_shape.DimsData());
+ output_t.buffer = output_data;
+
+ random::PhiloxRandom::Key key;
+ random::PhiloxRandom::ResultType counter;
+
+ GenerateKey(seed_t, &key, &counter);
+
+ Fill<Eigen::ThreadPoolDevice, random::UniformDistribution<random::PhiloxRandom, float>>(
+ random::PhiloxRandom(counter, key), &output_t);
+}
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_STATELESS_RANDOM_UNIFORM_H__
diff --git a/compute/cker/include/cker/ruy/RuySupport.h b/compute/cker/include/cker/ruy/RuySupport.h
index 432b181bd..9612dd517 100644
--- a/compute/cker/include/cker/ruy/RuySupport.h
+++ b/compute/cker/include/cker/ruy/RuySupport.h
@@ -22,11 +22,6 @@
#include <ruy/context.h>
#include "cker/Types.h"
-namespace
-{
-const int kDefaultNumThreadpoolThreads = 4;
-}
-
namespace nnfw
{
namespace cker
@@ -34,42 +29,6 @@ namespace cker
namespace ruy_support
{
-struct RuyContext
-{
-public:
- RuyContext() : ruy_context_(new ruy::Context)
- {
- SetMaxNumThreads(onert::util::getConfigInt(onert::util::config::RUY_THREADS));
-#ifdef USE_RUY_GEMV
- ruy_context_->cache_policy = ruy::kCacheLHSOnNarrowMul;
-#endif
- };
-
- ruy::Context *ruy_context() const { return ruy_context_.get(); }
-
- static inline RuyContext &GetRuyContext()
- {
- static thread_local RuyContext instance;
- return instance;
- }
-
- void SetMaxNumThreads(int max_num_threads)
- {
- const int target_num_threads =
- max_num_threads > -1 ? max_num_threads : kDefaultNumThreadpoolThreads;
- ruy_context_->max_num_threads = target_num_threads;
- }
-
-private:
- const std::unique_ptr<ruy::Context> ruy_context_;
-};
-
-inline ruy::Context *GetRuyContext()
-{
- auto &ctx = RuyContext::GetRuyContext();
- return ctx.ruy_context();
-}
-
template <typename Scalar, typename DataPointer>
void MakeRuyMatrix(const MatrixParams<Scalar> &params, DataPointer data_ptr,
ruy::Matrix<Scalar> *dst)