summaryrefslogtreecommitdiff
path: root/caffe2/perfkernels
diff options
context:
space:
mode:
authorMisha Smelyanskiy <msmelyan@fb.com>2017-08-30 16:11:39 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2017-08-30 16:22:11 -0700
commit080fab8f6c82daa93b478267f3a604e1bdb8083f (patch)
treedf9e4922f288955b27aa4fe4609d678729f52ca2 /caffe2/perfkernels
parent45e6e71198ab027fca61b6de6596479374a9b76c (diff)
downloadpytorch-080fab8f6c82daa93b478267f3a604e1bdb8083f.tar.gz
pytorch-080fab8f6c82daa93b478267f3a604e1bdb8083f.tar.bz2
pytorch-080fab8f6c82daa93b478267f3a604e1bdb8083f.zip
Code generator for and high-performance emebding look-up kernels, supporting
Summary: Code generator for and high-performance emebding look-up kernels, supporting Sum, WeightedSum, and Mean reducers. Achieve at least 1.5x speedup on float and over 2x speedup for float16, compared to existing code These are results on Broadwell, using sparse_lengths_sum_benchmar.par benchmark Old ============== [root@fblearner001.01.ftw1 /home/msmelyan]# numactl -m 0 -C 0 ./sparse_lengths_sum_benchmark.par --iteration 10000 Preparing lookup table. 2017-08-08 00:10:23.101848 Preparation finished. 2017-08-08 00:10:27.955680 I0808 00:10:27.955732 30700 net.cc:177] Starting benchmark. I0808 00:10:27.955759 30700 net.cc:178] Running warmup runs. I0808 00:10:27.956367 30700 net.cc:188] Main runs. I0808 00:10:31.839035 30700 net.cc:199] Main run finished. Milliseconds per iter: 0.388264. Iters per second: 2575.56 I0808 00:10:35.704169 30700 net.cc:233] Operator #0 (indices, Python) 0.0583264 ms/iter I0808 00:10:35.704210 30700 net.cc:233] Operator #1 (Y, SparseLengthsSum) 0.327694 ms/iter I0808 00:10:35.704213 30700 net.cc:237] Time per operator type: I0808 00:10:35.704217 30700 net.cc:246] 0.327694 SparseLengthsSum I0808 00:10:35.704221 30700 net.cc:246] 0.0583264 Python [root@fblearner001.01.ftw1 /home/msmelyan]# numactl -m 0 -C 0 ./sparse_lengths_sum_benchmark.par --iteration 10000 --dtype float16 Preparing lookup table. 2017-08-08 00:10:59.047159 Preparation finished. 2017-08-08 00:11:05.140565 I0808 00:11:05.140612 31725 net.cc:177] Starting benchmark. I0808 00:11:05.140635 31725 net.cc:178] Running warmup runs. I0808 00:11:05.141104 31725 net.cc:188] Main runs. I0808 00:11:08.371510 31725 net.cc:199] Main run finished. Milliseconds per iter: 0.323039. Iters per second: 3095.6 I0808 00:11:11.671450 31725 net.cc:233] Operator #0 (indices, Python) 0.0609876 ms/iter I0808 00:11:11.671489 31725 net.cc:233] Operator #1 (Y, SparseLengthsSum) 0.26856 ms/iter I0808 00:11:11.671494 31725 net.cc:237] Time per operator type: I0808 00:11:11.671497 31725 net.cc:246] 0.26856 SparseLengthsSum I0808 00:11:11.671500 31725 net.cc:246] 0.0609876 Python New (Misha's) ============== [root@fblearner001.01.ftw1 /home/msmelyan]# numactl -m 0 -C 0 ./sparse_lengths_sum_benchmark.par --iteration 10000 Preparing lookup table. 2017-08-07 23:44:55.897748 Preparation finished. 2017-08-07 23:45:00.708896 I0807 23:45:00.708945 4178361 net.cc:177] Starting benchmark. I0807 23:45:00.708971 4178361 net.cc:178] Running warmup runs. I0807 23:45:00.709444 4178361 net.cc:188] Main runs. I0807 23:45:03.608551 4178361 net.cc:199] Main run finished. Milliseconds per iter: 0.289909. Iters per second: 3449.36 I0807 23:45:06.536182 4178361 net.cc:233] Operator #0 (indices, Python) 0.0572399 ms/iter I0807 23:45:06.536224 4178361 net.cc:233] Operator #1 (Y, SparseLengthsSum) 0.23512 ms/iter I0807 23:45:06.536228 4178361 net.cc:237] Time per operator type: I0807 23:45:06.536232 4178361 net.cc:246] 0.23512 SparseLengthsSum I0807 23:45:06.536236 4178361 net.cc:246] 0.0572399 Python [root@fblearner001.01.ftw1 /home/msmelyan]# numactl -m 0 -C 0 ./sparse_lengths_sum_benchmark.par --iteration 10000 --dtype float16 Preparing lookup table. 2017-08-07 23:45:17.191579 Preparation finished. 2017-08-07 23:45:23.173668 I0807 23:45:23.173715 4179316 net.cc:177] Starting benchmark. I0807 23:45:23.173743 4179316 net.cc:178] Running warmup runs. I0807 23:45:23.174090 4179316 net.cc:188] Main runs. I0807 23:45:24.939749 4179316 net.cc:199] Main run finished. Milliseconds per iter: 0.176564. Iters per second: 5663.67 I0807 23:45:26.698885 4179316 net.cc:233] Operator #0 (indices, Python) 0.0557303 ms/iter I0807 23:45:26.698923 4179316 net.cc:233] Operator #1 (Y, SparseLengthsSum) 0.119794 ms/iter I0807 23:45:26.698927 4179316 net.cc:237] Time per operator type: I0807 23:45:26.698931 4179316 net.cc:246] 0.119794 SparseLengthsSum I0807 23:45:26.698935 4179316 net.cc:246] 0.0557303 Python Reviewed By: salexspb Differential Revision: D5582172 fbshipit-source-id: d71f5a55580b734a51b8f30852b75f379acfdaf2
Diffstat (limited to 'caffe2/perfkernels')
-rw-r--r--caffe2/perfkernels/embedding_lookup_avx2.cc1444
-rw-r--r--caffe2/perfkernels/hp_emblookup_codegen.py232
2 files changed, 1637 insertions, 39 deletions
diff --git a/caffe2/perfkernels/embedding_lookup_avx2.cc b/caffe2/perfkernels/embedding_lookup_avx2.cc
index 1e3d6fa443..8b87f486b2 100644
--- a/caffe2/perfkernels/embedding_lookup_avx2.cc
+++ b/caffe2/perfkernels/embedding_lookup_avx2.cc
@@ -1,49 +1,1415 @@
+#include <immintrin.h>
+#include "caffe2/core/common.h"
#include "caffe2/core/types.h"
namespace caffe2 {
-// TODO(msmelyan): implement code generator for implementation based on
-// following parameters:
-// index type: int32, int64 (encoded in function name)
-// embedding data type: float16, float32 (encoded in function name)
-// output type: float (encoded in function name)
-// weighted reduction: whether `weights` is nullptr or not
-// normalization (divide by lengths[i]): whether normalize_by_lengths is true
-// block size: 32, 64, 128, generic
+void EmbeddingLookup_int32_t_float_float__avx2_fma(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const float* input,
+ const int32_t* indices,
+ const int* lengths,
+ const float* weights,
+ bool normalize_by_lengths,
+ float* out) {
+ const int32_t prefdist_T0 = 16;
+ if (block_size == 128) {
+ // unrolling 16 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
+ _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0);
+ vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
+ _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0);
+ vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
+ _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
+ _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0);
+ vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
+ _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
+ vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
+ _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0);
+ vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
+ _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+ vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
+ _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0);
+ vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
+ _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
+ vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
+ _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
+ _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0);
+ vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
+ _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ TIndex j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
+ _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ for (; j < block_size; j++) {
+ op[j] += wgt * ip[j];
+ }
+ }
+ if (normalize_by_lengths && lengths[rangeIndex]) {
+ float len_inv = 1.0f / lengths[rangeIndex];
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+}
-// For now just invoke base implementation (this entire file can be autogenned)
-#define EMBEDDING_SPECIALIZATION(IndexType, InType, OutType) \
- void EmbeddingLookup_##IndexType##_##InType##_##OutType##__avx2_fma( \
- const TIndex block_size, \
- const TIndex output_size, \
- const TIndex index_size, \
- const TIndex data_size, \
- const InType* input, \
- const IndexType* indices, \
- const int* lengths, \
- const float* weights, \
- bool normalize_by_lengths, \
- OutType* out) { \
- decltype(EmbeddingLookup_##IndexType##_##InType##_##OutType##__avx2_fma) \
- EmbeddingLookup_##IndexType##_##InType##_##OutType##__base; \
- EmbeddingLookup_##IndexType##_##InType##_##OutType##__base( \
- block_size, \
- output_size, \
- index_size, \
- data_size, \
- input, \
- indices, \
- lengths, \
- weights, \
- normalize_by_lengths, \
- out); \
+void EmbeddingLookup_int64_t_float_float__avx2_fma(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const float* input,
+ const int64_t* indices,
+ const int* lengths,
+ const float* weights,
+ bool normalize_by_lengths,
+ float* out) {
+ const int64_t prefdist_T0 = 16;
+ if (block_size == 128) {
+ // unrolling 16 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
+ _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0);
+ vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
+ _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0);
+ vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
+ _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
+ _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0);
+ vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
+ _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
+ vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
+ _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0);
+ vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
+ _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+ vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
+ _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0);
+ vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
+ _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
+ vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
+ _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
+ _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0);
+ vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
+ _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ TIndex j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
+ _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ for (; j < block_size; j++) {
+ op[j] += wgt * ip[j];
+ }
+ }
+ if (normalize_by_lengths && lengths[rangeIndex]) {
+ float len_inv = 1.0f / lengths[rangeIndex];
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
}
+}
-EMBEDDING_SPECIALIZATION(int32_t, float, float);
-EMBEDDING_SPECIALIZATION(int64_t, float, float);
-EMBEDDING_SPECIALIZATION(int32_t, float16, float);
-EMBEDDING_SPECIALIZATION(int64_t, float16, float);
+void EmbeddingLookup_int32_t_float16_float__avx2_fma(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const float16* input,
+ const int32_t* indices,
+ const int* lengths,
+ const float* weights,
+ bool normalize_by_lengths,
+ float* out) {
+ const int32_t prefdist_T0 = 16;
+ if (block_size == 128) {
+ // unrolling 16 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
+ vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
+ vop40);
+ _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0);
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
+ vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
+ vop56);
+ _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0);
+ vop64 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
+ vop64);
+ _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
+ vop72);
+ _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0);
+ vop80 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
+ vop80);
+ _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
+ vop88 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
+ vop88);
+ _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0);
+ vop96 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
+ vop96);
+ _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+ vop104 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
+ vop104);
+ _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0);
+ vop112 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
+ vop112);
+ _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
+ vop120 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
+ vop120);
+ _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
+ vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
+ vop40);
+ _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0);
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
+ vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
+ vop56);
+ _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ TIndex j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(&ip[j]))),
+ _mm256_loadu_ps(&op[j])));
+ _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ float16 vtmp1[8] __attribute__((aligned(64)));
+ for (; j < block_size; j++) {
+ vtmp1[0] = ip[j];
+ __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
+ op[j] += wgt * ((float*)(&vtmp2))[0];
+ }
+ }
+ if (normalize_by_lengths && lengths[rangeIndex]) {
+ float len_inv = 1.0f / lengths[rangeIndex];
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+}
-#undef EMBEDDING_SPECIALIZATION
+void EmbeddingLookup_int64_t_float16_float__avx2_fma(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const float16* input,
+ const int64_t* indices,
+ const int* lengths,
+ const float* weights,
+ bool normalize_by_lengths,
+ float* out) {
+ const int64_t prefdist_T0 = 16;
+ if (block_size == 128) {
+ // unrolling 16 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
+ vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
+ vop40);
+ _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0);
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
+ vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
+ vop56);
+ _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0);
+ vop64 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
+ vop64);
+ _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
+ vop72);
+ _mm_prefetch((&ip_next_T0[72]), _MM_HINT_T0);
+ vop80 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
+ vop80);
+ _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
+ vop88 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
+ vop88);
+ _mm_prefetch((&ip_next_T0[88]), _MM_HINT_T0);
+ vop96 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
+ vop96);
+ _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+ vop104 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
+ vop104);
+ _mm_prefetch((&ip_next_T0[104]), _MM_HINT_T0);
+ vop112 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
+ vop112);
+ _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
+ vop120 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
+ vop120);
+ _mm_prefetch((&ip_next_T0[120]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
+ vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
+ vop40);
+ _mm_prefetch((&ip_next_T0[40]), _MM_HINT_T0);
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
+ vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
+ vop56);
+ _mm_prefetch((&ip_next_T0[56]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ _mm_prefetch((&ip_next_T0[8]), _MM_HINT_T0);
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ _mm_prefetch((&ip_next_T0[24]), _MM_HINT_T0);
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ TIndex j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ assert(
+ idx >= 0 && idx_pref_T0 >= 0 && idx < data_size &&
+ idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(&ip[j]))),
+ _mm256_loadu_ps(&op[j])));
+ _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ float16 vtmp1[8] __attribute__((aligned(64)));
+ for (; j < block_size; j++) {
+ vtmp1[0] = ip[j];
+ __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
+ op[j] += wgt * ((float*)(&vtmp2))[0];
+ }
+ }
+ if (normalize_by_lengths && lengths[rangeIndex]) {
+ float len_inv = 1.0f / lengths[rangeIndex];
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+}
} // namespace caffe2
diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py
new file mode 100644
index 0000000000..e6b6cf7501
--- /dev/null
+++ b/caffe2/perfkernels/hp_emblookup_codegen.py
@@ -0,0 +1,232 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+import argparse
+
+
+def unroll(uf, IndexType, InType, OutType, use_weights, isa):
+
+ def compute(regid, InType, use_weights, isa):
+ code = []
+
+ if InType == "float":
+ code.append("vop%d = _mm256_fmadd_ps(vwgt, \
+ _mm256_loadu_ps(ip + (%d)), vop%d);" % (regid, regid, regid))
+ else:
+ code.append("vop%d = _mm256_fmadd_ps(vwgt, \
+ _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))), \
+ vop%d);"
+ % (regid, regid, regid))
+
+ code.append("_mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid))
+
+ return code
+
+ code = []
+ code.append("// unrolling " + str(uf) + " times")
+ code.append(IndexType + " dataInd = 0;")
+ code.append("for (" + IndexType +
+ " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {")
+ code.append(OutType + " *op = &out[rangeIndex * block_size];")
+ for i in range(0, uf):
+ j = 8 * i
+ code.append("__m256 vop" + str(j) + " = _mm256_setzero_ps();")
+
+ # inner loop
+ code.append("for (" + IndexType +
+ " start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {")
+ code.append("const " + IndexType + " idx = indices[dataInd];")
+ code.append(OutType + " wgt = 1.f;")
+ code.append("if (weights) {")
+ code.append("wgt = weights[dataInd];")
+ code.append("}")
+ code.append("__m256 vwgt = _mm256_set1_ps(wgt);")
+ code.append("const " + InType + " *ip = &input[idx * block_size];")
+ code.append("const " + IndexType +
+ " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;");
+ code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];")
+ code.append(
+ "assert(idx >=0 && idx_pref_T0 >= 0 && idx < data_size && idx_pref_T0 < data_size);")
+ code.append("const " + InType +
+ " *ip_next_T0 = &input[idx_pref_T0 * block_size];")
+
+ for i in range(0, uf):
+ j = 8 * i
+ code.extend(compute(j, InType, use_weights, isa))
+ code.append("}")
+
+ code.append("if (normalize_by_lengths == false) {")
+ for i in range(0, uf):
+ j = 8 * i
+ code.append(
+ "_mm256_storeu_ps(&op[" + str(j) + "], vop" + str(j) + ");")
+ code.append("} else if (lengths[rangeIndex]) {")
+ # inv of length
+ code.append(
+ "__m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);")
+ for i in range(0, uf):
+ j = 8 * i
+ code.append(
+ "_mm256_storeu_ps(&op[" + str(j) + "], _mm256_mul_ps(" + "vop" + str(j) + ", vlen_inv));")
+ code.append("}")
+
+ code.append("}")
+ return code
+
+
+def generic(IndexType, InType, OutType, use_weights, isa):
+
+ def compute(InType, use_weights, isa):
+ code = []
+ if InType == "float":
+ code.append("_mm256_storeu_ps(&op[j], \
+ _mm256_fmadd_ps(vwgt,_mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])) \
+ );")
+ else:
+ code.append("_mm256_storeu_ps(&op[j], \
+ _mm256_fmadd_ps(vwgt, \
+ _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(&ip[j]))), _mm256_loadu_ps(&op[j])) \
+ );")
+
+ code.append("_mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);")
+
+ return code
+
+ code = []
+ code.append(IndexType + " dataInd = 0;")
+ code.append("for (" + IndexType +
+ " rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {")
+ code.append(OutType + " *op = &out[rangeIndex * block_size];")
+
+ # initialize to 0
+ code.append("TIndex j = 0;")
+ code.append("for(; j + 8 <= block_size; j += 8) {")
+ code.append("_mm256_storeu_ps(op + j, _mm256_setzero_ps());")
+ code.append("}")
+ code.append("for(; j < block_size; j++) {")
+ code.append("op[j] = 0.0f;")
+ code.append("}")
+
+ # inner loop
+ code.append("for (" + IndexType +
+ " start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) {")
+ code.append("const " + IndexType + " idx = indices[dataInd];")
+ code.append(OutType + " wgt = 1.f;")
+ code.append("if (weights) {")
+ code.append("wgt = weights[dataInd];")
+ code.append("}")
+ code.append("__m256 vwgt = _mm256_set1_ps(wgt);")
+ code.append("const " + InType + " *ip = &input[idx * block_size];")
+ code.append("const " + IndexType +
+ " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;");
+ code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];")
+ code.append(
+ "assert(idx >=0 && idx_pref_T0 >= 0 && idx < data_size && idx_pref_T0 < data_size);")
+ code.append("const " + InType +
+ " *ip_next_T0 = &input[idx_pref_T0 * block_size];")
+
+ # compute and store main loop
+ code.append("j = 0;")
+ code.append("for(; j + 8 <= block_size; j += 8) {")
+ code.extend(compute(InType, use_weights, isa))
+ code.append("}")
+ # leftover
+ if InType == "float16":
+ code.append("float16 vtmp1[8] __attribute__((aligned(64)));")
+ code.append("for(; j < block_size; j++) {")
+ if InType == "float":
+ code.append("op[j] += wgt * ip[j];")
+ else:
+ code.append("vtmp1[0] = ip[j];")
+ code.append("__m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));")
+ code.append("op[j] += wgt * ((float*)(&vtmp2))[0];")
+ code.append("}")
+
+ code.append("}")
+
+ code.append("if (normalize_by_lengths && lengths[rangeIndex]) {")
+ code.append("float len_inv = 1.0f / lengths[rangeIndex];")
+ code.append("__m256 vlen_inv = _mm256_set1_ps(len_inv);")
+ code.append("j = 0;")
+ code.append("for(; j + 8 <= block_size; j += 8) {")
+ code.append(
+ "_mm256_storeu_ps(&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));")
+ code.append("}")
+ code.append("for(; j < block_size; j++) {")
+ code.append("op[j] = len_inv * op[j];")
+ code.append("}")
+
+ code.append("}")
+
+ code.append("}")
+ return code
+
+
+
+# start main code
+parser = argparse.ArgumentParser()
+parser.add_argument('-f', nargs=1, help="file name")
+opts = parser.parse_args()
+filename = "embedding_lookup_avx2.cc"
+if opts.f:
+ filename = (opts.f)[0]
+fout = open(filename, 'w')
+
+options = [["int32_t", "float", "float"],
+ ["int64_t", "float", "float"],
+ ["int32_t", "float16", "float"],
+ ["int64_t", "float16", "float"]]
+
+code = []
+# includes
+code.append("#include \"caffe2/core/types.h\"")
+code.append("#include \"caffe2/core/common.h\"")
+code.append("#include <immintrin.h>")
+code.append("\n")
+
+code.append("namespace caffe2 {\n")
+for o in options:
+ [IndexType, InType, OutType] = o
+
+ fn = "void EmbeddingLookup_" + IndexType + \
+ "_" + InType + "_" + OutType + "__avx2_fma"
+ code.append(fn + "(")
+ code.append("const TIndex block_size,")
+ code.append("const TIndex output_size,")
+ code.append("const TIndex index_size,")
+ code.append("const TIndex data_size,")
+ code.append("const " + InType + "* input,")
+ code.append("const " + IndexType + "* indices,")
+ code.append("const int* lengths,")
+ code.append("const float* weights,")
+ code.append("bool normalize_by_lengths,")
+ code.append(OutType + "* out)")
+
+ code.append("{")
+ code.append("const " + IndexType + " prefdist_T0 = 16;")
+ #code.append("printf(\"calling " + fn + "\\n\");");
+
+ code.append("if (block_size == 128) {")
+ code.extend(unroll(16, IndexType, InType, OutType, True, "AVX2"))
+ code.append("} else if (block_size == 64) {")
+ code.extend(unroll(8, IndexType, InType, OutType, True, "AVX2"))
+ code.append("} else if (block_size == 32) {")
+ code.extend(unroll(4, IndexType, InType, OutType, True, "AVX2"))
+ code.append("} else {")
+ code.append("// generic code")
+ code.extend(generic(IndexType, InType, OutType, True, "AVX2"))
+ code.append("}")
+
+ code.append("}")
+
+ code.append("\n")
+code.append("} // namespace caffe2")
+
+for c in code:
+ #print(c, file = fout)
+ fout.write(c + "\n")
+fout.close()
+
+
+print("Created " + filename)