diff options
Diffstat (limited to 'compute/ncnn')
36 files changed, 24050 insertions, 0 deletions
diff --git a/compute/ncnn/CMakeLists.txt b/compute/ncnn/CMakeLists.txt new file mode 100644 index 000000000..a8f50120f --- /dev/null +++ b/compute/ncnn/CMakeLists.txt @@ -0,0 +1,34 @@ +if(NOT BUILD_SRCN_KERNEL) + message(STATUS "SRCN kernel library build: disabled") + return() +else(NOT BUILD_SRCN_KERNEL) + message(STATUS "SRCN kernel library build: OK") +endif() + +# Find and use pre-installed OpenMP +find_package(OpenMP QUIET) +if(NOT OpenMP_FOUND) + return() +endif(NOT OpenMP_FOUND) + +file(GLOB_RECURSE SOURCES src/*.cc) +file(GLOB_RECURSE TESTS src/*_test.cc) +list(REMOVE_ITEM SOURCES ${TESTS}) + +add_library(nnfw_lib_srcn STATIC ${SOURCES}) +target_include_directories(nnfw_lib_srcn PUBLIC include) +if(NOT TARGET OpenMP::OpenMP_CXX) + find_package(Threads REQUIRED) + add_library(OpenMP::OpenMP_CXX IMPORTED INTERFACE) + set_property(TARGET OpenMP::OpenMP_CXX + PROPERTY INTERFACE_COMPILE_OPTIONS ${OpenMP_CXX_FLAGS}) + # Only works if the same flag is passed to the linker; use CMake 3.9+ otherwise (Intel, AppleClang) + set_property(TARGET OpenMP::OpenMP_CXX + PROPERTY INTERFACE_LINK_LIBRARIES ${OpenMP_CXX_FLAGS} Threads::Threads) + +endif() +target_link_libraries(nnfw_lib_srcn PRIVATE OpenMP::OpenMP_CXX) +target_link_libraries(nnfw_lib_srcn PRIVATE nnfw_common) +target_compile_definitions(nnfw_lib_srcn PRIVATE TIZEN) # ANDROID or TIZEN +#target_compile_definitions(nnfw_lib_srcn PRIVATE NCNN) # Enable if ready +set_target_properties(nnfw_lib_srcn PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/compute/ncnn/README.md b/compute/ncnn/README.md new file mode 100644 index 000000000..5c39d249a --- /dev/null +++ b/compute/ncnn/README.md @@ -0,0 +1,9 @@ +### NCNN compute library + +This compute library is based on NCNN project (https://github.com/Tencent/ncnn) with custom optimization + +Current base commit: https://github.com/Tencent/ncnn/commit/0219f507b71bdb945d776c8586c162f2c22bba54 + +Added files for custom optimization is placed on +- Headers: include/ncnn/srcn +- Soruces: src/srcn diff --git a/compute/ncnn/include/ncnn/layer/binaryop.h b/compute/ncnn/include/ncnn/layer/binaryop.h new file mode 100644 index 000000000..4ccfd94b4 --- /dev/null +++ b/compute/ncnn/include/ncnn/layer/binaryop.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef __NCNN_LAYER_BINARYOP_H__ +#define __NCNN_LAYER_BINARYOP_H__ + +#include "ncnn/mat.h" + +namespace nnfw +{ +namespace ncnn +{ + +enum class BinaryOp +{ + Operation_ADD = 0, + Operation_SUB = 1, + Operation_MUL = 2, + Operation_DIV = 3, + Operation_MAX = 4, + Operation_MIN = 5, + Operation_POW = 6, + Operation_SQUAREDDIFFERENCE = 7 +}; + +struct BinaryOpParam +{ + BinaryOp op_type; + float b; + + BinaryOpParam() : op_type{BinaryOp::Operation_ADD}, b{0.0f} {} +}; + +int ncnn_binary_op(const BinaryOpParam ¶m, const Mat &bottom_blob, const Mat &bottom_blob1, + Mat &top_blob); +// TODO Inplace function porting +// int ncnn_binary_op_inplace(const BinaryParam ¶m, Mat &bottom_top_blob) const; +// int ncnn_binary_op_inplace(const BinaryOpParam ¶m, std::vector<Mat> &bottom_top_blobs) const; + +} // namespace ncnn +} // naemsapce nnfw + +#endif // __NCNN_LAYER_BINARYOP_H__ diff --git a/compute/ncnn/include/ncnn/layer/instance_norm.h b/compute/ncnn/include/ncnn/layer/instance_norm.h new file mode 100644 index 000000000..b7d89281d --- /dev/null +++ b/compute/ncnn/include/ncnn/layer/instance_norm.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef __NNFW_LAYER_INSTANCE_NORM_H_ +#define __NNFW_LAYER_INSTANCE_NORM_H_ + +#include "ncnn/mat.h" +#ifdef __ARM_NEON +#include <arm_neon.h> +#endif // __ARM_NEON + +namespace nnfw +{ +namespace ncnn +{ + +void ncnn_instance_norm_rowmajor(Mat &in_mat, Mat &out_mat, Mat &gamma_mat, Mat &beta_mat, + int channels, float eps); + +void ncnn_instance_norm_colmajor(Mat &in_mat, Mat &out_mat, Mat &gamma_mat, Mat &beta_mat, + int channels, float eps); + +void ncnn_instance_norm_with_relu_rowmajor(Mat &in_mat, Mat &out_mat, Mat &gamma_mat, Mat &beta_mat, + int channels, float eps, float slope); + +void ncnn_instance_norm_with_relu_colmajor(Mat &in_mat, Mat &out_mat, Mat &gamma_mat, Mat &beta_mat, + int channels, float eps, float slope); + +} // namespace ncnn + +} // namespace nnfw + +#endif // __NNFW_LAYER_INSTANCE_NORM_H_ diff --git a/compute/ncnn/include/ncnn/mat.h b/compute/ncnn/include/ncnn/mat.h new file mode 100644 index 000000000..2a577939d --- /dev/null +++ b/compute/ncnn/include/ncnn/mat.h @@ -0,0 +1,738 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_NCNN_MAT_H__ +#define __NNFW_NCNN_MAT_H__ + +#include <stdlib.h> +#include <string.h> +#if __ARM_NEON +#include <arm_neon.h> +#endif + +namespace nnfw +{ +namespace ncnn +{ + +// the three dimension matrix +class Mat +{ +public: + // empty + Mat(); + // vec + Mat(int w, size_t elemsize = 4); + // image + Mat(int w, int h, size_t elemsize = 4); + // dim + Mat(int w, int h, int c, size_t elemsize = 4); + // copy + Mat(const Mat &m); + // external vec + Mat(int w, void *data, size_t elemsize = 4); + // external image + Mat(int w, int h, void *data, size_t elemsize = 4); + // external dim + Mat(int w, int h, int c, void *data, size_t elemsize = 4); + // release + ~Mat(); + // assign + Mat &operator=(const Mat &m); + // set all + void fill(float v); + template <typename T> void fill(T v); + // deep copy + Mat clone() const; + // reshape vec + Mat reshape(int w) const; + // reshape image + Mat reshape(int w, int h) const; + // reshape dim + Mat reshape(int w, int h, int c) const; + // allocate vec + void create(int w, size_t elemsize = 4); + // allocate image + void create(int w, int h, size_t elemsize = 4); +// allocate dim +#ifdef _MEMORY_TO_TIME_ + void create(int w, int h, int c, size_t elemsize = 4, bool isNew = false); +#else + void create(int w, int h, int c, size_t elemsize = 4); +#endif +#ifdef USE_OPENCL_INSIDE + void create_empity_mat(int _w, int _h, int _c, size_t _elemsize); +#endif + + // refcount++ + void addref(); + // refcount-- + void release(); + + bool empty() const; + size_t total() const; + + // data reference + Mat channel(int c); + const Mat channel(int c) const; + float *row(int y); + const float *row(int y) const; + template <typename T> T *row(int y); + template <typename T> const T *row(int y) const; + + // access raw data + template <typename T> operator T *(); + template <typename T> operator const T *() const; + + // convenient access float vec element + float &operator[](int i); + const float &operator[](int i) const; + + enum + { + PIXEL_CONVERT_SHIFT = 16, + PIXEL_FORMAT_MASK = 0x0000ffff, + PIXEL_CONVERT_MASK = 0xffff0000, + + PIXEL_RGB = 1, + PIXEL_BGR = (1 << 1), + PIXEL_GRAY = (1 << 2), + PIXEL_RGBA = (1 << 3), + + PIXEL_RGB2BGR = PIXEL_RGB | (PIXEL_BGR << PIXEL_CONVERT_SHIFT), + PIXEL_RGB2GRAY = PIXEL_RGB | (PIXEL_GRAY << PIXEL_CONVERT_SHIFT), + + PIXEL_BGR2RGB = PIXEL_BGR | (PIXEL_RGB << PIXEL_CONVERT_SHIFT), + PIXEL_BGR2GRAY = PIXEL_BGR | (PIXEL_GRAY << PIXEL_CONVERT_SHIFT), + + PIXEL_GRAY2RGB = PIXEL_GRAY | (PIXEL_RGB << PIXEL_CONVERT_SHIFT), + PIXEL_GRAY2BGR = PIXEL_GRAY | (PIXEL_BGR << PIXEL_CONVERT_SHIFT), + + PIXEL_RGBA2RGB = PIXEL_RGBA | (PIXEL_RGB << PIXEL_CONVERT_SHIFT), + PIXEL_RGBA2BGR = PIXEL_RGBA | (PIXEL_BGR << PIXEL_CONVERT_SHIFT), + PIXEL_RGBA2GRAY = PIXEL_RGBA | (PIXEL_GRAY << PIXEL_CONVERT_SHIFT), + }; + +#ifdef _MEMORY_TO_TIME_ + static void from_pixels(const unsigned char *pixels, Mat &m, int type, int w, int h); + static void from_pixels(const unsigned char *pixels, Mat &m, int type, int w, int h, int top, + int bottom, int left, int right); +#endif // _MEMORY_TO_TIME_ + + // convenient construct from pixel data + static Mat from_pixels(const unsigned char *pixels, int type, int w, int h); + // convenient construct from pixel data and add the padding && only supports same PIXEL_RGB2BGR + // and PIXEL_BGR2RGB now + static Mat from_pixels(const unsigned char *pixels, int type, int w, int h, int top, int bottom, + int left, int right); + // convenient construct from pixel data and resize to specific size + static Mat from_pixels_resize(const unsigned char *pixels, int type, int w, int h, + int target_width, int target_height); + + // convenient export to pixel data + void to_pixels(unsigned char *pixels, int type); + // convenient export to pixel data and cut the padding && only supports same PIXEL_RGB2BGR and + // PIXEL_BGR2RGB now + void to_pixels(unsigned char *pixels, int type, int top, int bottom, int left, int right); + // convenient export to pixel data and resize to specific size + void to_pixels_resize(unsigned char *pixels, int type, int target_width, int target_height); + + // substract channel-wise mean values, then multiply by normalize values, pass 0 to skip + void substract_mean_normalize(const float *mean_vals, const float *norm_vals); + + // convenient construct from half precisoin floating point data + static Mat from_float16(const unsigned short *data, int size); + + // pointer to the data + void *data; + + // pointer to the reference counter + // when points to user-allocated data, the pointer is NULL + int *refcount; + + // element size in bytes + // 4 = float32/int32 + // 2 = float16 + // 1 = int8/uint8 + // 0 = empty + size_t elemsize; + + // the dimensionality + int dims; + + int w; + int h; + int c; + + size_t cstep; +}; + +// misc function +// image pixel bilinear resize +void resize_bilinear_c1(const unsigned char *src, int srcw, int srch, unsigned char *dst, int w, + int h); +void resize_bilinear_c3(const unsigned char *src, int srcw, int srch, unsigned char *dst, int w, + int h); +void resize_bilinear_c4(const unsigned char *src, int srcw, int srch, unsigned char *dst, int w, + int h); + +// mat process +enum +{ + BORDER_CONSTANT = 0, + BORDER_REPLICATE = 1, +}; +void copy_make_border(const Mat &src, Mat &dst, int top, int bottom, int left, int right, int type, + float v); +void copy_cut_border(const Mat &src, Mat &dst, int top, int bottom, int left, int right); +void resize_bilinear(const Mat &src, Mat &dst, int w, int h); + +// the alignment of all the allocated buffers +#define MALLOC_ALIGN 16 + +// Aligns a pointer to the specified number of bytes +// ptr Aligned pointer +// n Alignment size that must be a power of two +template <typename _Tp> static inline _Tp *alignPtr(_Tp *ptr, int n = (int)sizeof(_Tp)) +{ + return (_Tp *)(((size_t)ptr + n - 1) & -n); +} + +// Aligns a buffer size to the specified number of bytes +// The function returns the minimum number that is greater or equal to sz and is divisible by n +// sz Buffer size to align +// n Alignment size that must be a power of two +static inline size_t alignSize(size_t sz, int n) { return (sz + n - 1) & -n; } + +static inline void *fastMalloc(size_t size) +{ + unsigned char *udata = (unsigned char *)malloc(size + sizeof(void *) + MALLOC_ALIGN); + if (!udata) + return 0; + unsigned char **adata = alignPtr((unsigned char **)udata + 1, MALLOC_ALIGN); + adata[-1] = udata; + return adata; +} + +static inline void fastFree(void *ptr) +{ + if (ptr) + { + unsigned char *udata = ((unsigned char **)ptr)[-1]; + free(udata); + } +} + +// exchange-add operation for atomic operations on reference counters +#if defined __INTEL_COMPILER && !(defined WIN32 || defined _WIN32) +// atomic increment on the linux version of the Intel(tm) compiler +#define NCNN_XADD(addr, delta) \ + (int)_InterlockedExchangeAdd(const_cast<void *>(reinterpret_cast<volatile void *>(addr)), delta) +#elif defined __GNUC__ +#if defined __clang__ && __clang_major__ >= 3 && !defined __ANDROID__ && \ + !defined __EMSCRIPTEN__ && !defined(__CUDACC__) +#ifdef __ATOMIC_ACQ_REL +#define NCNN_XADD(addr, delta) \ + __c11_atomic_fetch_add((_Atomic(int) *)(addr), delta, __ATOMIC_ACQ_REL) +#else +#define NCNN_XADD(addr, delta) __atomic_fetch_add((_Atomic(int) *)(addr), delta, 4) +#endif +#else +#if defined __ATOMIC_ACQ_REL && !defined __clang__ +// version for gcc >= 4.7 +#define NCNN_XADD(addr, delta) \ + (int)__atomic_fetch_add((unsigned *)(addr), (unsigned)(delta), __ATOMIC_ACQ_REL) +#else +#define NCNN_XADD(addr, delta) (int)__sync_fetch_and_add((unsigned *)(addr), (unsigned)(delta)) +#endif +#endif +#elif defined _MSC_VER && !defined RC_INVOKED +#include <intrin.h> +#define NCNN_XADD(addr, delta) (int)_InterlockedExchangeAdd((long volatile *)addr, delta) +#else +static inline void NCNN_XADD(int *addr, int delta) +{ + int tmp = *addr; + *addr += delta; + return tmp; +} +#endif + +inline Mat::Mat() : data(0), refcount(0), elemsize(0), dims(0), w(0), h(0), c(0), cstep(0) {} + +inline Mat::Mat(int _w, size_t _elemsize) : data(0), refcount(0), dims(0) { create(_w, _elemsize); } + +inline Mat::Mat(int _w, int _h, size_t _elemsize) : data(0), refcount(0), dims(0) +{ + create(_w, _h, _elemsize); +} + +inline Mat::Mat(int _w, int _h, int _c, size_t _elemsize) : data(0), refcount(0), dims(0) +{ + create(_w, _h, _c, _elemsize); +} + +inline Mat::Mat(const Mat &m) + : data(m.data), refcount(m.refcount), elemsize(m.elemsize), dims(m.dims) +{ + if (refcount) + NCNN_XADD(refcount, 1); + + w = m.w; + h = m.h; + c = m.c; + + cstep = m.cstep; +} + +inline Mat::Mat(int _w, void *_data, size_t _elemsize) + : data(_data), refcount(0), elemsize(_elemsize), dims(1) +{ + w = _w; + h = 1; + c = 1; + + cstep = w; +} + +inline Mat::Mat(int _w, int _h, void *_data, size_t _elemsize) + : data(_data), refcount(0), elemsize(_elemsize), dims(2) +{ + w = _w; + h = _h; + c = 1; + + cstep = w * h; +} + +inline Mat::Mat(int _w, int _h, int _c, void *_data, size_t _elemsize) + : data(_data), refcount(0), elemsize(_elemsize), dims(3) +{ + w = _w; + h = _h; + c = _c; + + cstep = alignSize(w * h * elemsize, 16) / elemsize; +} + +inline Mat::~Mat() { release(); } + +inline Mat &Mat::operator=(const Mat &m) +{ + if (this == &m) + return *this; + + if (m.refcount) + NCNN_XADD(m.refcount, 1); + + release(); + + data = m.data; + refcount = m.refcount; + elemsize = m.elemsize; + + dims = m.dims; + w = m.w; + h = m.h; + c = m.c; + + cstep = m.cstep; + + return *this; +} + +inline void Mat::fill(float _v) +{ + int size = total(); + float *ptr = (float *)data; + +#if __ARM_NEON + int nn = size >> 2; + int remain = size - (nn << 2); +#else + int remain = size; +#endif // __ARM_NEON + +#if __ARM_NEON + float32x4_t _c = vdupq_n_f32(_v); +#if __aarch64__ + if (nn > 0) + { + asm volatile("0: \n" + "subs %w0, %w0, #1 \n" + "st1 {%4.4s}, [%1], #16 \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr) // %1 + : "0"(nn), "1"(ptr), + "w"(_c) // %4 + : "cc", "memory"); + } +#else + if (nn > 0) + { + asm volatile("0: \n" + "subs %0, #1 \n" + "vst1.f32 {%e4-%f4}, [%1 :128]!\n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr) // %1 + : "0"(nn), "1"(ptr), + "w"(_c) // %4 + : "cc", "memory"); + } +#endif // __aarch64__ +#endif // __ARM_NEON + for (; remain > 0; remain--) + { + *ptr++ = _v; + } +} + +template <typename T> inline void Mat::fill(T _v) +{ + int size = total(); + T *ptr = (T *)data; + for (int i = 0; i < size; i++) + { + ptr[i] = _v; + } +} + +inline Mat Mat::clone() const +{ + if (empty()) + return Mat(); + + Mat m; + if (dims == 1) + m.create(w, elemsize); + else if (dims == 2) + m.create(w, h, elemsize); + else if (dims == 3) + m.create(w, h, c, elemsize); + + if (total() > 0) + { + memcpy(m.data, data, total() * elemsize); + } + + return m; +} + +inline Mat Mat::reshape(int _w) const +{ + if (w * h * c != _w) + return Mat(); + + if (dims == 3 && cstep != (size_t)w * h) + { + Mat m; + m.create(_w, elemsize); + + // flatten + for (int i = 0; i < c; i++) + { + const void *ptr = (unsigned char *)data + i * cstep * elemsize; + void *mptr = (unsigned char *)m.data + i * w * h * elemsize; + memcpy(mptr, ptr, w * h * elemsize); + } + + return m; + } + + Mat m = *this; + + m.dims = 1; + m.w = _w; + m.h = 1; + m.c = 1; + + m.cstep = _w; + + return m; +} + +inline Mat Mat::reshape(int _w, int _h) const +{ + if (w * h * c != _w * _h) + return Mat(); + + if (dims == 3 && cstep != (size_t)w * h) + { + Mat m; + m.create(_w, _h, elemsize); + + // flatten + for (int i = 0; i < c; i++) + { + const void *ptr = (unsigned char *)data + i * cstep * elemsize; + void *mptr = (unsigned char *)m.data + i * w * h * elemsize; + memcpy(mptr, ptr, w * h * elemsize); + } + + return m; + } + + Mat m = *this; + + m.dims = 2; + m.w = _w; + m.h = _h; + m.c = 1; + + m.cstep = _w * _h; + + return m; +} + +inline Mat Mat::reshape(int _w, int _h, int _c) const +{ + if (w * h * c != _w * _h * _c) + return Mat(); + + if (dims < 3) + { + if ((size_t)_w * _h != alignSize(_w * _h * elemsize, 16) / elemsize) + { + Mat m; + m.create(_w, _h, _c, elemsize); + + // align channel + for (int i = 0; i < _c; i++) + { + const void *ptr = (unsigned char *)data + i * _w * _h * elemsize; + void *mptr = (unsigned char *)m.data + i * m.cstep * m.elemsize; + memcpy(mptr, ptr, _w * _h * elemsize); + } + + return m; + } + } + else if (c != _c) + { + // flatten and then align + Mat tmp = reshape(_w * _h * _c); + return tmp.reshape(_w, _h, _c); + } + + Mat m = *this; + + m.dims = 3; + m.w = _w; + m.h = _h; + m.c = _c; + + m.cstep = alignSize(_w * _h * elemsize, 16) / elemsize; + + return m; +} + +inline void Mat::create(int _w, size_t _elemsize) +{ + if (dims == 1 && w == _w && elemsize == _elemsize) + return; + + release(); + + elemsize = _elemsize; + + dims = 1; + w = _w; + h = 1; + c = 1; + + cstep = w; + + if (total() > 0) + { + size_t totalsize = total() * elemsize; + data = fastMalloc(totalsize + (int)sizeof(*refcount)); + refcount = (int *)(((unsigned char *)data) + totalsize); + *refcount = 1; + } +} + +inline void Mat::create(int _w, int _h, size_t _elemsize) +{ + if (dims == 2 && w == _w && h == _h && elemsize == _elemsize) + return; + + release(); + + elemsize = _elemsize; + + dims = 2; + w = _w; + h = _h; + c = 1; + + cstep = w * h; + + if (total() > 0) + { + size_t totalsize = total() * elemsize; + data = fastMalloc(totalsize + (int)sizeof(*refcount)); + refcount = (int *)(((unsigned char *)data) + totalsize); + *refcount = 1; + } +} + +#ifdef _MEMORY_TO_TIME_ +inline void Mat::create(int _w, int _h, int _c, size_t _elemsize, bool isNew) +{ + if (dims == 3 && w == _w && h == _h && c == _c && elemsize == _elemsize) + return; + + if (!isNew && dims == 3) + { + elemsize = _elemsize; + + w = _w; + h = _h; + c = _c; + + cstep = alignSize(w * h * elemsize, 16) / elemsize; + return; + } + + release(); + + elemsize = _elemsize; + + dims = 3; + w = _w; + h = _h; + c = _c; + + cstep = alignSize(w * h * elemsize, 16) / elemsize; + + if (total() > 0) + { + size_t totalsize = total() * elemsize; + data = fastMalloc(totalsize + (int)sizeof(*refcount)); + refcount = (int *)(((unsigned char *)data) + totalsize); + *refcount = 1; + } +} + +#else +inline void Mat::create(int _w, int _h, int _c, size_t _elemsize) +{ + if (dims == 3 && w == _w && h == _h && c == _c && elemsize == _elemsize) + return; + + release(); + + elemsize = _elemsize; + + dims = 3; + w = _w; + h = _h; + c = _c; + + cstep = alignSize(w * h * elemsize, 16) / elemsize; + + if (total() > 0) + { + size_t totalsize = total() * elemsize; + data = fastMalloc(totalsize + (int)sizeof(*refcount)); + refcount = (int *)(((unsigned char *)data) + totalsize); + *refcount = 1; + } +} +#endif //_MEMORY_TO_TIME_ + +#ifdef USE_OPENCL_INSIDE +inline void Mat::create_empity_mat(int _w, int _h, int _c, size_t _elemsize) +{ + if (dims == 3 && w == _w && h == _h && c == _c && elemsize == _elemsize) + return; + + release(); + + elemsize = _elemsize; + + dims = 3; + w = _w; + h = _h; + c = _c; + + cstep = alignSize(w * h * elemsize, 16) / elemsize; + data = NULL; +} +#endif // USE_OPENCL_INSIDE + +inline void Mat::addref() +{ + if (refcount) + NCNN_XADD(refcount, 1); +} + +inline void Mat::release() +{ + if (refcount && NCNN_XADD(refcount, -1) == 1) + fastFree(data); + + data = 0; + + elemsize = 0; + + dims = 0; + w = 0; + h = 0; + c = 0; + + cstep = 0; + + refcount = 0; +} + +inline bool Mat::empty() const { return data == 0 || total() == 0; } + +inline size_t Mat::total() const { return cstep * c; } + +inline Mat Mat::channel(int c) +{ + return Mat(w, h, (unsigned char *)data + cstep * c * elemsize, elemsize); +} + +inline const Mat Mat::channel(int c) const +{ + return Mat(w, h, (unsigned char *)data + cstep * c * elemsize, elemsize); +} + +inline float *Mat::row(int y) { return (float *)data + w * y; } + +inline const float *Mat::row(int y) const { return (const float *)data + w * y; } + +template <typename T> inline T *Mat::row(int y) { return (T *)data + w * y; } + +template <typename T> inline const T *Mat::row(int y) const { return (const T *)data + w * y; } + +template <typename T> inline Mat::operator T *() { return (T *)data; } + +template <typename T> inline Mat::operator const T *() const { return (const T *)data; } + +inline float &Mat::operator[](int i) { return ((float *)data)[i]; } + +inline const float &Mat::operator[](int i) const { return ((const float *)data)[i]; } + +} // namespace ncnn +} // namespace nnfw + +#endif // __NNFW_NCNN_MAT_H__ diff --git a/compute/ncnn/include/ncnn/srcn/conv_type.h b/compute/ncnn/include/ncnn/srcn/conv_type.h new file mode 100644 index 000000000..59152a094 --- /dev/null +++ b/compute/ncnn/include/ncnn/srcn/conv_type.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_CONV_TYPE_H__ +#define __NNFW_SRCN_CONV_TYPE_H__ + +namespace nnfw +{ +namespace srcn +{ + +enum convType_t +{ + row_major = 0, + col_major +}; + +struct convMat_t +{ + int w; + int h; + int c; + int n; + float *data; +}; + +struct convParams_t +{ + int kernel_w; + int kernel_h; + int stride_w; + int stride_h; + int dilation_w; + int dilation_h; + int padding; + int pad_w; + int pad_h; +}; + +struct winogradParams_t +{ + int kernel_w; + int kernel_h; + int stride_w; + int stride_h; + int dilation_w; + int dilation_h; + int batch; + int w; + int h; + int inch; + int outch; + int num_threads; + convType_t conv_type; + float *weight_data; +}; + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_CONV_TYPE_H__ diff --git a/compute/ncnn/include/ncnn/srcn/srcn_conv.h b/compute/ncnn/include/ncnn/srcn/srcn_conv.h new file mode 100644 index 000000000..11130c0db --- /dev/null +++ b/compute/ncnn/include/ncnn/srcn/srcn_conv.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_CONV_H__ +#define __NNFW_SRCN_CONV_H__ + +#include "conv_type.h" + +namespace nnfw +{ +namespace srcn +{ + +int check_winograd(winogradParams_t ¶ms); + +float *trans_weight2winograd(winogradParams_t ¶ms, unsigned int *size = NULL); + +void winograd_release(float *winograd_weight); + +void srcn_convolution2D(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, + const convParams_t &in_param, const float *winograd_weight, int num_threads, + convType_t conv_type); + +void srcn_deconvolution2D(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, + const convParams_t &in_param, int num_threads, convType_t conv_type); + +void *trans_weight2sparse(const convMat_t &weights_mat); + +void sparse_release(const int outch, void *ptr); + +void srcn_sparse_convolution2D(const convMat_t &in_mat, convMat_t &out_mat, + const convParams_t &in_param, const void *sparse_weight, + int number_threas, convType_t conv_type); + +void srcn_batch_convolution2D(const convMat_t &in_mat, const convMat_t &weights_mat, + convMat_t &out_mat, const convParams_t &in_param, + const float *winograd_weight, int num_threads, convType_t conv_type); + +void srcn_convolution2D_gpu(const convMat_t &in_mat, const convMat_t &weights_mat, + convMat_t &out_mat, const convParams_t &in_param, convType_t conv_type); + +void srcn_convolution2D_dpu(const convMat_t &in_mat, const convMat_t &weights_mat, + convMat_t &out_mat, const convParams_t &in_param, convType_t conv_type); + +void srcn_depthwise_conv(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, + const convMat_t &bias, const convParams_t &in_param, int num_threads, + convType_t conv_type); + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_CONV_H__ diff --git a/compute/ncnn/src/layer/arm/neon_mathfun.h b/compute/ncnn/src/layer/arm/neon_mathfun.h new file mode 100644 index 000000000..6e3cb66c8 --- /dev/null +++ b/compute/ncnn/src/layer/arm/neon_mathfun.h @@ -0,0 +1,315 @@ +/* NEON implementation of sin, cos, exp and log + * + * Inspired by Intel Approximate Math library, and based on the + * corresponding algorithms of the cephes math library + */ + +/* Copyright (C) 2011 Julien Pommier + * + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + * misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + * + * (this is the zlib license) + */ + +#include <arm_neon.h> + +#define c_inv_mant_mask ~0x7f800000u +#define c_cephes_SQRTHF 0.707106781186547524 +#define c_cephes_log_p0 7.0376836292E-2 +#define c_cephes_log_p1 -1.1514610310E-1 +#define c_cephes_log_p2 1.1676998740E-1 +#define c_cephes_log_p3 -1.2420140846E-1 +#define c_cephes_log_p4 +1.4249322787E-1 +#define c_cephes_log_p5 -1.6668057665E-1 +#define c_cephes_log_p6 +2.0000714765E-1 +#define c_cephes_log_p7 -2.4999993993E-1 +#define c_cephes_log_p8 +3.3333331174E-1 +#define c_cephes_log_q1 -2.12194440e-4 +#define c_cephes_log_q2 0.693359375 + +/* natural logarithm computed for 4 simultaneous float + * return NaN for x <= 0 + */ +static inline float32x4_t log_ps(float32x4_t x) +{ + float32x4_t one = vdupq_n_f32(1); + + x = vmaxq_f32(x, vdupq_n_f32(0)); /* force flush to zero on denormal values */ + uint32x4_t invalid_mask = vcleq_f32(x, vdupq_n_f32(0)); + + int32x4_t ux = vreinterpretq_s32_f32(x); + + int32x4_t emm0 = vshrq_n_s32(ux, 23); + + /* keep only the fractional part */ + ux = vandq_s32(ux, vdupq_n_s32(c_inv_mant_mask)); + ux = vorrq_s32(ux, vreinterpretq_s32_f32(vdupq_n_f32(0.5f))); + x = vreinterpretq_f32_s32(ux); + + emm0 = vsubq_s32(emm0, vdupq_n_s32(0x7f)); + float32x4_t e = vcvtq_f32_s32(emm0); + + e = vaddq_f32(e, one); + + /* part2: + * if( x < SQRTHF ) { + * e -= 1; + * x = x + x - 1.0; + * } else { x = x - 1.0; } + */ + uint32x4_t mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF)); + float32x4_t tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); + x = vsubq_f32(x, one); + e = vsubq_f32(e, vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(one), mask))); + x = vaddq_f32(x, tmp); + + float32x4_t z = vmulq_f32(x, x); + + float32x4_t y = vdupq_n_f32(c_cephes_log_p0); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7)); + y = vmulq_f32(y, x); + y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8)); + y = vmulq_f32(y, x); + + y = vmulq_f32(y, z); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1)); + y = vaddq_f32(y, tmp); + + tmp = vmulq_f32(z, vdupq_n_f32(0.5f)); + y = vsubq_f32(y, tmp); + + tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2)); + x = vaddq_f32(x, y); + x = vaddq_f32(x, tmp); + x = vreinterpretq_f32_u32( + vorrq_u32(vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN + return x; +} + +#define c_exp_hi 88.3762626647949f +#define c_exp_lo -88.3762626647949f + +#define c_cephes_LOG2EF 1.44269504088896341 +#define c_cephes_exp_C1 0.693359375 +#define c_cephes_exp_C2 -2.12194440e-4 + +#define c_cephes_exp_p0 1.9875691500E-4 +#define c_cephes_exp_p1 1.3981999507E-3 +#define c_cephes_exp_p2 8.3334519073E-3 +#define c_cephes_exp_p3 4.1665795894E-2 +#define c_cephes_exp_p4 1.6666665459E-1 +#define c_cephes_exp_p5 5.0000001201E-1 + +/* exp() computed for 4 float at once */ +static inline float32x4_t exp_ps(float32x4_t x) +{ + float32x4_t tmp, fx; + + float32x4_t one = vdupq_n_f32(1); + x = vminq_f32(x, vdupq_n_f32(c_exp_hi)); + x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo)); + + /* express exp(x) as exp(g + n*log(2)) */ + fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF)); + + /* perform a floorf */ + tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx)); + + /* if greater, substract 1 */ + uint32x4_t mask = vcgtq_f32(tmp, fx); + mask = vandq_u32(mask, vreinterpretq_u32_f32(one)); + + fx = vsubq_f32(tmp, vreinterpretq_f32_u32(mask)); + + tmp = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C1)); + float32x4_t z = vmulq_f32(fx, vdupq_n_f32(c_cephes_exp_C2)); + x = vsubq_f32(x, tmp); + x = vsubq_f32(x, z); + + static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1, c_cephes_exp_p2, + c_cephes_exp_p3, c_cephes_exp_p4, c_cephes_exp_p5}; + float32x4_t y = vld1q_dup_f32(cephes_exp_p + 0); + float32x4_t c1 = vld1q_dup_f32(cephes_exp_p + 1); + float32x4_t c2 = vld1q_dup_f32(cephes_exp_p + 2); + float32x4_t c3 = vld1q_dup_f32(cephes_exp_p + 3); + float32x4_t c4 = vld1q_dup_f32(cephes_exp_p + 4); + float32x4_t c5 = vld1q_dup_f32(cephes_exp_p + 5); + + y = vmulq_f32(y, x); + z = vmulq_f32(x, x); + + y = vaddq_f32(y, c1); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c2); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c3); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c4); + y = vmulq_f32(y, x); + y = vaddq_f32(y, c5); + + y = vmulq_f32(y, z); + y = vaddq_f32(y, x); + y = vaddq_f32(y, one); + + /* build 2^n */ + int32x4_t mm; + mm = vcvtq_s32_f32(fx); + mm = vaddq_s32(mm, vdupq_n_s32(0x7f)); + mm = vshlq_n_s32(mm, 23); + float32x4_t pow2n = vreinterpretq_f32_s32(mm); + + y = vmulq_f32(y, pow2n); + return y; +} + +#define c_minus_cephes_DP1 -0.78515625 +#define c_minus_cephes_DP2 -2.4187564849853515625e-4 +#define c_minus_cephes_DP3 -3.77489497744594108e-8 +#define c_sincof_p0 -1.9515295891E-4 +#define c_sincof_p1 8.3321608736E-3 +#define c_sincof_p2 -1.6666654611E-1 +#define c_coscof_p0 2.443315711809948E-005 +#define c_coscof_p1 -1.388731625493765E-003 +#define c_coscof_p2 4.166664568298827E-002 +#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI + +/* evaluation of 4 sines & cosines at once. + * + * The code is the exact rewriting of the cephes sinf function. + * Precision is excellent as long as x < 8192 (I did not bother to + * take into account the special handling they have for greater values + * -- it does not return garbage for arguments over 8192, though, but + * the extra precision is missing). + * + * Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the + * surprising but correct result. + * + * Note also that when you compute sin(x), cos(x) is available at + * almost no extra price so both sin_ps and cos_ps make use of + * sincos_ps.. + */ +static inline void sincos_ps(float32x4_t x, float32x4_t *ysin, float32x4_t *ycos) +{ + // any x + float32x4_t xmm1, xmm2, xmm3, y; + + uint32x4_t emm2; + + uint32x4_t sign_mask_sin, sign_mask_cos; + sign_mask_sin = vcltq_f32(x, vdupq_n_f32(0)); + x = vabsq_f32(x); + + /* scale by 4/Pi */ + y = vmulq_f32(x, vdupq_n_f32(c_cephes_FOPI)); + + /* store the integer part of y in mm0 */ + emm2 = vcvtq_u32_f32(y); + /* j=(j+1) & (~1) (see the cephes sources) */ + emm2 = vaddq_u32(emm2, vdupq_n_u32(1)); + emm2 = vandq_u32(emm2, vdupq_n_u32(~1)); + y = vcvtq_f32_u32(emm2); + + /* get the polynom selection mask + * there is one polynom for 0 <= x <= Pi/4 + * and another one for Pi/4<x<=Pi/2 + * + * Both branches will be computed. + */ + uint32x4_t poly_mask = vtstq_u32(emm2, vdupq_n_u32(2)); + + /* The magic pass: "Extended precision modular arithmetic" + * x = ((x - y * DP1) - y * DP2) - y * DP3; */ + xmm1 = vmulq_n_f32(y, c_minus_cephes_DP1); + xmm2 = vmulq_n_f32(y, c_minus_cephes_DP2); + xmm3 = vmulq_n_f32(y, c_minus_cephes_DP3); + x = vaddq_f32(x, xmm1); + x = vaddq_f32(x, xmm2); + x = vaddq_f32(x, xmm3); + + sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4))); + sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4)); + + /* Evaluate the first polynom (0 <= x <= Pi/4) in y1, + * and the second polynom (Pi/4 <= x <= 0) in y2 */ + float32x4_t z = vmulq_f32(x, x); + float32x4_t y1, y2; + + y1 = vmulq_n_f32(z, c_coscof_p0); + y2 = vmulq_n_f32(z, c_sincof_p0); + y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p1)); + y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p1)); + y1 = vmulq_f32(y1, z); + y2 = vmulq_f32(y2, z); + y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p2)); + y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p2)); + y1 = vmulq_f32(y1, z); + y2 = vmulq_f32(y2, z); + y1 = vmulq_f32(y1, z); + y2 = vmulq_f32(y2, x); + y1 = vsubq_f32(y1, vmulq_f32(z, vdupq_n_f32(0.5f))); + y2 = vaddq_f32(y2, x); + y1 = vaddq_f32(y1, vdupq_n_f32(1)); + + /* select the correct result from the two polynoms */ + float32x4_t ys = vbslq_f32(poly_mask, y1, y2); + float32x4_t yc = vbslq_f32(poly_mask, y2, y1); + *ysin = vbslq_f32(sign_mask_sin, vnegq_f32(ys), ys); + *ycos = vbslq_f32(sign_mask_cos, yc, vnegq_f32(yc)); +} + +static inline float32x4_t sin_ps(float32x4_t x) +{ + float32x4_t ysin, ycos; + sincos_ps(x, &ysin, &ycos); + return ysin; +} + +static inline float32x4_t cos_ps(float32x4_t x) +{ + float32x4_t ysin, ycos; + sincos_ps(x, &ysin, &ycos); + return ycos; +} + +static inline float32x4_t div_ps(float32x4_t a, float32x4_t b) +{ + float32x4_t reciprocal = vrecpeq_f32(b); + reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal); + // reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal); + return vmulq_f32(a, reciprocal); +} + +static inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) +{ + // pow(x, m) = exp(m * log(x)) + return exp_ps(vmulq_f32(b, log_ps(a))); +} diff --git a/compute/ncnn/src/layer/binaryop.cc b/compute/ncnn/src/layer/binaryop.cc new file mode 100644 index 000000000..a09d55f78 --- /dev/null +++ b/compute/ncnn/src/layer/binaryop.cc @@ -0,0 +1,1640 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ncnn/layer/binaryop.h" +#include <math.h> +#include <algorithm> +#include <functional> +#include <sys/time.h> + +#if __ARM_NEON +#include <arm_neon.h> +#include "arm/neon_mathfun.h" +#endif // __ARM_NEON + +namespace nnfw +{ +namespace ncnn +{ + +template <typename Op> static int binary_op(const Mat &a, const Mat &b, Mat &c) +{ + Op op; + + int w = a.w; + int h = a.h; + int channels = a.c; + int size = w * h; + + int w1 = b.w; + int h1 = b.h; + int channels1 = b.c; + int size1 = w1 * h1; + + if (a.dims == 3) + { + c.create(w, h, channels); + if (c.empty()) + return -100; + + if (b.dims == 3) + { + if (b.w == 1 && b.h == 1) + { + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = a.channel(q); + const float *ptr1 = b.channel(q); + float *outptr = c.channel(q); + + float tt = *ptr1; + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], tt); + } + } + + return 0; + } + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = a.channel(q); + const float *ptr1 = b.channel(q); + float *outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], ptr1[i]); + } + } + + return 0; + } + + if (b.dims == 2) + { +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = a.channel(q); + const float *ptr1 = (const float *)b + h * q; + float *outptr = c.channel(q); + + for (int y = 0; y < h; y++) + { + const float b0 = ptr1[y]; + for (int x = 0; x < w; x++) + { + outptr[x] = op(ptr[x], b0); + } + + ptr += w; + outptr += w; + } + } + + return 0; + } + + if (b.dims == 1) + { + if (b.w == 1) + { + const float b0 = b[0]; +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = a.channel(q); + float *outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b0); + } + } + + return 0; + } + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = a.channel(q); + const float b0 = b[q]; + float *outptr = c.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = op(ptr[i], b0); + } + } + + return 0; + } + } + else if (a.dims == 2) + { + if (b.dims == 3) + { + c.create(w1, h1, channels1); + if (c.empty()) + return -100; + +#pragma omp parallel for + for (int q = 0; q < channels1; q++) + { + const float *ptr = (const float *)a + h1 * q; + const float *ptr1 = b.channel(q); + float *outptr = c.channel(q); + + for (int y = 0; y < h1; y++) + { + const float a0 = ptr[y]; + for (int x = 0; x < w1; x++) + { + outptr[x] = op(a0, ptr1[x]); + } + + ptr1 += w1; + outptr += w1; + } + } + + return 0; + } + + c.create(w, h); + if (c.empty()) + return -100; + + if (b.dims == 2) + { + for (int i = 0; i < size; i++) + { + c[i] = op(a[i], b[i]); + } + + return 0; + } + + if (b.dims == 1) + { + c.create(w, h); + if (c.empty()) + return -100; + + if (b.w == 1) + { + const float b0 = b[0]; + for (int i = 0; i < size; i++) + { + c[i] = op(a[i], b0); + } + + return 0; + } + + const float *ptr = a; + float *outptr = c; + + for (int y = 0; y < h; y++) + { + const float b0 = b[y]; + for (int x = 0; x < w; x++) + { + outptr[x] = op(ptr[x], b0); + } + + ptr += w; + outptr += w; + } + + return 0; + } + } + else if (a.dims == 1) + { + if (a.w == 1) + { + if (b.dims == 3) + { + c.create(w1, h1, channels1); + if (c.empty()) + return -100; + + const float a0 = a[0]; +#pragma omp parallel for + for (int q = 0; q < channels1; q++) + { + const float *ptr1 = b.channel(q); + float *outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + outptr[i] = op(a0, ptr1[i]); + } + } + + return 0; + } + + if (b.dims == 2) + { + c.create(w1, h1); + if (c.empty()) + return -100; + + const float a0 = a[0]; + for (int i = 0; i < size1; i++) + { + c[i] = op(a0, b[i]); + } + + return 0; + } + + if (b.dims == 1) + { + c.create(w1); + if (c.empty()) + return -100; + + const float a0 = a[0]; + for (int i = 0; i < size1; i++) + { + c[i] = op(a0, b[i]); + } + + return 0; + } + } + + if (b.dims == 3) + { + c.create(w1, h1, channels1); + if (c.empty()) + return -100; + +#pragma omp parallel for + for (int q = 0; q < channels1; q++) + { + const float a0 = a[q]; + const float *ptr1 = b.channel(q); + float *outptr = c.channel(q); + + for (int i = 0; i < size1; i++) + { + outptr[i] = op(a0, ptr1[i]); + } + } + + return 0; + } + + if (b.dims == 2) + { + c.create(w1, h1); + if (c.empty()) + return -100; + + const float *ptr1 = b; + float *outptr = c; + + for (int y = 0; y < h1; y++) + { + const float a0 = a[y]; + for (int x = 0; x < w1; x++) + { + outptr[x] = op(a0, ptr1[x]); + } + + ptr1 += w1; + outptr += w1; + } + + return 0; + } + + if (b.dims == 1) + { + c.create(w); + if (c.empty()) + return -100; + + if (b.w == 1) + { + const float b0 = b[0]; + for (int i = 0; i < size; i++) + { + c[i] = op(a[i], b0); + } + + return 0; + } + + for (int i = 0; i < size; i++) + { + c[i] = op(a[i], b[i]); + } + } + } + + return 0; +} + +template <typename Op> static int binary_op_scalar_inplace(Mat &a, float b) +{ + Op op; + + int w = a.w; + int h = a.h; + int channels = a.c; + int size = w * h; + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + float *ptr = a.channel(q); + + for (int i = 0; i < size; i++) + { + ptr[i] = op(ptr[i], b); + } + } + + return 0; +} + +template <typename T> struct binary_op_max : std::binary_function<T, T, T> +{ + T operator()(const T &x, const T &y) const { return std::max(x, y); } +}; + +template <typename T> struct binary_op_min : std::binary_function<T, T, T> +{ + T operator()(const T &x, const T &y) const { return std::min(x, y); } +}; + +template <typename T> struct binary_op_pow : std::binary_function<T, T, T> +{ + T operator()(const T &x, const T &y) const { return pow(x, y); } +}; + +template <typename T> struct binary_op_SquaredDifference : std::binary_function<T, T, T> +{ + T operator()(const T &x, const T &y) const { return pow((x - y), 2); } +}; + +int ncnn_binary_op(const BinaryOpParam ¶m, const Mat &bottom_blob, const Mat &bottom_blob1, + Mat &top_blob) +{ + int ret = 0; + auto op_type = param.op_type; + // auto b = param.b; + + // Only support add operation, none broadcasting + // Other case, need to remove internal memory allocation and check correctness + if (op_type != BinaryOp::Operation_ADD) + { + throw std::runtime_error{"NYI: Only support ADD operation"}; + } + if (bottom_blob.dims != bottom_blob1.dims) + { + throw std::runtime_error{"NYI: Cannot use broadcasting"}; + } + +// printf("-------------------BinaryOp---------------\n"); + +// printf("op_type = %d, ", op_type); +// printf("in1: (%d, %d, %d), dims = %d, ", bottom_blob.w, bottom_blob.h, bottom_blob.c, +// bottom_blob.dims); +// printf("in2: (%d, %d, %d), dims = %d\n", bottom_blob1.w, bottom_blob1.h, bottom_blob1.c, +// bottom_blob1.dims); + +#if __ARM_NEON + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + int size = w * h; + + int w1 = bottom_blob1.w; + int h1 = bottom_blob1.h; + int channels1 = bottom_blob1.c; + int size1 = w1 * h1; + + if (op_type == BinaryOp::Operation_ADD) + { + if (bottom_blob.dims == 3 && bottom_blob1.dims == 3) + { + // Fix for nnfw: disable allocation for output + // top_blob.create(w, h, channels); + if (bottom_blob1.w == 1 && bottom_blob1.h == 1) + { + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + +#if __ARM_NEON + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + float tt = *ptr1; + + float32x4_t _p2 = vdupq_n_f32(tt); + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + + _p1 = vaddq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (*in1 + tt); + in1++; + out++; + } + +#else + float tt = *ptr1; + for (int i = 0; i < size; i++) + { + outptr[i] = (ptr[i] + tt); + } +#endif + } + + ret = 0; + } + else + { + if (size * bottom_blob.elemsize % 16 != 0) + { + throw std::runtime_error{"Unmatched alignment"}; + } + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *in2 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vld1q_f32(in2); + + _p1 = vaddq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + in2 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = *in1 + *in2; + in1++; + in2++; + out++; + } + } + } + } + else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1) + { + top_blob.create(w, h, channels); + if (bottom_blob1.w == 1) + { + ret = binary_op<std::plus<float>>(bottom_blob, bottom_blob1, top_blob); + // return ret; + goto out; + } + float *pt = (float *)bottom_blob1.data; + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float b0 = pt[q]; + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vdupq_n_f32(b0); + + _p1 = vaddq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (*in1 + b0); + in1++; + out++; + } + } + } + else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3) + { + top_blob.create(w1, h1, channels1); + if (top_blob.empty()) + return -100; + +#pragma omp parallel for + for (int q = 0; q < channels1; q++) + { + const float a0 = bottom_blob[q]; + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size1 >> 2; + int remain = size1 - (nn << 2); + + float *in1 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vdupq_n_f32(a0); + float32x4_t _p2 = vld1q_f32(in1); + + _p1 = vaddq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (a0 + *in1); + in1++; + out++; + } + } + } + else + ret = binary_op<std::plus<float>>(bottom_blob, bottom_blob1, top_blob); + } + +#if 0 // Disable operation except Operation_ADD + + if (op_type == BinaryOp::Operation_SUB) + { + if (bottom_blob.dims == 3 && bottom_blob1.dims == 3) + { + top_blob.create(w, h, channels); + + if (bottom_blob1.w == 1 && bottom_blob1.h == 1) + { + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + +#if __ARM_NEON + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + float tt = *ptr1; + + float32x4_t _p2 = vdupq_n_f32(tt); + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + + _p1 = vsubq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (*in1 - tt); + in1++; + out++; + } + +#else + float tt = *ptr1; + for (int i = 0; i < size; i++) + { + outptr[i] = (ptr[i] - tt); + } +#endif + } + + ret = 0; + } + else + { + top_blob.create(w, h, channels); +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *in2 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vld1q_f32(in2); + + _p1 = vsubq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + in2 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = *in1 - *in2; + in1++; + in2++; + out++; + } + } + } + } + else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1) + { + top_blob.create(w, h, channels); + if (bottom_blob1.w == 1) + { + ret = binary_op<std::minus<float>>(bottom_blob, bottom_blob1, top_blob); + // return ret; + goto out; + } + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float b0 = bottom_blob1[q]; + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vdupq_n_f32(b0); + + _p1 = vsubq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (*in1 - b0); + in1++; + out++; + } + } + } + else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3) + { + top_blob.create(w1, h1, channels1); + if (top_blob.empty()) + return -100; + +#pragma omp parallel for + for (int q = 0; q < channels1; q++) + { + const float a0 = bottom_blob[q]; + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size1 >> 2; + int remain = size1 - (nn << 2); + + float *in1 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vdupq_n_f32(a0); + float32x4_t _p2 = vld1q_f32(in1); + + _p1 = vsubq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (a0 - *in1); + in1++; + out++; + } + } + } + else + ret = binary_op<std::minus<float>>(bottom_blob, bottom_blob1, top_blob); + } + + if (op_type == BinaryOp::Operation_MUL) + { + if (bottom_blob.dims == 3 && bottom_blob1.dims == 3) + { + top_blob.create(w, h, channels); + + if (bottom_blob1.w == 1 && bottom_blob1.h == 1) + { + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + +#if __ARM_NEON + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + float tt = *ptr1; + + float32x4_t _p2 = vdupq_n_f32(tt); + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + + _p1 = vmulq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (*in1 * tt); + in1++; + out++; + } + +#else + float tt = *ptr1; + for (int i = 0; i < size; i++) + { + outptr[i] = (ptr[i] * tt); + } +#endif + } + + ret = 0; + } + else + { +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *in2 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vld1q_f32(in2); + + _p1 = vmulq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + in2 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = *in1 * *in2; + in1++; + in2++; + out++; + } + } + } + } + else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1) + { + top_blob.create(w, h, channels); + if (bottom_blob1.w == 1) + { + ret = binary_op<std::multiplies<float>>(bottom_blob, bottom_blob1, top_blob); + // return ret; + goto out; + } + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float b0 = bottom_blob1[q]; + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vdupq_n_f32(b0); + + _p1 = vmulq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (*in1 * b0); + in1++; + out++; + } + } + } + else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3) + { + top_blob.create(w1, h1, channels1); + if (top_blob.empty()) + return -100; + + if (bottom_blob.w != bottom_blob1.c) + { + ret = binary_op<std::multiplies<float>>(bottom_blob, bottom_blob1, top_blob); + goto out; + } + + float *pt = (float *)bottom_blob.data; + +#pragma omp parallel for + for (int q = 0; q < channels1; q++) + { + const float a0 = pt[q]; + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size1 >> 2; + int remain = size1 - (nn << 2); + + float *in1 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vdupq_n_f32(a0); + float32x4_t _p2 = vld1q_f32(in1); + + _p1 = vmulq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (a0 * *in1); + in1++; + out++; + } + } + } + else + ret = binary_op<std::multiplies<float>>(bottom_blob, bottom_blob1, top_blob); + } + + if (op_type == BinaryOp::Operation_DIV) + { + if (bottom_blob.dims == 3 && bottom_blob1.dims == 3) + { + top_blob.create(w, h, channels); + if (bottom_blob1.w == 1 && bottom_blob1.h == 1) + { + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + +#if __ARM_NEON + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + float tt = *ptr1; + + float32x4_t _p2 = vdupq_n_f32(tt); + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + + float32x4_t _p3 = vrecpeq_f32(_p2); + _p3 = vmulq_f32(vrecpsq_f32(_p2, _p3), _p3); + _p1 = vmulq_f32(_p1, _p3); + + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (*in1 / tt); + in1++; + out++; + } + +#else + float tt = *ptr1; + for (int i = 0; i < size; i++) + { + outptr[i] = (ptr[i] / tt); + } +#endif + } + + // return 0; + goto out; + } + else + { +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *in2 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vld1q_f32(in2); + + float32x4_t _p3 = vrecpeq_f32(_p2); + _p2 = vmulq_f32(vrecpsq_f32(_p2, _p3), _p3); + _p1 = vmulq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + in2 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = *in1 / *in2; + in1++; + in2++; + out++; + } + } + } + } + else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1) + { + top_blob.create(w, h, channels); + if (bottom_blob1.w == 1) + { + ret = binary_op<std::divides<float>>(bottom_blob, bottom_blob1, top_blob); + // return ret; + goto out; + } + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float b0 = bottom_blob1[q]; + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vdupq_n_f32(b0); + + //_p1 = vsubq_f32(_p1, _p2); + float32x4_t _p3 = vrecpeq_f32(_p2); + _p2 = vmulq_f32(vrecpsq_f32(_p2, _p3), _p3); + _p1 = vmulq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (*in1 / b0); + in1++; + out++; + } + } + } + else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3) + { + top_blob.create(w1, h1, channels1); + if (top_blob.empty()) + return -100; + +#pragma omp parallel for + for (int q = 0; q < channels1; q++) + { + const float a0 = bottom_blob[q]; + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size1 >> 2; + int remain = size1 - (nn << 2); + + float *in1 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vdupq_n_f32(a0); + float32x4_t _p2 = vld1q_f32(in1); + + //_p1 = vsubq_f32(_p1, _p2); + float32x4_t _p3 = vrecpeq_f32(_p2); + _p2 = vmulq_f32(vrecpsq_f32(_p2, _p3), _p3); + _p1 = vmulq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (a0 / *in1); + in1++; + out++; + } + } + } + else + ret = binary_op<std::divides<float>>(bottom_blob, bottom_blob1, top_blob); + } + + if (op_type == BinaryOp::Operation_MAX) + ret = binary_op<binary_op_max<float>>(bottom_blob, bottom_blob1, top_blob); + + if (op_type == BinaryOp::Operation_MIN) + ret = binary_op<binary_op_min<float>>(bottom_blob, bottom_blob1, top_blob); + + if (op_type == BinaryOp::Operation_POW) + { + if (bottom_blob.dims == 3 && bottom_blob1.dims == 3) + { + top_blob.create(w, h, channels); +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *in2 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vld1q_f32(in2); + + _p1 = pow_ps(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + in2 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = pow(*in1, *in2); + in1++; + in2++; + out++; + } + } + } + else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1) + { + top_blob.create(w, h, channels); + if (bottom_blob1.w == 1) + { + ret = binary_op<binary_op_pow<float>>(bottom_blob, bottom_blob1, top_blob); + // return ret; + goto out; + } + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float b0 = bottom_blob1[q]; + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vdupq_n_f32(b0); + + _p1 = pow_ps(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = pow(*in1, b0); + in1++; + out++; + } + } + } + else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3) + { + top_blob.create(w1, h1, channels1); + if (top_blob.empty()) + return -100; + +#pragma omp parallel for + for (int q = 0; q < channels1; q++) + { + const float a0 = bottom_blob[q]; + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size1 >> 2; + int remain = size1 - (nn << 2); + + float *in1 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vdupq_n_f32(a0); + float32x4_t _p2 = vld1q_f32(in1); + + _p1 = pow_ps(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = pow(a0, *in1); + in1++; + out++; + } + } + } + else + ret = binary_op<binary_op_pow<float>>(bottom_blob, bottom_blob1, top_blob); + } + + if (op_type == BinaryOp::Operation_SQUAREDDIFFERENCE) + { + if (bottom_blob.dims == 3 && bottom_blob1.dims == 3) + { + top_blob.create(w, h, channels); + + if (bottom_blob1.w == 1 && bottom_blob1.h == 1) + { + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + +#if __ARM_NEON + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + float tt = *ptr1; + + float32x4_t _p2 = vdupq_n_f32(tt); + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + + _p1 = vsubq_f32(_p1, _p2); + _p1 = vmulq_f32(_p1, _p1); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + float t2 = *in1 - tt; + *out = t2 * t2; + in1++; + out++; + } + +#else + float tt = *ptr1; + for (int i = 0; i < size; i++) + { + float t2 = (ptr[i] - tt); + outptr[i] = t2 * t2; + } +#endif + } + + ret = 0; + } + else + { +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *in2 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vld1q_f32(in2); + + _p1 = vsubq_f32(_p1, _p2); + _p1 = vmulq_f32(_p1, _p1); + vst1q_f32(out, _p1); + in1 += 4; + in2 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (*in1 - *in2) * (*in1 - *in2); + in1++; + in2++; + out++; + } + } + } + } + else if (bottom_blob.dims == 3 && bottom_blob1.dims == 1) + { + top_blob.create(w, h, channels); + if (bottom_blob1.w == 1) + { + ret = binary_op<binary_op_SquaredDifference<float>>(bottom_blob, bottom_blob1, top_blob); + // return ret; + goto out; + } + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const float *ptr = bottom_blob.channel(q); + const float b0 = bottom_blob1[q]; + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vdupq_n_f32(b0); + + _p1 = vsubq_f32(_p1, _p2); + _p1 = vmulq_f32(_p1, _p1); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (*in1 - b0) * (*in1 - b0); + in1++; + out++; + } + } + } + else if (bottom_blob.dims == 1 && bottom_blob1.dims == 3) + { + top_blob.create(w1, h1, channels1); + if (top_blob.empty()) + return -100; + +#pragma omp parallel for + for (int q = 0; q < channels1; q++) + { + const float a0 = bottom_blob[q]; + const float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size1 >> 2; + int remain = size1 - (nn << 2); + + float *in1 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vdupq_n_f32(a0); + float32x4_t _p2 = vld1q_f32(in1); + + _p1 = vsubq_f32(_p1, _p2); + _p1 = vmulq_f32(_p1, _p1); + vst1q_f32(out, _p1); + in1 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = (a0 - *in1) * (a0 - *in1); + in1++; + out++; + } + } + } + else + ret = binary_op<binary_op_SquaredDifference<float>>(bottom_blob, bottom_blob1, top_blob); + } + +#endif // 0 (Disable operation except Operation_ADD) + +#else + + if (op_type == BinaryOp::Operation_ADD) + ret = binary_op<std::plus<float>>(bottom_blob, bottom_blob1, top_blob); + + if (op_type == BinaryOp::Operation_SUB) + ret = binary_op<std::minus<float>>(bottom_blob, bottom_blob1, top_blob); + + if (op_type == BinaryOp::Operation_MUL) + ret = binary_op<std::multiplies<float>>(bottom_blob, bottom_blob1, top_blob); + + if (op_type == BinaryOp::Operation_DIV) + ret = binary_op<std::divides<float>>(bottom_blob, bottom_blob1, top_blob); + + if (op_type == BinaryOp::Operation_MAX) + ret = binary_op<binary_op_max<float>>(bottom_blob, bottom_blob1, top_blob); + + if (op_type == BinaryOp::Operation_MIN) + ret = binary_op<binary_op_min<float>>(bottom_blob, bottom_blob1, top_blob); + + if (op_type == BinaryOp::Operation_POW) + ret = binary_op<binary_op_pow<float>>(bottom_blob, bottom_blob1, top_blob); + if (op_type == BinaryOp::Operation_SQUAREDDIFFERENCE) + ret = binary_op<binary_op_SquaredDifference<float>>(bottom_blob, bottom_blob1, top_blob); +#endif + +/* +for (int p = 0; p < top_blob.c && p < 5; p++) +{ + float* outptr = top_blob.channel(p); + printf("channel: %d\n", p); + for (int i = 0; i < 1; i++) + { + for (int j = 0; j < 5; j++) + { + printf("%f ", outptr[j]); + } + printf("\n"); + outptr += top_blob.w; + } +} +printf("----------------------------\n"); +*/ + +out: + return ret; +} + +int ncnn_binary_op_inplace(const BinaryOpParam ¶m, Mat &bottom_top_blob) +{ + auto op_type = param.op_type; + auto b = param.b; + + // printf("-------------------BinaryOp-----forward_inplace----------\n"); + if (op_type == BinaryOp::Operation_ADD) + return binary_op_scalar_inplace<std::plus<float>>(bottom_top_blob, b); + + if (op_type == BinaryOp::Operation_SUB) + return binary_op_scalar_inplace<std::minus<float>>(bottom_top_blob, b); + + if (op_type == BinaryOp::Operation_MUL) + return binary_op_scalar_inplace<std::multiplies<float>>(bottom_top_blob, b); + + if (op_type == BinaryOp::Operation_DIV) + return binary_op_scalar_inplace<std::divides<float>>(bottom_top_blob, b); + + if (op_type == BinaryOp::Operation_MAX) + return binary_op_scalar_inplace<binary_op_max<float>>(bottom_top_blob, b); + + if (op_type == BinaryOp::Operation_MIN) + return binary_op_scalar_inplace<binary_op_min<float>>(bottom_top_blob, b); + + if (op_type == BinaryOp::Operation_POW) + return binary_op_scalar_inplace<binary_op_pow<float>>(bottom_top_blob, b); + + if (op_type == BinaryOp::Operation_SQUAREDDIFFERENCE) + return binary_op_scalar_inplace<binary_op_SquaredDifference<float>>(bottom_top_blob, b); + + return 0; +} + +int ncnn_binary_op_inplace(const BinaryOpParam ¶m, Mat &bottom_blob, Mat &bottom_top_blob) +{ + int ret = 0; + + Mat &bottom_blob1 = bottom_top_blob; + Mat &top_blob = bottom_top_blob; + auto op_type = param.op_type; + + if (op_type == BinaryOp::Operation_ADD) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + int size = w * h; + +// Unused variables +// int w1 = bottom_blob1.w; +// int h1 = bottom_blob1.h; +// int channels1 = bottom_blob1.c; +// int size1 = w1 * h1; + +#if __ARM_NEON + + if (bottom_blob.dims == 3 && bottom_blob1.dims == 3) + { +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + float *ptr = bottom_blob.channel(q); + float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + int nn = size >> 2; + int remain = size - (nn << 2); + + float *in1 = const_cast<float *>(ptr); + float *in2 = const_cast<float *>(ptr1); + float *out = const_cast<float *>(outptr); + + for (; nn > 0; nn--) + { + float32x4_t _p1 = vld1q_f32(in1); + float32x4_t _p2 = vld1q_f32(in2); + + _p1 = vaddq_f32(_p1, _p2); + vst1q_f32(out, _p1); + in1 += 4; + in2 += 4; + out += 4; + } + for (; remain > 0; remain--) + { + *out = *in1 + *in2; + in1++; + in2++; + out++; + } + } + } +#else + if (bottom_blob.dims == 3 && bottom_blob1.dims == 3) + { +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + float *ptr = bottom_blob.channel(q); + float *ptr1 = bottom_blob1.channel(q); + float *outptr = top_blob.channel(q); + + for (int i = 0; i < size; i++) + { + outptr[i] = ptr[i] + ptr1[i]; + } + } + return 0; + } +#endif + } + else + { + return -1; + } + return ret; +} + +} // namespace ncnn +} // namespace ncnn diff --git a/compute/ncnn/src/layer/instance_norm.cc b/compute/ncnn/src/layer/instance_norm.cc new file mode 100644 index 000000000..08c3f2c23 --- /dev/null +++ b/compute/ncnn/src/layer/instance_norm.cc @@ -0,0 +1,371 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ncnn/layer/instance_norm.h" +#ifdef _OPENMP +#include <omp.h> +#endif + +#include <math.h> +#include "ncnn/mat.h" +#ifdef __ARM_NEON +#include <arm_neon.h> +#endif // __ARM_NEON + +namespace nnfw +{ +namespace ncnn +{ + +void ncnn_instance_norm_rowmajor(Mat &in_mat, Mat &out_mat, Mat &gamma_mat, Mat &beta_mat, + int channels, float eps) +{ + // x = (x - mean) / (sqrt(var) + eps) * gamma + beta + + int w = in_mat.w; + int h = in_mat.h; + int size = w * h; +#ifdef __ARM_NEON + int nn = size >> 2; + int left4 = size & 3; +#endif + +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { +#ifdef __ARM_NEON + float *in_ptr = in_mat.channel(q); + float *out_ptr = out_mat.channel(q); + float32x4_t _sum = vdupq_n_f32(0.f); + float32x4_t _sq_sum = vdupq_n_f32(0.f); + for (int n = nn; n > 0; n--) + { + float32x4_t _p = vld1q_f32(in_ptr); + _sum = vaddq_f32(_sum, _p); + _p = vmulq_f32(_p, _p); + _sq_sum = vaddq_f32(_sq_sum, _p); + in_ptr += 4; + } + float sum = vgetq_lane_f32(_sum, 0) + vgetq_lane_f32(_sum, 1); + sum += vgetq_lane_f32(_sum, 2); + sum += vgetq_lane_f32(_sum, 3); + float sqsum = vgetq_lane_f32(_sq_sum, 0) + vgetq_lane_f32(_sq_sum, 1); + sqsum += vgetq_lane_f32(_sq_sum, 2); + sqsum += vgetq_lane_f32(_sq_sum, 3); + + for (int left = left4; left > 0; left--) + { + sum += *in_ptr; + sqsum += (*in_ptr) * (*in_ptr); + in_ptr++; + } + + float mean = sum / size; + float var = sqsum / size - mean * mean; + float gamma = gamma_mat[q]; + float beta = beta_mat[q]; + float a = gamma / (sqrt(var + eps)); + float b = -mean * a + beta; + + in_ptr = in_mat.channel(q); + float32x4_t _a = vdupq_n_f32(a); + float32x4_t _b = vdupq_n_f32(b); + for (int n = nn; n > 0; n--) + { + float32x4_t _p = vld1q_f32(in_ptr); + _p = vmulq_f32(_p, _a); + _p = vaddq_f32(_p, _b); + vst1q_f32(out_ptr, _p); + in_ptr += 4; + out_ptr += 4; + } + for (int left = left4; left > 0; left--) + { + *out_ptr = (*in_ptr) * a + b; + in_ptr++; + out_ptr++; + } +#else + float *in_ptr = in_mat.channel(q); + float *out_ptr = out_mat.channel(q); + // mean and var + float sum = 0.f; + float sqsum = 0.f; + for (int i = 0; i < size; i++) + { + sum += in_ptr[i]; + sqsum += in_ptr[i] * in_ptr[i]; + } + float mean = sum / size; + float var = sqsum / size - mean * mean; + + float gamma = gamma_mat[q]; + float beta = beta_mat[q]; + + float a = gamma / (sqrt(var + eps)); + float b = -mean * a + beta; + for (int i = 0; i < size; i++) + { + out_ptr[i] = in_ptr[i] * a + b; + } +#endif + } +} + +void ncnn_instance_norm_colmajor(Mat &in_mat, Mat &out_mat, Mat &gamma_mat, Mat &beta_mat, + int /*channels*/, float eps) +{ + // Treat CHW layout as HWC layout + int h = in_mat.c; + int w = in_mat.h; + int c = in_mat.w; + + int size = w * h; + int total = size * c; + + float sum[c] = {}; + float sqsum[c] = {}; + + float mean[c] = {}; + float var[c] = {}; + float a[c] = {}; + float b[c] = {}; + + float *in_ptr = in_mat.channel(0); + float *out_ptr = out_mat.channel(0); + +#pragma omp parallel for reduction(+ : sum, sqsum) schedule(guided) + for (int i = 0; i < total; i += c) + { + for (int j = 0; j < c; j++) + { + sum[j] += in_ptr[i + j]; + sqsum[j] += in_ptr[i + j] * in_ptr[i + j]; + } + } + + for (int i = 0; i < c; i++) + { + mean[i] = sum[i] / size; + var[i] = sqsum[i] / size - mean[i] * mean[i]; + a[i] = gamma_mat[i] / (sqrt(var[i] + eps)); + b[i] = -mean[i] * a[i] + beta_mat[i]; + } + +#pragma omp parallel for schedule(guided) + for (int i = 0; i < total; i += c) + { + for (int j = 0; j < c; j++) + { + out_ptr[i + j] = in_ptr[i + j] * a[j] + b[j]; + } + } +} + +void ncnn_instance_norm_with_relu_rowmajor(Mat &in_mat, Mat &out_mat, Mat &gamma_mat, Mat &beta_mat, + int channels, float eps, float /*slope*/) +{ + int w = in_mat.w; + int h = in_mat.h; + int size = w * h; +#ifdef __ARM_NEON + int nn = size >> 2; + int left4 = size & 3; +#endif +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { +#ifdef __ARM_NEON + float *in_ptr = in_mat.channel(q); + float *out_ptr = out_mat.channel(q); + float32x4_t _sum = vdupq_n_f32(0.f); + float32x4_t _sq_sum = vdupq_n_f32(0.f); + for (int n = nn; n > 0; n--) + { + float32x4_t _p = vld1q_f32(in_ptr); + _sum = vaddq_f32(_sum, _p); + _p = vmulq_f32(_p, _p); + _sq_sum = vaddq_f32(_sq_sum, _p); + in_ptr += 4; + } + // float sum = + // vgetq_lane_f32(_sum,0)+vgetq_lane_f32(_sum,1)+vgetq_lane_f32(_sum,2)+vgetq_lane_f32(_sum,3); + // float sqsum = vgetq_lane_f32(_sq_sum,0)+vgetq_lane_f32(_sq_sum,1)+ + // vgetq_lane_f32(_sq_sum,2)+vgetq_lane_f32(_sq_sum,3); + float sum = vgetq_lane_f32(_sum, 0) + vgetq_lane_f32(_sum, 1); + sum += vgetq_lane_f32(_sum, 2); + sum += vgetq_lane_f32(_sum, 3); + float sqsum = vgetq_lane_f32(_sq_sum, 0) + vgetq_lane_f32(_sq_sum, 1); + sqsum += vgetq_lane_f32(_sq_sum, 2); + sqsum += vgetq_lane_f32(_sq_sum, 3); + for (int left = left4; left > 0; left--) + { + sum += *in_ptr; + sqsum += (*in_ptr) * (*in_ptr); + in_ptr++; + } + + float mean = sum / size; + float var = sqsum / size - mean * mean; + float gamma = gamma_mat[q]; + float beta = beta_mat[q]; + float a = gamma / (sqrt(var + eps)); + float b = -mean * a + beta; + // TODO:slop is not used here , only for RELU which slop is always = 0; + in_ptr = in_mat.channel(q); + float32x4_t _a = vdupq_n_f32(a); + float32x4_t _b = vdupq_n_f32(b); + float32x4_t _zero = vdupq_n_f32(0.f); + for (int n = nn; n > 0; n--) + { + float32x4_t _p = vld1q_f32(in_ptr); + _p = vmulq_f32(_p, _a); + _p = vaddq_f32(_p, _b); + _p = vmaxq_f32(_p, _zero); + vst1q_f32(out_ptr, _p); + in_ptr += 4; + out_ptr += 4; + } + for (int left = left4; left > 0; left--) + { + int temp = (*in_ptr) * a + b; + *out_ptr = temp > 0 ? temp : 0; + in_ptr++; + out_ptr++; + } +#else + float *in_ptr = in_mat.channel(q); + float *out_ptr = out_mat.channel(q); + + // mean and var + float sum = 0.f; + float sqsum = 0.f; + for (int i = 0; i < size; i++) + { + sum += in_ptr[i]; + sqsum += in_ptr[i] * in_ptr[i]; + } + float mean = sum / size; + float var = sqsum / size - mean * mean; + + float gamma = gamma_mat[q]; + float beta = beta_mat[q]; + + float a = gamma / (sqrt(var + eps)); + float b = -mean * a + beta; + + if (slope == 0.f) + { + for (int i = 0; i < size; i++) + { + float temp = in_ptr[i] * a + b; + out_ptr[i] = temp > 0 ? temp : 0; + } + } + else + { + for (int i = 0; i < size; i++) + { + float temp = in_ptr[i] * a + b; + out_ptr[i] = temp > 0 ? temp : temp * slope; + } + } +#endif + } +} + +void ncnn_instance_norm_with_relu_colmajor(Mat &in_mat, Mat &out_mat, Mat &gamma_mat, Mat &beta_mat, + int /*channels*/, float eps, float slope) +{ + // Treat CHW layout as HWC layout + int h = in_mat.c; + int w = in_mat.h; + int c = in_mat.w; + + int size = w * h; + int total = size * c; + + float sum[c] = {}; + float sqsum[c] = {}; + + float mean[c] = {}; + float var[c] = {}; + float a[c] = {}; + float b[c] = {}; + + float *in_ptr = in_mat.channel(0); + float *out_ptr = out_mat.channel(0); + +#pragma omp parallel for reduction(+ : sum, sqsum) schedule(guided) + for (int i = 0; i < total; i += c) + { + for (int j = 0; j < c; j++) + { + sum[j] += in_ptr[i + j]; + sqsum[j] += in_ptr[i + j] * in_ptr[i + j]; + } + } + + for (int i = 0; i < c; i++) + { + mean[i] = sum[i] / size; + var[i] = sqsum[i] / size - mean[i] * mean[i]; + a[i] = gamma_mat[i] / (sqrt(var[i] + eps)); + b[i] = -mean[i] * a[i] + beta_mat[i]; + } + + if (slope == 0.f) + { +#pragma omp parallel for schedule(guided) + for (int i = 0; i < total; i += c) + { + for (int j = 0; j < c; j++) + { + float temp = in_ptr[i + j] * a[j] + b[j]; + out_ptr[i + j] = temp > 0 ? temp : 0; + } + } + } + else + { +#pragma omp parallel for schedule(guided) + for (int i = 0; i < total; i += c) + { + for (int j = 0; j < c; j++) + { + float temp = in_ptr[i + j] * a[j] + b[j]; + out_ptr[i + j] = temp > 0 ? temp : temp * slope; + } + } + } +} + +} // namespace ncnn + +} // namespace nnfw diff --git a/compute/ncnn/src/mat.cc b/compute/ncnn/src/mat.cc new file mode 100644 index 000000000..568378ef7 --- /dev/null +++ b/compute/ncnn/src/mat.cc @@ -0,0 +1,940 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ncnn/mat.h" + +#if __ARM_NEON +#include <arm_neon.h> +#endif // __ARM_NEON + +// Fix for nnfw: comment out cpu.h +//#include "cpu.h" + +namespace nnfw +{ +namespace ncnn +{ + +void Mat::substract_mean_normalize(const float *mean_vals, const float *norm_vals) +{ + int size = w * h; + + if (mean_vals && !norm_vals) + { +// substract mean only +#pragma omp parallel for + for (int q = 0; q < c; q++) + { + float *ptr = channel(q); // data + cstep * q; + const float mean = mean_vals[q]; + +#if __ARM_NEON + int nn = size >> 2; + int remain = size - (nn << 2); +#else + int remain = size; +#endif // __ARM_NEON + +#if __ARM_NEON +#if __aarch64__ + if (nn > 0) + { + asm volatile("dup v1.4s, %w4 \n" + "0: \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v0.4s}, [%1] \n" + "fsub v0.4s, v0.4s, v1.4s \n" + "subs %w0, %w0, #1 \n" + "st1 {v0.4s}, [%1], #16 \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr) // %1 + : "0"(nn), "1"(ptr), + "r"(mean) // %4 + : "cc", "memory", "v0", "v1"); + } +#else + if (nn > 0) + { + asm volatile("vdup.f32 q1, %4 \n" + "0: \n" + "pld [%1, #128] \n" + "vld1.f32 {d0-d1}, [%1 :128] \n" + "vsub.f32 q0, q0, q1 \n" + "subs %0, #1 \n" + "vst1.f32 {d0-d1}, [%1 :128]! \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr) // %1 + : "0"(nn), "1"(ptr), + "r"(mean) // %4 + : "cc", "memory", "q0", "q1"); + } +#endif // __aarch64__ +#endif // __ARM_NEON + for (; remain > 0; remain--) + { + *ptr -= mean; + ptr++; + } + } + } + else if (!mean_vals && norm_vals) + { +// normalize only +#pragma omp parallel for + for (int q = 0; q < c; q++) + { + float *ptr = channel(q); // data + cstep * q; + const float norm = norm_vals[q]; + +#if __ARM_NEON + int nn = size >> 2; + int remain = size - (nn << 2); +#else + int remain = size; +#endif // __ARM_NEON + +#if __ARM_NEON +#if __aarch64__ + if (nn > 0) + { + asm volatile("dup v1.4s, %w4 \n" + "0: \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v0.4s}, [%1] \n" + "fmul v0.4s, v0.4s, v1.4s \n" + "subs %w0, %w0, #1 \n" + "st1 {v0.4s}, [%1], #16 \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr) // %1 + : "0"(nn), "1"(ptr), + "r"(norm) // %4 + : "cc", "memory", "v0", "v1"); + } +#else + if (nn > 0) + { + asm volatile("vdup.f32 q1, %4 \n" + "0: \n" + "pld [%1, #128] \n" + "vld1.f32 {d0-d1}, [%1 :128] \n" + "vmul.f32 q0, q0, q1 \n" + "subs %0, #1 \n" + "vst1.f32 {d0-d1}, [%1 :128]! \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr) // %1 + : "0"(nn), "1"(ptr), + "r"(norm) // %4 + : "cc", "memory", "q0", "q1"); + } +#endif // __aarch64__ +#endif // __ARM_NEON + for (; remain > 0; remain--) + { + *ptr *= norm; + ptr++; + } + } + } + else if (mean_vals && norm_vals) + { +// substract mean and normalize +#pragma omp parallel for + for (int q = 0; q < c; q++) + { + float *ptr = channel(q); // data + cstep * q; + const float mean = mean_vals[q]; + const float norm = norm_vals[q]; + +#if __ARM_NEON + int nn = size >> 2; + int remain = size - (nn << 2); +#else + int remain = size; +#endif // __ARM_NEON + +#if __ARM_NEON +#if __aarch64__ + if (nn > 0) + { + asm volatile("dup v1.4s, %w4 \n" + "dup v2.4s, %w5 \n" + "0: \n" + "prfm pldl1keep, [%1, #128] \n" + "ld1 {v0.4s}, [%1] \n" + "fsub v0.4s, v0.4s, v1.4s \n" + "fmul v0.4s, v0.4s, v2.4s \n" + "subs %w0, %w0, #1 \n" + "st1 {v0.4s}, [%1], #16 \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr) // %1 + : "0"(nn), "1"(ptr), + "r"(mean), // %4 + "r"(norm) // %5 + : "cc", "memory", "v0", "v1", "v2"); + } +#else + if (nn > 0) + { + asm volatile("vdup.f32 q1, %4 \n" + "vdup.f32 q2, %5 \n" + "0: \n" + "pld [%1, #128] \n" + "vld1.f32 {d0-d1}, [%1 :128] \n" + "vsub.f32 q0, q0, q1 \n" + "vmul.f32 q0, q0, q2 \n" + "subs %0, #1 \n" + "vst1.f32 {d0-d1}, [%1 :128]! \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(ptr) // %1 + : "0"(nn), "1"(ptr), + "r"(mean), // %4 + "r"(norm) // %5 + : "cc", "memory", "q0", "q1", "q2"); + } +#endif // __aarch64__ +#endif // __ARM_NEON + for (; remain > 0; remain--) + { + *ptr = (*ptr - mean) * norm; + ptr++; + } + } + } +} + +// convert half precision floating point to float +static float half2float(unsigned short value) +{ + // 1 : 5 : 10 + unsigned short sign = (value & 0x8000) >> 15; + unsigned short exponent = (value & 0x7c00) >> 10; + unsigned short significand = value & 0x03FF; + + // fprintf(stderr, "%d %d %d\n", sign, exponent, significand); + + // 1 : 8 : 23 + union { + unsigned int u; + float f; + } tmp; + if (exponent == 0) + { + if (significand == 0) + { + // zero + tmp.u = (sign << 31); + } + else + { + // denormal + exponent = 0; + // find non-zero bit + while ((significand & 0x200) == 0) + { + significand <<= 1; + exponent++; + } + significand <<= 1; + significand &= 0x3FF; + tmp.u = (sign << 31) | ((-exponent + (-15 + 127)) << 23) | (significand << 13); + } + } + else if (exponent == 0x1F) + { + // infinity or NaN + tmp.u = (sign << 31) | (0xFF << 23) | (significand << 13); + } + else + { + // normalized + tmp.u = (sign << 31) | ((exponent + (-15 + 127)) << 23) | (significand << 13); + } + + return tmp.f; +} + +Mat Mat::from_float16(const unsigned short *data, int size) +{ + Mat m(size); + if (m.empty()) + return m; + + float *ptr = m; //.data; + +#if __ARM_NEON && (__ARM_FP & 2) + // Fix for nnfw: Alway support vfpv4 + // int nn = cpu_support_arm_vfpv4() ? size >> 2 : 0; + int nn = size >> 2; + int remain = size - (nn << 2); +#else + int remain = size; +#endif // __ARM_NEON + +#if __ARM_NEON && (__ARM_FP & 2) +#if __aarch64__ + if (nn > 0) + { + asm volatile("0: \n" + "ld1 {v0.4h}, [%1], #8 \n" + "fcvtl v1.4s, v0.4h \n" + "subs %w0, %w0, #1 \n" + "st1 {v1.4s}, [%2], #16 \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(data), // %1 + "=r"(ptr) // %2 + : "0"(nn), "1"(data), "2"(ptr) + : "cc", "memory", "v0", "v1"); + } +#else + if (nn > 0) + { + asm volatile("0: \n" + "pld [%1, #64] \n" + "vld1.s16 {d0}, [%1 :64]! \n" + "vcvt.f32.f16 q1, d0 \n" + "subs %0, #1 \n" + "vst1.f32 {d2-d3}, [%2 :128]! \n" + "bne 0b \n" + : "=r"(nn), // %0 + "=r"(data), // %1 + "=r"(ptr) // %2 + : "0"(nn), "1"(data), "2"(ptr) + : "cc", "memory", "q0", "q1"); + } +#endif // __aarch64__ +#endif // __ARM_NEON + for (; remain > 0; remain--) + { + *ptr = half2float(*data); + + data++; + ptr++; + } + + return m; +} + +static void copy_make_border_image(const Mat &src, Mat &dst, int top, int left, int type, float v) +{ + int w = dst.w; + int h = dst.h; + + const float *ptr = src; //.data; + float *outptr = dst; //.data; + + if (type == BORDER_CONSTANT) + { + int y = 0; + // fill top + for (; y < top; y++) + { + int x = 0; + for (; x < w; x++) + { + outptr[x] = v; + } + outptr += w; + } + // fill center + for (; y < (top + src.h); y++) + { + int x = 0; + for (; x < left; x++) + { + outptr[x] = v; + } + if (src.w < 12) + { + for (; x < (left + src.w); x++) + { + outptr[x] = ptr[x - left]; + } + } + else + { + memcpy(outptr + left, ptr, src.w * sizeof(float)); + x += src.w; + } + for (; x < w; x++) + { + outptr[x] = v; + } + ptr += src.w; + outptr += w; + } + // fill bottom + for (; y < h; y++) + { + int x = 0; + for (; x < w; x++) + { + outptr[x] = v; + } + outptr += w; + } + } + else if (type == BORDER_REPLICATE) + { + int y = 0; + // fill top + for (; y < top; y++) + { + int x = 0; + for (; x < left; x++) + { + outptr[x] = ptr[0]; + } + if (src.w < 12) + { + for (; x < (left + src.w); x++) + { + outptr[x] = ptr[x - left]; + } + } + else + { + memcpy(outptr + left, ptr, src.w * sizeof(float)); + x += src.w; + } + for (; x < w; x++) + { + outptr[x] = ptr[src.w - 1]; + } + outptr += w; + } + // fill center + for (; y < (top + src.h); y++) + { + int x = 0; + for (; x < left; x++) + { + outptr[x] = ptr[0]; + } + if (src.w < 12) + { + for (; x < (left + src.w); x++) + { + outptr[x] = ptr[x - left]; + } + } + else + { + memcpy(outptr + left, ptr, src.w * sizeof(float)); + x += src.w; + } + for (; x < w; x++) + { + outptr[x] = ptr[src.w - 1]; + } + ptr += src.w; + outptr += w; + } + // fill bottom + ptr -= src.w; + for (; y < h; y++) + { + int x = 0; + for (; x < left; x++) + { + outptr[x] = ptr[0]; + } + if (src.w < 12) + { + for (; x < (left + src.w); x++) + { + outptr[x] = ptr[x - left]; + } + } + else + { + memcpy(outptr + left, ptr, src.w * sizeof(float)); + x += src.w; + } + for (; x < w; x++) + { + outptr[x] = ptr[src.w - 1]; + } + outptr += w; + } + } +} + +#if defined(_MEMORY_TO_TIME_) && defined(_TIME_TO_MEMORY_) +static void copy_make_border_image_inplace(const Mat &src, Mat &dst, int top, int left, int type, + float v) +{ + int w = dst.w; + int h = dst.h; + + const float *ptr = src; + float *outptr = dst; + + if (type == BORDER_CONSTANT) + { + // fill bottom + int y = src.h + top; + outptr += y * w; + for (; y < h; y++) + { + int x = 0; + for (; x < w; x++) + { + outptr[x] = v; + } + outptr += w; + } + + // fill center + y = src.h + top - 1; + outptr = dst; + outptr += y * w; + ptr += (src.h - 1) * src.w; + + for (; y >= top; y--) + { + int x = left + src.w; + for (; x < w; x++) + { + outptr[x] = v; + } + + x = left + src.w - 1; + + for (; x >= left; x--) + { + outptr[x] = ptr[x - left]; + } + + for (x = 0; x < left; x++) + { + outptr[x] = v; + } + ptr -= src.w; + outptr -= w; + } + + // fill top + y = 0; + outptr = dst; + for (; y < top; y++) + { + int x = 0; + for (; x < w; x++) + { + outptr[x] = v; + } + outptr += w; + } + } +} +#endif // _MEMORY_TO_TIME_ && _TIME_TO_MEMORY_ + +void copy_make_border(const Mat &src, Mat &dst, int top, int bottom, int left, int right, int type, + float v) +{ + int w = src.w + left + right; + int h = src.h + top + bottom; + + if (w == src.w && h == src.h) + { + dst = src; + return; + } + + if (src.dims == 2) + { + dst.create(w, h); + if (dst.empty()) + return; + copy_make_border_image(src, dst, top, left, type, v); + } + else if (src.dims == 3) + { + int channels = src.c; + dst.create(w, h, channels); + if (dst.empty()) + return; + + if (src.data != dst.data) + { +// unroll image channel +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const Mat m = src.channel(q); + Mat borderm = dst.channel(q); + + copy_make_border_image(m, borderm, top, left, type, v); + } + } + else + { +#if defined(_MEMORY_TO_TIME_) && defined(_TIME_TO_MEMORY_) + for (int q = channels - 1; q >= 0; q--) + { + Mat m = src.channel(q); + Mat borderm = dst.channel(q); + copy_make_border_image_inplace(m, borderm, top, left, type, v); + } +#else +// unroll image channel +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const Mat m = src.channel(q); + Mat borderm = dst.channel(q); + + copy_make_border_image(m, borderm, top, left, type, v); + } +#endif // _MEMORY_TO_TIME_ && _TIME_TO_MEMORY_ + } + } +} + +static void copy_cut_border_image(const Mat &src, Mat &dst, int top, int left) +{ + int w = dst.w; + int h = dst.h; + + const float *ptr = src.row(top) + left; //.data + src.w * top + left; + float *outptr = dst; //.data; + + for (int y = 0; y < h; y++) + { + if (w < 12) + { + for (int x = 0; x < w; x++) + { + outptr[x] = ptr[x]; + } + } + else + { + memcpy(outptr, ptr, w * sizeof(float)); + } + outptr += w; + ptr += src.w; + } +} + +void copy_cut_border(const Mat &src, Mat &dst, int top, int bottom, int left, int right) +{ + int w = src.w - left - right; + int h = src.h - top - bottom; + +#ifndef _MEMORY_TO_TIME_ + if (w == src.w && h == src.h) + { + dst = src; + return; + } +#endif + + if (src.dims == 2) + { + dst.create(w, h); + if (dst.empty()) + return; + + copy_cut_border_image(src, dst, top, left); + } + else if (src.dims == 3) + { + int channels = src.c; + + dst.create(w, h, channels); + if (dst.empty()) + return; + +#if !defined(_MEMORY_TO_TIME_) || !defined(_TIME_TO_MEMORY_) +// unroll image channel +#pragma omp parallel for +#endif + for (int q = 0; q < channels; q++) + { + const Mat m = src.channel(q); + Mat cutm = dst.channel(q); + + copy_cut_border_image(m, cutm, top, left); + } + } +} + +static void resize_bilinear_image(const Mat &src, Mat &dst, int w, int h) +{ + double scale_x = (double)src.w / w; + double scale_y = (double)src.h / h; + + int *buf = new int[w + h + w * 2 + h * 2]; + + int *xofs = buf; // new int[w]; + int *yofs = buf + w; // new int[h]; + + float *alpha = (float *)(buf + w + h); // new float[w * 2]; + float *beta = (float *)(buf + w + h + w * 2); // new float[h * 2]; + + float fx; + float fy; + int sx; + int sy; + + for (int dx = 0; dx < w; dx++) + { + fx = (float)((dx + 0.5) * scale_x - 0.5); + sx = fx; // cvFloor(fx); + fx -= sx; + + if (sx >= src.w - 1) + { + sx = src.w - 2; + fx = 1.f; + } + + xofs[dx] = sx; + + alpha[dx * 2] = 1.f - fx; + alpha[dx * 2 + 1] = fx; + } + + for (int dy = 0; dy < h; dy++) + { + fy = (float)((dy + 0.5) * scale_y - 0.5); + sy = fy; // cvFloor(fy); + fy -= sy; + + if (sy >= src.h - 1) + { + sy = src.h - 2; + fy = 1.f; + } + + yofs[dy] = sy; + + beta[dy * 2] = 1.f - fy; + beta[dy * 2 + 1] = fy; + } + + // loop body + Mat rowsbuf0(w + 1); + Mat rowsbuf1(w + 1); + float *rows0 = rowsbuf0; + float *rows1 = rowsbuf1; + + int prev_sy1 = -1; + + for (int dy = 0; dy < h; dy++) + { + int sy = yofs[dy]; + + if (sy == prev_sy1) + { + // hresize one row + float *rows0_old = rows0; + rows0 = rows1; + rows1 = rows0_old; + const float *S1 = src.row(sy + 1); + + const float *alphap = alpha; + float *rows1p = rows1; + int dx = 0; +#if __ARM_NEON + for (; dx + 1 < w; dx += 2) + { + int sx = xofs[dx]; + int sxn = xofs[dx + 1]; + const float *S1p = S1 + sx; + const float *S1np = S1 + sxn; + + float32x4_t _a = vld1q_f32(alphap); + float32x2_t _S1 = vld1_f32(S1p); + float32x2_t _S1n = vld1_f32(S1np); + + float32x4_t _S1S1n = vcombine_f32(_S1, _S1n); + float32x4_t _ms1 = vmulq_f32(_S1S1n, _a); + float32x2_t _rows1 = vpadd_f32(vget_low_f32(_ms1), vget_high_f32(_ms1)); + + vst1_f32(rows1p + dx, _rows1); + + alphap += 4; + } +#endif // __ARM_NEON + for (; dx < w; dx++) + { + int sx = xofs[dx]; + const float *S1p = S1 + sx; + + float a0 = alphap[0]; + float a1 = alphap[1]; + rows1p[dx] = S1p[0] * a0 + S1p[1] * a1; + + alphap += 2; + } + } + else + { + // hresize two rows + const float *S0 = src.row(sy); + const float *S1 = src.row(sy + 1); + + const float *alphap = alpha; + float *rows0p = rows0; + float *rows1p = rows1; + int dx = 0; +#if __ARM_NEON + for (; dx + 1 < w; dx += 2) + { + int sx = xofs[dx]; + int sxn = xofs[dx + 1]; + const float *S0p = S0 + sx; + const float *S1p = S1 + sx; + const float *S0np = S0 + sxn; + const float *S1np = S1 + sxn; + + float32x4_t _a = vld1q_f32(alphap); + float32x2_t _S0 = vld1_f32(S0p); + float32x2_t _S1 = vld1_f32(S1p); + float32x2_t _S0n = vld1_f32(S0np); + float32x2_t _S1n = vld1_f32(S1np); + + float32x4_t _S0S0n = vcombine_f32(_S0, _S0n); + float32x4_t _S1S1n = vcombine_f32(_S1, _S1n); + float32x4_t _ms0 = vmulq_f32(_S0S0n, _a); + float32x4_t _ms1 = vmulq_f32(_S1S1n, _a); + float32x2_t _rows0 = vpadd_f32(vget_low_f32(_ms0), vget_high_f32(_ms0)); + float32x2_t _rows1 = vpadd_f32(vget_low_f32(_ms1), vget_high_f32(_ms1)); + + vst1_f32(rows0p + dx, _rows0); + vst1_f32(rows1p + dx, _rows1); + + alphap += 4; + } +#endif // __ARM_NEON + for (; dx < w; dx++) + { + int sx = xofs[dx]; + const float *S0p = S0 + sx; + const float *S1p = S1 + sx; + + float a0 = alphap[0]; + float a1 = alphap[1]; + rows0p[dx] = S0p[0] * a0 + S0p[1] * a1; + rows1p[dx] = S1p[0] * a0 + S1p[1] * a1; + + alphap += 2; + } + } + + prev_sy1 = sy + 1; + + // vresize + float b0 = beta[0]; + float b1 = beta[1]; + + float *rows0p = rows0; + float *rows1p = rows1; + float *Dp = dst.row(dy); + +#if __ARM_NEON + int nn = w >> 3; +#else + int nn = 0; +#endif + int remain = w - (nn << 3); + +#if __ARM_NEON + float32x4_t _b0 = vdupq_n_f32(b0); + float32x4_t _b1 = vdupq_n_f32(b1); + for (; nn > 0; nn--) + { + float32x4_t _rows0 = vld1q_f32(rows0p); + float32x4_t _rows1 = vld1q_f32(rows1p); + + float32x4_t _D = vmulq_f32(_rows0, _b0); + _D = vmlaq_f32(_D, _rows1, _b1); + + vst1q_f32(Dp, _D); + + float32x4_t _rows0n = vld1q_f32(rows0p + 4); + float32x4_t _rows1n = vld1q_f32(rows1p + 4); + + float32x4_t _Dn = vmulq_f32(_rows0n, _b0); + _Dn = vmlaq_f32(_Dn, _rows1n, _b1); + + vst1q_f32(Dp + 4, _Dn); + + Dp += 8; + rows0p += 8; + rows1p += 8; + } +#endif // __ARM_NEON + for (; remain; --remain) + { + // D[x] = rows0[x]*b0 + rows1[x]*b1; + *Dp++ = *rows0p++ * b0 + *rows1p++ * b1; + } + + beta += 2; + } + + delete[] buf; +} + +void resize_bilinear(const Mat &src, Mat &dst, int w, int h) +{ + if (w == src.w && h == src.h) + { + dst = src; + return; + } + + if (src.dims == 2) + { + dst.create(w, h); + if (dst.empty()) + return; + + resize_bilinear_image(src, dst, w, h); + } + else if (src.dims == 3) + { + int channels = src.c; + + dst.create(w, h, channels); + if (dst.empty()) + return; + +// unroll image channel +#pragma omp parallel for + for (int q = 0; q < channels; q++) + { + const Mat m = src.channel(q); + Mat resizem = dst.channel(q); + + resize_bilinear_image(m, resizem, w, h); + } + } +} + +} // namespace ncnn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/common.h b/compute/ncnn/src/srcn/common.h new file mode 100644 index 000000000..778a17a80 --- /dev/null +++ b/compute/ncnn/src/srcn/common.h @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_COMMON_H__ +#define __NNFW_SRCN_COMMON_H__ + +#include <string.h> +#include <limits> +#include <arm_neon.h> + +#include "ncnn/srcn/conv_type.h" + +namespace nnfw +{ +namespace srcn +{ + +#define sizeof_RhsScalar 4 +#define sizeof_LhsScalar 4 +#define sizeof_ResScalar 4 + +#define MIN(a, b) (a) > (b) ? (b) : (a) +#define MAX(a, b) (a) > (b) ? (a) : (b) + +enum shardType_t +{ + shardByCol = 0, + shardByRow +}; + +#ifdef TIZEN +#define L1_CACHE_SIZE (16536 * 2) +#define L2_CACHE_SIZE (524288 * 2) +#define L3_CACHE_SIZE (0) // no L3 +#define MAX_K (512) +// single-thread +#define GEN_COL (1440) +// multi-threads +#define MAX_COL (90) +#define MIN_COL (32) +#elif defined ANDROID +#define L1_CACHE_SIZE (16536 * 4) +#define L2_CACHE_SIZE (524288 * 8) +#define L3_CACHE_SIZE (0) //(524288 * 8) //no L3 +#define MAX_K (512 * 2) +// single-thread +#define GEN_COL (1440) +// multi-threads +#if __aarch64__ +#define MAX_COL (1024) +#else +#define MAX_COL (90) +#endif +#define MIN_COL (32) +#endif + +enum +{ + USE_COMMON_KENEL = 0, + USE_12BIT_KERNEL, + USE_NONZERO_KERENL +}; + +template <typename T> static T divup(const T &x, const T &y) +{ + return static_cast<T>((x + y - 1) / y); +} + +#ifdef NCNN +static inline size_t alignSize(size_t sz, int n) { return (sz + n - 1) / n * n; } + +static inline size_t alignBy2(size_t sz) { return (sz + 1) & -2; } +#endif // NCNN + +static inline int32_t BitNot(int32_t a) { return ~a; } + +static inline int32_t MaskIfNonZero(int32_t a) +{ + static int32_t zero = 0; + return a ? BitNot(zero) : zero; +} + +static inline int32_t BitAnd(int32_t a, int32_t b) { return a & b; } + +static inline int32_t ShiftRight(int32_t a, int offset) { return a >> offset; } + +static inline int32_t MaskIfLessThan(int32_t a, int32_t b) { return MaskIfNonZero(a < b); } + +static inline int32_t MaskIfGreaterThan(int32_t a, int32_t b) { return MaskIfNonZero(a > b); } + +static inline int32_t Add(int32_t a, int32_t b) { return a + b; } + +static inline int32_t RoundingDivideByPOT(int32_t x, int exponent) +{ + const int32_t mask = (1ll << exponent) - 1; + const int32_t zero = 0; + const int32_t one = 1; + const int32_t remainder = BitAnd(x, mask); + const int32_t threshold = Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); + return Add(ShiftRight(x, exponent), BitAnd(MaskIfGreaterThan(remainder, threshold), one)); +} +static inline int32_t SaturatingRoundingDoublingHighMul(int32_t a, int32_t b) +{ + bool overflow = a == b && a == std::numeric_limits<int32_t>::min(); + int64_t a_64(a); + int64_t b_64(b); + int64_t ab_64 = a_64 * b_64; + int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); + int32_t ab_x2_high32 = static_cast<int32_t>((ab_64 + nudge) / (1ll << 31)); + return overflow ? std::numeric_limits<int32_t>::max() : ab_x2_high32; +} + +static inline int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, + int shift) +{ + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + return RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier), right_shift); +} + +static inline int32x4_t SaturatingRoundingDoublingHighMulV(int32x4_t a, int32x4_t b) +{ + return vqrdmulhq_s32(a, b); +} + +static inline int32x4_t RoundingDivideByPOTV(int32x4_t x, int exponent) +{ + const int32x4_t shift_vec = vdupq_n_s32(-exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift_vec), 31); + const int32x4_t fixed_up_x = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed_up_x, shift_vec); +} + +static inline int32x4_t MultiplyByQuantizedMultiplierV(int32x4_t x, int32_t quantized_multiplier, + int shift) +{ + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + return RoundingDivideByPOTV( + SaturatingRoundingDoublingHighMulV(vrshlq_s32(x, vdupq_n_s32(left_shift)), + vdupq_n_s32(quantized_multiplier)), + right_shift); +} + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_COMMON_H__ diff --git a/compute/ncnn/src/srcn/conv_sgemm_multithreads.cc b/compute/ncnn/src/srcn/conv_sgemm_multithreads.cc new file mode 100644 index 000000000..21083f677 --- /dev/null +++ b/compute/ncnn/src/srcn/conv_sgemm_multithreads.cc @@ -0,0 +1,483 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include "ncnn/srcn/conv_type.h" +#include "common.h" +#include "sgemm_kernel.h" +#include "sgemm_pack.h" +#include "conv_sgemm_multithreads.h" + +namespace nnfw +{ +namespace srcn +{ + +void conv_sgemm_multithreads::param_init() +{ +#if __aarch64__ + if (conv_type_ == row_major) + { + mr_ = 8; + nr_ = 12; + } + else if (conv_type_ == col_major) + { +#ifdef BATCH_DILATION_FIX + if (out_mat_.n > 1) + { + + mr_ = 24; + nr_ = 4; + } + else +#endif // BATCH_DILATION_FIX + { + if (m_ > n_) + { + mr_ = 24; + nr_ = 4; + } + else + { + mr_ = 12; + nr_ = 8; + } + } + } +#else // __aarch64__ + if (conv_type_ == row_major) + { + mr_ = 6; + nr_ = 8; + } + else if (conv_type_ == col_major) + { + mr_ = 8; + nr_ = 6; + } +#endif // __aarch64__ + int col = n_; + + if (m_ > n_) + { + shard_type_ = shardByRow; + col = m_; + } + else + { + shard_type_ = shardByCol; + } + + int th_base = divup(col, num_threads_); + + th_base = MIN(MAX(th_base, MIN_COL), MAX_COL); + + int k_div = (nr_ * sizeof_RhsScalar); + int k_sub = (mr_ * nr_ * sizeof_ResScalar); + + const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div * 2), MAX_K); + bk_ = MIN(k_cache, k_); + + if (shard_type_ == shardByCol) + { + int m_sub = (bk_ * nr_ * sizeof_RhsScalar); + int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + if (L3_CACHE_SIZE) + m_div = (sizeof_LhsScalar * bk_ * 2); + int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div); + bm_ = MIN(m_cache, m_); + + bn_ = MIN(th_base, n_); + if (L3_CACHE_SIZE) + { + int n_sub = (bk_ * bm_ * sizeof_RhsScalar); + int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + int n_cache = divup((L3_CACHE_SIZE - n_sub), n_div); + bn_ = MIN(n_cache, bn_); + } + } + else + { + int n_sub = (bk_ * mr_ * sizeof_LhsScalar); + int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + if (L3_CACHE_SIZE) + n_div = (sizeof_LhsScalar * bk_ * 2); + int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div); + bn_ = MIN(n_cache, n_); + + bm_ = MIN(th_base, m_); + if (L3_CACHE_SIZE) + { + int m_sub = (bk_ * bn_ * sizeof_RhsScalar); + int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + int m_cache = divup((L3_CACHE_SIZE - m_sub), m_div); + bm_ = MIN(m_cache, bm_); + } + } + + nm_ = divup(m_, bm_); + nn_ = divup(n_, bn_); + nk_ = divup(k_, bk_); + + rm_ = m_ % bm_; + rn_ = n_ % bn_; + rk_ = k_ % bk_; +} + +conv_sgemm_multithreads::conv_sgemm_multithreads(const convMat_t &in_mat, + const convMat_t &weights_mat, convMat_t &out_mat, + const convParams_t &in_param, int num_threads, + convType_t conv_type) + + : in_mat_(in_mat), weights_mat_(weights_mat), out_mat_(out_mat), in_param_(in_param), + conv_type_(conv_type), num_threads_(num_threads) +{ + m_ = out_mat_.c; +#ifdef NCNN +#ifdef WITH_DPU + np_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); + n_ = (np_ + 1) / 2; +#else // WITH_DPU + n_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); +#endif // WITH_DPU +#else // NCNN +#ifdef WITH_DPU + np_ = out_mat_.n * out_mat_.w * out_mat_.h; + n_ = (np_ + 1) / 2; +#else // WITH_DPU + n_ = out_mat_.n * out_mat_.w * out_mat_.h; +#endif // WITH_DPU +#endif // NCNN + k_ = in_param_.kernel_h * in_param_.kernel_w * in_mat.c; + + param_init(); + + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + if (shard_type_ == shardByCol) + { + plhs_buffer_ = new float[lhs_stride * 1 * nm_]; + prhs_buffer_ = new float[rhs_stride * num_threads_]; + } + else + { + plhs_buffer_ = new float[lhs_stride * num_threads_]; + prhs_buffer_ = new float[rhs_stride * 1 * nn_]; + } + + if (plhs_buffer_ == NULL || prhs_buffer_ == NULL) + { + error_ = 1; + } + + if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 || + in_param_.stride_h != 1 || in_param_.padding != 0) + { + need_im2col_ = 1; + } + else + { + need_im2col_ = 0; + } + + omp_set_num_threads(num_threads_); + + error_ = 0; +} + +conv_sgemm_multithreads::~conv_sgemm_multithreads() +{ + if (plhs_buffer_) + delete[] plhs_buffer_; + if (prhs_buffer_) + delete[] prhs_buffer_; +} + +void conv_sgemm_multithreads::run() +{ + if (error_) + return; + + if (shard_type_ == shardByCol && conv_type_ == col_major) + { + compute_colmajor_colshard(); + } + else if (shard_type_ == shardByRow && conv_type_ == col_major) + { + compute_colmajor_rowshard(); + } + else if (shard_type_ == shardByCol && conv_type_ == row_major) + { + compute_rowmajor_colshard(); + } + else if (shard_type_ == shardByRow && conv_type_ == row_major) + { + compute_rowmajor_rowshard(); + } +} + +void conv_sgemm_multithreads::compute_rowmajor_colshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], + &plhs_buffer_[i * lhs_stride]); + } + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + int thread_num = omp_get_thread_num(); + // float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; + float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; + + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + else + { + _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + } + else + { +#ifdef WITH_DPU + _pack_rowmajor_notrans_rhs(nr_, bn, bk, np_, &in_mat_.data[n_ + l * bk_ * np_ + j * bn_], + prhs_ptr); +#else + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], + prhs_ptr); +#endif + } + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + +#ifdef WITH_DPU + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], + prhs_ptr, &out_mat_.data[n_ + i * bm_ * np_ + j * bn_], + l, np_, bk); +#else // WITH_DPU + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], + prhs_ptr, &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, + bk); +#endif // WITH_DPU + } + } + } +} + +void conv_sgemm_multithreads::compute_rowmajor_rowshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), + &prhs_buffer_[j * rhs_stride]); + } + else + { + _pack_rowmajor_image_rhs_batch( + nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), &prhs_buffer_[j * rhs_stride]); + } + } + else + { + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], + &prhs_buffer_[j * rhs_stride]); + } + } + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + int thread_num = omp_get_thread_num(); + float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; + + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], + plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, + &prhs_buffer_[j * rhs_stride], + &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } +} + +void conv_sgemm_multithreads::compute_colmajor_colshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], + &plhs_buffer_[i * lhs_stride]); + } + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + int thread_num = omp_get_thread_num(); + float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; + + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + else + { + _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + } + else + { + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], + prhs_ptr); + } + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], + prhs_ptr, &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, + bk); + } + } + } +} + +void conv_sgemm_multithreads::compute_colmajor_rowshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), + &prhs_buffer_[j * rhs_stride]); + } + else + { + _pack_colmajor_image_rhs_batch( + nr_, bn, bk, l * bk_, j * bn_, const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), &prhs_buffer_[j * rhs_stride]); + } + } + else + { + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], + &prhs_buffer_[j * rhs_stride]); + } + } + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + int thread_num = omp_get_thread_num(); + float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; + + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], + plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, + &prhs_buffer_[j * rhs_stride], + &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/conv_sgemm_multithreads.h b/compute/ncnn/src/srcn/conv_sgemm_multithreads.h new file mode 100644 index 000000000..9c9ce7437 --- /dev/null +++ b/compute/ncnn/src/srcn/conv_sgemm_multithreads.h @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_CONV_SGEMM_MULTITHREADS_H__ +#define __NNFW_SRCN_CONV_SGEMM_MULTITHREADS_H__ + +#include "ncnn/srcn/conv_type.h" +#include "common.h" + +namespace nnfw +{ +namespace srcn +{ + +class conv_sgemm_multithreads +{ +public: + conv_sgemm_multithreads(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, + const convParams_t &in_param, int num_threads, convType_t conv_type); + ~conv_sgemm_multithreads(); + + void run(); + +private: + void param_init(); + + void compute_rowmajor_colshard(); + void compute_rowmajor_rowshard(); + void compute_colmajor_colshard(); + void compute_colmajor_rowshard(); + + const convMat_t in_mat_; + const convMat_t weights_mat_; + convMat_t out_mat_; + const convParams_t in_param_; + convType_t conv_type_; + int num_threads_; + + int m_; + int n_; +#ifdef WITH_DPU + int np_; +#endif + int k_; + + int bm_; + int bn_; + int bk_; + + int rm_; + int rn_; + int rk_; + + int nm_; + int nn_; + int nk_; + + int mr_; + int nr_; + + int need_im2col_; + shardType_t shard_type_; + + float *prhs_buffer_; + float *plhs_buffer_; + + int error_; +}; + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_CONV_SGEMM_MULTITHREADS_H__ diff --git a/compute/ncnn/src/srcn/conv_sgemm_singlethread.cc b/compute/ncnn/src/srcn/conv_sgemm_singlethread.cc new file mode 100644 index 000000000..4cbbf217f --- /dev/null +++ b/compute/ncnn/src/srcn/conv_sgemm_singlethread.cc @@ -0,0 +1,366 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdexcept> + +#include "common.h" +#include "sgemm_kernel.h" +#include "sgemm_pack.h" +#include "conv_sgemm_singlethread.h" + +namespace nnfw +{ +namespace srcn +{ + +void conv_sgemm_singlethread::param_init() +{ + if (n_ > 3 * m_) + { + shard_type_ = shardByRow; + } + else + { + shard_type_ = shardByCol; + } + +#if __aarch64__ + if (conv_type_ == row_major) + { + if (shard_type_ == shardByRow) + { + mr_ = 8; + nr_ = 12; + } + else + { + mr_ = 12; + nr_ = 8; + } + } + else if (conv_type_ == col_major) + { +#ifndef BATCH_DILATION_FIX + mr_ = 12; + nr_ = 8; +#else // BATCH_DILATION_FIX + // TODO: batch(dilation) + inw * inh + if (out_mat_.n > 1) + { + mr_ = 24; + nr_ = 4; + } + else + { + mr_ = 12; + nr_ = 8; + } +#endif // BATCH_DILATION_FIX + } +#else // __aarch64__ + if (conv_type_ == row_major) + { + mr_ = 6; + nr_ = 8; + } + else if (conv_type_ == col_major) + { + mr_ = 8; + nr_ = 6; + } +#endif // __aarch64__ + + int k_div = (nr_ * sizeof_RhsScalar); + int k_sub = (mr_ * nr_ * sizeof_ResScalar); + + const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div), MAX_K); + bk_ = MIN(k_cache, k_); + + if (shard_type_ == shardByCol) + { + int m_sub = (bk_ * nr_ * sizeof_RhsScalar); + int m_cache = divup((L2_CACHE_SIZE - m_sub), (sizeof_LhsScalar * bk_ * 2)); + bm_ = MIN(m_cache, m_); + + bn_ = MIN(GEN_COL, n_); + if (L3_CACHE_SIZE) + { + int n_sub = (bk_ * bm_ * sizeof_RhsScalar); + int n_cache = divup((L3_CACHE_SIZE - n_sub), (sizeof_LhsScalar * bk_ * 2)); + bn_ = MIN(n_cache, bn_); + } + } + else + { + int n_sub = (bk_ * mr_ * sizeof_RhsScalar); + int n_cache = divup((L2_CACHE_SIZE - n_sub), (sizeof_LhsScalar * bk_ * 2)); + bn_ = MIN(n_cache, n_); + + bm_ = MIN(GEN_COL, m_); + if (L3_CACHE_SIZE) + { + int m_sub = (bk_ * bn_ * sizeof_RhsScalar); + int m_cache = divup((L3_CACHE_SIZE - m_sub), (sizeof_LhsScalar * bk_ * 2)); + bm_ = MIN(m_cache, bm_); + } + } + + nm_ = divup(m_, bm_); + nn_ = divup(n_, bn_); + nk_ = divup(k_, bk_); + + rm_ = m_ % bm_; + rn_ = n_ % bn_; + rk_ = k_ % bk_; +} + +conv_sgemm_singlethread::conv_sgemm_singlethread(const convMat_t &in_mat, + const convMat_t &weights_mat, convMat_t &out_mat, + const convParams_t &in_param, convType_t conv_type) + : in_mat_(in_mat), weights_mat_(weights_mat), out_mat_(out_mat), in_param_(in_param), + conv_type_(conv_type) +{ + m_ = out_mat_.c; +#ifdef NCNN + n_ = out_mat_.n * alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); +#else + n_ = out_mat_.n * out_mat_.w * out_mat_.h; +#endif + k_ = in_param_.kernel_h * in_param_.kernel_w * in_mat.c; + + param_init(); + + if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 || + in_param_.stride_h != 1 || in_param_.padding != 0 || out_mat_.n > 1) + { + need_im2col_ = 1; + } + else + { + need_im2col_ = 0; + } +} + +conv_sgemm_singlethread::~conv_sgemm_singlethread() {} + +void conv_sgemm_singlethread::run() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float *plhs_ptr = new float[mstride * bk_]; + float *prhs_ptr = new float[nstride * bk_]; + + if (conv_type_ == row_major) + { + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + else + { + _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + } + else + { + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], + prhs_ptr); + } + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], + plhs_ptr); + + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], + plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_rowmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + else + { + _pack_rowmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + } + else + { + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], + prhs_ptr); + } + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &out_mat_.data[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else + { + throw std::runtime_error{"Error shrad type!"}; + } + } + else if (conv_type_ == col_major) + { + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + else + { + _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + } + else + { + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], + prhs_ptr); + } + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], + plhs_ptr); + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], + plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + if (need_im2col_) + { + if (out_mat_.n == 1) + { + _pack_colmajor_image_rhs(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + else + { + _pack_colmajor_image_rhs_batch(nr_, bn, bk, l * bk_, j * bn_, + const_cast<convMat_t *>(&in_mat_), &out_mat_, + const_cast<convParams_t *>(&in_param_), prhs_ptr); + } + } + else + { + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], + prhs_ptr); + } + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &out_mat_.data[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else + { + throw std::runtime_error{"Error shrad type!"}; + } + } + else + { + throw std::runtime_error{"Error conv type!"}; + } + + delete[] plhs_ptr; + delete[] prhs_ptr; +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/conv_sgemm_singlethread.h b/compute/ncnn/src/srcn/conv_sgemm_singlethread.h new file mode 100644 index 000000000..63f8b6e66 --- /dev/null +++ b/compute/ncnn/src/srcn/conv_sgemm_singlethread.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_CONV_SGEMM_SINGLETHREAD_H__ +#define __NNFW_SRCN_CONV_SGEMM_SINGLETHREAD_H__ + +#include "ncnn/srcn/conv_type.h" +#include "common.h" + +namespace nnfw +{ +namespace srcn +{ + +class conv_sgemm_singlethread +{ +public: + conv_sgemm_singlethread(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, + const convParams_t &in_param, convType_t conv_type); + ~conv_sgemm_singlethread(); + + void run(); + +private: + void param_init(); + + const convMat_t in_mat_; + const convMat_t weights_mat_; + convMat_t out_mat_; + const convParams_t in_param_; + convType_t conv_type_; + + int m_; + int n_; + int k_; + + int bm_; + int bn_; + int bk_; + + int rm_; + int rn_; + int rk_; + + int nm_; + int nn_; + int nk_; + + int mr_; + int nr_; + + int need_im2col_; + + shardType_t shard_type_; +}; + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_CONV_SGEMM_SINGLETHREAD_H__ diff --git a/compute/ncnn/src/srcn/conv_sparse.cc b/compute/ncnn/src/srcn/conv_sparse.cc new file mode 100644 index 000000000..10e2a2b93 --- /dev/null +++ b/compute/ncnn/src/srcn/conv_sparse.cc @@ -0,0 +1,271 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include <stdexcept> + +#include "common.h" +#include "sgemm_kernel.h" +#include "sgemm_pack.h" +#include "conv_sparse.h" + +namespace nnfw +{ +namespace srcn +{ + +void conv_sparse::param_init() +{ +#ifdef NCNN + n_ = alignSize(out_mat_.h * out_mat_.w, 16 / sizeof(float)); +#else + n_ = out_mat_.w * out_mat_.h; +#endif + + bch_ = BCH; + nch_ = (out_mat_.c + bch_ - 1) / bch_; + + rch_ = out_mat_.c % bch_; + + bn_ = MIN(n_, L1_CACHE_SIZE / (sizeof(float) * 2)); + bn_ = MIN(bn_, (L2_CACHE_SIZE / 2 - bch_ * sizeof(weight_data_t)) / ((bch_ + 1) * sizeof(float)) / + num_threads_); + nn_ = (n_ + bn_ - 1) / bn_; + rn_ = n_ % bn_; + + if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 || + in_param_.stride_h != 1 || in_param_.padding != 0) + { + need_im2col_ = 1; + } + else + { + need_im2col_ = 0; + } +} + +conv_sparse::conv_sparse(const convMat_t &in_mat, convMat_t &out_mat, const convParams_t &in_param, + const sparse_weight_t *weights, int num_threads, convType_t conv_type) + : in_mat_(in_mat), out_mat_(out_mat), in_param_(in_param), weights_(weights), + num_threads_(num_threads), conv_type_(conv_type) +{ + param_init(); +} + +conv_sparse::~conv_sparse() {} + +void conv_sparse::compute_singlethread() +{ + if (need_im2col_) + { + for (int i = 0; i < nch_; i++) + { + const sparse_weight_t *weight_ptr = weights_ + i; + const int mxk = weight_ptr->mxk; + float prhs_ptr[bn_]; + + for (int j = 0; j < nn_; j++) + { + int k = -1; + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + weight_data_t *lhs_ptr = weight_ptr->wdata; + + for (int l = 0; l < mxk; l++) + { + if (k != lhs_ptr->k) + { + k = lhs_ptr->k; + _sparse_pack_rowmajor_image(bn, k, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), + prhs_ptr); + } + + // Why n_ = 64 x 64 is too much slower on Tizen??? + _sparse_sgemm_kernel(bn, lhs_ptr->data, prhs_ptr, + &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); + + lhs_ptr++; + } + } + } + } + else + { + for (int i = 0; i < nch_; i++) + { + const sparse_weight_t *weight_ptr = weights_ + i; + const int mxk = weight_ptr->mxk; + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + weight_data_t *lhs_ptr = weight_ptr->wdata; + float *rhs_ptr = in_mat_.data + j * bn_; + + for (int l = 0; l < mxk; l++) + { + // Why n_ = 64 x 64 is too much slower on Tizen??? + _sparse_sgemm_kernel(bn, lhs_ptr->data, rhs_ptr + lhs_ptr->k * n_, + &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); + + lhs_ptr++; + } + } + } + } +} + +void conv_sparse::compute_multithreads() +{ + omp_set_num_threads(num_threads_); + + if (nch_ >= num_threads_ || nch_ >= nn_) + { + if (need_im2col_) + { +#pragma omp parallel for + for (int i = 0; i < nch_; i++) + { + const sparse_weight_t *weight_ptr = weights_ + i; + const int mxk = weight_ptr->mxk; + float prhs_ptr[bn_]; + + for (int j = 0; j < nn_; j++) + { + int k = -1; + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + weight_data_t *lhs_ptr = weight_ptr->wdata; + + for (int l = 0; l < mxk; l++) + { + if (k != lhs_ptr->k) + { + k = lhs_ptr->k; + _sparse_pack_rowmajor_image(bn, k, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), + prhs_ptr); + } + + _sparse_sgemm_kernel(bn, lhs_ptr->data, prhs_ptr, + &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); + + lhs_ptr++; + } + } + } + } + else + { +#pragma omp parallel for + for (int i = 0; i < nch_; i++) + { + const sparse_weight_t *weight_ptr = weights_ + i; + const int mxk = weight_ptr->mxk; + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + weight_data_t *lhs_ptr = weight_ptr->wdata; + float *rhs_ptr = in_mat_.data + j * bn_; + + for (int l = 0; l < mxk; l++) + { + _sparse_sgemm_kernel(bn, lhs_ptr->data, rhs_ptr + lhs_ptr->k * n_, + &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); + + lhs_ptr++; + } + } + } + } + } + else + { + if (need_im2col_) + { + for (int i = 0; i < nch_; i++) + { + const sparse_weight_t *weight_ptr = weights_ + i; + const int mxk = weight_ptr->mxk; + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + int k = -1; + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + weight_data_t *lhs_ptr = weight_ptr->wdata; + float prhs_ptr[bn]; + + for (int l = 0; l < mxk; l++) + { + if (k != lhs_ptr->k) + { + k = lhs_ptr->k; + _sparse_pack_rowmajor_image(bn, k, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), + prhs_ptr); + } + + _sparse_sgemm_kernel(bn, lhs_ptr->data, prhs_ptr, + &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); + + lhs_ptr++; + } + } + } + } + else + { + for (int i = 0; i < nch_; i++) + { + const sparse_weight_t *weight_ptr = weights_ + i; + const int mxk = weight_ptr->mxk; + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + weight_data_t *lhs_ptr = weight_ptr->wdata; + float *rhs_ptr = in_mat_.data + j * bn_; + + for (int l = 0; l < mxk; l++) + { + _sparse_sgemm_kernel(bn, lhs_ptr->data, rhs_ptr + lhs_ptr->k * n_, + &out_mat_.data[lhs_ptr->m * n_ + j * bn_]); + + lhs_ptr++; + } + } + } + } + } +} + +void conv_sparse::run() +{ + if (num_threads_ == 1) + compute_singlethread(); + else if (num_threads_ > 1) + compute_multithreads(); + else + throw std::runtime_error{"Invalid thread number."}; +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/conv_sparse.h b/compute/ncnn/src/srcn/conv_sparse.h new file mode 100644 index 000000000..7ac358fd8 --- /dev/null +++ b/compute/ncnn/src/srcn/conv_sparse.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_CONV_SPARSE_H__ +#define __NNFW_SRCN_CONV_SPARSE_H__ + +#include "ncnn/srcn/conv_type.h" +#include "common.h" + +namespace nnfw +{ +namespace srcn +{ + +#define BCH 128 + +typedef struct +{ + short m; + short k; + float data; +} weight_data_t; + +typedef struct +{ + int mxk; + weight_data_t *wdata; +} sparse_weight_t; + +class conv_sparse +{ +public: + conv_sparse(const convMat_t &in_mat, convMat_t &out_mat, const convParams_t &in_param, + const sparse_weight_t *weights, int num_threads, convType_t conv_type); + ~conv_sparse(); + + void run(); + +private: + void param_init(); + void compute_singlethread(); + void compute_multithreads(); + + const convMat_t in_mat_; + convMat_t out_mat_; + const convParams_t in_param_; + const sparse_weight_t *weights_; + int num_threads_; + convType_t conv_type_; + + uint32_t n_; + uint32_t bn_; + int rn_; + int nn_; + + int bch_; + int rch_; + int nch_; + + int need_im2col_; +}; + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_CONV_SPARSE_H__ diff --git a/compute/ncnn/src/srcn/conv_winograd.cc b/compute/ncnn/src/srcn/conv_winograd.cc new file mode 100644 index 000000000..69649ea2a --- /dev/null +++ b/compute/ncnn/src/srcn/conv_winograd.cc @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common.h" +#include "conv_winograd.h" + +namespace std +{ +template <typename Dtype> static inline Dtype max(Dtype a, Dtype b) +{ + if (a > b) + return a; + else + return b; +} +} + +namespace nnfw +{ +namespace srcn +{ + +void conv_winograd::param_init() +{ + if ((in_param_.kernel_w != in_param_.kernel_h) || (in_param_.stride_w != in_param_.stride_h) || + (in_param_.kernel_w != 3 && in_param_.kernel_w != 5) || (in_param_.stride_w != 1) || + (!winograd_weight_)) + { + error_ = 1; + return; + } + + int M, N; + const int w = in_mat_.w; + const int h = in_mat_.h; + const int outw = out_mat_.w; + const int outh = out_mat_.h; + const int pad_w = in_param_.pad_w; + const int pad_h = in_param_.pad_h; + + if (in_param_.kernel_w == 3) + { + M = winograd_para_3x3s1::M; + N = winograd_para_3x3s1::N; + } + else + { + M = winograd_para_5x5s1::M; + N = winograd_para_5x5s1::N; + } + + tile_h_in_ = tile_w_in_ = M; + tile_h_out_ = tile_h_in_ - N + 1; + tile_w_out_ = tile_w_in_ - N + 1; + ntiles_h_ = (std::max(h + pad_h - tile_h_in_ + 1, outh) + tile_h_out_ - 1) / tile_h_out_; + ntiles_w_ = (std::max(w + pad_w - tile_w_in_ + 1, outw) + tile_w_out_ - 1) / tile_w_out_; + + error_ = 0; +} + +conv_winograd::conv_winograd(const convMat_t &in_mat, convMat_t &out_mat, + const convParams_t &in_param, convType_t conv_type, + const float *winograd_weight, int num_threads, int inc_stride, + int outc_stride, int c_stride) + : in_mat_(in_mat), out_mat_(out_mat), in_param_(in_param), conv_type_(conv_type), + winograd_weight_(winograd_weight), num_threads_(num_threads), inc_stride_(inc_stride), + outc_stride_(outc_stride), c_stride_(c_stride) + +{ + param_init(); +} + +conv_winograd::~conv_winograd() {} + +void conv_winograd::compute_sgemm(sgemmType_t major_type, sgemmTrans_t ltrans, sgemmTrans_t rtrans, + const int m, const int n, const int k, const float *lhs_data, + const float *rhs_data, float *res_data) +{ + class sgemm_singlethread sgemm(major_type, ltrans, rtrans, m, n, k, lhs_data, rhs_data, res_data, + num_threads_); + + sgemm.run(); +} + +void conv_winograd::winograd_input_im2col(float *col_buff) +{ + const int w = in_mat_.w; + const int h = in_mat_.h; + const float *data = in_mat_.data; + const int channels = in_mat_.c; + const int pad_w = in_param_.pad_w; + const int pad_h = in_param_.pad_h; + + if (conv_type_ == row_major) + { +#ifdef NCNN + const int n = alignSize(inc_stride_, 16 / sizeof(float)); +#else // NCNN + const int n = inc_stride_; +#endif // NCNN + for (int c = 0; c < channels; ++c) + { + for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) + { + for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) + { + for (int y = 0; y < tile_h_in_; ++y) + { + for (int x = 0; x < tile_w_in_; ++x) + { + int in_y = tile_h * tile_h_out_ + y - pad_h; + int in_x = tile_w * tile_w_out_ + x - pad_w; + + if (in_y < 0 || in_x < 0 || in_y >= h || in_x >= w) + { + col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_in_ + y) * + tile_w_in_ + + x] = 0; + } + else + { + col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_in_ + y) * + tile_w_in_ + + x] = data[c * n + in_y * w + in_x]; + } + } + } + } + } + } + } + else if (conv_type_ == col_major) + { + for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) + { + for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) + { + for (int y = 0; y < tile_h_in_; ++y) + { + for (int x = 0; x < tile_w_in_; ++x) + { + for (int c = 0; c < channels; ++c) + { + int in_y = tile_h * tile_h_out_ + y - pad_h; + int in_x = tile_w * tile_w_out_ + x - pad_w; + + if (in_y < 0 || in_x < 0 || in_y >= h || in_x >= w) + { + col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_in_ + y) * + tile_w_in_ + + x] = 0; + } + else + { + col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_in_ + y) * + tile_w_in_ + + x] = data[c + (in_y * w + in_x) * channels]; + } + } + } + } + } + } + } +} + +void conv_winograd::winograd_output_col2im(const float *col_buff) +{ + int outh = out_mat_.h; + int outw = out_mat_.w; + float *data = out_mat_.data; + int channels = out_mat_.c; + + if (conv_type_ == row_major) + { +#ifdef NCNN + const int n = alignSize(outc_stride_, 16 / sizeof(float)); +#else // NCNN + const int n = outc_stride_; +#endif // NCNN + for (int c = 0; c < channels; ++c) + { + for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) + { + for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) + { + for (int y = 0; y < tile_h_out_; ++y) + { + for (int x = 0; x < tile_w_out_; ++x) + { + int out_y = tile_h * tile_h_out_ + y; + int out_x = tile_w * tile_w_out_ + x; + if (out_y < outh && out_x < outw) + { + data[c * n + out_y * outw + out_x] = + col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_out_ + y) * + tile_w_out_ + + x]; + } + } + } + } + } + } + } + else if (conv_type_ == col_major) + { + for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) + { + for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) + { + for (int y = 0; y < tile_h_out_; ++y) + { + for (int x = 0; x < tile_w_out_; ++x) + { + for (int c = 0; c < channels; ++c) + { + int out_y = tile_h * tile_h_out_ + y; + int out_x = tile_w * tile_w_out_ + x; + if (out_y < outh && out_x < outw) + { + data[c + (out_y * outw + out_x) * c_stride_] = + col_buff[(((c * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * tile_h_out_ + y) * + tile_w_out_ + + x]; + } + } + } + } + } + } + } +} + +void conv_winograd::compute_winograd() +{ + // const int w = in_mat_.w; + // const int h = in_mat_.h; + const int inch = in_mat_.c; + // const int outw = out_mat_.w; + // const int outh = out_mat_.h; + const int outch = out_mat_.c; + const int kernel_size = in_param_.kernel_w; + + int M, N; + const double *A; + const double *B; + + if (kernel_size == 3) + { + M = winograd_para_3x3s1::M; + N = winograd_para_3x3s1::N; + B = winograd_para_3x3s1::getB(); + A = winograd_para_3x3s1::getA(); + } + else + { + M = winograd_para_5x5s1::M; + N = winograd_para_5x5s1::N; + B = winograd_para_5x5s1::getB(); + A = winograd_para_5x5s1::getA(); + } + + /*Step 2: transfer image to winograd domain*/ + float *col_buff = + new float[std::max(outch, inch) * ntiles_h_ * ntiles_w_ * tile_h_in_ * tile_w_in_]; + + int temp1_n = inch * ntiles_h_ * ntiles_w_; + float *temp1_ = + new float[tile_h_in_ * tile_w_in_ * std::max(outch, inch) * ntiles_h_ * ntiles_w_]; + + float *winograd_b = new float[M * M * M * M]; + + if ((NULL == col_buff) || (NULL == temp1_) || (NULL == winograd_b)) + { + delete[] col_buff; + delete[] temp1_; + delete[] winograd_b; + return; + } + + winograd_input_im2col(col_buff); + + kronecker_product(winograd_b, B, B, M, M, M, M); + + compute_sgemm(rowMajor, trans, trans, tile_h_in_ * tile_w_in_, temp1_n, tile_h_in_ * tile_w_in_, + winograd_b, col_buff, temp1_); + + delete[] winograd_b; + + /*Step 3: convolution in winograd domain*/ + for (int j = 0; j < tile_h_in_ * tile_w_in_; ++j) + { + compute_sgemm(rowMajor, notrans, notrans, outch, ntiles_h_ * ntiles_w_, inch, + winograd_weight_ + j * c_stride_ * inch, + temp1_ + j * inch * ntiles_h_ * ntiles_w_, + col_buff + j * outch * ntiles_h_ * ntiles_w_); + } + + /*Step 4: transfer back to time domain*/ + float *winograd_a = new float[M * (M - N + 1) * M * (M - N + 1)]; + if (NULL == winograd_a) + { + delete[] col_buff; + delete[] temp1_; + return; + } + kronecker_product(winograd_a, A, A, M, M - N + 1, M, M - N + 1); + compute_sgemm(rowMajor, trans, notrans, outch * ntiles_h_ * ntiles_w_, tile_h_out_ * tile_w_out_, + tile_h_in_ * tile_w_in_, col_buff, winograd_a, temp1_); + delete[] winograd_a; + delete[] col_buff; + + winograd_output_col2im(temp1_); + + delete[] temp1_; +} + +void conv_winograd::run() +{ + if (error_) + return; + + compute_winograd(); +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/conv_winograd.h b/compute/ncnn/src/srcn/conv_winograd.h new file mode 100644 index 000000000..76c2601f2 --- /dev/null +++ b/compute/ncnn/src/srcn/conv_winograd.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_CONV_WINOGRAD_H__ +#define __NNFW_SRCN_CONV_WINOGRAD_H__ + +#include "ncnn/srcn/conv_type.h" +#include "winograd.h" +#include "sgemm_singlethread.h" + +namespace nnfw +{ +namespace srcn +{ + +class conv_winograd +{ +public: + conv_winograd(const convMat_t &in_mat, convMat_t &out_mat, const convParams_t &in_param, + convType_t conv_type, const float *winograd_weight, int num_threads, int inc_stride, + int outc_stride, int c_stride); + ~conv_winograd(); + + void run(); + +private: + void param_init(); + void compute_sgemm(sgemmType_t major_type, sgemmTrans_t ltrans, sgemmTrans_t rtrans, const int m, + const int n, const int k, const float *lhs_data, const float *rhs_data, + float *res_data); + void winograd_input_im2col(float *col_buff); + void winograd_output_col2im(const float *col_buff); + void compute_winograd(); + + const convMat_t in_mat_; + convMat_t out_mat_; + const convParams_t in_param_; + convType_t conv_type_; + const float *winograd_weight_; + const int num_threads_; + + int tile_w_in_; + int tile_h_in_; + int tile_w_out_; + int tile_h_out_; + int ntiles_w_; + int ntiles_h_; + + int inc_stride_; + int outc_stride_; + int c_stride_; + + int error_; +}; + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_CONV_WINOGRAD_H__ diff --git a/compute/ncnn/src/srcn/conv_winograd_batch.cc b/compute/ncnn/src/srcn/conv_winograd_batch.cc new file mode 100644 index 000000000..cba45c648 --- /dev/null +++ b/compute/ncnn/src/srcn/conv_winograd_batch.cc @@ -0,0 +1,304 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common.h" +#include "conv_winograd_batch.h" + +namespace std +{ +template <typename Dtype> static inline Dtype max(Dtype a, Dtype b) +{ + if (a > b) + return a; + else + return b; +} +} + +namespace nnfw +{ +namespace srcn +{ + +void conv_winograd_batch::param_init() +{ + if ((in_param_.kernel_w != in_param_.kernel_h) || (in_param_.stride_w != in_param_.stride_h) || + (in_param_.kernel_w != 3 && in_param_.kernel_w != 5) || (in_param_.stride_w != 1) || + (!winograd_weight_)) + { + error_ = 1; + return; + } + + int M, N; + const int w = in_mat_.w; + const int h = in_mat_.h; + const int outw = out_mat_.w; + const int outh = out_mat_.h; + const int pad_w = in_param_.pad_w; + const int pad_h = in_param_.pad_h; + + if (in_param_.kernel_w == 3) + { + if (w == 4) + { + M = winograd_para_3x3s1_2::M; + N = winograd_para_3x3s1_2::N; + } + else + { + M = winograd_para_3x3s1::M; + N = winograd_para_3x3s1::N; + } + } + else + { + M = winograd_para_5x5s1::M; + N = winograd_para_5x5s1::N; + } + + tile_h_in_ = tile_w_in_ = M; + tile_h_out_ = tile_h_in_ - N + 1; + tile_w_out_ = tile_w_in_ - N + 1; + ntiles_h_ = (std::max(h + pad_h - tile_h_in_ + 1, outh) + tile_h_out_ - 1) / tile_h_out_; + ntiles_w_ = (std::max(w + pad_w - tile_w_in_ + 1, outw) + tile_w_out_ - 1) / tile_w_out_; + + error_ = 0; +} + +conv_winograd_batch::conv_winograd_batch(const convMat_t &in_mat, convMat_t &out_mat, + const convParams_t &in_param, convType_t conv_type, + const float *winograd_weight, int num_threads) + : in_mat_(in_mat), out_mat_(out_mat), in_param_(in_param), conv_type_(conv_type), + winograd_weight_(winograd_weight), num_threads_(num_threads) +{ + param_init(); +} + +conv_winograd_batch::~conv_winograd_batch() {} + +void conv_winograd_batch::compute_sgemm(sgemmType_t major_type, sgemmTrans_t ltrans, + sgemmTrans_t rtrans, const int m, const int n, const int k, + const float *lhs_data, const float *rhs_data, + float *res_data) +{ + class sgemm_singlethread sgemm(major_type, ltrans, rtrans, m, n, k, lhs_data, rhs_data, res_data, + num_threads_); + + sgemm.run(); +} + +void conv_winograd_batch::winograd_input_im2col(float *col_buff) +{ + const int w = in_mat_.w; + const int h = in_mat_.h; + const float *data = in_mat_.data; + const int channels = in_mat_.c; + const int batch = in_mat_.n; + const int pad_w = in_param_.pad_w; + const int pad_h = in_param_.pad_h; + + // TODO: row_major + if (conv_type_ == col_major) + { + for (int n = 0; n < batch; n++) + { + for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) + { + for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) + { + for (int y = 0; y < tile_h_in_; ++y) + { + for (int x = 0; x < tile_w_in_; ++x) + { + for (int c = 0; c < channels; ++c) + { + int in_y = tile_h * tile_h_out_ + y - pad_h; + int in_x = tile_w * tile_w_out_ + x - pad_w; + + if (in_y < 0 || in_x < 0 || in_y >= h || in_x >= w) + { + col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * + tile_h_in_ + + y) * + tile_w_in_ + + x] = 0; + } + else + { + col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * + tile_h_in_ + + y) * + tile_w_in_ + + x] = data[((n * h + in_y) * w + in_x) * channels + c]; + } + } + } + } + } + } + } + } +} + +void conv_winograd_batch::winograd_output_col2im(const float *col_buff) +{ + int outh = out_mat_.h; + int outw = out_mat_.w; + float *data = out_mat_.data; + int channels = out_mat_.c; + int batch = out_mat_.n; + + // TODO: row_major + if (conv_type_ == col_major) + { + for (int n = 0; n < batch; n++) + { + for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) + { + for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) + { + for (int y = 0; y < tile_h_out_; ++y) + { + for (int x = 0; x < tile_w_out_; ++x) + { + for (int c = 0; c < channels; ++c) + { + int out_y = tile_h * tile_h_out_ + y; + int out_x = tile_w * tile_w_out_ + x; + if (out_y < outh && out_x < outw) + { + data[((n * outh + out_y) * outw + out_x) * channels + c] = + col_buff[((((c * batch + n) * ntiles_h_ + tile_h) * ntiles_w_ + tile_w) * + tile_h_out_ + + y) * + tile_w_out_ + + x]; + } + } + } + } + } + } + } + } +} + +void conv_winograd_batch::compute_winograd() +{ + const int w = in_mat_.w; + // const int h = in_mat_.h; + const int inch = in_mat_.c; + // const int outw = out_mat_.w; + // const int outh = out_mat_.h; + const int outch = out_mat_.c; + const int kernel_size = in_param_.kernel_w; + const int batch = in_mat_.n; + + int M, N; + const double *A; + const double *B; + + if (kernel_size == 3) + { + if (w == 4) + { + M = winograd_para_3x3s1_2::M; + N = winograd_para_3x3s1_2::N; + B = winograd_para_3x3s1_2::getB(); + A = winograd_para_3x3s1_2::getA(); + } + else + { + M = winograd_para_3x3s1::M; + N = winograd_para_3x3s1::N; + B = winograd_para_3x3s1::getB(); + A = winograd_para_3x3s1::getA(); + } + } + else + { + M = winograd_para_5x5s1::M; + N = winograd_para_5x5s1::N; + B = winograd_para_5x5s1::getB(); + A = winograd_para_5x5s1::getA(); + } + + /*Step 2: transfer image to winograd domain*/ + float *col_buff = + new float[std::max(outch, inch) * batch * ntiles_h_ * ntiles_w_ * tile_h_in_ * tile_w_in_]; + + int temp1_n = batch * inch * ntiles_h_ * ntiles_w_; + float *temp1_ = + new float[batch * tile_h_in_ * tile_w_in_ * std::max(outch, inch) * ntiles_h_ * ntiles_w_]; + + float *winograd_b = new float[M * M * M * M]; + + if ((NULL == col_buff) || (NULL == temp1_) || (NULL == winograd_b)) + { + delete[] col_buff; + delete[] temp1_; + delete[] winograd_b; + return; + } + + winograd_input_im2col(col_buff); + + kronecker_product(winograd_b, B, B, M, M, M, M); + + compute_sgemm(rowMajor, trans, trans, tile_h_in_ * tile_w_in_, temp1_n, tile_h_in_ * tile_w_in_, + winograd_b, col_buff, temp1_); + delete[] winograd_b; + + /*Step 3: convolution in winograd domain*/ + for (int j = 0; j < tile_h_in_ * tile_w_in_; ++j) + { + compute_sgemm(rowMajor, notrans, notrans, outch, batch * ntiles_h_ * ntiles_w_, inch, + winograd_weight_ + j * outch * inch, + temp1_ + j * batch * inch * ntiles_h_ * ntiles_w_, + col_buff + j * batch * outch * ntiles_h_ * ntiles_w_); + } + + /*Step 4: transfer back to time domain*/ + float *winograd_a = new float[M * (M - N + 1) * M * (M - N + 1)]; + if (NULL == winograd_a) + { + delete[] col_buff; + delete[] temp1_; + return; + } + + kronecker_product(winograd_a, A, A, M, M - N + 1, M, M - N + 1); + compute_sgemm(rowMajor, trans, notrans, batch * outch * ntiles_h_ * ntiles_w_, + tile_h_out_ * tile_w_out_, tile_h_in_ * tile_w_in_, col_buff, winograd_a, temp1_); + delete[] winograd_a; + delete[] col_buff; + + winograd_output_col2im(temp1_); + + delete[] temp1_; +} + +void conv_winograd_batch::run() +{ + if (error_) + return; + + compute_winograd(); +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/conv_winograd_batch.h b/compute/ncnn/src/srcn/conv_winograd_batch.h new file mode 100644 index 000000000..a022d9c52 --- /dev/null +++ b/compute/ncnn/src/srcn/conv_winograd_batch.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_CONV_WINOGRAD_BATCH_H__ +#define __NNFW_SRCN_CONV_WINOGRAD_BATCH_H__ + +#include "ncnn/srcn/conv_type.h" +#include "winograd.h" +#include "sgemm_singlethread.h" + +namespace nnfw +{ +namespace srcn +{ + +class conv_winograd_batch +{ +public: + conv_winograd_batch(const convMat_t &in_mat, convMat_t &out_mat, const convParams_t &in_param, + convType_t conv_type, const float *winograd_weight, int num_threads); + ~conv_winograd_batch(); + + void run(); + +private: + void param_init(); + void compute_sgemm(sgemmType_t major_type, sgemmTrans_t ltrans, sgemmTrans_t rtrans, const int m, + const int n, const int k, const float *lhs_data, const float *rhs_data, + float *res_data); + void winograd_input_im2col(float *col_buff); + void winograd_output_col2im(const float *col_buff); + void compute_winograd(); + + const convMat_t in_mat_; + convMat_t out_mat_; + const convParams_t in_param_; + convType_t conv_type_; + const float *winograd_weight_; + const int num_threads_; + + int tile_w_in_; + int tile_h_in_; + int tile_w_out_; + int tile_h_out_; + int ntiles_w_; + int ntiles_h_; + + int error_; +}; + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_CONV_WINOGRAD_BATCH_H__ diff --git a/compute/ncnn/src/srcn/deconv_sgemm_multithreads.cc b/compute/ncnn/src/srcn/deconv_sgemm_multithreads.cc new file mode 100644 index 000000000..f3ccf13e5 --- /dev/null +++ b/compute/ncnn/src/srcn/deconv_sgemm_multithreads.cc @@ -0,0 +1,387 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include "common.h" +#include "sgemm_kernel.h" +#include "sgemm_pack.h" +#include "deconv_sgemm_multithreads.h" + +namespace nnfw +{ +namespace srcn +{ + +void deconv_sgemm_multithreads::param_init() +{ +#if __aarch64__ + if (conv_type_ == row_major) + { + mr_ = 8; + nr_ = 12; + } + else if (conv_type_ == col_major) + { + + mr_ = 12; + nr_ = 8; + } +#else // __aarch64__ + if (conv_type_ == row_major) + { + mr_ = 6; + nr_ = 8; + } + else if (conv_type_ == col_major) + { + mr_ = 8; + nr_ = 6; + } +#endif // __aarch64__ + + int col = n_; + + if (m_ > n_) + { + shard_type_ = shardByRow; + col = m_; + } + else + { + shard_type_ = shardByCol; + } + + int th_base = divup(col, num_threads_); + + th_base = MIN(MAX(th_base, MIN_COL), MAX_COL); + + int k_div = (nr_ * sizeof_RhsScalar); + int k_sub = (mr_ * nr_ * sizeof_ResScalar); + + const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div * 2), MAX_K); + bk_ = MIN(k_cache, k_); + + if (shard_type_ == shardByCol) + { + int m_sub = (bk_ * nr_ * sizeof_RhsScalar); + int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + if (L3_CACHE_SIZE) + m_div = (sizeof_LhsScalar * bk_ * 2); + int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div); + bm_ = MIN(m_cache, m_); + + bn_ = MIN(th_base, n_); + if (L3_CACHE_SIZE) + { + int n_sub = (bk_ * bm_ * sizeof_RhsScalar); + int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + int n_cache = divup((L3_CACHE_SIZE - n_sub), n_div); + bn_ = MIN(n_cache, bn_); + } + } + else + { + int n_sub = (bk_ * mr_ * sizeof_LhsScalar); + int n_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + if (L3_CACHE_SIZE) + n_div = (sizeof_LhsScalar * bk_ * 2); + int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div); + bn_ = MIN(n_cache, n_); + + bm_ = MIN(th_base, m_); + if (L3_CACHE_SIZE) + { + int m_sub = (bk_ * bn_ * sizeof_RhsScalar); + int m_div = (sizeof_LhsScalar * bk_ * 2 * num_threads_); + int m_cache = divup((L3_CACHE_SIZE - m_sub), m_div); + bm_ = MIN(m_cache, bm_); + } + } + + nm_ = divup(m_, bm_); + nn_ = divup(n_, bn_); + nk_ = divup(k_, bk_); + + rm_ = m_ % bm_; + rn_ = n_ % bn_; + rk_ = k_ % bk_; +} + +deconv_sgemm_multithreads::deconv_sgemm_multithreads(const convMat_t &in_mat, + const convMat_t &weights_mat, + convMat_t &out_mat, + const convParams_t &in_param, int num_threads, + convType_t conv_type) + + : in_mat_(in_mat), weights_mat_(weights_mat), out_mat_(out_mat), in_param_(in_param), + conv_type_(conv_type), num_threads_(num_threads) +{ + m_ = in_param_.kernel_h * in_param_.kernel_w * out_mat_.c; +#ifdef NCNN + n_ = alignSize(in_mat_.h * in_mat_.w, 16 / sizeof(float)); +#else // NCNN + n_ = in_mat_.w * in_mat_.h; +#endif // NCNN + k_ = in_mat.c; + + param_init(); + + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + if (shard_type_ == shardByCol) + { + plhs_buffer_ = new float[lhs_stride * 1 * nm_]; + prhs_buffer_ = new float[rhs_stride * num_threads_]; + } + else + { + plhs_buffer_ = new float[lhs_stride * num_threads_]; + prhs_buffer_ = new float[rhs_stride * 1 * nn_]; + } + + pres_buffer_ = new float[bm_ * bn_ * num_threads_]; + + if (plhs_buffer_ == NULL || prhs_buffer_ == NULL || pres_buffer_ == NULL) + { + error_ = 1; + } + + if (in_param_.kernel_w != 1 || in_param_.kernel_h != 1 || in_param_.stride_w != 1 || + in_param_.stride_h != 1 || in_param_.padding != 0) + { + need_col2im_ = 1; + } + else + { + need_col2im_ = 0; + } + + omp_set_num_threads(num_threads_); + + error_ = 0; +} + +deconv_sgemm_multithreads::~deconv_sgemm_multithreads() +{ + if (plhs_buffer_) + delete[] plhs_buffer_; + if (prhs_buffer_) + delete[] prhs_buffer_; + if (pres_buffer_) + delete[] pres_buffer_; +} + +void deconv_sgemm_multithreads::run() +{ + if (error_) + return; + + if (shard_type_ == shardByCol && conv_type_ == col_major) + { + compute_colmajor_colshard(); + } + else if (shard_type_ == shardByRow && conv_type_ == col_major) + { + compute_colmajor_rowshard(); + } + else if (shard_type_ == shardByCol && conv_type_ == row_major) + { + compute_rowmajor_colshard(); + } + else if (shard_type_ == shardByRow && conv_type_ == row_major) + { + compute_rowmajor_rowshard(); + } +} + +void deconv_sgemm_multithreads::compute_rowmajor_colshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], + &plhs_buffer_[i * lhs_stride]); + } + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + int thread_num = omp_get_thread_num(); + float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; + float *pres_ptr = &pres_buffer_[bm_ * bn_ * thread_num]; + + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], + prhs_ptr, pres_ptr, 0, bn, bk); + + if (need_col2im_) + _unpack_rowmajor_image_res(bm, bn, i * bm_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), pres_ptr); + } + } + } +} + +void deconv_sgemm_multithreads::compute_rowmajor_rowshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &in_mat_.data[l * bk_ * n_ + j * bn_], + &prhs_buffer_[j * rhs_stride]); + } + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + int thread_num = omp_get_thread_num(); + float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; + float *pres_ptr = &pres_buffer_[bm_ * bn_ * thread_num]; + + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &weights_mat_.data[l * bk_ * m_ + i * bm_], + plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, + &prhs_buffer_[j * rhs_stride], pres_ptr, 0, bn, bk); + if (need_col2im_) + _unpack_rowmajor_image_res(bm, bn, i * bm_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), pres_ptr); + } + } + } +} + +void deconv_sgemm_multithreads::compute_colmajor_colshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], + &plhs_buffer_[i * lhs_stride]); + } + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + int thread_num = omp_get_thread_num(); + float *prhs_ptr = &prhs_buffer_[rhs_stride * thread_num]; + float *pres_ptr = &pres_buffer_[bm_ * bn_ * thread_num]; + + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, &plhs_buffer_[i * lhs_stride], + prhs_ptr, pres_ptr, 0, bm, bk); + + // Need to add lock? + if (need_col2im_) + _unpack_colmajor_image_res(bm, bn, i * bm_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), pres_ptr); + } + } + } +} + +void deconv_sgemm_multithreads::compute_colmajor_rowshard() +{ + int lhs_stride = (bm_ + mr_ - 1) / mr_ * mr_ * bk_; + int rhs_stride = (bn_ + nr_ - 1) / nr_ * nr_ * bk_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + +#pragma omp parallel for + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &in_mat_.data[j * bn_ * k_ + l * bk_], + &prhs_buffer_[j * rhs_stride]); + } + +#pragma omp parallel for + for (int i = 0; i < nm_; i++) + { + int thread_num = omp_get_thread_num(); + float *plhs_ptr = &plhs_buffer_[lhs_stride * thread_num]; + float *pres_ptr = &pres_buffer_[bm_ * bn_ * thread_num]; + + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &weights_mat_.data[i * bm_ * k_ + l * bk_], + plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, + &prhs_buffer_[j * rhs_stride], pres_ptr, 0, bm, bk); + + if (need_col2im_) + _unpack_colmajor_image_res(bm, bn, i * bm_, j * bn_, const_cast<convMat_t *>(&in_mat_), + &out_mat_, const_cast<convParams_t *>(&in_param_), pres_ptr); + } + } + } +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/deconv_sgemm_multithreads.h b/compute/ncnn/src/srcn/deconv_sgemm_multithreads.h new file mode 100644 index 000000000..762f20380 --- /dev/null +++ b/compute/ncnn/src/srcn/deconv_sgemm_multithreads.h @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_DECONV_SGEMM_MULTITHREADS_H__ +#define __NNFW_SRCN_DECONV_SGEMM_MULTITHREADS_H__ + +#include "ncnn/srcn/conv_type.h" +#include "common.h" + +namespace nnfw +{ +namespace srcn +{ + +class deconv_sgemm_multithreads +{ +public: + deconv_sgemm_multithreads(const convMat_t &in_mat, const convMat_t &weights_mat, + convMat_t &out_mat, const convParams_t &in_param, int num_threads, + convType_t conv_type); + ~deconv_sgemm_multithreads(); + + void run(); + +private: + void param_init(); + + void compute_rowmajor_colshard(); + void compute_rowmajor_rowshard(); + void compute_colmajor_colshard(); + void compute_colmajor_rowshard(); + + const convMat_t in_mat_; + const convMat_t weights_mat_; + convMat_t out_mat_; + const convParams_t in_param_; + convType_t conv_type_; + const int num_threads_; + + int m_; + int n_; + int k_; + + int bm_; + int bn_; + int bk_; + + int rm_; + int rn_; + int rk_; + + int nm_; + int nn_; + int nk_; + + int mr_; + int nr_; + + int need_col2im_; + shardType_t shard_type_; + + float *prhs_buffer_; + float *plhs_buffer_; + float *pres_buffer_; + + int error_; +}; + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_DECONV_SGEMM_MULTITHREADS_H__ diff --git a/compute/ncnn/src/srcn/depthwise_conv.cc b/compute/ncnn/src/srcn/depthwise_conv.cc new file mode 100644 index 000000000..cd092d5ac --- /dev/null +++ b/compute/ncnn/src/srcn/depthwise_conv.cc @@ -0,0 +1,2684 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include <arm_neon.h> +#include <stdlib.h> +#include <string.h> + +#include "common.h" +#include "ncnn/srcn/conv_type.h" + +namespace nnfw +{ +namespace srcn +{ + +static void depthwise_conv3x3S1_nopad(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + float *in_ptr3 = inbuf + 3 * w; + + float *out_ptr0 = outbuf + 0 * outw; + float *out_ptr1 = outbuf + 1 * outw; + + int i; + for (i = 0; i + 1 < outh; i += 2) + { + int nn = (outw >> 2) - 1; + int remain = outw & 0x03; + + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + + "1:\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q2, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q2, %e[weight345][1]\n" + "vmul.f32 q12, q0, %e[weight012][0]\n" + "vmul.f32 q13, q2, %e[weight012][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr2], %[in_ptr2], #16\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q2, %e[weight678][1]\n" + "vmla.f32 q12, q0, %e[weight345][0]\n" + "vmla.f32 q13, q2, %e[weight345][1]\n" + + "pld [%[in_ptr3], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr3]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr3], %[in_ptr3], #16\n" + + "vmla.f32 q12, q0, %e[weight678][0]\n" + "vmla.f32 q13, q2, %e[weight678][1]\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vmla.f32 q15, q3, %f[weight678][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + + "bne 1b\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [in_ptr3] "+r"(in_ptr3), + + [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + float32x4_t input3 = vld1q_f32(in_ptr3); + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + float32x4_t out1 = vmulq_f32(input1, weight012); + out1 = vmlaq_f32(out1, input2, weight345); + out1 = vmlaq_f32(out1, input3, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + out1 = vsetq_lane_f32(bias0, out1, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); + + float32x2_t out01 = vpadd_f32(out00, out11); + + *out_ptr0 = vget_lane_f32(out01, 0); + *out_ptr1 = vget_lane_f32(out01, 1); + + in_ptr0++; + in_ptr1++; + in_ptr2++; + in_ptr3++; + out_ptr0++; + out_ptr1++; + } + + in_ptr0 += w + 2; + in_ptr1 += w + 2; + in_ptr2 += w + 2; + in_ptr3 += w + 2; + + out_ptr0 += outw; + out_ptr1 += outw; + } + + for (; i < outh; i++) + { + int nn = outw >> 2; + int remain = outw & 0x03; + + if (nn > 0) + { + __asm __volatile("1:\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q14, q0, %e[weight012][0]\n" + "vmla.f32 q14, q2, %e[weight012][1]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vmla.f32 q14, q0, %e[weight345][0]\n" + "vmla.f32 q14, q2, %e[weight345][1]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr2], %[in_ptr2], #16\n" + + "vmla.f32 q14, q0, %e[weight678][0]\n" + "vmla.f32 q14, q2, %e[weight678][1]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + + "bne 1b\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0++; + in_ptr1++; + in_ptr2++; + out_ptr0++; + } + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + } + } +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // !__aarch64__ +} + +static void depthwise_conv3x3S1_padding(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + float *in_ptr3 = inbuf + 3 * w; + + float *out_ptr0 = outbuf + 0 * outw; + float *out_ptr1 = outbuf + 1 * outw; + + int i; + for (i = 0; i + 1 < outh; i += 2) + { + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + if (i == 0) + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q8, #0\n" + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr0], %[in_ptr0], #12\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q2, %e[weight345][0]\n" + "vmul.f32 q11, q0, %e[weight345][1]\n" + "vmul.f32 q12, q2, %e[weight012][0]\n" + "vmul.f32 q13, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr1], %[in_ptr1], #12\n" + + "vmla.f32 q10, q2, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q12, q2, %e[weight345][0]\n" + "vmla.f32 q13, q0, %e[weight345][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr2], %[in_ptr2], #12\n" + + "vmla.f32 q12, q2, %e[weight678][0]\n" + "vmla.f32 q13, q0, %e[weight678][1]\n" + "vmla.f32 q15, q3, %f[weight678][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "1:\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight345][0]\n" + "vmul.f32 q11, q2, %e[weight345][1]\n" + "vmul.f32 q12, q0, %e[weight012][0]\n" + "vmul.f32 q13, q2, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q2, %e[weight678][1]\n" + "vmla.f32 q12, q0, %e[weight345][0]\n" + "vmla.f32 q13, q2, %e[weight345][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr2], %[in_ptr2], #16\n" + + "vmla.f32 q12, q0, %e[weight678][0]\n" + "vmla.f32 q13, q2, %e[weight678][1]\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vmla.f32 q15, q3, %f[weight678][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "bne 1b\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), + [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + + for (; remain > 0; remain--) + { + // TODO: when nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight345); + out0 = vmlaq_f32(out0, input1, weight678); + + float32x4_t out1 = vmulq_f32(input0, weight012); + out1 = vmlaq_f32(out1, input1, weight345); + out1 = vmlaq_f32(out1, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + out1 = vsetq_lane_f32(bias0, out1, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); + + float32x2_t out01 = vpadd_f32(out00, out11); + + *out_ptr0 = vget_lane_f32(out01, 0); + *out_ptr1 = vget_lane_f32(out01, 1); + + in_ptr0++; + in_ptr1++; + in_ptr2++; + out_ptr0++; + out_ptr1++; + } + + in_ptr0 += 1; + in_ptr1 += 1; + in_ptr2 += 1; + in_ptr3 += w; + } + else if (i == outh - 2) + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q8, #0\n" + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr0], %[in_ptr0], #12\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q2, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr1], %[in_ptr1], #12\n" + + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q10, q2, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + "vmul.f32 q12, q2, %e[weight012][0]\n" + "vmul.f32 q13, q0, %e[weight012][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr2], %[in_ptr2], #12\n" + + "vmla.f32 q10, q2, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q12, q2, %e[weight345][0]\n" + "vmla.f32 q13, q0, %e[weight345][1]\n" + + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "1:\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q2, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q2, %e[weight345][1]\n" + "vmul.f32 q12, q0, %e[weight012][0]\n" + "vmul.f32 q13, q2, %e[weight012][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr2], %[in_ptr2], #16\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q2, %e[weight678][1]\n" + "vmla.f32 q12, q0, %e[weight345][0]\n" + "vmla.f32 q13, q2, %e[weight345][1]\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "bne 1b\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), + [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: when nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + float32x4_t out1 = vmulq_f32(input1, weight012); + out1 = vmlaq_f32(out1, input2, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + out1 = vsetq_lane_f32(bias0, out1, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); + + float32x2_t out01 = vpadd_f32(out00, out11); + + *out_ptr0 = vget_lane_f32(out01, 0); + *out_ptr1 = vget_lane_f32(out01, 1); + + in_ptr0++; + in_ptr1++; + in_ptr2++; + out_ptr0++; + out_ptr1++; + } + } + else + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q8, #0\n" + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr0], %[in_ptr0], #12\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q2, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr1], %[in_ptr1], #12\n" + + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q10, q2, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + "vmul.f32 q12, q2, %e[weight012][0]\n" + "vmul.f32 q13, q0, %e[weight012][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr2], %[in_ptr2], #12\n" + + "vmla.f32 q10, q2, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q12, q2, %e[weight345][0]\n" + "vmla.f32 q13, q0, %e[weight345][1]\n" + + "pld [%[in_ptr3], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr3]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr3], %[in_ptr3], #12\n" + + "vmla.f32 q15, q2, %e[weight678][0]\n" + "vmla.f32 q15, q0, %e[weight678][1]\n" + "vmla.f32 q15, q3, %f[weight678][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "1:\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q2, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vand q15, %q[qbias0], %q[qbias0]\n" + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q2, %e[weight345][1]\n" + "vmul.f32 q12, q0, %e[weight012][0]\n" + "vmul.f32 q13, q2, %e[weight012][1]\n" + + "pld [%[in_ptr2], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vmla.f32 q15, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr2], %[in_ptr2], #16\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q2, %e[weight678][1]\n" + "vmla.f32 q12, q0, %e[weight345][0]\n" + "vmla.f32 q13, q2, %e[weight345][1]\n" + + "pld [%[in_ptr3], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr3]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vmla.f32 q15, q3, %f[weight345][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr3], %[in_ptr3], #16\n" + + "vmla.f32 q15, q0, %e[weight678][0]\n" + "vmla.f32 q15, q2, %e[weight678][1]\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vmla.f32 q15, q3, %f[weight678][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q15, q15, q12\n" + "vadd.f32 q14, q14, q11\n" + "vadd.f32 q15, q15, q13\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "vst1.f32 {d30-d31}, [%[out_ptr1]]!\n" + "bne 1b\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [in_ptr3] "+r"(in_ptr3), + + [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: when nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + float32x4_t input3 = vld1q_f32(in_ptr3); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + input3 = vsetq_lane_f32(0.0f, input3, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + float32x4_t out1 = vmulq_f32(input1, weight012); + out1 = vmlaq_f32(out1, input2, weight345); + out1 = vmlaq_f32(out1, input3, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + out1 = vsetq_lane_f32(bias0, out1, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + float32x2_t out11 = vadd_f32(vget_low_f32(out1), vget_high_f32(out1)); + + float32x2_t out01 = vpadd_f32(out00, out11); + + *out_ptr0 = vget_lane_f32(out01, 0); + *out_ptr1 = vget_lane_f32(out01, 1); + + in_ptr0++; + in_ptr1++; + in_ptr2++; + in_ptr3++; + out_ptr0++; + out_ptr1++; + } + in_ptr0 += w + 1; + in_ptr1 += w + 1; + in_ptr2 += w + 1; + in_ptr3 += w + 1; + } + + out_ptr0 += outw; + out_ptr1 += outw; + } + + for (; i < outh; i++) + { + // TODO:if i == 0, pad_top comes here. + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + + if (nn > 0) + { + __asm __volatile("vmov.i32 q8, #0\n" + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr0], %[in_ptr0], #12\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q2, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q8, q0, #3\n" + "vext.32 q3, q0, q1, #1\n" + "add %[in_ptr1], %[in_ptr1], #12\n" + + "vmla.f32 q10, q2, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "1:\n" + "add %[in_ptr0], %[in_ptr0], #16\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q2, %e[weight012][1]\n" + + "pld [%[in_ptr1], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + "add %[in_ptr1], %[in_ptr1], #16\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q2, %e[weight345][1]\n" + + "pld [%[in_ptr0], #192]\n" + "vld1.f32 {d0-d2}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q2, q0, q1, #1\n" + "vext.32 q3, q0, q1, #2\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: when nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0++; + in_ptr1++; + out_ptr0++; + out_ptr1++; + } + } + } +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // __aarch64__ +} + +static void depthwise_conv3x3S2_nopad(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + + const int tailstep = w - 2 * outw + w; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + + float *out_ptr0 = outbuf + 0 * outw; + + int i; + for (i = 0; i < outh; i++) + { + int nn = outw >> 2; + int remain = outw & 0x03; + + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + out_ptr0++; + } + + in_ptr0 += tailstep; + in_ptr1 += tailstep; + in_ptr2 += tailstep; + } + } + +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // __aarch64__ +} + +static void depthwise_conv3x3S2_padding00(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + + float *out_ptr0 = outbuf + 0 * outw; + + int i; + for (i = 0; i < outh; i++) + { + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + + if (i == outh - 1) + { + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + } + else + { + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + out_ptr0++; + } + + in_ptr0 += w; + in_ptr1 += w; + in_ptr2 += w; + } + } + } +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // !__aarch64__ +} + +static void depthwise_conv3x3S2_padding01(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + + float *out_ptr0 = outbuf + 0 * outw; + + int i; + for (i = 0; i < outh; i++) + { + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + + if (i == outh - 1) + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q2, #0\n" + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr0], %[in_ptr0], #28\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q3, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" + "vmla.f32 q14, q1, %f[weight012][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr1], %[in_ptr1], #28\n" + + "vmla.f32 q10, q3, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + "vmla.f32 q14, q1, %f[weight345][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: if nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + } + else + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q2, #0\n" + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr0], %[in_ptr0], #28\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q3, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" + "vmla.f32 q14, q1, %f[weight012][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr1], %[in_ptr1], #28\n" + + "vmla.f32 q10, q3, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]\n" + "vmla.f32 q14, q1, %f[weight345][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr2], %[in_ptr2], #28\n" + + "vmla.f32 q10, q3, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q14, q1, %f[weight678][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: if nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + out_ptr0++; + } + + in_ptr0 += w; + in_ptr1 += w; + in_ptr2 += w; + } + } + } + +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // __aarch64__ +} + +static void depthwise_conv3x3S2_padding10(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + + float *out_ptr0 = outbuf + 0 * outw; + + int i; + for (i = 0; i < outh; i++) + { + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + + // TODO: i == 0 && i == outh -1 + if (i == 0) + { + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight345][0]\n" + "vmul.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight345); + out0 = vmlaq_f32(out0, input1, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + + in_ptr2 += w; + } + else if (i == outh - 1) + { + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + } + else + { + if (nn > 0) + { + __asm __volatile("pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + out_ptr0++; + } + + in_ptr0 += w; + in_ptr1 += w; + in_ptr2 += w; + } + } + } + +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // __aarch64__ +} + +static void depthwise_conv3x3S2_padding11(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convMat_t &bias) +{ +#if !__aarch64__ + int w = in_mat.w; + int h = in_mat.h; + int outw = out_mat.w; + int outh = out_mat.h; + int channels = in_mat.c; + +#pragma omp parallel for + for (int c = 0; c < channels; c++) + { + const float *filter = kernel.data + c * 9; +#ifdef NCNN + float *inbuf = in_mat.data + c * alignSize(w * h, 16 / sizeof(float)); + float *outbuf = out_mat.data + c * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *inbuf = in_mat.data + c * w * h; + float *outbuf = out_mat.data + c * outw * outh; +#endif // NCNN + float bias0 = bias.data ? bias.data[c] : 0.0f; + + register float32x4_t weight012 asm("q4") = vld1q_f32(filter); + register float32x4_t weight345 asm("q5") = vld1q_f32(filter + 3); + register float32x4_t weight678 asm("q6") = vld1q_f32(filter + 6); + register float32x4_t qbias0 asm("q7") = vdupq_n_f32(bias0); + + float *in_ptr0 = inbuf + 0 * w; + float *in_ptr1 = inbuf + 1 * w; + float *in_ptr2 = inbuf + 2 * w; + + float *out_ptr0 = outbuf + 0 * outw; + + int i; + for (i = 0; i < outh; i++) + { + int nn = (outw >> 2) - 1; + int remain = (outw & 0x03) + 4; + + // TODO: i == 0 && i == outh - 1 + if (i == 0) + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q2, #0\n" + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr0], %[in_ptr0], #28\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q3, %e[weight345][0]\n" + "vmul.f32 q11, q0, %e[weight345][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" + "vmla.f32 q14, q1, %f[weight345][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr1], %[in_ptr1], #28\n" + + "vmla.f32 q10, q3, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q14, q1, %f[weight678][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight345][0]\n" + "vmul.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: if nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight345); + out0 = vmlaq_f32(out0, input1, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + + in_ptr2 += w; + } + else if (i == outh - 1) + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q2, #0\n" + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr0], %[in_ptr0], #28\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q3, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" + "vmla.f32 q14, q1, %f[weight012][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr1], %[in_ptr1], #28\n" + + "vmla.f32 q10, q3, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + "vmla.f32 q14, q1, %f[weight345][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: if nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + out_ptr0++; + } + } + else + { + if (nn > 0) + { + __asm __volatile("vmov.i32 q2, #0\n" + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr0], %[in_ptr0], #28\n" + + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q3, %e[weight012][0]\n" + "vmul.f32 q11, q0, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]\n" + "vmla.f32 q14, q1, %f[weight012][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr1], %[in_ptr1], #28\n" + + "vmla.f32 q10, q3, %e[weight345][0]\n" + "vmla.f32 q11, q0, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]\n" + "vmla.f32 q14, q1, %f[weight345][0]\n" + "vext.32 q3, q2, q0, #3\n" + "add %[in_ptr2], %[in_ptr2], #28\n" + + "vmla.f32 q10, q3, %e[weight678][0]\n" + "vmla.f32 q11, q0, %e[weight678][1]\n" + "vmla.f32 q14, q1, %f[weight678][0]\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "beq 2f\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vext.32 q3, q0, q2, #1\n" + + "1:\n" + "vand q14, %q[qbias0], %q[qbias0]\n" + "vmul.f32 q10, q0, %e[weight012][0]\n" + "vmul.f32 q11, q1, %e[weight012][1]\n" + + "pld [%[in_ptr1], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr1]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr1]]\n" + "vmla.f32 q14, q3, %f[weight012][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight345][0]\n" + "vmla.f32 q11, q1, %e[weight345][1]\n" + + "pld [%[in_ptr2], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr2]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr2]]\n" + "vmla.f32 q14, q3, %f[weight345][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vmla.f32 q10, q0, %e[weight678][0]\n" + "vmla.f32 q11, q1, %e[weight678][1]\n" + + "pld [%[in_ptr0], #256]\n" + "vld2.f32 {d0-d3}, [%[in_ptr0]]!\n" + "vld1.f32 {d4[0]}, [%[in_ptr0]]\n" + "vmla.f32 q14, q3, %f[weight678][0]\n" + "vext.32 q3, q0, q2, #1\n" + + "vadd.f32 q14, q14, q10\n" + "vadd.f32 q14, q14, q11\n" + + "subs %[nn], %[nn], #1\n" + "vst1.f32 {d28-d29}, [%[out_ptr0]]!\n" + "bne 1b\n" + "sub %[in_ptr0], %[in_ptr0], #32\n" + "2:\n" + : [in_ptr0] "+r"(in_ptr0), [in_ptr1] "+r"(in_ptr1), + [in_ptr2] "+r"(in_ptr2), [out_ptr0] "+r"(out_ptr0), [nn] "+r"(nn) + : [weight012] "w"(weight012), [weight345] "w"(weight345), + [weight678] "w"(weight678), [qbias0] "w"(qbias0) + : "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", + "q15", "cc", "memory"); + } + for (; remain > 0; remain--) + { + // TODO: if nn == 0, pad_left comes here. + float32x4_t input0 = vld1q_f32(in_ptr0); + float32x4_t input1 = vld1q_f32(in_ptr1); + float32x4_t input2 = vld1q_f32(in_ptr2); + + if (remain == 1) + { + input0 = vsetq_lane_f32(0.0f, input0, 2); + input1 = vsetq_lane_f32(0.0f, input1, 2); + input2 = vsetq_lane_f32(0.0f, input2, 2); + } + + float32x4_t out0 = vmulq_f32(input0, weight012); + out0 = vmlaq_f32(out0, input1, weight345); + out0 = vmlaq_f32(out0, input2, weight678); + + out0 = vsetq_lane_f32(bias0, out0, 3); + + float32x2_t out00 = vadd_f32(vget_low_f32(out0), vget_high_f32(out0)); + + float32x2_t out01 = vpadd_f32(out00, out00); + + *out_ptr0 = vget_lane_f32(out01, 0); + + in_ptr0 += 2; + in_ptr1 += 2; + in_ptr2 += 2; + out_ptr0++; + } + + in_ptr0 += w; + in_ptr1 += w; + in_ptr2 += w; + } + } + } +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)bias; +#endif // __aarch64__ +} + +static void depthwise_conv_colmajor(const convMat_t &in_mat, convMat_t &out_mat, + const convMat_t &kernel, const convParams_t &in_param) +{ +#if __aarch64__ + const int w = in_mat.w; + const int h = in_mat.h; + const int outw = out_mat.w; + const int outh = out_mat.h; + const int channels = out_mat.c; + const int stridew = in_param.stride_w; + const int strideh = in_param.stride_h; + const int padding = in_param.padding; + const int padw = in_param.pad_w; + const int padh = in_param.pad_h; + +#pragma omp parallel for + for (int oh = 0; oh < outh; oh++) + { + const float *input_data0 = in_mat.data + (oh * strideh - padh) * w * channels; + + memset(out_mat.data + oh * outw * channels, 0x00, outw * channels * sizeof(float)); + + for (int kh = 0; kh < in_param.kernel_h; kh++) + { + for (int kw = 0; kw < in_param.kernel_w; kw++) + { + const float *kernel_data = kernel.data + (kh * in_param.kernel_w + kw) * channels; + const float *input_data1 = input_data0 + (kh * w + kw) * channels; + + if (padding && ((oh * strideh + kh < padh) || (oh * strideh + kh >= padh + h))) + { + continue; + } + + int ow = 0; + for (; ow + 3 < outw; /*ow += 4*/) + { + if (((ow + 3) * stridew + kw < padw) || (ow * stridew + kw >= padw + w)) + { + ow += 4; + continue; + } + else if ((ow + 3) * stridew + kw >= padw + w) + { + break; + } + else if (ow * stridew + kw < padw) + { + int delta = (padw - kw) / stridew - ow; + delta += (padw - kw) % stridew ? 1 : 0; + ow += delta; + continue; + } + + int nn = channels >> 2; + int remain = channels & 0x03; + + const float *input_r0 = input_data1 + (ow * stridew - padw) * channels; + + const float *input_r1 = input_r0 + stridew * channels; + const float *input_r2 = input_r1 + stridew * channels; + const float *input_r3 = input_r2 + stridew * channels; + const float *weights_data = kernel_data; + float *output_r0 = out_mat.data + (oh * outw + ow) * channels; + float *output_r1 = output_r0 + channels; + float *output_r2 = output_r1 + channels; + float *output_r3 = output_r2 + channels; + + if (nn > 0) + { + int _n = (nn + 1) >> 1; + int oddn = nn & 1; + + asm volatile("subs %[_n], %[_n], #1\n" + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_r0]], #16\n" + "ld1 {v6.4s}, [%[input_r1]], #16\n" + "ld1 {v7.4s}, [%[input_r2]], #16\n" + "ld1 {v8.4s}, [%[input_r3]], #16\n" + "beq 1f\n" + + "0:\n" + "ld1 {v24.4s, v25.4s}, [%[output_r0]]\n" + "ld1 {v26.4s, v27.4s}, [%[output_r1]]\n" + "ld1 {v28.4s, v29.4s}, [%[output_r2]]\n" + "ld1 {v30.4s, v31.4s}, [%[output_r3]]\n" + + "ld1 {v9.4s}, [%[weights_data]], #16\n" + "ld1 {v10.4s}, [%[input_r0]], #16\n" + "ld1 {v11.4s}, [%[input_r1]], #16\n" + "ld1 {v12.4s}, [%[input_r2]], #16\n" + "ld1 {v13.4s}, [%[input_r3]], #16\n" + + "fmla v24.4s, v4.4s, v5.4s\n" + "fmla v26.4s, v4.4s, v6.4s\n" + + "fmla v28.4s, v4.4s, v7.4s\n" + "fmla v30.4s, v4.4s, v8.4s\n" + + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_r0]], #16\n" + "ld1 {v6.4s}, [%[input_r1]], #16\n" + "ld1 {v7.4s}, [%[input_r2]], #16\n" + "ld1 {v8.4s}, [%[input_r3]], #16\n" + + "fmla v25.4s, v9.4s, v10.4s\n" + "fmla v27.4s, v9.4s, v11.4s\n" + + "fmla v29.4s, v9.4s, v12.4s\n" + "fmla v31.4s, v9.4s, v13.4s\n" + + "st1 {v24.4s, v25.4s}, [%[output_r0]], #32\n" + "st1 {v26.4s, v27.4s}, [%[output_r1]], #32\n" + "st1 {v28.4s, v29.4s}, [%[output_r2]], #32\n" + "st1 {v30.4s, v31.4s}, [%[output_r3]], #32\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v24.4s}, [%[output_r0]]\n" + "ld1 {v26.4s}, [%[output_r1]]\n" + "ld1 {v28.4s}, [%[output_r2]]\n" + "ld1 {v30.4s}, [%[output_r3]]\n" + "cmp %[oddn], #1\n" + + "fmla v24.4s, v4.4s, v5.4s\n" + "fmla v26.4s, v4.4s, v6.4s\n" + + "fmla v28.4s, v4.4s, v7.4s\n" + "fmla v30.4s, v4.4s, v8.4s\n" + + "st1 {v24.4s}, [%[output_r0]], #16\n" + "st1 {v26.4s}, [%[output_r1]], #16\n" + "st1 {v28.4s}, [%[output_r2]], #16\n" + "st1 {v30.4s}, [%[output_r3]], #16\n" + + "beq 2f\n" + "ld1 {v25.4s}, [%[output_r0]]\n" + "ld1 {v27.4s}, [%[output_r1]]\n" + "ld1 {v29.4s}, [%[output_r2]]\n" + "ld1 {v31.4s}, [%[output_r3]]\n" + + "ld1 {v9.4s}, [%[weights_data]], #16\n" + "ld1 {v10.4s}, [%[input_r0]], #16\n" + "ld1 {v11.4s}, [%[input_r1]], #16\n" + "ld1 {v12.4s}, [%[input_r2]], #16\n" + "ld1 {v13.4s}, [%[input_r3]], #16\n" + + "fmla v25.4s, v9.4s, v10.4s\n" + "fmla v27.4s, v9.4s, v11.4s\n" + + "fmla v29.4s, v9.4s, v12.4s\n" + "fmla v31.4s, v9.4s, v13.4s\n" + + "st1 {v25.4s}, [%[output_r0]], #16\n" + "st1 {v27.4s}, [%[output_r1]], #16\n" + "st1 {v29.4s}, [%[output_r2]], #16\n" + "st1 {v31.4s}, [%[output_r3]], #16\n" + "2:\n" + : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), + [input_r1] "+r"(input_r1), [input_r2] "+r"(input_r2), + [input_r3] "+r"(input_r3), [output_r0] "+r"(output_r0), + [output_r1] "+r"(output_r1), [output_r2] "+r"(output_r2), + [output_r3] "+r"(output_r3), [_n] "+r"(_n) + : [oddn] "r"(oddn) + : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + if (remain >= 2) + { + asm volatile( + "ld1 {v24.2s}, [%[output_r0]]\n" + "ld1 {v26.2s}, [%[output_r1]]\n" + "ld1 {v28.2s}, [%[output_r2]]\n" + "ld1 {v30.2s}, [%[output_r3]]\n" + "ld1 {v4.2s}, [%[weights_data]], #8\n" + "ld1 {v5.2s}, [%[input_r0]], #8\n" + + "ld1 {v6.2s}, [%[input_r1]], #8\n" + "ld1 {v7.2s}, [%[input_r2]], #8\n" + "ld1 {v8.2s}, [%[input_r3]], #8\n" + + "fmla v24.2s, v4.2s, v5.2s\n" + "fmla v26.2s, v4.2s, v6.2s\n" + + "fmla v28.2s, v4.2s, v7.2s\n" + "fmla v30.2s, v4.2s, v8.2s\n" + + "st1 {v24.2s}, [%[output_r0]], #8\n" + "st1 {v26.2s}, [%[output_r1]], #8\n" + "st1 {v28.2s}, [%[output_r2]], #8\n" + "st1 {v30.2s}, [%[output_r3]], #8\n" + : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), + [input_r1] "+r"(input_r1), [input_r2] "+r"(input_r2), [input_r3] "+r"(input_r3), + [output_r0] "+r"(output_r0), [output_r1] "+r"(output_r1), + [output_r2] "+r"(output_r2), [output_r3] "+r"(output_r3) + : + : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v24", "v26", "v28", "v30"); + remain -= 2; + } + + if (remain > 0) + { + *output_r0++ += (*weights_data) * (*input_r0++); + *output_r1++ += (*weights_data++) * (*input_r1++); + *output_r2++ += (*weights_data) * (*input_r2++); + *output_r3++ += (*weights_data++) * (*input_r3++); + } + ow += 4; + } + + for (; ow + 1 < outw; /*ow += 2*/) + { + if (padding) + { + if (((ow + 1) * stridew + kw < padw) || (ow * stridew + kw >= padw + w)) + { + ow += 2; + continue; + } + else if ((ow + 1) * stridew + kw >= padw + w) + { + break; + } + else if (ow * stridew + kw < padw) + { + ow++; + continue; + } + } + + int nn = channels >> 2; + int remain = channels & 0x03; + + const float *input_r0 = input_data1 + (ow * stridew - padw) * channels; + + const float *input_r1 = input_r0 + stridew * channels; + const float *weights_data = kernel_data; + float *output_r0 = out_mat.data + (oh * outw + ow) * channels; + float *output_r1 = output_r0 + channels; + + if (nn > 0) + { + int _n = (nn + 1) >> 1; + int oddn = nn & 1; + + asm volatile("subs %[_n], %[_n], #1\n" + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_r0]], #16\n" + "ld1 {v6.4s}, [%[input_r1]], #16\n" + "beq 1f\n" + + "0:\n" + "ld1 {v24.4s, v25.4s}, [%[output_r0]]\n" + "ld1 {v26.4s, v27.4s}, [%[output_r1]]\n" + + "ld1 {v9.4s}, [%[weights_data]], #16\n" + "ld1 {v10.4s}, [%[input_r0]], #16\n" + "ld1 {v11.4s}, [%[input_r1]], #16\n" + + "fmla v24.4s, v4.4s, v5.4s\n" + "fmla v26.4s, v4.4s, v6.4s\n" + + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_r0]], #16\n" + "ld1 {v6.4s}, [%[input_r1]], #16\n" + + "fmla v25.4s, v9.4s, v10.4s\n" + "fmla v27.4s, v9.4s, v11.4s\n" + + "st1 {v24.4s, v25.4s}, [%[output_r0]], #32\n" + "st1 {v26.4s, v27.4s}, [%[output_r1]], #32\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v24.4s}, [%[output_r0]]\n" + "ld1 {v26.4s}, [%[output_r1]]\n" + "cmp %[oddn], #1\n" + + "fmla v24.4s, v4.4s, v5.4s\n" + "fmla v26.4s, v4.4s, v6.4s\n" + + "st1 {v24.4s}, [%[output_r0]], #16\n" + "st1 {v26.4s}, [%[output_r1]], #16\n" + + "beq 2f\n" + "ld1 {v25.4s}, [%[output_r0]]\n" + "ld1 {v27.4s}, [%[output_r1]]\n" + + "ld1 {v9.4s}, [%[weights_data]], #16\n" + "ld1 {v10.4s}, [%[input_r0]], #16\n" + "ld1 {v11.4s}, [%[input_r1]], #16\n" + + "fmla v25.4s, v9.4s, v10.4s\n" + "fmla v27.4s, v9.4s, v11.4s\n" + + "st1 {v25.4s}, [%[output_r0]], #16\n" + "st1 {v27.4s}, [%[output_r1]], #16\n" + "2:\n" + : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), + [input_r1] "+r"(input_r1), [output_r0] "+r"(output_r0), + [output_r1] "+r"(output_r1), [_n] "+r"(_n) + : [oddn] "r"(oddn) + : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + if (remain >= 2) + { + asm volatile("ld1 {v24.2s}, [%[output_r0]]\n" + "ld1 {v26.2s}, [%[output_r1]]\n" + "ld1 {v4.2s}, [%[weights_data]], #8\n" + "ld1 {v5.2s}, [%[input_r0]], #8\n" + + "ld1 {v6.2s}, [%[input_r1]], #8\n" + + "fmla v24.2s, v4.2s, v5.2s\n" + "fmla v26.2s, v4.2s, v6.2s\n" + + "st1 {v24.2s}, [%[output_r0]], #8\n" + "st1 {v26.2s}, [%[output_r1]], #8\n" + : [weights_data] "+r"(weights_data), [input_r0] "+r"(input_r0), + [input_r1] "+r"(input_r1), [output_r0] "+r"(output_r0), + [output_r1] "+r"(output_r1) + : + : "cc", "memory", "v4", "v5", "v6", "v7", "v8", "v24", "v26", "v28", + "v30"); + remain -= 2; + } + + if (remain > 0) + { + *output_r0++ += (*weights_data) * (*input_r0++); + *output_r1++ += (*weights_data++) * (*input_r1++); + } + ow += 2; + } + + for (; ow < outw; ow++) + { + const float *input_data = input_data1 + (ow * stridew - padw) * channels; + + if (padding && ((ow * stridew + kw < padw) || (ow * strideh + kw >= padw + w))) + { + continue; + } + + int nn = channels >> 2; + int remain = channels & 0x03; + + const float *weights_data = kernel_data; + float *output_data = out_mat.data + (oh * outw + ow) * channels; + + if (nn > 0) + { + int _n = (nn + 1) >> 1; + int oddn = nn & 1; + + asm volatile("subs %[_n], %[_n], #1\n" + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_data]], #16\n" + "beq 1f\n" + + "0:\n" + "ld1 {v30.4s, v31.4s}, [%[output_data]]\n" + "ld1 {v6.4s}, [%[weights_data]], #16\n" + "ld1 {v7.4s}, [%[input_data]], #16\n" + "fmla v30.4s, v4.4s, v5.4s\n" + + "ld1 {v4.4s}, [%[weights_data]], #16\n" + "ld1 {v5.4s}, [%[input_data]], #16\n" + "fmla v31.4s, v6.4s, v7.4s\n" + + "st1 {v30.4s, v31.4s}, [%[output_data]], #32\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v30.4s}, [%[output_data]]\n" + "cmp %[oddn], #1\n" + "fmla v30.4s, v4.4s, v5.4s\n" + "st1 {v30.4s}, [%[output_data]], #16\n" + "beq 2f\n" + "ld1 {v31.4s}, [%[output_data]]\n" + "ld1 {v6.4s}, [%[weights_data]], #16\n" + "ld1 {v7.4s}, [%[input_data]], #16\n" + "fmla v31.4s, v6.4s, v7.4s\n" + + "st1 {v31.4s}, [%[output_data]], #16\n" + "2:\n" + : [weights_data] "+r"(weights_data), [input_data] "+r"(input_data), + [output_data] "+r"(output_data), [_n] "+r"(_n) + : [oddn] "r"(oddn) + : "cc", "memory", "v4", "v5", "v30", "v31"); + } + if (remain >= 2) + { + asm volatile("ld1 {v30.2s}, [%[output_data]]\n" + "ld1 {v4.2s}, [%[weights_data]], #8\n" + "ld1 {v5.2s}, [%[input_data]], #8\n" + + "fmla v30.2s, v4.2s, v5.2s\n" + + "st1 {v30.2s}, [%[output_data]], #8\n" + : [weights_data] "+r"(weights_data), [input_data] "+r"(input_data), + [output_data] "+r"(output_data) + : + : "cc", "memory", "v4", "v5", "v30"); + remain -= 2; + } + + if (remain > 0) + { + *output_data++ += (*weights_data++) * (*input_data++); + } + } + } + } + } +#else // __aarch64__ + (void)in_mat; + (void)out_mat; + (void)kernel; + (void)in_param; +#endif // __aarch64__ +} + +void srcn_depthwise_conv(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, + const convMat_t &bias, const convParams_t &in_param, int num_threads, + convType_t conv_type) +{ + omp_set_num_threads(num_threads); + + if (conv_type == col_major) + { + depthwise_conv_colmajor(in_mat, out_mat, weights_mat, in_param); + return; + } + + else if (conv_type == row_major) + { + if (in_param.kernel_w == 3 && in_param.kernel_h == 3 && in_param.dilation_w == 1 && + in_param.dilation_h == 1) + { + if (in_param.stride_w == 1 && in_param.stride_h == 1) + { + if (in_param.padding == 0) + depthwise_conv3x3S1_nopad(in_mat, out_mat, weights_mat, bias); + else + depthwise_conv3x3S1_padding(in_mat, out_mat, weights_mat, bias); + } + else if (in_param.stride_w == 2 && in_param.stride_h == 2) + { + if (in_param.padding == 0) + depthwise_conv3x3S2_nopad(in_mat, out_mat, weights_mat, bias); + else + { + if (in_param.pad_w == 0 && in_param.pad_h == 0) + depthwise_conv3x3S2_padding00(in_mat, out_mat, weights_mat, bias); + else if (in_param.pad_w == 0 && in_param.pad_h == 1) + depthwise_conv3x3S2_padding10(in_mat, out_mat, weights_mat, bias); + else if (in_param.pad_w == 1 && in_param.pad_h == 0) + depthwise_conv3x3S2_padding01(in_mat, out_mat, weights_mat, bias); + else if (in_param.pad_w == 1 && in_param.pad_h == 1) + depthwise_conv3x3S2_padding11(in_mat, out_mat, weights_mat, bias); + } + } + } + } +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/direct_conv_colmajor.cc b/compute/ncnn/src/srcn/direct_conv_colmajor.cc new file mode 100644 index 000000000..300235222 --- /dev/null +++ b/compute/ncnn/src/srcn/direct_conv_colmajor.cc @@ -0,0 +1,5872 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include <stdlib.h> +#include <arm_neon.h> +#include "ncnn/srcn/conv_type.h" + +namespace nnfw +{ +namespace srcn +{ + +#if __aarch64__ +static void direct_conv_l(const convMat_t &bottom_blob, convMat_t &top_blob, + const convMat_t &_kernel, const int _stride, const int padding, + const int pad_top, const int pad_left) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + const int kernel_w = _kernel.w; + const int kernel_h = _kernel.h; + + for (int m = 0; m < kernel_w * kernel_h; m++) + { + const float *_kernel0 = _kernel.data + m * inch * outch; + const float *img0 = + bottom_blob.data + (m / kernel_w - pad_top) * w * inch + (m % kernel_w - pad_left) * inch; + +#ifdef _OPENMP +#pragma omp parallel for +#endif // _OPENMP + for (int p = 0; p < outh; p++) + { + float *out0 = top_blob.data + p * outw * outch; + + // clear output + if (m == 0) + { + for (int j = 0; j < outw * outch; j++) + { + *(out0 + j) = 0.f; + } + } + + if (padding) + { + if (((p * _stride + m / kernel_w) < pad_top) || (p * _stride + m / kernel_w >= pad_top + h)) + { + continue; + } + } + + const float *img1 = img0 + p * w * inch * _stride; + + int q = 0; + for (; q + 3 < outw; /*q += 4*/) + { + if (padding) + { + if (((q + 3) * _stride + m % kernel_w < pad_left) || + (q * _stride + m % kernel_w) >= pad_left + w) + { + out0 += outch * 4; + img1 += inch * _stride * 4; + q += 4; + continue; + } + else if ((q + 3) * _stride + m % kernel_w >= pad_left + w) + { + break; + } + else if (q * _stride + m % kernel_w < pad_left) + { + int delta = (pad_left - m % kernel_w) / _stride - q; + delta += (pad_left - m % kernel_w) % _stride ? 1 : 0; + out0 += outch * delta; + img1 += inch * _stride * delta; + q += delta; + continue; + } + } + + const float *_x0 = img1; + const float *_x1 = img1 + inch * _stride; + const float *_x2 = img1 + inch * _stride * 2; + const float *_x3 = img1 + inch * _stride * 3; + const float *kernel0 = _kernel0; + + int i = 0; + for (; i + 3 < inch; i += 4) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); + register float32x4_t rx1 asm("v5") = vld1q_f32(_x1); + register float32x4_t rx2 asm("v16") = vld1q_f32(_x2); + register float32x4_t rx3 asm("v17") = vld1q_f32(_x3); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + float *outptr2 = out0 + outch * 2; + float *outptr3 = out0 + outch * 3; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v30.4s, v8.4s, %[rx2].s[2]\n" + "fmla v31.4s, v8.4s, %[rx3].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + "fmla v30.4s, v9.4s, %[rx2].s[3]\n" + "fmla v31.4s, v9.4s, %[rx3].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v30.4s, v11.4s, %[rx2].s[1]\n" + "fmla v31.4s, v11.4s, %[rx3].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v15.4s, v12.4s, %[rx1].s[2]\n" + "fmla v30.4s, v12.4s, %[rx2].s[2]\n" + "fmla v31.4s, v12.4s, %[rx3].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + "fmla v15.4s, v13.4s, %[rx1].s[3]\n" + "fmla v30.4s, v13.4s, %[rx2].s[3]\n" + "fmla v31.4s, v13.4s, %[rx3].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v30.4s, v8.4s, %[rx2].s[2]\n" + "fmla v31.4s, v8.4s, %[rx3].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + "fmla v30.4s, v9.4s, %[rx2].s[3]\n" + "fmla v31.4s, v9.4s, %[rx3].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v30.4s, v11.4s, %[rx2].s[1]\n" + "fmla v31.4s, v11.4s, %[rx3].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v15.4s, v12.4s, %[rx1].s[2]\n" + "fmla v30.4s, v12.4s, %[rx2].s[2]\n" + "fmla v31.4s, v12.4s, %[rx3].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + "fmla v15.4s, v13.4s, %[rx1].s[3]\n" + "fmla v30.4s, v13.4s, %[rx2].s[3]\n" + "fmla v31.4s, v13.4s, %[rx3].s[3]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v30.4s, v8.4s, %[rx2].s[2]\n" + "fmla v31.4s, v8.4s, %[rx3].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + "fmla v30.4s, v9.4s, %[rx2].s[3]\n" + "fmla v31.4s, v9.4s, %[rx3].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), + [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v30", "v31"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + "ld1 {v30.2s}, [%[outptr2]]\n" + "ld1 {v31.2s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + "fmla v30.2s, v6.2s, %[rx2].s[0]\n" + "fmla v31.2s, v6.2s, %[rx3].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + "fmla v15.2s, v7.2s, %[rx1].s[1]\n" + "fmla v30.2s, v7.2s, %[rx2].s[1]\n" + "fmla v31.2s, v7.2s, %[rx3].s[1]\n" + "fmla v14.2s, v8.2s, %[rx0].s[2]\n" + "fmla v15.2s, v8.2s, %[rx1].s[2]\n" + "fmla v30.2s, v8.2s, %[rx2].s[2]\n" + "fmla v31.2s, v8.2s, %[rx3].s[2]\n" + "fmla v14.2s, v9.2s, %[rx0].s[3]\n" + "fmla v15.2s, v9.2s, %[rx1].s[3]\n" + "fmla v30.2s, v9.2s, %[rx2].s[3]\n" + "fmla v31.2s, v9.2s, %[rx3].s[3]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + "st1 {v30.2s}, [%[outptr2]], #8\n" + "st1 {v31.2s}, [%[outptr3]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), + + [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14", "v15", "v30", + "v31"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x0 + 3)); + + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x1 + 3)); + + *outptr2 += (*kernel0) * (*_x2) + (*(kernel0 + outch)) * (*(_x2 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x2 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x2 + 3)); + + *outptr3 += (*kernel0) * (*_x3) + (*(kernel0 + outch)) * (*(_x3 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x3 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x3 + 3)); + + kernel0++; + outptr0++; + outptr1++; + outptr2++; + outptr3++; + } + + kernel0 += outch * 3; + _x0 += 4; + _x1 += 4; + _x2 += 4; + _x3 += 4; + } + + for (; i + 1 < inch; i += 2) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_f32(_x0); + register float32x2_t rx1 asm("v5") = vld1_f32(_x1); + register float32x2_t rx2 asm("v16") = vld1_f32(_x2); + register float32x2_t rx3 asm("v17") = vld1_f32(_x3); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + float *outptr2 = out0 + outch * 2; + float *outptr3 = out0 + outch * 3; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile( + "cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v30.4s, v11.4s, %[rx2].s[1]\n" + "fmla v31.4s, v11.4s, %[rx3].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v30.4s, v11.4s, %[rx2].s[1]\n" + "fmla v31.4s, v11.4s, %[rx3].s[1]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [_n] "+r"(_n), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), + [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14", "v15", "v30", "v31"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + "ld1 {v30.2s}, [%[outptr2]]\n" + "ld1 {v31.2s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + "fmla v30.2s, v6.2s, %[rx2].s[0]\n" + "fmla v31.2s, v6.2s, %[rx3].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + "fmla v15.2s, v7.2s, %[rx1].s[1]\n" + "fmla v30.2s, v7.2s, %[rx2].s[1]\n" + "fmla v31.2s, v7.2s, %[rx3].s[1]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + "st1 {v30.2s}, [%[outptr2]], #8\n" + "st1 {v31.2s}, [%[outptr3]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), + + [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v7", "v14", "v15", "v30", "v31"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); + *outptr2 += (*kernel0) * (*_x2) + (*(kernel0 + outch)) * (*(_x2 + 1)); + *outptr3 += (*kernel0) * (*_x3) + (*(kernel0 + outch)) * (*(_x3 + 1)); + + kernel0++; + outptr0++; + outptr1++; + outptr2++; + outptr3++; + } + + kernel0 += outch; + _x0 += 2; + _x1 += 2; + _x2 += 2; + _x3 += 2; + } + + for (; i < inch; i++) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); + register float32x2_t rx1 asm("v5") = vld1_dup_f32(_x1); + register float32x2_t rx2 asm("v16") = vld1_dup_f32(_x2); + register float32x2_t rx3 asm("v17") = vld1_dup_f32(_x3); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + float *outptr2 = out0 + outch * 2; + float *outptr3 = out0 + outch * 3; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile( + "cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [_n] "+r"(_n), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v10", "v14", "v15", "v30", "v31"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + "ld1 {v30.2s}, [%[outptr2]]\n" + "ld1 {v31.2s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + "fmla v30.2s, v6.2s, %[rx2].s[0]\n" + "fmla v31.2s, v6.2s, %[rx3].s[0]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + "st1 {v30.2s}, [%[outptr2]], #8\n" + "st1 {v31.2s}, [%[outptr3]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [rx0] "w"(rx0), [rx1] "w"(rx1), + + [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v14", "v15", "v30", "v31"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0); + *outptr1 += (*kernel0) * (*_x1); + *outptr2 += (*kernel0) * (*_x2); + *outptr3 += (*kernel0) * (*_x3); + + kernel0++; + outptr0++; + outptr1++; + outptr2++; + outptr3++; + } + + _x0 += 1; + _x1 += 1; + _x2 += 1; + _x3 += 1; + } + + img1 += inch * 4 * _stride; + out0 += outch * 4; + q += 4; + } + + for (; q + 1 < outw; /*q += 2*/) + { + if (padding) + { + if (((q + 1) * _stride + m % kernel_w < pad_left) || + (q * _stride + m % kernel_w) >= pad_left + w) + { + out0 += outch * 2; + img1 += inch * _stride * 2; + q += 2; + continue; + } + else if ((q + 1) * _stride + m % kernel_w >= pad_left + w) + { + break; + } + else if (q * _stride + m % kernel_w < pad_left) + { + out0 += outch; + img1 += inch * _stride; + q++; + continue; + } + } + + const float *_x0 = img1; + const float *_x1 = img1 + inch * _stride; + const float *kernel0 = _kernel0; + + int i = 0; + for (; i + 3 < inch; i += 4) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); + register float32x4_t rx1 asm("v5") = vld1q_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v15.4s, v12.4s, %[rx1].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + "fmla v15.4s, v13.4s, %[rx1].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v15.4s, v12.4s, %[rx1].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + "fmla v15.4s, v13.4s, %[rx1].s[3]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + "fmla v15.2s, v7.2s, %[rx1].s[1]\n" + "fmla v14.2s, v8.2s, %[rx0].s[2]\n" + "fmla v15.2s, v8.2s, %[rx1].s[2]\n" + "fmla v14.2s, v9.2s, %[rx0].s[3]\n" + "fmla v15.2s, v9.2s, %[rx1].s[3]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14", "v15"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x0 + 3)); + + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x1 + 3)); + + kernel0++; + outptr0++; + outptr1++; + } + + kernel0 += outch * 3; + _x0 += 4; + _x1 += 4; + } + + for (; i + 1 < inch; i += 2) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_f32(_x0); + register float32x2_t rx1 asm("v5") = vld1_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14", "v15"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + "fmla v15.2s, v7.2s, %[rx1].s[1]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) + : "cc", "memory", "x0", "v6", "v7", "v14", "v15"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); + + kernel0++; + outptr0++; + outptr1++; + } + + kernel0 += outch; + _x0 += 2; + _x1 += 2; + } + + for (; i < inch; i++) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); + register float32x2_t rx1 asm("v5") = vld1_dup_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v10", "v14", "v15"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [rx0] "w"(rx0), [rx1] "w"(rx1) + : "cc", "memory", "x0", "v6", "v14", "v15"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0); + *outptr1 += (*kernel0) * (*_x1); + + kernel0++; + outptr0++; + outptr1++; + } + + _x0 += 1; + _x1 += 1; + } + + img1 += inch * 2 * _stride; + out0 += outch * 2; + q += 2; + } + + for (; q < outw; q++) + { + if (padding) + { + if ((q * _stride + m % kernel_w < pad_left) || + (q * _stride + m % kernel_w >= pad_left + w)) + { + img1 += inch * _stride; + out0 += outch; + continue; + } + } + + const float *_x0 = img1; + const float *kernel0 = _kernel0; + + int i = 0; + for (; i + 3 < inch; i += 4) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); + + float *outptr0 = out0; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + "fmla v14.2s, v8.2s, %[rx0].s[2]\n" + "fmla v14.2s, v9.2s, %[rx0].s[3]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [stride] "r"(stride), [rx0] "w"(rx0) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x0 + 3)); + + kernel0++; + outptr0++; + } + + kernel0 += outch * 3; + _x0 += 4; + } + + for (; i + 1 < inch; i += 2) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_f32(_x0); + + float *outptr0 = out0; + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [stride] "r"(stride), [rx0] "w"(rx0) + : "cc", "memory", "x0", "v6", "v7", "v14"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); + + kernel0++; + outptr0++; + } + + kernel0 += outch; + _x0 += 2; + } + + for (; i < inch; i++) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); + + float *outptr0 = out0; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [rx0] "w"(rx0), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v10", "v14"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [rx0] "w"(rx0) + : "cc", "memory", "x0", "v6", "v14"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0); + + kernel0++; + outptr0++; + } + + _x0 += 1; + } + + img1 += inch * _stride; + out0 += outch; + } + } + } +} + +static void direct_conv_s(const convMat_t &bottom_blob, convMat_t &top_blob, + const convMat_t &_kernel, const int _stride, const int padding, + const int pad_top, const int pad_left) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + const int kernel_w = _kernel.w; + const int kernel_h = _kernel.h; + +#ifdef _OPENMP +#pragma omp parallel for +#endif + for (int p = 0; p < outh; p++) + { + const float *img0 = bottom_blob.data + (p * _stride - pad_top) * w * inch; + float *out = top_blob.data + p * outw * outch; + + // clear output + for (int j = 0; j < outw * outch; j++) + { + *(out + j) = 0.f; + } + + for (int m = 0; m < kernel_w * kernel_h; m++) + { + if (padding) + { + if (((p * _stride + m / kernel_w) < pad_top) || (p * _stride + m / kernel_w >= pad_top + h)) + { + continue; + } + } + + float *out0 = out; + const float *_kernel0 = _kernel.data + m * inch * outch; + const float *img1 = img0 + (m / kernel_w) * w * inch + (m % kernel_w - pad_left) * inch; + + int q = 0; + for (; q + 3 < outw; /*q += 4*/) + { + if (padding) + { + if (((q + 3) * _stride + m % kernel_w < pad_left) || + (q * _stride + m % kernel_w) >= pad_left + w) + { + out0 += outch * 4; + img1 += inch * _stride * 4; + q += 4; + continue; + } + else if ((q + 3) * _stride + m % kernel_w >= pad_left + w) + { + break; + } + else if (q * _stride + m % kernel_w < pad_left) + { + int delta = (pad_left - m % kernel_w) / _stride - q; + delta += (pad_left - m % kernel_w) % _stride ? 1 : 0; + out0 += outch * delta; + img1 += inch * _stride * delta; + q += delta; + continue; + } + } + + const float *_x0 = img1; + const float *_x1 = img1 + inch * _stride; + const float *_x2 = img1 + inch * _stride * 2; + const float *_x3 = img1 + inch * _stride * 3; + const float *kernel0 = _kernel0; + + int i = 0; + for (; i + 3 < inch; i += 4) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); + register float32x4_t rx1 asm("v5") = vld1q_f32(_x1); + register float32x4_t rx2 asm("v16") = vld1q_f32(_x2); + register float32x4_t rx3 asm("v17") = vld1q_f32(_x3); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + float *outptr2 = out0 + outch * 2; + float *outptr3 = out0 + outch * 3; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v30.4s, v8.4s, %[rx2].s[2]\n" + "fmla v31.4s, v8.4s, %[rx3].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + "fmla v30.4s, v9.4s, %[rx2].s[3]\n" + "fmla v31.4s, v9.4s, %[rx3].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v30.4s, v11.4s, %[rx2].s[1]\n" + "fmla v31.4s, v11.4s, %[rx3].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v15.4s, v12.4s, %[rx1].s[2]\n" + "fmla v30.4s, v12.4s, %[rx2].s[2]\n" + "fmla v31.4s, v12.4s, %[rx3].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + "fmla v15.4s, v13.4s, %[rx1].s[3]\n" + "fmla v30.4s, v13.4s, %[rx2].s[3]\n" + "fmla v31.4s, v13.4s, %[rx3].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v30.4s, v8.4s, %[rx2].s[2]\n" + "fmla v31.4s, v8.4s, %[rx3].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + "fmla v30.4s, v9.4s, %[rx2].s[3]\n" + "fmla v31.4s, v9.4s, %[rx3].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v30.4s, v11.4s, %[rx2].s[1]\n" + "fmla v31.4s, v11.4s, %[rx3].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v15.4s, v12.4s, %[rx1].s[2]\n" + "fmla v30.4s, v12.4s, %[rx2].s[2]\n" + "fmla v31.4s, v12.4s, %[rx3].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + "fmla v15.4s, v13.4s, %[rx1].s[3]\n" + "fmla v30.4s, v13.4s, %[rx2].s[3]\n" + "fmla v31.4s, v13.4s, %[rx3].s[3]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v30.4s, v8.4s, %[rx2].s[2]\n" + "fmla v31.4s, v8.4s, %[rx3].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + "fmla v30.4s, v9.4s, %[rx2].s[3]\n" + "fmla v31.4s, v9.4s, %[rx3].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), + [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v30", "v31"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + "ld1 {v30.2s}, [%[outptr2]]\n" + "ld1 {v31.2s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + "fmla v30.2s, v6.2s, %[rx2].s[0]\n" + "fmla v31.2s, v6.2s, %[rx3].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + "fmla v15.2s, v7.2s, %[rx1].s[1]\n" + "fmla v30.2s, v7.2s, %[rx2].s[1]\n" + "fmla v31.2s, v7.2s, %[rx3].s[1]\n" + "fmla v14.2s, v8.2s, %[rx0].s[2]\n" + "fmla v15.2s, v8.2s, %[rx1].s[2]\n" + "fmla v30.2s, v8.2s, %[rx2].s[2]\n" + "fmla v31.2s, v8.2s, %[rx3].s[2]\n" + "fmla v14.2s, v9.2s, %[rx0].s[3]\n" + "fmla v15.2s, v9.2s, %[rx1].s[3]\n" + "fmla v30.2s, v9.2s, %[rx2].s[3]\n" + "fmla v31.2s, v9.2s, %[rx3].s[3]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + "st1 {v30.2s}, [%[outptr2]], #8\n" + "st1 {v31.2s}, [%[outptr3]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), + + [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14", "v15", "v30", + "v31"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x0 + 3)); + + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x1 + 3)); + + *outptr2 += (*kernel0) * (*_x2) + (*(kernel0 + outch)) * (*(_x2 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x2 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x2 + 3)); + + *outptr3 += (*kernel0) * (*_x3) + (*(kernel0 + outch)) * (*(_x3 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x3 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x3 + 3)); + + kernel0++; + outptr0++; + outptr1++; + outptr2++; + outptr3++; + } + + kernel0 += outch * 3; + _x0 += 4; + _x1 += 4; + _x2 += 4; + _x3 += 4; + } + + for (; i + 1 < inch; i += 2) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_f32(_x0); + register float32x2_t rx1 asm("v5") = vld1_f32(_x1); + register float32x2_t rx2 asm("v16") = vld1_f32(_x2); + register float32x2_t rx3 asm("v17") = vld1_f32(_x3); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + float *outptr2 = out0 + outch * 2; + float *outptr3 = out0 + outch * 3; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile( + "cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v30.4s, v11.4s, %[rx2].s[1]\n" + "fmla v31.4s, v11.4s, %[rx3].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v30.4s, v11.4s, %[rx2].s[1]\n" + "fmla v31.4s, v11.4s, %[rx3].s[1]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v30.4s, v7.4s, %[rx2].s[1]\n" + "fmla v31.4s, v7.4s, %[rx3].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [_n] "+r"(_n), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), + [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14", "v15", "v30", "v31"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + "ld1 {v30.2s}, [%[outptr2]]\n" + "ld1 {v31.2s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + "fmla v30.2s, v6.2s, %[rx2].s[0]\n" + "fmla v31.2s, v6.2s, %[rx3].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + "fmla v15.2s, v7.2s, %[rx1].s[1]\n" + "fmla v30.2s, v7.2s, %[rx2].s[1]\n" + "fmla v31.2s, v7.2s, %[rx3].s[1]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + "st1 {v30.2s}, [%[outptr2]], #8\n" + "st1 {v31.2s}, [%[outptr3]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), + + [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v7", "v14", "v15", "v30", "v31"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); + *outptr2 += (*kernel0) * (*_x2) + (*(kernel0 + outch)) * (*(_x2 + 1)); + *outptr3 += (*kernel0) * (*_x3) + (*(kernel0 + outch)) * (*(_x3 + 1)); + + kernel0++; + outptr0++; + outptr1++; + outptr2++; + outptr3++; + } + + kernel0 += outch; + _x0 += 2; + _x1 += 2; + _x2 += 2; + _x3 += 2; + } + + for (; i < inch; i++) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); + register float32x2_t rx1 asm("v5") = vld1_dup_f32(_x1); + register float32x2_t rx2 asm("v16") = vld1_dup_f32(_x2); + register float32x2_t rx3 asm("v17") = vld1_dup_f32(_x3); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + float *outptr2 = out0 + outch * 2; + float *outptr3 = out0 + outch * 3; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile( + "cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v30.4s, v10.4s, %[rx2].s[0]\n" + "fmla v31.4s, v10.4s, %[rx3].s[0]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + "ld1 {v30.4s}, [%[outptr2]]\n" + "ld1 {v31.4s}, [%[outptr3]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v30.4s, v6.4s, %[rx2].s[0]\n" + "fmla v31.4s, v6.4s, %[rx3].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "st1 {v30.4s}, [%[outptr2]], #16\n" + "st1 {v31.4s}, [%[outptr3]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [_n] "+r"(_n), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn), [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v10", "v14", "v15", "v30", "v31"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + "ld1 {v30.2s}, [%[outptr2]]\n" + "ld1 {v31.2s}, [%[outptr3]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + "fmla v30.2s, v6.2s, %[rx2].s[0]\n" + "fmla v31.2s, v6.2s, %[rx3].s[0]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + "st1 {v30.2s}, [%[outptr2]], #8\n" + "st1 {v31.2s}, [%[outptr3]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [rx0] "w"(rx0), [rx1] "w"(rx1), + + [rx2] "w"(rx2), [rx3] "w"(rx3) + : "cc", "memory", "x0", "v6", "v14", "v15", "v30", "v31"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0); + *outptr1 += (*kernel0) * (*_x1); + *outptr2 += (*kernel0) * (*_x2); + *outptr3 += (*kernel0) * (*_x3); + + kernel0++; + outptr0++; + outptr1++; + outptr2++; + outptr3++; + } + + _x0 += 1; + _x1 += 1; + _x2 += 1; + _x3 += 1; + } + + img1 += inch * 4 * _stride; + out0 += outch * 4; + q += 4; + } + + for (; q + 1 < outw; /*q += 2*/) + { + if (padding) + { + if (((q + 1) * _stride + m % kernel_w < pad_left) || + (q * _stride + m % kernel_w) >= pad_left + w) + { + out0 += outch * 2; + img1 += inch * _stride * 2; + q += 2; + continue; + } + else if ((q + 1) * _stride + m % kernel_w >= pad_left + w) + { + break; + } + else if (q * _stride + m % kernel_w < pad_left) + { + out0 += outch; + img1 += inch * _stride; + q++; + continue; + } + } + + const float *_x0 = img1; + const float *_x1 = img1 + inch * _stride; + const float *kernel0 = _kernel0; + + int i = 0; + for (; i + 3 < inch; i += 4) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); + register float32x4_t rx1 asm("v5") = vld1q_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v15.4s, v12.4s, %[rx1].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + "fmla v15.4s, v13.4s, %[rx1].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v15.4s, v12.4s, %[rx1].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + "fmla v15.4s, v13.4s, %[rx1].s[3]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v15.4s, v8.4s, %[rx1].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + "fmla v15.4s, v9.4s, %[rx1].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + "fmla v15.2s, v7.2s, %[rx1].s[1]\n" + "fmla v14.2s, v8.2s, %[rx0].s[2]\n" + "fmla v15.2s, v8.2s, %[rx1].s[2]\n" + "fmla v14.2s, v9.2s, %[rx0].s[3]\n" + "fmla v15.2s, v9.2s, %[rx1].s[3]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14", "v15"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x0 + 3)); + + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x1 + 3)); + + kernel0++; + outptr0++; + outptr1++; + } + + kernel0 += outch * 3; + _x0 += 4; + _x1 += 4; + } + + for (; i + 1 < inch; i += 2) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_f32(_x0); + register float32x2_t rx1 asm("v5") = vld1_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v15.4s, v11.4s, %[rx1].s[1]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v15.4s, v7.4s, %[rx1].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14", "v15"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + "fmla v15.2s, v7.2s, %[rx1].s[1]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) + : "cc", "memory", "x0", "v6", "v7", "v14", "v15"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); + + kernel0++; + outptr0++; + outptr1++; + } + + kernel0 += outch; + _x0 += 2; + _x1 += 2; + } + + for (; i < inch; i++) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); + register float32x2_t rx1 asm("v5") = vld1_dup_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v15.4s, v10.4s, %[rx1].s[0]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + "ld1 {v15.4s}, [%[outptr1]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v15.4s, v6.4s, %[rx1].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "st1 {v15.4s}, [%[outptr1]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v10", "v14", "v15"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + "ld1 {v15.2s}, [%[outptr1]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v15.2s, v6.2s, %[rx1].s[0]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + "st1 {v15.2s}, [%[outptr1]], #8\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [rx0] "w"(rx0), [rx1] "w"(rx1) + : "cc", "memory", "x0", "v6", "v14", "v15"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0); + *outptr1 += (*kernel0) * (*_x1); + + kernel0++; + outptr0++; + outptr1++; + } + + _x0 += 1; + _x1 += 1; + } + + img1 += inch * 2 * _stride; + out0 += outch * 2; + q += 2; + } + + for (; q < outw; q++) + { + if (padding) + { + if ((q * _stride + m % kernel_w < pad_left) || + (q * _stride + m % kernel_w >= pad_left + w)) + { + img1 += inch * _stride; + out0 += outch; + continue; + } + } + + const float *_x0 = img1; + const float *kernel0 = _kernel0; + + int i = 0; + for (; i + 3 < inch; i += 4) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x4_t rx0 asm("v4") = vld1q_f32(_x0); + + float *outptr0 = out0; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v13.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + "fmla v14.4s, v12.4s, %[rx0].s[2]\n" + "fmla v14.4s, v13.4s, %[rx0].s[3]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + "fmla v14.4s, v8.4s, %[rx0].s[2]\n" + "fmla v14.4s, v9.4s, %[rx0].s[3]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v8.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v9.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + "fmla v14.2s, v8.2s, %[rx0].s[2]\n" + "fmla v14.2s, v9.2s, %[rx0].s[3]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [stride] "r"(stride), [rx0] "w"(rx0) + : "cc", "memory", "x0", "v6", "v7", "v8", "v9", "v14"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x0 + 3)); + + kernel0++; + outptr0++; + } + + kernel0 += outch * 3; + _x0 += 4; + } + + for (; i + 1 < inch; i += 2) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_f32(_x0); + + float *outptr0 = out0; + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + "fmla v14.4s, v11.4s, %[rx0].s[1]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + "fmla v14.4s, v7.4s, %[rx0].s[1]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v7", "v10", "v11", "v14"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + "add x0, x0, %[stride]\n" + "ld1 {v7.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + "fmla v14.2s, v7.2s, %[rx0].s[1]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [stride] "r"(stride), [rx0] "w"(rx0) + : "cc", "memory", "x0", "v6", "v7", "v14"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); + + kernel0++; + outptr0++; + } + + kernel0 += outch; + _x0 += 2; + } + + for (; i < inch; i++) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("v4") = vld1_dup_f32(_x0); + + float *outptr0 = out0; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + "beq 1f\n" + + "0:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v10.4s}, [x0]\n" + + "fmla v14.4s, v10.4s, %[rx0].s[0]\n" + + "cmp %[oddn], #1\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + + "bne 3f\n" + + "2:\n" + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "ld1 {v6.4s}, [x0]\n" + + "ld1 {v14.4s}, [%[outptr0]]\n" + + "fmla v14.4s, v6.4s, %[rx0].s[0]\n" + + "st1 {v14.4s}, [%[outptr0]], #16\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [rx0] "w"(rx0), [oddn] "r"(oddn) + : "cc", "memory", "x0", "v6", "v10", "v14"); + } + + if (remain >= 2) + { + asm volatile("ld1 {v14.2s}, [%[outptr0]]\n" + + "mov x0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "ld1 {v6.2s}, [x0]\n" + + "fmla v14.2s, v6.2s, %[rx0].s[0]\n" + + "st1 {v14.2s}, [%[outptr0]], #8\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [rx0] "w"(rx0) + : "cc", "memory", "x0", "v6", "v14"); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0); + + kernel0++; + outptr0++; + } + + _x0 += 1; + } + + img1 += inch * _stride; + out0 += outch; + } + } + } +} + +#else // __aarch64__ +static void direct_conv_l(const convMat_t &bottom_blob, convMat_t &top_blob, + const convMat_t &_kernel, const int _stride, const int padding, + const int pad_top, const int pad_left) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + const int kernel_w = _kernel.w; + const int kernel_h = _kernel.h; + + for (int m = 0; m < kernel_w * kernel_h; m++) + { + const float *_kernel0 = _kernel.data + m * inch * outch; + const float *img0 = + bottom_blob.data + (m / kernel_w - pad_top) * w * inch + (m % kernel_w - pad_left) * inch; + +#ifdef _OPENMP +#pragma omp parallel for +#endif // _OPENMP + for (int p = 0; p < outh; p++) + { + float *out0 = top_blob.data + p * outw * outch; + // clear output. + if (m == 0) + { + for (int j = 0; j < outw * outch; j++) + { + *(out0 + j) = 0.f; + } + } + + if (padding) + { + if (((p * _stride + m / kernel_w) < pad_top) || (p * _stride + m / kernel_w >= pad_top + h)) + { + continue; + } + } + + const float *img1 = img0 + p * w * inch * _stride; + + int q = 0; + for (; q + 1 < outw; /*q += 2*/) + { + if (padding) + { + if (((q + 1) * _stride + m % kernel_w < pad_left) || + (q * _stride + m % kernel_w) >= pad_left + w) + { + out0 += outch * 2; + img1 += inch * _stride * 2; + q += 2; + continue; + } + else if (q * _stride + m % kernel_w < pad_left) + { + out0 += outch; + img1 += inch * _stride; + q++; + continue; + } + else if ((q + 1) * _stride + m % kernel_w >= pad_left + w) + { + break; + } + } + + const float *_x0 = img1; + const float *_x1 = img1 + inch * _stride; + const float *kernel0 = _kernel0; + + int i = 0; + for (; i + 3 < inch; i += 4) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x4_t rx0 asm("q4") = vld1q_f32(_x0); + register float32x4_t rx1 asm("q5") = vld1q_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q15, q6, %e[rx1][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q15, q7, %e[rx1][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q15, q8, %f[rx1][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + "vmla.f32 q15, q9, %f[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "vmla.f32 q14, q10, %e[rx0][0]\n" + "vmla.f32 q15, q10, %e[rx1][0]\n" + "vmla.f32 q14, q11, %e[rx0][1]\n" + "vmla.f32 q15, q11, %e[rx1][1]\n" + "vmla.f32 q14, q12, %f[rx0][0]\n" + "vmla.f32 q15, q12, %f[rx1][0]\n" + "vmla.f32 q14, q13, %f[rx0][1]\n" + "vmla.f32 q15, q13, %f[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q15, q6, %e[rx1][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q15, q7, %e[rx1][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q15, q8, %f[rx1][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + "vmla.f32 q15, q9, %f[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + + "vmla.f32 q14, q10, %e[rx0][0]\n" + "vmla.f32 q15, q10, %e[rx1][0]\n" + "vmla.f32 q14, q11, %e[rx0][1]\n" + "vmla.f32 q15, q11, %e[rx1][1]\n" + "vmla.f32 q14, q12, %f[rx0][0]\n" + "vmla.f32 q15, q12, %f[rx1][0]\n" + "vmla.f32 q14, q13, %f[rx0][1]\n" + "vmla.f32 q15, q13, %f[rx1][1]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q15, q6, %e[rx1][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q15, q7, %e[rx1][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q15, q8, %f[rx1][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + "vmla.f32 q15, q9, %f[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15"); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + "vld1.f32 {d30}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18}, [r0]\n" + + "vmla.f32 d28, d12, %e[rx0][0]\n" + "vmla.f32 d30, d12, %e[rx1][0]\n" + "vmla.f32 d28, d14, %e[rx0][1]\n" + "vmla.f32 d30, d14, %e[rx1][1]\n" + "vmla.f32 d28, d16, %f[rx0][0]\n" + "vmla.f32 d30, d16, %f[rx1][0]\n" + "vmla.f32 d28, d18, %f[rx0][1]\n" + "vmla.f32 d30, d18, %f[rx1][1]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + "vst1.f32 {d30}, [%[outptr1]]!\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) +#ifndef _OPENMP + + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x0 + 3)); + + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x1 + 3)); + + kernel0++; + outptr0++; + outptr1++; + } + + kernel0 += outch * 3; + _x0 += 4; + _x1 += 4; + } + + for (; i + 1 < inch; i += 2) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("d8") = vld1_f32(_x0); + register float32x2_t rx1 asm("d10") = vld1_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + "vmla.f32 q15, q7, %P[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q15, q10, %P[rx1][0]\n" + "vmla.f32 q14, q11, %P[rx0][1]\n" + "vmla.f32 q15, q11, %P[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + "vmla.f32 q15, q7, %P[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q15, q10, %P[rx1][0]\n" + "vmla.f32 q14, q11, %P[rx0][1]\n" + "vmla.f32 q15, q11, %P[rx1][1]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + "vmla.f32 q15, q7, %P[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q10", "q11", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + + ); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + "vld1.f32 {d30}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14}, [r0]\n" + + "vmla.f32 d28, d12, %P[rx0][0]\n" + "vmla.f32 d30, d12, %P[rx1][0]\n" + "vmla.f32 d28, d14, %P[rx0][1]\n" + "vmla.f32 d30, d14, %P[rx1][1]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + "vst1.f32 {d30}, [%[outptr1]]!\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); + + kernel0++; + outptr0++; + outptr1++; + } + + kernel0 += outch; + _x0 += 2; + _x1 += 2; + } + + for (; i < inch; i++) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("d8") = vld1_dup_f32(_x0); + register float32x2_t rx1 asm("d10") = vld1_dup_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q15, q10, %P[rx1][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q15, q10, %P[rx1][0]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q10", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + "vld1.f32 {d30}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + + "vmla.f32 d28, d12, %P[rx0][0]\n" + "vmla.f32 d30, d12, %P[rx1][0]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + "vst1.f32 {d30}, [%[outptr1]]!\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [rx0] "w"(rx0), [rx1] "w"(rx1) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0); + *outptr1 += (*kernel0) * (*_x1); + + kernel0++; + outptr0++; + outptr1++; + } + + _x0 += 1; + _x1 += 1; + } + + img1 += inch * 2 * _stride; + out0 += outch * 2; + q += 2; + } + + for (; q < outw; q++) + { + if (padding) + { + if ((q * _stride + m % kernel_w < pad_left) || + (q * _stride + m % kernel_w) >= pad_left + bottom_blob.w) + { + img1 += inch * _stride; + out0 += outch; + continue; + } + } + + const float *_x0 = img1; + const float *kernel0 = _kernel0; + + int i = 0; + for (; i + 3 < inch; i += 4) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x4_t rx0 asm("q4") = vld1q_f32(_x0); + + float *outptr0 = out0; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "vmla.f32 q14, q10, %e[rx0][0]\n" + "vmla.f32 q14, q11, %e[rx0][1]\n" + "vmla.f32 q14, q12, %f[rx0][0]\n" + "vmla.f32 q14, q13, %f[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + + "vmla.f32 q14, q10, %e[rx0][0]\n" + "vmla.f32 q14, q11, %e[rx0][1]\n" + "vmla.f32 q14, q12, %f[rx0][0]\n" + "vmla.f32 q14, q13, %f[rx0][1]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + + ); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18}, [r0]\n" + + "vmla.f32 d28, d12, %e[rx0][0]\n" + "vmla.f32 d28, d14, %e[rx0][1]\n" + "vmla.f32 d28, d16, %f[rx0][0]\n" + "vmla.f32 d28, d18, %f[rx0][1]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [stride] "r"(stride), [rx0] "w"(rx0) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x0 + 3)); + + kernel0++; + outptr0++; + } + + kernel0 += outch * 3; + _x0 += 4; + } + + for (; i + 1 < inch; i += 2) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("d8") = vld1_f32(_x0); + + float *outptr0 = out0; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q14, q11, %P[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q14, q11, %P[rx0][1]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q10", "q11", "q14" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + + ); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14}, [r0]\n" + + "vmla.f32 d28, d12, %P[rx0][0]\n" + "vmla.f32 d28, d14, %P[rx0][1]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [stride] "r"(stride), [rx0] "w"(rx0) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); + + kernel0++; + outptr0++; + } + + kernel0 += outch; + _x0 += 2; + } + + for (; i < inch; i++) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("d8") = vld1_dup_f32(_x0); + + float *outptr0 = out0; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [rx0] "w"(rx0), [oddn] "r"(oddn) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q10", "q14" + +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + + "vmla.f32 d28, d12, %P[rx0][0]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [rx0] "w"(rx0) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0); + + kernel0++; + outptr0++; + } + + _x0 += 1; + } + + img1 += inch * _stride; + out0 += outch; + } + } + } +} + +static void direct_conv_s(const convMat_t &bottom_blob, convMat_t &top_blob, + const convMat_t &_kernel, const int _stride, const int padding, + const int pad_top, const int pad_left) +{ + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int inch = bottom_blob.c; + const int outw = top_blob.w; + const int outh = top_blob.h; + const int outch = top_blob.c; + const int kernel_w = _kernel.w; + const int kernel_h = _kernel.h; + +#ifdef _OPENMP +#pragma omp parallel for +#endif // _OPENMP + for (int p = 0; p < outh; p++) + { + const float *img0 = bottom_blob.data + (p * _stride - pad_top) * w * inch; + float *out = top_blob.data + p * outw * outch; + + // clear output. + for (int j = 0; j < outw * outch; j++) + { + *(out + j) = 0.f; + } + + for (int m = 0; m < kernel_w * kernel_h; m++) + { + if (padding) + { + if (((p * _stride + m / kernel_w) < pad_top) || (p * _stride + m / kernel_w >= pad_top + h)) + { + continue; + } + } + + float *out0 = out; + const float *_kernel0 = _kernel.data + m * inch * outch; + const float *img1 = img0 + (m / kernel_w) * w * inch + (m % kernel_w - pad_left) * inch; + + int q = 0; + for (; q + 1 < outw; /*q += 2*/) + { + if (padding) + { + if (((q + 1) * _stride + m % kernel_w < pad_left) || + (q * _stride + m % kernel_w >= pad_left + w)) + { + out0 += outch * 2; + img1 += inch * _stride * 2; + q += 2; + continue; + } + else if (q * _stride + m % kernel_w < pad_left) + { + out0 += outch; + img1 += inch * _stride; + q++; + continue; + } + else if ((q + 1) * _stride + m % kernel_w >= pad_left + w) + { + break; + } + } + + const float *_x0 = img1; + const float *_x1 = img1 + inch * _stride; + + const float *kernel0 = _kernel0; + + int i = 0; + for (; i + 3 < inch; i += 4) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x4_t rx0 asm("q4") = vld1q_f32(_x0); + register float32x4_t rx1 asm("q5") = vld1q_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q15, q6, %e[rx1][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q15, q7, %e[rx1][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q15, q8, %f[rx1][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + "vmla.f32 q15, q9, %f[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "vmla.f32 q14, q10, %e[rx0][0]\n" + "vmla.f32 q15, q10, %e[rx1][0]\n" + "vmla.f32 q14, q11, %e[rx0][1]\n" + "vmla.f32 q15, q11, %e[rx1][1]\n" + "vmla.f32 q14, q12, %f[rx0][0]\n" + "vmla.f32 q15, q12, %f[rx1][0]\n" + "vmla.f32 q14, q13, %f[rx0][1]\n" + "vmla.f32 q15, q13, %f[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q15, q6, %e[rx1][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q15, q7, %e[rx1][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q15, q8, %f[rx1][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + "vmla.f32 q15, q9, %f[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + + "vmla.f32 q14, q10, %e[rx0][0]\n" + "vmla.f32 q15, q10, %e[rx1][0]\n" + "vmla.f32 q14, q11, %e[rx0][1]\n" + "vmla.f32 q15, q11, %e[rx1][1]\n" + "vmla.f32 q14, q12, %f[rx0][0]\n" + "vmla.f32 q15, q12, %f[rx1][0]\n" + "vmla.f32 q14, q13, %f[rx0][1]\n" + "vmla.f32 q15, q13, %f[rx1][1]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q15, q6, %e[rx1][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q15, q7, %e[rx1][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q15, q8, %f[rx1][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + "vmla.f32 q15, q9, %f[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15"); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + "vld1.f32 {d30}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18}, [r0]\n" + + "vmla.f32 d28, d12, %e[rx0][0]\n" + "vmla.f32 d30, d12, %e[rx1][0]\n" + "vmla.f32 d28, d14, %e[rx0][1]\n" + "vmla.f32 d30, d14, %e[rx1][1]\n" + "vmla.f32 d28, d16, %f[rx0][0]\n" + "vmla.f32 d30, d16, %f[rx1][0]\n" + "vmla.f32 d28, d18, %f[rx0][1]\n" + "vmla.f32 d30, d18, %f[rx1][1]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + "vst1.f32 {d30}, [%[outptr1]]!\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q14", "q15" +#else + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x0 + 3)); + + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x1 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x1 + 3)); + + kernel0++; + outptr0++; + outptr1++; + } + + kernel0 += outch * 3; + _x0 += 4; + _x1 += 4; + } + + for (; i + 1 < inch; i += 2) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("d8") = vld1_f32(_x0); + register float32x2_t rx1 asm("d10") = vld1_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + "vmla.f32 q15, q7, %P[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q15, q10, %P[rx1][0]\n" + "vmla.f32 q14, q11, %P[rx0][1]\n" + "vmla.f32 q15, q11, %P[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + "vmla.f32 q15, q7, %P[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q15, q10, %P[rx1][0]\n" + "vmla.f32 q14, q11, %P[rx0][1]\n" + "vmla.f32 q15, q11, %P[rx1][1]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + "vmla.f32 q15, q7, %P[rx1][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q10", "q11", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + "vld1.f32 {d30}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14}, [r0]\n" + + "vmla.f32 d28, d12, %P[rx0][0]\n" + "vmla.f32 d30, d12, %P[rx1][0]\n" + "vmla.f32 d28, d14, %P[rx0][1]\n" + "vmla.f32 d30, d14, %P[rx1][1]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + "vst1.f32 {d30}, [%[outptr1]]!\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [stride] "r"(stride), [rx0] "w"(rx0), [rx1] "w"(rx1) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); + *outptr1 += (*kernel0) * (*_x1) + (*(kernel0 + outch)) * (*(_x1 + 1)); + + kernel0++; + outptr0++; + outptr1++; + } + + kernel0 += outch; + _x0 += 2; + _x1 += 2; + } + + for (; i < inch; i++) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("d8") = vld1_dup_f32(_x0); + register float32x2_t rx1 asm("d10") = vld1_dup_f32(_x1); + + float *outptr0 = out0; + float *outptr1 = out0 + outch; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q15, q10, %P[rx1][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q15, q10, %P[rx1][0]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + "vld1.f32 {d30-d31}, [%[outptr1]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q15, q6, %P[rx1][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vst1.f32 {d30-d31}, [%[outptr1]]!\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [_n] "+r"(_n) + : [rx0] "w"(rx0), [rx1] "w"(rx1), [oddn] "r"(oddn) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q10", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + "vld1.f32 {d30}, [%[outptr1]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + + "vmla.f32 d28, d12, %P[rx0][0]\n" + "vmla.f32 d30, d12, %P[rx1][0]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + "vst1.f32 {d30}, [%[outptr1]]!\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1) + : [rx0] "w"(rx0), [rx1] "w"(rx1) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0); + *outptr1 += (*kernel0) * (*_x1); + + kernel0++; + outptr0++; + outptr1++; + } + + _x0 += 1; + _x1 += 1; + } + + img1 += inch * 2 * _stride; + out0 += outch * 2; + q += 2; + } + + for (; q < outw; q++) + { + if (padding) + { + if ((q * _stride + m % kernel_w < pad_left) || + (q * _stride + m % kernel_w >= pad_left + w)) + { + img1 += inch * _stride; + out0 += outch; + continue; + } + } + + const float *_x0 = img1; + const float *kernel0 = _kernel0; + + int i = 0; + for (; i + 3 < inch; i += 4) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x4_t rx0 asm("q4") = vld1q_f32(_x0); + + float *outptr0 = out0; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "vmla.f32 q14, q10, %e[rx0][0]\n" + "vmla.f32 q14, q11, %e[rx0][1]\n" + "vmla.f32 q14, q12, %f[rx0][0]\n" + "vmla.f32 q14, q13, %f[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + + "vmla.f32 q14, q10, %e[rx0][0]\n" + "vmla.f32 q14, q11, %e[rx0][1]\n" + "vmla.f32 q14, q12, %f[rx0][0]\n" + "vmla.f32 q14, q13, %f[rx0][1]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %e[rx0][0]\n" + "vmla.f32 q14, q7, %e[rx0][1]\n" + "vmla.f32 q14, q8, %f[rx0][0]\n" + "vmla.f32 q14, q9, %f[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d16}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d18}, [r0]\n" + + "vmla.f32 d28, d12, %e[rx0][0]\n" + "vmla.f32 d28, d14, %e[rx0][1]\n" + "vmla.f32 d28, d16, %f[rx0][0]\n" + "vmla.f32 d28, d18, %f[rx0][1]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [stride] "r"(stride), [rx0] "w"(rx0) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)) + + (*(kernel0 + outch * 2)) * (*(_x0 + 2)) + + (*(kernel0 + outch * 3)) * (*(_x0 + 3)); + + kernel0++; + outptr0++; + } + + kernel0 += outch * 3; + _x0 += 4; + } + + for (; i + 1 < inch; i += 2) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("d8") = vld1_f32(_x0); + + float *outptr0 = out0; + + int stride = outch << 2; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q14, q11, %P[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + "vmla.f32 q14, q11, %P[rx0][1]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + "vmla.f32 q14, q7, %P[rx0][1]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [stride] "r"(stride), [rx0] "w"(rx0), [oddn] "r"(oddn) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q10", "q11", "q14" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + "add r0, r0, %[stride]\n" + "vld1.f32 {d14}, [r0]\n" + + "vmla.f32 d28, d12, %P[rx0][0]\n" + "vmla.f32 d28, d14, %P[rx0][1]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [stride] "r"(stride), [rx0] "w"(rx0) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0) + (*(kernel0 + outch)) * (*(_x0 + 1)); + + kernel0++; + outptr0++; + } + + kernel0 += outch; + _x0 += 2; + } + + for (; i < inch; i++) + { + int nn = outch >> 2; + int remain = outch & 0x03; + + register float32x2_t rx0 asm("d8") = vld1_dup_f32(_x0); + + float *outptr0 = out0; + + if (nn > 0) + { + int _n = nn >> 1; + int oddn = nn & 1; + + asm volatile("cmp %[_n], #0\n" + "beq 2f\n" + "subs %[_n], %[_n], #1\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "beq 1f\n" + + "0:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "subs %[_n], %[_n], #1\n" + "bne 0b\n" + + "1:\n" + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d20-d21}, [r0]\n" + + "vmla.f32 q14, q10, %P[rx0][0]\n" + + "cmp %[oddn], #1\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + + "bne 3f\n" + + "2:\n" + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #16\n" + "vld1.f32 {d12-d13}, [r0]\n" + + "vld1.f32 {d28-d29}, [%[outptr0]]\n" + + "vmla.f32 q14, q6, %P[rx0][0]\n" + + "vst1.f32 {d28-d29}, [%[outptr0]]!\n" + "3:\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0), [_n] "+r"(_n) + : [rx0] "w"(rx0), [oddn] "r"(oddn) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q10", "q14" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + } + + if (remain >= 2) + { + asm volatile("vld1.f32 {d28}, [%[outptr0]]\n" + + "mov r0, %[kernel0]\n" + "add %[kernel0], %[kernel0], #8\n" + "vld1.f32 {d12}, [r0]\n" + + "vmla.f32 d28, d12, %P[rx0][0]\n" + + "vst1.f32 {d28}, [%[outptr0]]!\n" + : [kernel0] "+r"(kernel0), [outptr0] "+r"(outptr0) + : [rx0] "w"(rx0) +#ifndef _OPENMP + : "cc", "memory", "r0", "q6", "q14", "q15" +#else // _OPENMP + : "cc", "memory", "r0", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", + "q14", "q15" +#endif // _OPENMP + ); + remain -= 2; + } + + if (remain == 1) + { + *outptr0 += (*kernel0) * (*_x0); + + kernel0++; + outptr0++; + } + + _x0 += 1; + } + + img1 += inch * _stride; + out0 += outch; + } + } + } +} +#endif // __aarch64__ + +void direct_conv_colmajor(const convMat_t &bottom_blob, convMat_t &top_blob, + const convMat_t &kernel, const convParams_t ¶ms, int num_threads) +{ + omp_set_num_threads(num_threads); + + if (bottom_blob.c * top_blob.c < 256 * 256) + { + direct_conv_s(bottom_blob, top_blob, kernel, params.stride_w, params.padding, params.pad_h, + params.pad_w); + return; + } + + direct_conv_l(bottom_blob, top_blob, kernel, params.stride_w, params.padding, params.pad_h, + params.pad_w); +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/direct_conv_colmajor.h b/compute/ncnn/src/srcn/direct_conv_colmajor.h new file mode 100644 index 000000000..5e15192c9 --- /dev/null +++ b/compute/ncnn/src/srcn/direct_conv_colmajor.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_DIRECT_CONV_COLMAJOR_H__ +#define __NNFW_SRCN_DIRECT_CONV_COLMAJOR_H__ + +#include "ncnn/srcn/conv_type.h" + +namespace nnfw +{ +namespace srcn +{ + +void direct_conv_colmajor(const convMat_t &, convMat_t &, const convMat_t &, const convParams_t &, + int); + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_DIRECT_CONV_COLMAJOR_H__ diff --git a/compute/ncnn/src/srcn/sgemm_kernel.cc b/compute/ncnn/src/srcn/sgemm_kernel.cc new file mode 100644 index 000000000..90c3641db --- /dev/null +++ b/compute/ncnn/src/srcn/sgemm_kernel.cc @@ -0,0 +1,2508 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <arm_neon.h> + +namespace nnfw +{ +namespace srcn +{ + +#if __aarch64__ +static void sgemm_rowmajor_micro_kernel_8x12(const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k, const int k0, + const int stride) +{ + int oddk = (k & 1); + int nk = ((k + 1) / 2) - 1; + + const int nstride = stride << 2; + + __asm __volatile("ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + + "cmp %[k0], #0\n" + "beq 0f\n" + + "mov x0, %[res_ptr]\n" + "ld1 {v8.4s, v9.4s, v10.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v11.4s, v12.4s, v13.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v14.4s, v15.4s, v16.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v17.4s, v18.4s, v19.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v20.4s, v21.4s, v22.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v23.4s, v24.4s, v25.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v26.4s, v27.4s, v28.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v29.4s, v30.4s, v31.4s}, [x0]\n" + "cbz %w[nk], 4f\n" + "b 1f\n" + + "0:\n" + "movi v8.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v17.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "cbz %w[nk], 4f\n" + + "1:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + "fmla v21.4s, v3.4s, v1.s[0]\n" + "fmla v24.4s, v3.4s, v1.s[1]\n" + "fmla v27.4s, v3.4s, v1.s[2]\n" + "fmla v30.4s, v3.4s, v1.s[3]\n" + "fmla v22.4s, v4.4s, v1.s[0]\n" + "fmla v25.4s, v4.4s, v1.s[1]\n" + "fmla v28.4s, v4.4s, v1.s[2]\n" + "fmla v31.4s, v4.4s, v1.s[3]\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v9.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v0.s[1]\n" + "fmla v15.4s, v6.4s, v0.s[2]\n" + "fmla v18.4s, v6.4s, v0.s[3]\n" + "fmla v10.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v0.s[1]\n" + "fmla v16.4s, v7.4s, v0.s[2]\n" + "fmla v19.4s, v7.4s, v0.s[3]\n" + + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + "fmla v21.4s, v6.4s, v1.s[0]\n" + "fmla v24.4s, v6.4s, v1.s[1]\n" + "fmla v27.4s, v6.4s, v1.s[2]\n" + "fmla v30.4s, v6.4s, v1.s[3]\n" + "fmla v22.4s, v7.4s, v1.s[0]\n" + "fmla v25.4s, v7.4s, v1.s[1]\n" + "subs %w[nk], %w[nk], #1\n" + "fmla v28.4s, v7.4s, v1.s[2]\n" + "fmla v31.4s, v7.4s, v1.s[3]\n" + "bne 1b\n" + + "4:\n" + "mov x0, %[res_ptr]\n" + "cbnz %[oddk], 2f\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v21.4s, v3.4s, v1.s[0]\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "fmla v22.4s, v4.4s, v1.s[0]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v24.4s, v3.4s, v1.s[1]\n" + "fmla v25.4s, v4.4s, v1.s[1]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v27.4s, v3.4s, v1.s[2]\n" + "fmla v28.4s, v4.4s, v1.s[2]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + "fmla v30.4s, v3.4s, v1.s[3]\n" + "fmla v31.4s, v4.4s, v1.s[3]\n" + + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + + "fmla v8.4s, v5.4s, v0.s[0]\n" + "fmla v9.4s, v6.4s, v0.s[0]\n" + "fmla v10.4s, v7.4s, v0.s[0]\n" + "st1 {v8.4s, v9.4s, v10.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v0.s[1]\n" + "st1 {v11.4s, v12.4s, v13.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v14.4s, v5.4s, v0.s[2]\n" + "fmla v15.4s, v6.4s, v0.s[2]\n" + "fmla v16.4s, v7.4s, v0.s[2]\n" + "st1 {v14.4s, v15.4s, v16.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v17.4s, v5.4s, v0.s[3]\n" + "fmla v18.4s, v6.4s, v0.s[3]\n" + "fmla v19.4s, v7.4s, v0.s[3]\n" + "st1 {v17.4s, v18.4s, v19.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + + "fmla v20.4s, v5.4s, v1.s[0]\n" + "fmla v21.4s, v6.4s, v1.s[0]\n" + "fmla v22.4s, v7.4s, v1.s[0]\n" + "st1 {v20.4s, v21.4s, v22.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v23.4s, v5.4s, v1.s[1]\n" + "fmla v24.4s, v6.4s, v1.s[1]\n" + "fmla v25.4s, v7.4s, v1.s[1]\n" + "st1 {v23.4s, v24.4s, v25.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v26.4s, v5.4s, v1.s[2]\n" + "fmla v27.4s, v6.4s, v1.s[2]\n" + "fmla v28.4s, v7.4s, v1.s[2]\n" + "st1 {v26.4s, v27.4s, v28.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + "fmla v30.4s, v6.4s, v1.s[3]\n" + "fmla v31.4s, v7.4s, v1.s[3]\n" + "b 3f\n" + + "2:\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "st1 {v8.4s, v9.4s, v10.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v11.4s, v2.4s, v0.s[1]\n" + "fmla v12.4s, v3.4s, v0.s[1]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "st1 {v11.4s, v12.4s, v13.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v14.4s, v2.4s, v0.s[2]\n" + "fmla v15.4s, v3.4s, v0.s[2]\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "st1 {v14.4s, v15.4s, v16.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v17.4s, v2.4s, v0.s[3]\n" + "fmla v18.4s, v3.4s, v0.s[3]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + "st1 {v17.4s, v18.4s, v19.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + + "fmla v20.4s, v2.4s, v1.s[0]\n" + "fmla v21.4s, v3.4s, v1.s[0]\n" + "fmla v22.4s, v4.4s, v1.s[0]\n" + "st1 {v20.4s, v21.4s, v22.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v23.4s, v2.4s, v1.s[1]\n" + "fmla v24.4s, v3.4s, v1.s[1]\n" + "fmla v25.4s, v4.4s, v1.s[1]\n" + "st1 {v23.4s, v24.4s, v25.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v26.4s, v2.4s, v1.s[2]\n" + "fmla v27.4s, v3.4s, v1.s[2]\n" + "fmla v28.4s, v4.4s, v1.s[2]\n" + "st1 {v26.4s, v27.4s, v28.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v29.4s, v2.4s, v1.s[3]\n" + "fmla v30.4s, v3.4s, v1.s[3]\n" + "fmla v31.4s, v4.4s, v1.s[3]\n" + + "3:\n" + "st1 {v29.4s, v30.4s, v31.4s}, [x0]\n" + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), + [nk] "+r"(nk) + : [oddk] "r"(oddk), [k0] "r"(k0), [nstride] "r"(nstride) + : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +static void sgemm_rowmajor_micro_kernel_12x8(const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k, const int k0, + const int stride) +{ + int oddk = (k & 1); + int nk = ((k + 1) / 2) - 1; + + const int nstride = stride << 2; + + __asm __volatile("ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v4.4s, v5.4s}, [%[rhs_ptr]], #32\n" + + "cmp %[k0], #0\n" + "beq 0f\n" + + "mov x0, %[res_ptr]\n" + "ld1 {v8.4s, v9.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v10.4s, v11.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v12.4s, v13.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v14.4s, v15.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v16.4s, v17.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v18.4s, v19.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v20.4s, v21.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v22.4s, v23.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v24.4s, v25.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v26.4s, v27.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v28.4s, v29.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v30.4s, v31.4s}, [x0]\n" + "cbz %w[nk], 4f\n" + "b 1f\n" + + "0:\n" + "movi v8.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v17.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "cbz %w[nk], 4f\n" + + "1:\n" + "fmla v8.4s, v4.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "fmla v12.4s, v4.4s, v0.s[2]\n" + "fmla v14.4s, v4.4s, v0.s[3]\n" + "fmla v9.4s, v5.4s, v0.s[0]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "fmla v13.4s, v5.4s, v0.s[2]\n" + "fmla v15.4s, v5.4s, v0.s[3]\n" + + "fmla v16.4s, v4.4s, v1.s[0]\n" + "fmla v18.4s, v4.4s, v1.s[1]\n" + "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" + "fmla v20.4s, v4.4s, v1.s[2]\n" + "fmla v22.4s, v4.4s, v1.s[3]\n" + "fmla v17.4s, v5.4s, v1.s[0]\n" + "fmla v19.4s, v5.4s, v1.s[1]\n" + "fmla v21.4s, v5.4s, v1.s[2]\n" + "fmla v23.4s, v5.4s, v1.s[3]\n" + + "ld1 {v6.4s, v7.4s}, [%[rhs_ptr]], #32\n" + + "fmla v24.4s, v4.4s, v2.s[0]\n" + "fmla v26.4s, v4.4s, v2.s[1]\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "fmla v28.4s, v4.4s, v2.s[2]\n" + "fmla v30.4s, v4.4s, v2.s[3]\n" + "fmla v25.4s, v5.4s, v2.s[0]\n" + "fmla v27.4s, v5.4s, v2.s[1]\n" + "fmla v29.4s, v5.4s, v2.s[2]\n" + "fmla v31.4s, v5.4s, v2.s[3]\n" + + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "fmla v12.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v0.s[3]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v0.s[2]\n" + "fmla v15.4s, v7.4s, v0.s[3]\n" + + "fmla v16.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v1.s[1]\n" + "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" + "fmla v20.4s, v6.4s, v1.s[2]\n" + "fmla v22.4s, v6.4s, v1.s[3]\n" + "fmla v17.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v1.s[1]\n" + "fmla v21.4s, v7.4s, v1.s[2]\n" + "fmla v23.4s, v7.4s, v1.s[3]\n" + + "ld1 {v4.4s, v5.4s}, [%[rhs_ptr]], #32\n" + + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v26.4s, v6.4s, v2.s[1]\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "fmla v28.4s, v6.4s, v2.s[2]\n" + "fmla v30.4s, v6.4s, v2.s[3]\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "fmla v27.4s, v7.4s, v2.s[1]\n" + "subs %w[nk], %w[nk], #1\n" + "fmla v29.4s, v7.4s, v2.s[2]\n" + "fmla v31.4s, v7.4s, v2.s[3]\n" + "bne 1b\n" + + "4:\n" + "mov x0, %[res_ptr]\n" + "cbnz %[oddk], 2f\n" + + "fmla v8.4s, v4.4s, v0.s[0]\n" + "fmla v9.4s, v5.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "fmla v10.4s, v4.4s, v0.s[1]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "fmla v12.4s, v4.4s, v0.s[2]\n" + "fmla v13.4s, v5.4s, v0.s[2]\n" + "fmla v14.4s, v4.4s, v0.s[3]\n" + "fmla v15.4s, v5.4s, v0.s[3]\n" + + "fmla v16.4s, v4.4s, v1.s[0]\n" + "fmla v17.4s, v5.4s, v1.s[0]\n" + "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" + "fmla v18.4s, v4.4s, v1.s[1]\n" + "fmla v19.4s, v5.4s, v1.s[1]\n" + "fmla v20.4s, v4.4s, v1.s[2]\n" + "fmla v21.4s, v5.4s, v1.s[2]\n" + "fmla v22.4s, v4.4s, v1.s[3]\n" + "fmla v23.4s, v5.4s, v1.s[3]\n" + + "ld1 {v6.4s, v7.4s}, [%[rhs_ptr]], #32\n" + + "fmla v24.4s, v4.4s, v2.s[0]\n" + "fmla v25.4s, v5.4s, v2.s[0]\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "fmla v26.4s, v4.4s, v2.s[1]\n" + "fmla v27.4s, v5.4s, v2.s[1]\n" + "fmla v28.4s, v4.4s, v2.s[2]\n" + "fmla v29.4s, v5.4s, v2.s[2]\n" + "fmla v30.4s, v4.4s, v2.s[3]\n" + "fmla v31.4s, v5.4s, v2.s[3]\n" + + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "st1 {v8.4s, v9.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "st1 {v10.4s, v11.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v12.4s, v6.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v0.s[2]\n" + "st1 {v12.4s, v13.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v14.4s, v6.4s, v0.s[3]\n" + "fmla v15.4s, v7.4s, v0.s[3]\n" + "st1 {v14.4s, v15.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + + "fmla v16.4s, v6.4s, v1.s[0]\n" + "fmla v17.4s, v7.4s, v1.s[0]\n" + "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" + "st1 {v16.4s, v17.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v18.4s, v6.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v1.s[1]\n" + "st1 {v18.4s, v19.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v20.4s, v6.4s, v1.s[2]\n" + "fmla v21.4s, v7.4s, v1.s[2]\n" + "st1 {v20.4s, v21.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v22.4s, v6.4s, v1.s[3]\n" + "fmla v23.4s, v7.4s, v1.s[3]\n" + "st1 {v22.4s, v23.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + + "fmla v24.4s, v6.4s, v2.s[0]\n" + "fmla v25.4s, v7.4s, v2.s[0]\n" + "st1 {v24.4s, v25.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v26.4s, v6.4s, v2.s[1]\n" + "fmla v27.4s, v7.4s, v2.s[1]\n" + "st1 {v26.4s, v27.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v28.4s, v6.4s, v2.s[2]\n" + "fmla v29.4s, v7.4s, v2.s[2]\n" + "st1 {v28.4s, v29.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v30.4s, v6.4s, v2.s[3]\n" + "fmla v31.4s, v7.4s, v2.s[3]\n" + "b 3f\n" + + "2:\n" + "fmla v8.4s, v4.4s, v0.s[0]\n" + "fmla v9.4s, v5.4s, v0.s[0]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "st1 {v8.4s, v9.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v10.4s, v4.4s, v0.s[1]\n" + "fmla v11.4s, v5.4s, v0.s[1]\n" + "st1 {v10.4s, v11.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v12.4s, v4.4s, v0.s[2]\n" + "fmla v13.4s, v5.4s, v0.s[2]\n" + "st1 {v12.4s, v13.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v14.4s, v4.4s, v0.s[3]\n" + "fmla v15.4s, v5.4s, v0.s[3]\n" + "st1 {v14.4s, v15.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + + "fmla v16.4s, v4.4s, v1.s[0]\n" + "fmla v17.4s, v5.4s, v1.s[0]\n" + "ld1 {v2.4s}, [%[lhs_ptr]], #16\n" + "st1 {v16.4s, v17.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v18.4s, v4.4s, v1.s[1]\n" + "fmla v19.4s, v5.4s, v1.s[1]\n" + "st1 {v18.4s, v19.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v20.4s, v4.4s, v1.s[2]\n" + "fmla v21.4s, v5.4s, v1.s[2]\n" + "st1 {v20.4s, v21.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v22.4s, v4.4s, v1.s[3]\n" + "fmla v23.4s, v5.4s, v1.s[3]\n" + "st1 {v22.4s, v23.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + + "fmla v24.4s, v4.4s, v2.s[0]\n" + "fmla v25.4s, v5.4s, v2.s[0]\n" + "st1 {v24.4s, v25.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v26.4s, v4.4s, v2.s[1]\n" + "fmla v27.4s, v5.4s, v2.s[1]\n" + "st1 {v26.4s, v27.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v28.4s, v4.4s, v2.s[2]\n" + "fmla v29.4s, v5.4s, v2.s[2]\n" + "st1 {v28.4s, v29.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v30.4s, v4.4s, v2.s[3]\n" + "fmla v31.4s, v5.4s, v2.s[3]\n" + + "3:\n" + "st1 {v30.4s, v31.4s}, [x0]\n" + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), + [nk] "+r"(nk) + : [oddk] "r"(oddk), [k0] "r"(k0), [nstride] "r"(nstride) + : "x0", "v0", "v1", "v2", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +#ifdef BATCH_DILATION_FIX +static void sgemm_rowmajor_micro_kernel_4x24(const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k, const int k0, + const int stride) +{ + int oddk = (k & 1); + int nk = ((k + 1) / 2) - 1; + + const int nstride = stride << 2; + + __asm __volatile("ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + + "cmp %[k0], #0\n" + "beq 0f\n" + + "mov x0, %[res_ptr]\n" + "mov x1, x0\n" + "ld1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" + "ld1 {v11.4s, v12.4s, v13.4s}, [x1]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "ld1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" + "ld1 {v17.4s, v18.4s, v19.4s}, [x1]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "ld1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" + "ld1 {v23.4s, v24.4s, v25.4s}, [x1]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "ld1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" + "ld1 {v29.4s, v30.4s, v31.4s}, [x1]\n" + "cbz %w[nk], 4f\n" + "b 1f\n" + + "0:\n" + "movi v8.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v17.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "cbz %w[nk], 4f\n" + + "1:\n" + "mov x0, v0.d[0]\n" + "cmp x0, #0\n" + "bne 5f\n" + "mov x0, v0.d[1]\n" + "cmp x0, #0\n" + "bne 5f\n" + "add %[rhs_ptr], %[rhs_ptr], #96\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "b 6f\n" + "5:\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v14.4s, v2.4s, v0.s[1]\n" + "fmla v20.4s, v2.4s, v0.s[2]\n" + "fmla v26.4s, v2.4s, v0.s[3]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v15.4s, v3.4s, v0.s[1]\n" + "fmla v21.4s, v3.4s, v0.s[2]\n" + "fmla v27.4s, v3.4s, v0.s[3]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v16.4s, v4.4s, v0.s[1]\n" + "fmla v22.4s, v4.4s, v0.s[2]\n" + "fmla v28.4s, v4.4s, v0.s[3]\n" + + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + + "fmla v11.4s, v5.4s, v0.s[0]\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + "fmla v23.4s, v5.4s, v0.s[2]\n" + "fmla v29.4s, v5.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v0.s[0]\n" + "fmla v18.4s, v6.4s, v0.s[1]\n" + "fmla v24.4s, v6.4s, v0.s[2]\n" + "fmla v30.4s, v6.4s, v0.s[3]\n" + "fmla v13.4s, v7.4s, v0.s[0]\n" + "fmla v19.4s, v7.4s, v0.s[1]\n" + "fmla v25.4s, v7.4s, v0.s[2]\n" + "fmla v31.4s, v7.4s, v0.s[3]\n" + + "6:\n" + "mov x0, v1.d[0]\n" + "cmp x0, #0\n" + "bne 7f\n" + "mov x0, v1.d[1]\n" + "cmp x0, #0\n" + "bne 7f\n" + "add %[rhs_ptr], %[rhs_ptr], #96\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "b 8f\n" + "7:\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + "fmla v8.4s, v2.4s, v1.s[0]\n" + "fmla v14.4s, v2.4s, v1.s[1]\n" + "fmla v20.4s, v2.4s, v1.s[2]\n" + "fmla v26.4s, v2.4s, v1.s[3]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + "fmla v9.4s, v3.4s, v1.s[0]\n" + "fmla v15.4s, v3.4s, v1.s[1]\n" + "fmla v21.4s, v3.4s, v1.s[2]\n" + "fmla v27.4s, v3.4s, v1.s[3]\n" + "fmla v10.4s, v4.4s, v1.s[0]\n" + "fmla v16.4s, v4.4s, v1.s[1]\n" + "fmla v22.4s, v4.4s, v1.s[2]\n" + "fmla v28.4s, v4.4s, v1.s[3]\n" + + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + + "fmla v11.4s, v5.4s, v1.s[0]\n" + "fmla v17.4s, v5.4s, v1.s[1]\n" + "fmla v23.4s, v5.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v1.s[1]\n" + "fmla v24.4s, v6.4s, v1.s[2]\n" + "fmla v30.4s, v6.4s, v1.s[3]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v1.s[1]\n" + "fmla v25.4s, v7.4s, v1.s[2]\n" + "fmla v31.4s, v7.4s, v1.s[3]\n" + + "8:\n" + "subs %w[nk], %w[nk], #1\n" + "bne 1b\n" + + "4:\n" + "mov x0, %[res_ptr]\n" + "cbnz %[oddk], 2f\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v14.4s, v2.4s, v0.s[1]\n" + "fmla v15.4s, v3.4s, v0.s[1]\n" + "fmla v16.4s, v4.4s, v0.s[1]\n" + "fmla v20.4s, v2.4s, v0.s[2]\n" + "fmla v21.4s, v3.4s, v0.s[2]\n" + "fmla v22.4s, v4.4s, v0.s[2]\n" + "fmla v26.4s, v2.4s, v0.s[3]\n" + "fmla v27.4s, v3.4s, v0.s[3]\n" + "fmla v28.4s, v4.4s, v0.s[3]\n" + + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + + "fmla v11.4s, v5.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v0.s[0]\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + "fmla v18.4s, v6.4s, v0.s[1]\n" + "fmla v19.4s, v7.4s, v0.s[1]\n" + "fmla v23.4s, v5.4s, v0.s[2]\n" + "fmla v24.4s, v6.4s, v0.s[2]\n" + "fmla v25.4s, v7.4s, v0.s[2]\n" + "fmla v29.4s, v5.4s, v0.s[3]\n" + "fmla v30.4s, v6.4s, v0.s[3]\n" + "fmla v31.4s, v7.4s, v0.s[3]\n" + + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + + "fmla v8.4s, v2.4s, v1.s[0]\n" + "fmla v9.4s, v3.4s, v1.s[0]\n" + "fmla v10.4s, v4.4s, v1.s[0]\n" + "mov x1, x0\n" + "st1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" + "fmla v11.4s, v5.4s, v1.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" + "fmla v14.4s, v2.4s, v1.s[1]\n" + "fmla v15.4s, v3.4s, v1.s[1]\n" + "fmla v16.4s, v4.4s, v1.s[1]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" + "fmla v17.4s, v5.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v1.s[1]\n" + "st1 {v17.4s, v18.4s, v19.4s}, [x1]\n" + "fmla v20.4s, v2.4s, v1.s[2]\n" + "fmla v21.4s, v3.4s, v1.s[2]\n" + "fmla v22.4s, v4.4s, v1.s[2]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" + "fmla v23.4s, v5.4s, v1.s[2]\n" + "fmla v24.4s, v6.4s, v1.s[2]\n" + "fmla v25.4s, v7.4s, v1.s[2]\n" + "st1 {v23.4s, v24.4s, v25.4s}, [x1]\n" + "fmla v26.4s, v2.4s, v1.s[3]\n" + "fmla v27.4s, v3.4s, v1.s[3]\n" + "fmla v28.4s, v4.4s, v1.s[3]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + "fmla v30.4s, v6.4s, v1.s[3]\n" + "fmla v31.4s, v7.4s, v1.s[3]\n" + "b 3f\n" + + "2:\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "mov x1, x0\n" + "st1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" + "fmla v11.4s, v5.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v0.s[0]\n" + "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" + "fmla v14.4s, v2.4s, v0.s[1]\n" + "fmla v15.4s, v3.4s, v0.s[1]\n" + "fmla v16.4s, v4.4s, v0.s[1]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + "fmla v18.4s, v6.4s, v0.s[1]\n" + "fmla v19.4s, v7.4s, v0.s[1]\n" + "st1 {v17.4s, v18.4s, v19.4s}, [x1]\n" + "fmla v20.4s, v2.4s, v0.s[2]\n" + "fmla v21.4s, v3.4s, v0.s[2]\n" + "fmla v22.4s, v4.4s, v0.s[2]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" + "fmla v23.4s, v5.4s, v0.s[2]\n" + "fmla v24.4s, v6.4s, v0.s[2]\n" + "fmla v25.4s, v7.4s, v0.s[2]\n" + "st1 {v23.4s, v24.4s, v25.4s}, [x1]\n" + "fmla v26.4s, v2.4s, v0.s[3]\n" + "fmla v27.4s, v3.4s, v0.s[3]\n" + "fmla v28.4s, v4.4s, v0.s[3]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" + "fmla v29.4s, v5.4s, v0.s[3]\n" + "fmla v30.4s, v6.4s, v0.s[3]\n" + "fmla v31.4s, v7.4s, v0.s[3]\n" + "3:\n" + "st1 {v29.4s, v30.4s, v31.4s}, [x1]\n" + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), + [nk] "+r"(nk) + : [oddk] "r"(oddk), [k0] "r"(k0), [nstride] "r"(nstride) + : "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} +#else // BATCH_DILATION_FIX +static void sgemm_rowmajor_micro_kernel_4x24(const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k, const int k0, + const int stride) +{ + int oddk = (k & 1); + int nk = ((k + 1) / 2) - 1; + + const int nstride = stride << 2; + + __asm __volatile("ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" + + "cmp %[k0], #0\n" + "beq 0f\n" + + "mov x0, %[res_ptr]\n" + "mov x1, x0\n" + "ld1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" + "ld1 {v11.4s, v12.4s, v13.4s}, [x1]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "ld1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" + "ld1 {v17.4s, v18.4s, v19.4s}, [x1]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "ld1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" + "ld1 {v23.4s, v24.4s, v25.4s}, [x1]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "ld1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" + "ld1 {v29.4s, v30.4s, v31.4s}, [x1]\n" + "cbz %w[nk], 4f\n" + "b 1f\n" + + "0:\n" + "movi v8.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v17.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "cbz %w[nk], 4f\n" + + "1:\n" + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v14.4s, v2.4s, v0.s[1]\n" + "fmla v20.4s, v2.4s, v0.s[2]\n" + "fmla v26.4s, v2.4s, v0.s[3]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v15.4s, v3.4s, v0.s[1]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + "fmla v21.4s, v3.4s, v0.s[2]\n" + "fmla v27.4s, v3.4s, v0.s[3]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v16.4s, v4.4s, v0.s[1]\n" + "fmla v22.4s, v4.4s, v0.s[2]\n" + "fmla v28.4s, v4.4s, v0.s[3]\n" + + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + + "fmla v11.4s, v5.4s, v0.s[0]\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + "fmla v23.4s, v5.4s, v0.s[2]\n" + "fmla v29.4s, v5.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v0.s[0]\n" + "fmla v18.4s, v6.4s, v0.s[1]\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + "fmla v24.4s, v6.4s, v0.s[2]\n" + "fmla v30.4s, v6.4s, v0.s[3]\n" + "fmla v13.4s, v7.4s, v0.s[0]\n" + "fmla v19.4s, v7.4s, v0.s[1]\n" + "fmla v25.4s, v7.4s, v0.s[2]\n" + "fmla v31.4s, v7.4s, v0.s[3]\n" + + "fmla v8.4s, v2.4s, v1.s[0]\n" + "fmla v14.4s, v2.4s, v1.s[1]\n" + "fmla v20.4s, v2.4s, v1.s[2]\n" + "fmla v26.4s, v2.4s, v1.s[3]\n" + "fmla v9.4s, v3.4s, v1.s[0]\n" + "fmla v15.4s, v3.4s, v1.s[1]\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + "fmla v21.4s, v3.4s, v1.s[2]\n" + "fmla v27.4s, v3.4s, v1.s[3]\n" + "fmla v10.4s, v4.4s, v1.s[0]\n" + "fmla v16.4s, v4.4s, v1.s[1]\n" + "fmla v22.4s, v4.4s, v1.s[2]\n" + "fmla v28.4s, v4.4s, v1.s[3]\n" + + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + + "fmla v11.4s, v5.4s, v1.s[0]\n" + "fmla v17.4s, v5.4s, v1.s[1]\n" + "fmla v23.4s, v5.4s, v1.s[2]\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v1.s[1]\n" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + "fmla v24.4s, v6.4s, v1.s[2]\n" + "fmla v30.4s, v6.4s, v1.s[3]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v1.s[1]\n" + "subs %w[nk], %w[nk], #1\n" + "fmla v25.4s, v7.4s, v1.s[2]\n" + "fmla v31.4s, v7.4s, v1.s[3]\n" + "bne 1b\n" + + "4:\n" + "mov x0, %[res_ptr]\n" + "cbnz %[oddk], 2f\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v14.4s, v2.4s, v0.s[1]\n" + "fmla v15.4s, v3.4s, v0.s[1]\n" + "fmla v16.4s, v4.4s, v0.s[1]\n" + "fmla v20.4s, v2.4s, v0.s[2]\n" + "fmla v21.4s, v3.4s, v0.s[2]\n" + "fmla v22.4s, v4.4s, v0.s[2]\n" + "fmla v26.4s, v2.4s, v0.s[3]\n" + "fmla v27.4s, v3.4s, v0.s[3]\n" + "fmla v28.4s, v4.4s, v0.s[3]\n" + + "ld1 {v2.4s, v3.4s, v4.4s}, [%[rhs_ptr]], #48\n" + + "fmla v11.4s, v5.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v0.s[0]\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + "fmla v18.4s, v6.4s, v0.s[1]\n" + "fmla v19.4s, v7.4s, v0.s[1]\n" + "fmla v23.4s, v5.4s, v0.s[2]\n" + "fmla v24.4s, v6.4s, v0.s[2]\n" + "fmla v25.4s, v7.4s, v0.s[2]\n" + "fmla v29.4s, v5.4s, v0.s[3]\n" + "fmla v30.4s, v6.4s, v0.s[3]\n" + "fmla v31.4s, v7.4s, v0.s[3]\n" + + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + + "fmla v8.4s, v2.4s, v1.s[0]\n" + "fmla v9.4s, v3.4s, v1.s[0]\n" + "fmla v10.4s, v4.4s, v1.s[0]\n" + "mov x1, x0\n" + "st1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" + "fmla v11.4s, v5.4s, v1.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" + "fmla v14.4s, v2.4s, v1.s[1]\n" + "fmla v15.4s, v3.4s, v1.s[1]\n" + "fmla v16.4s, v4.4s, v1.s[1]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" + "fmla v17.4s, v5.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v1.s[1]\n" + "st1 {v17.4s, v18.4s, v19.4s}, [x1]\n" + "fmla v20.4s, v2.4s, v1.s[2]\n" + "fmla v21.4s, v3.4s, v1.s[2]\n" + "fmla v22.4s, v4.4s, v1.s[2]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" + "fmla v23.4s, v5.4s, v1.s[2]\n" + "fmla v24.4s, v6.4s, v1.s[2]\n" + "fmla v25.4s, v7.4s, v1.s[2]\n" + "st1 {v23.4s, v24.4s, v25.4s}, [x1]\n" + "fmla v26.4s, v2.4s, v1.s[3]\n" + "fmla v27.4s, v3.4s, v1.s[3]\n" + "fmla v28.4s, v4.4s, v1.s[3]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" + "fmla v29.4s, v5.4s, v1.s[3]\n" + "fmla v30.4s, v6.4s, v1.s[3]\n" + "fmla v31.4s, v7.4s, v1.s[3]\n" + "b 3f\n" + + "2:\n" + "ld1 {v5.4s, v6.4s, v7.4s}, [%[rhs_ptr]], #48\n" + + "fmla v8.4s, v2.4s, v0.s[0]\n" + "fmla v9.4s, v3.4s, v0.s[0]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "mov x1, x0\n" + "st1 {v8.4s, v9.4s, v10.4s}, [x1], #48\n" + "fmla v11.4s, v5.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v0.s[0]\n" + "st1 {v11.4s, v12.4s, v13.4s}, [x1]\n" + "fmla v14.4s, v2.4s, v0.s[1]\n" + "fmla v15.4s, v3.4s, v0.s[1]\n" + "fmla v16.4s, v4.4s, v0.s[1]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v14.4s, v15.4s, v16.4s}, [x1], #48\n" + "fmla v17.4s, v5.4s, v0.s[1]\n" + "fmla v18.4s, v6.4s, v0.s[1]\n" + "fmla v19.4s, v7.4s, v0.s[1]\n" + "st1 {v17.4s, v18.4s, v19.4s}, [x1]\n" + "fmla v20.4s, v2.4s, v0.s[2]\n" + "fmla v21.4s, v3.4s, v0.s[2]\n" + "fmla v22.4s, v4.4s, v0.s[2]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v20.4s, v21.4s, v22.4s}, [x1], #48\n" + "fmla v23.4s, v5.4s, v0.s[2]\n" + "fmla v24.4s, v6.4s, v0.s[2]\n" + "fmla v25.4s, v7.4s, v0.s[2]\n" + "st1 {v23.4s, v24.4s, v25.4s}, [x1]\n" + "fmla v26.4s, v2.4s, v0.s[3]\n" + "fmla v27.4s, v3.4s, v0.s[3]\n" + "fmla v28.4s, v4.4s, v0.s[3]\n" + "add x0, x0, %[nstride]\n" + "mov x1, x0\n" + "st1 {v26.4s, v27.4s, v28.4s}, [x1], #48\n" + "fmla v29.4s, v5.4s, v0.s[3]\n" + "fmla v30.4s, v6.4s, v0.s[3]\n" + "fmla v31.4s, v7.4s, v0.s[3]\n" + "3:\n" + "st1 {v29.4s, v30.4s, v31.4s}, [x1]\n" + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), + [nk] "+r"(nk) + : [oddk] "r"(oddk), [k0] "r"(k0), [nstride] "r"(nstride) + : "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} +#endif // BATCH_DILATION_FIX + +static void sgemm_rowmajor_micro_kernel_24x4(const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k, const int k0, + const int stride) +{ + int oddk = (k & 1); + int nk = ((k + 1) / 2) - 1; + + const int nstride = stride << 2; + + __asm __volatile("ld1 {v0.4s, v1.4s, v2.4s}, [%[lhs_ptr]], #48\n" + "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" + + "cmp %[k0], #0\n" + "beq 0f\n" + + "mov x0, %[res_ptr]\n" + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v9.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v13.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v14.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v15.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v16.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v17.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v18.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v19.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v20.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v21.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v22.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v23.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v24.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v25.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v26.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v27.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v28.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v29.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v30.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "ld1 {v31.4s}, [x0]\n" + "cbz %w[nk], 4f\n" + "b 1f\n" + + "0:\n" + "movi v8.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "movi v11.4s, #0x0\n" + "movi v12.4s, #0x0\n" + "movi v13.4s, #0x0\n" + "movi v14.4s, #0x0\n" + "movi v15.4s, #0x0\n" + "movi v16.4s, #0x0\n" + "movi v17.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + "cbz %w[nk], 4f\n" + + "1:\n" + "ld1 {v3.4s, v4.4s, v5.4s}, [%[lhs_ptr]], #48\n" + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v9.4s, v6.4s, v0.s[1]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v11.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "fmla v13.4s, v6.4s, v1.s[1]\n" + "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "fmla v15.4s, v6.4s, v1.s[3]\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v17.4s, v6.4s, v2.s[1]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v19.4s, v6.4s, v2.s[3]\n" + "ld1 {v0.4s, v1.4s, v2.4s}, [%[lhs_ptr]], #48\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "fmla v21.4s, v6.4s, v3.s[1]\n" + "fmla v22.4s, v6.4s, v3.s[2]\n" + "fmla v23.4s, v6.4s, v3.s[3]\n" + "fmla v24.4s, v6.4s, v4.s[0]\n" + "fmla v25.4s, v6.4s, v4.s[1]\n" + "fmla v26.4s, v6.4s, v4.s[2]\n" + "fmla v27.4s, v6.4s, v4.s[3]\n" + "fmla v28.4s, v6.4s, v5.s[0]\n" + "fmla v29.4s, v6.4s, v5.s[1]\n" + "fmla v30.4s, v6.4s, v5.s[2]\n" + "fmla v31.4s, v6.4s, v5.s[3]\n" + + "ld1 {v3.4s, v4.4s, v5.4s}, [%[lhs_ptr]], #48\n" + "fmla v8.4s, v7.4s, v0.s[0]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v10.4s, v7.4s, v0.s[2]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "fmla v12.4s, v7.4s, v1.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" + "fmla v14.4s, v7.4s, v1.s[2]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "fmla v16.4s, v7.4s, v2.s[0]\n" + "fmla v17.4s, v7.4s, v2.s[1]\n" + "fmla v18.4s, v7.4s, v2.s[2]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "ld1 {v0.4s, v1.4s, v2.4s}, [%[lhs_ptr]], #48\n" + "fmla v20.4s, v7.4s, v3.s[0]\n" + "fmla v21.4s, v7.4s, v3.s[1]\n" + "fmla v22.4s, v7.4s, v3.s[2]\n" + "fmla v23.4s, v7.4s, v3.s[3]\n" + "fmla v24.4s, v7.4s, v4.s[0]\n" + "fmla v25.4s, v7.4s, v4.s[1]\n" + "fmla v26.4s, v7.4s, v4.s[2]\n" + "fmla v27.4s, v7.4s, v4.s[3]\n" + "fmla v28.4s, v7.4s, v5.s[0]\n" + "fmla v29.4s, v7.4s, v5.s[1]\n" + "subs %w[nk], %w[nk], #1\n" + "fmla v30.4s, v7.4s, v5.s[2]\n" + "fmla v31.4s, v7.4s, v5.s[3]\n" + "bne 1b\n" + + "4:\n" + "mov x0, %[res_ptr]\n" + "cbnz %[oddk], 2f\n" + + "ld1 {v3.4s, v4.4s, v5.4s}, [%[lhs_ptr]], #48\n" + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v9.4s, v6.4s, v0.s[1]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v11.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "fmla v13.4s, v6.4s, v1.s[1]\n" + "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "fmla v15.4s, v6.4s, v1.s[3]\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v17.4s, v6.4s, v2.s[1]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v19.4s, v6.4s, v2.s[3]\n" + "ld1 {v0.4s, v1.4s, v2.4s}, [%[lhs_ptr]], #48\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "fmla v21.4s, v6.4s, v3.s[1]\n" + "fmla v22.4s, v6.4s, v3.s[2]\n" + "fmla v23.4s, v6.4s, v3.s[3]\n" + "fmla v24.4s, v6.4s, v4.s[0]\n" + "fmla v25.4s, v6.4s, v4.s[1]\n" + "fmla v26.4s, v6.4s, v4.s[2]\n" + "fmla v27.4s, v6.4s, v4.s[3]\n" + "fmla v28.4s, v6.4s, v5.s[0]\n" + "fmla v29.4s, v6.4s, v5.s[1]\n" + "fmla v30.4s, v6.4s, v5.s[2]\n" + "fmla v31.4s, v6.4s, v5.s[3]\n" + + "ld1 {v3.4s, v4.4s, v5.4s}, [%[lhs_ptr]], #48\n" + "fmla v8.4s, v7.4s, v0.s[0]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "st1 {v8.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v10.4s, v7.4s, v0.s[2]\n" + "st1 {v9.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "st1 {v10.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v12.4s, v7.4s, v1.s[0]\n" + "st1 {v11.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "st1 {v12.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v14.4s, v7.4s, v1.s[2]\n" + "st1 {v13.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "st1 {v14.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v16.4s, v7.4s, v2.s[0]\n" + "st1 {v15.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v17.4s, v7.4s, v2.s[1]\n" + "st1 {v16.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v18.4s, v7.4s, v2.s[2]\n" + "st1 {v17.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "st1 {v18.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v20.4s, v7.4s, v3.s[0]\n" + "st1 {v19.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v21.4s, v7.4s, v3.s[1]\n" + "st1 {v20.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v22.4s, v7.4s, v3.s[2]\n" + "st1 {v21.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v23.4s, v7.4s, v3.s[3]\n" + "st1 {v22.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v24.4s, v7.4s, v4.s[0]\n" + "st1 {v23.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v25.4s, v7.4s, v4.s[1]\n" + "st1 {v24.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v26.4s, v7.4s, v4.s[2]\n" + "st1 {v25.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v27.4s, v7.4s, v4.s[3]\n" + "st1 {v26.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v28.4s, v7.4s, v5.s[0]\n" + "st1 {v27.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v29.4s, v7.4s, v5.s[1]\n" + "st1 {v28.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v30.4s, v7.4s, v5.s[2]\n" + "st1 {v29.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v31.4s, v7.4s, v5.s[3]\n" + "st1 {v30.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "b 3f\n" + + "2:\n" + "ld1 {v3.4s, v4.4s, v5.4s}, [%[lhs_ptr]], #48\n" + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v9.4s, v6.4s, v0.s[1]\n" + "st1 {v8.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "st1 {v9.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v11.4s, v6.4s, v0.s[3]\n" + "st1 {v10.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "st1 {v11.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v13.4s, v6.4s, v1.s[1]\n" + "st1 {v12.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "st1 {v13.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v15.4s, v6.4s, v1.s[3]\n" + "st1 {v14.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "st1 {v15.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v17.4s, v6.4s, v2.s[1]\n" + "st1 {v16.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "st1 {v17.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v19.4s, v6.4s, v2.s[3]\n" + "st1 {v18.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "st1 {v19.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v21.4s, v6.4s, v3.s[1]\n" + "st1 {v20.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v22.4s, v6.4s, v3.s[2]\n" + "st1 {v21.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v23.4s, v6.4s, v3.s[3]\n" + "st1 {v22.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v24.4s, v6.4s, v4.s[0]\n" + "st1 {v23.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v25.4s, v6.4s, v4.s[1]\n" + "st1 {v24.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v26.4s, v6.4s, v4.s[2]\n" + "st1 {v25.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v27.4s, v6.4s, v4.s[3]\n" + "st1 {v26.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v28.4s, v6.4s, v5.s[0]\n" + "st1 {v27.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v29.4s, v6.4s, v5.s[1]\n" + "st1 {v28.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v30.4s, v6.4s, v5.s[2]\n" + "st1 {v29.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "fmla v31.4s, v6.4s, v5.s[3]\n" + "st1 {v30.4s}, [x0]\n" + "add x0, x0, %[nstride]\n" + "3:\n" + "st1 {v31.4s}, [x0]\n" + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), + [nk] "+r"(nk) + : [oddk] "r"(oddk), [k0] "r"(k0), [nstride] "r"(nstride) + : "x0", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +#else // __aarch64__ +static void sgemm_rowmajor_micro_kernel_6x8(const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k, const int k0, + const int stride) +{ + int nk = k >> 2; + int rk = k & 3; + + const int nstride = stride << 2; + + if (rk == 0) + { + nk--; + rk = 4; + } + + __asm __volatile("vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" + "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" + + "cmp %[k0], #0\n" + "beq 0f\n" + + "mov r0, %[res_ptr]\n" + + "vld1.f32 {d8-d11}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d12-d15}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d16-d19}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d20-d23}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d24-d27}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d28-d31}, [r0]\n" + "b 1f\n" + + "0:\n" + "vmov.i32 q4, #0\n" + "vmov.i32 q5, #0\n" + "vmov.i32 q6, #0\n" + "pld [%[lhs_ptr], #48]\n" + "vmov.i32 q7, #0\n" + "pld [%[rhs_ptr], #48]\n" + "vmov.i32 q8, #0\n" + "pld [%[lhs_ptr], #112]\n" + "vmov.i32 q9, #0\n" + "pld [%[rhs_ptr], #112]\n" + "vmov.i32 q10, #0\n" + "vmov.i32 q11, #0\n" + "vmov.i32 q12, #0\n" + "vmov.i32 q13, #0\n" + "pld [%[lhs_ptr], #176]\n" + "vmov.i32 q14, #0\n" + "pld [%[rhs_ptr], #176]\n" + "vmov.i32 q15, #0\n" + + "1:\n" + "cmp %[nk], #0\n" + "beq 6f\n" + "vmla.f32 q4, q2, d0[0]\n" + "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q6, q2, d0[1]\n" + "vmla.f32 q8, q2, d1[0]\n" + "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q10, q2, d1[1]\n" + "vmla.f32 q12, q2, d2[0]\n" + "vmla.f32 q14, q2, d2[1]\n" + "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" + + "vmla.f32 q5, q3, d0[0]\n" + "vmla.f32 q7, q3, d0[1]\n" + "vmla.f32 q9, q3, d1[0]\n" + "vmla.f32 q11, q3, d1[1]\n" + "vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q13, q3, d2[0]\n" + "vmla.f32 q15, q3, d2[1]\n" + "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" + + "vmla.f32 q4, q2, d3[0]\n" + "subs %[nk], %[nk], #1\n" + "vmla.f32 q6, q2, d3[1]\n" + "pld [%[lhs_ptr], #208]\n" + "vmla.f32 q8, q2, d0[0]\n" + "vmla.f32 q10, q2, d0[1]\n" + "pld [%[rhs_ptr], #192]\n" + "vmla.f32 q12, q2, d1[0]\n" + "vmla.f32 q14, q2, d1[1]\n" + "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" + + "vmla.f32 q5, q3, d3[0]\n" + "vmla.f32 q7, q3, d3[1]\n" + "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q9, q3, d0[0]\n" + "vmla.f32 q11, q3, d0[1]\n" + "vmla.f32 q13, q3, d1[0]\n" + "vmla.f32 q15, q3, d1[1]\n" + "vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" + + "vmla.f32 q4, q2, d2[0]\n" + "vmla.f32 q6, q2, d2[1]\n" + "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q8, q2, d3[0]\n" + "vmla.f32 q10, q2, d3[1]\n" + "pld [%[lhs_ptr], #240]\n" + "vmla.f32 q12, q2, d0[0]\n" + "vmla.f32 q14, q2, d0[1]\n" + "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" + + "vmla.f32 q5, q3, d2[0]\n" + "vmla.f32 q7, q3, d2[1]\n" + "pld [%[rhs_ptr], #208]\n" + "vmla.f32 q9, q3, d3[0]\n" + "vmla.f32 q11, q3, d3[1]\n" + "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q13, q3, d0[0]\n" + "vmla.f32 q15, q3, d0[1]\n" + "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" + + "vmla.f32 q4, q2, d1[0]\n" + "vmla.f32 q6, q2, d1[1]\n" + "vmla.f32 q8, q2, d2[0]\n" + "vmla.f32 q10, q2, d2[1]\n" + "vmla.f32 q12, q2, d3[0]\n" + "vmla.f32 q14, q2, d3[1]\n" + "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" + + "vmla.f32 q5, q3, d1[0]\n" + "vmla.f32 q7, q3, d1[1]\n" + "vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q9, q3, d2[0]\n" + "vmla.f32 q11, q3, d2[1]\n" + "vmla.f32 q13, q3, d3[0]\n" + "vmla.f32 q15, q3, d3[1]\n" + "bne 1b\n" + + "6:\n" + "mov r0, %[res_ptr]\n" + "subs %[rk], %[rk], #1\n" + "beq 3f\n" + + "vmla.f32 q4, q2, d0[0]\n" + "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q6, q2, d0[1]\n" + "vmla.f32 q8, q2, d1[0]\n" + "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q10, q2, d1[1]\n" + "vmla.f32 q12, q2, d2[0]\n" + "subs %[rk], %[rk], #1\n" + "vmla.f32 q14, q2, d2[1]\n" + "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" + + "vmla.f32 q5, q3, d0[0]\n" + "vmla.f32 q7, q3, d0[1]\n" + "vmla.f32 q9, q3, d1[0]\n" + "vmla.f32 q11, q3, d1[1]\n" + "vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q13, q3, d2[0]\n" + "vmla.f32 q15, q3, d2[1]\n" + "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" + "beq 4f\n" + + "vmla.f32 q4, q2, d3[0]\n" + "vmla.f32 q6, q2, d3[1]\n" + "subs %[rk], %[rk], #1\n" + "vmla.f32 q8, q2, d0[0]\n" + "vmla.f32 q10, q2, d0[1]\n" + "vmla.f32 q12, q2, d1[0]\n" + "vmla.f32 q14, q2, d1[1]\n" + "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" + + "vmla.f32 q5, q3, d3[0]\n" + "vmla.f32 q7, q3, d3[1]\n" + "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q9, q3, d0[0]\n" + "vmla.f32 q11, q3, d0[1]\n" + "vmla.f32 q13, q3, d1[0]\n" + "vmla.f32 q15, q3, d1[1]\n" + "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" + "beq 5f\n" + + "vld1.32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q4, q2, d2[0]\n" + "vmla.f32 q6, q2, d2[1]\n" + "vmla.f32 q8, q2, d3[0]\n" + "vmla.f32 q10, q2, d3[1]\n" + "vmla.f32 q12, q2, d0[0]\n" + "vmla.f32 q14, q2, d0[1]\n" + "vld1.32 {d4-d5}, [%[rhs_ptr]]!\n" + + "vmla.f32 q5, q3, d2[0]\n" + "vmla.f32 q7, q3, d2[1]\n" + "vmla.f32 q9, q3, d3[0]\n" + "vmla.f32 q11, q3, d3[1]\n" + "vld1.32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q13, q3, d0[0]\n" + "vmla.f32 q15, q3, d0[1]\n" + "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" + + "vmla.f32 q4, q2, d1[0]\n" + "vmla.f32 q5, q3, d1[0]\n" + "vst1.32 {d8-d11}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q6, q2, d1[1]\n" + "vmla.f32 q7, q3, d1[1]\n" + "vst1.32 {d12-d15}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q8, q2, d2[0]\n" + "vmla.f32 q9, q3, d2[0]\n" + "vst1.32 {d16-d19}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q10, q2, d2[1]\n" + "vmla.f32 q11, q3, d2[1]\n" + "vst1.32 {d20-d23}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q12, q2, d3[0]\n" + "vmla.f32 q13, q3, d3[0]\n" + "vst1.32 {d24-d27}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q14, q2, d3[1]\n" + "vmla.f32 q15, q3, d3[1]\n" + "b 2f\n" + + "3:\n" + "vld1.32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q4, q2, d0[0]\n" + "vmla.f32 q5, q3, d0[0]\n" + "vst1.32 {d8-d11}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q6, q2, d0[1]\n" + "vmla.f32 q7, q3, d0[1]\n" + "vst1.32 {d12-d15}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q8, q2, d1[0]\n" + "vld1.32 {d2}, [%[lhs_ptr]]!\n" + "vmla.f32 q9, q3, d1[0]\n" + "vst1.32 {d16-d19}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q10, q2, d1[1]\n" + "vmla.f32 q11, q3, d1[1]\n" + "vst1.32 {d20-d23}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q12, q2, d2[0]\n" + "vmla.f32 q13, q3, d2[0]\n" + "vst1.32 {d24-d27}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q14, q2, d2[1]\n" + "vmla.f32 q15, q3, d2[1]\n" + "b 2f\n" + + "4:\n" + "vmla.f32 q4, q2, d3[0]\n" + "vmla.f32 q5, q3, d3[0]\n" + "vst1.32 {d8-d11}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q6, q2, d3[1]\n" + "vmla.f32 q7, q3, d3[1]\n" + "vst1.32 {d12-d15}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q8, q2, d0[0]\n" + "vmla.f32 q9, q3, d0[0]\n" + "vst1.32 {d16-d19}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q10, q2, d0[1]\n" + "vmla.f32 q11, q3, d0[1]\n" + "vst1.32 {d20-d23}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q12, q2, d1[0]\n" + "vmla.f32 q13, q3, d1[0]\n" + "vst1.32 {d24-d27}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q14, q2, d1[1]\n" + "vmla.f32 q15, q3, d1[1]\n" + "b 2f\n" + + "5:\n" + "vld1.32 {d0}, [%[lhs_ptr]]!\n" + "vmla.f32 q4, q2, d2[0]\n" + "vmla.f32 q5, q3, d2[0]\n" + "vst1.32 {d8-d11}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q6, q2, d2[1]\n" + "vmla.f32 q7, q3, d2[1]\n" + "vst1.32 {d12-d15}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q8, q2, d3[0]\n" + "vmla.f32 q9, q3, d3[0]\n" + "vst1.32 {d16-d19}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q10, q2, d3[1]\n" + "vmla.f32 q11, q3, d3[1]\n" + "vst1.32 {d20-d23}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q12, q2, d0[0]\n" + "vmla.f32 q13, q3, d0[0]\n" + "vst1.32 {d24-d27}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q14, q2, d0[1]\n" + "vmla.f32 q15, q3, d0[1]\n" + "2:\n" + "vst1.32 {d28-d31}, [r0]\n" + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), + [nk] "+r"(nk), [rk] "+r"(rk) + : [k0] "r"(k0), [nstride] "r"(nstride) + : "r0", "r1", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc"); +} + +static void sgemm_rowmajor_micro_kernel_4x12(const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k, const int k0, + const int stride) +{ + int rk = (k & 1); + int nk = (k + 1) / 2; + + const int nstride = stride << 2; + + asm volatile("vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" + + "cmp %[k0], #0\n" + "beq 0f\n" + + "mov r1, %[res_ptr]\n" + + "subs %[nk], %[nk], #1\n" + "mov r0, r1\n" + "vld1.f32 {d8-d9}, [r0]!\n" + "add r1, %[nstride]\n" + "vld1.f32 {d16-d17}, [r0]!\n" + "vld1.f32 {d24-d25}, [r0]\n" + "mov r0, r1\n" + "vld1.f32 {d10-d11}, [r0]!\n" + "add r1, %[nstride]\n" + "vld1.f32 {d18-d19}, [r0]!\n" + "vld1.f32 {d26-d27}, [r0]\n" + "mov r0, r1\n" + "vld1.f32 {d12-d13}, [r0]!\n" + "add r1, %[nstride]\n" + "vld1.f32 {d20-d21}, [r0]!\n" + "vld1.f32 {d28-d29}, [r0]\n" + "mov r0, r1\n" + "vld1.f32 {d14-d15}, [r0]!\n" + "vld1.f32 {d22-d23}, [r0]!\n" + "vld1.f32 {d30-d31}, [r0]\n" + "beq 2f\n" + + "b 1f\n" + + "0:\n" + "veor q4, q4\n" + "subs %[nk],%[nk], #1\n" + "vmov.f32 q8, q4\n" + "vmov.f32 q12, q4\n" + "vmov.f32 q5, q4\n" + "vmov.f32 q9, q4\n" + "vmov.f32 q13, q4\n" + "vmov.f32 q6, q4\n" + "vmov.f32 q10, q4\n" + "vmov.f32 q14, q4\n" + "vmov.f32 q7, q4\n" + "vmov.f32 q11, q4\n" + "vmov.f32 q15, q4\n" + + "beq 2f\n" + + "1:\n" + "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q4, q2, d0[0]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vmla.f32 q7, q2, d1[1]\n" + "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" + "vmla.f32 q8, q3, d0[0]\n" + "vmla.f32 q9, q3, d0[1]\n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q10, q3, d1[0]\n" + "vmla.f32 q11, q3, d1[1]\n" + "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q12, q2, d0[0]\n" + "vmla.f32 q13, q2, d0[1]\n" + "pld [%[lhs_ptr], #208]\n" + "vmla.f32 q14, q2, d1[0]\n" + "pld [%[rhs_ptr], #192]\n" + "vmla.f32 q15, q2, d1[1]\n" + + "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" + "vmla.f32 q4, q3, d2[0]\n" + "vmla.f32 q5, q3, d2[1]\n" + "vmla.f32 q6, q3, d3[0]\n" + "vmla.f32 q7, q3, d3[1]\n" + "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q8, q2, d2[0]\n" + "vmla.f32 q9, q2, d2[1]\n" + "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q10, q2, d3[0]\n" + "vmla.f32 q11, q2, d3[1]\n" + "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" + "vmla.f32 q12, q3, d2[0]\n" + "vmla.f32 q13, q3, d2[1]\n" + "subs %[nk],%[nk], #1\n" + "pld [%[lhs_ptr], #240]\n" + "vmla.f32 q14, q3, d3[0]\n" + "pld [%[rhs_ptr], #208]\n" + "vmla.f32 q15, q3, d3[1]\n" + "bne 1b\n" + + "2:\n" + "cmp %[rk], #1\n" + "beq 3f\n" + + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q4, q2, d0[0]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vmla.f32 q7, q2, d1[1]\n" + "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" + "vmla.f32 q8, q3, d0[0]\n" + "vmla.f32 q9, q3, d0[1]\n" + "vmla.f32 q10, q3, d1[0]\n" + "vmla.f32 q11, q3, d1[1]\n" + "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q12, q2, d0[0]\n" + "vmla.f32 q13, q2, d0[1]\n" + "vmla.f32 q14, q2, d1[0]\n" + "vmla.f32 q15, q2, d1[1]\n" + + "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" + "vld1.f32 {d0-d1}, [%[rhs_ptr]]!\n" + "mov r1, %[res_ptr]\n" + "mov r0, r1\n" + "vmla.f32 q4, q3, d2[0]\n" + "vmla.f32 q8, q2, d2[0]\n" + "vmla.f32 q12, q0, d2[0]\n" + "vst1.f32 {d8-d9}, [r0]!\n" + "add r1, %[nstride]\n" + "vmla.f32 q5, q3, d2[1]\n" + "vst1.f32 {d16-d17}, [r0]!\n" + "vmla.f32 q9, q2, d2[1]\n" + "vst1.f32 {d24-d25}, [r0]\n" + "mov r0, r1\n" + "vmla.f32 q13, q0, d2[1]\n" + "vst1.f32 {d10-d11}, [r0]!\n" + "vmla.f32 q6, q3, d3[0]\n" + "add r1, %[nstride]\n" + "vst1.f32 {d18-d19}, [r0]!\n" + "vmla.f32 q10, q2, d3[0]\n" + "vst1.f32 {d26-d27}, [r0]\n" + "mov r0, r1\n" + "vmla.f32 q14, q0, d3[0]\n" + "vst1.f32 {d12-d13}, [r0]!\n" + "add r1, %[nstride]\n" + "vmla.f32 q7, q3, d3[1]\n" + "vst1.f32 {d20-d21}, [r0]!\n" + "vmla.f32 q11, q2, d3[1]\n" + "vst1.f32 {d28-d29}, [r0]\n" + "mov r0, r1\n" + "vmla.f32 q15, q0, d3[1]\n" + "b 4f\n" + + "3:\n" + "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" + "vld1.f32 {d2-d3}, [%[rhs_ptr]]!\n" + "mov r1, %[res_ptr]\n" + "mov r0, r1\n" + "vmla.f32 q4, q2, d0[0]\n" + "vmla.f32 q8, q3, d0[0]\n" + "vmla.f32 q12, q1, d0[0]\n" + "vst1.f32 {d8-d9}, [r0]!\n" + "add r1, %[nstride]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vst1.f32 {d16-d17}, [r0]!\n" + "vmla.f32 q9, q3, d0[1]\n" + "vst1.f32 {d24-d25}, [r0]\n" + "mov r0, r1\n" + "vmla.f32 q13, q1, d0[1]\n" + "vst1.f32 {d10-d11}, [r0]!\n" + "vmla.f32 q6, q2, d1[0]\n" + "add r1, %[nstride]\n" + "vst1.f32 {d18-d19}, [r0]!\n" + "vmla.f32 q10, q3, d1[0]\n" + "vst1.f32 {d26-d27}, [r0]\n" + "mov r0, r1\n" + "vmla.f32 q14, q1, d1[0]\n" + "vst1.f32 {d12-d13}, [r0]!\n" + "add r1, %[nstride]\n" + "vmla.f32 q7, q2, d1[1]\n" + "vst1.f32 {d20-d21}, [r0]!\n" + "vmla.f32 q11, q3, d1[1]\n" + "vst1.f32 {d28-d29}, [r0]\n" + "mov r0, r1\n" + "vmla.f32 q15, q1, d1[1]\n" + + "4:\n" + "vst1.f32 {d14-d15}, [r0]!\n" + "vst1.f32 {d22-d23}, [r0]!\n" + "vst1.f32 {d30-d31}, [r0]\n" + + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), + [nk] "+r"(nk), [rk] "+r"(rk) + : [k0] "r"(k0), [nstride] "r"(nstride) + : "r0", "r1", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc"); +} + +static void sgemm_rowmajor_micro_kernel_12x4(const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k, const int k0, + const int stride) +{ + int rk = (k & 1); + int nk = (k + 1) / 2; + + const int nstride = stride << 2; + + asm volatile("vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" + + "cmp %[k0], #0\n" + "beq 0f\n" + + "mov r0, %[res_ptr]\n" + "subs %[nk], %[nk], #1\n" + "vld1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d28-d29}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d30-d31}, [r0]\n" + "beq 2f\n" + "b 1f\n" + + "0:\n" + "veor q4, q4\n" + "subs %[nk],%[nk], #1\n" + "vmov.f32 q5, q4\n" + "vmov.f32 q6, q4\n" + "vmov.f32 q7, q4\n" + "vmov.f32 q8, q4\n" + "vmov.f32 q9, q4\n" + "vmov.f32 q10, q4\n" + "vmov.f32 q11, q4\n" + "vmov.f32 q12, q4\n" + "vmov.f32 q13, q4\n" + "vmov.f32 q14, q4\n" + "vmov.f32 q15, q4\n" + + "beq 2f\n" + + "1:\n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q4, q2, d0[0]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vmla.f32 q7, q2, d1[1]\n" + "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q8, q2, d2[0]\n" + "vmla.f32 q9, q2, d2[1]\n" + "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q10, q2, d3[0]\n" + "vmla.f32 q11, q2, d3[1]\n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q12, q2, d0[0]\n" + "vmla.f32 q13, q2, d0[1]\n" + "pld [%[rhs_ptr], #208]\n" + "vmla.f32 q14, q2, d1[0]\n" + "pld [%[lhs_ptr], #192]\n" + "vmla.f32 q15, q2, d1[1]\n" + + "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q4, q3, d2[0]\n" + "vmla.f32 q5, q3, d2[1]\n" + "vmla.f32 q6, q3, d3[0]\n" + "vmla.f32 q7, q3, d3[1]\n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q8, q3, d0[0]\n" + "vmla.f32 q9, q3, d0[1]\n" + "vld1.f32 {d4-d5}, [%[rhs_ptr]]!\n" + "vmla.f32 q10, q3, d1[0]\n" + "vmla.f32 q11, q3, d1[1]\n" + "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q12, q3, d2[0]\n" + "vmla.f32 q13, q3, d2[1]\n" + "subs %[nk],%[nk], #1\n" + "pld [%[rhs_ptr], #240]\n" + "vmla.f32 q14, q3, d3[0]\n" + "pld [%[lhs_ptr], #208]\n" + "vmla.f32 q15, q3, d3[1]\n" + "bne 1b\n" + + "2:\n" + "cmp %[rk], #1\n" + "beq 3f\n" + + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q4, q2, d0[0]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vmla.f32 q7, q2, d1[1]\n" + "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q8, q2, d2[0]\n" + "vmla.f32 q9, q2, d2[1]\n" + "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" + "vmla.f32 q10, q2, d3[0]\n" + "vmla.f32 q11, q2, d3[1]\n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q12, q2, d0[0]\n" + "vmla.f32 q13, q2, d0[1]\n" + "vmla.f32 q14, q2, d1[0]\n" + "vmla.f32 q15, q2, d1[1]\n" + + "mov r0, %[res_ptr]\n" + "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q4, q3, d2[0]\n" + "vst1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q5, q3, d2[1]\n" + "vst1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q6, q3, d3[0]\n" + "vst1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q7, q3, d3[1]\n" + "vst1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q8, q3, d0[0]\n" + "vst1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q9, q3, d0[1]\n" + "vst1.f32 {d18-d19}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q10, q3, d1[0]\n" + "vst1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q11, q3, d1[1]\n" + "vst1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q12, q3, d2[0]\n" + "vst1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q13, q3, d2[1]\n" + "vst1.f32 {d26-d27}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q14, q3, d3[0]\n" + "vst1.f32 {d28-d29}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q15, q3, d3[1]\n" + "b 4f\n" + + "3:\n" + "mov r0, %[res_ptr]\n" + "vld1.f32 {d2-d3}, [%[lhs_ptr]]!\n" + "vmla.f32 q4, q2, d0[0]\n" + "vst1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q5, q2, d0[1]\n" + "vst1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q6, q2, d1[0]\n" + "vst1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q7, q2, d1[1]\n" + "vst1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vld1.f32 {d0-d1}, [%[lhs_ptr]]!\n" + "vmla.f32 q8, q2, d2[0]\n" + "vst1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q9, q2, d2[1]\n" + "vst1.f32 {d18-d19}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q10, q2, d3[0]\n" + "vst1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q11, q2, d3[1]\n" + "vst1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q12, q2, d0[0]\n" + "vst1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q13, q2, d0[1]\n" + "vst1.f32 {d26-d27}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q14, q2, d1[0]\n" + "vst1.f32 {d28-d29}, [r0]\n" + "add r0, r0, %[nstride]\n" + "vmla.f32 q15, q3, d1[1]\n" + + "4:\n" + "vst1.f32 {d30-d31}, [r0]\n" + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), + [nk] "+r"(nk), [rk] "+r"(rk) + : [k0] "r"(k0), [nstride] "r"(nstride) + : "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "cc"); +} +#endif // __aarch64__ + +typedef void (*sgemm_rowmajoy_micro_kernel_func)(const float *, const float *, float *, const int, + const int, const int); + +static sgemm_rowmajoy_micro_kernel_func sgemm_rowmajoy_micro_kernel_table[12][12] = { + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + { + + 0, 0, 0, 0, 0, +#if !__aarch64__ + sgemm_rowmajor_micro_kernel_4x12, +#else // !__aarch64__ + 0, +#endif // !__aarch64__ + 0, 0, 0, 0, 0, +#if __aarch64__ + sgemm_rowmajor_micro_kernel_4x24 +#else // __aarch64__ + 0 +#endif // __aarch64__ + }, + {0, 0, 0, +#if !__aarch64__ + sgemm_rowmajor_micro_kernel_6x8, +#else // !__aarch64__ + 0, +#endif // !__aarch64__ + 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, +#if __aarch64__ + sgemm_rowmajor_micro_kernel_8x12, +#else // __aarch64__ + 0, +#endif // __aarch64__ + 0, 0, 0, 0, 0, 0 + + }, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + + }, + {0, +#if !__aarch64__ + sgemm_rowmajor_micro_kernel_12x4, +#else // !__aarch64__ + 0, +#endif // !__aarch64__ + 0, +#if __aarch64__ + sgemm_rowmajor_micro_kernel_12x8, +#else // __aarch64__ + 0, +#endif // __aarch64__ + 0, 0, 0, 0, 0, 0, 0, 0 + + }, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + { + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + + }, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + + }, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + + }, + {0, +#if __aarch64__ + sgemm_rowmajor_micro_kernel_24x4, +#else // __aarch64__ + 0, +#endif // __aarch64__ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + + }, + +}; + +void _sgemm_rowmajor_macro_kernel_divnm(const int mr, const int nr, const int mb, const int nb, + const int kb, const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k0, const int nstride, + const int kstride) +{ + const int nm = (mb + mr - 1) / mr; + const int nn = (nb + nr - 1) / nr; + const int rm = mb % mr; + const int rn = nb % nr; + + sgemm_rowmajoy_micro_kernel_func sgemm_rowmajoy_micro_kernel = + sgemm_rowmajoy_micro_kernel_table[mr / 2 - 1][nr / 2 - 1]; + if (!sgemm_rowmajoy_micro_kernel) + return; + + for (int j = 0; j < nn; j++) + { + const int _nr = (j != nn - 1 || rn == 0) ? nr : rn; + for (int i = 0; i < nm; i++) + { + const int _mr = (i != nm - 1 || rm == 0) ? mr : rm; + if (_mr == mr && _nr == nr) + { + sgemm_rowmajoy_micro_kernel(&lhs_ptr[i * mr * kstride], &rhs_ptr[j * nr * kstride], + &res_ptr[i * mr * nstride + j * nr], kb, k0, nstride); + } + else + { + float res_micro[mr * nr]; + float *res = &res_ptr[i * mr * nstride + j * nr]; + + sgemm_rowmajoy_micro_kernel(&lhs_ptr[i * mr * kstride], &rhs_ptr[j * nr * kstride], + res_micro, kb, 0, nr); + if (k0 == 0) + { + for (int pi = 0; pi < _mr; pi++) + { + for (int pj = 0; pj < _nr; pj++) + { + res[pi * nstride + pj] = res_micro[pi * nr + pj]; + } + } + } + else + { + for (int pi = 0; pi < _mr; pi++) + { + for (int pj = 0; pj < _nr; pj++) + { + res[pi * nstride + pj] += res_micro[pi * nr + pj]; + } + } + } + } + } + } +} + +void _sgemm_rowmajor_macro_kernel_divmn(const int mr, const int nr, const int mb, const int nb, + const int kb, const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k0, const int nstride, + const int kstride) +{ + const int nm = (mb + mr - 1) / mr; + const int nn = (nb + nr - 1) / nr; + const int rm = mb % mr; + const int rn = nb % nr; + + sgemm_rowmajoy_micro_kernel_func sgemm_rowmajoy_micro_kernel = + sgemm_rowmajoy_micro_kernel_table[mr / 2 - 1][nr / 2 - 1]; + if (!sgemm_rowmajoy_micro_kernel) + return; + + for (int j = 0; j < nm; j++) + { + const int _mr = (j != nm - 1 || rm == 0) ? mr : rm; + for (int i = 0; i < nn; i++) + { + const int _nr = (i != nn - 1 || rn == 0) ? nr : rn; + if (_mr == mr && _nr == nr) + { + sgemm_rowmajoy_micro_kernel(&lhs_ptr[j * mr * kstride], &rhs_ptr[i * nr * kstride], + &res_ptr[j * mr * nstride + i * nr], kb, k0, nstride); + } + else + { + float res_micro[mr * nr]; + float *res = &res_ptr[j * mr * nstride + i * nr]; + + sgemm_rowmajoy_micro_kernel(&lhs_ptr[j * mr * kstride], &rhs_ptr[i * nr * kstride], + res_micro, kb, 0, nr); + if (k0 == 0) + { + for (int pi = 0; pi < _mr; pi++) + { + for (int pj = 0; pj < _nr; pj++) + { + res[pi * nstride + pj] = res_micro[pi * nr + pj]; + } + } + } + else + { + for (int pi = 0; pi < _mr; pi++) + { + for (int pj = 0; pj < _nr; pj++) + { + res[pi * nstride + pj] += res_micro[pi * nr + pj]; + } + } + } + } + } + } +} + +void _sgemm_colmajor_macro_kernel_divnm(const int mr, const int nr, const int mb, const int nb, + const int kb, const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k0, const int mstride, + const int kstride) +{ + _sgemm_rowmajor_macro_kernel_divmn(nr, mr, nb, mb, kb, rhs_ptr, lhs_ptr, res_ptr, k0, mstride, + kstride); +} + +void _sgemm_colmajor_macro_kernel_divmn(const int mr, const int nr, const int mb, const int nb, + const int kb, const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k0, const int mstride, + const int kstride) +{ + _sgemm_rowmajor_macro_kernel_divnm(nr, mr, nb, mb, kb, rhs_ptr, lhs_ptr, res_ptr, k0, mstride, + kstride); +} + +#if __aarch64__ +void _sparse_sgemm_kernel(const int nb, float lhs_data, const float *rhs_ptr, float *res_ptr) +{ + int nn = nb >> 3; + int rn = nb & 7; + + if (nn > 0) + { + asm volatile("mov x0, %[res_ptr]\n" + "dup v0.2d, %[lhs_data]\n" + "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v2.4s}, [x0], #16\n" + + "subs %[nn], %[nn], #1\n" + "beq 2f\n" + + "1:\n" + "ld1 {v4.4s}, [x0], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + "fmla v2.4s, v1.4s, v0.s[0]\n" + "st1 {v2.4s}, [%[res_ptr]], #16\n" + + "ld1 {v2.4s}, [x0], #16\n" + "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" + + "fmla v4.4s, v3.4s, v0.s[0]\n" + "st1 {v4.4s}, [%[res_ptr]], #16\n" + + "subs %[nn], %[nn], #1\n" + "bne 1b\n" + + "2:\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v4.4s}, [x0], #16\n" + + "fmla v2.4s, v1.4s, v0.s[0]\n" + "st1 {v2.4s}, [%[res_ptr]], #16\n" + + "fmla v4.4s, v3.4s, v0.s[0]\n" + "st1 {v4.4s}, [%[res_ptr]], #16\n" + : [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), [nn] "+r"(nn) + : [lhs_data] "r"(lhs_data) + : "x0", "v0", "v1", "v2", "v3", "v4", "cc"); + } + if (rn > 0) + { + int _nn = rn >> 2; + int _rn = rn & 3; + + if (_nn > 0) + { + asm volatile("dup v0.2d, %[lhs_data]\n" + "ld1 {v1.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[res_ptr]]\n" + "fmla v2.4s, v1.4s, v0.s[0]\n" + "st1 {v2.4s}, [%[res_ptr]], #16\n" + : [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr) + : [lhs_data] "r"(lhs_data) + : "x0", "x1", "x2", "cc"); + } + if (_rn > 0) + { + for (int i = 0; i < _rn; i++) + { + res_ptr[i] += lhs_data * rhs_ptr[i]; + } + } + } +} + +#else // __aarch64__ +void _sparse_sgemm_kernel(const int nb, float lhs_data, const float *rhs_ptr, float *res_ptr) +{ + int nn = nb >> 3; + int rn = nb & 7; + + if (nn > 0) + { + asm volatile("mov r0, %[res_ptr]\n" + "vdup.32 d0, %[lhs_data]\n" + "vld1.f32 {d2-d3}, [%[rhs_ptr]]!\n" + "vld1.f32 {d4-d5}, [r0]!\n" + + "subs %[nn], %[nn], #1\n" + "beq 2f\n" + + "1:\n" + "vld1.f32 {d8-d9}, [r0]!\n" + "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" + + "vmla.f32 q2, q1, d0[0]\n" + "vst1.f32 {d4-d5}, [%[res_ptr]]!\n" + + "vld1.f32 {d4-d5}, [r0]!\n" + "vld1.f32 {d2-d3}, [%[rhs_ptr]]!\n" + + "vmla.f32 q4, q3, d0[0]\n" + "vst1.f32 {d8-d9}, [%[res_ptr]]!\n" + + "subs %[nn], %[nn], #1\n" + "bne 1b\n" + + "2:\n" + "vld1.f32 {d6-d7}, [%[rhs_ptr]]!\n" + "vld1.f32 {d8-d9}, [r0]!\n" + + "vmla.f32 q2, q1, d0[0]\n" + "vst1.f32 {d4-d5}, [%[res_ptr]]!\n" + + "vmla.f32 q4, q3, d0[0]\n" + "vst1.f32 {d8-d9}, [%[res_ptr]]!\n" + : [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr), [nn] "+r"(nn) + : [lhs_data] "r"(lhs_data) + : "r0", "q0", "q1", "q2", "q3", "q4", "cc"); + } + if (rn > 0) + { + int _nn = rn >> 2; + int _rn = rn & 3; + + if (_nn > 0) + { + asm volatile("vdup.32 d0, %[lhs_data]\n" + "vld1.f32 {d2-d3}, [%[rhs_ptr]]!\n" + "vld1.f32 {d4-d5}, [%[res_ptr]]\n" + "vmla.f32 q2, q1, d0[0]\n" + "vst1.f32 {d4-d5}, [%[res_ptr]]!\n" + : [rhs_ptr] "+r"(rhs_ptr), [res_ptr] "+r"(res_ptr) + : [lhs_data] "r"(lhs_data) + : "q0", "q1", "q2", "cc"); + } + if (_rn > 0) + { + for (int i = 0; i < _rn; i++) + { + res_ptr[i] += lhs_data * rhs_ptr[i]; + } + } + } +} +#endif // __aarch64__ + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/sgemm_kernel.h b/compute/ncnn/src/srcn/sgemm_kernel.h new file mode 100644 index 000000000..9e220bc33 --- /dev/null +++ b/compute/ncnn/src/srcn/sgemm_kernel.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_SGEMM_KERNEL_H__ +#define __NNFW_SRCN_SGEMM_KERNEL_H__ + +#include "ncnn/srcn/conv_type.h" + +namespace nnfw +{ +namespace srcn +{ + +void _sgemm_rowmajor_macro_kernel_divnm(const int mr, const int nr, const int mb, const int nb, + const int kb, const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k0, const int nstride, + const int kstride); + +void _sgemm_rowmajor_macro_kernel_divmn(const int mr, const int nr, const int mb, const int nb, + const int kb, const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k0, const int nstride, + const int kstride); + +void _sgemm_colmajor_macro_kernel_divnm(const int mr, const int nr, const int mb, const int nb, + const int kb, const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k0, const int mstride, + const int kstride); + +void _sgemm_colmajor_macro_kernel_divmn(const int mr, const int nr, const int mb, const int nb, + const int kb, const float *lhs_ptr, const float *rhs_ptr, + float *res_ptr, const int k0, const int mstride, + const int kstride); + +void _sparse_sgemm_kernel(const int nb, float lhs_data, const float *rhs_ptr, float *res_ptr); + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_SGEMM_KERNEL_H__ diff --git a/compute/ncnn/src/srcn/sgemm_pack.cc b/compute/ncnn/src/srcn/sgemm_pack.cc new file mode 100644 index 000000000..8767f6c0a --- /dev/null +++ b/compute/ncnn/src/srcn/sgemm_pack.cc @@ -0,0 +1,2316 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdlib.h> +#include <arm_neon.h> + +#include "ncnn/srcn/conv_type.h" +#include "common.h" + +namespace nnfw +{ +namespace srcn +{ + +void _pack_rowmajor_notrans_lhs(const int mr, const int mb, const int kb, const int stride, + const float *lhs_ptr, float *plhs_ptr) +{ + const int nm = mb / mr; + const int rm = mb % mr; + + switch (mr) + { +#if __aarch64__ + case 24: + for (int i = 0; i < nm; i++) + { + int nk = kb >> 2; + int rk = kb & 0x03; + + const float *lhs_temp = lhs_ptr; + const int _stride = stride << 2; + + if (nk > 0) + { + asm volatile("0:\n" + "mov x0, %[lhs_temp]\n" + + "ld1 {v4.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v5.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v4.4s, v6.4s\n" + "zip2 v30.4s, v4.4s, v6.4s\n" + "zip1 v29.4s, v5.4s, v7.4s\n" + "zip2 v31.4s, v5.4s, v7.4s\n" + "zip1 v4.4s, v28.4s, v29.4s\n" + "zip2 v5.4s, v28.4s, v29.4s\n" + "zip1 v6.4s, v30.4s, v31.4s\n" + "zip2 v7.4s, v30.4s, v31.4s\n" + + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v9.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v8.4s, v10.4s\n" + "zip2 v30.4s, v8.4s, v10.4s\n" + "zip1 v29.4s, v9.4s, v11.4s\n" + "zip2 v31.4s, v9.4s, v11.4s\n" + "zip1 v8.4s, v28.4s, v29.4s\n" + "zip2 v9.4s, v28.4s, v29.4s\n" + "zip1 v10.4s, v30.4s, v31.4s\n" + "zip2 v11.4s, v30.4s, v31.4s\n" + + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v13.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v14.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v15.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v12.4s, v14.4s\n" + "zip2 v30.4s, v12.4s, v14.4s\n" + "zip1 v29.4s, v13.4s, v15.4s\n" + "zip2 v31.4s, v13.4s, v15.4s\n" + "zip1 v12.4s, v28.4s, v29.4s\n" + "zip2 v13.4s, v28.4s, v29.4s\n" + "zip1 v14.4s, v30.4s, v31.4s\n" + "zip2 v15.4s, v30.4s, v31.4s\n" + + "ld1 {v16.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v17.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v18.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v19.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v16.4s, v18.4s\n" + "zip2 v30.4s, v16.4s, v18.4s\n" + "zip1 v29.4s, v17.4s, v19.4s\n" + "zip2 v31.4s, v17.4s, v19.4s\n" + "zip1 v16.4s, v28.4s, v29.4s\n" + "zip2 v17.4s, v28.4s, v29.4s\n" + "zip1 v18.4s, v30.4s, v31.4s\n" + "zip2 v19.4s, v30.4s, v31.4s\n" + + "ld1 {v20.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v21.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v22.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v23.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v20.4s, v22.4s\n" + "zip2 v30.4s, v20.4s, v22.4s\n" + "zip1 v29.4s, v21.4s, v23.4s\n" + "zip2 v31.4s, v21.4s, v23.4s\n" + "zip1 v20.4s, v28.4s, v29.4s\n" + "zip2 v21.4s, v28.4s, v29.4s\n" + "zip1 v22.4s, v30.4s, v31.4s\n" + "zip2 v23.4s, v30.4s, v31.4s\n" + + "ld1 {v24.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v25.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v26.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v27.4s}, [x0]\n" + + "zip1 v28.4s, v24.4s, v26.4s\n" + "zip2 v30.4s, v24.4s, v26.4s\n" + "zip1 v29.4s, v25.4s, v27.4s\n" + "zip2 v31.4s, v25.4s, v27.4s\n" + "zip1 v24.4s, v28.4s, v29.4s\n" + "zip2 v25.4s, v28.4s, v29.4s\n" + "zip1 v26.4s, v30.4s, v31.4s\n" + "zip2 v27.4s, v30.4s, v31.4s\n" + + "st1 {v4.4s}, [%[plhs_ptr]], #16\n" + "st1 {v8.4s}, [%[plhs_ptr]], #16\n" + "st1 {v12.4s}, [%[plhs_ptr]], #16\n" + "st1 {v16.4s}, [%[plhs_ptr]], #16\n" + "st1 {v20.4s}, [%[plhs_ptr]], #16\n" + "st1 {v24.4s}, [%[plhs_ptr]], #16\n" + "st1 {v5.4s}, [%[plhs_ptr]], #16\n" + "st1 {v9.4s}, [%[plhs_ptr]], #16\n" + "st1 {v13.4s}, [%[plhs_ptr]], #16\n" + "st1 {v17.4s}, [%[plhs_ptr]], #16\n" + "st1 {v21.4s}, [%[plhs_ptr]], #16\n" + "st1 {v25.4s}, [%[plhs_ptr]], #16\n" + "st1 {v6.4s}, [%[plhs_ptr]], #16\n" + "st1 {v10.4s}, [%[plhs_ptr]], #16\n" + "st1 {v14.4s}, [%[plhs_ptr]], #16\n" + "st1 {v18.4s}, [%[plhs_ptr]], #16\n" + "st1 {v22.4s}, [%[plhs_ptr]], #16\n" + "st1 {v26.4s}, [%[plhs_ptr]], #16\n" + "st1 {v7.4s}, [%[plhs_ptr]], #16\n" + "st1 {v11.4s}, [%[plhs_ptr]], #16\n" + "st1 {v15.4s}, [%[plhs_ptr]], #16\n" + "st1 {v19.4s}, [%[plhs_ptr]], #16\n" + "st1 {v23.4s}, [%[plhs_ptr]], #16\n" + "st1 {v27.4s}, [%[plhs_ptr]], #16\n" + + "subs %[nk], %[nk], #1\n" + "add %[lhs_temp], %[lhs_temp], #16\n" + "bne 0b\n" + : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + + for (int j = 0; j < rk; j++) + { + plhs_ptr[0] = lhs_temp[0]; + plhs_ptr[1] = lhs_temp[stride]; + plhs_ptr[2] = lhs_temp[stride << 1]; + plhs_ptr[3] = lhs_temp[3 * stride]; + plhs_ptr[4] = lhs_temp[stride << 2]; + plhs_ptr[5] = lhs_temp[5 * stride]; + plhs_ptr[6] = lhs_temp[6 * stride]; + plhs_ptr[7] = lhs_temp[7 * stride]; + plhs_ptr[8] = lhs_temp[stride << 3]; + plhs_ptr[9] = lhs_temp[9 * stride]; + plhs_ptr[10] = lhs_temp[10 * stride]; + plhs_ptr[11] = lhs_temp[11 * stride]; + plhs_ptr[12] = lhs_temp[0]; + plhs_ptr[13] = lhs_temp[13 * stride]; + plhs_ptr[14] = lhs_temp[14 * stride]; + plhs_ptr[15] = lhs_temp[15 * stride]; + plhs_ptr[16] = lhs_temp[stride << 4]; + plhs_ptr[17] = lhs_temp[17 * stride]; + plhs_ptr[18] = lhs_temp[18 * stride]; + plhs_ptr[19] = lhs_temp[19 * stride]; + plhs_ptr[20] = lhs_temp[20 * stride]; + plhs_ptr[21] = lhs_temp[21 * stride]; + plhs_ptr[22] = lhs_temp[22 * stride]; + plhs_ptr[23] = lhs_temp[23 * stride]; + plhs_ptr += mr; + lhs_temp++; + } + + lhs_ptr += mr * stride; + } + break; + case 16: + for (int i = 0; i < nm; i++) + { + int nk = kb >> 2; + int rk = kb & 0x03; + + const float *lhs_temp = lhs_ptr; + const int _stride = stride << 2; + + if (nk > 0) + { + asm volatile("0:\n" + "mov x0, %[lhs_temp]\n" + + "ld1 {v4.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v5.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v4.4s, v6.4s\n" + "zip2 v30.4s, v4.4s, v6.4s\n" + "zip1 v29.4s, v5.4s, v7.4s\n" + "zip2 v31.4s, v5.4s, v7.4s\n" + "zip1 v4.4s, v28.4s, v29.4s\n" + "zip2 v5.4s, v28.4s, v29.4s\n" + "zip1 v6.4s, v30.4s, v31.4s\n" + "zip2 v7.4s, v30.4s, v31.4s\n" + + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v9.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v8.4s, v10.4s\n" + "zip2 v30.4s, v8.4s, v10.4s\n" + "zip1 v29.4s, v9.4s, v11.4s\n" + "zip2 v31.4s, v9.4s, v11.4s\n" + "zip1 v8.4s, v28.4s, v29.4s\n" + "zip2 v9.4s, v28.4s, v29.4s\n" + "zip1 v10.4s, v30.4s, v31.4s\n" + "zip2 v11.4s, v30.4s, v31.4s\n" + + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v13.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v14.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v15.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v12.4s, v14.4s\n" + "zip2 v30.4s, v12.4s, v14.4s\n" + "zip1 v29.4s, v13.4s, v15.4s\n" + "zip2 v31.4s, v13.4s, v15.4s\n" + "zip1 v12.4s, v28.4s, v29.4s\n" + "zip2 v13.4s, v28.4s, v29.4s\n" + "zip1 v14.4s, v30.4s, v31.4s\n" + "zip2 v15.4s, v30.4s, v31.4s\n" + + "ld1 {v16.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v17.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v18.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v19.4s}, [x0]\n" + + "zip1 v28.4s, v16.4s, v18.4s\n" + "zip2 v30.4s, v16.4s, v18.4s\n" + "zip1 v29.4s, v17.4s, v19.4s\n" + "zip2 v31.4s, v17.4s, v19.4s\n" + "zip1 v16.4s, v28.4s, v29.4s\n" + "zip2 v17.4s, v28.4s, v29.4s\n" + "zip1 v18.4s, v30.4s, v31.4s\n" + "zip2 v19.4s, v30.4s, v31.4s\n" + + "st1 {v4.4s}, [%[plhs_ptr]], #16\n" + "st1 {v8.4s}, [%[plhs_ptr]], #16\n" + "st1 {v12.4s}, [%[plhs_ptr]], #16\n" + "st1 {v16.4s}, [%[plhs_ptr]], #16\n" + "st1 {v5.4s}, [%[plhs_ptr]], #16\n" + "st1 {v9.4s}, [%[plhs_ptr]], #16\n" + "st1 {v13.4s}, [%[plhs_ptr]], #16\n" + "st1 {v17.4s}, [%[plhs_ptr]], #16\n" + "st1 {v6.4s}, [%[plhs_ptr]], #16\n" + "st1 {v10.4s}, [%[plhs_ptr]], #16\n" + "st1 {v14.4s}, [%[plhs_ptr]], #16\n" + "st1 {v18.4s}, [%[plhs_ptr]], #16\n" + "st1 {v7.4s}, [%[plhs_ptr]], #16\n" + "st1 {v11.4s}, [%[plhs_ptr]], #16\n" + "st1 {v15.4s}, [%[plhs_ptr]], #16\n" + "st1 {v19.4s}, [%[plhs_ptr]], #16\n" + + "subs %[nk], %[nk], #1\n" + "add %[lhs_temp], %[lhs_temp], #16\n" + "bne 0b\n" + : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v28", "v29", + "v30", "v31"); + } + + for (int j = 0; j < rk; j++) + { + plhs_ptr[0] = lhs_temp[0]; + plhs_ptr[1] = lhs_temp[stride]; + plhs_ptr[2] = lhs_temp[stride << 1]; + plhs_ptr[3] = lhs_temp[3 * stride]; + plhs_ptr[4] = lhs_temp[stride << 2]; + plhs_ptr[5] = lhs_temp[5 * stride]; + plhs_ptr[6] = lhs_temp[6 * stride]; + plhs_ptr[7] = lhs_temp[7 * stride]; + plhs_ptr[8] = lhs_temp[stride << 3]; + plhs_ptr[9] = lhs_temp[9 * stride]; + plhs_ptr[10] = lhs_temp[10 * stride]; + plhs_ptr[11] = lhs_temp[11 * stride]; + plhs_ptr[12] = lhs_temp[0]; + plhs_ptr[13] = lhs_temp[13 * stride]; + plhs_ptr[14] = lhs_temp[14 * stride]; + plhs_ptr[15] = lhs_temp[15 * stride]; + plhs_ptr += mr; + lhs_temp++; + } + + lhs_ptr += mr * stride; + } + break; +#endif // __aarch64__ + case 12: + for (int i = 0; i < nm; i++) + { + int nk = kb >> 2; + int rk = kb & 0x03; + + const float *lhs_temp = lhs_ptr; + const int _stride = stride << 2; + + if (nk > 0) + { +#if __aarch64__ + asm volatile("0:\n" + "mov x0, %[lhs_temp]\n" + + "ld1 {v4.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v5.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v4.4s, v6.4s\n" + "zip2 v30.4s, v4.4s, v6.4s\n" + "zip1 v29.4s, v5.4s, v7.4s\n" + "zip2 v31.4s, v5.4s, v7.4s\n" + "zip1 v4.4s, v28.4s, v29.4s\n" + "zip2 v5.4s, v28.4s, v29.4s\n" + "zip1 v6.4s, v30.4s, v31.4s\n" + "zip2 v7.4s, v30.4s, v31.4s\n" + + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v9.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v8.4s, v10.4s\n" + "zip2 v30.4s, v8.4s, v10.4s\n" + "zip1 v29.4s, v9.4s, v11.4s\n" + "zip2 v31.4s, v9.4s, v11.4s\n" + "zip1 v8.4s, v28.4s, v29.4s\n" + "zip2 v9.4s, v28.4s, v29.4s\n" + "zip1 v10.4s, v30.4s, v31.4s\n" + "zip2 v11.4s, v30.4s, v31.4s\n" + + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v13.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v14.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v15.4s}, [x0]\n" + + "zip1 v28.4s, v12.4s, v14.4s\n" + "zip2 v30.4s, v12.4s, v14.4s\n" + "zip1 v29.4s, v13.4s, v15.4s\n" + "zip2 v31.4s, v13.4s, v15.4s\n" + "zip1 v12.4s, v28.4s, v29.4s\n" + "zip2 v13.4s, v28.4s, v29.4s\n" + "zip1 v14.4s, v30.4s, v31.4s\n" + "zip2 v15.4s, v30.4s, v31.4s\n" + + "st1 {v4.4s}, [%[plhs_ptr]], #16\n" + "st1 {v8.4s}, [%[plhs_ptr]], #16\n" + "st1 {v12.4s}, [%[plhs_ptr]], #16\n" + "st1 {v5.4s}, [%[plhs_ptr]], #16\n" + "st1 {v9.4s}, [%[plhs_ptr]], #16\n" + "st1 {v13.4s}, [%[plhs_ptr]], #16\n" + "st1 {v6.4s}, [%[plhs_ptr]], #16\n" + "st1 {v10.4s}, [%[plhs_ptr]], #16\n" + "st1 {v14.4s}, [%[plhs_ptr]], #16\n" + "st1 {v7.4s}, [%[plhs_ptr]], #16\n" + "st1 {v11.4s}, [%[plhs_ptr]], #16\n" + "st1 {v15.4s}, [%[plhs_ptr]], #16\n" + + "subs %[nk], %[nk], #1\n" + "add %[lhs_temp], %[lhs_temp], #16\n" + "bne 0b\n" + : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile("0:\n" + "mov r0, %[lhs_temp]\n" + + "vld1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[_stride]\n" + + "vzip.32 q4, q6\n" + "vzip.32 q5, q7\n" + "vzip.32 q4, q5\n" + "vzip.32 q6, q7\n" + + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[_stride]\n" + + "vzip.32 q8, q10\n" + "vzip.32 q9, q11\n" + "vzip.32 q8, q9\n" + "vzip.32 q10, q11\n" + + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d28-d29}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d30-d31}, [r0]\n" + + "vzip.32 q12, q14\n" + "vzip.32 q13, q15\n" + "vzip.32 q12, q13\n" + "vzip.32 q14, q15\n" + + "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" + "vst1.f32 {d16-d17}, [%[plhs_ptr]]!\n" + "vst1.f32 {d24-d25}, [%[plhs_ptr]]!\n" + "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" + "vst1.f32 {d18-d19}, [%[plhs_ptr]]!\n" + "vst1.f32 {d26-d27}, [%[plhs_ptr]]!\n" + "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" + "vst1.f32 {d20-d21}, [%[plhs_ptr]]!\n" + "vst1.f32 {d28-d29}, [%[plhs_ptr]]!\n" + "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" + "vst1.f32 {d22-d23}, [%[plhs_ptr]]!\n" + "vst1.f32 {d30-d31}, [%[plhs_ptr]]!\n" + + "subs %[nk], %[nk], #1\n" + "add %[lhs_temp], %[lhs_temp], #16\n" + "bne 0b\n" + : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15"); +#endif // __aarch64__ + } + + for (int j = 0; j < rk; j++) + { + plhs_ptr[0] = lhs_temp[0]; + plhs_ptr[1] = lhs_temp[stride]; + plhs_ptr[2] = lhs_temp[stride << 1]; + plhs_ptr[3] = lhs_temp[3 * stride]; + plhs_ptr[4] = lhs_temp[stride << 2]; + plhs_ptr[5] = lhs_temp[5 * stride]; + plhs_ptr[6] = lhs_temp[6 * stride]; + plhs_ptr[7] = lhs_temp[7 * stride]; + plhs_ptr[8] = lhs_temp[stride << 3]; + plhs_ptr[9] = lhs_temp[9 * stride]; + plhs_ptr[10] = lhs_temp[10 * stride]; + plhs_ptr[11] = lhs_temp[11 * stride]; + plhs_ptr += mr; + lhs_temp++; + } + + lhs_ptr += mr * stride; + } + break; + case 8: + for (int i = 0; i < nm; i++) + { + int nk = kb >> 2; + int rk = kb & 0x03; + + const float *lhs_temp = lhs_ptr; + const int _stride = stride << 2; + + if (nk > 0) + { +#if __aarch64__ + asm volatile("0:\n" + "mov x0, %[lhs_temp]\n" + + "ld1 {v4.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v5.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v4.4s, v6.4s\n" + "zip2 v30.4s, v4.4s, v6.4s\n" + "zip1 v29.4s, v5.4s, v7.4s\n" + "zip2 v31.4s, v5.4s, v7.4s\n" + "zip1 v4.4s, v28.4s, v29.4s\n" + "zip2 v5.4s, v28.4s, v29.4s\n" + "zip1 v6.4s, v30.4s, v31.4s\n" + "zip2 v7.4s, v30.4s, v31.4s\n" + + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v9.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "zip1 v28.4s, v8.4s, v10.4s\n" + "zip2 v30.4s, v8.4s, v10.4s\n" + "zip1 v29.4s, v9.4s, v11.4s\n" + "zip2 v31.4s, v9.4s, v11.4s\n" + "zip1 v8.4s, v28.4s, v29.4s\n" + "zip2 v9.4s, v28.4s, v29.4s\n" + "zip1 v10.4s, v30.4s, v31.4s\n" + "zip2 v11.4s, v30.4s, v31.4s\n" + + "st1 {v4.4s}, [%[plhs_ptr]], #16\n" + "st1 {v8.4s}, [%[plhs_ptr]], #16\n" + "st1 {v5.4s}, [%[plhs_ptr]], #16\n" + "st1 {v9.4s}, [%[plhs_ptr]], #16\n" + "st1 {v6.4s}, [%[plhs_ptr]], #16\n" + "st1 {v10.4s}, [%[plhs_ptr]], #16\n" + "st1 {v7.4s}, [%[plhs_ptr]], #16\n" + "st1 {v11.4s}, [%[plhs_ptr]], #16\n" + + "subs %[nk], %[nk], #1\n" + "add %[lhs_temp], %[lhs_temp], #16\n" + "bne 0b\n" + : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile("0:\n" + "mov r0, %[lhs_temp]\n" + + "vld1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[_stride]\n" + + "vzip.32 q4, q6\n" + "vzip.32 q5, q7\n" + "vzip.32 q4, q5\n" + "vzip.32 q6, q7\n" + + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + + "vzip.32 q8, q10\n" + "vzip.32 q9, q11\n" + "vzip.32 q8, q9\n" + "vzip.32 q10, q11\n" + + "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" + "vst1.f32 {d16-d17}, [%[plhs_ptr]]!\n" + "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" + "vst1.f32 {d18-d19}, [%[plhs_ptr]]!\n" + "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" + "vst1.f32 {d20-d21}, [%[plhs_ptr]]!\n" + "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" + "vst1.f32 {d22-d23}, [%[plhs_ptr]]!\n" + + "subs %[nk], %[nk], #1\n" + "add %[lhs_temp], %[lhs_temp], #16\n" + "bne 0b\n" + : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); +#endif // __aarch64__ + } + + for (int j = 0; j < rk; j++) + { + plhs_ptr[0] = lhs_temp[0]; + plhs_ptr[1] = lhs_temp[stride]; + plhs_ptr[2] = lhs_temp[stride << 1]; + plhs_ptr[3] = lhs_temp[3 * stride]; + plhs_ptr[4] = lhs_temp[stride << 2]; + plhs_ptr[5] = lhs_temp[5 * stride]; + plhs_ptr[6] = lhs_temp[6 * stride]; + plhs_ptr[7] = lhs_temp[7 * stride]; + plhs_ptr += mr; + lhs_temp++; + } + + lhs_ptr += mr * stride; + } + break; + case 6: + for (int i = 0; i < nm; i++) + { + int nk = kb >> 2; + int rk = kb & 0x03; + + const float *lhs_temp = lhs_ptr; + const int _stride = stride << 2; + + if (nk > 0) + { +#if __aarch64__ + // TODO: 4--->6 + asm volatile("0:\n" + "mov x0, %[lhs_temp]\n" + + "ld1 {v4.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v5.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v8.4s}, [x0]\n" + + "zip1 v28.4s, v4.4s, v6.4s\n" + "zip2 v30.4s, v4.4s, v6.4s\n" + "zip1 v29.4s, v5.4s, v7.4s\n" + "zip2 v31.4s, v5.4s, v7.4s\n" + "zip1 v4.4s, v28.4s, v29.4s\n" + "zip2 v5.4s, v28.4s, v29.4s\n" + "zip1 v6.4s, v30.4s, v31.4s\n" + "zip2 v7.4s, v30.4s, v31.4s\n" + + "st1 {v4.4s}, [%[plhs_ptr]], #16\n" + "st1 {v5.4s}, [%[plhs_ptr]], #16\n" + "st1 {v6.4s}, [%[plhs_ptr]], #16\n" + "st1 {v7.4s}, [%[plhs_ptr]], #16\n" + + "subs %[nk], %[nk], #1\n" + "add %[lhs_temp], %[lhs_temp], #16\n" + "bne 0b\n" + : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile("0:\n" + "mov r0, %[lhs_temp]\n" + + "vld1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "vzip.32 q4, q6\n" + "vzip.32 q5, q7\n" + "vzip.32 q4, q5\n" + "vzip.32 q6, q7\n" + "vzip.32 q8, q9\n" + + "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" + "vst1.f32 {d16}, [%[plhs_ptr]]!\n" + "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" + "vst1.f32 {d17}, [%[plhs_ptr]]!\n" + "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" + "vst1.f32 {d18}, [%[plhs_ptr]]!\n" + "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" + "vst1.f32 {d19}, [%[plhs_ptr]]!\n" + + "subs %[nk], %[nk], #1\n" + "add %[lhs_temp], %[lhs_temp], #16\n" + "bne 0b\n" + : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9"); +#endif // __aarch64__ + } + + for (int j = 0; j < rk; j++) + { + plhs_ptr[0] = lhs_temp[0]; + plhs_ptr[1] = lhs_temp[stride]; + plhs_ptr[2] = lhs_temp[stride << 1]; + plhs_ptr[3] = lhs_temp[3 * stride]; + plhs_ptr[4] = lhs_temp[stride << 2]; + plhs_ptr[5] = lhs_temp[5 * stride]; + plhs_ptr += mr; + lhs_temp++; + } + + lhs_ptr += mr * stride; + } + break; + case 4: + for (int i = 0; i < nm; i++) + { + int nk = kb >> 2; + int rk = kb & 0x03; + + const float *lhs_temp = lhs_ptr; + const int _stride = stride << 2; + + if (nk > 0) + { +#if __aarch64__ + asm volatile("0:\n" + "mov x0, %[lhs_temp]\n" + + "ld1 {v4.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v5.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "zip1 v28.4s, v4.4s, v6.4s\n" + "zip2 v30.4s, v4.4s, v6.4s\n" + "zip1 v29.4s, v5.4s, v7.4s\n" + "zip2 v31.4s, v5.4s, v7.4s\n" + "zip1 v4.4s, v28.4s, v29.4s\n" + "zip2 v5.4s, v28.4s, v29.4s\n" + "zip1 v6.4s, v30.4s, v31.4s\n" + "zip2 v7.4s, v30.4s, v31.4s\n" + + "st1 {v4.4s}, [%[plhs_ptr]], #16\n" + "st1 {v5.4s}, [%[plhs_ptr]], #16\n" + "st1 {v6.4s}, [%[plhs_ptr]], #16\n" + "st1 {v7.4s}, [%[plhs_ptr]], #16\n" + + "subs %[nk], %[nk], #1\n" + "add %[lhs_temp], %[lhs_temp], #16\n" + "bne 0b\n" + : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile("0:\n" + "mov r0, %[lhs_temp]\n" + + "vld1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "vzip.32 q4, q6\n" + "vzip.32 q5, q7\n" + "vzip.32 q4, q5\n" + "vzip.32 q6, q7\n" + + "vst1.f32 {d8-d9}, [%[plhs_ptr]]!\n" + "vst1.f32 {d10-d11}, [%[plhs_ptr]]!\n" + "vst1.f32 {d12-d13}, [%[plhs_ptr]]!\n" + "vst1.f32 {d14-d15}, [%[plhs_ptr]]!\n" + + "subs %[nk], %[nk], #1\n" + "add %[lhs_temp], %[lhs_temp], #16\n" + "bne 0b\n" + : [lhs_temp] "+r"(lhs_temp), [plhs_ptr] "+r"(plhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "r0", "q4", "q5", "q6", "q7"); +#endif // __aarch64__ + } + + for (int j = 0; j < rk; j++) + { + plhs_ptr[0] = lhs_temp[0]; + plhs_ptr[1] = lhs_temp[stride]; + plhs_ptr[2] = lhs_temp[stride << 1]; + plhs_ptr[3] = lhs_temp[3 * stride]; + plhs_ptr += mr; + lhs_temp++; + } + + lhs_ptr += mr * stride; + } + break; + default: + break; + } + + if (rm > 0) + { + for (int j = 0; j < kb; j++) + { + for (int i = 0; i < rm; i++) + { + plhs_ptr[i] = lhs_ptr[i * stride]; + } + for (int i = rm; i < mr; i++) + { + plhs_ptr[i] = 0.f; + } + plhs_ptr += mr; + lhs_ptr++; + } + } +} + +void _pack_rowmajor_notrans_rhs(const int nr, const int nb, const int kb, const int stride, + const float *rhs_ptr, float *prhs_ptr) +{ + const int nn = nb / nr; + const int rn = nb % nr; + + switch (nr) + { + case 24: + for (int j = 0; j < nn; j++) + { + const float *rhs_temp = rhs_ptr; + float32x4_t q0, q1, q2, q3, q4, q5; + for (int i = 0; i < kb; i++) + { + q0 = vld1q_f32(rhs_temp); + q1 = vld1q_f32(rhs_temp + 4); + q2 = vld1q_f32(rhs_temp + 8); + q3 = vld1q_f32(rhs_temp + 12); + q4 = vld1q_f32(rhs_temp + 16); + q5 = vld1q_f32(rhs_temp + 20); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + vst1q_f32(prhs_ptr + 8, q2); + vst1q_f32(prhs_ptr + 12, q3); + vst1q_f32(prhs_ptr + 16, q4); + vst1q_f32(prhs_ptr + 20, q5); + + rhs_temp += stride; + prhs_ptr += nr; + } + + rhs_ptr += nr; + } + break; + case 16: + for (int j = 0; j < nn; j++) + { + const float *rhs_temp = rhs_ptr; + float32x4_t q0, q1, q2, q3; + for (int i = 0; i < kb; i++) + { + q0 = vld1q_f32(rhs_temp); + q1 = vld1q_f32(rhs_temp + 4); + q2 = vld1q_f32(rhs_temp + 8); + q3 = vld1q_f32(rhs_temp + 12); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + vst1q_f32(prhs_ptr + 8, q2); + vst1q_f32(prhs_ptr + 12, q3); + + rhs_temp += stride; + prhs_ptr += nr; + } + + rhs_ptr += nr; + } + break; + case 12: + for (int j = 0; j < nn; j++) + { + const float *rhs_temp = rhs_ptr; + float32x4_t q0, q1, q2; + for (int i = 0; i < kb; i++) + { + q0 = vld1q_f32(rhs_temp); + q1 = vld1q_f32(rhs_temp + 4); + q2 = vld1q_f32(rhs_temp + 8); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + vst1q_f32(prhs_ptr + 8, q2); + + rhs_temp += stride; + prhs_ptr += nr; + } + + rhs_ptr += nr; + } + break; + case 8: + for (int j = 0; j < nn; j++) + + { + const float *rhs_temp = rhs_ptr; + float32x4_t q0, q1, q2, q3; + + int i = 0; + for (; i + 1 < kb; i += 2) + { + q0 = vld1q_f32(rhs_temp); + q1 = vld1q_f32(rhs_temp + 4); + q2 = vld1q_f32(rhs_temp + stride); + q3 = vld1q_f32(rhs_temp + stride + 4); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + vst1q_f32(prhs_ptr + 8, q2); + vst1q_f32(prhs_ptr + 12, q3); + + rhs_temp += stride << 1; + prhs_ptr += nr << 1; + } + + for (; i < kb; i++) + { + q0 = vld1q_f32(rhs_temp); + q1 = vld1q_f32(rhs_temp + 4); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + + rhs_temp += stride; + prhs_ptr += nr; + } + + rhs_ptr += nr; + } + break; + case 6: + for (int j = 0; j < nn; j++) + + { + const float *rhs_temp = rhs_ptr; + float32x4_t q0, q2; + float32x2_t q1, q3; + + int i = 0; + for (; i + 1 < kb; i += 2) + { + q0 = vld1q_f32(rhs_temp); + q1 = vld1_f32(rhs_temp + 4); + + q2 = vld1q_f32(rhs_temp + stride); + q3 = vld1_f32(rhs_temp + stride + 4); + vst1q_f32(prhs_ptr, q0); + vst1_f32(prhs_ptr + 4, q1); + vst1q_f32(prhs_ptr + 6, q2); + vst1_f32(prhs_ptr + 10, q3); + + rhs_temp += stride << 1; + prhs_ptr += nr << 1; + } + + for (; i < kb; i++) + { + q0 = vld1q_f32(rhs_temp); + q1 = vld1_f32(rhs_temp + 4); + + vst1q_f32(prhs_ptr, q0); + vst1_f32(prhs_ptr + 4, q1); + + rhs_temp += stride; + prhs_ptr += nr; + } + + rhs_ptr += nr; + } + break; + case 4: + for (int j = 0; j < nn; j++) + + { + const float *rhs_temp = rhs_ptr; + float32x4_t q0, q1, q2, q3; + + int i = 0; + for (; i + 3 < kb; i += 4) + { + q0 = vld1q_f32(rhs_temp); + q1 = vld1q_f32(rhs_temp + stride); + q2 = vld1q_f32(rhs_temp + (stride << 1)); + q3 = vld1q_f32(rhs_temp + (stride * 3)); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + vst1q_f32(prhs_ptr + 8, q2); + vst1q_f32(prhs_ptr + 12, q3); + + rhs_temp += stride << 2; + prhs_ptr += nr << 2; + } + for (; i + 1 < kb; i += 2) + { + q0 = vld1q_f32(rhs_temp); + q1 = vld1q_f32(rhs_temp + stride); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + + rhs_temp += stride << 1; + prhs_ptr += nr << 1; + } + for (; i < kb; i++) + { + q0 = vld1q_f32(rhs_temp); + vst1q_f32(prhs_ptr, q0); + + rhs_temp += stride; + prhs_ptr += nr; + } + + rhs_ptr += nr; + } + break; + default: + break; + } + + if (rn > 0) + { + for (int i = 0; i < kb; i++) + { + for (int j = 0; j < rn; j++) + { + prhs_ptr[j] = rhs_ptr[j]; + } + for (int j = rn; j < nr; j++) + { + prhs_ptr[j] = 0.f; + } + prhs_ptr += nr; + rhs_ptr += stride; + } + } +} + +void _pack_rowmajor_trans_lhs(const int mr, const int mb, const int kb, const int stride, + const float *lhs_ptr, float *plhs_ptr) +{ + _pack_rowmajor_notrans_rhs(mr, mb, kb, stride, lhs_ptr, plhs_ptr); +} + +void _pack_rowmajor_trans_rhs(const int nr, const int nb, const int kb, const int stride, + const float *rhs_ptr, float *prhs_ptr) +{ + _pack_rowmajor_notrans_lhs(nr, nb, kb, stride, rhs_ptr, prhs_ptr); +} + +static inline void _pack_rowmajor_image_subn(const int nr, const int nb, const int stride, + const float *buffer, float *prhs_ptr) +{ + const int nn = nb / nr; + const int rn = nb % nr; + + switch (nr) + { + case 24: + for (int j = 0; j < nn; j++) + { + float32x4_t q0, q1, q2, q3, q4, q5; + q0 = vld1q_f32(buffer); + q1 = vld1q_f32(buffer + 4); + q2 = vld1q_f32(buffer + 8); + q3 = vld1q_f32(buffer + 12); + q4 = vld1q_f32(buffer + 16); + q5 = vld1q_f32(buffer + 20); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + vst1q_f32(prhs_ptr + 8, q2); + vst1q_f32(prhs_ptr + 12, q3); + vst1q_f32(prhs_ptr + 16, q4); + vst1q_f32(prhs_ptr + 20, q5); + prhs_ptr += stride; + buffer += nr; + } + break; + case 16: + for (int j = 0; j < nn; j++) + { + float32x4_t q0, q1, q2, q3; + q0 = vld1q_f32(buffer); + q1 = vld1q_f32(buffer + 4); + q2 = vld1q_f32(buffer + 8); + q3 = vld1q_f32(buffer + 12); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + vst1q_f32(prhs_ptr + 8, q2); + vst1q_f32(prhs_ptr + 12, q3); + prhs_ptr += stride; + buffer += nr; + } + break; + case 12: + for (int j = 0; j < nn; j++) + { + float32x4_t q0, q1, q2; + q0 = vld1q_f32(buffer); + q1 = vld1q_f32(buffer + 4); + q2 = vld1q_f32(buffer + 8); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + vst1q_f32(prhs_ptr + 8, q2); + prhs_ptr += stride; + buffer += nr; + } + break; + case 8: + for (int j = 0; j < nn; j++) + { + float32x4_t q0, q1; + q0 = vld1q_f32(buffer); + q1 = vld1q_f32(buffer + 4); + vst1q_f32(prhs_ptr, q0); + vst1q_f32(prhs_ptr + 4, q1); + prhs_ptr += stride; + buffer += nr; + } + break; + case 6: + for (int j = 0; j < nn; j++) + { + float32x4_t q0; + float32x2_t q1; + q0 = vld1q_f32(buffer); + q1 = vld1_f32(buffer + 4); + vst1q_f32(prhs_ptr, q0); + vst1_f32(prhs_ptr + 4, q1); + prhs_ptr += stride; + buffer += nr; + } + break; + case 4: + for (int j = 0; j < nn; j++) + { + float32x4_t q0; + q0 = vld1q_f32(buffer); + vst1q_f32(prhs_ptr, q0); + prhs_ptr += stride; + buffer += nr; + } + break; + default: + break; + } + + if (rn > 0) + { + for (int j = 0; j < rn; j++) + { + prhs_ptr[j] = buffer[j]; + } + for (int j = rn; j < nr; j++) + { + prhs_ptr[j] = 0.f; + } + } +} + +void _pack_rowmajor_image_rhs(const int nr, const int nb, const int kb, const int k0, const int n0, + convMat_t *input, convMat_t *output, convParams_t *params, + float *prhs_ptr) +{ + const int w = input->w; + const int h = input->h; + const int outw = output->w; + const int kernel_w = params->kernel_w; + const int kernel_h = params->kernel_h; + const int stride_w = params->stride_w; + const int stride_h = params->stride_h; + const int pad_w = params->pad_w; + const int pad_h = params->pad_h; + + const int in_row0 = n0 / outw * stride_h; + const int in_col0 = n0 % outw * stride_w; + int seg0 = outw - n0 % outw; + if (seg0 > nb) + seg0 = nb; + int rows = (nb - seg0 + outw - 1) / outw; + if (seg0) + rows++; + const int segn = (nb - seg0) % outw; + + float row_data[nb]; + + for (int i = k0; i < kb + k0; i++) + { + const int ic = i / (kernel_w * kernel_h); + const int in_row1 = ((i / kernel_w) % kernel_h) * params->dilation_h + in_row0; + const int in_col1 = i % kernel_w * params->dilation_w; + +#ifdef NCNN + const float *input_data = input->data + ic * alignSize(w * h, 16 / sizeof(float)); +#else // NCNN + const float *input_data = input->data + ic * w * h; +#endif // NCNN + float *buffer = row_data; + int in_row = in_row1 - pad_h; + + for (int out_rows = rows; out_rows; out_rows--) + { + int cols = (out_rows != 1 || segn == 0) ? outw : segn; + int in_col = in_col1 - pad_w; + if (out_rows == rows) + { + cols = seg0; + in_col += in_col0; + } + if ((unsigned int)in_row < (unsigned int)h) + { + for (int out_col = cols; out_col; out_col--) + { + if ((unsigned int)in_col < (unsigned int)w) + *(buffer++) = input_data[in_row * w + in_col]; + else + *(buffer++) = 0; + in_col += stride_w; + } + } + else + { + for (int out_col = cols; out_col; out_col--) + { + *(buffer++) = 0; + in_col += stride_w; + } + } + + in_row += stride_h; + } + + _pack_rowmajor_image_subn(nr, nb, nr * kb, row_data, prhs_ptr); + prhs_ptr += nr; + } +} + +void _pack_rowmajor_image_rhs_batch(const int nr, const int nb, const int kb, const int k0, + const int n0, convMat_t *input, convMat_t *output, + convParams_t *params, float *prhs_ptr) +{ + const int w = input->w; + const int h = input->h; + const int c = input->c; + +#ifdef NCNN + const int seg_size = alignSize(output->w * output->h, 16 / sizeof(float)); +#else // NCNN + const int seg_size = output->w * output->h; +#endif // NCNN + +#ifdef NCNN + float *data = input->data + (alignSize(w * h, 16 / sizeof(float)) * c) * (n0 / seg_size); +#else // NCNN + float *data = input->data + (w * h * c) * (n0 / seg_size); +#endif // NCNN + + int seg0 = seg_size - n0 % seg_size; + if (seg0 > nb) + seg0 = nb; + int nseg = (nb - seg0 + seg_size - 1) / seg_size; + if (seg0) + nseg++; + const int segn = (nb - seg0) % seg_size; + convMat_t _input = {w, h, c, 1, data}; + + for (int i = 0; i < nseg; i++) + { + const int _nb = (i == 0 ? seg0 : (i == nseg - 1 ? segn : seg_size)); + const int _n0 = (i == 0 ? seg_size - seg0 : 0); + + _pack_rowmajor_image_rhs(nr, _nb, kb, k0, _n0, &_input, output, params, prhs_ptr); + +#ifdef NCNN + _input.data += alignSize(w * h, 16 / sizeof(float)) * c; +#else // NCNN + _input.data += w * h * c; +#endif // NCNN + } +} + +void _unpack_rowmajor_image_res(const int mb, const int nb, const int m0, const int n0, + convMat_t *input, convMat_t *output, convParams_t *params, + float *pres_ptr) +{ + const int outw = output->w; + const int outh = output->h; + const int w = input->w; + const int kernel_w = params->kernel_w; + const int kernel_h = params->kernel_h; + const int stride_w = params->stride_w; + const int stride_h = params->stride_h; + const int pad_w = params->pad_w; + const int pad_h = params->pad_h; + + const int out_row0 = n0 / w * stride_h; + const int out_col0 = n0 % w * stride_w; + int seg0 = w - n0 % w; + if (seg0 > nb) + seg0 = nb; + int rows = (nb - seg0 + w - 1) / w; + if (seg0) + rows++; + const int segn = (nb - seg0) % w; + + for (int i = m0; i < mb + m0; i++) + { + const int oc = i / (kernel_w * kernel_h); + const int out_row1 = ((i / kernel_w) % kernel_h) * params->dilation_h + out_row0; + const int out_col1 = i % kernel_w * params->dilation_w; + +#ifdef NCNN + float *output_data = output->data + oc * alignSize(outw * outh, 16 / sizeof(float)); +#else // NCNN + float *output_data = output->data + oc * outw * outh; +#endif // NCNN + int out_row = out_row1 - pad_h; + + for (int in_rows = rows; in_rows; in_rows--) + { + int cols = (in_rows != 1 || segn == 0) ? w : segn; + int out_col = out_col1 - pad_w; + if (in_rows == rows) + { + cols = seg0; + out_col += out_col0; + } + if ((unsigned int)out_row < (unsigned int)outh) + { + for (int in_col = cols; in_col; in_col--) + { + if ((unsigned int)out_col < (unsigned int)outw) + output_data[out_row * outw + out_col] += *pres_ptr++; + else + pres_ptr++; + out_col += stride_w; + } + } + else + { + pres_ptr += cols; + } + out_row += stride_h; + } + } +} + +// TODO:v8 & other case. +static inline void _pack_colmajor_image_rhs_sub(const int nr, const int k, const float *buffer, + float *prhs_ptr) +{ + int nk = k >> 2; + int rk = k & 0x03; + + const int _stride = k << 2; + + switch (nr) + { + case 12: + if (nk > 0) + { +#if __aarch64__ + asm volatile("0:\n" + "mov x0, %[buffer]\n" + + "ld1 {v4.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v5.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v4.4s, v6.4s\n" + "zip2 v30.4s, v4.4s, v6.4s\n" + "zip1 v29.4s, v5.4s, v7.4s\n" + "zip2 v31.4s, v5.4s, v7.4s\n" + "zip1 v4.4s, v28.4s, v29.4s\n" + "zip2 v5.4s, v28.4s, v29.4s\n" + "zip1 v6.4s, v30.4s, v31.4s\n" + "zip2 v7.4s, v30.4s, v31.4s\n" + + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v9.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v11.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v8.4s, v10.4s\n" + "zip2 v30.4s, v8.4s, v10.4s\n" + "zip1 v29.4s, v9.4s, v11.4s\n" + "zip2 v31.4s, v9.4s, v11.4s\n" + "zip1 v8.4s, v28.4s, v29.4s\n" + "zip2 v9.4s, v28.4s, v29.4s\n" + "zip1 v10.4s, v30.4s, v31.4s\n" + "zip2 v11.4s, v30.4s, v31.4s\n" + + "ld1 {v12.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v13.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v14.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v15.4s}, [x0]\n" + + "zip1 v28.4s, v12.4s, v14.4s\n" + "zip2 v30.4s, v12.4s, v14.4s\n" + "zip1 v29.4s, v13.4s, v15.4s\n" + "zip2 v31.4s, v13.4s, v15.4s\n" + "zip1 v12.4s, v28.4s, v29.4s\n" + "zip2 v13.4s, v28.4s, v29.4s\n" + "zip1 v14.4s, v30.4s, v31.4s\n" + "zip2 v15.4s, v30.4s, v31.4s\n" + + "st1 {v4.4s}, [%[prhs_ptr]], #16\n" + "st1 {v8.4s}, [%[prhs_ptr]], #16\n" + "st1 {v12.4s}, [%[prhs_ptr]], #16\n" + "st1 {v5.4s}, [%[prhs_ptr]], #16\n" + "st1 {v9.4s}, [%[prhs_ptr]], #16\n" + "st1 {v13.4s}, [%[prhs_ptr]], #16\n" + "st1 {v6.4s}, [%[prhs_ptr]], #16\n" + "st1 {v10.4s}, [%[prhs_ptr]], #16\n" + "st1 {v14.4s}, [%[prhs_ptr]], #16\n" + "st1 {v7.4s}, [%[prhs_ptr]], #16\n" + "st1 {v11.4s}, [%[prhs_ptr]], #16\n" + "st1 {v15.4s}, [%[prhs_ptr]], #16\n" + + "subs %[nk], %[nk], #1\n" + "add %[buffer], %[buffer], #16\n" + "bne 0b\n" + : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile("0:\n" + "mov r0, %[buffer]\n" + + "vld1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[_stride]\n" + + "vzip.32 q4, q6\n" + "vzip.32 q5, q7\n" + "vzip.32 q4, q5\n" + "vzip.32 q6, q7\n" + + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + "add r0, r0, %[_stride]\n" + + "vzip.32 q8, q10\n" + "vzip.32 q9, q11\n" + "vzip.32 q8, q9\n" + "vzip.32 q10, q11\n" + + "vld1.f32 {d24-d25}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d26-d27}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d28-d29}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d30-d31}, [r0]\n" + + "vzip.32 q12, q14\n" + "vzip.32 q13, q15\n" + "vzip.32 q12, q13\n" + "vzip.32 q14, q15\n" + + "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" + "vst1.f32 {d16-d17}, [%[prhs_ptr]]!\n" + "vst1.f32 {d24-d25}, [%[prhs_ptr]]!\n" + "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" + "vst1.f32 {d18-d19}, [%[prhs_ptr]]!\n" + "vst1.f32 {d26-d27}, [%[prhs_ptr]]!\n" + "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" + "vst1.f32 {d20-d21}, [%[prhs_ptr]]!\n" + "vst1.f32 {d28-d29}, [%[prhs_ptr]]!\n" + "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" + "vst1.f32 {d22-d23}, [%[prhs_ptr]]!\n" + "vst1.f32 {d30-d31}, [%[prhs_ptr]]!\n" + + "subs %[nk], %[nk], #1\n" + "add %[buffer], %[buffer], #16\n" + "bne 0b\n" + : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15"); +#endif // __aarch64__ + } + + for (int j = 0; j < rk; j++) + { + prhs_ptr[0] = buffer[0]; + prhs_ptr[1] = buffer[k]; + prhs_ptr[2] = buffer[k << 1]; + prhs_ptr[3] = buffer[3 * k]; + prhs_ptr[4] = buffer[k << 2]; + prhs_ptr[5] = buffer[5 * k]; + prhs_ptr[6] = buffer[6 * k]; + prhs_ptr[7] = buffer[7 * k]; + prhs_ptr[8] = buffer[k << 3]; + prhs_ptr[9] = buffer[9 * k]; + prhs_ptr[10] = buffer[10 * k]; + prhs_ptr[11] = buffer[11 * k]; + prhs_ptr += nr; + buffer++; + } + break; + + case 8: + if (nk > 0) + { +#if __aarch64__ + asm volatile("0:\n" + "mov x0, %[buffer]\n" + + "ld1 {v4.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v5.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v7.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + + "zip1 v28.4s, v4.4s, v6.4s\n" + "zip2 v30.4s, v4.4s, v6.4s\n" + "zip1 v29.4s, v5.4s, v7.4s\n" + "zip2 v31.4s, v5.4s, v7.4s\n" + "zip1 v4.4s, v28.4s, v29.4s\n" + "zip2 v5.4s, v28.4s, v29.4s\n" + "zip1 v6.4s, v30.4s, v31.4s\n" + "zip2 v7.4s, v30.4s, v31.4s\n" + + "ld1 {v8.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v9.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v10.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v11.4s}, [x0]\n" + + "zip1 v28.4s, v8.4s, v10.4s\n" + "zip2 v30.4s, v8.4s, v10.4s\n" + "zip1 v29.4s, v9.4s, v11.4s\n" + "zip2 v31.4s, v9.4s, v11.4s\n" + "zip1 v8.4s, v28.4s, v29.4s\n" + "zip2 v9.4s, v28.4s, v29.4s\n" + "zip1 v10.4s, v30.4s, v31.4s\n" + "zip2 v11.4s, v30.4s, v31.4s\n" + + "st1 {v4.4s}, [%[prhs_ptr]], #16\n" + "st1 {v8.4s}, [%[prhs_ptr]], #16\n" + "st1 {v5.4s}, [%[prhs_ptr]], #16\n" + "st1 {v9.4s}, [%[prhs_ptr]], #16\n" + "st1 {v6.4s}, [%[prhs_ptr]], #16\n" + "st1 {v10.4s}, [%[prhs_ptr]], #16\n" + "st1 {v7.4s}, [%[prhs_ptr]], #16\n" + "st1 {v11.4s}, [%[prhs_ptr]], #16\n" + + "subs %[nk], %[nk], #1\n" + "add %[buffer], %[buffer], #16\n" + "bne 0b\n" + : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile("0:\n" + "mov r0, %[buffer]\n" + + "vld1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[_stride]\n" + + "vzip.32 q4, q6\n" + "vzip.32 q5, q7\n" + "vzip.32 q4, q5\n" + "vzip.32 q6, q7\n" + + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d20-d21}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d22-d23}, [r0]\n" + + "vzip.32 q8, q10\n" + "vzip.32 q9, q11\n" + "vzip.32 q8, q9\n" + "vzip.32 q10, q11\n" + + "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" + "vst1.f32 {d16-d17}, [%[prhs_ptr]]!\n" + "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" + "vst1.f32 {d18-d19}, [%[prhs_ptr]]!\n" + "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" + "vst1.f32 {d20-d21}, [%[prhs_ptr]]!\n" + "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" + "vst1.f32 {d22-d23}, [%[prhs_ptr]]!\n" + + "subs %[nk], %[nk], #1\n" + "add %[buffer], %[buffer], #16\n" + "bne 0b\n" + : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); +#endif // __aarch64__ + } + + for (int j = 0; j < rk; j++) + { + prhs_ptr[0] = buffer[0]; + prhs_ptr[1] = buffer[k]; + prhs_ptr[2] = buffer[k << 1]; + prhs_ptr[3] = buffer[3 * k]; + prhs_ptr[4] = buffer[k << 2]; + prhs_ptr[5] = buffer[5 * k]; + prhs_ptr[6] = buffer[6 * k]; + prhs_ptr[7] = buffer[7 * k]; + prhs_ptr += nr; + buffer++; + } + break; +#if !__aarch64__ + case 6: + if (nk > 0) + { + asm volatile("0:\n" + "mov r0, %[buffer]\n" + + "vld1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d16-d17}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d18-d19}, [r0]\n" + + "vzip.32 q4, q6\n" + "vzip.32 q5, q7\n" + "vzip.32 q4, q5\n" + "vzip.32 q6, q7\n" + "vzip.32 q8, q9\n" + + "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" + "vst1.f32 {d16}, [%[prhs_ptr]]!\n" + "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" + "vst1.f32 {d17}, [%[prhs_ptr]]!\n" + "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" + "vst1.f32 {d18}, [%[prhs_ptr]]!\n" + "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" + "vst1.f32 {d19}, [%[prhs_ptr]]!\n" + + "subs %[nk], %[nk], #1\n" + "add %[buffer], %[buffer], #16\n" + "bne 0b\n" + : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "r0", "q4", "q5", "q6", "q7", "q8", "q9"); + } + + for (int j = 0; j < rk; j++) + { + prhs_ptr[0] = buffer[0]; + prhs_ptr[1] = buffer[k]; + prhs_ptr[2] = buffer[k << 1]; + prhs_ptr[3] = buffer[3 * k]; + prhs_ptr[4] = buffer[k << 2]; + prhs_ptr[5] = buffer[5 * k]; + prhs_ptr += nr; + buffer++; + } + break; +#endif // !__aarch64__ + case 4: + if (nk > 0) + { +#if __aarch64__ + asm volatile("0:\n" + "mov x0, %[buffer]\n" + + "ld1 {v4.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v5.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v6.4s}, [x0]\n" + "add x0, x0, %[_stride]\n" + "ld1 {v7.4s}, [x0]\n" + + "zip1 v28.4s, v4.4s, v6.4s\n" + "zip2 v30.4s, v4.4s, v6.4s\n" + "zip1 v29.4s, v5.4s, v7.4s\n" + "zip2 v31.4s, v5.4s, v7.4s\n" + "zip1 v4.4s, v28.4s, v29.4s\n" + "zip2 v5.4s, v28.4s, v29.4s\n" + "zip1 v6.4s, v30.4s, v31.4s\n" + "zip2 v7.4s, v30.4s, v31.4s\n" + + "st1 {v4.4s}, [%[prhs_ptr]], #16\n" + "st1 {v5.4s}, [%[prhs_ptr]], #16\n" + "st1 {v6.4s}, [%[prhs_ptr]], #16\n" + "st1 {v7.4s}, [%[prhs_ptr]], #16\n" + + "subs %[nk], %[nk], #1\n" + "add %[buffer], %[buffer], #16\n" + "bne 0b\n" + : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "x0", "v4", "v5", "v6", "v7", "v28", "v29", "v30", "v31"); +#else // __aarch64__ + asm volatile("0:\n" + "mov r0, %[buffer]\n" + + "vld1.f32 {d8-d9}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d10-d11}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d12-d13}, [r0]\n" + "add r0, r0, %[_stride]\n" + "vld1.f32 {d14-d15}, [r0]\n" + + "vzip.32 q4, q6\n" + "vzip.32 q5, q7\n" + "vzip.32 q4, q5\n" + "vzip.32 q6, q7\n" + + "vst1.f32 {d8-d9}, [%[prhs_ptr]]!\n" + "vst1.f32 {d10-d11}, [%[prhs_ptr]]!\n" + "vst1.f32 {d12-d13}, [%[prhs_ptr]]!\n" + "vst1.f32 {d14-d15}, [%[prhs_ptr]]!\n" + + "subs %[nk], %[nk], #1\n" + "add %[buffer], %[buffer], #16\n" + "bne 0b\n" + : [buffer] "+r"(buffer), [prhs_ptr] "+r"(prhs_ptr), [nk] "+r"(nk) + : [_stride] "r"(_stride) + : "cc", "memory", "r0", "q4", "q5", "q6", "q7"); +#endif // __aarch64__ + } + + for (int j = 0; j < rk; j++) + { + prhs_ptr[0] = buffer[0]; + prhs_ptr[1] = buffer[k]; + prhs_ptr[2] = buffer[k << 1]; + prhs_ptr[3] = buffer[3 * k]; + prhs_ptr += nr; + buffer++; + } + break; + default: + break; + } +} + +void _pack_colmajor_notrans_lhs(const int mr, const int mb, const int kb, const int stride, + const float *lhs_ptr, float *plhs_ptr) +{ + _pack_rowmajor_notrans_rhs(mr, mb, kb, stride, lhs_ptr, plhs_ptr); +} + +void _pack_colmajor_notrans_rhs(const int nr, const int nb, const int kb, const int stride, + const float *rhs_ptr, float *prhs_ptr) +{ + _pack_rowmajor_notrans_lhs(nr, nb, kb, stride, rhs_ptr, prhs_ptr); +} + +void _pack_colmajor_trans_lhs(const int mr, const int mb, const int kb, const int stride, + const float *lhs_ptr, float *plhs_ptr) +{ + _pack_rowmajor_notrans_lhs(mr, mb, kb, stride, lhs_ptr, plhs_ptr); +} + +void _pack_colmajor_trans_rhs(const int nr, const int nb, const int kb, const int stride, + const float *rhs_ptr, float *prhs_ptr) +{ + _pack_rowmajor_notrans_rhs(nr, nb, kb, stride, rhs_ptr, prhs_ptr); +} + +void _pack_colmajor_image_rhs(const int nr, const int nb, const int kb, const int k0, const int n0, + convMat_t *input, convMat_t *output, convParams_t *params, + float *prhs_ptr) +{ + const int w = input->w; + const int h = input->h; + const int c = input->c; + const int outw = output->w; + const int kernel_w = params->kernel_w; + const int kernel_h = params->kernel_h; + const int stride_w = params->stride_w; + const int stride_h = params->stride_h; + const int pad_w = params->pad_w; + const int pad_h = params->pad_h; + const float *input_data = input->data; + + int c0 = c - k0 % c; + if (c0 > kb) + c0 = kb; + int nc = (kb - c0 + c - 1) / c; + if (c0) + nc++; + const int cn = (kb - c0) % c; + + int seg0 = outw - n0 % outw; + if (seg0 > nb) + seg0 = nb; + int rows = (nb - seg0 + outw - 1) / outw; + if (seg0) + rows++; + const int segn = (nb - seg0) % outw; + + const int in_row0 = n0 / outw * stride_h; + const int in_col0 = n0 % outw * stride_w; + + for (int i = 0; i < nc; i++) + { + const int channels = (i == 0 && c0 != 0) ? c0 : ((i == nc - 1 && cn != 0) ? cn : c); + const int c1 = (i == 0) ? k0 % c : 0; + + float tmp_data[channels * nr]; + int nindex = 0; + float *buffer = tmp_data; + float *prhs_tmp = prhs_ptr; + + const int in_row1 = (k0 / c + i) / kernel_w % kernel_h * params->dilation_h + in_row0; + const int in_col1 = (k0 / c + i) % kernel_w * params->dilation_w; + + int in_row = in_row1 - pad_h; + + for (int out_rows = rows; out_rows; out_rows--) + { + int cols = (out_rows != 1 || segn == 0) ? outw : segn; + int in_col = in_col1 - pad_w; + if (out_rows == rows) + { + cols = seg0; + in_col += in_col0; + } + if ((unsigned int)in_row < (unsigned int)h) + { + for (int out_col = cols; out_col; out_col--) + { + if ((unsigned int)in_col < (unsigned int)w) + { + for (int j = c1; j < c1 + channels; j++) + { + *(buffer++) = input_data[(in_row * w + in_col) * c + j]; + } + } + else + { + for (int j = 0; j < channels; j++) + { + *(buffer++) = 0; + } + } + in_col += stride_w; + + nindex++; + if (nindex == nr) + { + nindex = 0; + buffer = tmp_data; + _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); + prhs_tmp += kb * nr; + } + } + } + else + { + for (int out_col = cols; out_col; out_col--) + { + for (int j = 0; j < channels; j++) + { + *(buffer++) = 0; + } + in_col += stride_w; + + nindex++; + if (nindex == nr) + { + nindex = 0; + buffer = tmp_data; + _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); + prhs_tmp += kb * nr; + } + } + } + + in_row += stride_h; + } + + if (nindex > 0) + { + float *data = tmp_data; + for (int i = 0; i < channels; i++) + { + for (int j = 0; j < nindex; j++) + { + prhs_tmp[j] = data[j * channels]; + } + for (int j = nindex; j < nr; j++) + { + prhs_tmp[j] = 0.f; + } + prhs_tmp += nr; + data++; + } + } + + prhs_ptr += channels * nr; + } +} + +void _pack_colmajor_image_rhs_batch(const int nr, const int nb, const int kb, const int k0, + const int n0, convMat_t *input, convMat_t *output, + convParams_t *params, float *prhs_ptr) +{ + const int w = input->w; + const int h = input->h; + const int c = input->c; + const int outw = output->w; + const int kernel_w = params->kernel_w; + const int kernel_h = params->kernel_h; + const int stride_w = params->stride_w; + const int stride_h = params->stride_h; + + int c0 = c - k0 % c; + if (c0 > kb) + c0 = kb; + int nc = (kb - c0 + c - 1) / c; + if (c0) + nc++; + const int cn = (kb - c0) % c; + + const int seg_size = output->w * output->h; + + const float *indata = input->data + (w * h * c) * (n0 / seg_size); + + int bseg0 = seg_size - n0 % seg_size; + if (bseg0 > nb) + bseg0 = nb; + int bnseg = (nb - bseg0 + seg_size - 1) / seg_size; + if (bseg0) + bnseg++; + const int bsegn = (nb - bseg0) % seg_size; + + for (int ll = 0; ll < nc; ll++) + { + const float *input_data = indata; + + const int channels = (ll == 0 && c0 != 0) ? c0 : ((ll == nc - 1 && cn != 0) ? cn : c); + const int c1 = (ll == 0) ? k0 % c : 0; + + int nindex = 0; + float *prhs_tmp = prhs_ptr; + float tmp_data[channels * nr]; + float *buffer = tmp_data; + + for (int i = 0; i < bnseg; i++) + { + const int _nb = + ((i == 0 && bseg0 != 0) ? bseg0 : ((i == bnseg - 1 && bsegn != 0) ? bsegn : seg_size)); + const int _n0 = (i == 0 ? n0 % seg_size : 0); + + int seg0 = outw - _n0 % outw; + if (seg0 > _nb) + seg0 = _nb; + int rows = (_nb - seg0 + outw - 1) / outw; + if (seg0) + rows++; + const int segn = (_nb - seg0) % outw; + + const int in_row0 = _n0 / outw * stride_h; + const int in_col0 = _n0 % outw * stride_w; + + const int in_row1 = (k0 / c + ll) / kernel_w % kernel_h + in_row0; + const int in_col1 = (k0 / c + ll) % kernel_w; + + int in_row = in_row1; + + for (int out_rows = rows; out_rows; out_rows--) + { + int cols = (out_rows != 1 || segn == 0) ? outw : segn; + int in_col = in_col1; + if (out_rows == rows) + { + cols = seg0; + in_col += in_col0; + } + if ((unsigned int)in_row < (unsigned int)h) + { + for (int out_col = cols; out_col; out_col--) + { + if ((unsigned int)in_col < (unsigned int)w) + { + for (int j = c1; j < c1 + channels; j++) + { + *(buffer++) = input_data[(in_row * w + in_col) * c + j]; + } + } + else + { + for (int j = 0; j < channels; j++) + { + *(buffer++) = 0; + } + } + in_col += stride_w; + + nindex++; + if (nindex == nr) + { + nindex = 0; + buffer = tmp_data; + _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); + prhs_tmp += kb * nr; + } + } + } + else + { + for (int out_col = cols; out_col; out_col--) + { + for (int j = 0; j < channels; j++) + { + *(buffer++) = 0; + } + in_col += stride_w; + + nindex++; + if (nindex == nr) + { + nindex = 0; + buffer = tmp_data; + _pack_colmajor_image_rhs_sub(nr, channels, tmp_data, prhs_tmp); + prhs_tmp += kb * nr; + } + } + } + + in_row += stride_h; + } + + input_data += w * h * c; + } + + if (nindex > 0) + { + float *data = tmp_data; + for (int ii = 0; ii < channels; ii++) + { + for (int jj = 0; jj < nindex; jj++) + { + prhs_tmp[jj] = data[jj * channels]; + } + for (int jj = nindex; jj < nr; jj++) + { + prhs_tmp[jj] = 0.f; + } + prhs_tmp += nr; + data++; + } + } + + prhs_ptr += channels * nr; + } +} + +void _unpack_colmajor_image_res(const int mb, const int nb, const int m0, const int n0, + convMat_t *input, convMat_t *output, convParams_t *params, + float *pres_ptr) +{ + const int w = input->w; + const int outw = output->w; + const int outh = output->h; + const int outc = output->c; + const int kernel_w = params->kernel_w; + const int kernel_h = params->kernel_h; + const int stride_w = params->stride_w; + const int stride_h = params->stride_h; + const int pad_w = params->pad_w; + const int pad_h = params->pad_h; + float *output_data = output->data; + + int c0 = outc - m0 % outc; + if (c0 > mb) + c0 = mb; + int nc = (mb - c0 + outc - 1) / outc; + if (c0) + nc++; + const int cn = (mb - c0) % outc; + + int seg0 = w - n0 % w; + if (seg0 > nb) + seg0 = nb; + int rows = (nb - seg0 + w - 1) / w; + if (seg0) + rows++; + const int segn = (nb - seg0) % w; + + const int out_row0 = n0 / w * stride_h; + const int out_col0 = n0 % w * stride_w; + + for (int i = 0; i < nc; i++) + { + const int channels = (i == 0 && c0 != 0) ? c0 : ((i == nc - 1 && cn != 0) ? cn : outc); + const int c1 = (i == 0) ? m0 % outc : 0; + + float *buffer = pres_ptr; + + const int out_row1 = (m0 / outc + i) / kernel_w % kernel_h * params->dilation_h + out_row0; + const int out_col1 = (m0 / outc + i) % kernel_w * params->dilation_w; + + int out_row = out_row1 - pad_h; + + for (int in_rows = rows; in_rows; in_rows--) + { + int cols = (in_rows != 1 || segn == 0) ? w : segn; + int out_col = out_col1 - pad_w; + if (in_rows == rows) + { + cols = seg0; + out_col += out_col0; + } + if ((unsigned int)out_row < (unsigned int)outh) + { + for (int in_col = cols; in_col; in_col--) + { + if ((unsigned int)out_col < (unsigned int)outw) + { + for (int j = c1; j < c1 + channels; j++) + { + // Note:Data competition for multi-threads + //#pragma omp atomic //low performance + output_data[(out_row * outw + out_col) * outc + j] += *(buffer + j - c1); + } + } + buffer += mb; + out_col += stride_w; + } + } + else + { + buffer += cols * mb; + } + out_row += stride_h; + } + + pres_ptr += channels; + } +} + +void _sparse_pack_rowmajor_image(const int nb, const int k0, const int n0, convMat_t *input, + convMat_t *output, convParams_t *params, float *prhs_ptr) +{ + const int w = input->w; + const int h = input->h; + const int outw = output->w; + const int kernel_w = params->kernel_w; + const int kernel_h = params->kernel_h; + const int stride_w = params->stride_w; + const int stride_h = params->stride_h; + const int pad_w = params->pad_w; + const int pad_h = params->pad_h; + + const int in_row0 = n0 / outw * stride_h; + const int in_col0 = n0 % outw * stride_w; + int seg0 = outw - n0 % outw; + if (seg0 > nb) + seg0 = nb; + int rows = (nb - seg0 + outw - 1) / outw; + if (seg0) + rows++; + const int segn = (nb - seg0) % outw; + + const int ic = k0 / (kernel_w * kernel_h); + const int in_row1 = ((k0 / kernel_w) % kernel_h) * params->dilation_h + in_row0; + const int in_col1 = k0 % kernel_w * params->dilation_w; + +#ifdef NCNN + const float *input_data = input->data + ic * alignSize(w * h, 16 / sizeof(float)); +#else // NCNN + const float *input_data = input->data + ic * w * h; +#endif // NCNN + + int in_row = in_row1 - pad_h; + + for (int out_rows = rows; out_rows; out_rows--) + { + int cols = (out_rows != 1 || segn == 0) ? outw : segn; + int in_col = in_col1 - pad_w; + if (out_rows == rows) + { + cols = seg0; + in_col += in_col0; + } + if ((unsigned int)in_row < (unsigned int)h) + { + for (int out_col = cols; out_col; out_col--) + { + if ((unsigned int)in_col < (unsigned int)w) + *(prhs_ptr++) = input_data[in_row * w + in_col]; + else + *(prhs_ptr++) = 0; + in_col += stride_w; + } + } + else + { + for (int out_col = cols; out_col; out_col--) + { + *(prhs_ptr++) = 0; + in_col += stride_w; + } + } + + in_row += stride_h; + } +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/sgemm_pack.h b/compute/ncnn/src/srcn/sgemm_pack.h new file mode 100644 index 000000000..d64843ebb --- /dev/null +++ b/compute/ncnn/src/srcn/sgemm_pack.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_SGEMM_PACK_H__ +#define __NNFW_SRCN_SGEMM_PACK_H__ + +#include "ncnn/srcn/conv_type.h" + +namespace nnfw +{ +namespace srcn +{ + +void _pack_rowmajor_notrans_lhs(const int mr, const int mb, const int kb, const int stride, + const float *lhs_ptr, float *plhs_ptr); +void _pack_rowmajor_notrans_rhs(const int nr, const int nb, const int kb, const int stride, + const float *rhs_ptr, float *prhs_ptr); +void _pack_rowmajor_trans_lhs(const int mr, const int mb, const int kb, const int stride, + const float *lhs_ptr, float *plhs_ptr); +void _pack_rowmajor_trans_rhs(const int nr, const int nb, const int kb, const int stride, + const float *rhs_ptr, float *prhs_ptr); +void _pack_rowmajor_image_rhs(const int nr, const int nb, const int kb, const int k0, const int n0, + convMat_t *input, convMat_t *output, convParams_t *params, + float *prhs_ptr); +void _pack_rowmajor_image_rhs_batch(const int nr, const int nb, const int kb, const int k0, + const int n0, convMat_t *input, convMat_t *output, + convParams_t *params, float *prhs_ptr); + +void _unpack_rowmajor_image_res(const int mb, const int nb, const int m0, const int n0, + convMat_t *input, convMat_t *output, convParams_t *params, + float *pres_ptr); + +void _pack_colmajor_notrans_lhs(const int mr, const int mb, const int kb, const int stride, + const float *lhs_ptr, float *plhs_ptr); +void _pack_colmajor_notrans_rhs(const int nr, const int nb, const int kb, const int stride, + const float *rhs_ptr, float *prhs_ptr); +void _pack_colmajor_trans_lhs(const int mr, const int mb, const int kb, const int stride, + const float *lhs_ptr, float *plhs_ptr); +void _pack_colmajor_trans_rhs(const int nr, const int nb, const int kb, const int stride, + const float *rhs_ptr, float *prhs_ptr); + +void _pack_colmajor_image_rhs(const int nr, const int nb, const int kb, const int k0, const int n0, + convMat_t *input, convMat_t *output, convParams_t *params, + float *prhs_ptr); + +void _pack_colmajor_image_rhs_batch(const int nr, const int nb, const int kb, const int k0, + const int n0, convMat_t *input, convMat_t *output, + convParams_t *params, float *prhs_ptr); + +void _unpack_colmajor_image_res(const int mb, const int nb, const int m0, const int n0, + convMat_t *input, convMat_t *output, convParams_t *params, + float *pres_ptr); + +void _sparse_pack_rowmajor_image(const int nb, const int k0, const int n0, convMat_t *input, + convMat_t *output, convParams_t *params, float *prhs_ptr); + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_SGEMM_PACK_H__ diff --git a/compute/ncnn/src/srcn/sgemm_singlethread.cc b/compute/ncnn/src/srcn/sgemm_singlethread.cc new file mode 100644 index 000000000..3de3e1214 --- /dev/null +++ b/compute/ncnn/src/srcn/sgemm_singlethread.cc @@ -0,0 +1,689 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdexcept> + +#include "common.h" +#include "sgemm_kernel.h" +#include "sgemm_pack.h" +#include "sgemm_singlethread.h" + +namespace nnfw +{ +namespace srcn +{ + +void sgemm_singlethread::param_init() +{ + if (n_ >= m_) + { + shard_type_ = shardByRow; + } + else + { + shard_type_ = shardByCol; + } + +#if __aarch64__ + if (major_type_ == rowMajor) + { + if (shard_type_ == shardByRow) + { + mr_ = 8; + nr_ = 12; + } + else + { + mr_ = 12; + nr_ = 8; + } + } + else if (major_type_ == colMajor) + { + mr_ = 12; + nr_ = 8; + } +#else // __aarch64__ + if (major_type_ == rowMajor) + { + // it is a bug, but i do not know why as now. + if (ltrans_ == notrans && rtrans_ == trans) + { + mr_ = 4; + nr_ = 12; + } + else + { + mr_ = 6; + nr_ = 8; + } + } + else if (major_type_ == colMajor) + { + mr_ = 8; + nr_ = 6; + } +#endif // __aarch64__ + + int k_div = (nr_ * sizeof_RhsScalar); + int k_sub = (mr_ * nr_ * sizeof_ResScalar); + + int gen_col = GEN_COL / cache_div_; + int min_k = MAX_K / cache_div_; + + const int k_cache = MIN(divup((int)(L1_CACHE_SIZE - k_sub), (int)k_div), min_k); + bk_ = MIN(k_cache, k_); + + if (shard_type_ == shardByCol) + { + int m_sub = (bk_ * nr_ * sizeof_RhsScalar); + int m_div = (sizeof_LhsScalar * bk_ * 2 * cache_div_); + if (L3_CACHE_SIZE) + m_div = (sizeof_LhsScalar * bk_ * 2); + int m_cache = divup((L2_CACHE_SIZE - m_sub), m_div); + bm_ = MIN(m_cache, m_); + + bn_ = MIN(gen_col, n_); + if (L3_CACHE_SIZE) + { + int n_sub = (bk_ * bm_ * sizeof_RhsScalar); + int n_cache = divup((L3_CACHE_SIZE - n_sub), (sizeof_LhsScalar * bk_ * 2)); + bn_ = MIN(n_cache, bn_); + } + } + else + { + int n_sub = (bk_ * mr_ * sizeof_RhsScalar); + int n_div = (sizeof_LhsScalar * bk_ * 2 * cache_div_); + if (L3_CACHE_SIZE) + n_div = (sizeof_LhsScalar * bk_ * 2); + int n_cache = divup((L2_CACHE_SIZE - n_sub), n_div); + bn_ = MIN(n_cache, n_); + + bm_ = MIN(gen_col, m_); + if (L3_CACHE_SIZE) + { + int m_sub = (bk_ * bn_ * sizeof_RhsScalar); + int m_cache = divup((L3_CACHE_SIZE - m_sub), (sizeof_LhsScalar * bk_ * 2)); + bm_ = MIN(m_cache, bm_); + } + } + + nm_ = divup(m_, bm_); + nn_ = divup(n_, bn_); + nk_ = divup(k_, bk_); + + rm_ = m_ % bm_; + rn_ = n_ % bn_; + rk_ = k_ % bk_; +} + +sgemm_singlethread::sgemm_singlethread(sgemmType_t major_type, sgemmTrans_t ltrans, + sgemmTrans_t rtrans, const int m, const int n, const int k, + const float *lhs_data, const float *rhs_data, + float *res_data, int cache_div) + : lhs_data_(lhs_data), rhs_data_(rhs_data), res_data_(res_data), major_type_(major_type), + ltrans_(ltrans), rtrans_(rtrans), m_(m), n_(n), k_(k), cache_div_(cache_div) +{ + param_init(); +} + +sgemm_singlethread::~sgemm_singlethread() {} + +void sgemm_singlethread::run() +{ + if (major_type_ == rowMajor) + { + if (ltrans_ == notrans && rtrans_ == notrans) + { + compute_rowmajor_nn(); + } + else if (ltrans_ == notrans && rtrans_ == trans) + { + compute_rowmajor_nt(); + } + else if (ltrans_ == trans && rtrans_ == notrans) + { + compute_rowmajor_tn(); + } + else if (ltrans_ == trans && rtrans_ == trans) + { + compute_rowmajor_tt(); + } + else + { + throw std::runtime_error{"error trans type."}; + } + } + else if (major_type_ == colMajor) + { + if (ltrans_ == notrans && rtrans_ == notrans) + { + compute_colmajor_nn(); + } + else if (ltrans_ == notrans && rtrans_ == trans) + { + compute_colmajor_nt(); + } + else if (ltrans_ == trans && rtrans_ == notrans) + { + compute_colmajor_tn(); + } + else if (ltrans_ == trans && rtrans_ == trans) + { + compute_colmajor_tt(); + } + else + { + throw std::runtime_error{"error trans type."}; + } + } + else + { + throw std::runtime_error{"error major type."}; + } +} + +void sgemm_singlethread::compute_rowmajor_nn() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_rowmajor_nt() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_notrans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_rowmajor_tn() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_rowmajor_notrans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_rowmajor_tt() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + _sgemm_rowmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_rowmajor_trans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_rowmajor_trans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + _sgemm_rowmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[i * bm_ * n_ + j * bn_], l, n_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_colmajor_nn() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_colmajor_nt() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_notrans_lhs(mr_, bm, bk, m_, &lhs_data_[l * bk_ * m_ + i * bm_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_colmajor_tn() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_colmajor_notrans_rhs(nr_, bn, bk, k_, &rhs_data_[j * bn_ * k_ + l * bk_], prhs_ptr); + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +void sgemm_singlethread::compute_colmajor_tt() +{ + int mstride = (bm_ + mr_ - 1) / mr_ * mr_; + int nstride = (bn_ + nr_ - 1) / nr_ * nr_; + + float plhs_ptr[mstride * bk_]; + float prhs_ptr[nstride * bk_]; + + if (shard_type_ == shardByCol) + { + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + _sgemm_colmajor_macro_kernel_divnm(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else if (shard_type_ == shardByRow) + { + for (int i = 0; i < nm_; i++) + { + const int bm = (i != nm_ - 1 || rm_ == 0) ? bm_ : rm_; + + for (int l = 0; l < nk_; l++) + { + const int bk = (l != nk_ - 1 || rk_ == 0) ? bk_ : rk_; + + _pack_colmajor_trans_lhs(mr_, bm, bk, k_, &lhs_data_[i * bm_ * k_ + l * bk_], plhs_ptr); + + for (int j = 0; j < nn_; j++) + { + const int bn = (j != nn_ - 1 || rn_ == 0) ? bn_ : rn_; + + _pack_colmajor_trans_rhs(nr_, bn, bk, n_, &rhs_data_[l * bk_ * n_ + j * bn_], prhs_ptr); + + _sgemm_colmajor_macro_kernel_divmn(mr_, nr_, bm, bn, bk, plhs_ptr, prhs_ptr, + &res_data_[j * bn_ * m_ + i * bm_], l, m_, bk); + } + } + } + } + else + { + throw std::runtime_error{"error shard type."}; + } +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/sgemm_singlethread.h b/compute/ncnn/src/srcn/sgemm_singlethread.h new file mode 100644 index 000000000..47954e028 --- /dev/null +++ b/compute/ncnn/src/srcn/sgemm_singlethread.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_SGEMM_SINGLETHREAD_H__ +#define __NNFW_SRCN_SGEMM_SINGLETHREAD_H__ + +#include "common.h" + +namespace nnfw +{ +namespace srcn +{ + +typedef enum { rowMajor = 0, colMajor } sgemmType_t; + +typedef enum { trans = 0, notrans } sgemmTrans_t; + +class sgemm_singlethread +{ +public: + sgemm_singlethread(sgemmType_t major_type, sgemmTrans_t ltrans, sgemmTrans_t rtrans, const int m, + const int n, const int k, const float *lhs_data, const float *rhs_data, + float *res_data, int cache_div); + ~sgemm_singlethread(); + + void run(); + +private: + void param_init(); + + void compute_rowmajor_nn(); + void compute_rowmajor_nt(); + void compute_rowmajor_tn(); + void compute_rowmajor_tt(); + + void compute_colmajor_nn(); + void compute_colmajor_nt(); + void compute_colmajor_tn(); + void compute_colmajor_tt(); + + const float *lhs_data_; + const float *rhs_data_; + float *res_data_; + + sgemmType_t major_type_; + sgemmTrans_t ltrans_; + sgemmTrans_t rtrans_; + + int m_; + int n_; + int k_; + + int bm_; + int bn_; + int bk_; + + int rm_; + int rn_; + int rk_; + + int nm_; + int nn_; + int nk_; + + int mr_; + int nr_; + + shardType_t shard_type_; + int cache_div_; +}; + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_SGEMM_SINGLETHREAD_H__ diff --git a/compute/ncnn/src/srcn/sgemm_test.cc b/compute/ncnn/src/srcn/sgemm_test.cc new file mode 100644 index 000000000..1b10970bb --- /dev/null +++ b/compute/ncnn/src/srcn/sgemm_test.cc @@ -0,0 +1,1883 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdio.h> +#include <stdlib.h> +#include <sys/time.h> +#include <unistd.h> + +#include "ncnn/srcn/conv_type.h" +#include "srcn/srcn_conv.h" +//#include "srcn_sgemm.h" +#include "conv_sgemm_singlethread.h" +#include "conv_sgemm_multithreads.h" +//#include "conv_sgemm_batch.h" +#include "sgemm_singlethread.h" +#include "conv_winograd.h" +#include "winograd.h" + +//#include "conv_gpu.h" +//#include "convolutiondepthwise_3x3.h" + +namespace nnfw +{ +namespace srcn +{ + +static void direct_conv_rowmajor(convMat_t *input, convMat_t *output, convMat_t *filter, + convParams_t *params) +{ + const int w = input->w; + const int h = input->h; + const int inch = input->c; + const int outw = output->w; + const int outh = output->h; + const int outch = output->c; + const int kernel_w = params->kernel_w; + const int kernel_h = params->kernel_h; + const int stride_w = params->stride_w; + const int stride_h = params->stride_h; + const int pad_w = params->pad_w; + const int pad_h = params->pad_h; + const int dilation_w = params->dilation_w; + const int dilation_h = params->dilation_h; + const float *input_data = input->data; + const float *filter_data = filter->data; + float *output_data = output->data; + + for (int out_c = 0; out_c < outch; out_c++) + { + for (int out_row = 0; out_row < outh; out_row++) + { + for (int out_col = 0; out_col < outw; out_col++) + { + const int in_col0 = (out_col * stride_w) - pad_w; + const int in_row0 = (out_row * stride_h) - pad_h; + float sum = 0.f; + for (int in_c = 0; in_c < inch; in_c++) + { + for (int filter_y = 0; filter_y < kernel_h; filter_y++) + { + for (int filter_x = 0; filter_x < kernel_w; filter_x++) + { + const int in_col = in_col0 + filter_x * dilation_w; + const int in_row = in_row0 + filter_y * dilation_h; + + if (((unsigned int)in_col < (unsigned int)w) && + ((unsigned int)in_row < (unsigned int)h)) + { + float input_value = input_data[(in_c * h + in_row) * w + in_col]; + float filter_value = + filter_data[((out_c * inch + in_c) * kernel_h + filter_y) * kernel_w + + filter_x]; + sum += (input_value * filter_value); + } + } + } + } + output_data[(out_c * outh + out_row) * outw + out_col] = sum; + } + } + } +} + +static void direct_deconv_rowmajor(convMat_t *input, convMat_t *output, convMat_t *filter, + convParams_t *params) +{ + const int w = input->w; + const int h = input->h; + const int inch = input->c; + const int outw = output->w; + const int outh = output->h; + const int outch = output->c; + const int kernel_w = params->kernel_w; + const int kernel_h = params->kernel_h; + const int stride_w = params->stride_w; + const int stride_h = params->stride_h; + const int pad_w = params->pad_w; + const int pad_h = params->pad_h; + const int dilation_w = params->dilation_w; + const int dilation_h = params->dilation_h; + const float *input_data = input->data; + const float *filter_data = filter->data; + float *output_data = output->data; + + for (int i = 0; i < outw * outh * outch; i++) + { + output_data[i] = 0; + } + + for (int in_c = 0; in_c < inch; in_c++) + { + for (int in_row = 0; in_row < h; in_row++) + { + for (int in_col = 0; in_col < w; in_col++) + { + const int out_col0 = (in_col * stride_w) - pad_w; + const int out_row0 = (in_row * stride_h) - pad_h; + float in_value = input_data[(in_c * h + in_row) * w + in_col]; + for (int out_c = 0; out_c < outch; out_c++) + { + for (int filter_y = 0; filter_y < kernel_h; filter_y++) + { + for (int filter_x = 0; filter_x < kernel_w; filter_x++) + { + const int out_col = out_col0 + filter_x * dilation_w; + const int out_row = out_row0 + filter_y * dilation_h; + + if (((unsigned int)out_col < (unsigned int)outw) && + ((unsigned int)out_row < (unsigned int)outh)) + { + float filter_value = + filter_data[((in_c * outch + out_c) * kernel_h + filter_y) * kernel_w + + filter_x]; + output_data[(out_c * outh + out_row) * outw + out_col] += filter_value * in_value; + } + } + } + } + } + } + } +} + +static void direct_sgemm_rowmajor(int Atrans, int Btrans, int m, int n, int k, float *A, float *B, + float *C) +{ + float *aa, *bb; + + if (Atrans == trans) + { + aa = (float *)malloc(m * k * sizeof(float)); + if (!aa) + return; + + for (int i = 0; i < k; i++) + { + for (int j = 0; j < m; j++) + { + aa[j * k + i] = A[i * m + j]; + } + } + } + else + { + aa = A; + } + + if (Btrans == trans) + { + bb = (float *)malloc(n * k * sizeof(float)); + if (!bb) + return; + + for (int i = 0; i < n; i++) + { + for (int j = 0; j < k; j++) + { + bb[j * n + i] = B[i * k + j]; + } + } + } + else + { + bb = B; + } + + for (int i = 0; i < m; i++) + { + for (int j = 0; j < n; j++) + { + float res = 0.f; + for (int l = 0; l < k; l++) + { + res += aa[i * k + l] * bb[l * n + j]; + } + C[i * n + j] = res; + } + } +} + +/*static void direct_sgemm_kernel(const int k, const int lhs_stride, const int rhs_stride, const int +res_stride, + const float *lhs_ptr, const float *rhs_ptr, float *res_ptr) +{ + int lstride = lhs_stride << 2; + int rstride = rhs_stride << 2; + int estride = res_stride << 2; + int rstep = rstride << 2; + + int nk = (k >> 2) - 1; + + __asm __volatile ( + "movi v16.4s, #0x0\n" + "movi v17.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + "movi v21.4s, #0x0\n" + "movi v22.4s, #0x0\n" + "movi v23.4s, #0x0\n" + "movi v24.4s, #0x0\n" + "movi v25.4s, #0x0\n" + "movi v26.4s, #0x0\n" + "movi v27.4s, #0x0\n" + "movi v28.4s, #0x0\n" + "movi v29.4s, #0x0\n" + "movi v30.4s, #0x0\n" + "movi v31.4s, #0x0\n" + + "mov x0, %[lhs_ptr]\n" + "add %[lhs_ptr], %[lhs_ptr], #16\n" + "ld1 {v0.4s}, [x0]\n" + "add x0, x0, %[lstride]\n" + "ld1 {v1.4s}, [x0]\n" + "add x0, x0, %[lstride]\n" + "ld1 {v2.4s}, [x0]\n" + "add x0, x0, %[lstride]\n" + "ld1 {v3.4s}, [x0]\n" + "add x0, x0, %[lstride]\n" + + "mov x1, %[rhs_ptr]\n" + "add %[rhs_ptr], %[rhs_ptr], %[rstep]\n" + "ld1 {v8.4s, v9.4s}, [x1]\n" + "add x1, x1, %[rstride]\n" + "ld1 {v10.4s, v11.4s}, [x1]\n" + "add x1, x1, %[rstride]\n" + + "1:\n" + "fmla v16.4s, v8.4s, v0.s[0]\n" + "fmla v17.4s, v9.4s, v0.s[0]\n" + "fmla v16.4s, v10.4s, v0.s[1]\n" + "fmla v17.4s, v11.4s, v0.s[1]\n" + "fmla v18.4s, v8.4s, v1.s[0]\n" + "fmla v19.4s, v9.4s, v1.s[0]\n" + "fmla v18.4s, v10.4s, v1.s[1]\n" + "fmla v19.4s, v11.4s, v1.s[1]\n" + "ld1 {v12.4s, v13.4s}, [x1]\n" + "fmla v20.4s, v8.4s, v2.s[0]\n" + "add x1, x1, %[rstride]\n" + "fmla v21.4s, v9.4s, v2.s[0]\n" + "ld1 {v14.4s, v15.4s}, [x1]\n" + "fmla v20.4s, v10.4s, v2.s[1]\n" + "add x1, x1, %[rstride]\n" + "fmla v21.4s, v11.4s, v2.s[1]\n" + "fmla v22.4s, v8.4s, v3.s[0]\n" + "fmla v23.4s, v9.4s, v3.s[0]\n" + "fmla v22.4s, v10.4s, v3.s[1]\n" + "fmla v23.4s, v11.4s, v3.s[1]\n" + + "ld1 {v4.4s}, [x0]\n" + "fmla v16.4s, v12.4s, v0.s[2]\n" + "add x0, x0, %[lstride]\n" + "fmla v17.4s, v13.4s, v0.s[2]\n" + "ld1 {v5.4s}, [x0]\n" + "fmla v16.4s, v14.4s, v0.s[3]\n" + "add x0, x0, %[lstride]\n" + "fmla v17.4s, v15.4s, v0.s[3]\n" + "ld1 {v6.4s}, [x0]\n" + "fmla v18.4s, v12.4s, v1.s[2]\n" + "add x0, x0, %[lstride]\n" + "fmla v19.4s, v13.4s, v1.s[2]\n" + "ld1 {v7.4s}, [x0]\n" + "fmla v18.4s, v14.4s, v1.s[3]\n" + "add x0, x0, %[lstride]\n" + "fmla v19.4s, v15.4s, v1.s[3]\n" + "fmla v20.4s, v12.4s, v2.s[2]\n" + "fmla v21.4s, v13.4s, v2.s[2]\n" + "fmla v20.4s, v14.4s, v2.s[3]\n" + "fmla v21.4s, v15.4s, v2.s[3]\n" + "fmla v22.4s, v12.4s, v3.s[2]\n" + "fmla v23.4s, v13.4s, v3.s[2]\n" + "fmla v22.4s, v14.4s, v3.s[3]\n" + "fmla v23.4s, v15.4s, v3.s[3]\n" + + "mov x0, %[lhs_ptr]\n" + "add %[lhs_ptr], %[lhs_ptr], #16\n" + + "fmla v24.4s, v8.4s, v4.s[0]\n" + "fmla v25.4s, v9.4s, v4.s[0]\n" + "ld1 {v0.4s}, [x0]\n" + "fmla v24.4s, v10.4s, v4.s[1]\n" + "add x0, x0, %[lstride]\n" + "fmla v25.4s, v11.4s, v4.s[1]\n" + "ld1 {v1.4s}, [x0]\n" + "fmla v26.4s, v8.4s, v5.s[0]\n" + "add x0, x0, %[lstride]\n" + "fmla v27.4s, v9.4s, v5.s[0]\n" + "ld1 {v2.4s}, [x0]\n" + "fmla v26.4s, v10.4s, v5.s[1]\n" + "add x0, x0, %[lstride]\n" + "fmla v27.4s, v11.4s, v5.s[1]\n" + "ld1 {v3.4s}, [x0]\n" + "fmla v28.4s, v8.4s, v6.s[0]\n" + "add x0, x0, %[lstride]\n" + "fmla v29.4s, v9.4s, v6.s[0]\n" + "fmla v28.4s, v10.4s, v6.s[1]\n" + "fmla v29.4s, v11.4s, v6.s[1]\n" + "fmla v30.4s, v8.4s, v7.s[0]\n" + "fmla v31.4s, v9.4s, v7.s[0]\n" + "fmla v30.4s, v10.4s, v7.s[1]\n" + "fmla v31.4s, v11.4s, v7.s[1]\n" + + "mov x1, %[rhs_ptr]\n" + "add %[rhs_ptr], %[rhs_ptr], %[rstep]\n" + + "fmla v24.4s, v12.4s, v4.s[2]\n" + "fmla v25.4s, v13.4s, v4.s[2]\n" + "ld1 {v8.4s, v9.4s}, [x1]\n" + "fmla v24.4s, v14.4s, v4.s[3]\n" + "add x1, x1, %[rstride]\n" + "fmla v25.4s, v15.4s, v4.s[3]\n" + "ld1 {v10.4s, v11.4s}, [x1]\n" + "fmla v26.4s, v12.4s, v5.s[2]\n" + "add x1, x1, %[rstride]\n" + "fmla v27.4s, v13.4s, v5.s[2]\n" + "fmla v26.4s, v14.4s, v5.s[3]\n" + "fmla v27.4s, v15.4s, v5.s[3]\n" + "fmla v28.4s, v12.4s, v6.s[2]\n" + "fmla v29.4s, v13.4s, v6.s[2]\n" + "fmla v28.4s, v14.4s, v6.s[3]\n" + "fmla v29.4s, v15.4s, v6.s[3]\n" + "fmla v30.4s, v12.4s, v7.s[2]\n" + "fmla v31.4s, v13.4s, v7.s[2]\n" + "subs %w[nk], %w[nk], #1\n" + "fmla v30.4s, v14.4s, v7.s[3]\n" + "fmla v31.4s, v15.4s, v7.s[3]\n" + "bne 1b\n" + + "fmla v16.4s, v8.4s, v0.s[0]\n" + "fmla v17.4s, v9.4s, v0.s[0]\n" + "fmla v16.4s, v10.4s, v0.s[1]\n" + "fmla v17.4s, v11.4s, v0.s[1]\n" + "fmla v18.4s, v8.4s, v1.s[0]\n" + "fmla v19.4s, v9.4s, v1.s[0]\n" + "fmla v18.4s, v10.4s, v1.s[1]\n" + "fmla v19.4s, v11.4s, v1.s[1]\n" + "ld1 {v12.4s, v13.4s}, [x1]\n" + "fmla v20.4s, v8.4s, v2.s[0]\n" + "add x1, x1, %[rstride]\n" + "fmla v21.4s, v9.4s, v2.s[0]\n" + "ld1 {v14.4s, v15.4s}, [x1]\n" + "fmla v20.4s, v10.4s, v2.s[1]\n" + "add x1, x1, %[rstride]\n" + "fmla v21.4s, v11.4s, v2.s[1]\n" + "fmla v22.4s, v8.4s, v3.s[0]\n" + "fmla v23.4s, v9.4s, v3.s[0]\n" + "fmla v22.4s, v10.4s, v3.s[1]\n" + "fmla v23.4s, v11.4s, v3.s[1]\n" + + "ld1 {v4.4s}, [x0]\n" + "fmla v16.4s, v12.4s, v0.s[2]\n" + "add x0, x0, %[lstride]\n" + "fmla v17.4s, v13.4s, v0.s[2]\n" + "ld1 {v5.4s}, [x0]\n" + "fmla v16.4s, v14.4s, v0.s[3]\n" + "add x0, x0, %[lstride]\n" + "fmla v17.4s, v15.4s, v0.s[3]\n" + "ld1 {v6.4s}, [x0]\n" + "fmla v18.4s, v12.4s, v1.s[2]\n" + "add x0, x0, %[lstride]\n" + "fmla v19.4s, v13.4s, v1.s[2]\n" + "ld1 {v7.4s}, [x0]\n" + "fmla v18.4s, v14.4s, v1.s[3]\n" + "add x0, x0, %[lstride]\n" + "fmla v19.4s, v15.4s, v1.s[3]\n" + "fmla v20.4s, v12.4s, v2.s[2]\n" + "fmla v21.4s, v13.4s, v2.s[2]\n" + "fmla v20.4s, v14.4s, v2.s[3]\n" + "fmla v21.4s, v15.4s, v2.s[3]\n" + "fmla v22.4s, v12.4s, v3.s[2]\n" + "fmla v23.4s, v13.4s, v3.s[2]\n" + "fmla v22.4s, v14.4s, v3.s[3]\n" + "fmla v23.4s, v15.4s, v3.s[3]\n" + + "mov x0, %[res_ptr]\n" + "fmla v24.4s, v8.4s, v4.s[0]\n" + "fmla v25.4s, v9.4s, v4.s[0]\n" + "st1 {v16.4s, v17.4s}, [x0]\n" + "add x0, x0, %[estride]\n" + "fmla v24.4s, v10.4s, v4.s[1]\n" + "fmla v25.4s, v11.4s, v4.s[1]\n" + "st1 {v18.4s, v19.4s}, [x0]\n" + "add x0, x0, %[estride]\n" + "fmla v26.4s, v8.4s, v5.s[0]\n" + "fmla v27.4s, v9.4s, v5.s[0]\n" + "st1 {v20.4s, v21.4s}, [x0]\n" + "add x0, x0, %[estride]\n" + "fmla v26.4s, v10.4s, v5.s[1]\n" + "fmla v27.4s, v11.4s, v5.s[1]\n" + "st1 {v22.4s, v23.4s}, [x0]\n" + "add x0, x0, %[estride]\n" + "fmla v28.4s, v8.4s, v6.s[0]\n" + "fmla v29.4s, v9.4s, v6.s[0]\n" + "fmla v28.4s, v10.4s, v6.s[1]\n" + "fmla v29.4s, v11.4s, v6.s[1]\n" + "fmla v30.4s, v8.4s, v7.s[0]\n" + "fmla v31.4s, v9.4s, v7.s[0]\n" + "fmla v30.4s, v10.4s, v7.s[1]\n" + "fmla v31.4s, v11.4s, v7.s[1]\n" + + "fmla v24.4s, v12.4s, v4.s[2]\n" + "fmla v25.4s, v13.4s, v4.s[2]\n" + "fmla v24.4s, v14.4s, v4.s[3]\n" + "fmla v25.4s, v15.4s, v4.s[3]\n" + "fmla v26.4s, v12.4s, v5.s[2]\n" + "fmla v27.4s, v13.4s, v5.s[2]\n" + "st1 {v24.4s, v25.4s}, [x0]\n" + "add x0, x0, %[estride]\n" + "fmla v26.4s, v14.4s, v5.s[3]\n" + "fmla v27.4s, v15.4s, v5.s[3]\n" + "fmla v28.4s, v12.4s, v6.s[2]\n" + "fmla v29.4s, v13.4s, v6.s[2]\n" + "st1 {v26.4s, v27.4s}, [x0]\n" + "add x0, x0, %[estride]\n" + "fmla v28.4s, v14.4s, v6.s[3]\n" + "fmla v29.4s, v15.4s, v6.s[3]\n" + "fmla v30.4s, v12.4s, v7.s[2]\n" + "fmla v31.4s, v13.4s, v7.s[2]\n" + "st1 {v28.4s, v29.4s}, [x0]\n" + "add x0, x0, %[estride]\n" + "fmla v30.4s, v14.4s, v7.s[3]\n" + "fmla v31.4s, v15.4s, v7.s[3]\n" + "st1 {v30.4s, v31.4s}, [x0]\n" + :[lhs_ptr] "+r" (lhs_ptr), [rhs_ptr] "+r" (rhs_ptr), [res_ptr] "+r" (res_ptr), + [nk] "+r" (nk) + : [lstride] "r" (lstride), [rstride] "r" (rstride), [estride] "r" (estride), [rstep] "r" +(rstep) + : "x0", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); +}*/ + +static void direct_conv_colmajor(convMat_t *input, convMat_t *output, convMat_t *filter, + convParams_t *params) +{ + const int w = input->w; + const int h = input->h; + const int inch = input->c; + const int outw = output->w; + const int outh = output->h; + const int outch = output->c; + const int kernel_w = params->kernel_w; + const int kernel_h = params->kernel_h; + const int stride_w = params->stride_w; + const int stride_h = params->stride_h; + const int pad_w = params->pad_w; + const int pad_h = params->pad_h; + const int dilation_w = params->dilation_w; + const int dilation_h = params->dilation_h; + const float *input_data = input->data; + const float *filter_data = filter->data; + float *output_data = output->data; + + for (int out_row = 0; out_row < outh; out_row++) + { + for (int out_col = 0; out_col < outw; out_col++) + { + const int in_col0 = (out_col * stride_w) - pad_w; + const int in_row0 = (out_row * stride_h) - pad_h; + + for (int out_c = 0; out_c < outch; out_c++) + { + float sum = 0.f; + for (int filter_y = 0; filter_y < kernel_h; filter_y++) + { + for (int filter_x = 0; filter_x < kernel_w; filter_x++) + { + const int in_col = in_col0 + filter_x * dilation_w; + const int in_row = in_row0 + filter_y * dilation_h; + + if (((unsigned int)in_col < (unsigned int)w) && + ((unsigned int)in_row < (unsigned int)h)) + { + for (int in_c = 0; in_c < inch; in_c++) + { + float input_value = input_data[(in_row * w + in_col) * inch + in_c]; + float filter_value = + filter_data[((filter_y * kernel_w + filter_x) * inch + in_c) * outch + out_c]; + sum += (input_value * filter_value); + } + } + } + } + output_data[(out_row * outw + out_col) * outch + out_c] = sum; + } + } + } +} + +static void direct_sgemm_colmajor(int Atrans, int Btrans, int m, int n, int k, float *A, float *B, + float *C) +{ + float *aa, *bb; + + if (Atrans) + { + aa = (float *)malloc(m * k * sizeof(float)); + if (!aa) + return; + + for (int i = 0; i < k; i++) + { + for (int j = 0; j < m; j++) + { + aa[i * m + j] = A[j * k + i]; + } + } + } + else + { + aa = A; + } + + if (Btrans) + { + bb = (float *)malloc(n * k * sizeof(float)); + if (!bb) + return; + + for (int i = 0; i < n; i++) + { + for (int j = 0; j < k; j++) + { + bb[i * k + j] = B[j * n + i]; + } + } + } + else + { + bb = B; + } + + for (int i = 0; i < m; i++) + { + for (int j = 0; j < n; j++) + { + float res = 0.f; + for (int l = 0; l < k; l++) + { + res += bb[j * k + l] * aa[l * m + i]; + } + C[j * m + i] = res; + } + } +} + +#if 0 +static int test_sgemm(int m, int n, int k, int loops) +{ + struct timeval start, end; + float total_time = 0.f; + + const int mb = 180; + const int nb = 1440; + const int kb = 512; + + const int mr = 4; + const int nr = 12; + +#if 0 + const int pm = (m + mr - 1) / mr * mr; + const int pn = (n + nr - 1) / nr * nr; + const int pk = k; +#else + const int pm = (mb + mr - 1) / mr * mr; + const int pn = (nb + nr - 1) / nr * nr; + const int pk = kb; +#endif + const int nm = (m + mb - 1) / mb; + const int nn = (n + nb - 1) / nb; + const int nk = (k + kb - 1) / kb; + + const int rm = m % mb; + const int rn = n % nb; + const int rk = k % kb; + + float *A = (float *)malloc(m * k * sizeof(float)); + if(!A) return 0; + + for(int i = 0 ; i < m * k; i++) + { + A[i] = 0.001 + i * 0.000001; + } + + float *B = (float *)malloc(k * n * sizeof(float)); + if(!B) return 0; + + for(int i = 0 ; i < n * k; i++) + { + B[i] = 0.001 - i * 0.000001; + } + + float *C = (float *)malloc(m * n * sizeof(float)); + if(!C) return 0; + +#if 0 + float *PA = (float *)malloc(pm * pk * sizeof(float)); + if(!PA) return 0; + + float *PB = (float *)malloc(pk * pn * sizeof(float)); + if(!PB) return 0; +#else + float PA[pm * pk]; + float PB[pk * pn]; +#endif + + for(int nloop = 0; nloop < loops; nloop++) + + { + gettimeofday(&start, NULL); + + //pack_rowmajor_notrans_lhs(mr, m, k, k, A, PA); + //pack_rowmajor_notrans_rhs(nr, n, k, n, B, PB); +#if 1 + for (int j = 0; j < nn; j++) + { + const int _nb = (j != nn - 1 || rn == 0) ? nb : rn; + for (int l = 0; l < nk; l++) + { + const int _kb = (l != nk - 1 || rk == 0) ? kb : rk; + pack_rowmajor_notrans_rhs(nr, _nb, _kb, 1, n, &B[l * kb * n + j * nb], PB); + for(int i = 0; i < nm; i++) + { + const int _mb = (i != nm - 1 || rm == 0) ? mb : rm; + pack_rowmajor_notrans_lhs(mr, _mb, _kb, 1, k, &A[i * mb * k + l * kb], PA); + sgemm_rowmajor_macro_kernel_divnm(mr, nr, _mb, _nb, _kb, PA, PB, &C[i * mb * n + j * nb], l, n, _kb); + //sgemm_rowmajor_macro_kernel_divnm(mr, nr, _mb, _nb, _kb, &PA[i * mb * k + l * kb], &PB[l * kb * pn + j * nb], &C[i * mb * n + j * nb], l, n, pk); + } + } + } +#else + for (int j = 0; j < nm; j++) + { + const int _mb = (j != nm - 1 || rm == 0) ? mb : rm; + for (int l = 0; l < nk; l++) + { + const int _kb = (l != nk - 1 || rk == 0) ? kb : rk; + pack_rowmajor_notrans_lhs(mr, _mb, _kb, 1, k, &A[j * mb * k + l * kb], PA); + for(int i = 0; i < nn; i++) + { + const int _nb = (i != nn - 1 || rn == 0) ? nb : rn; + pack_rowmajor_notrans_rhs(nr, _nb, _kb, 1, n, &B[l * kb * n + i * nb], PB); + sgemm_rowmajor_macro_kernel_divmn(mr, nr, _mb, _nb, _kb, PA, PB, &C[j * mb * n + i * nb], l, n, _kb); + //sgemm_rowmajor_macro_kernel_divmn(mr, nr, _mb, _nb, _kb, &PA[i * mb * k + l * kb], &PB[l * kb * pn + j * nb], &C[i * mb * n + j * nb], l, n, pk); + } + } + } +#endif + gettimeofday(&end, NULL); + total_time += ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec))/1000; + } + + int div = m * n < 16 ? m * n : 16; + int num = m * n > 64 ? 64 : m * n; + + float *c_ptr = &C[0]; + for(int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if((i + 1) % div == 0) printf("\n"); + } + + printf("\n"); + + c_ptr = &C[m * n - num]; + for(int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if((i + 1) % div == 0) printf("\n"); + } + + printf("\n"); + + long long total_size = (long long)m *n * k * 2; + printf("AVER Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", total_time / loops , total_size, (double)total_size/(total_time / loops)/1000000); + + free(A); + free(B); + free(C); + + //free(PA); + //free(PB); + +} +#endif + +static int test_sgemm(int m, int n, int k, int type, int loops) +{ + struct timeval start, end; + float total_time = 0.f; + + // printf("1.\n"); + + float *A = (float *)malloc(m * k * sizeof(float)); + if (!A) + return 0; + + for (int i = 0; i < m * k; i++) + { + A[i] = 0.001 + i * 0.001; // i * 0.000001; + } + + float *B = (float *)malloc(k * n * sizeof(float)); + if (!B) + return 0; + + for (int i = 0; i < n * k; i++) + { + B[i] = 0.001 - i * 0.001; // - i * 0.000001; + } + + float *C = (float *)malloc(m * n * sizeof(float)); + if (!C) + return 0; + + for (int nloop = 0; nloop < loops; nloop++) + + { + gettimeofday(&start, NULL); + + if (type == 0) + { + // direct_sgemm_rowmajor(notrans, notrans, m, n, k, A, B, C); + direct_sgemm_colmajor(notrans, notrans, m, n, k, A, B, C); + } + + else if (type == 1) + { + class sgemm_singlethread my_gemm(colMajor, notrans, notrans, m, n, k, A, B, C, 1); + my_gemm.run(); + } + + /*else if(type == 2) + { + for(int i = 0; i < m / 8; i++) + { + for(int j = 0; j < n / 8; j++) + { + direct_sgemm_kernel(k, k, n, n, A + i * 8 * k, B + j * 8, C + i * 8 * n + j * 8); + } + } + }*/ + + gettimeofday(&end, NULL); + total_time += + ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; + } + + int div = m * n < 16 ? m * n : 16; + int num = m * n > 64 ? 64 : m * n; + + float *c_ptr = &C[0]; + for (int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if ((i + 1) % div == 0) + printf("\n"); + } + + printf("\n"); + + c_ptr = &C[m * n - num]; + for (int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if ((i + 1) % div == 0) + printf("\n"); + } + + printf("\n"); + + long long total_size = (long long)m * n * k * 2; + printf("AVER Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", total_time / loops, + total_size, (double)total_size / (total_time / loops) / 1000000); + + free(A); + free(B); + free(C); + + return 0; +} + +void weight_tensorflow2caffe(float *out, float *in, int H, int W, int C, int N) +{ // HWCN ---> NCHW + for (int h = 0; h < H; ++h) + { + for (int w = 0; w < W; ++w) + { + for (int c = 0; c < C; ++c) + { + for (int n = 0; n < N; ++n) + { + int index_in = h * W * C * N + w * C * N + c * N + n; + int index_out = n * C * H * W + c * H * W + h * W + w; + // printf("%3d <--- %3d\n", index_out, index_in); + out[index_out] = in[index_in]; + } + } + } + } +} + +void trans_weight2winograd(const convMat_t &_kernel, float **winograd_weight) +{ + const double *G; + const int kernel_size = _kernel.h; + const int channels = _kernel.c; + const int num_output = _kernel.n; + + int tile_h_in_, tile_w_in_; + int M, N; + + /*Step 1: transfer weight to winograd domain*/ + if (kernel_size == 3) + { + M = winograd_para_3x3s1::M; + N = winograd_para_3x3s1::N; + G = winograd_para_3x3s1::getG(); + } + else + { + M = winograd_para_5x5s1::M; + N = winograd_para_5x5s1::N; + G = winograd_para_5x5s1::getG(); + } + + tile_h_in_ = tile_w_in_ = M; + + float *winograd_g = new float[M * M * N * N]; + if (NULL == winograd_g) + return; + kronecker_product(winograd_g, G, G, M, N, M, N); + + *winograd_weight = new float[tile_h_in_ * tile_w_in_ * channels * num_output]; + + if (NULL == *winograd_weight) + return; + + float *weight_data_tran = new float[_kernel.h * _kernel.w * _kernel.c * _kernel.n]; + if (NULL == weight_data_tran) + return; + weight_tensorflow2caffe(weight_data_tran, _kernel.data, kernel_size, kernel_size, channels, + num_output); + + class sgemm_singlethread sgemm(rowMajor, notrans, trans, tile_h_in_ * tile_w_in_, + channels * num_output, kernel_size * kernel_size, winograd_g, + weight_data_tran, *winograd_weight, 1); + + sgemm.run(); + + delete[] weight_data_tran; + + /*With winograd, original weight data is useless.*/ + delete[] winograd_g; +} + +static int test_conv(const int w, const int h, const int kernel_size, const int stride, + const int inch, const int outch, const int padding, const int conv_type, + const int thread_num, const int loops) +{ + struct timeval start, end; + float total_time = 0.f; + + struct timeval start1, end1; + float total_time1 = 0.f; + + const int dilation = 1; + + const int kernel_dilation = dilation * (kernel_size - 1) + 1; + + convMat_t input; + convMat_t output; + convMat_t filter; + convParams_t params; + + int pad_l, pad_r, pad_t, pad_b; + if (padding) + { + int pad_w = kernel_dilation + (w - 1) / stride * stride - w; + int pad_h = kernel_dilation + (h - 1) / stride * stride - h; + pad_l = pad_w / 2; + pad_r = pad_w - pad_l; + pad_t = pad_h / 2; + pad_b = pad_h - pad_t; + } + else + { + pad_l = pad_r = pad_t = pad_b = 0; + } + + input.w = w; + input.h = h; + input.c = inch; + input.n = 1; +#ifdef NCNN + input.data = + (float *)malloc(alignSize(input.w * input.h, 16 / sizeof(float)) * input.c * sizeof(float)); +#else + input.data = (float *)malloc(input.w * input.h * input.c * sizeof(float)); +#endif + + if (!input.data) + return 0; + + output.w = (w + pad_l + pad_r - kernel_dilation) / stride + 1; + output.h = (h + pad_t + pad_b - kernel_dilation) / stride + 1; + output.c = outch; + output.n = 1; +#ifdef NCNN + output.data = (float *)malloc(alignSize(output.w * output.h, 16 / sizeof(float)) * output.c * + sizeof(float)); +#else + output.data = (float *)malloc(output.w * output.h * output.c * sizeof(float)); +#endif + + if (!output.data) + return 0; + + for (int i = 0; i < output.w * output.h * output.c; i++) + { + output.data[i] = 0; + } + + filter.w = kernel_size; + filter.h = kernel_size; + filter.c = inch; + filter.n = outch; + filter.data = (float *)malloc(filter.w * filter.h * filter.c * filter.n * sizeof(float)); + if (!filter.data) + return 0; + + for (int i = 0; i < input.w * input.h * input.c; i++) + { + input.data[i] = 0.001 + i * 0.000001; + } + +#if 1 + for (int i = 0; i < filter.w * filter.h * filter.c * filter.n; i++) + { + filter.data[i] = 0.001 - i * 0.000001; + } +#else + for (int i = 0; i < filter.w * filter.h * filter.c * filter.n; i++) + { + if ((i + 1) % 15 == 0) + filter.data[i] = 0.001 - i * 0.000001; + else + filter.data[i] = 0; + } +#endif + params.kernel_w = kernel_size; + params.kernel_h = kernel_size; + params.stride_w = stride; + params.stride_h = stride; + params.padding = padding; + params.pad_w = pad_l; + params.pad_h = pad_t; + params.dilation_w = dilation; + params.dilation_h = dilation; + + const int m = output.c; + const int n = output.w * output.h; + const int k = params.kernel_h * params.kernel_w * input.c; + + // ocl_context_t context; + size_t local_min[2]; + /** + if(conv_type == 14 || conv_type == 15 || conv_type == 6) + { + if(init_gpu(&context) < 0) return -1; + //if(conv_type ==14 || conv_type == 5) sgemm_ocltune(&context, m, n, (k < 1024 ? k : + 1024), local_min); + //else if(conv_type == 6) + { + if(kernel_size == 3) directconv_3x3S1_tune(&context, &input, &filter, &output, + local_min); + else if(kernel_size == 1) directconv_1x1S1_tune(&context, &input, &filter, &output, + local_min); + } + //local_min[0] = 1; local_min[1] = 1; + } + **/ + if (conv_type == 0) + { + for (int nloop = 0; nloop < loops; nloop++) + { + gettimeofday(&start, NULL); + + direct_conv_rowmajor(&input, &output, &filter, ¶ms); + // direct_conv_colmajor(&input, &output, &filter, ¶ms); + + gettimeofday(&end, NULL); + total_time += + ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; + } + } + else if (conv_type == 1) + { + for (int nloop = 0; nloop < loops; nloop++) + { + // printf("nloop = %d, thread_num = %d\n", nloop, thread_num); + // class srcn_sgemm my_gemm(input, filter, output, params, thread_num, col_major); + gettimeofday(&start, NULL); + + /*if(thread_num == 1) + { + class conv_sgemm_singlethread my_gemm(input, filter, output, params, col_major); + my_gemm.run(); + } + else + { + class conv_sgemm_multithreads my_gemm(input, filter, output, params, thread_num, + col_major); + my_gemm.run(); + }*/ + + srcn_convolution2D(input, filter, output, params, NULL, thread_num, row_major); + + // printf("sync\n"); + + gettimeofday(&end, NULL); + total_time += + ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; + } + } + else if (conv_type == 2) + { + float *winograd_weight; + + // trans_weight2winograd(filter, &winograd_weight); + + winogradParams_t wparams = {params.kernel_w, + params.kernel_h, + params.stride_w, + params.stride_h, + params.dilation_w, + params.dilation_h, + 1, + w, + h, + input.c, + output.c, + thread_num, + col_major, + filter.data}; + winograd_weight = trans_weight2winograd(wparams); + + for (int nloop = 0; nloop < loops; nloop++) + { + gettimeofday(&start, NULL); + + // class conv_winograd my_sgemm(input, output, params, col_major, winograd_weight, thread_num, + // w * h, n); + // my_sgemm.run(); + + srcn_convolution2D(input, filter, output, params, winograd_weight, thread_num, row_major); + + gettimeofday(&end, NULL); + total_time += + ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; + } + } + else if (conv_type == 3) + { + void *sparse_weight = trans_weight2sparse(filter); + + for (int nloop = 0; nloop < loops; nloop++) + { + gettimeofday(&start, NULL); + + srcn_sparse_convolution2D(input, output, params, sparse_weight, thread_num, row_major); + + gettimeofday(&end, NULL); + total_time += + ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; + } + + sparse_release(outch, sparse_weight); + } /** +else if(conv_type == 4) +{ +#if 0 + cl_int err; + convlib::load_opencl("./libmali.so"); + const int mpad = (m + 4 - 1) / 4 * 4; + const int npad = (n + 4 - 1) / 4 * 4; + cl_mem lhs_gpu = convlib::clCreateBuffer(context.context, CL_MEM_READ_WRITE | +CL_MEM_ALLOC_HOST_PTR, mpad * k * sizeof(float), NULL, &err); + if(err != CL_SUCCESS) + { + printf("err = %d@%s:%d\n", err, __FUNCTION__, __LINE__); + return -1; + } + + cl_image_format rhs_format = {CL_RGBA, CL_FLOAT}; + cl_image_desc desc = + { + CL_MEM_OBJECT_IMAGE2D, + (size_t)npad / 4, + (size_t)k, + 0, 0, + 0, + 0, 0, 0, 0 + }; + cl_mem rhs_gpu = convlib::clCreateImage(context.context, CL_MEM_READ_ONLY | +CL_MEM_ALLOC_HOST_PTR, &rhs_format, &desc, NULL, &err); + if(err != CL_SUCCESS) + { + printf("err = %d@%s:%d\n", err, __FUNCTION__, __LINE__); + return -1; + } + + cl_mem rhs_gpu = convlib::clCreateBuffer(context.context, CL_MEM_READ_WRITE | +CL_MEM_ALLOC_HOST_PTR, npad * k * sizeof(float), NULL, &err); + if(err != CL_SUCCESS) + { + printf("err = %d@%s:%d\n", err, __FUNCTION__, __LINE__); + return -1;; + } + + cl_mem res_gpu = convlib::clCreateBuffer(context.context, CL_MEM_READ_WRITE | +CL_MEM_ALLOC_HOST_PTR, mpad * npad * sizeof(float), NULL, &err); + if(err != CL_SUCCESS) + { + printf("err = %d@%s:%d\n", err, __FUNCTION__, __LINE__); + return -1; + } +#endif + for(int nloop = 0; nloop < loops + 1; nloop++) + { + gettimeofday(&start, NULL); + + //cl_mem _res_gpu = conv2D_gpu_sgemm(&context, &input, &filter, &output, ¶ms, local_min, +lhs_gpu, rhs_gpu, res_gpu); + + //get_result_gpu(&context, output.data + gpu_data_off, _res_gpu, m, n); + srcn_convolution2D_gpu(input, filter, output, params, row_major); + + gettimeofday(&end, NULL); + + if(nloop > 0) total_time += ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 ++ start.tv_usec))/1000; + } +} +else if(conv_type == 5) +{ + + for(int nloop = 0; nloop < loops + 1; nloop++) + { + gettimeofday(&start, NULL); + + //cl_mem res_gpu = conv2D_gpu_sgemm(&context, &input, &filter, &output, ¶ms, local_min); + + //clFlush(context.cmdQueue); + gettimeofday(&start1, NULL); + #if 1 + srcn_convolution2D(input, filter, output, params, NULL, thread_num, row_major + + #endif + //usleep(80 * 1000); + gettimeofday(&end1, NULL); + total_time1 += ((end1.tv_sec * 1000000 + end1.tv_usec) - (start1.tv_sec * 1000000 + +start1.tv_usec))/1000; + + //get_result_gpu(&context, output.data + gpu_data_off, res_gpu, m, n); + + srcn_convolution2D_dpu(input, filter, output, params, row_major); + + gettimeofday(&end, NULL); + if(nloop > 0) total_time += ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 ++ start.tv_usec))/1000; + } +} +else if(conv_type == 6) +{ + for(int nloop = 0; nloop < loops; nloop++) + { + gettimeofday(&start, NULL); + + if(kernel_size == 3 && stride == 1 && padding == 0) + { + conv2D_gpu_directconv_3x3S1(&context, &input, &filter, &output, ¶ms, local_min); + } + else if(kernel_size == 1 && stride == 1 && padding == 0) + { + conv2D_gpu_directconv_1x1S1(&context, &input, &filter, &output, ¶ms, local_min); + } + + gettimeofday(&end, NULL); + total_time += ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + +start.tv_usec))/1000; + } +}**/ + + int div = m * n < 16 ? m * n : 16; + int num = m * n > 64 ? 64 : m * n; + + if (conv_type < 4) + printf("[CPU RESULT]\n"); + else if (conv_type == 4) + printf("[GPU RESULT]\n"); + else if (conv_type == 5) + printf("[DPU RESULT]\n"); + float *c_ptr = output.data; + for (int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if ((i + 1) % div == 0) + printf("\n"); + } + + printf("\n"); + + c_ptr = &output.data[m * n - num]; + for (int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if ((i + 1) % div == 0) + printf("\n"); + } + + printf("\n"); + + long long total_size = (long long)m * n * k * 2; + printf( + "AVER Time consuming: %.2fms, CPU Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", + total_time / loops, total_time1 / loops, total_size, + (double)total_size / (total_time / loops) / 1000000); + + free(input.data); + free(output.data); + free(filter.data); + + return 0; +} + +static int test_deconv(const int w, const int h, const int kernel_size, const int stride, + const int inch, const int outch, const int padding, const int conv_type, + const int thread_num, const int loops) +{ + struct timeval start, end; + float total_time = 0.f; + + const int dilation = 1; + + const int kernel_dilation = dilation * (kernel_size - 1) + 1; + + convMat_t input; + convMat_t output; + convMat_t filter; + convParams_t params; + + int pad_l, pad_r, pad_t, pad_b; + if (padding) + { + int pad_w = kernel_dilation - 1; + int pad_h = kernel_dilation - 1; + pad_l = pad_w / 2; + pad_r = pad_w - pad_l; + pad_t = pad_h / 2; + pad_b = pad_h - pad_t; + } + else + { + pad_l = pad_r = pad_t = pad_b = 0; + } + + input.w = w; + input.h = h; + input.c = inch; + input.data = (float *)malloc(input.w * input.h * input.c * sizeof(float)); + if (!input.data) + return 0; + + // output.w = (w + pad_l + pad_r - kernel_dilation) / stride + 1; + // output.h = (h + pad_t + pad_b - kernel_dilation) / stride + 1; + output.w = stride * (w - 1) + kernel_dilation - (pad_l + pad_r); + output.h = stride * (h - 1) + kernel_dilation - (pad_t + pad_b); + output.c = outch; + output.data = (float *)malloc(output.w * output.h * output.c * sizeof(float)); + if (!output.data) + return 0; + + filter.w = kernel_size; + filter.h = kernel_size; + filter.c = outch; + filter.n = inch; + filter.data = (float *)malloc(filter.w * filter.h * filter.c * filter.n * sizeof(float)); + if (!filter.data) + return 0; + + for (int i = 0; i < input.w * input.h * input.c; i++) + { + input.data[i] = 0.001 + i * 0.000001; + } + + for (int i = 0; i < filter.w * filter.h * filter.c * filter.n; i++) + { + filter.data[i] = 0.001 - i * 0.000001; + } + + params.kernel_w = kernel_size; + params.kernel_h = kernel_size; + params.stride_w = stride; + params.stride_h = stride; + params.padding = padding; + params.pad_w = pad_l; + params.pad_h = pad_t; + params.dilation_w = dilation; + params.dilation_h = dilation; + + const int m = params.kernel_h * params.kernel_w * output.c; + const int n = input.w * input.h; + const int k = input.c; + + if (conv_type == 0) + { + for (int nloop = 0; nloop < loops; nloop++) + + { + gettimeofday(&start, NULL); + + direct_deconv_rowmajor(&input, &output, &filter, ¶ms); + + gettimeofday(&end, NULL); + total_time += + ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; + } + } + else if (conv_type == 1) + { + for (int nloop = 0; nloop < loops; nloop++) + + { + gettimeofday(&start, NULL); + + for (int i = 0; i < output.w * output.h * output.c; i++) + { + output.data[i] = 0; + } + + srcn_deconvolution2D(input, filter, output, params, thread_num, row_major); + + gettimeofday(&end, NULL); + total_time += + ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; + } + } + + const int output_size = output.w * output.h * output.c; + + int div = output_size < 16 ? output_size : 16; + int num = output_size > 64 ? 64 : output_size; + + float *c_ptr = output.data; + for (int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if ((i + 1) % div == 0) + printf("\n"); + } + + printf("\n"); + + c_ptr = &output.data[output_size - num]; + for (int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if ((i + 1) % div == 0) + printf("\n"); + } + + printf("\n"); + + long long total_size = (long long)m * n * k * 2; + printf("AVER Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", total_time / loops, + total_size, (double)total_size / (total_time / loops) / 1000000); + + free(input.data); + free(output.data); + free(filter.data); + + return 0; +} + +static int test_batch_conv(const int batch, const int w, const int h, const int kernel_size, + const int stride, const int inch, const int outch, const int padding, + const int conv_type, const int thread_num, const int loops) +{ + struct timeval start, end; + float total_time = 0.f; + + const int dilation = 1; + + const int kernel_dilation = dilation * (kernel_size - 1) + 1; + + convMat_t input; + convMat_t output; + convMat_t filter; + convParams_t params; + + int pad_l, pad_r, pad_t, pad_b; + if (padding) + { + int pad_w = kernel_dilation + (w - 1) / stride * stride - w; + int pad_h = kernel_dilation + (h - 1) / stride * stride - h; + pad_l = pad_w / 2; + pad_r = pad_w - pad_l; + pad_t = pad_h / 2; + pad_b = pad_h - pad_t; + } + else + { + pad_l = pad_r = pad_t = pad_b = 0; + } + + input.w = w; + input.h = h; + input.c = inch; + input.n = batch; + input.data = (float *)malloc(input.n * input.w * input.h * input.c * sizeof(float)); + if (!input.data) + return 0; + + output.w = (w + pad_l + pad_r - kernel_dilation) / stride + 1; + output.h = (h + pad_t + pad_b - kernel_dilation) / stride + 1; + output.c = outch; + output.n = batch; + output.data = (float *)malloc(output.n * output.w * output.h * output.c * sizeof(float)); + if (!output.data) + return 0; + + filter.w = kernel_size; + filter.h = kernel_size; + filter.c = inch; + filter.n = outch; + filter.data = (float *)malloc(filter.w * filter.h * filter.c * filter.n * sizeof(float)); + if (!filter.data) + return 0; + + for (int i = 0; i < input.w * input.h * input.c * input.n; i++) + { + input.data[i] = 0.001 + i * 0.000001; + } + + for (int i = 0; i < filter.w * filter.h * filter.c * filter.n; i++) + { + filter.data[i] = 0.001 - i * 0.000001; + } + + params.kernel_w = kernel_size; + params.kernel_h = kernel_size; + params.stride_w = stride; + params.stride_h = stride; + params.padding = padding; + params.pad_w = pad_l; + params.pad_h = pad_t; + params.dilation_w = dilation; + params.dilation_h = dilation; + + const int m = output.c; + const int n = output.w * output.h; + const int k = params.kernel_h * params.kernel_w * input.c; + + if (conv_type == 1) + { + for (int nloop = 0; nloop < loops; nloop++) + + { + // printf("nloop = %d, thread_num = %d\n", nloop, thread_num); + // class srcn_sgemm my_gemm(input, filter, output, params, thread_num, col_major); + + gettimeofday(&start, NULL); + + srcn_batch_convolution2D(input, filter, output, params, NULL, thread_num, col_major); + + gettimeofday(&end, NULL); + total_time += + ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; + } + } + else if (conv_type == 2) + { + float *winograd_weight; + + // trans_weight2winograd(filter, &winograd_weight); + + winogradParams_t wparams = {params.kernel_w, + params.kernel_h, + params.stride_w, + params.stride_h, + params.dilation_w, + params.dilation_h, + input.n, + w, + h, + input.c, + output.c, + thread_num, + col_major, + filter.data}; + winograd_weight = trans_weight2winograd(wparams); + + for (int nloop = 0; nloop < loops; nloop++) + + { + gettimeofday(&start, NULL); + + srcn_batch_convolution2D(input, filter, output, params, winograd_weight, thread_num, + col_major); + + gettimeofday(&end, NULL); + total_time += + ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; + } + } + + int div = m * n < 16 ? m * n : 16; + int num = m * n > 64 ? 64 : m * n; + + float *c_ptr = output.data; + for (int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if ((i + 1) % div == 0) + printf("\n"); + } + + printf("\n"); + + c_ptr = &output.data[m * n * batch - num]; + for (int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if ((i + 1) % div == 0) + printf("\n"); + } + + printf("\n"); + + long long total_size = (long long)batch * m * n * k * 2; + printf("AVER Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", total_time / loops, + total_size, (double)total_size / (total_time / loops) / 1000000); + + free(input.data); + free(output.data); + free(filter.data); + + return 0; +} + +static int test_depthwise_conv(const int w, const int h, const int kernel_size, const int stride, + const int inch, const int outch, const int padding, + const int conv_type, const int thread_num, const int loops) +{ + if (outch != inch) + return -1; + struct timeval start, end; + float total_time = 0.f; + + const int dilation = 1; + + const int kernel_dilation = dilation * (kernel_size - 1) + 1; + + convMat_t input; + convMat_t output; + convMat_t filter; + convMat_t bias; + convParams_t params; + + int pad_l, pad_r, pad_t, pad_b; + if (padding) + { + int pad_w = kernel_dilation + (w - 1) / stride * stride - w; + int pad_h = kernel_dilation + (h - 1) / stride * stride - h; + pad_l = pad_w / 2; + pad_r = pad_w - pad_l; + pad_t = pad_h / 2; + pad_b = pad_h - pad_t; + } + else + { + pad_l = pad_r = pad_t = pad_b = 0; + } + + input.w = w; + input.h = h; + input.c = inch; + input.n = 1; +#ifdef NCNN + input.data = + (float *)malloc(alignSize(input.w * input.h, 16 / sizeof(float)) * input.c * sizeof(float)); +#else + input.data = (float *)malloc(input.w * input.h * input.c * sizeof(float)); +#endif + if (!input.data) + return 0; + + output.w = (w + pad_l + pad_r - kernel_dilation) / stride + 1; + output.h = (h + pad_t + pad_b - kernel_dilation) / stride + 1; + output.c = outch; + output.n = 1; + +#ifdef NCNN + output.data = (float *)malloc(alignSize(output.w * output.h, 16 / sizeof(float)) * output.c * + sizeof(float)); +#else + output.data = (float *)malloc(output.w * output.h * output.c * sizeof(float)); +#endif + const int gpu_data_off = output.w * output.h * output.c; + if (!output.data) + return 0; + + for (int i = 0; i < output.w * output.h * output.c; i++) + { + output.data[i] = 1.f; + } + + filter.w = kernel_size; + filter.h = kernel_size; + filter.c = 1; + filter.n = outch; + filter.data = (float *)malloc(filter.w * filter.h * filter.c * filter.n * sizeof(float)); + if (!filter.data) + return 0; + + for (int i = 0; i < input.w * input.h * input.c; i++) + { + input.data[i] = 0.001 + i * 0.000001; + } + + for (int i = 0; i < filter.w * filter.h * filter.c * filter.n; i++) + { + filter.data[i] = 0.001 - i * 0.000001; + } + + bias.w = outch; + bias.data = (float *)malloc(bias.w * sizeof(float)); + if (!bias.data) + return 0; + for (int i = 0; i < bias.w; i++) + { + bias.data[i] = 0.f; + } + + params.kernel_w = kernel_size; + params.kernel_h = kernel_size; + params.stride_w = stride; + params.stride_h = stride; + params.padding = padding; + params.pad_w = pad_l; + params.pad_h = pad_t; + params.dilation_w = dilation; + params.dilation_h = dilation; + + const int m = output.c; + const int n = output.w * output.h; + const int k = params.kernel_h * params.kernel_w * input.c; + + // ocl_context_t context; + size_t local_min[2] = {4, 4}; + /** + if(conv_type == 1) + { + if(init_gpu(&context) < 0) return -1; + depthwise_conv_3x3S1_tune(&context, &input, &filter, &output, local_min); + }**/ + + gettimeofday(&start, NULL); + if (conv_type == 0) + srcn_depthwise_conv(input, filter, output, bias, params, 4, + row_major); // convdw3x3s1_neon(input, output, filter, filter); + // else if(conv_type == 1) depthwise_conv_gpu3x3S1(&context, &input, &filter, &output, ¶ms, + // local_min); + else if (conv_type == 2) + { + for (int i = 0; i < input.c; i++) + { + convMat_t _input; + convMat_t _output; + convMat_t _filter; + convParams_t _params = params; + + _input.w = input.w; + _input.h = input.h; + _input.c = 1; + _input.n = 1; +#ifdef NCNN + _input.data = input.data + i * alignSize(input.w * input.h, 16 / sizeof(float)); +#else + _input.data = input.data + i * input.w * input.h; +#endif + + _output.w = output.w; + _output.h = output.h; + _output.c = 1; + _output.n = 1; +#ifdef NCNN + _output.data = output.data + i * alignSize(output.w * output.h, 16 / sizeof(float)); +#else + _output.data = output.data + i * output.w * output.h; +#endif + _filter.w = filter.w; + _filter.h = filter.h; + _filter.c = 1; // filter.c; + _filter.n = 1; // filter.n; + _filter.data = filter.data + i * 9; + + srcn_convolution2D(_input, _filter, _output, _params, NULL, 1, row_major); + // direct_conv_rowmajor(&_input, &_output, &_filter, &_params); + } + } + + gettimeofday(&end, NULL); + total_time += + ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)) / 1000; + + int div = m * n < 16 ? m * n : 16; + int num = m * n > 64 ? 64 : m * n; + + if (conv_type == 0) + printf("[CPU RESULT]\n"); + else if (conv_type == 1) + printf("[GPU RESULT]\n"); + float *c_ptr = output.data; + for (int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if ((i + 1) % div == 0) + printf("\n"); + } + + printf("\n"); + + c_ptr = &output.data[m * n - num]; + for (int i = 0; i < num; i++) + { + printf("%f ", c_ptr[i]); + if ((i + 1) % div == 0) + printf("\n"); + } + + printf("\n"); + + long long total_size = (long long)m * n * k * 2; + printf("AVER Time consuming: %.2fms, total size: %lld, (GFLOP: %.2f)\n", total_time / loops, + total_size, (double)total_size / (total_time / loops) / 1000000); + + free(input.data); + free(output.data); + free(filter.data); + free(bias.data); + + return 0; +} + +//#define TEST_SGEMM +#define TEST_CONV +//#define TEST_DECONV +//#define TEST_BATCH_CONV +//#define TEST_DEPTHWISE_CONV + +int main(int argc, char **argv) +{ +#ifdef TEST_SGEMM + if (argc < 6) + return 0; + + const int m = atoi(argv[1]); + const int n = atoi(argv[2]); + const int k = atoi(argv[3]); + const int type = atoi(argv[4]); + const int loops = atoi(argv[5]); + + test_sgemm(m, n, k, type, loops); +#elif (defined TEST_CONV) + if (argc < 10) + return 0; + const int w = atoi(argv[1]); + const int h = atoi(argv[2]); + const int kernel_size = atoi(argv[3]); + const int stride = atoi(argv[4]); + const int outch = atoi(argv[5]); + const int inch = atoi(argv[6]); + const int padding = atoi(argv[7]); + const int conv_type = atoi(argv[8]); + const int thread_num = atoi(argv[9]); + int loops = 1; + if (argc > 10) + loops = atoi(argv[10]); + test_conv(w, h, kernel_size, stride, inch, outch, padding, conv_type, thread_num, loops); +#elif (defined TEST_DECONV) + if (argc < 10) + return 0; + const int w = atoi(argv[1]); + const int h = atoi(argv[2]); + const int kernel_size = atoi(argv[3]); + const int stride = atoi(argv[4]); + const int outch = atoi(argv[5]); + const int inch = atoi(argv[6]); + const int padding = atoi(argv[7]); + const int conv_type = atoi(argv[8]); + const int thread_num = atoi(argv[9]); + int loops = 1; + if (argc > 10) + loops = atoi(argv[10]); + test_deconv(w, h, kernel_size, stride, inch, outch, padding, conv_type, thread_num, loops); +#elif (defined TEST_BATCH_CONV) + if (argc < 11) + return 0; + const int batch = atoi(argv[1]); + const int w = atoi(argv[2]); + const int h = atoi(argv[3]); + const int kernel_size = atoi(argv[4]); + const int stride = atoi(argv[5]); + const int outch = atoi(argv[6]); + const int inch = atoi(argv[7]); + const int padding = atoi(argv[8]); + const int conv_type = atoi(argv[9]); + const int thread_num = atoi(argv[10]); + int loops = 1; + if (argc > 11) + loops = atoi(argv[11]); + test_batch_conv(batch, w, h, kernel_size, stride, inch, outch, padding, conv_type, thread_num, + loops); +#elif (defined TEST_DEPTHWISE_CONV) + if (argc < 10) + return 0; + const int w = atoi(argv[1]); + const int h = atoi(argv[2]); + const int kernel_size = atoi(argv[3]); + const int stride = atoi(argv[4]); + const int outch = atoi(argv[5]); + const int inch = atoi(argv[6]); + const int padding = atoi(argv[7]); + const int conv_type = atoi(argv[8]); + const int thread_num = atoi(argv[9]); + int loops = 1; + if (argc > 10) + loops = atoi(argv[10]); + test_depthwise_conv(w, h, kernel_size, stride, inch, outch, padding, conv_type, thread_num, + loops); +#endif + + return 0; +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/srcn_conv.cc b/compute/ncnn/src/srcn/srcn_conv.cc new file mode 100644 index 000000000..bb8e4f13e --- /dev/null +++ b/compute/ncnn/src/srcn/srcn_conv.cc @@ -0,0 +1,614 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef _OPENMP +#include <omp.h> +#endif + +#include "ncnn/srcn/conv_type.h" +#include "common.h" +#include "sgemm_singlethread.h" +#include "conv_sgemm_singlethread.h" +#include "conv_sgemm_multithreads.h" +#include "conv_winograd.h" +#include "direct_conv_colmajor.h" +#include "winograd.h" + +#include "deconv_sgemm_multithreads.h" +#include "conv_sparse.h" +#include "conv_winograd_batch.h" + +namespace nnfw +{ +namespace srcn +{ + +static inline void weight_transfer(float *out, float *in, int H, int W, int C, int N) +{ + // HWCN ---> NCHW + for (int h = 0; h < H; ++h) + { + for (int w = 0; w < W; ++w) + { + for (int c = 0; c < C; ++c) + { + for (int n = 0; n < N; ++n) + { + int index_in = h * W * C * N + w * C * N + c * N + n; + int index_out = n * C * H * W + c * H * W + h * W + w; + out[index_out] = in[index_in]; + } + } + } + } +} + +int check_winograd(winogradParams_t ¶ms) +{ + int winograd_flag = + ((params.kernel_w == params.kernel_h) && (params.stride_w == params.stride_h) && + (params.kernel_w == 3 || params.kernel_w == 5) && (params.stride_w == 1) && + (params.dilation_w == 1) && (params.dilation_h == 1)); + + int winograd_channel_cond = 64 * 64; + int winograd_image_cond = 10 * 10; + +#ifdef TIZEN + if (params.num_threads > 1) + { + winograd_channel_cond = 128 * 128; + winograd_image_cond = 20 * 20; + } +#endif // TIZEN + + winograd_flag &= (params.inch * params.outch >= winograd_channel_cond); + + if (params.w > 0 && params.h > 0 && params.batch == 1) + { + winograd_flag &= (params.w * params.h >= winograd_image_cond); + } + + return winograd_flag; +} + +float *trans_weight2winograd(winogradParams_t ¶ms, unsigned int *size = NULL) +{ + int M, N; + const double *G; + + float *winograd_weight; + + int winograd_channel_cond = 64 * 64; + int winograd_image_cond = 10 * 10; + +#ifdef TIZEN + if (params.num_threads > 1) + { + winograd_channel_cond = 128 * 128; + // int winograd_image_cond = 20 * 20; + } +#endif // TIZEN + + int winograd_flag = + ((params.kernel_w == params.kernel_h) && (params.stride_w == params.stride_h) && + (params.kernel_w == 3 || params.kernel_w == 5) && (params.stride_w == 1) && + (params.dilation_w == 1) && (params.dilation_h == 1)); + if (!winograd_flag) + return NULL; + + winograd_flag = (params.inch * params.outch >= winograd_channel_cond); + + if (!winograd_flag) + return NULL; + + if (params.w > 0 && params.h > 0 && params.batch == 1) + { + winograd_flag &= (params.w * params.h >= winograd_image_cond); + if (!winograd_flag) + return NULL; + } + + const int kernel_size = params.kernel_w; + const int inch = params.inch; + const int outch = params.outch; + float *weight_data = params.weight_data; + + /*Step 1: transfer weight to winograd domain*/ + if (kernel_size == 3) + { + if (params.w == 4 && params.batch > 1) + { + M = winograd_para_3x3s1_2::M; + N = winograd_para_3x3s1_2::N; + G = winograd_para_3x3s1_2::getG(); + } + else + { + M = winograd_para_3x3s1::M; + N = winograd_para_3x3s1::N; + G = winograd_para_3x3s1::getG(); + } + } + else + { + M = winograd_para_5x5s1::M; + N = winograd_para_5x5s1::N; + G = winograd_para_5x5s1::getG(); + } + + int tile_h_in_, tile_w_in_; + tile_h_in_ = tile_w_in_ = M; + + if (size) + *size = tile_h_in_ * tile_w_in_ * inch * outch; + + winograd_weight = new float[tile_h_in_ * tile_w_in_ * inch * outch]; + if (!winograd_weight) + return NULL; + + float *winograd_g = new float[M * M * N * N]; + if (!winograd_g) + { + delete[] winograd_weight; + return NULL; + } + + kronecker_product(winograd_g, G, G, M, N, M, N); + + if (params.conv_type == col_major) + { + weight_data = new float[kernel_size * kernel_size * inch * outch]; + if (!weight_data) + { + delete[] winograd_weight; + delete[] winograd_g; + return NULL; + } + weight_transfer(weight_data, params.weight_data, kernel_size, kernel_size, inch, outch); + } + + class sgemm_singlethread sgemm(rowMajor, notrans, trans, tile_h_in_ * tile_w_in_, inch * outch, + kernel_size * kernel_size, winograd_g, weight_data, + winograd_weight, 1); + + sgemm.run(); + + if (params.conv_type == col_major) + delete[] weight_data; + + delete[] winograd_g; + + return winograd_weight; +} + +void winograd_release(float *winograd_weight) +{ + if (winograd_weight) + delete[] winograd_weight; +} + +void srcn_convolution2D(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, + const convParams_t &in_param, const float *winograd_weight, int num_threads, + convType_t conv_type) +{ + const int outw = out_mat.w; + const int outh = out_mat.h; + const int inch = in_mat.c; + const int outch = out_mat.c; + + int winograd_flag = + ((in_param.kernel_w == in_param.kernel_h) && (in_param.stride_w == in_param.stride_h) && + (in_param.kernel_w == 3 || in_param.kernel_w == 5) && (in_param.stride_w == 1) && + (winograd_weight) && (in_param.dilation_w == 1) && (in_param.dilation_h == 1)); + + int direct_flag = ((conv_type == col_major) && (in_param.stride_w == in_param.stride_h) && + (in_param.dilation_w == 1) && (in_param.dilation_h == 1)); + + int winograd_image_cond = 10 * 10; + int winograd_channel_cond = 64 * 64; + int direct_image_cond = 4 * 4; + int direct_channel_cond = 16 * 16; + +#ifdef TIZEN + if (num_threads > 1) + { + winograd_image_cond = 20 * 20; + winograd_channel_cond = 128 * 128; + } +#endif + + winograd_flag &= + ((outw * outh >= winograd_image_cond) && (inch * outch >= winograd_channel_cond)); + direct_flag &= ((outw * outh <= direct_image_cond) || (inch * outch <= direct_channel_cond)); + + if (num_threads == 1) + { + if (winograd_flag) + { + class conv_winograd conv(in_mat, out_mat, in_param, conv_type, winograd_weight, num_threads, + in_mat.w * in_mat.h, outw * outh, outch); + conv.run(); + } + else if (direct_flag) + { + direct_conv_colmajor(in_mat, out_mat, weights_mat, in_param, num_threads); + } + else + { + class conv_sgemm_singlethread conv(in_mat, weights_mat, out_mat, in_param, conv_type); + conv.run(); + } + } + else if (num_threads > 1) + { + if (winograd_flag) + { + const int npart = num_threads > 4 ? 4 : num_threads; + + omp_set_num_threads(npart); + + if (conv_type == col_major) + { + if (outch < 512) + { + const int _H = (outh + npart - 1) / npart; + + if (_H < in_param.pad_h) + { + class conv_winograd conv(in_mat, out_mat, in_param, conv_type, winograd_weight, 1, + in_mat.w * in_mat.h, outw * outh, outch); + conv.run(); + return; + } + + // const int ih = (_H - 1) * in_param.stride_w + in_param.kernel_w; + // const int oh = _H; + const int nh = (outh + _H - 1) / _H; + int rh = outh % _H; + if (rh == 0) + rh = _H; + +#pragma omp parallel for + for (int i = 0; i < nh; i++) + { + int pad_h_part = 0; + convMat_t in_part; + convMat_t out_part; + const int oh = (i != nh - 1 || rh == 0) ? _H : rh; + const int ih = (oh - 1) * in_param.stride_w + in_param.kernel_w; + + in_part.w = in_mat.w; + in_part.c = inch; + out_part.w = outw; + out_part.c = outch; + in_part.h = ih; + out_part.h = oh; + + int bottom_offset = i * _H - in_param.pad_h; + if (bottom_offset < 0) + { + bottom_offset = 0; + pad_h_part = in_param.pad_h; + } + in_part.data = in_mat.data + bottom_offset * in_mat.w * inch * in_param.stride_w; + if (ih + bottom_offset > in_mat.h) + { + in_part.h = in_mat.h - bottom_offset; + } + + out_part.data = out_mat.data + i * _H * outw * outch; + + convParams_t params = { + in_param.kernel_w, in_param.kernel_h, in_param.stride_w, in_param.stride_h, 1, 1, + in_param.padding, in_param.pad_w, pad_h_part}; + + class conv_winograd conv(in_part, out_part, params, conv_type, winograd_weight, + num_threads, in_mat.w * in_mat.h, outw * outh, outch); + conv.run(); + } + } + else + { + const int _OUTC = (outch + npart - 1) / npart; + + const int nc = (outch + _OUTC - 1) / _OUTC; + int rc = out_mat.c % _OUTC; + if (rc == 0) + rc = _OUTC; + +#pragma omp parallel for + for (int i = 0; i < nc; i++) + { + const float *weight_part; + convMat_t out_part; + + const int oc = (i != nc - 1 || rc == 0) ? _OUTC : rc; + + out_part.w = outw; + out_part.h = outh; + out_part.c = oc; + out_part.data = out_mat.data + i * _OUTC; + weight_part = winograd_weight + i * _OUTC * inch; + class conv_winograd conv(in_mat, out_part, in_param, conv_type, weight_part, + num_threads, in_mat.w * in_mat.h, outw * outh, outch); + conv.run(); + } + } + } + else if (conv_type == row_major) + { +#ifdef TIZEN + if (outch < 512) +#else // TIZEN + if (outh >= 20) +#endif // TIZEN + { + const int _H = (outh + npart - 1) / npart; + + if (_H < in_param.pad_h) + { + class conv_winograd conv(in_mat, out_mat, in_param, conv_type, winograd_weight, 1, + in_mat.w * in_mat.h, outw * outh, outch); + conv.run(); + return; + } + + // const int ih = (_H - 1) * in_param.stride_w + in_param.kernel_w; + // const int oh = _H; + const int nh = (outh + _H - 1) / _H; + int rh = outh % _H; + if (rh == 0) + rh = _H; + +#pragma omp parallel for + for (int i = 0; i < nh; i++) + { + int pad_h_part = 0; + convMat_t in_part; + convMat_t out_part; + const int oh = (i != nh - 1 || rh == 0) ? _H : rh; + const int ih = (oh - 1) * in_param.stride_w + in_param.kernel_w; + + in_part.w = in_mat.w; + in_part.c = inch; + out_part.w = outw; + out_part.c = outch; + in_part.h = ih; + out_part.h = oh; + + int bottom_offset = i * _H - in_param.pad_h; + if (bottom_offset < 0) + { + bottom_offset = 0; + pad_h_part = in_param.pad_h; + } + in_part.data = in_mat.data + bottom_offset * in_mat.w * in_param.stride_w; + if (ih + bottom_offset > in_mat.h) + { + in_part.h = in_mat.h - bottom_offset; + } + + out_part.data = out_mat.data + i * _H * outw; + + convParams_t params = { + in_param.kernel_w, in_param.kernel_h, in_param.stride_w, 1, 1, + in_param.stride_h, in_param.padding, in_param.pad_w, pad_h_part}; + + class conv_winograd conv(in_part, out_part, params, conv_type, winograd_weight, + num_threads, in_mat.w * in_mat.h, outw * outh, outch); + conv.run(); + } + } + else + { + const int _OUTC = (outch + npart - 1) / npart; + + const int nc = (outch + _OUTC - 1) / _OUTC; + int rc = out_mat.c % _OUTC; + if (rc == 0) + rc = _OUTC; + +#pragma omp parallel for + for (int i = 0; i < nc; i++) + { + const float *weight_part; + convMat_t out_part; + + const int oc = (i != nc - 1 || rc == 0) ? _OUTC : rc; + + out_part.w = outw; + out_part.h = outh; + out_part.c = oc; + out_part.data = out_mat.data + i * _OUTC * outw * outh; + weight_part = winograd_weight + i * _OUTC * inch; + class conv_winograd conv(in_mat, out_part, in_param, conv_type, weight_part, + num_threads, in_mat.w * in_mat.h, outw * outh, outch); + conv.run(); + } + } + } + } + else if (direct_flag) + { + direct_conv_colmajor(in_mat, out_mat, weights_mat, in_param, num_threads); + } + else + { + class conv_sgemm_multithreads conv(in_mat, weights_mat, out_mat, in_param, num_threads, + conv_type); + conv.run(); + } + } +} + +void srcn_deconvolution2D(const convMat_t &in_mat, const convMat_t &weights_mat, convMat_t &out_mat, + const convParams_t &in_param, int num_threads, convType_t conv_type) +{ + class deconv_sgemm_multithreads deconv(in_mat, weights_mat, out_mat, in_param, num_threads, + conv_type); + deconv.run(); +} + +void *trans_weight2sparse(const convMat_t &weights_mat) +{ + const int kernel_w = weights_mat.w; + const int kernel_h = weights_mat.h; + const int inch = weights_mat.c; + const int outch = weights_mat.n; + + const int nch = (outch + BCH - 1) / BCH; + const int rch = outch % BCH; + + const float *data = weights_mat.data; + const int klength = inch * kernel_h * kernel_w; + + sparse_weight_t *sparse_weight = new sparse_weight_t[nch]; + if (!sparse_weight) + return NULL; + + for (int i = 0; i < nch; i++) + { + int _bch = (i != nch - 1 || rch == 0) ? BCH : rch; + sparse_weight_t *sparse_weight_n = &sparse_weight[i]; + sparse_weight_n->mxk = 0; + + for (int j = 0; j < _bch; j++) + { + for (int l = 0; l < klength; l++) + { + float val = *(data + (i * BCH + j) * klength + l); + if (val != 0) + { + sparse_weight_n->mxk++; + } + } + } + } + + for (int i = 0; i < nch; i++) + { + int _bch = (i != nch - 1 || rch == 0) ? BCH : rch; + sparse_weight_t *sparse_weight_n = &sparse_weight[i]; + sparse_weight_n->wdata = new weight_data_t[sparse_weight_n->mxk]; + int index = 0; + + for (int l = 0; l < klength; l++) + { + for (int j = 0; j < _bch; j++) + { + float val = *(data + (i * BCH + j) * klength + l); + if (val != 0) + { + sparse_weight_n->wdata[index].m = i * BCH + j; + sparse_weight_n->wdata[index].k = l; + sparse_weight_n->wdata[index++].data = val; + } + } + } + } + + return (void *)sparse_weight; +} + +void sparse_release(const int outch, void *ptr) +{ + sparse_weight_t *sparse_weight = (sparse_weight_t *)ptr; + const int nch = (outch + BCH - 1) / BCH; + + if (!sparse_weight) + return; + + for (int i = 0; i < nch; i++) + { + sparse_weight_t *sparse_weight_n = &sparse_weight[i]; + if (sparse_weight_n->wdata) + delete[] sparse_weight_n->wdata; + } + + if (sparse_weight) + delete[] sparse_weight; +} + +void srcn_sparse_convolution2D(const convMat_t &in_mat, convMat_t &out_mat, + const convParams_t &in_param, const void *sparse_weight, + int number_threas, convType_t conv_type) +{ + class conv_sparse conv(in_mat, out_mat, in_param, (const sparse_weight_t *)sparse_weight, + number_threas, conv_type); + + for (int i = 0; i < out_mat.c * out_mat.h * out_mat.w; i++) + { + *(out_mat.data + i) = 0; + } + + conv.run(); +} + +void srcn_batch_convolution2D(const convMat_t &in_mat, const convMat_t &weights_mat, + convMat_t &out_mat, const convParams_t &in_param, + const float *winograd_weight, int num_threads, convType_t conv_type) +{ + int winograd_flag = (winograd_weight != NULL); + + if (winograd_flag) + { + if (num_threads > 1) + { + omp_set_num_threads(num_threads); + const int batch = in_mat.n; + const int npart = (batch + num_threads - 1) / num_threads; + const int nn = (batch + npart - 1) / npart; + const int rn = batch % npart; + +#pragma omp parallel for + for (int i = 0; i < nn; i++) + { + const int pn = (i != nn - 1 || rn == 0) ? npart : rn; + convMat_t in_mat_part = {in_mat.w, in_mat.h, in_mat.c, pn, + in_mat.data + i * npart * in_mat.w * in_mat.h * in_mat.c}; + convMat_t out_mat_part = {out_mat.w, out_mat.h, out_mat.c, pn, + out_mat.data + i * npart * out_mat.w * out_mat.h * out_mat.c}; + + class conv_winograd_batch conv(in_mat_part, out_mat_part, in_param, conv_type, + winograd_weight, num_threads); + conv.run(); + } + } + else + { + class conv_winograd_batch conv(in_mat, out_mat, in_param, conv_type, winograd_weight, + num_threads); + conv.run(); + } + } + else + { + if (num_threads == 1) + { + class conv_sgemm_singlethread conv(in_mat, weights_mat, out_mat, in_param, conv_type); + conv.run(); + } + else + { + class conv_sgemm_multithreads conv(in_mat, weights_mat, out_mat, in_param, num_threads, + conv_type); + conv.run(); + } + } +} + +} // namespace srcn +} // namespace nnfw diff --git a/compute/ncnn/src/srcn/winograd.h b/compute/ncnn/src/srcn/winograd.h new file mode 100644 index 000000000..5ad8f1126 --- /dev/null +++ b/compute/ncnn/src/srcn/winograd.h @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NNFW_SRCN_WINOGRAD_H__ +#define __NNFW_SRCN_WINOGRAD_H__ + +namespace nnfw +{ +namespace srcn +{ + +struct winograd_para_3x3s1 +{ + static const int M = 3 + 4 - 1; + static const int N = 3; + + static const double *getG() + { + static const double G[M * N] = { + 1. / 4., 0, 0, -1. / 6., -1. / 6., -1. / 6., -1. / 6., 1. / 6., -1. / 6., + 1. / 24., 1. / 12., 1. / 6., 1. / 24., -1. / 12., 1. / 6., 0, 0, 1, + }; + return G; + } + + static const double *getA() + { + static const double A[M * (M - N + 1)] = { + 1, 0, 0, 0, 1, 1, 1, 1, 1, -1, 1, -1, 1, 2, 4, 8, 1, -2, 4, -8, 0, 0, 0, 1, + }; + return A; + } + + static const double *getB() + { + static const double B[M * M] = { + 4, 0, 0, 0, 0, 0, 0, -4, 4, -2, 2, 4, -5, -4, -4, -1, -1, 0, + 0, 1, -1, 2, -2, -5, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, + }; + return B; + }; +}; + +struct winograd_para_3x3s1_2 +{ + static const int M = 3 + 2 - 1; + static const int N = 3; + + static const double *getG() + { + static const double G[M * N] = { + 1, 0, 0, 1. / 2., 1. / 2., 1. / 2., 1. / 2., -1. / 2., 1. / 2., 0, 0, 1, + }; + return G; + } + + static const double *getA() + { + static const double A[M * (M - N + 1)] = { + 1, 0, 1, 1, 1, -1, 0, 1, + }; + return A; + } + + static const double *getB() + { + static const double B[M * M] = { + 1, 0, 0, 0, 0, 1, -1, -1, -1, 1, 1, 0, 0, 0, 0, 1, + }; + return B; + }; +}; + +struct winograd_para_5x5s1 +{ + static const int M = 5 + 4 - 1; + static const int N = 5; + + static const double *getG() + { + static const double G[M * N] = { + 1, 0, 0, 0, 0, -2. / 9., -2. / 9., -2. / 9., + -2. / 9., -2. / 9., -2. / 9., 2. / 9., -2. / 9., 2. / 9., -2. / 9., 1. / 90., + 1. / 45., 2. / 45., 4. / 45., 8. / 45., 1. / 90., -1. / 45., 2. / 45., -4. / 45., + 8. / 45., 4. / 45., 2. / 45., 1. / 45., 1. / 90., 1. / 180., 4. / 45., -2. / 45., + 1. / 45., -1. / 90., 1. / 180., 0, 0, 0, 0, 1, + }; + return G; + } + + static const double *getA() + { + static const double A[M * (M - N + 1)] = {1, 0, 0, 0, 1, 1, 1, 1, 1, -1, 1, -1, 1, 2, 4, 8, + 1, -2, 4, -8, 8, 4, 2, 1, 8, -4, 2, -1, 0, 0, 0, 1}; + return A; + } + + static const double *getB() + { + static const double B[M * M] = { + 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, + -1, 1. / 2, -1. / 2, 2, -2, -1, -21. / 4, 1, 1, 1. / 4, + 1. / 4, 4, 4, 0, 0, -17. / 4, 17. / 4, -5. / 2, 5. / 2, -5. / 2, + 5. / 2, 21. / 4, 21. / 4, -17. / 4, -17. / 4, -5. / 4, -5. / 4, -5, -5, 0, + 0, 1, -1, 2, -2, 1. / 2, -1. / 2, -21. / 4, -1, 1, + 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + 0, 0, 0, 1, + }; + return B; + } +}; + +static void kronecker_product(float *out, const double *in1, const double *in2, int m, int n, int p, + int q) +{ + for (int i = 0; i < m; ++i) + { + for (int j = 0; j < n; ++j) + { + for (int k = 0; k < p; ++k) + { + for (int l = 0; l < q; ++l) + { + out[(p * i + k) * n * q + q * j + l] = in1[n * i + j] * in2[k * q + l]; + /* compute in double precision and then convert it back to Dtype for accuracy */ + } + } + } + } +} + +} // namespace srcn +} // namespace nnfw + +#endif // __NNFW_SRCN_WINOGRAD_H__ |