diff options
author | Peter Goldsborough <psag@fb.com> | 2018-01-19 15:37:01 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-01-19 15:44:34 -0800 |
commit | d401c26d633e49de201d1faf6e3e84e25d185b60 (patch) | |
tree | 2aa8a2e9978b5fc44bc8d669b6ba41618566ab18 /caffe2/perfkernels | |
parent | 8dc0702af5b141aceb6291f54a2746cc26427dac (diff) | |
download | pytorch-d401c26d633e49de201d1faf6e3e84e25d185b60.tar.gz pytorch-d401c26d633e49de201d1faf6e3e84e25d185b60.tar.bz2 pytorch-d401c26d633e49de201d1faf6e3e84e25d185b60.zip |
Add FusedEmbeddingLookup
Summary:
Updates the perfkernel codebase to implement embedding lookup for our new fused storage format, where each row in the data matrix stores the quantized values *and* the scale and bias.
msmelyan see this as my best-effort attempt at updating the perfkernel stuff for the fused storage. Let me know if any of this is grossly wrong. I also don't know if we need to update any of the prefetching operations or something like that.
Note that we have to keep the old code around for a bit until we get rid of the old operations with separate `scale_bias` storage.
Reviewed By: kennyhorror
Differential Revision: D6710843
fbshipit-source-id: b485ef2389f526c5db1260cac9d4be3fc8df0979
Diffstat (limited to 'caffe2/perfkernels')
-rw-r--r-- | caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc | 2760 | ||||
-rw-r--r-- | caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc | 170 | ||||
-rw-r--r-- | caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h | 67 | ||||
-rw-r--r-- | caffe2/perfkernels/hp_emblookup_codegen.py | 188 |
4 files changed, 3113 insertions, 72 deletions
diff --git a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc new file mode 100644 index 0000000000..7277df7842 --- /dev/null +++ b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc @@ -0,0 +1,2760 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//// -------------------------- +//// ATTENTION: +//// THIS CODE IS AUTOGENERATED +//// BY caffe2/caffe2/perfkernels/hp_emblookup_codegen.py +//// DO NOT MODIFY!!! +//// -------------------------- + +#include <caffe2/core/common.h> +#include <caffe2/core/types.h> +#include <immintrin.h> + +namespace caffe2 { + +void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const float* input, + const int32_t* indices, + const int* lengths, + const float* weights, + bool normalize_by_lengths, + float* out) { + const int32_t prefdist_T0 = 16; + const int32_t fused_block_size = block_size + 2; + if (block_size == 128) { + // unrolling 16 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); + // skip unnecessary prefetch of (&ip_next_T0[56]) + vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); + _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); + // skip unnecessary prefetch of (&ip_next_T0[72]) + vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); + _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); + vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); + // skip unnecessary prefetch of (&ip_next_T0[88]) + vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); + _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); + // skip unnecessary prefetch of (&ip_next_T0[104]) + vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); + _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); + vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); + // skip unnecessary prefetch of (&ip_next_T0[120]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); + // skip unnecessary prefetch of (&ip_next_T0[56]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else if (block_size == 16) { + // unrolling 2 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + } + } + } else { + // generic code + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + TIndex j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); + _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + } + for (; j < block_size; j++) { + op[j] += wgt * ip[j]; + } + } + if (normalize_by_lengths && lengths[rangeIndex]) { + float len_inv = 1.0f / lengths[rangeIndex]; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } +} + +void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const float* input, + const int64_t* indices, + const int* lengths, + const float* weights, + bool normalize_by_lengths, + float* out) { + const int64_t prefdist_T0 = 16; + const int64_t fused_block_size = block_size + 2; + if (block_size == 128) { + // unrolling 16 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); + // skip unnecessary prefetch of (&ip_next_T0[56]) + vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); + _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); + // skip unnecessary prefetch of (&ip_next_T0[72]) + vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); + _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); + vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); + // skip unnecessary prefetch of (&ip_next_T0[88]) + vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); + _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); + // skip unnecessary prefetch of (&ip_next_T0[104]) + vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); + _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); + vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); + // skip unnecessary prefetch of (&ip_next_T0[120]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); + // skip unnecessary prefetch of (&ip_next_T0[56]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else if (block_size == 16) { + // unrolling 2 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + } + } + } else { + // generic code + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + TIndex j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); + _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + } + for (; j < block_size; j++) { + op[j] += wgt * ip[j]; + } + } + if (normalize_by_lengths && lengths[rangeIndex]) { + float len_inv = 1.0f / lengths[rangeIndex]; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } +} + +void Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float__avx2_fma( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const float16* input, + const int32_t* indices, + const int* lengths, + const float* weights, + bool normalize_by_lengths, + float* out) { + const int32_t prefdist_T0 = 16; + const int32_t fused_block_size = block_size + 4; + if (block_size == 128) { + // unrolling 16 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), + vop16); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), + vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), + vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), + vop40); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), + vop48); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), + vop56); + // skip unnecessary prefetch of (&ip_next_T0[56]) + vop64 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))), + vop64); + _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))), + vop72); + // skip unnecessary prefetch of (&ip_next_T0[72]) + vop80 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))), + vop80); + // skip unnecessary prefetch of (&ip_next_T0[80]) + vop88 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))), + vop88); + // skip unnecessary prefetch of (&ip_next_T0[88]) + vop96 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))), + vop96); + _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + vop104 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))), + vop104); + // skip unnecessary prefetch of (&ip_next_T0[104]) + vop112 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))), + vop112); + // skip unnecessary prefetch of (&ip_next_T0[112]) + vop120 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))), + vop120); + // skip unnecessary prefetch of (&ip_next_T0[120]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), + vop16); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), + vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), + vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), + vop40); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), + vop48); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), + vop56); + // skip unnecessary prefetch of (&ip_next_T0[56]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), + vop16); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), + vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else if (block_size == 16) { + // unrolling 2 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + } + } + } else { + // generic code + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + TIndex j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps(_mm_loadu_si128( + reinterpret_cast<const __m128i*>(&ip[j]))), + _mm256_loadu_ps(&op[j]))); + _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + } + float16 vtmp1[8] CAFFE2_ALIGNED(64); + for (; j < block_size; j++) { + vtmp1[0] = ip[j]; + __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1)); + op[j] += wgt * ((float*)(&vtmp2))[0]; + } + } + if (normalize_by_lengths && lengths[rangeIndex]) { + float len_inv = 1.0f / lengths[rangeIndex]; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } +} + +void Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float__avx2_fma( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const float16* input, + const int64_t* indices, + const int* lengths, + const float* weights, + bool normalize_by_lengths, + float* out) { + const int64_t prefdist_T0 = 16; + const int64_t fused_block_size = block_size + 4; + if (block_size == 128) { + // unrolling 16 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), + vop16); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), + vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), + vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), + vop40); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), + vop48); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), + vop56); + // skip unnecessary prefetch of (&ip_next_T0[56]) + vop64 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))), + vop64); + _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))), + vop72); + // skip unnecessary prefetch of (&ip_next_T0[72]) + vop80 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))), + vop80); + // skip unnecessary prefetch of (&ip_next_T0[80]) + vop88 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))), + vop88); + // skip unnecessary prefetch of (&ip_next_T0[88]) + vop96 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))), + vop96); + _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + vop104 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))), + vop104); + // skip unnecessary prefetch of (&ip_next_T0[104]) + vop112 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))), + vop112); + // skip unnecessary prefetch of (&ip_next_T0[112]) + vop120 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))), + vop120); + // skip unnecessary prefetch of (&ip_next_T0[120]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), + vop16); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), + vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))), + vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))), + vop40); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))), + vop48); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))), + vop56); + // skip unnecessary prefetch of (&ip_next_T0[56]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))), + vop16); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))), + vop24); + // skip unnecessary prefetch of (&ip_next_T0[24]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else if (block_size == 16) { + // unrolling 2 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))), + vop8); + // skip unnecessary prefetch of (&ip_next_T0[8]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + } + } + } else { + // generic code + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + TIndex j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps(_mm_loadu_si128( + reinterpret_cast<const __m128i*>(&ip[j]))), + _mm256_loadu_ps(&op[j]))); + _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + } + float16 vtmp1[8] CAFFE2_ALIGNED(64); + for (; j < block_size; j++) { + vtmp1[0] = ip[j]; + __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1)); + op[j] += wgt * ((float*)(&vtmp2))[0]; + } + } + if (normalize_by_lengths && lengths[rangeIndex]) { + float len_inv = 1.0f / lengths[rangeIndex]; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } +} + +void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const uint8_t* input, + const int32_t* indices, + const int* lengths, + const float* weights, + bool normalize_by_lengths, + float* out) { + const int32_t prefdist_T0 = 16; + const int32_t fused_block_size = block_size + 8; + if (block_size == 128) { + // unrolling 16 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + float bio; + if (weights) { + wgt = weights[dataInd]; + } + const float* scale_bias = reinterpret_cast<const float*>( + &input[idx * fused_block_size + block_size]); + bio = wgt * scale_bias[1]; + wgt = wgt * scale_bias[0]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), + _mm256_add_ps(vop16, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), + _mm256_add_ps(vop24, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))), + _mm256_add_ps(vop32, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[32]) + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))), + _mm256_add_ps(vop40, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))), + _mm256_add_ps(vop48, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))), + _mm256_add_ps(vop56, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[56]) + vop64 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))), + _mm256_add_ps(vop64, vbio)); + _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))), + _mm256_add_ps(vop72, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[72]) + vop80 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))), + _mm256_add_ps(vop80, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[80]) + vop88 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))), + _mm256_add_ps(vop88, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[88]) + vop96 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))), + _mm256_add_ps(vop96, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[96]) + vop104 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))), + _mm256_add_ps(vop104, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[104]) + vop112 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))), + _mm256_add_ps(vop112, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[112]) + vop120 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))), + _mm256_add_ps(vop120, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[120]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + float bio; + if (weights) { + wgt = weights[dataInd]; + } + const float* scale_bias = reinterpret_cast<const float*>( + &input[idx * fused_block_size + block_size]); + bio = wgt * scale_bias[1]; + wgt = wgt * scale_bias[0]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), + _mm256_add_ps(vop16, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), + _mm256_add_ps(vop24, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))), + _mm256_add_ps(vop32, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[32]) + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))), + _mm256_add_ps(vop40, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))), + _mm256_add_ps(vop48, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))), + _mm256_add_ps(vop56, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[56]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + float bio; + if (weights) { + wgt = weights[dataInd]; + } + const float* scale_bias = reinterpret_cast<const float*>( + &input[idx * fused_block_size + block_size]); + bio = wgt * scale_bias[1]; + wgt = wgt * scale_bias[0]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), + _mm256_add_ps(vop16, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), + _mm256_add_ps(vop24, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[24]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else if (block_size == 16) { + // unrolling 2 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + float bio; + if (weights) { + wgt = weights[dataInd]; + } + const float* scale_bias = reinterpret_cast<const float*>( + &input[idx * fused_block_size + block_size]); + bio = wgt * scale_bias[1]; + wgt = wgt * scale_bias[0]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + } + } + } else { + // generic code + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + TIndex j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + float bio; + if (weights) { + wgt = weights[dataInd]; + } + const float* scale_bias = reinterpret_cast<const float*>( + &input[idx * fused_block_size + block_size]); + bio = wgt * scale_bias[1]; + wgt = wgt * scale_bias[0]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( + reinterpret_cast<const __m128i*>(&ip[j])))), + _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); + _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + } + for (; j < block_size; j++) { + op[j] += wgt * ((float)ip[j]) + bio; + } + } + if (normalize_by_lengths && lengths[rangeIndex]) { + float len_inv = 1.0f / lengths[rangeIndex]; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } +} + +void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const uint8_t* input, + const int64_t* indices, + const int* lengths, + const float* weights, + bool normalize_by_lengths, + float* out) { + const int64_t prefdist_T0 = 16; + const int64_t fused_block_size = block_size + 8; + if (block_size == 128) { + // unrolling 16 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + float bio; + if (weights) { + wgt = weights[dataInd]; + } + const float* scale_bias = reinterpret_cast<const float*>( + &input[idx * fused_block_size + block_size]); + bio = wgt * scale_bias[1]; + wgt = wgt * scale_bias[0]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), + _mm256_add_ps(vop16, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), + _mm256_add_ps(vop24, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))), + _mm256_add_ps(vop32, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[32]) + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))), + _mm256_add_ps(vop40, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))), + _mm256_add_ps(vop48, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))), + _mm256_add_ps(vop56, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[56]) + vop64 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))), + _mm256_add_ps(vop64, vbio)); + _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))), + _mm256_add_ps(vop72, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[72]) + vop80 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))), + _mm256_add_ps(vop80, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[80]) + vop88 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))), + _mm256_add_ps(vop88, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[88]) + vop96 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))), + _mm256_add_ps(vop96, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[96]) + vop104 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))), + _mm256_add_ps(vop104, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[104]) + vop112 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))), + _mm256_add_ps(vop112, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[112]) + vop120 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))), + _mm256_add_ps(vop120, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[120]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + float bio; + if (weights) { + wgt = weights[dataInd]; + } + const float* scale_bias = reinterpret_cast<const float*>( + &input[idx * fused_block_size + block_size]); + bio = wgt * scale_bias[1]; + wgt = wgt * scale_bias[0]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), + _mm256_add_ps(vop16, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), + _mm256_add_ps(vop24, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[24]) + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))), + _mm256_add_ps(vop32, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[32]) + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))), + _mm256_add_ps(vop40, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[40]) + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))), + _mm256_add_ps(vop48, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[48]) + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))), + _mm256_add_ps(vop56, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[56]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + float bio; + if (weights) { + wgt = weights[dataInd]; + } + const float* scale_bias = reinterpret_cast<const float*>( + &input[idx * fused_block_size + block_size]); + bio = wgt * scale_bias[1]; + wgt = wgt * scale_bias[0]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))), + _mm256_add_ps(vop16, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[16]) + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))), + _mm256_add_ps(vop24, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[24]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else if (block_size == 16) { + // unrolling 2 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + float bio; + if (weights) { + wgt = weights[dataInd]; + } + const float* scale_bias = reinterpret_cast<const float*>( + &input[idx * fused_block_size + block_size]); + bio = wgt * scale_bias[1]; + wgt = wgt * scale_bias[0]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))), + _mm256_add_ps(vop0, vbio)); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))), + _mm256_add_ps(vop8, vbio)); + // skip unnecessary prefetch of (&ip_next_T0[8]) + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + } + } + } else { + // generic code + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + TIndex j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + CAFFE_ENFORCE( + idx >= 0 && idx < data_size, + "Index ", + dataInd, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + float wgt = 1.f; + float bio; + if (weights) { + wgt = weights[dataInd]; + } + const float* scale_bias = reinterpret_cast<const float*>( + &input[idx * fused_block_size + block_size]); + bio = wgt * scale_bias[1]; + wgt = wgt * scale_bias[0]; + __m256 vbio = _mm256_set1_ps(bio); + __m256 vwgt = _mm256_set1_ps(wgt); + const uint8_t* ip = &input[idx * fused_block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size); + const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, + _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( + reinterpret_cast<const __m128i*>(&ip[j])))), + _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); + _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + } + for (; j < block_size; j++) { + op[j] += wgt * ((float)ip[j]) + bio; + } + } + if (normalize_by_lengths && lengths[rangeIndex]) { + float len_inv = 1.0f / lengths[rangeIndex]; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } +} + +} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc new file mode 100644 index 0000000000..7aa756bdeb --- /dev/null +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc @@ -0,0 +1,170 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h" + +#include "caffe2/core/types.h" +#include "caffe2/perfkernels/common.h" +#include "caffe2/perfkernels/typed_axpy.h" +#include "caffe2/utils/cpuid.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +// Base implementation does runtime dispatch for each segment of reduction +template <typename IndexType, typename InType, typename OutType> +static void Fused8BitRowwiseEmbeddingLookupGenericSlow( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const InType* input, + const IndexType* indices, + const int* lengths, + const float* weights, // optional, can be null for sum reducer + bool normalize_by_lengths, + OutType* out) { + // block_size is the number of elements and fused_block_size is the size of + // an entire row, including scale and bias. + const auto scale_bias_offset = 8 / sizeof(InType); + const TIndex fused_block_size = block_size + scale_bias_offset; + TIndex current = 0; + for (int m = 0; m < output_size; ++m) { + memset(out, 0, sizeof(OutType) * block_size); + EigenVectorArrayMap<OutType> out_vector(out, block_size); + for (int i = 0; i < lengths[m]; ++i) { + CAFFE_ENFORCE_LT(current, index_size); + TIndex idx = indices[current]; + CAFFE_ENFORCE( + 0 <= idx && idx < data_size, + "Index ", + current, + " is out of bounds: ", + idx, + ", range 0 to ", + data_size); + CAFFE_ENFORCE_LT(idx, data_size); +#ifdef __GNUC__ + if (current + 1 < index_size) { + __builtin_prefetch( + input + fused_block_size * indices[current + 1], 0, 1); + } +#endif // __GNUC__ + + const float* scale_bias = reinterpret_cast<const float*>( + input + fused_block_size * indices[current] + block_size); + + const float weight = weights ? weights[current] : 1.0f; + const float scale = weight * scale_bias[0]; + const float bias = weight * scale_bias[1]; + + TypedAxpy<InType, OutType>( + block_size, scale, input + fused_block_size * indices[current], out); + + out_vector += bias; + + ++current; + } + if (normalize_by_lengths && lengths[m]) { + // hack: context is not really used + math::Scale<OutType, CPUContext>( + block_size, 1.f / lengths[m], out, out, nullptr); + } + out += block_size; + } + CAFFE_ENFORCE_EQ( + current, + index_size, + "Your input seems to be incorrect: the sum of lengths values should be " + "the size of the indices tensor, but it appears not."); +} + +// Proxy back to generic implementation +#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION( \ + IndexType, InType, OutType) \ + void \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##__base( \ + const TIndex block_size, \ + const TIndex output_size, \ + const TIndex index_size, \ + const TIndex data_size, \ + const InType* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OutType* out) { \ + Fused8BitRowwiseEmbeddingLookupGenericSlow<IndexType, InType, OutType>( \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + normalize_by_lengths, \ + out); \ + } \ + template <> \ + void Fused8BitRowwiseEmbeddingLookup( \ + const TIndex block_size, \ + const TIndex output_size, \ + const TIndex index_size, \ + const TIndex data_size, \ + const InType* input, \ + const IndexType* indices, \ + const int* lengths, \ + const float* weights, \ + bool normalize_by_lengths, \ + OutType* out) { \ + const int32_t one = 1; \ + CAFFE_ENFORCE_EQ( \ + reinterpret_cast<const uint8_t*>(&one)[0], \ + 1, \ + "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ + AVX2_FMA_DO( \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + normalize_by_lengths, \ + out); \ + BASE_DO( \ + Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + lengths, \ + weights, \ + normalize_by_lengths, \ + out); \ + } + +FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, uint8_t, float); +FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, uint8_t, float); + +#undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION + +} // namespace caffe2 diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h new file mode 100644 index 0000000000..5251331e3e --- /dev/null +++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h @@ -0,0 +1,67 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "caffe2/core/common.h" + +namespace caffe2 { + +/** + * Embedding lookup with reduction. + * + * `input` of size data_size * (block_size + 8B) + * `indices` of size index_size + * `lengths` of size output_size + * `weights` nullptr or array of size index_size + * `out` of size output_size * block_size + * sum(lengths[i]) == index_size + * + * Note that block_size should be the number of quantized values per row in the + * data, i.e. excluding the scale and bias. The total (fused) block size is + * assumed to be this block_size, plus 4 bytes for scale and 4 bytes for bias. + * + * Behavior is roughly equivalent to pseudocode: + * + * pos = 0 + * fused_block_size = block_size + 8B // quantized values and scale and bias + * for (i = 0..index_size-1) + * for (k = 0..block_size-1) + * out[i*block_size + k] = 0 + * for (j = 0..lengths[i]-1) + * for (k = 0..block_size-1) + * out[i*block_size + k] += input[indices[pos]*(fused_block_size) + k] * + * (weights ? weights[pos] : 1.0) + * pos += 1 + * if (normalize_weights && lengths[i] > 0) + * for (k = 0..block_size-1) + * out[i*block_size + k] /= lengths[i] + * + */ + +template <typename IndexType, typename InType, typename OutType> +void Fused8BitRowwiseEmbeddingLookup( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const InType* input, + const IndexType* indices, + const int* lengths, + const float* weights, // optional, can be null for non-weighted sum + bool normalize_by_lengths, + OutType* out); +} // namespace caffe2 diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py index 87c656a8be..28afc95841 100644 --- a/caffe2/perfkernels/hp_emblookup_codegen.py +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -20,47 +20,40 @@ from __future__ import unicode_literals import argparse import sys +sizeof = {'float': 4, 'float16': 2, 'uint8_t': 1} -def unroll(uf, IndexType, InType, OutType, use_weights, isa): - - def sizeof(InType): - size = 0 - if InType == "float": - size = 4 - elif InType == "float16": - size = 2 - elif InType == "uint8_t": - size = 1 - else: - assert False - - return size +def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused): def compute(regid, InType, use_weights, isa, prefetch): code = [] if InType == "float": - code.append("vop%d = _mm256_fmadd_ps(vwgt, \ - _mm256_loadu_ps(ip + (%d)), vop%d);" % (regid, regid, regid)) - + code.append( + "vop%d = _mm256_fmadd_ps(vwgt, \ + _mm256_loadu_ps(ip + (%d)), vop%d);" + % (regid, regid, regid) + ) elif InType == "float16": - code.append("vop%d = _mm256_fmadd_ps(vwgt, \ + code.append( + "vop%d = _mm256_fmadd_ps(vwgt, \ _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))), \ vop%d);" - % (regid, regid, regid)) + % (regid, regid, regid) + ) elif InType == "uint8_t": - code.append("vop%d = _mm256_fmadd_ps(vwgt, \ + code.append( + "vop%d = _mm256_fmadd_ps(vwgt, \ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))), \ _mm256_add_ps(vop%d, vbio));" - % (regid, regid, regid)) + % (regid, regid, regid) + ) else: assert False - - if prefetch == True: + if prefetch: code.append("_mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid)) else: - code.append("// skip unecassery prefetch of (&ip_next_T0[%d])" % (regid)) + code.append("// skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid)) return code @@ -87,8 +80,16 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa): code.append("if (weights) {") code.append("wgt = weights[dataInd];") code.append("}") - code.append("bio = wgt * scale_bias[2 * idx + 1];"); - code.append("wgt = wgt * scale_bias[2 * idx];"); + if fused: + code.append( + 'const float* scale_bias = reinterpret_cast<' + 'const float*>(&input[idx * fused_block_size + block_size]);' + ) + code.append("bio = wgt * scale_bias[1];") + code.append("wgt = wgt * scale_bias[0];") + else: + code.append("bio = wgt * scale_bias[2 * idx + 1];") + code.append("wgt = wgt * scale_bias[2 * idx];") code.append("__m256 vbio = _mm256_set1_ps(bio);") else: code.append(OutType + " wgt = 1.f;") @@ -97,20 +98,25 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa): code.append("}") code.append("__m256 vwgt = _mm256_set1_ps(wgt);") - code.append("const " + InType + " *ip = &input[idx * block_size];") - code.append("const " + IndexType + - " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;"); + code.append("const {} *ip = &input[idx * fused_block_size];".format(InType)) + code.append( + 'const {} next_T0 = (dataInd < index_size - prefdist_T0)' + ' ? (dataInd + prefdist_T0) : dataInd;'.format(IndexType) + ) code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];") code.append( "CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);") - code.append("const " + InType + - " *ip_next_T0 = &input[idx_pref_T0 * block_size];") + + code.append( + 'const {} *ip_next_T0 = &input[idx_pref_T0' + ' * fused_block_size];'.format(InType) + ) for i in range(0, uf): j = 8 * i cachelinesize = 64 - byteoffset = sizeof(InType) * j - prefetch = ((byteoffset % cachelinesize) == 0) + byteoffset = sizeof[InType] * j + prefetch = (byteoffset % cachelinesize) == 0 code.extend(compute(j, InType, use_weights, isa, prefetch)) code.append("}") @@ -133,25 +139,31 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa): return code -def generic(IndexType, InType, OutType, use_weights, isa): +def generic(IndexType, InType, OutType, use_weights, isa, fused): def compute(InType, use_weights, isa): code = [] if InType == "float": - code.append("_mm256_storeu_ps(&op[j], \ + code.append( + "_mm256_storeu_ps(&op[j], \ _mm256_fmadd_ps(vwgt,_mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])) \ - );") + );" + ) elif InType == "float16": - code.append("_mm256_storeu_ps(&op[j], \ + code.append( + "_mm256_storeu_ps(&op[j], \ _mm256_fmadd_ps(vwgt, \ _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(&ip[j]))), _mm256_loadu_ps(&op[j])) \ - );") + );" + ) elif InType == "uint8_t": - code.append("_mm256_storeu_ps(&op[j], \ + code.append( + "_mm256_storeu_ps(&op[j], \ _mm256_fmadd_ps(vwgt, \ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(&ip[j])))), \ _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio) ) \ - );") + );" + ) else: assert False @@ -188,9 +200,17 @@ def generic(IndexType, InType, OutType, use_weights, isa): code.append("if (weights) {") code.append("wgt = weights[dataInd];") code.append("}") - code.append("assert (scale_bias);") - code.append("bio = wgt * scale_bias[2 * idx + 1];"); - code.append("wgt = wgt * scale_bias[2 * idx];"); + if fused: + code.append( + 'const float* scale_bias = reinterpret_cast<' + 'const float*>(&input[idx * fused_block_size + block_size]);' + ) + code.append("bio = wgt * scale_bias[1];") + code.append("wgt = wgt * scale_bias[0];") + else: + code.append("assert (scale_bias);") + code.append("bio = wgt * scale_bias[2 * idx + 1];") + code.append("wgt = wgt * scale_bias[2 * idx];") code.append("__m256 vbio = _mm256_set1_ps(bio);") else: code.append(OutType + " wgt = 1.f;") @@ -199,14 +219,18 @@ def generic(IndexType, InType, OutType, use_weights, isa): code.append("}") code.append("__m256 vwgt = _mm256_set1_ps(wgt);") - code.append("const " + InType + " *ip = &input[idx * block_size];") - code.append("const " + IndexType + - " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;"); + code.append("const {} *ip = &input[idx * fused_block_size];".format(InType)) + code.append( + 'const {} next_T0 = (dataInd < index_size - prefdist_T0)' + ' ? (dataInd + prefdist_T0) : dataInd;'.format(IndexType) + ) code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];") code.append( "CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);") - code.append("const " + InType + - " *ip_next_T0 = &input[idx_pref_T0 * block_size];") + code.append( + "const {} *ip_next_T0 = &input[idx_pref_T0 * fused_block_size];". + format(InType) + ) # compute and store main loop code.append("j = 0;") @@ -215,7 +239,6 @@ def generic(IndexType, InType, OutType, use_weights, isa): code.append("}") # leftover if InType == "float16": - #code.append("float16 vtmp1[8] __attribute__((aligned(64)));") code.append("float16 vtmp1[8] CAFFE2_ALIGNED(64);") code.append("for(; j < block_size; j++) {") if InType == "float": @@ -251,14 +274,17 @@ def generic(IndexType, InType, OutType, use_weights, isa): return code - # start main code parser = argparse.ArgumentParser() -parser.add_argument('-f', nargs=1, help="file name") +parser.add_argument('-f', '--filename', help="file name") +parser.add_argument('--fused', action='store_true') opts = parser.parse_args() -filename = "embedding_lookup_avx2.cc" -if opts.f: - filename = (opts.f)[0] +if opts.filename: + filename = opts.filename +elif opts.fused: + filename = "embedding_lookup_fused_8bit_rowwise_avx2.cc" +else: + filename = "embedding_lookup_avx2.cc" fout = open(filename, 'w') options = [["int32_t", "float", "float"], @@ -289,14 +315,14 @@ code.append( */ """) code.append("//// --------------------------") -code.append("//// ATTENTION: ") +code.append("//// ATTENTION:") code.append("//// THIS CODE IS AUTOGENERATED") -code.append("//// BY %s " % (sys.argv[0])) -code.append("//// DO NOT MODIFY!!! ") +code.append("//// BY {}".format(sys.argv[0])) +code.append("//// DO NOT MODIFY!!!") code.append("//// --------------------------\n\n") -code.append("#include \"caffe2/core/types.h\"") -code.append("#include \"caffe2/core/common.h\"") +code.append("#include <caffe2/core/types.h>") +code.append("#include <caffe2/core/common.h>") code.append("#include <immintrin.h>") code.append("\n") @@ -304,9 +330,11 @@ code.append("namespace caffe2 {\n") for o in options: [IndexType, InType, OutType] = o - fn = "void EmbeddingLookup_" + IndexType + \ - "_" + InType + "_" + OutType + "__avx2_fma" - code.append(fn + "(") + prefix = 'Fused8BitRowwise' if opts.fused else '' + fn = 'void {}EmbeddingLookup_{}_{}_{}__avx2_fma('.format( + prefix, IndexType, InType, OutType + ) + code.append(fn) code.append("const TIndex block_size,") code.append("const TIndex output_size,") code.append("const TIndex index_size,") @@ -315,29 +343,45 @@ for o in options: code.append("const " + IndexType + "* indices,") code.append("const int* lengths,") code.append("const float* weights,") - code.append("const float* scale_bias,") + if not opts.fused: + code.append("const float* scale_bias,") code.append("bool normalize_by_lengths,") code.append(OutType + "* out)") code.append("{") code.append("const " + IndexType + " prefdist_T0 = 16;") + # block_size is the number of elements and fused_block_size is the size of + # an entire row, including scale and bias. + offset = (8 // sizeof[InType]) if opts.fused else 0 + code.append( + "const {} fused_block_size = block_size + {};". + format(IndexType, offset) + ) + #code.append("printf(\"calling " + fn + "\\n\");"); - if InType != "uint8_t": - code.append("CAFFE_ENFORCE(scale_bias == nullptr, \"scale_bias must be nullptr\");"); - else: - code.append("CAFFE_ENFORCE(scale_bias != nullptr, \"scale_bias must not be nullptr\");"); + if not opts.fused: + if InType != "uint8_t": + code.append( + 'CAFFE_ENFORCE(scale_bias == nullptr,' + ' "scale_bias must be nullptr");' + ) + else: + code.append( + 'CAFFE_ENFORCE(scale_bias != nullptr,' + ' "scale_bias must not be nullptr");' + ) code.append("if (block_size == 128) {") - code.extend(unroll(16, IndexType, InType, OutType, True, "AVX2")) + code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused) code.append("} else if (block_size == 64) {") - code.extend(unroll(8, IndexType, InType, OutType, True, "AVX2")) + code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused) code.append("} else if (block_size == 32) {") - code.extend(unroll(4, IndexType, InType, OutType, True, "AVX2")) + code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused) code.append("} else if (block_size == 16) {") - code.extend(unroll(2, IndexType, InType, OutType, True, "AVX2")) + code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused) code.append("} else {") code.append("// generic code") - code.extend(generic(IndexType, InType, OutType, True, "AVX2")) + code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused) code.append("}") |