summaryrefslogtreecommitdiff
path: root/runtimes/libs/srcn/src/winograd.h
diff options
context:
space:
mode:
Diffstat (limited to 'runtimes/libs/srcn/src/winograd.h')
-rw-r--r--runtimes/libs/srcn/src/winograd.h148
1 files changed, 148 insertions, 0 deletions
diff --git a/runtimes/libs/srcn/src/winograd.h b/runtimes/libs/srcn/src/winograd.h
new file mode 100644
index 000000000..5ad8f1126
--- /dev/null
+++ b/runtimes/libs/srcn/src/winograd.h
@@ -0,0 +1,148 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __NNFW_SRCN_WINOGRAD_H__
+#define __NNFW_SRCN_WINOGRAD_H__
+
+namespace nnfw
+{
+namespace srcn
+{
+
+struct winograd_para_3x3s1
+{
+ static const int M = 3 + 4 - 1;
+ static const int N = 3;
+
+ static const double *getG()
+ {
+ static const double G[M * N] = {
+ 1. / 4., 0, 0, -1. / 6., -1. / 6., -1. / 6., -1. / 6., 1. / 6., -1. / 6.,
+ 1. / 24., 1. / 12., 1. / 6., 1. / 24., -1. / 12., 1. / 6., 0, 0, 1,
+ };
+ return G;
+ }
+
+ static const double *getA()
+ {
+ static const double A[M * (M - N + 1)] = {
+ 1, 0, 0, 0, 1, 1, 1, 1, 1, -1, 1, -1, 1, 2, 4, 8, 1, -2, 4, -8, 0, 0, 0, 1,
+ };
+ return A;
+ }
+
+ static const double *getB()
+ {
+ static const double B[M * M] = {
+ 4, 0, 0, 0, 0, 0, 0, -4, 4, -2, 2, 4, -5, -4, -4, -1, -1, 0,
+ 0, 1, -1, 2, -2, -5, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1,
+ };
+ return B;
+ };
+};
+
+struct winograd_para_3x3s1_2
+{
+ static const int M = 3 + 2 - 1;
+ static const int N = 3;
+
+ static const double *getG()
+ {
+ static const double G[M * N] = {
+ 1, 0, 0, 1. / 2., 1. / 2., 1. / 2., 1. / 2., -1. / 2., 1. / 2., 0, 0, 1,
+ };
+ return G;
+ }
+
+ static const double *getA()
+ {
+ static const double A[M * (M - N + 1)] = {
+ 1, 0, 1, 1, 1, -1, 0, 1,
+ };
+ return A;
+ }
+
+ static const double *getB()
+ {
+ static const double B[M * M] = {
+ 1, 0, 0, 0, 0, 1, -1, -1, -1, 1, 1, 0, 0, 0, 0, 1,
+ };
+ return B;
+ };
+};
+
+struct winograd_para_5x5s1
+{
+ static const int M = 5 + 4 - 1;
+ static const int N = 5;
+
+ static const double *getG()
+ {
+ static const double G[M * N] = {
+ 1, 0, 0, 0, 0, -2. / 9., -2. / 9., -2. / 9.,
+ -2. / 9., -2. / 9., -2. / 9., 2. / 9., -2. / 9., 2. / 9., -2. / 9., 1. / 90.,
+ 1. / 45., 2. / 45., 4. / 45., 8. / 45., 1. / 90., -1. / 45., 2. / 45., -4. / 45.,
+ 8. / 45., 4. / 45., 2. / 45., 1. / 45., 1. / 90., 1. / 180., 4. / 45., -2. / 45.,
+ 1. / 45., -1. / 90., 1. / 180., 0, 0, 0, 0, 1,
+ };
+ return G;
+ }
+
+ static const double *getA()
+ {
+ static const double A[M * (M - N + 1)] = {1, 0, 0, 0, 1, 1, 1, 1, 1, -1, 1, -1, 1, 2, 4, 8,
+ 1, -2, 4, -8, 8, 4, 2, 1, 8, -4, 2, -1, 0, 0, 0, 1};
+ return A;
+ }
+
+ static const double *getB()
+ {
+ static const double B[M * M] = {
+ 1, 0, 0, 0, 0, 0, 0, 0, 0, 1,
+ -1, 1. / 2, -1. / 2, 2, -2, -1, -21. / 4, 1, 1, 1. / 4,
+ 1. / 4, 4, 4, 0, 0, -17. / 4, 17. / 4, -5. / 2, 5. / 2, -5. / 2,
+ 5. / 2, 21. / 4, 21. / 4, -17. / 4, -17. / 4, -5. / 4, -5. / 4, -5, -5, 0,
+ 0, 1, -1, 2, -2, 1. / 2, -1. / 2, -21. / 4, -1, 1,
+ 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1,
+ };
+ return B;
+ }
+};
+
+static void kronecker_product(float *out, const double *in1, const double *in2, int m, int n, int p,
+ int q)
+{
+ for (int i = 0; i < m; ++i)
+ {
+ for (int j = 0; j < n; ++j)
+ {
+ for (int k = 0; k < p; ++k)
+ {
+ for (int l = 0; l < q; ++l)
+ {
+ out[(p * i + k) * n * q + q * j + l] = in1[n * i + j] * in2[k * q + l];
+ /* compute in double precision and then convert it back to Dtype for accuracy */
+ }
+ }
+ }
+ }
+}
+
+} // namespace srcn
+} // namespace nnfw
+
+#endif // __NNFW_SRCN_WINOGRAD_H__