From 080fab8f6c82daa93b478267f3a604e1bdb8083f Mon Sep 17 00:00:00 2001 From: Misha Smelyanskiy Date: Wed, 30 Aug 2017 16:11:39 -0700 Subject: Code generator for and high-performance emebding look-up kernels, supporting Summary: Code generator for and high-performance emebding look-up kernels, supporting Sum, WeightedSum, and Mean reducers. Achieve at least 1.5x speedup on float and over 2x speedup for float16, compared to existing code These are results on Broadwell, using sparse_lengths_sum_benchmar.par benchmark Old ============== [root@fblearner001.01.ftw1 /home/msmelyan]# numactl -m 0 -C 0 ./sparse_lengths_sum_benchmark.par --iteration 10000 Preparing lookup table. 2017-08-08 00:10:23.101848 Preparation finished. 2017-08-08 00:10:27.955680 I0808 00:10:27.955732 30700 net.cc:177] Starting benchmark. I0808 00:10:27.955759 30700 net.cc:178] Running warmup runs. I0808 00:10:27.956367 30700 net.cc:188] Main runs. I0808 00:10:31.839035 30700 net.cc:199] Main run finished. Milliseconds per iter: 0.388264. Iters per second: 2575.56 I0808 00:10:35.704169 30700 net.cc:233] Operator #0 (indices, Python) 0.0583264 ms/iter I0808 00:10:35.704210 30700 net.cc:233] Operator #1 (Y, SparseLengthsSum) 0.327694 ms/iter I0808 00:10:35.704213 30700 net.cc:237] Time per operator type: I0808 00:10:35.704217 30700 net.cc:246] 0.327694 SparseLengthsSum I0808 00:10:35.704221 30700 net.cc:246] 0.0583264 Python [root@fblearner001.01.ftw1 /home/msmelyan]# numactl -m 0 -C 0 ./sparse_lengths_sum_benchmark.par --iteration 10000 --dtype float16 Preparing lookup table. 2017-08-08 00:10:59.047159 Preparation finished. 2017-08-08 00:11:05.140565 I0808 00:11:05.140612 31725 net.cc:177] Starting benchmark. I0808 00:11:05.140635 31725 net.cc:178] Running warmup runs. I0808 00:11:05.141104 31725 net.cc:188] Main runs. I0808 00:11:08.371510 31725 net.cc:199] Main run finished. Milliseconds per iter: 0.323039. Iters per second: 3095.6 I0808 00:11:11.671450 31725 net.cc:233] Operator #0 (indices, Python) 0.0609876 ms/iter I0808 00:11:11.671489 31725 net.cc:233] Operator #1 (Y, SparseLengthsSum) 0.26856 ms/iter I0808 00:11:11.671494 31725 net.cc:237] Time per operator type: I0808 00:11:11.671497 31725 net.cc:246] 0.26856 SparseLengthsSum I0808 00:11:11.671500 31725 net.cc:246] 0.0609876 Python New (Misha's) ============== [root@fblearner001.01.ftw1 /home/msmelyan]# numactl -m 0 -C 0 ./sparse_lengths_sum_benchmark.par --iteration 10000 Preparing lookup table. 2017-08-07 23:44:55.897748 Preparation finished. 2017-08-07 23:45:00.708896 I0807 23:45:00.708945 4178361 net.cc:177] Starting benchmark. I0807 23:45:00.708971 4178361 net.cc:178] Running warmup runs. I0807 23:45:00.709444 4178361 net.cc:188] Main runs. I0807 23:45:03.608551 4178361 net.cc:199] Main run finished. Milliseconds per iter: 0.289909. Iters per second: 3449.36 I0807 23:45:06.536182 4178361 net.cc:233] Operator #0 (indices, Python) 0.0572399 ms/iter I0807 23:45:06.536224 4178361 net.cc:233] Operator #1 (Y, SparseLengthsSum) 0.23512 ms/iter I0807 23:45:06.536228 4178361 net.cc:237] Time per operator type: I0807 23:45:06.536232 4178361 net.cc:246] 0.23512 SparseLengthsSum I0807 23:45:06.536236 4178361 net.cc:246] 0.0572399 Python [root@fblearner001.01.ftw1 /home/msmelyan]# numactl -m 0 -C 0 ./sparse_lengths_sum_benchmark.par --iteration 10000 --dtype float16 Preparing lookup table. 2017-08-07 23:45:17.191579 Preparation finished. 2017-08-07 23:45:23.173668 I0807 23:45:23.173715 4179316 net.cc:177] Starting benchmark. I0807 23:45:23.173743 4179316 net.cc:178] Running warmup runs. I0807 23:45:23.174090 4179316 net.cc:188] Main runs. I0807 23:45:24.939749 4179316 net.cc:199] Main run finished. Milliseconds per iter: 0.176564. Iters per second: 5663.67 I0807 23:45:26.698885 4179316 net.cc:233] Operator #0 (indices, Python) 0.0557303 ms/iter I0807 23:45:26.698923 4179316 net.cc:233] Operator #1 (Y, SparseLengthsSum) 0.119794 ms/iter I0807 23:45:26.698927 4179316 net.cc:237] Time per operator type: I0807 23:45:26.698931 4179316 net.cc:246] 0.119794 SparseLengthsSum I0807 23:45:26.698935 4179316 net.cc:246] 0.0557303 Python Reviewed By: salexspb Differential Revision: D5582172 fbshipit-source-id: d71f5a55580b734a51b8f30852b75f379acfdaf2 --- caffe2/perfkernels/embedding_lookup_avx2.cc | 1444 ++++++++++++++++++++++++++- caffe2/perfkernels/hp_emblookup_codegen.py | 232 +++++ 2 files changed, 1637 insertions(+), 39 deletions(-) create mode 100644 caffe2/perfkernels/hp_emblookup_codegen.py (limited to 'caffe2/perfkernels') diff --git a/caffe2/perfkernels/embedding_lookup_avx2.cc b/caffe2/perfkernels/embedding_lookup_avx2.cc index 1e3d6fa443..8b87f486b2 100644 --- a/caffe2/perfkernels/embedding_lookup_avx2.cc +++ b/caffe2/perfkernels/embedding_lookup_avx2.cc @@ -1,49 +1,1415 @@ +#include +#include "caffe2/core/common.h" #include "caffe2/core/types.h" namespace caffe2 { -// TODO(msmelyan): implement code generator for implementation based on -// following parameters: -// index type: int32, int64 (encoded in function name) -// embedding data type: float16, float32 (encoded in function name) -// output type: float (encoded in function name) -// weighted reduction: whether `weights` is nullptr or not -// normalization (divide by lengths[i]): whether normalize_by_lengths is true -// block size: 32, 64, 128, generic +void EmbeddingLookup_int32_t_float_float__avx2_fma( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const float* input, + const int32_t* indices, + const int* lengths, + const float* weights, + bool normalize_by_lengths, + float* out) { + const int32_t prefdist_T0 = 16; + if (block_size == 128) { + // unrolling 16 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); + _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); + vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); + _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); + vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); + _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); + _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0); + vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); + _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); + vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); + _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0); + vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); + _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); + _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0); + vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); + _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); + vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); + _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); + _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); + vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); + _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else { + // generic code + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + TIndex j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); + _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + } + for (; j < block_size; j++) { + op[j] += wgt * ip[j]; + } + } + if (normalize_by_lengths && lengths[rangeIndex]) { + float len_inv = 1.0f / lengths[rangeIndex]; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } +} -// For now just invoke base implementation (this entire file can be autogenned) -#define EMBEDDING_SPECIALIZATION(IndexType, InType, OutType) \ - void EmbeddingLookup_##IndexType##_##InType##_##OutType##__avx2_fma( \ - const TIndex block_size, \ - const TIndex output_size, \ - const TIndex index_size, \ - const TIndex data_size, \ - const InType* input, \ - const IndexType* indices, \ - const int* lengths, \ - const float* weights, \ - bool normalize_by_lengths, \ - OutType* out) { \ - decltype(EmbeddingLookup_##IndexType##_##InType##_##OutType##__avx2_fma) \ - EmbeddingLookup_##IndexType##_##InType##_##OutType##__base; \ - EmbeddingLookup_##IndexType##_##InType##_##OutType##__base( \ - block_size, \ - output_size, \ - index_size, \ - data_size, \ - input, \ - indices, \ - lengths, \ - weights, \ - normalize_by_lengths, \ - out); \ +void EmbeddingLookup_int64_t_float_float__avx2_fma( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const float* input, + const int64_t* indices, + const int* lengths, + const float* weights, + bool normalize_by_lengths, + float* out) { + const int64_t prefdist_T0 = 16; + if (block_size == 128) { + // unrolling 16 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); + _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); + vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); + _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); + vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); + _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); + _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0); + vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); + _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); + vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); + _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0); + vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); + _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); + _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0); + vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); + _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); + vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); + _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); + _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); + vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); + _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else { + // generic code + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + TIndex j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float* ip = &input[idx * block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float* ip_next_T0 = &input[idx_pref_T0 * block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); + _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + } + for (; j < block_size; j++) { + op[j] += wgt * ip[j]; + } + } + if (normalize_by_lengths && lengths[rangeIndex]) { + float len_inv = 1.0f / lengths[rangeIndex]; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } } +} -EMBEDDING_SPECIALIZATION(int32_t, float, float); -EMBEDDING_SPECIALIZATION(int64_t, float, float); -EMBEDDING_SPECIALIZATION(int32_t, float16, float); -EMBEDDING_SPECIALIZATION(int64_t, float16, float); +void EmbeddingLookup_int32_t_float16_float__avx2_fma( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const float16* input, + const int32_t* indices, + const int* lengths, + const float* weights, + bool normalize_by_lengths, + float* out) { + const int32_t prefdist_T0 = 16; + if (block_size == 128) { + // unrolling 16 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (8)))), + vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (16)))), + vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (24)))), + vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (32)))), + vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (40)))), + vop40); + _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (48)))), + vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (56)))), + vop56); + _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); + vop64 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (64)))), + vop64); + _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (72)))), + vop72); + _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0); + vop80 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (80)))), + vop80); + _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); + vop88 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (88)))), + vop88); + _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0); + vop96 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (96)))), + vop96); + _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + vop104 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (104)))), + vop104); + _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0); + vop112 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (112)))), + vop112); + _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); + vop120 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (120)))), + vop120); + _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (8)))), + vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (16)))), + vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (24)))), + vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (32)))), + vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (40)))), + vop40); + _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (48)))), + vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (56)))), + vop56); + _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (8)))), + vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (16)))), + vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (24)))), + vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else { + // generic code + int32_t dataInd = 0; + for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + TIndex j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int32_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * block_size]; + const int32_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int32_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps(_mm_loadu_si128( + reinterpret_cast(&ip[j]))), + _mm256_loadu_ps(&op[j]))); + _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + } + float16 vtmp1[8] __attribute__((aligned(64))); + for (; j < block_size; j++) { + vtmp1[0] = ip[j]; + __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1)); + op[j] += wgt * ((float*)(&vtmp2))[0]; + } + } + if (normalize_by_lengths && lengths[rangeIndex]) { + float len_inv = 1.0f / lengths[rangeIndex]; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } +} -#undef EMBEDDING_SPECIALIZATION +void EmbeddingLookup_int64_t_float16_float__avx2_fma( + const TIndex block_size, + const TIndex output_size, + const TIndex index_size, + const TIndex data_size, + const float16* input, + const int64_t* indices, + const int* lengths, + const float* weights, + bool normalize_by_lengths, + float* out) { + const int64_t prefdist_T0 = 16; + if (block_size == 128) { + // unrolling 16 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + __m256 vop64 = _mm256_setzero_ps(); + __m256 vop72 = _mm256_setzero_ps(); + __m256 vop80 = _mm256_setzero_ps(); + __m256 vop88 = _mm256_setzero_ps(); + __m256 vop96 = _mm256_setzero_ps(); + __m256 vop104 = _mm256_setzero_ps(); + __m256 vop112 = _mm256_setzero_ps(); + __m256 vop120 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (8)))), + vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (16)))), + vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (24)))), + vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (32)))), + vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (40)))), + vop40); + _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (48)))), + vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (56)))), + vop56); + _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); + vop64 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (64)))), + vop64); + _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0); + vop72 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (72)))), + vop72); + _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0); + vop80 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (80)))), + vop80); + _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0); + vop88 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (88)))), + vop88); + _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0); + vop96 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (96)))), + vop96); + _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0); + vop104 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (104)))), + vop104); + _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0); + vop112 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (112)))), + vop112); + _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0); + vop120 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (120)))), + vop120); + _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + _mm256_storeu_ps(&op[64], vop64); + _mm256_storeu_ps(&op[72], vop72); + _mm256_storeu_ps(&op[80], vop80); + _mm256_storeu_ps(&op[88], vop88); + _mm256_storeu_ps(&op[96], vop96); + _mm256_storeu_ps(&op[104], vop104); + _mm256_storeu_ps(&op[112], vop112); + _mm256_storeu_ps(&op[120], vop120); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); + _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); + _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); + _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); + _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); + _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); + _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); + _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); + } + } + } else if (block_size == 64) { + // unrolling 8 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + __m256 vop32 = _mm256_setzero_ps(); + __m256 vop40 = _mm256_setzero_ps(); + __m256 vop48 = _mm256_setzero_ps(); + __m256 vop56 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (8)))), + vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (16)))), + vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (24)))), + vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + vop32 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (32)))), + vop32); + _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0); + vop40 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (40)))), + vop40); + _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0); + vop48 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (48)))), + vop48); + _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0); + vop56 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (56)))), + vop56); + _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + _mm256_storeu_ps(&op[32], vop32); + _mm256_storeu_ps(&op[40], vop40); + _mm256_storeu_ps(&op[48], vop48); + _mm256_storeu_ps(&op[56], vop56); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); + _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); + _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); + _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); + } + } + } else if (block_size == 32) { + // unrolling 4 times + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + __m256 vop0 = _mm256_setzero_ps(); + __m256 vop8 = _mm256_setzero_ps(); + __m256 vop16 = _mm256_setzero_ps(); + __m256 vop24 = _mm256_setzero_ps(); + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; + vop0 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (0)))), + vop0); + _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0); + vop8 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (8)))), + vop8); + _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0); + vop16 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (16)))), + vop16); + _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0); + vop24 = _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps( + _mm_loadu_si128(reinterpret_cast(ip + (24)))), + vop24); + _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0); + } + if (normalize_by_lengths == false) { + _mm256_storeu_ps(&op[0], vop0); + _mm256_storeu_ps(&op[8], vop8); + _mm256_storeu_ps(&op[16], vop16); + _mm256_storeu_ps(&op[24], vop24); + } else if (lengths[rangeIndex]) { + __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); + _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); + _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); + _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); + _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); + } + } + } else { + // generic code + int64_t dataInd = 0; + for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { + float* op = &out[rangeIndex * block_size]; + TIndex j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps(op + j, _mm256_setzero_ps()); + } + for (; j < block_size; j++) { + op[j] = 0.0f; + } + for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; + ++dataInd) { + const int64_t idx = indices[dataInd]; + float wgt = 1.f; + if (weights) { + wgt = weights[dataInd]; + } + __m256 vwgt = _mm256_set1_ps(wgt); + const float16* ip = &input[idx * block_size]; + const int64_t next_T0 = (dataInd < index_size - prefdist_T0) + ? (dataInd + prefdist_T0) + : dataInd; + const int64_t idx_pref_T0 = indices[next_T0]; + assert( + idx >= 0 && idx_pref_T0 >= 0 && idx < data_size && + idx_pref_T0 < data_size); + const float16* ip_next_T0 = &input[idx_pref_T0 * block_size]; + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], + _mm256_fmadd_ps( + vwgt, + _mm256_cvtph_ps(_mm_loadu_si128( + reinterpret_cast(&ip[j]))), + _mm256_loadu_ps(&op[j]))); + _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0); + } + float16 vtmp1[8] __attribute__((aligned(64))); + for (; j < block_size; j++) { + vtmp1[0] = ip[j]; + __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1)); + op[j] += wgt * ((float*)(&vtmp2))[0]; + } + } + if (normalize_by_lengths && lengths[rangeIndex]) { + float len_inv = 1.0f / lengths[rangeIndex]; + __m256 vlen_inv = _mm256_set1_ps(len_inv); + j = 0; + for (; j + 8 <= block_size; j += 8) { + _mm256_storeu_ps( + &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); + } + for (; j < block_size; j++) { + op[j] = len_inv * op[j]; + } + } + } + } +} } // namespace caffe2 diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py new file mode 100644 index 0000000000..e6b6cf7501 --- /dev/null +++ b/caffe2/perfkernels/hp_emblookup_codegen.py @@ -0,0 +1,232 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +import argparse + + +def unroll(uf, IndexType, InType, OutType, use_weights, isa): + + def compute(regid, InType, use_weights, isa): + code = [] + + if InType == "float": + code.append("vop%d = _mm256_fmadd_ps(vwgt, \ + _mm256_loadu_ps(ip + (%d)), vop%d);" % (regid, regid, regid)) + else: + code.append("vop%d = _mm256_fmadd_ps(vwgt, \ + _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast(ip + (%d)))), \ + vop%d);" + % (regid, regid, regid)) + + code.append("_mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid)) + + return code + + code = [] + code.append("// unrolling " + str(uf) + " times") + code.append(IndexType + " dataInd = 0;") + code.append("for (" + IndexType + + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {") + code.append(OutType + " *op = &out[rangeIndex * block_size];") + for i in range(0, uf): + j = 8 * i + code.append("__m256 vop" + str(j) + " = _mm256_setzero_ps();") + + # inner loop + code.append("for (" + IndexType + + " start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {") + code.append("const " + IndexType + " idx = indices[dataInd];") + code.append(OutType + " wgt = 1.f;") + code.append("if (weights) {") + code.append("wgt = weights[dataInd];") + code.append("}") + code.append("__m256 vwgt = _mm256_set1_ps(wgt);") + code.append("const " + InType + " *ip = &input[idx * block_size];") + code.append("const " + IndexType + + " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;"); + code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];") + code.append( + "assert(idx >=0 && idx_pref_T0 >= 0 && idx < data_size && idx_pref_T0 < data_size);") + code.append("const " + InType + + " *ip_next_T0 = &input[idx_pref_T0 * block_size];") + + for i in range(0, uf): + j = 8 * i + code.extend(compute(j, InType, use_weights, isa)) + code.append("}") + + code.append("if (normalize_by_lengths == false) {") + for i in range(0, uf): + j = 8 * i + code.append( + "_mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");") + code.append("} else if (lengths[rangeIndex]) {") + # inv of length + code.append( + "__m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);") + for i in range(0, uf): + j = 8 * i + code.append( + "_mm256_storeu_ps(&op[" + str(j) + "], _mm256_mul_ps(" + "vop" + str(j) + ", vlen_inv));") + code.append("}") + + code.append("}") + return code + + +def generic(IndexType, InType, OutType, use_weights, isa): + + def compute(InType, use_weights, isa): + code = [] + if InType == "float": + code.append("_mm256_storeu_ps(&op[j], \ + _mm256_fmadd_ps(vwgt,_mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])) \ + );") + else: + code.append("_mm256_storeu_ps(&op[j], \ + _mm256_fmadd_ps(vwgt, \ + _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast(&ip[j]))), _mm256_loadu_ps(&op[j])) \ + );") + + code.append("_mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);") + + return code + + code = [] + code.append(IndexType + " dataInd = 0;") + code.append("for (" + IndexType + + " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {") + code.append(OutType + " *op = &out[rangeIndex * block_size];") + + # initialize to 0 + code.append("TIndex j = 0;") + code.append("for(; j + 8 <= block_size; j += 8) {") + code.append("_mm256_storeu_ps(op + j, _mm256_setzero_ps());") + code.append("}") + code.append("for(; j < block_size; j++) {") + code.append("op[j] = 0.0f;") + code.append("}") + + # inner loop + code.append("for (" + IndexType + + " start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {") + code.append("const " + IndexType + " idx = indices[dataInd];") + code.append(OutType + " wgt = 1.f;") + code.append("if (weights) {") + code.append("wgt = weights[dataInd];") + code.append("}") + code.append("__m256 vwgt = _mm256_set1_ps(wgt);") + code.append("const " + InType + " *ip = &input[idx * block_size];") + code.append("const " + IndexType + + " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;"); + code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];") + code.append( + "assert(idx >=0 && idx_pref_T0 >= 0 && idx < data_size && idx_pref_T0 < data_size);") + code.append("const " + InType + + " *ip_next_T0 = &input[idx_pref_T0 * block_size];") + + # compute and store main loop + code.append("j = 0;") + code.append("for(; j + 8 <= block_size; j += 8) {") + code.extend(compute(InType, use_weights, isa)) + code.append("}") + # leftover + if InType == "float16": + code.append("float16 vtmp1[8] __attribute__((aligned(64)));") + code.append("for(; j < block_size; j++) {") + if InType == "float": + code.append("op[j] += wgt * ip[j];") + else: + code.append("vtmp1[0] = ip[j];") + code.append("__m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));") + code.append("op[j] += wgt * ((float*)(&vtmp2))[0];") + code.append("}") + + code.append("}") + + code.append("if (normalize_by_lengths && lengths[rangeIndex]) {") + code.append("float len_inv = 1.0f / lengths[rangeIndex];") + code.append("__m256 vlen_inv = _mm256_set1_ps(len_inv);") + code.append("j = 0;") + code.append("for(; j + 8 <= block_size; j += 8) {") + code.append( + "_mm256_storeu_ps(&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));") + code.append("}") + code.append("for(; j < block_size; j++) {") + code.append("op[j] = len_inv * op[j];") + code.append("}") + + code.append("}") + + code.append("}") + return code + + + +# start main code +parser = argparse.ArgumentParser() +parser.add_argument('-f', nargs=1, help="file name") +opts = parser.parse_args() +filename = "embedding_lookup_avx2.cc" +if opts.f: + filename = (opts.f)[0] +fout = open(filename, 'w') + +options = [["int32_t", "float", "float"], + ["int64_t", "float", "float"], + ["int32_t", "float16", "float"], + ["int64_t", "float16", "float"]] + +code = [] +# includes +code.append("#include \"caffe2/core/types.h\"") +code.append("#include \"caffe2/core/common.h\"") +code.append("#include ") +code.append("\n") + +code.append("namespace caffe2 {\n") +for o in options: + [IndexType, InType, OutType] = o + + fn = "void EmbeddingLookup_" + IndexType + \ + "_" + InType + "_" + OutType + "__avx2_fma" + code.append(fn + "(") + code.append("const TIndex block_size,") + code.append("const TIndex output_size,") + code.append("const TIndex index_size,") + code.append("const TIndex data_size,") + code.append("const " + InType + "* input,") + code.append("const " + IndexType + "* indices,") + code.append("const int* lengths,") + code.append("const float* weights,") + code.append("bool normalize_by_lengths,") + code.append(OutType + "* out)") + + code.append("{") + code.append("const " + IndexType + " prefdist_T0 = 16;") + #code.append("printf(\"calling " + fn + "\\n\");"); + + code.append("if (block_size == 128) {") + code.extend(unroll(16, IndexType, InType, OutType, True, "AVX2")) + code.append("} else if (block_size == 64) {") + code.extend(unroll(8, IndexType, InType, OutType, True, "AVX2")) + code.append("} else if (block_size == 32) {") + code.extend(unroll(4, IndexType, InType, OutType, True, "AVX2")) + code.append("} else {") + code.append("// generic code") + code.extend(generic(IndexType, InType, OutType, True, "AVX2")) + code.append("}") + + code.append("}") + + code.append("\n") +code.append("} // namespace caffe2") + +for c in code: + #print(c, file = fout) + fout.write(c + "\n") +fout.close() + + +print("Created " + filename) -- cgit v1.2.3