diff options
Diffstat (limited to 'compute/ncnn/src/srcn/depthwise_conv.cc')
-rw-r--r-- | compute/ncnn/src/srcn/depthwise_conv.cc | 2684 |
1 files changed, 2684 insertions, 0 deletions
diff --git a/compute/ncnn/src/srcn/depthwise_conv.cc b/compute/ncnn/src/srcn/depthwise_conv.cc new file mode 100644 index 000000000..cd092d5ac --- /dev/null +++ b/compute/ncnn/src/srcn/depthwise_conv.cc @@ -0,0 +1,2684 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include <arm_neon.h> +#include <stdlib.h> +#include <string.h> + +#include "common.h" +#include "ncnn/srcn/conv_type.h" + +namespace nnfw +{ +namespace srcn +{ + +static void depthwise_conv3x3S1_nopad(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + float *in_ptr3 = inbuf + 3 * w; + + float *out_ptr0 = outbuf + 0 * outw; + float *out_ptr1 = outbuf + 1 * outw; + + int i; + for (i = 0; i + 1 < outh; i += 2) + { + int nn = (outw >> 2) - 1; + int remain = outw & 0x03; + + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + + "1:\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q2, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q2, %e[weight345][1]\n" + "vmul.f32 q12, q0, %e[weight012][0]\n" + "vmul.f32 q13, q2, %e[weight012][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr2], %[in_ptr2], #16\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q2, %e[weight678][1]\n" + "vmla.f32 q12, q0, %e[weight345][0]\n" + "vmla.f32 q13, q2, %e[weight345][1]\n" + + "pld [%[in_ptr3], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr3]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr3], %[in_ptr3], #16\n" + + "vmla.f32 q12, q0, %e[weight678][0]\n" + "vmla.f32 q13, q2, %e[weight678][1]\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vmla.f32 q15, q3, %f[weight678][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + + "bne 1b\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [in_ptr3] "+r"(in_ptr3), + + [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + float32x4_t input3 = vld1q_f32(in_ptr3); + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + float32x4_t out1 = vmulq_f32(input1, weight012); + out1 = vmlaq_f32(out1, input2, weight345); + out1 = vmlaq_f32(out1, input3, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + out1 = vsetq_lane_f32(bias0, out1, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); + + float32x2_t out01 = vpadd_f32(out00, out11); + + *out_ptr0 = vget_lane_f32(out01, 0); + *out_ptr1 = vget_lane_f32(out01, 1); + + in_ptr0++; + in_ptr1++; + in_ptr2++; + in_ptr3++; + out_ptr0++; + out_ptr1++; + } + + in_ptr0 += w + 2; + in_ptr1 += w + 2; + in_ptr2 += w + 2; + in_ptr3 += w + 2; + + out_ptr0 += outw; + out_ptr1 += outw; + } + + for (; i < outh; i++) + { + int nn = outw >> 2; + int remain = outw & 0x03; + + if (nn > 0) + { + __asm __volatile("1:\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q14, q0, %e[weight012][0]\n" + "vmla.f32 q14, q2, %e[weight012][1]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vmla.f32 q14, q0, %e[weight345][0]\n" + "vmla.f32 q14, q2, %e[weight345][1]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr2], %[in_ptr2], #16\n" + + "vmla.f32 q14, q0, %e[weight678][0]\n" + "vmla.f32 q14, q2, %e[weight678][1]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + + "bne 1b\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0++; + in_ptr1++; + in_ptr2++; + out_ptr0++; + } + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + } + } +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // !__aarch64__ +} + +static void depthwise_conv3x3S1_padding(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + float *in_ptr3 = inbuf + 3 * w; + + float *out_ptr0 = outbuf + 0 * outw; + float *out_ptr1 = outbuf + 1 * outw; + + int i; + for (i = 0; i + 1 < outh; i += 2) + { + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + if (i == 0) + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q8, #0\n" + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr0], %[in_ptr0], #12\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q2, %e[weight345][0]\n" + "vmul.f32 q11, q0, %e[weight345][1]\n" + "vmul.f32 q12, q2, %e[weight012][0]\n" + "vmul.f32 q13, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr1], %[in_ptr1], #12\n" + + "vmla.f32 q10, q2, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q12, q2, %e[weight345][0]\n" + "vmla.f32 q13, q0, %e[weight345][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr2], %[in_ptr2], #12\n" + + "vmla.f32 q12, q2, %e[weight678][0]\n" + "vmla.f32 q13, q0, %e[weight678][1]\n" + "vmla.f32 q15, q3, %f[weight678][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "1:\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight345][0]\n" + "vmul.f32 q11, q2, %e[weight345][1]\n" + "vmul.f32 q12, q0, %e[weight012][0]\n" + "vmul.f32 q13, q2, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q2, %e[weight678][1]\n" + "vmla.f32 q12, q0, %e[weight345][0]\n" + "vmla.f32 q13, q2, %e[weight345][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr2], %[in_ptr2], #16\n" + + "vmla.f32 q12, q0, %e[weight678][0]\n" + "vmla.f32 q13, q2, %e[weight678][1]\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vmla.f32 q15, q3, %f[weight678][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "bne 1b\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), + [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + + for (; remain > 0; remain--) + { + // TODO: when nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight345); + out0 = vmlaq_f32(out0, input1, weight678); + + float32x4_t out1 = vmulq_f32(input0, weight012); + out1 = vmlaq_f32(out1, input1, weight345); + out1 = vmlaq_f32(out1, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + out1 = vsetq_lane_f32(bias0, out1, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); + + float32x2_t out01 = vpadd_f32(out00, out11); + + *out_ptr0 = vget_lane_f32(out01, 0); + *out_ptr1 = vget_lane_f32(out01, 1); + + in_ptr0++; + in_ptr1++; + in_ptr2++; + out_ptr0++; + out_ptr1++; + } + + in_ptr0 += 1; + in_ptr1 += 1; + in_ptr2 += 1; + in_ptr3 += w; + } + else if (i == outh - 2) + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q8, #0\n" + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr0], %[in_ptr0], #12\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q2, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr1], %[in_ptr1], #12\n" + + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q10, q2, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + "vmul.f32 q12, q2, %e[weight012][0]\n" + "vmul.f32 q13, q0, %e[weight012][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr2], %[in_ptr2], #12\n" + + "vmla.f32 q10, q2, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q12, q2, %e[weight345][0]\n" + "vmla.f32 q13, q0, %e[weight345][1]\n" + + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "1:\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q2, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q2, %e[weight345][1]\n" + "vmul.f32 q12, q0, %e[weight012][0]\n" + "vmul.f32 q13, q2, %e[weight012][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr2], %[in_ptr2], #16\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q2, %e[weight678][1]\n" + "vmla.f32 q12, q0, %e[weight345][0]\n" + "vmla.f32 q13, q2, %e[weight345][1]\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "bne 1b\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), + [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: when nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + float32x4_t out1 = vmulq_f32(input1, weight012); + out1 = vmlaq_f32(out1, input2, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + out1 = vsetq_lane_f32(bias0, out1, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); + + float32x2_t out01 = vpadd_f32(out00, out11); + + *out_ptr0 = vget_lane_f32(out01, 0); + *out_ptr1 = vget_lane_f32(out01, 1); + + in_ptr0++; + in_ptr1++; + in_ptr2++; + out_ptr0++; + out_ptr1++; + } + } + else + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q8, #0\n" + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr0], %[in_ptr0], #12\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q2, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr1], %[in_ptr1], #12\n" + + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q10, q2, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + "vmul.f32 q12, q2, %e[weight012][0]\n" + "vmul.f32 q13, q0, %e[weight012][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr2], %[in_ptr2], #12\n" + + "vmla.f32 q10, q2, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q12, q2, %e[weight345][0]\n" + "vmla.f32 q13, q0, %e[weight345][1]\n" + + "pld [%[in_ptr3], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr3]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr3], %[in_ptr3], #12\n" + + "vmla.f32 q15, q2, %e[weight678][0]\n" + "vmla.f32 q15, q0, %e[weight678][1]\n" + "vmla.f32 q15, q3, %f[weight678][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "1:\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q2, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q2, %e[weight345][1]\n" + "vmul.f32 q12, q0, %e[weight012][0]\n" + "vmul.f32 q13, q2, %e[weight012][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr2], %[in_ptr2], #16\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q2, %e[weight678][1]\n" + "vmla.f32 q12, q0, %e[weight345][0]\n" + "vmla.f32 q13, q2, %e[weight345][1]\n" + + "pld [%[in_ptr3], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr3]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr3], %[in_ptr3], #16\n" + + "vmla.f32 q15, q0, %e[weight678][0]\n" + "vmla.f32 q15, q2, %e[weight678][1]\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vmla.f32 q15, q3, %f[weight678][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "bne 1b\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [in_ptr3] "+r"(in_ptr3), + + [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: when nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + float32x4_t input3 = vld1q_f32(in_ptr3); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + input3 = vsetq_lane_f32(0.0f, input3, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + float32x4_t out1 = vmulq_f32(input1, weight012); + out1 = vmlaq_f32(out1, input2, weight345); + out1 = vmlaq_f32(out1, input3, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + out1 = vsetq_lane_f32(bias0, out1, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); + + float32x2_t out01 = vpadd_f32(out00, out11); + + *out_ptr0 = vget_lane_f32(out01, 0); + *out_ptr1 = vget_lane_f32(out01, 1); + + in_ptr0++; + in_ptr1++; + in_ptr2++; + in_ptr3++; + out_ptr0++; + out_ptr1++; + } + in_ptr0 += w + 1; + in_ptr1 += w + 1; + in_ptr2 += w + 1; + in_ptr3 += w + 1; + } + + out_ptr0 += outw; + out_ptr1 += outw; + } + + for (; i < outh; i++) + { + // TODO:if i == 0, pad_top comes here. + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + + if (nn > 0) + { + __asm __volatile("vmov.i32 q8, #0\n" + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr0], %[in_ptr0], #12\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q2, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr1], %[in_ptr1], #12\n" + + "vmla.f32 q10, q2, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "1:\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q2, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q2, %e[weight345][1]\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: when nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0++; + in_ptr1++; + out_ptr0++; + out_ptr1++; + } + } + } +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // __aarch64__ +} + +static void depthwise_conv3x3S2_nopad(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + + const int tailstep = w - 2 * outw + w; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + + float *out_ptr0 = outbuf + 0 * outw; + + int i; + for (i = 0; i < outh; i++) + { + int nn = outw >> 2; + int remain = outw & 0x03; + + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + out_ptr0++; + } + + in_ptr0 += tailstep; + in_ptr1 += tailstep; + in_ptr2 += tailstep; + } + } + +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // __aarch64__ +} + +static void depthwise_conv3x3S2_padding00(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + + float *out_ptr0 = outbuf + 0 * outw; + + int i; + for (i = 0; i < outh; i++) + { + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + + if (i == outh - 1) + { + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + } + else + { + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + out_ptr0++; + } + + in_ptr0 += w; + in_ptr1 += w; + in_ptr2 += w; + } + } + } +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // !__aarch64__ +} + +static void depthwise_conv3x3S2_padding01(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + + float *out_ptr0 = outbuf + 0 * outw; + + int i; + for (i = 0; i < outh; i++) + { + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + + if (i == outh - 1) + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q2, #0\n" + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr0], %[in_ptr0], #28\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q3, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" + "vmla.f32 q14, q1, %f[weight012][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr1], %[in_ptr1], #28\n" + + "vmla.f32 q10, q3, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + "vmla.f32 q14, q1, %f[weight345][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: if nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + } + else + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q2, #0\n" + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr0], %[in_ptr0], #28\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q3, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" + "vmla.f32 q14, q1, %f[weight012][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr1], %[in_ptr1], #28\n" + + "vmla.f32 q10, q3, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]\n" + "vmla.f32 q14, q1, %f[weight345][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr2], %[in_ptr2], #28\n" + + "vmla.f32 q10, q3, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q14, q1, %f[weight678][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: if nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + out_ptr0++; + } + + in_ptr0 += w; + in_ptr1 += w; + in_ptr2 += w; + } + } + } + +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // __aarch64__ +} + +static void depthwise_conv3x3S2_padding10(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + + float *out_ptr0 = outbuf + 0 * outw; + + int i; + for (i = 0; i < outh; i++) + { + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + + // TODO: i == 0 && i == outh -1 + if (i == 0) + { + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight345][0]\n" + "vmul.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight345); + out0 = vmlaq_f32(out0, input1, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + + in_ptr2 += w; + } + else if (i == outh - 1) + { + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + } + else + { + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + out_ptr0++; + } + + in_ptr0 += w; + in_ptr1 += w; + in_ptr2 += w; + } + } + } + +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // __aarch64__ +} + +static void depthwise_conv3x3S2_padding11(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + + float *out_ptr0 = outbuf + 0 * outw; + + int i; + for (i = 0; i < outh; i++) + { + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + + // TODO: i == 0 && i == outh - 1 + if (i == 0) + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q2, #0\n" + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr0], %[in_ptr0], #28\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q3, %e[weight345][0]\n" + "vmul.f32 q11, q0, %e[weight345][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" + "vmla.f32 q14, q1, %f[weight345][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr1], %[in_ptr1], #28\n" + + "vmla.f32 q10, q3, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q14, q1, %f[weight678][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight345][0]\n" + "vmul.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: if nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight345); + out0 = vmlaq_f32(out0, input1, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + + in_ptr2 += w; + } + else if (i == outh - 1) + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q2, #0\n" + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr0], %[in_ptr0], #28\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q3, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" + "vmla.f32 q14, q1, %f[weight012][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr1], %[in_ptr1], #28\n" + + "vmla.f32 q10, q3, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + "vmla.f32 q14, q1, %f[weight345][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: if nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + } + else + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q2, #0\n" + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr0], %[in_ptr0], #28\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q3, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" + "vmla.f32 q14, q1, %f[weight012][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr1], %[in_ptr1], #28\n" + + "vmla.f32 q10, q3, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]\n" + "vmla.f32 q14, q1, %f[weight345][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr2], %[in_ptr2], #28\n" + + "vmla.f32 q10, q3, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q14, q1, %f[weight678][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: if nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + out_ptr0++; + } + + in_ptr0 += w; + in_ptr1 += w; + in_ptr2 += w; + } + } + } +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // __aarch64__ +} + +static void depthwise_conv_colmajor(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convParams_t &in_param) +{ +#if __aarch64__ + const int w = in_mat.w; + const int h = in_mat.h; + const int outw = out_mat.w; + const int outh = out_mat.h; + const int channels = out_mat.c; + const int stridew = in_param.stride_w; + const int strideh = in_param.stride_h; + const int padding = in_param.padding; + const int padw = in_param.pad_w; + const int padh = in_param.pad_h; + +#pragma omp parallel for + for (int oh = 0; oh < outh; oh++) + { + const float *input_data0 = in_mat.data + (oh * strideh - padh) * w * channels; + + memset(out_mat.data + oh * outw * channels, 0x00, outw * channels * sizeof(float)); + + for (int kh = 0; kh < in_param.kernel_h; kh++) + { + for (int kw = 0; kw < in_param.kernel_w; kw++) + { + const float *kernel_data = kernel.data + (kh * in_param.kernel_w + kw) * channels; + const float *input_data1 = input_data0 + (kh * w + kw) * channels; + + if (padding && ((oh * strideh + kh < padh) || (oh * strideh + kh >= padh + h))) + { + continue; + } + + int ow = 0; + for (; ow + 3 < outw; /*ow += 4*/) + { + if (((ow + 3) * stridew + kw < padw) || (ow * stridew + kw >= padw + w)) + { + ow += 4; + continue; + } + else if ((ow + 3) * stridew + kw >= padw + w) + { + break; + } + else if (ow * stridew + kw < padw) + { + int delta = (padw - kw) / stridew - ow; + delta += (padw - kw) % stridew ? 1 : 0; + ow += delta; + continue; + } + + int nn = channels >> 2; + int remain = channels & 0x03; + + const float *input_r0 = input_data1 + (ow * stridew - padw) * channels; + + const float *input_r1 = input_r0 + stridew * channels; + const float *input_r2 = input_r1 + stridew * channels; + const float *input_r3 = input_r2 + stridew * channels; + const float *weights_data = kernel_data; + float *output_r0 = out_mat.data + (oh * outw + ow) * channels; + float *output_r1 = output_r0 + channels; + float *output_r2 = output_r1 + channels; + float *output_r3 = output_r2 + channels; + + if (nn > 0) + { + int _n = (nn + 1) >> 1; + int oddn = nn & 1; + + asm volatile("subs %[_n], %[_n], #1\n" + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_r0]], #16\n" + "ld1 {v6.4s}, [%[input_r1]], #16\n" + "ld1 {v7.4s}, [%[input_r2]], #16\n" + "ld1 {v8.4s}, [%[input_r3]], #16\n" + "beq 1f\n" + + "0:\n" + "ld1 {v24.4s, v25.4s}, [%[output_r0]]\n" + "ld1 {v26.4s, v27.4s}, [%[output_r1]]\n" + "ld1 {v28.4s, v29.4s}, [%[output_r2]]\n" + "ld1 {v30.4s, v31.4s}, [%[output_r3]]\n" + + "ld1 {v9.4s}, [%[weights_data]], #16\n" + "ld1 {v10.4s}, [%[input_r0]], #16\n" + "ld1 {v11.4s}, [%[input_r1]], #16\n" + "ld1 {v12.4s}, [%[input_r2]], #16\n" + "ld1 {v13.4s}, [%[input_r3]], #16\n" + + "fmla v24.4s, v4.4s, v5.4s\n" + "fmla v26.4s, v4.4s, v6.4s\n" + + "fmla v28.4s, v4.4s, v7.4s\n" + "fmla v30.4s, v4.4s, v8.4s\n" + + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_r0]], #16\n" + "ld1 {v6.4s}, [%[input_r1]], #16\n" + "ld1 {v7.4s}, [%[input_r2]], #16\n" + "ld1 {v8.4s}, [%[input_r3]], #16\n" + + "fmla v25.4s, v9.4s, v10.4s\n" + "fmla v27.4s, v9.4s, v11.4s\n" + + "fmla v29.4s, v9.4s, v12.4s\n" + "fmla v31.4s, v9.4s, v13.4s\n" + + "st1 {v24.4s, v25.4s}, [%[output_r0]], #32\n" + "st1 {v26.4s, v27.4s}, [%[output_r1]], #32\n" + "st1 {v28.4s, v29.4s}, [%[output_r2]], #32\n" + "st1 {v30.4s, v31.4s}, [%[output_r3]], #32\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v24.4s}, [%[output_r0]]\n" + "ld1 {v26.4s}, [%[output_r1]]\n" + "ld1 {v28.4s}, [%[output_r2]]\n" + "ld1 {v30.4s}, [%[output_r3]]\n" + "cmp %[oddn], #1\n" + + "fmla v24.4s, v4.4s, v5.4s\n" + "fmla v26.4s, v4.4s, v6.4s\n" + + "fmla v28.4s, v4.4s, v7.4s\n" + "fmla v30.4s, v4.4s, v8.4s\n" + + "st1 {v24.4s}, [%[output_r0]], #16\n" + "st1 {v26.4s}, [%[output_r1]], #16\n" + "st1 {v28.4s}, [%[output_r2]], #16\n" + "st1 {v30.4s}, [%[output_r3]], #16\n" + + "beq 2f\n" + "ld1 {v25.4s}, [%[output_r0]]\n" + "ld1 {v27.4s}, [%[output_r1]]\n" + "ld1 {v29.4s}, [%[output_r2]]\n" + "ld1 {v31.4s}, [%[output_r3]]\n" + + "ld1 {v9.4s}, [%[weights_data]], #16\n" + "ld1 {v10.4s}, [%[input_r0]], #16\n" + "ld1 {v11.4s}, [%[input_r1]], #16\n" + "ld1 {v12.4s}, [%[input_r2]], #16\n" + "ld1 {v13.4s}, [%[input_r3]], #16\n" + + "fmla v25.4s, v9.4s, v10.4s\n" + "fmla v27.4s, v9.4s, v11.4s\n" + + "fmla v29.4s, v9.4s, v12.4s\n" + "fmla v31.4s, v9.4s, v13.4s\n" + + "st1 {v25.4s}, [%[output_r0]], #16\n" + "st1 {v27.4s}, [%[output_r1]], #16\n" + "st1 {v29.4s}, [%[output_r2]], #16\n" + "st1 {v31.4s}, [%[output_r3]], #16\n" + "2:\n" + : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), + [input_r1] "+r"(input_r1), [input_r2] "+r"(input_r2), + [input_r3] "+r"(input_r3), [output_r0] "+r"(output_r0), + [output_r1] "+r"(output_r1), [output_r2] "+r"(output_r2), + [output_r3] "+r"(output_r3), [_n] "+r"(_n) + : [oddn] "r"(oddn) + : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + if (remain >= 2) + { + asm volatile( + "ld1 {v24.2s}, [%[output_r0]]\n" + "ld1 {v26.2s}, [%[output_r1]]\n" + "ld1 {v28.2s}, [%[output_r2]]\n" + "ld1 {v30.2s}, [%[output_r3]]\n" + "ld1 {v4.2s}, [%[weights_data]], #8\n" + "ld1 {v5.2s}, [%[input_r0]], #8\n" + + "ld1 {v6.2s}, [%[input_r1]], #8\n" + "ld1 {v7.2s}, [%[input_r2]], #8\n" + "ld1 {v8.2s}, [%[input_r3]], #8\n" + + "fmla v24.2s, v4.2s, v5.2s\n" + "fmla v26.2s, v4.2s, v6.2s\n" + + "fmla v28.2s, v4.2s, v7.2s\n" + "fmla v30.2s, v4.2s, v8.2s\n" + + "st1 {v24.2s}, [%[output_r0]], #8\n" + "st1 {v26.2s}, [%[output_r1]], #8\n" + "st1 {v28.2s}, [%[output_r2]], #8\n" + "st1 {v30.2s}, [%[output_r3]], #8\n" + : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), + [input_r1] "+r"(input_r1), [input_r2] "+r"(input_r2), [input_r3] "+r"(input_r3), + [output_r0] "+r"(output_r0), [output_r1] "+r"(output_r1), + [output_r2] "+r"(output_r2), [output_r3] "+r"(output_r3) + : + : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v24", "v26", "v28", "v30"); + remain -= 2; + } + + if (remain > 0) + { + *output_r0++ += (*weights_data) * (*input_r0++); + *output_r1++ += (*weights_data++) * (*input_r1++); + *output_r2++ += (*weights_data) * (*input_r2++); + *output_r3++ += (*weights_data++) * (*input_r3++); + } + ow += 4; + } + + for (; ow + 1 < outw; /*ow += 2*/) + { + if (padding) + { + if (((ow + 1) * stridew + kw < padw) || (ow * stridew + kw >= padw + w)) + { + ow += 2; + continue; + } + else if ((ow + 1) * stridew + kw >= padw + w) + { + break; + } + else if (ow * stridew + kw < padw) + { + ow++; + continue; + } + } + + int nn = channels >> 2; + int remain = channels & 0x03; + + const float *input_r0 = input_data1 + (ow * stridew - padw) * channels; + + const float *input_r1 = input_r0 + stridew * channels; + const float *weights_data = kernel_data; + float *output_r0 = out_mat.data + (oh * outw + ow) * channels; + float *output_r1 = output_r0 + channels; + + if (nn > 0) + { + int _n = (nn + 1) >> 1; + int oddn = nn & 1; + + asm volatile("subs %[_n], %[_n], #1\n" + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_r0]], #16\n" + "ld1 {v6.4s}, [%[input_r1]], #16\n" + "beq 1f\n" + + "0:\n" + "ld1 {v24.4s, v25.4s}, [%[output_r0]]\n" + "ld1 {v26.4s, v27.4s}, [%[output_r1]]\n" + + "ld1 {v9.4s}, [%[weights_data]], #16\n" + "ld1 {v10.4s}, [%[input_r0]], #16\n" + "ld1 {v11.4s}, [%[input_r1]], #16\n" + + "fmla v24.4s, v4.4s, v5.4s\n" + "fmla v26.4s, v4.4s, v6.4s\n" + + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_r0]], #16\n" + "ld1 {v6.4s}, [%[input_r1]], #16\n" + + "fmla v25.4s, v9.4s, v10.4s\n" + "fmla v27.4s, v9.4s, v11.4s\n" + + "st1 {v24.4s, v25.4s}, [%[output_r0]], #32\n" + "st1 {v26.4s, v27.4s}, [%[output_r1]], #32\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v24.4s}, [%[output_r0]]\n" + "ld1 {v26.4s}, [%[output_r1]]\n" + "cmp %[oddn], #1\n" + + "fmla v24.4s, v4.4s, v5.4s\n" + "fmla v26.4s, v4.4s, v6.4s\n" + + "st1 {v24.4s}, [%[output_r0]], #16\n" + "st1 {v26.4s}, [%[output_r1]], #16\n" + + "beq 2f\n" + "ld1 {v25.4s}, [%[output_r0]]\n" + "ld1 {v27.4s}, [%[output_r1]]\n" + + "ld1 {v9.4s}, [%[weights_data]], #16\n" + "ld1 {v10.4s}, [%[input_r0]], #16\n" + "ld1 {v11.4s}, [%[input_r1]], #16\n" + + "fmla v25.4s, v9.4s, v10.4s\n" + "fmla v27.4s, v9.4s, v11.4s\n" + + "st1 {v25.4s}, [%[output_r0]], #16\n" + "st1 {v27.4s}, [%[output_r1]], #16\n" + "2:\n" + : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), + [input_r1] "+r"(input_r1), [output_r0] "+r"(output_r0), + [output_r1] "+r"(output_r1), [_n] "+r"(_n) + : [oddn] "r"(oddn) + : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + if (remain >= 2) + { + asm volatile("ld1 {v24.2s}, [%[output_r0]]\n" + "ld1 {v26.2s}, [%[output_r1]]\n" + "ld1 {v4.2s}, [%[weights_data]], #8\n" + "ld1 {v5.2s}, [%[input_r0]], #8\n" + + "ld1 {v6.2s}, [%[input_r1]], #8\n" + + "fmla v24.2s, v4.2s, v5.2s\n" + "fmla v26.2s, v4.2s, v6.2s\n" + + "st1 {v24.2s}, [%[output_r0]], #8\n" + "st1 {v26.2s}, [%[output_r1]], #8\n" + : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), + [input_r1] "+r"(input_r1), [output_r0] "+r"(output_r0), + [output_r1] "+r"(output_r1) + : + : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v24", "v26", "v28", + "v30"); + remain -= 2; + } + + if (remain > 0) + { + *output_r0++ += (*weights_data) * (*input_r0++); + *output_r1++ += (*weights_data++) * (*input_r1++); + } + ow += 2; + } + + for (; ow < outw; ow++) + { + const float *input_data = input_data1 + (ow * stridew - padw) * channels; + + if (padding && ((ow * stridew + kw < padw) || (ow * strideh + kw >= padw + w))) + { + continue; + } + + int nn = channels >> 2; + int remain = channels & 0x03; + + const float *weights_data = kernel_data; + float *output_data = out_mat.data + (oh * outw + ow) * channels; + + if (nn > 0) + { + int _n = (nn + 1) >> 1; + int oddn = nn & 1; + + asm volatile("subs %[_n], %[_n], #1\n" + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_data]], #16\n" + "beq 1f\n" + + "0:\n" + "ld1 {v30.4s, v31.4s}, [%[output_data]]\n" + "ld1 {v6.4s}, [%[weights_data]], #16\n" + "ld1 {v7.4s}, [%[input_data]], #16\n" + "fmla v30.4s, v4.4s, v5.4s\n" + + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_data]], #16\n" + "fmla v31.4s, v6.4s, v7.4s\n" + + "st1 {v30.4s, v31.4s}, [%[output_data]], #32\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v30.4s}, [%[output_data]]\n" + "cmp %[oddn], #1\n" + "fmla v30.4s, v4.4s, v5.4s\n" + "st1 {v30.4s}, [%[output_data]], #16\n" + "beq 2f\n" + "ld1 {v31.4s}, [%[output_data]]\n" + "ld1 {v6.4s}, [%[weights_data]], #16\n" + "ld1 {v7.4s}, [%[input_data]], #16\n" + "fmla v31.4s, v6.4s, v7.4s\n" + + "st1 {v31.4s}, [%[output_data]], #16\n" + "2:\n" + : [weights_data] "+r"(weights_data), [input_data] "+r"(input_data), + [output_data] "+r"(output_data), [_n] "+r"(_n) + : [oddn] "r"(oddn) + : "cc", "memory", "v4", "v5", "v30", "v31"); + } + if (remain >= 2) + { + asm volatile("ld1 {v30.2s}, [%[output_data]]\n" + "ld1 {v4.2s}, [%[weights_data]], #8\n" + "ld1 {v5.2s}, [%[input_data]], #8\n" + + "fmla v30.2s, v4.2s, v5.2s\n" + + "st1 {v30.2s}, [%[output_data]], #8\n" + : [weights_data] "+r"(weights_data), [input_data] "+r"(input_data), + [output_data] "+r"(output_data) + : + : "cc", "memory", "v4", "v5", "v30"); + remain -= 2; + } + + if (remain > 0) + { + *output_data++ += (*weights_data++) * (*input_data++); + } + } + } + } + } +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)in_param; +#endif // __aarch64__ +} + +void srcn_depthwise_conv(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, + const convMat_t &bias, const convParams_t &in_param, int num_threads, + convType_t conv_type) +{ + omp_set_num_threads(num_threads); + + if (conv_type == col_major) + { + depthwise_conv_colmajor(in_mat, out_mat, weights_mat, in_param); + return; + } + + else if (conv_type == row_major) + { + if (in_param.kernel_w == 3 && in_param.kernel_h == 3 && in_param.dilation_w == 1 && + in_param.dilation_h == 1) + { + if (in_param.stride_w == 1 && in_param.stride_h == 1) + { + if (in_param.padding == 0) + depthwise_conv3x3S1_nopad(in_mat, out_mat, weights_mat, bias); + else + depthwise_conv3x3S1_padding(in_mat, out_mat, weights_mat, bias); + } + else if (in_param.stride_w == 2 && in_param.stride_h == 2) + { + if (in_param.padding == 0) + depthwise_conv3x3S2_nopad(in_mat, out_mat, weights_mat, bias); + else + { + if (in_param.pad_w == 0 && in_param.pad_h == 0) + depthwise_conv3x3S2_padding00(in_mat, out_mat, weights_mat, bias); + else if (in_param.pad_w == 0 && in_param.pad_h == 1) + depthwise_conv3x3S2_padding10(in_mat, out_mat, weights_mat, bias); + else if (in_param.pad_w == 1 && in_param.pad_h == 0) + depthwise_conv3x3S2_padding01(in_mat, out_mat, weights_mat, bias); + else if (in_param.pad_w == 1 && in_param.pad_h == 1) + depthwise_conv3x3S2_padding11(in_mat, out_mat, weights_mat, bias); + } + } + } + } +} + +} // namespace srcn +} // namespace nnfw |