summaryrefslogtreecommitdiff
path: root/compute
diff options
context:
space:
mode:
Diffstat (limited to 'compute')
-rw-r--r--compute/ARMComputeEx/src/runtime/CL/functions/CLSplitVEx.cpp2
-rw-r--r--compute/cker/CMakeLists.txt4
-rw-r--r--compute/cker/include/cker/PortableTensorUtils.h54
-rw-r--r--compute/cker/include/cker/Shape.h2
-rw-r--r--compute/cker/include/cker/Types.h9
-rw-r--r--compute/cker/include/cker/eigen/eigen_gemm_eigen.h95
-rw-r--r--compute/cker/include/cker/operation/Conv.h15
-rw-r--r--compute/cker/include/cker/operation/DepthwiseConv.h1
-rw-r--r--compute/cker/include/cker/operation/Einsum.h6
-rw-r--r--compute/cker/include/cker/operation/FullyConnected.h39
-rw-r--r--compute/cker/include/cker/operation/optimized/Gemm.h100
-rw-r--r--compute/cker/include/cker/operation/reference/Conv.h85
-rw-r--r--compute/cker/include/cker/operation/reference/integer_ops/DepthwiseConvHybrid.h122
-rw-r--r--compute/cker/include/cker/train/operation/FullyConnected.h49
-rw-r--r--compute/cker/include/cker/train/operation/Loss.h77
-rw-r--r--compute/cker/include/cker/train/operation/ReLU.h50
-rw-r--r--compute/cker/src/train/FullyConnected.test.cc83
-rw-r--r--compute/cker/src/train/Loss.test.cc201
-rw-r--r--compute/cker/src/train/Relu.test.cc107
-rw-r--r--compute/ruy/include/ruy/Shape.h2
20 files changed, 1094 insertions, 9 deletions
diff --git a/compute/ARMComputeEx/src/runtime/CL/functions/CLSplitVEx.cpp b/compute/ARMComputeEx/src/runtime/CL/functions/CLSplitVEx.cpp
index 73f5f6eb1..bca4d5cb6 100644
--- a/compute/ARMComputeEx/src/runtime/CL/functions/CLSplitVEx.cpp
+++ b/compute/ARMComputeEx/src/runtime/CL/functions/CLSplitVEx.cpp
@@ -174,7 +174,7 @@ void CLSplitVEx::configure(const ICLTensor *input, const ICLTensor *size_splits,
// Extract output tensor info
std::vector<ITensorInfo *> outputs_info;
- for (auto &output : _outputs)
+ for (auto &&output : _outputs)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(output);
outputs_info.emplace_back(output->info());
diff --git a/compute/cker/CMakeLists.txt b/compute/cker/CMakeLists.txt
index ce328b685..d464dccae 100644
--- a/compute/cker/CMakeLists.txt
+++ b/compute/cker/CMakeLists.txt
@@ -12,6 +12,10 @@ if(PROFILE_RUY)
target_link_libraries(nnfw_lib_cker INTERFACE ruy_profiler)
endif(PROFILE_RUY)
+if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
+ target_compile_definitions(nnfw_lib_cker INTERFACE CKER_X86_PLATFORM)
+endif(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
+
target_include_directories(nnfw_lib_cker INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/include)
# Workaround to avoid warning
diff --git a/compute/cker/include/cker/PortableTensorUtils.h b/compute/cker/include/cker/PortableTensorUtils.h
index 2a58a2ec9..7e4b01a01 100644
--- a/compute/cker/include/cker/PortableTensorUtils.h
+++ b/compute/cker/include/cker/PortableTensorUtils.h
@@ -144,6 +144,60 @@ inline void PortableSymmetricQuantizeFloats(const float *values, const int size,
}
}
+inline void PortableAsymmetricQuantizeFloats(const float *values, const int size,
+ int8_t *quantized_values, float *scaling_factor,
+ int32_t *offset)
+{
+ /* Copied from TensorFlow PortableAsymmetricQuantizeFloats */
+ const int32_t kMinScale = -128;
+ const int32_t kMaxScale = 127;
+ const double qmin_double = kMinScale;
+ const double qmax_double = kMaxScale;
+ const auto minmax = std::minmax_element(values, values + size);
+ const double rmin = static_cast<double>(std::min(0.0f, *minmax.first));
+ const double rmax = static_cast<double>(std::max(0.0f, *minmax.second));
+ if (rmin == rmax)
+ {
+ memset(quantized_values, 0, size * sizeof(int8_t));
+ *scaling_factor = 1;
+ *offset = 0;
+ return;
+ }
+ else
+ {
+ double scale = (rmax - rmin) / (qmax_double - qmin_double);
+ const double zero_point_from_min = qmin_double - rmin / scale;
+ const double zero_point_from_max = qmax_double - rmax / scale;
+ const double zero_point_from_min_error = std::abs(qmin_double) + std::abs(rmin / scale);
+ const double zero_point_from_max_error = std::abs(qmax_double) + std::abs(rmax / scale);
+ const double zero_point_double = zero_point_from_min_error < zero_point_from_max_error
+ ? zero_point_from_min
+ : zero_point_from_max;
+ int8_t nudged_zero_point = 0;
+ if (zero_point_double <= qmin_double)
+ {
+ nudged_zero_point = kMinScale;
+ }
+ else if (zero_point_double >= qmax_double)
+ {
+ nudged_zero_point = kMaxScale;
+ }
+ else
+ {
+ nudged_zero_point = static_cast<int8_t>(round(zero_point_double));
+ }
+ *scaling_factor = scale;
+ *offset = nudged_zero_point;
+ }
+ const float scaling_factor_inv = 1.0f / *scaling_factor;
+ for (int i = 0; i < size; ++i)
+ {
+ const int32_t quantized_value =
+ static_cast<int32_t>(std::round(*offset + values[i] * scaling_factor_inv));
+ quantized_values[i] = std::min(kMaxScale, std::max(kMinScale, quantized_value));
+ }
+}
+
inline void PortableMatrixBatchVectorMultiplyAccumulate(const int8_t *__restrict__ matrix,
const int m_rows, const int m_cols,
const int8_t *__restrict__ vectors,
diff --git a/compute/cker/include/cker/Shape.h b/compute/cker/include/cker/Shape.h
index 86caf7d18..9269ce9aa 100644
--- a/compute/cker/include/cker/Shape.h
+++ b/compute/cker/include/cker/Shape.h
@@ -156,7 +156,7 @@ public:
const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
Resize(dimensions_count);
int32_t *data = DimsData();
- for (auto it : src_iterable)
+ for (auto &&it : src_iterable)
{
*data = it;
++data;
diff --git a/compute/cker/include/cker/Types.h b/compute/cker/include/cker/Types.h
index 495c89440..3fd0cf5b6 100644
--- a/compute/cker/include/cker/Types.h
+++ b/compute/cker/include/cker/Types.h
@@ -258,9 +258,12 @@ struct FullyConnectedParams
// uint8, etc, activation params.
int32_t quantized_activation_min;
int32_t quantized_activation_max;
- // float activation params - no one use this params, but ruy might use them later.
- // float float_activation_min;
- // float float_activation_max;
+ // float activation params
+ float float_activation_min;
+ float float_activation_max;
+ // Mark the operands as cacheable if they are unchanging, e.g. weights.
+ bool lhs_cacheable;
+ bool rhs_cacheable;
// FullyConnectedWeightsFormat weights_format;
};
diff --git a/compute/cker/include/cker/eigen/eigen_gemm_eigen.h b/compute/cker/include/cker/eigen/eigen_gemm_eigen.h
new file mode 100644
index 000000000..d4f8fc09d
--- /dev/null
+++ b/compute/cker/include/cker/eigen/eigen_gemm_eigen.h
@@ -0,0 +1,95 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_EGIEN_EIGEN_GEMM_EIGEN_H__
+#define __NNFW_CKER_EGIEN_EIGEN_GEMM_EIGEN_H__
+
+// See b/131835803: in TFLite code, because eigen_spatial_convolutions.h does
+// #define Eigen EigenForTFLite, it is difficult to have any #include of Eigen
+// headers in a header file, as that results in name classes (compilation
+// errors) depending on the order in which these headers are #included.
+// So we have moved the #include of Eigen here, in a .cc file, where we have
+// control over the header #include sequence.
+// #include "third_party/eigen3/Eigen/Core"
+// #include "tensorflow/lite/kernels/cpu_backend_context.h"
+// #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
+// #include "tensorflow/lite/kernels/internal/common.h"
+// #include "cker/eigen/eigen_convolution_helpers.h"
+#include "cker/operation/Common.h"
+#include "cker/Types.h"
+
+#include <Eigen/Core>
+
+namespace nnfw
+{
+namespace cker
+{
+namespace detail
+{
+
+// tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_eigen.h and cpu_backend_gemm_eigen.cc
+struct GemmImplUsingEigen
+{
+ static void Run(const MatrixParams<float> &lhs_params, const float *lhs_data,
+ const MatrixParams<float> &rhs_params, const float *rhs_data,
+ const MatrixParams<float> &dst_params, float *dst_data,
+ const GemmParams<float, float> &params)
+ {
+ // This code assumes specific storage orders, encoded in these Eigen types.
+ // These assumptions have been checked by TF_LITE_ASSERT's in the public
+ // Gemm entry point already, before the implementation gets to this point.
+ using EigenMatrixMapRowMajorConst =
+ Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
+ using EigenMatrixMapColMajorConst =
+ Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>;
+ using EigenMatrixMapColMajorMutable =
+ Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>;
+
+ EigenMatrixMapRowMajorConst eigen_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
+ EigenMatrixMapColMajorConst eigen_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
+ EigenMatrixMapColMajorMutable eigen_dst(dst_data, dst_params.rows, dst_params.cols);
+
+ if (rhs_params.cols == 1)
+ {
+ eigen_dst.col(0).noalias() = eigen_lhs * eigen_rhs.col(0);
+ }
+ else if (lhs_params.rows == 1)
+ {
+ eigen_dst.row(0).noalias() = eigen_lhs.row(0) * eigen_rhs;
+ }
+ else
+ {
+ eigen_dst.noalias() = eigen_lhs * eigen_rhs;
+ }
+
+ if (params.bias)
+ {
+ BiasAndClamp(params.clamp_min, params.clamp_max, dst_params.rows, params.bias,
+ dst_params.rows * dst_params.cols, dst_data);
+ }
+ else
+ {
+ eigen_dst = eigen_dst.cwiseMin(params.clamp_max).cwiseMax(params.clamp_min);
+ }
+ }
+};
+
+} // namespace detail
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_EGIEN_EIGEN_GEMM_EIGEN_H__
diff --git a/compute/cker/include/cker/operation/Conv.h b/compute/cker/include/cker/operation/Conv.h
index 7cd54dcd5..2572b51ee 100644
--- a/compute/cker/include/cker/operation/Conv.h
+++ b/compute/cker/include/cker/operation/Conv.h
@@ -207,6 +207,21 @@ private:
std::vector<int32_t> _per_channel_output_multiplier;
std::vector<int> _per_channel_output_shift;
};
+
+struct ConvHybridTempArena
+{
+ ConvHybridTempArena(int batch_size, int input_size)
+ {
+ input_quantized.resize(input_size);
+ // TODO: Optimize the case of batch_size = 1
+ input_scaling_factors.resize(batch_size);
+ input_offsets.resize(batch_size);
+ }
+ std::vector<int8_t> input_quantized;
+ std::vector<float> input_scaling_factors;
+ std::vector<int32_t> input_offsets;
+};
+
} // namespace cker
} // namespace nnfw
diff --git a/compute/cker/include/cker/operation/DepthwiseConv.h b/compute/cker/include/cker/operation/DepthwiseConv.h
index ed1f93d44..c926ec4f1 100644
--- a/compute/cker/include/cker/operation/DepthwiseConv.h
+++ b/compute/cker/include/cker/operation/DepthwiseConv.h
@@ -26,6 +26,7 @@
#include "cker/operation/optimized/DepthwiseConvUint8.h"
#include "cker/operation/optimized/integer_ops/DepthwiseConvInt8.h"
#include "cker/operation/reference/integer_ops/DepthwiseConvUInt8.h"
+#include "cker/operation/reference/integer_ops/DepthwiseConvHybrid.h"
#include "cker/CpuBackendThreadpool.h"
namespace nnfw
diff --git a/compute/cker/include/cker/operation/Einsum.h b/compute/cker/include/cker/operation/Einsum.h
index 6721a7508..bb9f88f8d 100644
--- a/compute/cker/include/cker/operation/Einsum.h
+++ b/compute/cker/include/cker/operation/Einsum.h
@@ -274,7 +274,7 @@ public:
}
for (int i = 0; i < num_inputs; ++i)
{
- for (int label : free_labels[i])
+ for (auto &&label : free_labels[i])
{
result_labels.push_back(label);
result_shape_dims.push_back(label_to_dim_sizes[label]);
@@ -300,7 +300,7 @@ public:
{
// We inflated the output. Modify result labels accordingly.
Labels inflated_labels;
- for (int label : result_labels)
+ for (auto &&label : result_labels)
{
inflated_labels.insert(inflated_labels.end(), output_label_counts[label], label);
}
@@ -775,7 +775,7 @@ private:
Shape inflated_shape;
std::vector<int32_t> strided_shape_dims;
std::vector<int32_t> inflated_shape_dims;
- for (int label : labels)
+ for (auto &&label : labels)
{
const int32_t count = label_counts[label];
const int current_axis =
diff --git a/compute/cker/include/cker/operation/FullyConnected.h b/compute/cker/include/cker/operation/FullyConnected.h
index b7d27e85d..71a2f19ef 100644
--- a/compute/cker/include/cker/operation/FullyConnected.h
+++ b/compute/cker/include/cker/operation/FullyConnected.h
@@ -21,6 +21,7 @@
#include <ruy/context.h>
#include "cker/operation/FullyConnectedDense16x1.h"
#include "cker/operation/FullyConnectedSparse16x1.h"
+#include "cker/operation/optimized/Gemm.h"
#include "cker/Shape.h"
#include "cker/Types.h"
#include "cker/Utils.h"
@@ -58,6 +59,42 @@ public:
std::vector<int32_t> accum_scratch;
};
+#if defined(CKER_X86_PLATFORM)
+
+// From tensorflow/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+inline void FullyConnected(const FullyConnectedParams &params, const Shape &input_shape,
+ const float *input_data, const Shape &weights_shape,
+ const float *weights_data, const Shape &,
+ const float *optional_bias_data, const Shape &output_shape,
+ float *output_data)
+{
+ const int dims_count = weights_shape.DimensionsCount();
+ const int input_rows = weights_shape.Dims(dims_count - 1);
+ MatrixParams<float> rhs_params;
+ rhs_params.order = Order::kColMajor;
+ rhs_params.rows = input_rows;
+ rhs_params.cols = input_shape.FlatSize() / input_rows;
+ rhs_params.cache_policy = optimized::DefaultCachePolicy(params.rhs_cacheable);
+
+ MatrixParams<float> lhs_params;
+ lhs_params.order = Order::kRowMajor;
+ lhs_params.cols = weights_shape.Dims(dims_count - 1);
+ lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
+ lhs_params.cache_policy = optimized::DefaultCachePolicy(params.lhs_cacheable);
+ MatrixParams<float> dst_params;
+ dst_params.order = Order::kColMajor;
+ dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
+ dst_params.cols = FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
+ GemmParams<float, float> gemm_params;
+ gemm_params.bias = optional_bias_data;
+ gemm_params.clamp_min = params.float_activation_min;
+ gemm_params.clamp_max = params.float_activation_max;
+ optimized::Gemm(lhs_params, weights_data, rhs_params, input_data, dst_params, output_data,
+ gemm_params);
+}
+
+#else // CKER_X86_PLATFORM
+
inline void FullyConnected(const FullyConnectedParams &params, const Shape &input_shape,
const float *input_data, const Shape &weights_shape,
const float *weights_data, const Shape &, const float *bias_data,
@@ -89,6 +126,8 @@ inline void FullyConnected(const FullyConnectedParams &params, const Shape &inpu
}
}
+#endif // CKER_X86_PLATFORM
+
inline void FullyConnected(const FullyConnectedParams &params, const Shape &input_shape,
const uint8_t *input_data, const Shape &filter_shape,
const uint8_t *filter_data, const Shape &bias_shape,
diff --git a/compute/cker/include/cker/operation/optimized/Gemm.h b/compute/cker/include/cker/operation/optimized/Gemm.h
new file mode 100644
index 000000000..cfebef452
--- /dev/null
+++ b/compute/cker/include/cker/operation/optimized/Gemm.h
@@ -0,0 +1,100 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_OPTIMIZED_GEMM_H__
+#define __NNFW_CKER_OPTIMIZED_GEMM_H__
+
+#include "cker/eigen/eigen_gemm_eigen.h"
+#include "cker/Shape.h"
+#include "cker/Types.h"
+
+#include <ruy/context.h>
+
+namespace nnfw
+{
+namespace cker
+{
+namespace optimized
+{
+
+#if defined(CKER_X86_PLATFORM)
+
+/* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_x86.h */
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
+ QuantizationFlavor quantization_flavor>
+struct GemmImplX86
+{
+ static void Run(const MatrixParams<LhsScalar> &, const LhsScalar *,
+ const MatrixParams<RhsScalar> &, const RhsScalar *,
+ const MatrixParams<DstScalar> &, DstScalar *,
+ const GemmParams<AccumScalar, DstScalar, quantization_flavor> &)
+ {
+ static_assert(
+ std::is_floating_point<LhsScalar>::value && std::is_floating_point<RhsScalar>::value &&
+ std::is_floating_point<AccumScalar>::value && std::is_floating_point<DstScalar>::value &&
+ quantization_flavor != QuantizationFlavor::kFloatingPoint,
+ "GemmImplX86 does not supported types other than float yet.");
+ }
+};
+
+// For float, defer to eigen for now.
+template <> struct GemmImplX86<float, float, float, float, QuantizationFlavor::kFloatingPoint>
+{
+ static void Run(const MatrixParams<float> &lhs_params, const float *lhs_data,
+ const MatrixParams<float> &rhs_params, const float *rhs_data,
+ const MatrixParams<float> &dst_params, float *dst_data,
+ const GemmParams<float, float, QuantizationFlavor::kFloatingPoint> &params)
+ {
+ detail::GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params,
+ dst_data, params);
+ }
+};
+
+/* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm.h */
+/* GEMM dispatch implementation for x86.
+ */
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
+ QuantizationFlavor quantization_flavor>
+struct GemmImpl : GemmImplX86<LhsScalar, RhsScalar, AccumScalar, DstScalar, quantization_flavor>
+{
+};
+
+/* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm.h */
+template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
+ QuantizationFlavor quantization_flavor>
+void Gemm(const MatrixParams<LhsScalar> &lhs_params, const LhsScalar *lhs_data,
+ const MatrixParams<RhsScalar> &rhs_params, const RhsScalar *rhs_data,
+ const MatrixParams<DstScalar> &dst_params, DstScalar *dst_data,
+ const GemmParams<AccumScalar, DstScalar, quantization_flavor> &params)
+{
+ // Generic case: dispatch to any backend as a general GEMM.
+ GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, quantization_flavor>::Run(
+ lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params);
+}
+
+// From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_params.h
+inline CachePolicy DefaultCachePolicy(bool is_constant_data)
+{
+ return is_constant_data ? CachePolicy::kCacheIfLargeSpeedup : CachePolicy::kNeverCache;
+}
+#endif // CKER_X86_PLATFORM
+
+} // namespace optimized
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_OPTIMIZED_GEMM_H__
diff --git a/compute/cker/include/cker/operation/reference/Conv.h b/compute/cker/include/cker/operation/reference/Conv.h
index 8bfd4694e..e316083a5 100644
--- a/compute/cker/include/cker/operation/reference/Conv.h
+++ b/compute/cker/include/cker/operation/reference/Conv.h
@@ -311,6 +311,91 @@ inline void Conv(const ConvParams &params, const int32_t *output_multiplier,
}
}
+// Slightly modified from tflite 2.13.0 HybridConvPerChannel
+// im2col and im2col_shape are removed since it is not used in reference kernel.
+inline void HybridConvPerChannel(const ConvParams &params, float *scaling_factors_ptr,
+ const Shape &input_shape, const int8_t *input_data,
+ const Shape &filter_shape, const int8_t *filter_data,
+ const Shape &bias_shape, const float *bias_data,
+ const Shape &output_shape, float *output_data,
+ const float *per_channel_scale, const int32_t *input_offset)
+
+{
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ assert(input_shape.DimensionsCount() == 4);
+ assert(filter_shape.DimensionsCount() == 4);
+ assert(output_shape.DimensionsCount() == 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = input_shape.Dims(3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ if (bias_data)
+ {
+ assert(bias_shape.FlatSize() == output_depth);
+ UNUSED_RELEASE(bias_shape);
+ }
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_input_depth = filter_shape.Dims(3);
+ const int groups = input_depth / filter_input_depth;
+ assert(input_depth % filter_input_depth == 0);
+ const int filters_per_group = output_depth / groups;
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ for (int batch = 0; batch < batches; ++batch)
+ {
+ for (int out_y = 0; out_y < output_height; ++out_y)
+ {
+ for (int out_x = 0; out_x < output_width; ++out_x)
+ {
+ for (int out_channel = 0; out_channel < output_depth; ++out_channel)
+ {
+ auto group = out_channel / filters_per_group;
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ int32_t acc = 0;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y)
+ {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x)
+ {
+ for (int in_channel = 0; in_channel < filter_input_depth; ++in_channel)
+ {
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
+ // If the location is outside the bounds of the input image,
+ // use zero as a default value.
+ if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && (in_y < input_height))
+ {
+ int32_t input_val = input_data[Offset(input_shape, batch, in_y, in_x,
+ in_channel + group * filter_input_depth)];
+ int32_t filter_val =
+ filter_data[Offset(filter_shape, out_channel, filter_y, filter_x, in_channel)];
+ acc += filter_val * (input_val - input_offset[batch]);
+ }
+ }
+ }
+ }
+ float acc_float = acc * per_channel_scale[out_channel] * scaling_factors_ptr[batch];
+ if (bias_data)
+ {
+ acc_float += bias_data[out_channel];
+ }
+ output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
+ ActivationFunctionWithMinMax(acc_float, output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+}
+
} // namespace reference
} // namespace cker
} // namespace nnfw
diff --git a/compute/cker/include/cker/operation/reference/integer_ops/DepthwiseConvHybrid.h b/compute/cker/include/cker/operation/reference/integer_ops/DepthwiseConvHybrid.h
new file mode 100644
index 000000000..9fc58ad3b
--- /dev/null
+++ b/compute/cker/include/cker/operation/reference/integer_ops/DepthwiseConvHybrid.h
@@ -0,0 +1,122 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_REFERENCE_DEPTHWISE_CONV_HYBRID_H__
+#define __NNFW_CKER_REFERENCE_DEPTHWISE_CONV_HYBRID_H__
+
+#include "cker/Shape.h"
+#include "cker/Types.h"
+#include "cker/Utils.h"
+
+namespace nnfw
+{
+namespace cker
+{
+namespace reference_integer_ops
+{
+
+inline void DepthwiseConvHybridPerChannel(const DepthwiseConvParams &params,
+ float *scaling_factors_ptr, const Shape &input_shape,
+ const int8_t *input_data, const Shape &filter_shape,
+ const int8_t *filter_data, const Shape &bias_shape,
+ const float *bias_data, const Shape &output_shape,
+ float *output_data, const float *per_channel_scale,
+ int32_t *input_offset)
+{
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+
+ // Check dimensions of the tensors.
+ assert(input_shape.DimensionsCount() == 4);
+ assert(filter_shape.DimensionsCount() == 4);
+ assert(output_shape.DimensionsCount() == 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int bias_depth = bias_shape.FlatSize();
+ UNUSED_RELEASE(output_depth);
+ UNUSED_RELEASE(bias_shape);
+ assert(output_depth == input_depth * depth_multiplier);
+ assert(bias_depth == output_depth);
+
+ for (int batch = 0; batch < batches; ++batch)
+ {
+ for (int out_y = 0; out_y < output_height; ++out_y)
+ {
+ for (int out_x = 0; out_x < output_width; ++out_x)
+ {
+ for (int in_channel = 0; in_channel < input_depth; ++in_channel)
+ {
+ for (int m = 0; m < depth_multiplier; ++m)
+ {
+ const int output_channel = m + in_channel * depth_multiplier;
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ int32_t acc = 0;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y)
+ {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x)
+ {
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
+ // Zero padding by omitting the areas outside the image.
+ const bool is_point_inside_image =
+ (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && (in_y < input_height);
+ if (is_point_inside_image)
+ {
+ int32_t input_val =
+ input_data[Offset(input_shape, batch, in_y, in_x, in_channel)];
+ int32_t filter_val =
+ filter_data[Offset(filter_shape, 0, filter_y, filter_x, output_channel)];
+ acc += filter_val * (input_val - input_offset[batch]);
+ }
+ }
+ }
+ float acc_float = static_cast<float>(acc);
+ acc_float *= per_channel_scale[output_channel] * scaling_factors_ptr[batch];
+ if (bias_data && output_channel < bias_depth)
+ {
+ acc_float += bias_data[output_channel];
+ }
+ output_data[Offset(output_shape, batch, out_y, out_x, output_channel)] =
+ ActivationFunctionWithMinMax(acc_float, output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+ }
+}
+
+} // namespace reference_integer_ops
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_REFERENCE_DEPTHWISE_CONV_HYBRID_H__
diff --git a/compute/cker/include/cker/train/operation/FullyConnected.h b/compute/cker/include/cker/train/operation/FullyConnected.h
new file mode 100644
index 000000000..b0255d287
--- /dev/null
+++ b/compute/cker/include/cker/train/operation/FullyConnected.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_TRAIN_OPERATION_FULLY_CONNECTED_H__
+#define __NNFW_CKER_TRAIN_OPERATION_FULLY_CONNECTED_H__
+
+#include "cker/eigen/Utils.h"
+#include "cker/Shape.h"
+
+namespace nnfw
+{
+namespace cker
+{
+namespace train
+{
+
+template <typename T>
+inline void FullyConnectedBiasGrad(const Shape &incomming_shape, const T *incomming_data,
+ const Shape &grad_shape, T *grad_data)
+{
+ const auto bias_size = grad_shape.FlatSize();
+ if (bias_size != incomming_shape.Dims(incomming_shape.DimensionsCount() - 1) ||
+ bias_size != grad_shape.Dims(0))
+ throw std::runtime_error("cker::FullyConnectedBiasGrad: Unmatched shape");
+
+ const auto in_mat = MapAsMatrixWithLastDimAsRows(incomming_data, incomming_shape);
+ auto grad_mat = MapAsMatrixWithLastDimAsRows(grad_data, grad_shape);
+
+ grad_mat = in_mat.rowwise().sum();
+}
+
+} // namespace train
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_FULLY_CONNECTED_H__
diff --git a/compute/cker/include/cker/train/operation/Loss.h b/compute/cker/include/cker/train/operation/Loss.h
new file mode 100644
index 000000000..94f49ff07
--- /dev/null
+++ b/compute/cker/include/cker/train/operation/Loss.h
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_TRAIN_OPERATION_LOSS_H__
+#define __NNFW_CKER_TRAIN_OPERATION_LOSS_H__
+
+#include "cker/Shape.h"
+#include "cker/eigen/Utils.h"
+
+namespace nnfw
+{
+namespace cker
+{
+namespace train
+{
+
+template <typename T>
+inline void MSE(const Shape &y_pred_shape, const T *y_pred_data, const Shape &y_true_shape,
+ const T *y_true_data, const Shape &output_shape, T *output_data)
+{
+ // TODO Consider Reduction
+ if (output_shape != Shape{1})
+ throw std::runtime_error("cker::MSE: output_shape != Shape{1}");
+ if (y_pred_shape != y_true_shape)
+ throw std::runtime_error("cker::MSE: y_pred_shape != y_true_shape");
+
+ const auto y_pred = MapAsMatrixWithLastDimAsRows(y_pred_data, y_pred_shape);
+ const auto y_true = MapAsMatrixWithLastDimAsRows(y_true_data, y_true_shape);
+
+ double squared_sum = 0.0f;
+ for (size_t c = 0; c < (size_t)y_pred.cols(); ++c)
+ {
+ for (size_t r = 0; r < (size_t)y_pred.rows(); ++r)
+ {
+ double error = y_pred.coeff(r, c) - y_true.coeff(r, c);
+ squared_sum += (error * error);
+ }
+ }
+
+ auto size = y_pred.cols() * y_pred.rows();
+ output_data[0] = (T)(squared_sum / size);
+}
+
+template <typename T>
+inline void MSEGrad(const Shape &y_pred_shape, const T *y_pred_data, const Shape &y_true_shape,
+ const T *y_true_data, const Shape &grad_shape, T *grad_data)
+{
+ if (y_pred_shape != y_true_shape)
+ throw std::runtime_error("cker::MSEGrad: y_pred_shape != y_true_shape");
+ if (y_pred_shape != grad_shape)
+ throw std::runtime_error("cker::MSEGrad: y_pred_shape != grad_shape");
+
+ const int size = grad_shape.FlatSize();
+ for (int i = 0; i < size; ++i)
+ {
+ grad_data[i] = static_cast<T>(-2 * (y_true_data[i] - y_pred_data[i]) / size);
+ }
+}
+
+} // namespace train
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_TRAIN_OPERATION_LOSS_H__
diff --git a/compute/cker/include/cker/train/operation/ReLU.h b/compute/cker/include/cker/train/operation/ReLU.h
new file mode 100644
index 000000000..32cf7fa9c
--- /dev/null
+++ b/compute/cker/include/cker/train/operation/ReLU.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_CKER_TRAIN_OPERATION_RELU_H__
+#define __NNFW_CKER_TRAIN_OPERATION_RELU_H__
+
+#include "cker/Shape.h"
+#include "cker/eigen/Utils.h"
+
+#include <Eigen/Core>
+
+namespace nnfw
+{
+namespace cker
+{
+namespace train
+{
+
+inline void ReLUGrad(const Shape &output_shape, const float *output_data,
+ const Shape &incoming_shape, const float *incoming_data,
+ const Shape &grad_shape, float *grad_data)
+{
+ const auto output_map = MapAsVector(output_data, output_shape);
+ const auto incoming_map = MapAsVector(incoming_data, incoming_shape);
+ auto grad_map = MapAsVector(grad_data, grad_shape);
+
+ if (output_shape == incoming_shape && output_shape == grad_shape)
+ grad_map.array() = incoming_map.array() * (output_map.array() > 0.0f).template cast<float>();
+ else
+ throw std::runtime_error("cker::ReLUGrad: Unsupported shape");
+}
+
+} // namespace train
+} // namespace cker
+} // namespace nnfw
+
+#endif // __NNFW_CKER_TRAIN_OPERATION_RELU_H__
diff --git a/compute/cker/src/train/FullyConnected.test.cc b/compute/cker/src/train/FullyConnected.test.cc
new file mode 100644
index 000000000..37c2d4a97
--- /dev/null
+++ b/compute/cker/src/train/FullyConnected.test.cc
@@ -0,0 +1,83 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cker/train/operation/FullyConnected.h>
+
+#include <gtest/gtest.h>
+#include <vector>
+
+TEST(CKer_Operation, FullyConnectedBiasGrad)
+{
+ {
+ // Shape: {2, 4}
+ std::vector<float> incoming_backward = {-1, 2, -3, 4, 5, -6, -7, 8};
+ // Shape: {4}
+ std::vector<float> expected_bias_backward = {4, -4, -10, 12};
+ std::vector<float> bias_backward(4);
+
+ nnfw::cker::train::FullyConnectedBiasGrad(
+ nnfw::cker::Shape{2, 4}, incoming_backward.data(),
+ nnfw::cker::Shape{static_cast<int>(bias_backward.size())}, bias_backward.data());
+
+ for (size_t i = 0; i < bias_backward.size(); ++i)
+ ASSERT_EQ(bias_backward[i], expected_bias_backward[i]);
+ }
+
+ {
+ // Shape: {3, 3}
+ std::vector<float> incoming_backward = {-1, 2, -3, 4, 5, -6, -7, 8, 9};
+ // Shape: {3}
+ std::vector<float> expected_bias_backward = {-4, 15, 0};
+ std::vector<float> bias_backward(3);
+
+ nnfw::cker::train::FullyConnectedBiasGrad(
+ nnfw::cker::Shape{3, 3}, incoming_backward.data(),
+ nnfw::cker::Shape{static_cast<int>(bias_backward.size())}, bias_backward.data());
+
+ for (size_t i = 0; i < bias_backward.size(); ++i)
+ ASSERT_EQ(bias_backward[i], expected_bias_backward[i]);
+ }
+
+ {
+ // Shape: {1, 2, 2, 3}
+ std::vector<float> incoming_backward = {-1, 2, -3, 4, 5, -6, -7, 8, 9, -10, -11, 12};
+ // Shape: {3}
+ std::vector<float> expected_bias_backward = {-14, 4, 12};
+ std::vector<float> bias_backward(3);
+
+ nnfw::cker::train::FullyConnectedBiasGrad(
+ nnfw::cker::Shape{1, 2, 2, 3}, incoming_backward.data(),
+ nnfw::cker::Shape{static_cast<int>(bias_backward.size())}, bias_backward.data());
+
+ for (size_t i = 0; i < bias_backward.size(); ++i)
+ ASSERT_EQ(bias_backward[i], expected_bias_backward[i]);
+ }
+}
+
+TEST(CKer_Operation, neg_FullyConnectedBiasGrad)
+{
+ {
+ // Unmatched shape
+ // Shape: {2, 4}
+ std::vector<float> incoming_backward = {-1, 2, -3, 4, 5, -6, -7, 8};
+ // Shape: {3}
+ std::vector<float> bias_backward(3);
+ EXPECT_ANY_THROW(nnfw::cker::train::FullyConnectedBiasGrad(
+ nnfw::cker::Shape{2, 4}, incoming_backward.data(),
+ nnfw::cker::Shape{static_cast<int>(bias_backward.size())},
+ bias_backward.data()););
+ }
+}
diff --git a/compute/cker/src/train/Loss.test.cc b/compute/cker/src/train/Loss.test.cc
new file mode 100644
index 000000000..98568f47a
--- /dev/null
+++ b/compute/cker/src/train/Loss.test.cc
@@ -0,0 +1,201 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cker/train/operation/Loss.h>
+
+#include <gtest/gtest.h>
+#include <vector>
+
+TEST(CKer_Operation, LossMSE)
+{
+ {
+ // Shape: {1, 10} -> m_rows:10, m_cols:1
+ std::vector<int> y_pred = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+ std::vector<int> y_true = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<int> output(1);
+ std::vector<int> expected = {1};
+
+ nnfw::cker::train::MSE(nnfw::cker::Shape{1, 10}, y_pred.data(), nnfw::cker::Shape{1, 10},
+ y_true.data(), nnfw::cker::Shape{1}, output.data());
+
+ EXPECT_EQ(output[0], expected[0]);
+ }
+
+ {
+ // Shape: {1, 10} -> m_rows:10, m_cols:1
+ std::vector<float> y_pred = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10.};
+ std::vector<float> y_true = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9.};
+ std::vector<float> output(1);
+ std::vector<float> expected = {1.0};
+
+ nnfw::cker::train::MSE(nnfw::cker::Shape{1, 10}, y_pred.data(), nnfw::cker::Shape{1, 10},
+ y_true.data(), nnfw::cker::Shape{1}, output.data());
+
+ EXPECT_FLOAT_EQ(output[0], expected[0]);
+ }
+
+ {
+ // Shape: {2, 3} -> m_rows:3, m_cols:2
+ std::vector<float> y_pred = {27.2, 31.8, 51.9, 10.2, 34.2, 12.4};
+ std::vector<float> y_true = {31.3, 40.3, 29.7, 12.9, 25.8, 11.9};
+ std::vector<float> output(1);
+ std::vector<float> expected = {110.0};
+
+ nnfw::cker::train::MSE(nnfw::cker::Shape{2, 3}, y_pred.data(), nnfw::cker::Shape{2, 3},
+ y_true.data(), nnfw::cker::Shape{1}, output.data());
+
+ EXPECT_FLOAT_EQ(output[0], expected[0]);
+ }
+
+ {
+ // Shape: {2, 3, 4} -> m_rows:4, m_cols:6
+ std::vector<float> y_pred = {1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
+ 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.};
+ std::vector<float> y_true = {1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.,
+ 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.};
+ std::vector<float> output(1);
+ std::vector<float> expected = {2.1666667};
+
+ nnfw::cker::train::MSE(nnfw::cker::Shape{2, 3, 4}, y_pred.data(), nnfw::cker::Shape{2, 3, 4},
+ y_true.data(), nnfw::cker::Shape{1}, output.data());
+
+ EXPECT_FLOAT_EQ(output[0], expected[0]);
+ }
+}
+
+TEST(CKer_Operation, neg_LossMSE)
+{
+ {
+ // Invalid expected value
+ std::vector<float> y_pred = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10.};
+ std::vector<float> y_true = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9.};
+ std::vector<float> output(1);
+ std::vector<float> expected = {-1.0};
+
+ nnfw::cker::train::MSE(nnfw::cker::Shape{2, 3, 4}, y_pred.data(), nnfw::cker::Shape{2, 3, 4},
+ y_true.data(), nnfw::cker::Shape{1}, output.data());
+
+ EXPECT_NE(output[0], expected[0]);
+ }
+
+ {
+ // Invalid output shape
+ std::vector<float> y_pred = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10.};
+ std::vector<float> y_true = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9.};
+ std::vector<float> output(3);
+ std::vector<float> expected = {1.0};
+
+ EXPECT_ANY_THROW(nnfw::cker::train::MSE(nnfw::cker::Shape{2, 3, 4}, y_pred.data(),
+ nnfw::cker::Shape{2, 3, 4}, y_true.data(),
+ nnfw::cker::Shape{3}, output.data()));
+ }
+
+ {
+ // Different y_pread and y_true shape
+ std::vector<float> y_pred = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10.};
+ std::vector<float> y_true = {0., 1., 2., 3., 4., 5.};
+ std::vector<float> output(1);
+ std::vector<float> expected = {1.0};
+
+ EXPECT_ANY_THROW(nnfw::cker::train::MSE(nnfw::cker::Shape{2, 3, 4}, y_pred.data(),
+ nnfw::cker::Shape{2, 3}, y_true.data(),
+ nnfw::cker::Shape{1}, output.data()));
+ }
+}
+
+TEST(CKer_Operation, LossMSEGrad)
+{
+ {
+ // Shape: {1, 10} -> m_rows:10, m_cols:1
+ std::vector<int> y_pred = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
+ std::vector<int> y_true = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ std::vector<int> deriv_y_pred(10);
+ std::vector<int> expected = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
+
+ nnfw::cker::train::MSEGrad(nnfw::cker::Shape{1, 10}, y_pred.data(), nnfw::cker::Shape{1, 10},
+ y_true.data(), nnfw::cker::Shape{1, 10}, deriv_y_pred.data());
+
+ for (size_t i = 0; i < deriv_y_pred.size(); ++i)
+ EXPECT_EQ(deriv_y_pred[i], expected[i]);
+ }
+
+ {
+ // Shape: {1, 10} -> m_rows:10, m_cols:1
+ std::vector<float> y_pred = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10.};
+ std::vector<float> y_true = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9.};
+ std::vector<float> deriv_y_pred(10);
+ std::vector<float> expected = {0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2};
+
+ nnfw::cker::train::MSEGrad(nnfw::cker::Shape{1, 10}, y_pred.data(), nnfw::cker::Shape{1, 10},
+ y_true.data(), nnfw::cker::Shape{1, 10}, deriv_y_pred.data());
+
+ for (size_t i = 0; i < deriv_y_pred.size(); ++i)
+ EXPECT_FLOAT_EQ(deriv_y_pred[i], expected[i]);
+ }
+
+ {
+ // Shape: {2, 3} -> m_rows:3, m_cols:2
+ std::vector<float> y_pred = {27.2, 31.8, 51.9, 10.2, 34.2, 12.4};
+ std::vector<float> y_true = {31.3, 40.3, 29.7, 12.9, 25.8, 11.9};
+ std::vector<float> deriv_y_pred(6);
+ std::vector<float> expected = {-1.3666667, -2.8333333, 7.4, -0.9, 2.8, 0.1666667};
+
+ nnfw::cker::train::MSEGrad(nnfw::cker::Shape{2, 3}, y_pred.data(), nnfw::cker::Shape{2, 3},
+ y_true.data(), nnfw::cker::Shape{2, 3}, deriv_y_pred.data());
+
+ for (size_t i = 0; i < deriv_y_pred.size(); ++i)
+ EXPECT_FLOAT_EQ(deriv_y_pred[i], expected[i]);
+ }
+}
+
+TEST(CKer_Operation, neg_LossMSEGrad)
+{
+ {
+ // Invalid expected value
+ std::vector<float> y_pred = {27.2, 31.8, 51.9, 10.2, 34.2, 12.4};
+ std::vector<float> y_true = {31.3, 40.3, 29.7, 12.9, 25.8, 11.9};
+ std::vector<float> deriv_y_pred(6);
+ std::vector<float> expected = {1., 1., 1., 1., 1., 1.};
+
+ nnfw::cker::train::MSEGrad(nnfw::cker::Shape{2, 3}, y_pred.data(), nnfw::cker::Shape{2, 3},
+ y_true.data(), nnfw::cker::Shape{2, 3}, deriv_y_pred.data());
+
+ for (size_t i = 0; i < deriv_y_pred.size(); ++i)
+ EXPECT_NE(deriv_y_pred[i], expected[i]);
+ }
+
+ {
+ // Different y_pred and y_true shape
+ std::vector<float> y_pred = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10.};
+ std::vector<float> y_true = {0., 1., 2., 3., 4., 5.};
+ std::vector<float> deriv_y_pred(10);
+
+ EXPECT_ANY_THROW(nnfw::cker::train::MSEGrad(nnfw::cker::Shape{1, 10}, y_pred.data(),
+ nnfw::cker::Shape{2, 3}, y_true.data(),
+ nnfw::cker::Shape{1, 10}, deriv_y_pred.data()));
+ }
+
+ {
+ // Different y_pred and deriv_y_pred shape
+ std::vector<float> y_pred = {1., 2., 3., 4., 5., 6., 7., 8., 9., 10.};
+ std::vector<float> y_true = {0., 1., 2., 3., 4., 5., 6., 7., 8., 9.};
+ std::vector<float> deriv_y_pred(6);
+
+ EXPECT_ANY_THROW(nnfw::cker::train::MSEGrad(nnfw::cker::Shape{1, 10}, y_pred.data(),
+ nnfw::cker::Shape{1, 10}, y_true.data(),
+ nnfw::cker::Shape{2, 3}, deriv_y_pred.data()));
+ }
+}
diff --git a/compute/cker/src/train/Relu.test.cc b/compute/cker/src/train/Relu.test.cc
new file mode 100644
index 000000000..d94411038
--- /dev/null
+++ b/compute/cker/src/train/Relu.test.cc
@@ -0,0 +1,107 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <cker/operation/ReLU.h>
+#include <cker/train/operation/ReLU.h>
+
+#include <gtest/gtest.h>
+#include <vector>
+
+namespace
+{
+
+template <typename T> class ReluOpVerifier
+{
+public:
+ ReluOpVerifier(const std::vector<T> &input, const std::vector<T> &expected_output,
+ const std::vector<T> &backprop_output,
+ const std::vector<T> &expected_backprop_input)
+ : _input{input}, _expected_output{expected_output}, _backprop_output{backprop_output},
+ _expected_backprop_input{expected_backprop_input}
+ {
+ EXPECT_TRUE(input.size() == expected_output.size());
+ _output.resize(_expected_output.size());
+ _backprop_input.resize(_expected_backprop_input.size());
+ }
+
+public:
+ void verifyExpected()
+ {
+ nnfw::cker::ReLU(nnfw::cker::Shape{static_cast<int>(_input.size())}, _input.data(),
+ nnfw::cker::Shape{static_cast<int>(_output.size())}, _output.data());
+
+ for (size_t i = 0; i < _output.size(); ++i)
+ ASSERT_EQ(_output[i], _expected_output[i]);
+
+ if (_backprop_output.size() > 0)
+ {
+ nnfw::cker::train::ReLUGrad(
+ nnfw::cker::Shape{static_cast<int>(_output.size())}, _output.data(),
+ nnfw::cker::Shape{static_cast<int>(_backprop_output.size())}, _backprop_output.data(),
+ nnfw::cker::Shape{static_cast<int>(_backprop_input.size())}, _backprop_input.data());
+
+ for (size_t i = 0; i < _backprop_input.size(); ++i)
+ ASSERT_EQ(_backprop_input[i], _expected_backprop_input[i]);
+ }
+ }
+
+private:
+ std::vector<T> _input;
+ std::vector<T> _output;
+ std::vector<T> _expected_output;
+ std::vector<T> _backprop_output;
+ std::vector<T> _backprop_input;
+ std::vector<T> _expected_backprop_input;
+};
+
+} // namespace
+
+TEST(CKer_Operation, ReLU)
+{
+ {
+ std::vector<float> input_forward = {-1, 2, 3, -4};
+ std::vector<float> expected_forward = {0, 2, 3, 0};
+ std::vector<float> incoming_backward = {-5, 6, -7, 8};
+ std::vector<float> expected_backward = {0, 6, -7, 0};
+ ReluOpVerifier<float> verifier{input_forward, expected_forward, incoming_backward,
+ expected_backward};
+ verifier.verifyExpected();
+ }
+
+ {
+ std::vector<float> input_forward = {0, -1, 2, 3, -4, 5, 6, -7};
+ std::vector<float> expected_forward = {0, 0, 2, 3, 0, 5, 6, 0};
+ std::vector<float> incoming_backward = {8, -9, 10, 11, -12, -13, 14, -15};
+ std::vector<float> expected_backward = {0, 0, 10, 11, 0, -13, 14, 0};
+ ReluOpVerifier<float> verifier{input_forward, expected_forward, incoming_backward,
+ expected_backward};
+ verifier.verifyExpected();
+ }
+}
+
+TEST(CKer_Operation, neg_ReLU)
+{
+ {
+ // Unmatched shape
+ std::vector<float> input_forward = {0, -1, 2, 3, -4};
+ std::vector<float> expected_forward = {0, 0, 2, 3, 0};
+ std::vector<float> incoming_backward = {-5, 6, -7, 8};
+ std::vector<float> expected_backward = {0, 6, -7, 0};
+ ReluOpVerifier<float> verifier{input_forward, expected_forward, incoming_backward,
+ expected_backward};
+ EXPECT_ANY_THROW(verifier.verifyExpected());
+ }
+}
diff --git a/compute/ruy/include/ruy/Shape.h b/compute/ruy/include/ruy/Shape.h
index 981c5b4de..151a67377 100644
--- a/compute/ruy/include/ruy/Shape.h
+++ b/compute/ruy/include/ruy/Shape.h
@@ -156,7 +156,7 @@ public:
const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
Resize(dimensions_count);
int32_t *data = DimsData();
- for (auto it : src_iterable)
+ for (auto &&it : src_iterable)
{
*data = it;
++data;