summaryrefslogtreecommitdiff
path: root/onert-micro/luci-interpreter/pal/common/PALReduceCommon.h
diff options
context:
space:
mode:
Diffstat (limited to 'onert-micro/luci-interpreter/pal/common/PALReduceCommon.h')
-rw-r--r--onert-micro/luci-interpreter/pal/common/PALReduceCommon.h114
1 files changed, 114 insertions, 0 deletions
diff --git a/onert-micro/luci-interpreter/pal/common/PALReduceCommon.h b/onert-micro/luci-interpreter/pal/common/PALReduceCommon.h
new file mode 100644
index 000000000..a5b0e10dd
--- /dev/null
+++ b/onert-micro/luci-interpreter/pal/common/PALReduceCommon.h
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
+ * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LUCI_INTERPRETER_PAL_TANH_H
+#define LUCI_INTERPRETER_PAL_TANH_H
+
+#include "PALUtils.h"
+
+namespace luci_interpreter_pal
+{
+namespace
+{
+// This method parses the input 'axis' to remove duplicates and handle negative
+// values, and returns a valid 'out_axis'
+inline bool resolveAxis(const int num_dims, const int *axis, const int64_t num_axis,
+ int *out_num_axis)
+{
+ int out_axis[2];
+ *out_num_axis = 0; // Just in case.
+ // Short-circuit axis resolution for scalars; the axis will go unused.
+ if (num_dims == 0)
+ {
+ return true;
+ }
+ // o(n^2) is fine since out_num_axis should be really small, mostly <= 4
+ for (int64_t idx = 0; idx < num_axis; ++idx)
+ {
+ // Handle negative index. A positive index 'p_idx' can be represented as a
+ // negative index 'n_idx' as: n_idx = p_idx-num_dims
+ // eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1] */
+ int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
+ if (current < 0 || current >= num_dims)
+ {
+ return false;
+ }
+ bool is_dup = false;
+ for (int j = 0; j < *out_num_axis; ++j)
+ {
+ if (out_axis[j] == current)
+ {
+ is_dup = true;
+ break;
+ }
+ }
+ if (!is_dup)
+ {
+ out_axis[*out_num_axis] = current;
+ *out_num_axis += 1;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+// Computes the generic value (i.e., sum/max/min/prod) of elements across
+// dimensions given in axis. It needs to pass in init_value and reducer.
+template <typename T>
+inline void ReduceGeneric(const T *input_data, const int *input_dims, const int input_num_dims,
+ T *output_data, const int *axis, const int64_t num_axis_dimensions,
+ T init_value, const int output_flat_size, T reducer(const T, const T))
+{
+ // Return early when input shape has zero dim.
+ for (int i = 0; i < input_num_dims; ++i)
+ {
+ if (input_dims[i] == 0)
+ return;
+ }
+
+ for (size_t idx = 0; idx < output_flat_size; ++idx)
+ {
+ output_data[idx] = init_value;
+ }
+
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ if (!resolveAxis(input_num_dims, axis, num_axis_dimensions, &num_resolved_axis))
+ {
+ return;
+ }
+
+ int temp_index[5];
+ // Reset input iterator.
+ for (int idx = 0; idx < input_num_dims; ++idx)
+ {
+ temp_index[idx] = 0;
+ }
+ // Iterate through input_data.
+ do
+ {
+ size_t input_offset = reducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr);
+ size_t output_offset =
+ reducedOutputOffset(input_num_dims, input_dims, temp_index, num_resolved_axis, axis);
+ output_data[output_offset] = reducer(output_data[output_offset], input_data[input_offset]);
+ } while (nextIndex(input_num_dims, input_dims, temp_index));
+}
+
+} // namespace luci_interpreter_pal
+
+#endif // LUCI_INTERPRETER_PAL_TANH_H