summaryrefslogtreecommitdiff
path: root/compiler/nest/core/examples/conv2d.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/nest/core/examples/conv2d.cpp')
-rw-r--r--compiler/nest/core/examples/conv2d.cpp152
1 files changed, 152 insertions, 0 deletions
diff --git a/compiler/nest/core/examples/conv2d.cpp b/compiler/nest/core/examples/conv2d.cpp
new file mode 100644
index 000000000..e405af9c0
--- /dev/null
+++ b/compiler/nest/core/examples/conv2d.cpp
@@ -0,0 +1,152 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <nest/Module.h>
+
+int main(int, char **)
+{
+ // This example shows how to specify convolution with IFM(1x3x3) and Kernel(1x1x3x3) with nest
+ // - STRIDE is 1, and there is no padding
+ //
+ // The below code corresponds to the following nest DSL code:
+ // ----------------------------------------------------------------------------------------------
+ // Domain ofm(1, 1, 1)
+ // Domain ifm(1, 3, 3)
+ // Domain ker(1, 1, 3, 3)
+ //
+ // Var ofm_ch : { min = 0, max = 1 }
+ // Var ofm_row : { min = 0, max = 1 }
+ // Var ofm_col : { min = 0, max = 1 }
+ // Var ker_ch : { min = 0, max = 1 }
+ // Var ker_row : { min = 0, max = 3 }
+ // Var ker_col : { min = 0, max = 3 }
+ //
+ // PUSH ifm(ker_ch, ker_row, ker_col) * ker(ofm_ch, ker_ch, ofm_row + ker_row, ofm_col + ker_col)
+ // RET ofm(ofm_ch, ofm_row, ofm_col)
+ // ----------------------------------------------------------------------------------------------
+ //
+ // The first part declares Domain(s) which corresponds to a multi-dimensional array in C-style
+ // (without type). For example, 'Domain ofm(1, 3, 3)' corresponds to the
+ // following C array declaration.
+ // float ofm[1][3][3];
+ // (Here we assume that domain type is 'float')
+ //
+ // The second part declares Var(s) which serves as a loop iteration variable. Basically, each
+ // variable emits one for loop and these loops are nested. As there are 6 variables in the above
+ // example, there will be 6 nested-loops.
+ //
+ // Each variable has a corresponding bound, and the bound of each variable states the starting /
+ // termination condition. For example, 'Var ofm_ch : { min = 0, max = 1 }' will introduce the
+ // following for loop:
+ // ----------------------------------------------------------------------------------------------
+ // for (int ofm_ch = 0; ofm_ch < 1; ++ofm_ch) { ... }
+ // ----------------------------------------------------------------------------------------------
+ //
+ // The last part declares statement(s) which state the computation performed inside these nested
+ // loops. Nest is stack-based. There is a virtual stack inside nested loop, and the evaluation of
+ // each statement will update this stack.
+ //
+ // Each nest code has one return statement (RET). This return statement specifies where to write
+ // the computed result.
+ //
+ // PUSH 'expr' statement evaluates an arithmetic expression (specified by 'expr') and pushes the
+ // numeric result to the stack. When PUSH statement evaluates an arithmetic expression, variables
+ // that do not appear in RET statement are treated as reduction variables. For example,
+ // ker_ch, ker_row, and ker_col do not appear in RET statement. So, PUSH '...' statement in the
+ // above example corresponds to the following nested loops:
+ // ----------------------------------------------------------------------------------------------
+ // float value = 0.0f;
+ //
+ // for (int ker_ch = 0; ker_ch < 1; ++ker_ch) {
+ // for (int ker_row = 0; ker_row < 3; ++ker_row) {
+ // for (int ker_col = 0; ker_col < 3; ++ker_col) {
+ // float ifm_value = ifm[ker_ch][ofm_row + ker_row][ofm_col + ker_col];
+ // float ker_value = ker[ofm_ch][ker_ch][ker_row][ker_col];
+ // value += ifm_value * ker_value;
+ // }
+ // }
+ // }
+ // ----------------------------------------------------------------------------------------------
+ //
+ // In summary, the above nest example corresponds to the following 2D convolution:
+ // ----------------------------------------------------------------------------------------------
+ // float ofm[1][1][1];
+ // float ifm[1][3][3];
+ // float ker[1][1][3][3];
+ //
+ // for (int ofm_ch = 0; ofm_ch < 1; ++ofm_ch) {
+ // for (int ofm_row = 0; ofm_row < 1; ++ofm_row) {
+ // for (int ofm_col = 0; ofm_col < 1; ++ofm_col) {
+ // float value = 0.0f;
+ //
+ // for (int ker_ch = 0; ker_ch < 1; ++ker_ch) {
+ // for (int ker_row = 0; ker_row < 3; ++ker_row) {
+ // for (int ker_col = 0; ker_col < 3; ++ker_col) {
+ // float ifm_value = ifm[ker_ch][ofm_row + ker_row][ofm_col + ker_col];
+ // float ker_value = ker[ofm_ch][ker_ch][ker_row][ker_col];
+ // value += ifm_value * ker_value;
+ // }
+ // }
+ // }
+ //
+ // ofm[ofm_ch][ofm_col][ofm_row] = value;
+ // }
+ // }
+ // }
+ // ----------------------------------------------------------------------------------------------
+ //
+ nest::Module m;
+
+ //
+ // Domains
+ //
+ auto ofm = m.domain().make({1 /*C*/, 1 /*H*/, 1 /*W*/});
+ auto ifm = m.domain().make({1 /*C*/, 3 /*H*/, 3 /*W*/});
+ auto ker = m.domain().make({1 /*N*/, 1 /*C*/, 3 /*H*/, 3 /*W*/});
+
+ //
+ // Variables
+ //
+ auto ofm_ch = m.var().make();
+ auto ofm_row = m.var().make();
+ auto ofm_col = m.var().make();
+
+ auto ker_ch = m.var().make();
+ auto ker_row = m.var().make();
+ auto ker_col = m.var().make();
+
+ // Declare the bound of each variables
+ using nest::Bound;
+
+ m.var().bound(ofm_ch) = Bound{0, 1};
+ m.var().bound(ofm_row) = Bound{0, 1};
+ m.var().bound(ofm_col) = Bound{0, 1};
+
+ m.var().bound(ker_ch) = Bound{0, 1};
+ m.var().bound(ker_row) = Bound{0, 3};
+ m.var().bound(ker_col) = Bound{0, 3};
+
+ //
+ // Statement
+ //
+ auto ifm_value = ifm(ker_ch, ofm_row + ker_row, ofm_col + ker_col);
+ auto ker_value = ker(ofm_ch, ker_ch, ker_row, ker_col);
+
+ m.push(ifm_value * ker_value);
+ m.ret(ofm(ofm_ch, ofm_row, ofm_col));
+
+ return 0;
+}