diff options
Diffstat (limited to 'compute/ncnn/src/srcn/sgemm_pack.cc')
-rw-r--r-- | compute/ncnn/src/srcn/sgemm_pack.cc | 2316 |
1 files changed, 0 insertions, 2316 deletions
diff --git a/compute/ncnn/src/srcn/sgemm_pack.cc b/compute/ncnn/src/srcn/sgemm_pack.cc deleted file mode 100644 index 8767f6c0a..000000000 --- a/compute/ncnn/src/srcn/sgemm_pack.cc +++ /dev/null @@ -1,2316 +0,0 @@ -/* - * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <stdlib.h> -#include <arm_neon.h> - -#include "ncnn/srcn/conv_type.h" -#include "common.h" - -namespace nnfw -{ -namespace srcn -{ - -void _pack_rowmajor_notrans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr) -{ - const int nm = mb / mr; - const int rm = mb % mr; - - switch (mr) - { -#if __aarch64__ - case 24: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v13.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v14.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v15.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v12.4s, v14.4s\n" - "zip2 v30.4s, v12.4s, v14.4s\n" - "zip1 v29.4s, v13.4s, v15.4s\n" - "zip2 v31.4s, v13.4s, v15.4s\n" - "zip1 v12.4s, v28.4s, v29.4s\n" - "zip2 v13.4s, v28.4s, v29.4s\n" - "zip1 v14.4s, v30.4s, v31.4s\n" - "zip2 v15.4s, v30.4s, v31.4s\n" - - "ld1 {v16.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v17.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v18.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v19.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v16.4s, v18.4s\n" - "zip2 v30.4s, v16.4s, v18.4s\n" - "zip1 v29.4s, v17.4s, v19.4s\n" - "zip2 v31.4s, v17.4s, v19.4s\n" - "zip1 v16.4s, v28.4s, v29.4s\n" - "zip2 v17.4s, v28.4s, v29.4s\n" - "zip1 v18.4s, v30.4s, v31.4s\n" - "zip2 v19.4s, v30.4s, v31.4s\n" - - "ld1 {v20.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v21.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v22.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v23.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v20.4s, v22.4s\n" - "zip2 v30.4s, v20.4s, v22.4s\n" - "zip1 v29.4s, v21.4s, v23.4s\n" - "zip2 v31.4s, v21.4s, v23.4s\n" - "zip1 v20.4s, v28.4s, v29.4s\n" - "zip2 v21.4s, v28.4s, v29.4s\n" - "zip1 v22.4s, v30.4s, v31.4s\n" - "zip2 v23.4s, v30.4s, v31.4s\n" - - "ld1 {v24.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v25.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v26.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v27.4s}, [x0]\n" - - "zip1 v28.4s, v24.4s, v26.4s\n" - "zip2 v30.4s, v24.4s, v26.4s\n" - "zip1 v29.4s, v25.4s, v27.4s\n" - "zip2 v31.4s, v25.4s, v27.4s\n" - "zip1 v24.4s, v28.4s, v29.4s\n" - "zip2 v25.4s, v28.4s, v29.4s\n" - "zip1 v26.4s, v30.4s, v31.4s\n" - "zip2 v27.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v8.4s}, [%[plhs_ptr]], #16\n" - "st1 {v12.4s}, [%[plhs_ptr]], #16\n" - "st1 {v16.4s}, [%[plhs_ptr]], #16\n" - "st1 {v20.4s}, [%[plhs_ptr]], #16\n" - "st1 {v24.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v9.4s}, [%[plhs_ptr]], #16\n" - "st1 {v13.4s}, [%[plhs_ptr]], #16\n" - "st1 {v17.4s}, [%[plhs_ptr]], #16\n" - "st1 {v21.4s}, [%[plhs_ptr]], #16\n" - "st1 {v25.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v10.4s}, [%[plhs_ptr]], #16\n" - "st1 {v14.4s}, [%[plhs_ptr]], #16\n" - "st1 {v18.4s}, [%[plhs_ptr]], #16\n" - "st1 {v22.4s}, [%[plhs_ptr]], #16\n" - "st1 {v26.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - "st1 {v11.4s}, [%[plhs_ptr]], #16\n" - "st1 {v15.4s}, [%[plhs_ptr]], #16\n" - "st1 {v19.4s}, [%[plhs_ptr]], #16\n" - "st1 {v23.4s}, [%[plhs_ptr]], #16\n" - "st1 {v27.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr[4] = lhs_temp[stride << 2]; - plhs_ptr[5] = lhs_temp[5 * stride]; - plhs_ptr[6] = lhs_temp[6 * stride]; - plhs_ptr[7] = lhs_temp[7 * stride]; - plhs_ptr[8] = lhs_temp[stride << 3]; - plhs_ptr[9] = lhs_temp[9 * stride]; - plhs_ptr[10] = lhs_temp[10 * stride]; - plhs_ptr[11] = lhs_temp[11 * stride]; - plhs_ptr[12] = lhs_temp[0]; - plhs_ptr[13] = lhs_temp[13 * stride]; - plhs_ptr[14] = lhs_temp[14 * stride]; - plhs_ptr[15] = lhs_temp[15 * stride]; - plhs_ptr[16] = lhs_temp[stride << 4]; - plhs_ptr[17] = lhs_temp[17 * stride]; - plhs_ptr[18] = lhs_temp[18 * stride]; - plhs_ptr[19] = lhs_temp[19 * stride]; - plhs_ptr[20] = lhs_temp[20 * stride]; - plhs_ptr[21] = lhs_temp[21 * stride]; - plhs_ptr[22] = lhs_temp[22 * stride]; - plhs_ptr[23] = lhs_temp[23 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; - case 16: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v13.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v14.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v15.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v12.4s, v14.4s\n" - "zip2 v30.4s, v12.4s, v14.4s\n" - "zip1 v29.4s, v13.4s, v15.4s\n" - "zip2 v31.4s, v13.4s, v15.4s\n" - "zip1 v12.4s, v28.4s, v29.4s\n" - "zip2 v13.4s, v28.4s, v29.4s\n" - "zip1 v14.4s, v30.4s, v31.4s\n" - "zip2 v15.4s, v30.4s, v31.4s\n" - - "ld1 {v16.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v17.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v18.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v19.4s}, [x0]\n" - - "zip1 v28.4s, v16.4s, v18.4s\n" - "zip2 v30.4s, v16.4s, v18.4s\n" - "zip1 v29.4s, v17.4s, v19.4s\n" - "zip2 v31.4s, v17.4s, v19.4s\n" - "zip1 v16.4s, v28.4s, v29.4s\n" - "zip2 v17.4s, v28.4s, v29.4s\n" - "zip1 v18.4s, v30.4s, v31.4s\n" - "zip2 v19.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v8.4s}, [%[plhs_ptr]], #16\n" - "st1 {v12.4s}, [%[plhs_ptr]], #16\n" - "st1 {v16.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v9.4s}, [%[plhs_ptr]], #16\n" - "st1 {v13.4s}, [%[plhs_ptr]], #16\n" - "st1 {v17.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v10.4s}, [%[plhs_ptr]], #16\n" - "st1 {v14.4s}, [%[plhs_ptr]], #16\n" - "st1 {v18.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - "st1 {v11.4s}, [%[plhs_ptr]], #16\n" - "st1 {v15.4s}, [%[plhs_ptr]], #16\n" - "st1 {v19.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v28", "v29", - "v30", "v31"); - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr[4] = lhs_temp[stride << 2]; - plhs_ptr[5] = lhs_temp[5 * stride]; - plhs_ptr[6] = lhs_temp[6 * stride]; - plhs_ptr[7] = lhs_temp[7 * stride]; - plhs_ptr[8] = lhs_temp[stride << 3]; - plhs_ptr[9] = lhs_temp[9 * stride]; - plhs_ptr[10] = lhs_temp[10 * stride]; - plhs_ptr[11] = lhs_temp[11 * stride]; - plhs_ptr[12] = lhs_temp[0]; - plhs_ptr[13] = lhs_temp[13 * stride]; - plhs_ptr[14] = lhs_temp[14 * stride]; - plhs_ptr[15] = lhs_temp[15 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; -#endif // __aarch64__ - case 12: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v13.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v14.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v15.4s}, [x0]\n" - - "zip1 v28.4s, v12.4s, v14.4s\n" - "zip2 v30.4s, v12.4s, v14.4s\n" - "zip1 v29.4s, v13.4s, v15.4s\n" - "zip2 v31.4s, v13.4s, v15.4s\n" - "zip1 v12.4s, v28.4s, v29.4s\n" - "zip2 v13.4s, v28.4s, v29.4s\n" - "zip1 v14.4s, v30.4s, v31.4s\n" - "zip2 v15.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v8.4s}, [%[plhs_ptr]], #16\n" - "st1 {v12.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v9.4s}, [%[plhs_ptr]], #16\n" - "st1 {v13.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v10.4s}, [%[plhs_ptr]], #16\n" - "st1 {v14.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - "st1 {v11.4s}, [%[plhs_ptr]], #16\n" - "st1 {v15.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[lhs_temp]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q8, q10\n" - "vzip.32 q9, q11\n" - "vzip.32 q8, q9\n" - "vzip.32 q10, q11\n" - - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d28-d29}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d30-d31}, [r0]\n" - - "vzip.32 q12, q14\n" - "vzip.32 q13, q15\n" - "vzip.32 q12, q13\n" - "vzip.32 q14, q15\n" - - "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" - "vst1.f32 {d16-d17}, [%[plhs_ptr]]!\n" - "vst1.f32 {d24-d25}, [%[plhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" - "vst1.f32 {d18-d19}, [%[plhs_ptr]]!\n" - "vst1.f32 {d26-d27}, [%[plhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" - "vst1.f32 {d20-d21}, [%[plhs_ptr]]!\n" - "vst1.f32 {d28-d29}, [%[plhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" - "vst1.f32 {d22-d23}, [%[plhs_ptr]]!\n" - "vst1.f32 {d30-d31}, [%[plhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", - "q12", "q13", "q14", "q15"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr[4] = lhs_temp[stride << 2]; - plhs_ptr[5] = lhs_temp[5 * stride]; - plhs_ptr[6] = lhs_temp[6 * stride]; - plhs_ptr[7] = lhs_temp[7 * stride]; - plhs_ptr[8] = lhs_temp[stride << 3]; - plhs_ptr[9] = lhs_temp[9 * stride]; - plhs_ptr[10] = lhs_temp[10 * stride]; - plhs_ptr[11] = lhs_temp[11 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; - case 8: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v8.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v9.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v10.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - "st1 {v11.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[lhs_temp]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vzip.32 q8, q10\n" - "vzip.32 q9, q11\n" - "vzip.32 q8, q9\n" - "vzip.32 q10, q11\n" - - "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" - "vst1.f32 {d16-d17}, [%[plhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" - "vst1.f32 {d18-d19}, [%[plhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" - "vst1.f32 {d20-d21}, [%[plhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" - "vst1.f32 {d22-d23}, [%[plhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr[4] = lhs_temp[stride << 2]; - plhs_ptr[5] = lhs_temp[5 * stride]; - plhs_ptr[6] = lhs_temp[6 * stride]; - plhs_ptr[7] = lhs_temp[7 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; - case 6: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { -#if __aarch64__ - // TODO: 4--->6 - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v8.4s}, [x0]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[lhs_temp]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - "vzip.32 q8, q9\n" - - "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" - "vst1.f32 {d16}, [%[plhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" - "vst1.f32 {d17}, [%[plhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" - "vst1.f32 {d18}, [%[plhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" - "vst1.f32 {d19}, [%[plhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr[4] = lhs_temp[stride << 2]; - plhs_ptr[5] = lhs_temp[5 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; - case 4: - for (int i = 0; i < nm; i++) - { - int nk = kb >> 2; - int rk = kb & 0x03; - - const float *lhs_temp = lhs_ptr; - const int _stride = stride << 2; - - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[lhs_temp]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[plhs_ptr]], #16\n" - "st1 {v5.4s}, [%[plhs_ptr]], #16\n" - "st1 {v6.4s}, [%[plhs_ptr]], #16\n" - "st1 {v7.4s}, [%[plhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[lhs_temp]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[lhs_temp], %[lhs_temp], #16\n" - "bne 0b\n" - : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - plhs_ptr[0] = lhs_temp[0]; - plhs_ptr[1] = lhs_temp[stride]; - plhs_ptr[2] = lhs_temp[stride << 1]; - plhs_ptr[3] = lhs_temp[3 * stride]; - plhs_ptr += mr; - lhs_temp++; - } - - lhs_ptr += mr * stride; - } - break; - default: - break; - } - - if (rm > 0) - { - for (int j = 0; j < kb; j++) - { - for (int i = 0; i < rm; i++) - { - plhs_ptr[i] = lhs_ptr[i * stride]; - } - for (int i = rm; i < mr; i++) - { - plhs_ptr[i] = 0.f; - } - plhs_ptr += mr; - lhs_ptr++; - } - } -} - -void _pack_rowmajor_notrans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr) -{ - const int nn = nb / nr; - const int rn = nb % nr; - - switch (nr) - { - case 24: - for (int j = 0; j < nn; j++) - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q1, q2, q3, q4, q5; - for (int i = 0; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + 4); - q2 = vld1q_f32(rhs_temp + 8); - q3 = vld1q_f32(rhs_temp + 12); - q4 = vld1q_f32(rhs_temp + 16); - q5 = vld1q_f32(rhs_temp + 20); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - vst1q_f32(prhs_ptr + 16, q4); - vst1q_f32(prhs_ptr + 20, q5); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - case 16: - for (int j = 0; j < nn; j++) - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q1, q2, q3; - for (int i = 0; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + 4); - q2 = vld1q_f32(rhs_temp + 8); - q3 = vld1q_f32(rhs_temp + 12); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - case 12: - for (int j = 0; j < nn; j++) - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q1, q2; - for (int i = 0; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + 4); - q2 = vld1q_f32(rhs_temp + 8); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - case 8: - for (int j = 0; j < nn; j++) - - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q1, q2, q3; - - int i = 0; - for (; i + 1 < kb; i += 2) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + 4); - q2 = vld1q_f32(rhs_temp + stride); - q3 = vld1q_f32(rhs_temp + stride + 4); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - - rhs_temp += stride << 1; - prhs_ptr += nr << 1; - } - - for (; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + 4); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - case 6: - for (int j = 0; j < nn; j++) - - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q2; - float32x2_t q1, q3; - - int i = 0; - for (; i + 1 < kb; i += 2) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1_f32(rhs_temp + 4); - - q2 = vld1q_f32(rhs_temp + stride); - q3 = vld1_f32(rhs_temp + stride + 4); - vst1q_f32(prhs_ptr, q0); - vst1_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 6, q2); - vst1_f32(prhs_ptr + 10, q3); - - rhs_temp += stride << 1; - prhs_ptr += nr << 1; - } - - for (; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1_f32(rhs_temp + 4); - - vst1q_f32(prhs_ptr, q0); - vst1_f32(prhs_ptr + 4, q1); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - case 4: - for (int j = 0; j < nn; j++) - - { - const float *rhs_temp = rhs_ptr; - float32x4_t q0, q1, q2, q3; - - int i = 0; - for (; i + 3 < kb; i += 4) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + stride); - q2 = vld1q_f32(rhs_temp + (stride << 1)); - q3 = vld1q_f32(rhs_temp + (stride * 3)); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - - rhs_temp += stride << 2; - prhs_ptr += nr << 2; - } - for (; i + 1 < kb; i += 2) - { - q0 = vld1q_f32(rhs_temp); - q1 = vld1q_f32(rhs_temp + stride); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - - rhs_temp += stride << 1; - prhs_ptr += nr << 1; - } - for (; i < kb; i++) - { - q0 = vld1q_f32(rhs_temp); - vst1q_f32(prhs_ptr, q0); - - rhs_temp += stride; - prhs_ptr += nr; - } - - rhs_ptr += nr; - } - break; - default: - break; - } - - if (rn > 0) - { - for (int i = 0; i < kb; i++) - { - for (int j = 0; j < rn; j++) - { - prhs_ptr[j] = rhs_ptr[j]; - } - for (int j = rn; j < nr; j++) - { - prhs_ptr[j] = 0.f; - } - prhs_ptr += nr; - rhs_ptr += stride; - } - } -} - -void _pack_rowmajor_trans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr) -{ - _pack_rowmajor_notrans_rhs(mr, mb, kb, stride, lhs_ptr, plhs_ptr); -} - -void _pack_rowmajor_trans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr) -{ - _pack_rowmajor_notrans_lhs(nr, nb, kb, stride, rhs_ptr, prhs_ptr); -} - -static inline void _pack_rowmajor_image_subn(const int nr, const int nb, const int stride, - const float *buffer, float *prhs_ptr) -{ - const int nn = nb / nr; - const int rn = nb % nr; - - switch (nr) - { - case 24: - for (int j = 0; j < nn; j++) - { - float32x4_t q0, q1, q2, q3, q4, q5; - q0 = vld1q_f32(buffer); - q1 = vld1q_f32(buffer + 4); - q2 = vld1q_f32(buffer + 8); - q3 = vld1q_f32(buffer + 12); - q4 = vld1q_f32(buffer + 16); - q5 = vld1q_f32(buffer + 20); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - vst1q_f32(prhs_ptr + 16, q4); - vst1q_f32(prhs_ptr + 20, q5); - prhs_ptr += stride; - buffer += nr; - } - break; - case 16: - for (int j = 0; j < nn; j++) - { - float32x4_t q0, q1, q2, q3; - q0 = vld1q_f32(buffer); - q1 = vld1q_f32(buffer + 4); - q2 = vld1q_f32(buffer + 8); - q3 = vld1q_f32(buffer + 12); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - vst1q_f32(prhs_ptr + 12, q3); - prhs_ptr += stride; - buffer += nr; - } - break; - case 12: - for (int j = 0; j < nn; j++) - { - float32x4_t q0, q1, q2; - q0 = vld1q_f32(buffer); - q1 = vld1q_f32(buffer + 4); - q2 = vld1q_f32(buffer + 8); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - vst1q_f32(prhs_ptr + 8, q2); - prhs_ptr += stride; - buffer += nr; - } - break; - case 8: - for (int j = 0; j < nn; j++) - { - float32x4_t q0, q1; - q0 = vld1q_f32(buffer); - q1 = vld1q_f32(buffer + 4); - vst1q_f32(prhs_ptr, q0); - vst1q_f32(prhs_ptr + 4, q1); - prhs_ptr += stride; - buffer += nr; - } - break; - case 6: - for (int j = 0; j < nn; j++) - { - float32x4_t q0; - float32x2_t q1; - q0 = vld1q_f32(buffer); - q1 = vld1_f32(buffer + 4); - vst1q_f32(prhs_ptr, q0); - vst1_f32(prhs_ptr + 4, q1); - prhs_ptr += stride; - buffer += nr; - } - break; - case 4: - for (int j = 0; j < nn; j++) - { - float32x4_t q0; - q0 = vld1q_f32(buffer); - vst1q_f32(prhs_ptr, q0); - prhs_ptr += stride; - buffer += nr; - } - break; - default: - break; - } - - if (rn > 0) - { - for (int j = 0; j < rn; j++) - { - prhs_ptr[j] = buffer[j]; - } - for (int j = rn; j < nr; j++) - { - prhs_ptr[j] = 0.f; - } - } -} - -void _pack_rowmajor_image_rhs(const int nr, const int nb, const int kb, const int k0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *prhs_ptr) -{ - const int w = input->w; - const int h = input->h; - const int outw = output->w; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - - const int in_row0 = n0 / outw * stride_h; - const int in_col0 = n0 % outw * stride_w; - int seg0 = outw - n0 % outw; - if (seg0 > nb) - seg0 = nb; - int rows = (nb - seg0 + outw - 1) / outw; - if (seg0) - rows++; - const int segn = (nb - seg0) % outw; - - float row_data[nb]; - - for (int i = k0; i < kb + k0; i++) - { - const int ic = i / (kernel_w * kernel_h); - const int in_row1 = ((i / kernel_w) % kernel_h) * params->dilation_h + in_row0; - const int in_col1 = i % kernel_w * params->dilation_w; - -#ifdef NCNN - const float *input_data = input->data + ic * alignSize(w * h, 16 / sizeof(float)); -#else // NCNN - const float *input_data = input->data + ic * w * h; -#endif // NCNN - float *buffer = row_data; - int in_row = in_row1 - pad_h; - - for (int out_rows = rows; out_rows; out_rows--) - { - int cols = (out_rows != 1 || segn == 0) ? outw : segn; - int in_col = in_col1 - pad_w; - if (out_rows == rows) - { - cols = seg0; - in_col += in_col0; - } - if ((unsigned int)in_row < (unsigned int)h) - { - for (int out_col = cols; out_col; out_col--) - { - if ((unsigned int)in_col < (unsigned int)w) - *(buffer++) = input_data[in_row * w + in_col]; - else - *(buffer++) = 0; - in_col += stride_w; - } - } - else - { - for (int out_col = cols; out_col; out_col--) - { - *(buffer++) = 0; - in_col += stride_w; - } - } - - in_row += stride_h; - } - - _pack_rowmajor_image_subn(nr, nb, nr * kb, row_data, prhs_ptr); - prhs_ptr += nr; - } -} - -void _pack_rowmajor_image_rhs_batch(const int nr, const int nb, const int kb, const int k0, - const int n0, convMat_t *input, convMat_t *output, - convParams_t *params, float *prhs_ptr) -{ - const int w = input->w; - const int h = input->h; - const int c = input->c; - -#ifdef NCNN - const int seg_size = alignSize(output->w * output->h, 16 / sizeof(float)); -#else // NCNN - const int seg_size = output->w * output->h; -#endif // NCNN - -#ifdef NCNN - float *data = input->data + (alignSize(w * h, 16 / sizeof(float)) * c) * (n0 / seg_size); -#else // NCNN - float *data = input->data + (w * h * c) * (n0 / seg_size); -#endif // NCNN - - int seg0 = seg_size - n0 % seg_size; - if (seg0 > nb) - seg0 = nb; - int nseg = (nb - seg0 + seg_size - 1) / seg_size; - if (seg0) - nseg++; - const int segn = (nb - seg0) % seg_size; - convMat_t _input = {w, h, c, 1, data}; - - for (int i = 0; i < nseg; i++) - { - const int _nb = (i == 0 ? seg0 : (i == nseg - 1 ? segn : seg_size)); - const int _n0 = (i == 0 ? seg_size - seg0 : 0); - - _pack_rowmajor_image_rhs(nr, _nb, kb, k0, _n0, &_input, output, params, prhs_ptr); - -#ifdef NCNN - _input.data += alignSize(w * h, 16 / sizeof(float)) * c; -#else // NCNN - _input.data += w * h * c; -#endif // NCNN - } -} - -void _unpack_rowmajor_image_res(const int mb, const int nb, const int m0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *pres_ptr) -{ - const int outw = output->w; - const int outh = output->h; - const int w = input->w; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - - const int out_row0 = n0 / w * stride_h; - const int out_col0 = n0 % w * stride_w; - int seg0 = w - n0 % w; - if (seg0 > nb) - seg0 = nb; - int rows = (nb - seg0 + w - 1) / w; - if (seg0) - rows++; - const int segn = (nb - seg0) % w; - - for (int i = m0; i < mb + m0; i++) - { - const int oc = i / (kernel_w * kernel_h); - const int out_row1 = ((i / kernel_w) % kernel_h) * params->dilation_h + out_row0; - const int out_col1 = i % kernel_w * params->dilation_w; - -#ifdef NCNN - float *output_data = output->data + oc * alignSize(outw * outh, 16 / sizeof(float)); -#else // NCNN - float *output_data = output->data + oc * outw * outh; -#endif // NCNN - int out_row = out_row1 - pad_h; - - for (int in_rows = rows; in_rows; in_rows--) - { - int cols = (in_rows != 1 || segn == 0) ? w : segn; - int out_col = out_col1 - pad_w; - if (in_rows == rows) - { - cols = seg0; - out_col += out_col0; - } - if ((unsigned int)out_row < (unsigned int)outh) - { - for (int in_col = cols; in_col; in_col--) - { - if ((unsigned int)out_col < (unsigned int)outw) - output_data[out_row * outw + out_col] += *pres_ptr++; - else - pres_ptr++; - out_col += stride_w; - } - } - else - { - pres_ptr += cols; - } - out_row += stride_h; - } - } -} - -// TODO:v8 & other case. -static inline void _pack_colmajor_image_rhs_sub(const int nr, const int k, const float *buffer, - float *prhs_ptr) -{ - int nk = k >> 2; - int rk = k & 0x03; - - const int _stride = k << 2; - - switch (nr) - { - case 12: - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[buffer]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "ld1 {v12.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v13.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v14.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v15.4s}, [x0]\n" - - "zip1 v28.4s, v12.4s, v14.4s\n" - "zip2 v30.4s, v12.4s, v14.4s\n" - "zip1 v29.4s, v13.4s, v15.4s\n" - "zip2 v31.4s, v13.4s, v15.4s\n" - "zip1 v12.4s, v28.4s, v29.4s\n" - "zip2 v13.4s, v28.4s, v29.4s\n" - "zip1 v14.4s, v30.4s, v31.4s\n" - "zip2 v15.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[prhs_ptr]], #16\n" - "st1 {v8.4s}, [%[prhs_ptr]], #16\n" - "st1 {v12.4s}, [%[prhs_ptr]], #16\n" - "st1 {v5.4s}, [%[prhs_ptr]], #16\n" - "st1 {v9.4s}, [%[prhs_ptr]], #16\n" - "st1 {v13.4s}, [%[prhs_ptr]], #16\n" - "st1 {v6.4s}, [%[prhs_ptr]], #16\n" - "st1 {v10.4s}, [%[prhs_ptr]], #16\n" - "st1 {v14.4s}, [%[prhs_ptr]], #16\n" - "st1 {v7.4s}, [%[prhs_ptr]], #16\n" - "st1 {v11.4s}, [%[prhs_ptr]], #16\n" - "st1 {v15.4s}, [%[prhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[buffer]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q8, q10\n" - "vzip.32 q9, q11\n" - "vzip.32 q8, q9\n" - "vzip.32 q10, q11\n" - - "vld1.f32 {d24-d25}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d26-d27}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d28-d29}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d30-d31}, [r0]\n" - - "vzip.32 q12, q14\n" - "vzip.32 q13, q15\n" - "vzip.32 q12, q13\n" - "vzip.32 q14, q15\n" - - "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" - "vst1.f32 {d16-d17}, [%[prhs_ptr]]!\n" - "vst1.f32 {d24-d25}, [%[prhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" - "vst1.f32 {d18-d19}, [%[prhs_ptr]]!\n" - "vst1.f32 {d26-d27}, [%[prhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" - "vst1.f32 {d20-d21}, [%[prhs_ptr]]!\n" - "vst1.f32 {d28-d29}, [%[prhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" - "vst1.f32 {d22-d23}, [%[prhs_ptr]]!\n" - "vst1.f32 {d30-d31}, [%[prhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", - "q12", "q13", "q14", "q15"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - prhs_ptr[0] = buffer[0]; - prhs_ptr[1] = buffer[k]; - prhs_ptr[2] = buffer[k << 1]; - prhs_ptr[3] = buffer[3 * k]; - prhs_ptr[4] = buffer[k << 2]; - prhs_ptr[5] = buffer[5 * k]; - prhs_ptr[6] = buffer[6 * k]; - prhs_ptr[7] = buffer[7 * k]; - prhs_ptr[8] = buffer[k << 3]; - prhs_ptr[9] = buffer[9 * k]; - prhs_ptr[10] = buffer[10 * k]; - prhs_ptr[11] = buffer[11 * k]; - prhs_ptr += nr; - buffer++; - } - break; - - case 8: - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[buffer]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "ld1 {v8.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v9.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v10.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v11.4s}, [x0]\n" - - "zip1 v28.4s, v8.4s, v10.4s\n" - "zip2 v30.4s, v8.4s, v10.4s\n" - "zip1 v29.4s, v9.4s, v11.4s\n" - "zip2 v31.4s, v9.4s, v11.4s\n" - "zip1 v8.4s, v28.4s, v29.4s\n" - "zip2 v9.4s, v28.4s, v29.4s\n" - "zip1 v10.4s, v30.4s, v31.4s\n" - "zip2 v11.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[prhs_ptr]], #16\n" - "st1 {v8.4s}, [%[prhs_ptr]], #16\n" - "st1 {v5.4s}, [%[prhs_ptr]], #16\n" - "st1 {v9.4s}, [%[prhs_ptr]], #16\n" - "st1 {v6.4s}, [%[prhs_ptr]], #16\n" - "st1 {v10.4s}, [%[prhs_ptr]], #16\n" - "st1 {v7.4s}, [%[prhs_ptr]], #16\n" - "st1 {v11.4s}, [%[prhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", - "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[buffer]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d20-d21}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d22-d23}, [r0]\n" - - "vzip.32 q8, q10\n" - "vzip.32 q9, q11\n" - "vzip.32 q8, q9\n" - "vzip.32 q10, q11\n" - - "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" - "vst1.f32 {d16-d17}, [%[prhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" - "vst1.f32 {d18-d19}, [%[prhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" - "vst1.f32 {d20-d21}, [%[prhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" - "vst1.f32 {d22-d23}, [%[prhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - prhs_ptr[0] = buffer[0]; - prhs_ptr[1] = buffer[k]; - prhs_ptr[2] = buffer[k << 1]; - prhs_ptr[3] = buffer[3 * k]; - prhs_ptr[4] = buffer[k << 2]; - prhs_ptr[5] = buffer[5 * k]; - prhs_ptr[6] = buffer[6 * k]; - prhs_ptr[7] = buffer[7 * k]; - prhs_ptr += nr; - buffer++; - } - break; -#if !__aarch64__ - case 6: - if (nk > 0) - { - asm volatile("0:\n" - "mov r0, %[buffer]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d16-d17}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d18-d19}, [r0]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - "vzip.32 q8, q9\n" - - "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" - "vst1.f32 {d16}, [%[prhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" - "vst1.f32 {d17}, [%[prhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" - "vst1.f32 {d18}, [%[prhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" - "vst1.f32 {d19}, [%[prhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9"); - } - - for (int j = 0; j < rk; j++) - { - prhs_ptr[0] = buffer[0]; - prhs_ptr[1] = buffer[k]; - prhs_ptr[2] = buffer[k << 1]; - prhs_ptr[3] = buffer[3 * k]; - prhs_ptr[4] = buffer[k << 2]; - prhs_ptr[5] = buffer[5 * k]; - prhs_ptr += nr; - buffer++; - } - break; -#endif // !__aarch64__ - case 4: - if (nk > 0) - { -#if __aarch64__ - asm volatile("0:\n" - "mov x0, %[buffer]\n" - - "ld1 {v4.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v5.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v6.4s}, [x0]\n" - "add x0, x0, %[_stride]\n" - "ld1 {v7.4s}, [x0]\n" - - "zip1 v28.4s, v4.4s, v6.4s\n" - "zip2 v30.4s, v4.4s, v6.4s\n" - "zip1 v29.4s, v5.4s, v7.4s\n" - "zip2 v31.4s, v5.4s, v7.4s\n" - "zip1 v4.4s, v28.4s, v29.4s\n" - "zip2 v5.4s, v28.4s, v29.4s\n" - "zip1 v6.4s, v30.4s, v31.4s\n" - "zip2 v7.4s, v30.4s, v31.4s\n" - - "st1 {v4.4s}, [%[prhs_ptr]], #16\n" - "st1 {v5.4s}, [%[prhs_ptr]], #16\n" - "st1 {v6.4s}, [%[prhs_ptr]], #16\n" - "st1 {v7.4s}, [%[prhs_ptr]], #16\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v28", "v29", "v30", "v31"); -#else // __aarch64__ - asm volatile("0:\n" - "mov r0, %[buffer]\n" - - "vld1.f32 {d8-d9}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d10-d11}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d12-d13}, [r0]\n" - "add r0, r0, %[_stride]\n" - "vld1.f32 {d14-d15}, [r0]\n" - - "vzip.32 q4, q6\n" - "vzip.32 q5, q7\n" - "vzip.32 q4, q5\n" - "vzip.32 q6, q7\n" - - "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" - "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" - "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" - "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" - - "subs %[nk], %[nk], #1\n" - "add %[buffer], %[buffer], #16\n" - "bne 0b\n" - : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) - : [_stride] "r"(_stride) - : "cc", "memory", "r0", "q4", "q5", "q6", "q7"); -#endif // __aarch64__ - } - - for (int j = 0; j < rk; j++) - { - prhs_ptr[0] = buffer[0]; - prhs_ptr[1] = buffer[k]; - prhs_ptr[2] = buffer[k << 1]; - prhs_ptr[3] = buffer[3 * k]; - prhs_ptr += nr; - buffer++; - } - break; - default: - break; - } -} - -void _pack_colmajor_notrans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr) -{ - _pack_rowmajor_notrans_rhs(mr, mb, kb, stride, lhs_ptr, plhs_ptr); -} - -void _pack_colmajor_notrans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr) -{ - _pack_rowmajor_notrans_lhs(nr, nb, kb, stride, rhs_ptr, prhs_ptr); -} - -void _pack_colmajor_trans_lhs(const int mr, const int mb, const int kb, const int stride, - const float *lhs_ptr, float *plhs_ptr) -{ - _pack_rowmajor_notrans_lhs(mr, mb, kb, stride, lhs_ptr, plhs_ptr); -} - -void _pack_colmajor_trans_rhs(const int nr, const int nb, const int kb, const int stride, - const float *rhs_ptr, float *prhs_ptr) -{ - _pack_rowmajor_notrans_rhs(nr, nb, kb, stride, rhs_ptr, prhs_ptr); -} - -void _pack_colmajor_image_rhs(const int nr, const int nb, const int kb, const int k0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *prhs_ptr) -{ - const int w = input->w; - const int h = input->h; - const int c = input->c; - const int outw = output->w; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - const float *input_data = input->data; - - int c0 = c - k0 % c; - if (c0 > kb) - c0 = kb; - int nc = (kb - c0 + c - 1) / c; - if (c0) - nc++; - const int cn = (kb - c0) % c; - - int seg0 = outw - n0 % outw; - if (seg0 > nb) - seg0 = nb; - int rows = (nb - seg0 + outw - 1) / outw; - if (seg0) - rows++; - const int segn = (nb - seg0) % outw; - - const int in_row0 = n0 / outw * stride_h; - const int in_col0 = n0 % outw * stride_w; - - for (int i = 0; i < nc; i++) - { - const int channels = (i == 0 && c0 != 0) ? c0 : ((i == nc - 1 && cn != 0) ? cn : c); - const int c1 = (i == 0) ? k0 % c : 0; - - float tmp_data[channels * nr]; - int nindex = 0; - float *buffer = tmp_data; - float *prhs_tmp = prhs_ptr; - - const int in_row1 = (k0 / c + i) / kernel_w % kernel_h * params->dilation_h + in_row0; - const int in_col1 = (k0 / c + i) % kernel_w * params->dilation_w; - - int in_row = in_row1 - pad_h; - - for (int out_rows = rows; out_rows; out_rows--) - { - int cols = (out_rows != 1 || segn == 0) ? outw : segn; - int in_col = in_col1 - pad_w; - if (out_rows == rows) - { - cols = seg0; - in_col += in_col0; - } - if ((unsigned int)in_row < (unsigned int)h) - { - for (int out_col = cols; out_col; out_col--) - { - if ((unsigned int)in_col < (unsigned int)w) - { - for (int j = c1; j < c1 + channels; j++) - { - *(buffer++) = input_data[(in_row * w + in_col) * c + j]; - } - } - else - { - for (int j = 0; j < channels; j++) - { - *(buffer++) = 0; - } - } - in_col += stride_w; - - nindex++; - if (nindex == nr) - { - nindex = 0; - buffer = tmp_data; - _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); - prhs_tmp += kb * nr; - } - } - } - else - { - for (int out_col = cols; out_col; out_col--) - { - for (int j = 0; j < channels; j++) - { - *(buffer++) = 0; - } - in_col += stride_w; - - nindex++; - if (nindex == nr) - { - nindex = 0; - buffer = tmp_data; - _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); - prhs_tmp += kb * nr; - } - } - } - - in_row += stride_h; - } - - if (nindex > 0) - { - float *data = tmp_data; - for (int i = 0; i < channels; i++) - { - for (int j = 0; j < nindex; j++) - { - prhs_tmp[j] = data[j * channels]; - } - for (int j = nindex; j < nr; j++) - { - prhs_tmp[j] = 0.f; - } - prhs_tmp += nr; - data++; - } - } - - prhs_ptr += channels * nr; - } -} - -void _pack_colmajor_image_rhs_batch(const int nr, const int nb, const int kb, const int k0, - const int n0, convMat_t *input, convMat_t *output, - convParams_t *params, float *prhs_ptr) -{ - const int w = input->w; - const int h = input->h; - const int c = input->c; - const int outw = output->w; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - - int c0 = c - k0 % c; - if (c0 > kb) - c0 = kb; - int nc = (kb - c0 + c - 1) / c; - if (c0) - nc++; - const int cn = (kb - c0) % c; - - const int seg_size = output->w * output->h; - - const float *indata = input->data + (w * h * c) * (n0 / seg_size); - - int bseg0 = seg_size - n0 % seg_size; - if (bseg0 > nb) - bseg0 = nb; - int bnseg = (nb - bseg0 + seg_size - 1) / seg_size; - if (bseg0) - bnseg++; - const int bsegn = (nb - bseg0) % seg_size; - - for (int ll = 0; ll < nc; ll++) - { - const float *input_data = indata; - - const int channels = (ll == 0 && c0 != 0) ? c0 : ((ll == nc - 1 && cn != 0) ? cn : c); - const int c1 = (ll == 0) ? k0 % c : 0; - - int nindex = 0; - float *prhs_tmp = prhs_ptr; - float tmp_data[channels * nr]; - float *buffer = tmp_data; - - for (int i = 0; i < bnseg; i++) - { - const int _nb = - ((i == 0 && bseg0 != 0) ? bseg0 : ((i == bnseg - 1 && bsegn != 0) ? bsegn : seg_size)); - const int _n0 = (i == 0 ? n0 % seg_size : 0); - - int seg0 = outw - _n0 % outw; - if (seg0 > _nb) - seg0 = _nb; - int rows = (_nb - seg0 + outw - 1) / outw; - if (seg0) - rows++; - const int segn = (_nb - seg0) % outw; - - const int in_row0 = _n0 / outw * stride_h; - const int in_col0 = _n0 % outw * stride_w; - - const int in_row1 = (k0 / c + ll) / kernel_w % kernel_h + in_row0; - const int in_col1 = (k0 / c + ll) % kernel_w; - - int in_row = in_row1; - - for (int out_rows = rows; out_rows; out_rows--) - { - int cols = (out_rows != 1 || segn == 0) ? outw : segn; - int in_col = in_col1; - if (out_rows == rows) - { - cols = seg0; - in_col += in_col0; - } - if ((unsigned int)in_row < (unsigned int)h) - { - for (int out_col = cols; out_col; out_col--) - { - if ((unsigned int)in_col < (unsigned int)w) - { - for (int j = c1; j < c1 + channels; j++) - { - *(buffer++) = input_data[(in_row * w + in_col) * c + j]; - } - } - else - { - for (int j = 0; j < channels; j++) - { - *(buffer++) = 0; - } - } - in_col += stride_w; - - nindex++; - if (nindex == nr) - { - nindex = 0; - buffer = tmp_data; - _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); - prhs_tmp += kb * nr; - } - } - } - else - { - for (int out_col = cols; out_col; out_col--) - { - for (int j = 0; j < channels; j++) - { - *(buffer++) = 0; - } - in_col += stride_w; - - nindex++; - if (nindex == nr) - { - nindex = 0; - buffer = tmp_data; - _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); - prhs_tmp += kb * nr; - } - } - } - - in_row += stride_h; - } - - input_data += w * h * c; - } - - if (nindex > 0) - { - float *data = tmp_data; - for (int ii = 0; ii < channels; ii++) - { - for (int jj = 0; jj < nindex; jj++) - { - prhs_tmp[jj] = data[jj * channels]; - } - for (int jj = nindex; jj < nr; jj++) - { - prhs_tmp[jj] = 0.f; - } - prhs_tmp += nr; - data++; - } - } - - prhs_ptr += channels * nr; - } -} - -void _unpack_colmajor_image_res(const int mb, const int nb, const int m0, const int n0, - convMat_t *input, convMat_t *output, convParams_t *params, - float *pres_ptr) -{ - const int w = input->w; - const int outw = output->w; - const int outh = output->h; - const int outc = output->c; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - float *output_data = output->data; - - int c0 = outc - m0 % outc; - if (c0 > mb) - c0 = mb; - int nc = (mb - c0 + outc - 1) / outc; - if (c0) - nc++; - const int cn = (mb - c0) % outc; - - int seg0 = w - n0 % w; - if (seg0 > nb) - seg0 = nb; - int rows = (nb - seg0 + w - 1) / w; - if (seg0) - rows++; - const int segn = (nb - seg0) % w; - - const int out_row0 = n0 / w * stride_h; - const int out_col0 = n0 % w * stride_w; - - for (int i = 0; i < nc; i++) - { - const int channels = (i == 0 && c0 != 0) ? c0 : ((i == nc - 1 && cn != 0) ? cn : outc); - const int c1 = (i == 0) ? m0 % outc : 0; - - float *buffer = pres_ptr; - - const int out_row1 = (m0 / outc + i) / kernel_w % kernel_h * params->dilation_h + out_row0; - const int out_col1 = (m0 / outc + i) % kernel_w * params->dilation_w; - - int out_row = out_row1 - pad_h; - - for (int in_rows = rows; in_rows; in_rows--) - { - int cols = (in_rows != 1 || segn == 0) ? w : segn; - int out_col = out_col1 - pad_w; - if (in_rows == rows) - { - cols = seg0; - out_col += out_col0; - } - if ((unsigned int)out_row < (unsigned int)outh) - { - for (int in_col = cols; in_col; in_col--) - { - if ((unsigned int)out_col < (unsigned int)outw) - { - for (int j = c1; j < c1 + channels; j++) - { - // Note:Data competition for multi-threads - //#pragma omp atomic //low performance - output_data[(out_row * outw + out_col) * outc + j] += *(buffer + j - c1); - } - } - buffer += mb; - out_col += stride_w; - } - } - else - { - buffer += cols * mb; - } - out_row += stride_h; - } - - pres_ptr += channels; - } -} - -void _sparse_pack_rowmajor_image(const int nb, const int k0, const int n0, convMat_t *input, - convMat_t *output, convParams_t *params, float *prhs_ptr) -{ - const int w = input->w; - const int h = input->h; - const int outw = output->w; - const int kernel_w = params->kernel_w; - const int kernel_h = params->kernel_h; - const int stride_w = params->stride_w; - const int stride_h = params->stride_h; - const int pad_w = params->pad_w; - const int pad_h = params->pad_h; - - const int in_row0 = n0 / outw * stride_h; - const int in_col0 = n0 % outw * stride_w; - int seg0 = outw - n0 % outw; - if (seg0 > nb) - seg0 = nb; - int rows = (nb - seg0 + outw - 1) / outw; - if (seg0) - rows++; - const int segn = (nb - seg0) % outw; - - const int ic = k0 / (kernel_w * kernel_h); - const int in_row1 = ((k0 / kernel_w) % kernel_h) * params->dilation_h + in_row0; - const int in_col1 = k0 % kernel_w * params->dilation_w; - -#ifdef NCNN - const float *input_data = input->data + ic * alignSize(w * h, 16 / sizeof(float)); -#else // NCNN - const float *input_data = input->data + ic * w * h; -#endif // NCNN - - int in_row = in_row1 - pad_h; - - for (int out_rows = rows; out_rows; out_rows--) - { - int cols = (out_rows != 1 || segn == 0) ? outw : segn; - int in_col = in_col1 - pad_w; - if (out_rows == rows) - { - cols = seg0; - in_col += in_col0; - } - if ((unsigned int)in_row < (unsigned int)h) - { - for (int out_col = cols; out_col; out_col--) - { - if ((unsigned int)in_col < (unsigned int)w) - *(prhs_ptr++) = input_data[in_row * w + in_col]; - else - *(prhs_ptr++) = 0; - in_col += stride_w; - } - } - else - { - for (int out_col = cols; out_col; out_col--) - { - *(prhs_ptr++) = 0; - in_col += stride_w; - } - } - - in_row += stride_h; - } -} - -} // namespace srcn -} // namespace nnfw |