//// -------------------------- //// ATTENTION: //// THIS CODE IS AUTOGENERATED //// BY hp_emblookup_codegen.py //// DO NOT MODIFY!!! //// -------------------------- #include #include namespace caffe2 { template static bool EmbeddingLookup_int32_t_float_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { const int prefdist_T0 = 16; const int fused_block_size = block_size + 0; int dataInd = 0; if (block_size == 128) { // unrolling 16 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); __m256 vop64 = _mm256_setzero_ps(); __m256 vop72 = _mm256_setzero_ps(); __m256 vop80 = _mm256_setzero_ps(); __m256 vop88 = _mm256_setzero_ps(); __m256 vop96 = _mm256_setzero_ps(); __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); _mm_prefetch( reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); _mm_prefetch( reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); _mm_prefetch( reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); _mm_prefetch( reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); _mm_prefetch( reinterpret_cast(&ip_next_T0[80]), _MM_HINT_T0); vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); _mm_prefetch( reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); _mm_prefetch( reinterpret_cast(&ip_next_T0[112]), _MM_HINT_T0); vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); _mm256_storeu_ps(&op[64], vop64); _mm256_storeu_ps(&op[72], vop72); _mm256_storeu_ps(&op[80], vop80); _mm256_storeu_ps(&op[88], vop88); _mm256_storeu_ps(&op[96], vop96); _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); } } } else if (block_size == 64) { // unrolling 8 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); _mm_prefetch( reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); _mm_prefetch( reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); _mm_prefetch( reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); } } } else if (block_size == 32) { // unrolling 4 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); _mm_prefetch( reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); } } } else if (block_size == 16) { // unrolling 2 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); } } } else { // generic code for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps(op + j, _mm256_setzero_ps()); } for (; j < block_size; j++) { op[j] = 0.0f; } if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_fmadd_ps( vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); _mm_prefetch( reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ip[j]; } } if (normalize_by_lengths && lengths[rangeIndex]) { float len_inv = 1.0f / lengths[rangeIndex]; __m256 vlen_inv = _mm256_set1_ps(len_inv); j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); } for (; j < block_size; j++) { op[j] = len_inv * op[j]; } } } } return dataInd == index_size; } bool EmbeddingLookup_int32_t_float_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int32_t_float_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } bool EmbeddingLookup_int32_t_float_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int32_t_float_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } template static bool EmbeddingLookup_int64_t_float_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, const int64_t* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { const int64_t prefdist_T0 = 16; const int64_t fused_block_size = block_size + 0; int64_t dataInd = 0; if (block_size == 128) { // unrolling 16 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); __m256 vop64 = _mm256_setzero_ps(); __m256 vop72 = _mm256_setzero_ps(); __m256 vop80 = _mm256_setzero_ps(); __m256 vop88 = _mm256_setzero_ps(); __m256 vop96 = _mm256_setzero_ps(); __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); _mm_prefetch( reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); _mm_prefetch( reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); _mm_prefetch( reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64); _mm_prefetch( reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80); _mm_prefetch( reinterpret_cast(&ip_next_T0[80]), _MM_HINT_T0); vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96); _mm_prefetch( reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112); _mm_prefetch( reinterpret_cast(&ip_next_T0[112]), _MM_HINT_T0); vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); _mm256_storeu_ps(&op[64], vop64); _mm256_storeu_ps(&op[72], vop72); _mm256_storeu_ps(&op[80], vop80); _mm256_storeu_ps(&op[88], vop88); _mm256_storeu_ps(&op[96], vop96); _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); } } } else if (block_size == 64) { // unrolling 8 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); _mm_prefetch( reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32); _mm_prefetch( reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48); _mm_prefetch( reinterpret_cast(&ip_next_T0[48]), _MM_HINT_T0); vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); } } } else if (block_size == 32) { // unrolling 4 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16); _mm_prefetch( reinterpret_cast(&ip_next_T0[16]), _MM_HINT_T0); vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); } } } else if (block_size == 16) { // unrolling 2 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); } } } else { // generic code for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps(op + j, _mm256_setzero_ps()); } for (; j < block_size; j++) { op[j] = 0.0f; } if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const float* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_fmadd_ps( vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j]))); _mm_prefetch( reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ip[j]; } } if (normalize_by_lengths && lengths[rangeIndex]) { float len_inv = 1.0f / lengths[rangeIndex]; __m256 vlen_inv = _mm256_set1_ps(len_inv); j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); } for (; j < block_size; j++) { op[j] = len_inv * op[j]; } } } } return dataInd == index_size; } bool EmbeddingLookup_int64_t_float_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, const int64_t* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int64_t_float_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } bool EmbeddingLookup_int64_t_float_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const float* input, const int64_t* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int64_t_float_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } template static bool EmbeddingLookup_int32_t_half_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { const int prefdist_T0 = 16; const int fused_block_size = block_size + 0; int dataInd = 0; if (block_size == 128) { // unrolling 16 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); __m256 vop64 = _mm256_setzero_ps(); __m256 vop72 = _mm256_setzero_ps(); __m256 vop80 = _mm256_setzero_ps(); __m256 vop88 = _mm256_setzero_ps(); __m256 vop96 = _mm256_setzero_ps(); __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (8)))), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (16)))), vop16); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (24)))), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); _mm_prefetch( reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (40)))), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (48)))), vop48); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (56)))), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (64)))), vop64); _mm_prefetch( reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (72)))), vop72); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (80)))), vop80); // skip unnecessary prefetch of (&ip_next_T0[80]) vop88 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (88)))), vop88); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (96)))), vop96); _mm_prefetch( reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (104)))), vop104); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (112)))), vop112); // skip unnecessary prefetch of (&ip_next_T0[112]) vop120 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (120)))), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); _mm256_storeu_ps(&op[64], vop64); _mm256_storeu_ps(&op[72], vop72); _mm256_storeu_ps(&op[80], vop80); _mm256_storeu_ps(&op[88], vop88); _mm256_storeu_ps(&op[96], vop96); _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); } } } else if (block_size == 64) { // unrolling 8 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (8)))), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (16)))), vop16); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (24)))), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); _mm_prefetch( reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (40)))), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (48)))), vop48); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (56)))), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); } } } else if (block_size == 32) { // unrolling 4 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (8)))), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (16)))), vop16); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (24)))), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); } } } else if (block_size == 16) { // unrolling 2 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (8)))), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); } } } else { // generic code for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps(op + j, _mm256_setzero_ps()); } for (; j < block_size; j++) { op[j] = 0.0f; } if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps(_mm_loadu_si128( reinterpret_cast(&ip[j]))), _mm256_loadu_ps(&op[j]))); _mm_prefetch( reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } alignas(64) at::Half vtmp1[8]; for (; j < block_size; j++) { vtmp1[0] = ip[j]; __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1)); op[j] += wgt * ((float*)(&vtmp2))[0]; } } if (normalize_by_lengths && lengths[rangeIndex]) { float len_inv = 1.0f / lengths[rangeIndex]; __m256 vlen_inv = _mm256_set1_ps(len_inv); j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); } for (; j < block_size; j++) { op[j] = len_inv * op[j]; } } } } return dataInd == index_size; } bool EmbeddingLookup_int32_t_half_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int32_t_half_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } bool EmbeddingLookup_int32_t_half_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int32_t_half_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } template static bool EmbeddingLookup_int64_t_half_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, const int64_t* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { const int64_t prefdist_T0 = 16; const int64_t fused_block_size = block_size + 0; int64_t dataInd = 0; if (block_size == 128) { // unrolling 16 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); __m256 vop64 = _mm256_setzero_ps(); __m256 vop72 = _mm256_setzero_ps(); __m256 vop80 = _mm256_setzero_ps(); __m256 vop88 = _mm256_setzero_ps(); __m256 vop96 = _mm256_setzero_ps(); __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (8)))), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (16)))), vop16); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (24)))), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); _mm_prefetch( reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (40)))), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (48)))), vop48); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (56)))), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (64)))), vop64); _mm_prefetch( reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (72)))), vop72); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (80)))), vop80); // skip unnecessary prefetch of (&ip_next_T0[80]) vop88 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (88)))), vop88); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (96)))), vop96); _mm_prefetch( reinterpret_cast(&ip_next_T0[96]), _MM_HINT_T0); vop104 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (104)))), vop104); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (112)))), vop112); // skip unnecessary prefetch of (&ip_next_T0[112]) vop120 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (120)))), vop120); // skip unnecessary prefetch of (&ip_next_T0[120]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); _mm256_storeu_ps(&op[64], vop64); _mm256_storeu_ps(&op[72], vop72); _mm256_storeu_ps(&op[80], vop80); _mm256_storeu_ps(&op[88], vop88); _mm256_storeu_ps(&op[96], vop96); _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); } } } else if (block_size == 64) { // unrolling 8 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (8)))), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (16)))), vop16); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (24)))), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (32)))), vop32); _mm_prefetch( reinterpret_cast(&ip_next_T0[32]), _MM_HINT_T0); vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (40)))), vop40); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (48)))), vop48); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (56)))), vop56); // skip unnecessary prefetch of (&ip_next_T0[56]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); } } } else if (block_size == 32) { // unrolling 4 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (8)))), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (16)))), vop16); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (24)))), vop24); // skip unnecessary prefetch of (&ip_next_T0[24]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); } } } else if (block_size == 16) { // unrolling 2 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (0)))), vop0); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps( _mm_loadu_si128(reinterpret_cast(ip + (8)))), vop8); // skip unnecessary prefetch of (&ip_next_T0[8]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); } } } else { // generic code for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps(op + j, _mm256_setzero_ps()); } for (; j < block_size; j++) { op[j] = 0.0f; } if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } __m256 vwgt = _mm256_set1_ps(wgt); const at::Half* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_fmadd_ps( vwgt, _mm256_cvtph_ps(_mm_loadu_si128( reinterpret_cast(&ip[j]))), _mm256_loadu_ps(&op[j]))); _mm_prefetch( reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } alignas(64) at::Half vtmp1[8]; for (; j < block_size; j++) { vtmp1[0] = ip[j]; __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1)); op[j] += wgt * ((float*)(&vtmp2))[0]; } } if (normalize_by_lengths && lengths[rangeIndex]) { float len_inv = 1.0f / lengths[rangeIndex]; __m256 vlen_inv = _mm256_set1_ps(len_inv); j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); } for (; j < block_size; j++) { op[j] = len_inv * op[j]; } } } } return dataInd == index_size; } bool EmbeddingLookup_int64_t_half_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, const int64_t* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int64_t_half_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } bool EmbeddingLookup_int64_t_half_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const at::Half* input, const int64_t* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int64_t_half_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } template static bool EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { const int prefdist_T0 = 16; const int fused_block_size = block_size + 0; int dataInd = 0; if (block_size == 128) { // unrolling 16 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); __m256 vop64 = _mm256_setzero_ps(); __m256 vop72 = _mm256_setzero_ps(); __m256 vop80 = _mm256_setzero_ps(); __m256 vop88 = _mm256_setzero_ps(); __m256 vop96 = _mm256_setzero_ps(); __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (8))))), _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (16))))), _mm256_add_ps(vop16, vbio)); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (24))))), _mm256_add_ps(vop24, vbio)); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (32))))), _mm256_add_ps(vop32, vbio)); // skip unnecessary prefetch of (&ip_next_T0[32]) vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (40))))), _mm256_add_ps(vop40, vbio)); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (48))))), _mm256_add_ps(vop48, vbio)); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (56))))), _mm256_add_ps(vop56, vbio)); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (64))))), _mm256_add_ps(vop64, vbio)); _mm_prefetch( reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (72))))), _mm256_add_ps(vop72, vbio)); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (80))))), _mm256_add_ps(vop80, vbio)); // skip unnecessary prefetch of (&ip_next_T0[80]) vop88 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (88))))), _mm256_add_ps(vop88, vbio)); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (96))))), _mm256_add_ps(vop96, vbio)); // skip unnecessary prefetch of (&ip_next_T0[96]) vop104 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (104))))), _mm256_add_ps(vop104, vbio)); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (112))))), _mm256_add_ps(vop112, vbio)); // skip unnecessary prefetch of (&ip_next_T0[112]) vop120 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (120))))), _mm256_add_ps(vop120, vbio)); // skip unnecessary prefetch of (&ip_next_T0[120]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); _mm256_storeu_ps(&op[64], vop64); _mm256_storeu_ps(&op[72], vop72); _mm256_storeu_ps(&op[80], vop80); _mm256_storeu_ps(&op[88], vop88); _mm256_storeu_ps(&op[96], vop96); _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); } } } else if (block_size == 64) { // unrolling 8 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (8))))), _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (16))))), _mm256_add_ps(vop16, vbio)); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (24))))), _mm256_add_ps(vop24, vbio)); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (32))))), _mm256_add_ps(vop32, vbio)); // skip unnecessary prefetch of (&ip_next_T0[32]) vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (40))))), _mm256_add_ps(vop40, vbio)); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (48))))), _mm256_add_ps(vop48, vbio)); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (56))))), _mm256_add_ps(vop56, vbio)); // skip unnecessary prefetch of (&ip_next_T0[56]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); } } } else if (block_size == 32) { // unrolling 4 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (8))))), _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (16))))), _mm256_add_ps(vop16, vbio)); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (24))))), _mm256_add_ps(vop24, vbio)); // skip unnecessary prefetch of (&ip_next_T0[24]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); } } } else if (block_size == 16) { // unrolling 2 times for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (8))))), _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); } } } else { // generic code for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps(op + j, _mm256_setzero_ps()); } for (; j < block_size; j++) { op[j] = 0.0f; } if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( reinterpret_cast(&ip[j])))), _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); _mm_prefetch( reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ((float)ip[j]) + bio; } } if (normalize_by_lengths && lengths[rangeIndex]) { float len_inv = 1.0f / lengths[rangeIndex]; __m256 vlen_inv = _mm256_set1_ps(len_inv); j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); } for (; j < block_size; j++) { op[j] = len_inv * op[j]; } } } } return dataInd == index_size; } bool EmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } bool EmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, const int* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int32_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } template static bool EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, const int64_t* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { const int64_t prefdist_T0 = 16; const int64_t fused_block_size = block_size + 0; int64_t dataInd = 0; if (block_size == 128) { // unrolling 16 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); __m256 vop64 = _mm256_setzero_ps(); __m256 vop72 = _mm256_setzero_ps(); __m256 vop80 = _mm256_setzero_ps(); __m256 vop88 = _mm256_setzero_ps(); __m256 vop96 = _mm256_setzero_ps(); __m256 vop104 = _mm256_setzero_ps(); __m256 vop112 = _mm256_setzero_ps(); __m256 vop120 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (8))))), _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (16))))), _mm256_add_ps(vop16, vbio)); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (24))))), _mm256_add_ps(vop24, vbio)); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (32))))), _mm256_add_ps(vop32, vbio)); // skip unnecessary prefetch of (&ip_next_T0[32]) vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (40))))), _mm256_add_ps(vop40, vbio)); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (48))))), _mm256_add_ps(vop48, vbio)); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (56))))), _mm256_add_ps(vop56, vbio)); // skip unnecessary prefetch of (&ip_next_T0[56]) vop64 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (64))))), _mm256_add_ps(vop64, vbio)); _mm_prefetch( reinterpret_cast(&ip_next_T0[64]), _MM_HINT_T0); vop72 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (72))))), _mm256_add_ps(vop72, vbio)); // skip unnecessary prefetch of (&ip_next_T0[72]) vop80 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (80))))), _mm256_add_ps(vop80, vbio)); // skip unnecessary prefetch of (&ip_next_T0[80]) vop88 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (88))))), _mm256_add_ps(vop88, vbio)); // skip unnecessary prefetch of (&ip_next_T0[88]) vop96 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (96))))), _mm256_add_ps(vop96, vbio)); // skip unnecessary prefetch of (&ip_next_T0[96]) vop104 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (104))))), _mm256_add_ps(vop104, vbio)); // skip unnecessary prefetch of (&ip_next_T0[104]) vop112 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (112))))), _mm256_add_ps(vop112, vbio)); // skip unnecessary prefetch of (&ip_next_T0[112]) vop120 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (120))))), _mm256_add_ps(vop120, vbio)); // skip unnecessary prefetch of (&ip_next_T0[120]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); _mm256_storeu_ps(&op[64], vop64); _mm256_storeu_ps(&op[72], vop72); _mm256_storeu_ps(&op[80], vop80); _mm256_storeu_ps(&op[88], vop88); _mm256_storeu_ps(&op[96], vop96); _mm256_storeu_ps(&op[104], vop104); _mm256_storeu_ps(&op[112], vop112); _mm256_storeu_ps(&op[120], vop120); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv)); _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv)); _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv)); _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv)); _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv)); _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv)); _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv)); _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv)); } } } else if (block_size == 64) { // unrolling 8 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); __m256 vop32 = _mm256_setzero_ps(); __m256 vop40 = _mm256_setzero_ps(); __m256 vop48 = _mm256_setzero_ps(); __m256 vop56 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (8))))), _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (16))))), _mm256_add_ps(vop16, vbio)); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (24))))), _mm256_add_ps(vop24, vbio)); // skip unnecessary prefetch of (&ip_next_T0[24]) vop32 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (32))))), _mm256_add_ps(vop32, vbio)); // skip unnecessary prefetch of (&ip_next_T0[32]) vop40 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (40))))), _mm256_add_ps(vop40, vbio)); // skip unnecessary prefetch of (&ip_next_T0[40]) vop48 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (48))))), _mm256_add_ps(vop48, vbio)); // skip unnecessary prefetch of (&ip_next_T0[48]) vop56 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (56))))), _mm256_add_ps(vop56, vbio)); // skip unnecessary prefetch of (&ip_next_T0[56]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); _mm256_storeu_ps(&op[32], vop32); _mm256_storeu_ps(&op[40], vop40); _mm256_storeu_ps(&op[48], vop48); _mm256_storeu_ps(&op[56], vop56); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv)); _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv)); _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv)); _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv)); } } } else if (block_size == 32) { // unrolling 4 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); __m256 vop16 = _mm256_setzero_ps(); __m256 vop24 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (8))))), _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) vop16 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (16))))), _mm256_add_ps(vop16, vbio)); // skip unnecessary prefetch of (&ip_next_T0[16]) vop24 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (24))))), _mm256_add_ps(vop24, vbio)); // skip unnecessary prefetch of (&ip_next_T0[24]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); _mm256_storeu_ps(&op[16], vop16); _mm256_storeu_ps(&op[24], vop24); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv)); _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv)); } } } else if (block_size == 16) { // unrolling 2 times for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; __m256 vop0 = _mm256_setzero_ps(); __m256 vop8 = _mm256_setzero_ps(); if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; vop0 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (0))))), _mm256_add_ps(vop0, vbio)); _mm_prefetch( reinterpret_cast(&ip_next_T0[0]), _MM_HINT_T0); vop8 = _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( _mm_loadl_epi64(reinterpret_cast(ip + (8))))), _mm256_add_ps(vop8, vbio)); // skip unnecessary prefetch of (&ip_next_T0[8]) } if (!normalize_by_lengths || lengths[rangeIndex] == 0) { _mm256_storeu_ps(&op[0], vop0); _mm256_storeu_ps(&op[8], vop8); } else { __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]); _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv)); _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv)); } } } else { // generic code for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) { float* op = &out[rangeIndex * block_size]; int64_t j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps(op + j, _mm256_setzero_ps()); } for (; j < block_size; j++) { op[j] = 0.0f; } if (dataInd + lengths[rangeIndex] > index_size) { return false; } for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex]; ++dataInd) { const int64_t idx = indices[dataInd]; if (idx < 0 || idx >= data_size) { return false; } float wgt = 1.f; float bio; if (weights) { wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd]; } bio = wgt * scale_bias[2 * idx + 1]; wgt = wgt * scale_bias[2 * idx]; __m256 vbio = _mm256_set1_ps(bio); __m256 vwgt = _mm256_set1_ps(wgt); const uint8_t* ip = &input[idx * fused_block_size]; const int64_t next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd; const int64_t idx_pref_T0 = indices[next_T0]; if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) { return false; } const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size]; j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_fmadd_ps( vwgt, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64( reinterpret_cast(&ip[j])))), _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio))); _mm_prefetch( reinterpret_cast(&ip_next_T0[j]), _MM_HINT_T0); } for (; j < block_size; j++) { op[j] += wgt * ((float)ip[j]) + bio; } } if (normalize_by_lengths && lengths[rangeIndex]) { float len_inv = 1.0f / lengths[rangeIndex]; __m256 vlen_inv = _mm256_set1_ps(len_inv); j = 0; for (; j + 8 <= block_size; j += 8) { _mm256_storeu_ps( &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv)); } for (; j < block_size; j++) { op[j] = len_inv * op[j]; } } } } return dataInd == index_size; } bool EmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, const int64_t* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } bool EmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, const int64_t* indices, const int* lengths, const float* weights, const float* scale_bias, bool normalize_by_lengths, float* out) { return EmbeddingLookup_int64_t_uint8_t_float__avx2_fma( block_size, output_size, index_size, data_size, input, indices, lengths, weights, scale_bias, normalize_by_lengths, out); } } // namespace caffe2