summaryrefslogtreecommitdiff
path: root/caffe2/perfkernels
diff options
context:
space:
mode:
authorPeter Goldsborough <psag@fb.com>2018-01-19 15:37:01 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-01-19 15:44:34 -0800
commitd401c26d633e49de201d1faf6e3e84e25d185b60 (patch)
tree2aa8a2e9978b5fc44bc8d669b6ba41618566ab18 /caffe2/perfkernels
parent8dc0702af5b141aceb6291f54a2746cc26427dac (diff)
downloadpytorch-d401c26d633e49de201d1faf6e3e84e25d185b60.tar.gz
pytorch-d401c26d633e49de201d1faf6e3e84e25d185b60.tar.bz2
pytorch-d401c26d633e49de201d1faf6e3e84e25d185b60.zip
Add FusedEmbeddingLookup
Summary: Updates the perfkernel codebase to implement embedding lookup for our new fused storage format, where each row in the data matrix stores the quantized values *and* the scale and bias. msmelyan see this as my best-effort attempt at updating the perfkernel stuff for the fused storage. Let me know if any of this is grossly wrong. I also don't know if we need to update any of the prefetching operations or something like that. Note that we have to keep the old code around for a bit until we get rid of the old operations with separate `scale_bias` storage. Reviewed By: kennyhorror Differential Revision: D6710843 fbshipit-source-id: b485ef2389f526c5db1260cac9d4be3fc8df0979
Diffstat (limited to 'caffe2/perfkernels')
-rw-r--r--caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc2760
-rw-r--r--caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc170
-rw-r--r--caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h67
-rw-r--r--caffe2/perfkernels/hp_emblookup_codegen.py188
4 files changed, 3113 insertions, 72 deletions
diff --git a/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc
new file mode 100644
index 0000000000..7277df7842
--- /dev/null
+++ b/caffe2/perfkernels/embedding_lookup_fused_8bit_rowwise_avx2.cc
@@ -0,0 +1,2760 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//// --------------------------
+//// ATTENTION:
+//// THIS CODE IS AUTOGENERATED
+//// BY caffe2/caffe2/perfkernels/hp_emblookup_codegen.py
+//// DO NOT MODIFY!!!
+//// --------------------------
+
+#include <caffe2/core/common.h>
+#include <caffe2/core/types.h>
+#include <immintrin.h>
+
+namespace caffe2 {
+
+void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const float* input,
+ const int32_t* indices,
+ const int* lengths,
+ const float* weights,
+ bool normalize_by_lengths,
+ float* out) {
+ const int32_t prefdist_T0 = 16;
+ const int32_t fused_block_size = block_size + 2;
+ if (block_size == 128) {
+ // unrolling 16 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
+ _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
+ // skip unnecessary prefetch of (&ip_next_T0[72])
+ vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
+ _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
+ vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
+ // skip unnecessary prefetch of (&ip_next_T0[88])
+ vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
+ _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+ vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
+ // skip unnecessary prefetch of (&ip_next_T0[104])
+ vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
+ _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
+ vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
+ // skip unnecessary prefetch of (&ip_next_T0[120])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else if (block_size == 16) {
+ // unrolling 2 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ TIndex j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
+ _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ for (; j < block_size; j++) {
+ op[j] += wgt * ip[j];
+ }
+ }
+ if (normalize_by_lengths && lengths[rangeIndex]) {
+ float len_inv = 1.0f / lengths[rangeIndex];
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+}
+
+void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const float* input,
+ const int64_t* indices,
+ const int* lengths,
+ const float* weights,
+ bool normalize_by_lengths,
+ float* out) {
+ const int64_t prefdist_T0 = 16;
+ const int64_t fused_block_size = block_size + 2;
+ if (block_size == 128) {
+ // unrolling 16 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
+ _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
+ // skip unnecessary prefetch of (&ip_next_T0[72])
+ vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
+ _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
+ vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
+ // skip unnecessary prefetch of (&ip_next_T0[88])
+ vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
+ _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+ vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
+ // skip unnecessary prefetch of (&ip_next_T0[104])
+ vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
+ _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
+ vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
+ // skip unnecessary prefetch of (&ip_next_T0[120])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
+ _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
+ vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
+ _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
+ vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else if (block_size == 16) {
+ // unrolling 2 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ TIndex j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
+ _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ for (; j < block_size; j++) {
+ op[j] += wgt * ip[j];
+ }
+ }
+ if (normalize_by_lengths && lengths[rangeIndex]) {
+ float len_inv = 1.0f / lengths[rangeIndex];
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+}
+
+void Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float__avx2_fma(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const float16* input,
+ const int32_t* indices,
+ const int* lengths,
+ const float* weights,
+ bool normalize_by_lengths,
+ float* out) {
+ const int32_t prefdist_T0 = 16;
+ const int32_t fused_block_size = block_size + 4;
+ if (block_size == 128) {
+ // unrolling 16 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
+ vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
+ vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
+ vop48);
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
+ vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ vop64 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
+ vop64);
+ _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
+ vop72);
+ // skip unnecessary prefetch of (&ip_next_T0[72])
+ vop80 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
+ vop80);
+ // skip unnecessary prefetch of (&ip_next_T0[80])
+ vop88 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
+ vop88);
+ // skip unnecessary prefetch of (&ip_next_T0[88])
+ vop96 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
+ vop96);
+ _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+ vop104 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
+ vop104);
+ // skip unnecessary prefetch of (&ip_next_T0[104])
+ vop112 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
+ vop112);
+ // skip unnecessary prefetch of (&ip_next_T0[112])
+ vop120 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
+ vop120);
+ // skip unnecessary prefetch of (&ip_next_T0[120])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
+ vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
+ vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
+ vop48);
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
+ vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else if (block_size == 16) {
+ // unrolling 2 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ TIndex j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(&ip[j]))),
+ _mm256_loadu_ps(&op[j])));
+ _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ float16 vtmp1[8] CAFFE2_ALIGNED(64);
+ for (; j < block_size; j++) {
+ vtmp1[0] = ip[j];
+ __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
+ op[j] += wgt * ((float*)(&vtmp2))[0];
+ }
+ }
+ if (normalize_by_lengths && lengths[rangeIndex]) {
+ float len_inv = 1.0f / lengths[rangeIndex];
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+}
+
+void Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float__avx2_fma(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const float16* input,
+ const int64_t* indices,
+ const int* lengths,
+ const float* weights,
+ bool normalize_by_lengths,
+ float* out) {
+ const int64_t prefdist_T0 = 16;
+ const int64_t fused_block_size = block_size + 4;
+ if (block_size == 128) {
+ // unrolling 16 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
+ vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
+ vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
+ vop48);
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
+ vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ vop64 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
+ vop64);
+ _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
+ vop72);
+ // skip unnecessary prefetch of (&ip_next_T0[72])
+ vop80 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
+ vop80);
+ // skip unnecessary prefetch of (&ip_next_T0[80])
+ vop88 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
+ vop88);
+ // skip unnecessary prefetch of (&ip_next_T0[88])
+ vop96 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
+ vop96);
+ _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
+ vop104 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
+ vop104);
+ // skip unnecessary prefetch of (&ip_next_T0[104])
+ vop112 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
+ vop112);
+ // skip unnecessary prefetch of (&ip_next_T0[112])
+ vop120 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
+ vop120);
+ // skip unnecessary prefetch of (&ip_next_T0[120])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
+ vop32);
+ _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
+ vop40);
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
+ vop48);
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
+ vop56);
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
+ vop16);
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
+ vop24);
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else if (block_size == 16) {
+ // unrolling 2 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
+ vop0);
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
+ vop8);
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ TIndex j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const float16* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtph_ps(_mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(&ip[j]))),
+ _mm256_loadu_ps(&op[j])));
+ _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ float16 vtmp1[8] CAFFE2_ALIGNED(64);
+ for (; j < block_size; j++) {
+ vtmp1[0] = ip[j];
+ __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
+ op[j] += wgt * ((float*)(&vtmp2))[0];
+ }
+ }
+ if (normalize_by_lengths && lengths[rangeIndex]) {
+ float len_inv = 1.0f / lengths[rangeIndex];
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+}
+
+void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const uint8_t* input,
+ const int32_t* indices,
+ const int* lengths,
+ const float* weights,
+ bool normalize_by_lengths,
+ float* out) {
+ const int32_t prefdist_T0 = 16;
+ const int32_t fused_block_size = block_size + 8;
+ if (block_size == 128) {
+ // unrolling 16 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ float bio;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ const float* scale_bias = reinterpret_cast<const float*>(
+ &input[idx * fused_block_size + block_size]);
+ bio = wgt * scale_bias[1];
+ wgt = wgt * scale_bias[0];
+ __m256 vbio = _mm256_set1_ps(bio);
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const uint8_t* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
+ _mm256_add_ps(vop0, vbio));
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
+ _mm256_add_ps(vop8, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
+ _mm256_add_ps(vop16, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
+ _mm256_add_ps(vop24, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
+ _mm256_add_ps(vop32, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[32])
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
+ _mm256_add_ps(vop40, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
+ _mm256_add_ps(vop48, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
+ _mm256_add_ps(vop56, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ vop64 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
+ _mm256_add_ps(vop64, vbio));
+ _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
+ _mm256_add_ps(vop72, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[72])
+ vop80 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
+ _mm256_add_ps(vop80, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[80])
+ vop88 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
+ _mm256_add_ps(vop88, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[88])
+ vop96 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
+ _mm256_add_ps(vop96, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[96])
+ vop104 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
+ _mm256_add_ps(vop104, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[104])
+ vop112 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
+ _mm256_add_ps(vop112, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[112])
+ vop120 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
+ _mm256_add_ps(vop120, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[120])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ float bio;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ const float* scale_bias = reinterpret_cast<const float*>(
+ &input[idx * fused_block_size + block_size]);
+ bio = wgt * scale_bias[1];
+ wgt = wgt * scale_bias[0];
+ __m256 vbio = _mm256_set1_ps(bio);
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const uint8_t* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
+ _mm256_add_ps(vop0, vbio));
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
+ _mm256_add_ps(vop8, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
+ _mm256_add_ps(vop16, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
+ _mm256_add_ps(vop24, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
+ _mm256_add_ps(vop32, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[32])
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
+ _mm256_add_ps(vop40, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
+ _mm256_add_ps(vop48, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
+ _mm256_add_ps(vop56, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ float bio;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ const float* scale_bias = reinterpret_cast<const float*>(
+ &input[idx * fused_block_size + block_size]);
+ bio = wgt * scale_bias[1];
+ wgt = wgt * scale_bias[0];
+ __m256 vbio = _mm256_set1_ps(bio);
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const uint8_t* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
+ _mm256_add_ps(vop0, vbio));
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
+ _mm256_add_ps(vop8, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
+ _mm256_add_ps(vop16, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
+ _mm256_add_ps(vop24, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else if (block_size == 16) {
+ // unrolling 2 times
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ float bio;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ const float* scale_bias = reinterpret_cast<const float*>(
+ &input[idx * fused_block_size + block_size]);
+ bio = wgt * scale_bias[1];
+ wgt = wgt * scale_bias[0];
+ __m256 vbio = _mm256_set1_ps(bio);
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const uint8_t* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
+ _mm256_add_ps(vop0, vbio));
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
+ _mm256_add_ps(vop8, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ int32_t dataInd = 0;
+ for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ TIndex j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int32_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ float bio;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ const float* scale_bias = reinterpret_cast<const float*>(
+ &input[idx * fused_block_size + block_size]);
+ bio = wgt * scale_bias[1];
+ wgt = wgt * scale_bias[0];
+ __m256 vbio = _mm256_set1_ps(bio);
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const uint8_t* ip = &input[idx * fused_block_size];
+ const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int32_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
+ reinterpret_cast<const __m128i*>(&ip[j])))),
+ _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
+ _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ for (; j < block_size; j++) {
+ op[j] += wgt * ((float)ip[j]) + bio;
+ }
+ }
+ if (normalize_by_lengths && lengths[rangeIndex]) {
+ float len_inv = 1.0f / lengths[rangeIndex];
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+}
+
+void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const uint8_t* input,
+ const int64_t* indices,
+ const int* lengths,
+ const float* weights,
+ bool normalize_by_lengths,
+ float* out) {
+ const int64_t prefdist_T0 = 16;
+ const int64_t fused_block_size = block_size + 8;
+ if (block_size == 128) {
+ // unrolling 16 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ __m256 vop64 = _mm256_setzero_ps();
+ __m256 vop72 = _mm256_setzero_ps();
+ __m256 vop80 = _mm256_setzero_ps();
+ __m256 vop88 = _mm256_setzero_ps();
+ __m256 vop96 = _mm256_setzero_ps();
+ __m256 vop104 = _mm256_setzero_ps();
+ __m256 vop112 = _mm256_setzero_ps();
+ __m256 vop120 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ float bio;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ const float* scale_bias = reinterpret_cast<const float*>(
+ &input[idx * fused_block_size + block_size]);
+ bio = wgt * scale_bias[1];
+ wgt = wgt * scale_bias[0];
+ __m256 vbio = _mm256_set1_ps(bio);
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const uint8_t* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
+ _mm256_add_ps(vop0, vbio));
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
+ _mm256_add_ps(vop8, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
+ _mm256_add_ps(vop16, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
+ _mm256_add_ps(vop24, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
+ _mm256_add_ps(vop32, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[32])
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
+ _mm256_add_ps(vop40, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
+ _mm256_add_ps(vop48, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
+ _mm256_add_ps(vop56, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ vop64 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
+ _mm256_add_ps(vop64, vbio));
+ _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
+ vop72 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
+ _mm256_add_ps(vop72, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[72])
+ vop80 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
+ _mm256_add_ps(vop80, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[80])
+ vop88 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
+ _mm256_add_ps(vop88, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[88])
+ vop96 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
+ _mm256_add_ps(vop96, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[96])
+ vop104 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
+ _mm256_add_ps(vop104, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[104])
+ vop112 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
+ _mm256_add_ps(vop112, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[112])
+ vop120 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
+ _mm256_add_ps(vop120, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[120])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ _mm256_storeu_ps(&op[64], vop64);
+ _mm256_storeu_ps(&op[72], vop72);
+ _mm256_storeu_ps(&op[80], vop80);
+ _mm256_storeu_ps(&op[88], vop88);
+ _mm256_storeu_ps(&op[96], vop96);
+ _mm256_storeu_ps(&op[104], vop104);
+ _mm256_storeu_ps(&op[112], vop112);
+ _mm256_storeu_ps(&op[120], vop120);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
+ _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
+ _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
+ _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
+ _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
+ _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
+ _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
+ _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
+ }
+ }
+ } else if (block_size == 64) {
+ // unrolling 8 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ __m256 vop32 = _mm256_setzero_ps();
+ __m256 vop40 = _mm256_setzero_ps();
+ __m256 vop48 = _mm256_setzero_ps();
+ __m256 vop56 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ float bio;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ const float* scale_bias = reinterpret_cast<const float*>(
+ &input[idx * fused_block_size + block_size]);
+ bio = wgt * scale_bias[1];
+ wgt = wgt * scale_bias[0];
+ __m256 vbio = _mm256_set1_ps(bio);
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const uint8_t* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
+ _mm256_add_ps(vop0, vbio));
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
+ _mm256_add_ps(vop8, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
+ _mm256_add_ps(vop16, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
+ _mm256_add_ps(vop24, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ vop32 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
+ _mm256_add_ps(vop32, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[32])
+ vop40 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
+ _mm256_add_ps(vop40, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[40])
+ vop48 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
+ _mm256_add_ps(vop48, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[48])
+ vop56 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
+ _mm256_add_ps(vop56, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[56])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ _mm256_storeu_ps(&op[32], vop32);
+ _mm256_storeu_ps(&op[40], vop40);
+ _mm256_storeu_ps(&op[48], vop48);
+ _mm256_storeu_ps(&op[56], vop56);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
+ _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
+ _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
+ _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
+ }
+ }
+ } else if (block_size == 32) {
+ // unrolling 4 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ __m256 vop16 = _mm256_setzero_ps();
+ __m256 vop24 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ float bio;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ const float* scale_bias = reinterpret_cast<const float*>(
+ &input[idx * fused_block_size + block_size]);
+ bio = wgt * scale_bias[1];
+ wgt = wgt * scale_bias[0];
+ __m256 vbio = _mm256_set1_ps(bio);
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const uint8_t* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
+ _mm256_add_ps(vop0, vbio));
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
+ _mm256_add_ps(vop8, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ vop16 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
+ _mm256_add_ps(vop16, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[16])
+ vop24 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
+ _mm256_add_ps(vop24, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[24])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ _mm256_storeu_ps(&op[16], vop16);
+ _mm256_storeu_ps(&op[24], vop24);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
+ _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
+ }
+ }
+ } else if (block_size == 16) {
+ // unrolling 2 times
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ __m256 vop0 = _mm256_setzero_ps();
+ __m256 vop8 = _mm256_setzero_ps();
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ float bio;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ const float* scale_bias = reinterpret_cast<const float*>(
+ &input[idx * fused_block_size + block_size]);
+ bio = wgt * scale_bias[1];
+ wgt = wgt * scale_bias[0];
+ __m256 vbio = _mm256_set1_ps(bio);
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const uint8_t* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ vop0 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
+ _mm256_add_ps(vop0, vbio));
+ _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
+ vop8 = _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
+ _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
+ _mm256_add_ps(vop8, vbio));
+ // skip unnecessary prefetch of (&ip_next_T0[8])
+ }
+ if (normalize_by_lengths == false) {
+ _mm256_storeu_ps(&op[0], vop0);
+ _mm256_storeu_ps(&op[8], vop8);
+ } else if (lengths[rangeIndex]) {
+ __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
+ _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
+ _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
+ }
+ }
+ } else {
+ // generic code
+ int64_t dataInd = 0;
+ for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
+ float* op = &out[rangeIndex * block_size];
+ TIndex j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(op + j, _mm256_setzero_ps());
+ }
+ for (; j < block_size; j++) {
+ op[j] = 0.0f;
+ }
+ for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
+ ++dataInd) {
+ const int64_t idx = indices[dataInd];
+ CAFFE_ENFORCE(
+ idx >= 0 && idx < data_size,
+ "Index ",
+ dataInd,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ float wgt = 1.f;
+ float bio;
+ if (weights) {
+ wgt = weights[dataInd];
+ }
+ const float* scale_bias = reinterpret_cast<const float*>(
+ &input[idx * fused_block_size + block_size]);
+ bio = wgt * scale_bias[1];
+ wgt = wgt * scale_bias[0];
+ __m256 vbio = _mm256_set1_ps(bio);
+ __m256 vwgt = _mm256_set1_ps(wgt);
+ const uint8_t* ip = &input[idx * fused_block_size];
+ const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
+ ? (dataInd + prefdist_T0)
+ : dataInd;
+ const int64_t idx_pref_T0 = indices[next_T0];
+ CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
+ const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j],
+ _mm256_fmadd_ps(
+ vwgt,
+ _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
+ reinterpret_cast<const __m128i*>(&ip[j])))),
+ _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
+ _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
+ }
+ for (; j < block_size; j++) {
+ op[j] += wgt * ((float)ip[j]) + bio;
+ }
+ }
+ if (normalize_by_lengths && lengths[rangeIndex]) {
+ float len_inv = 1.0f / lengths[rangeIndex];
+ __m256 vlen_inv = _mm256_set1_ps(len_inv);
+ j = 0;
+ for (; j + 8 <= block_size; j += 8) {
+ _mm256_storeu_ps(
+ &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
+ }
+ for (; j < block_size; j++) {
+ op[j] = len_inv * op[j];
+ }
+ }
+ }
+ }
+}
+
+} // namespace caffe2
diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc
new file mode 100644
index 0000000000..7aa756bdeb
--- /dev/null
+++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.cc
@@ -0,0 +1,170 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h"
+
+#include "caffe2/core/types.h"
+#include "caffe2/perfkernels/common.h"
+#include "caffe2/perfkernels/typed_axpy.h"
+#include "caffe2/utils/cpuid.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+// Base implementation does runtime dispatch for each segment of reduction
+template <typename IndexType, typename InType, typename OutType>
+static void Fused8BitRowwiseEmbeddingLookupGenericSlow(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const InType* input,
+ const IndexType* indices,
+ const int* lengths,
+ const float* weights, // optional, can be null for sum reducer
+ bool normalize_by_lengths,
+ OutType* out) {
+ // block_size is the number of elements and fused_block_size is the size of
+ // an entire row, including scale and bias.
+ const auto scale_bias_offset = 8 / sizeof(InType);
+ const TIndex fused_block_size = block_size + scale_bias_offset;
+ TIndex current = 0;
+ for (int m = 0; m < output_size; ++m) {
+ memset(out, 0, sizeof(OutType) * block_size);
+ EigenVectorArrayMap<OutType> out_vector(out, block_size);
+ for (int i = 0; i < lengths[m]; ++i) {
+ CAFFE_ENFORCE_LT(current, index_size);
+ TIndex idx = indices[current];
+ CAFFE_ENFORCE(
+ 0 <= idx && idx < data_size,
+ "Index ",
+ current,
+ " is out of bounds: ",
+ idx,
+ ", range 0 to ",
+ data_size);
+ CAFFE_ENFORCE_LT(idx, data_size);
+#ifdef __GNUC__
+ if (current + 1 < index_size) {
+ __builtin_prefetch(
+ input + fused_block_size * indices[current + 1], 0, 1);
+ }
+#endif // __GNUC__
+
+ const float* scale_bias = reinterpret_cast<const float*>(
+ input + fused_block_size * indices[current] + block_size);
+
+ const float weight = weights ? weights[current] : 1.0f;
+ const float scale = weight * scale_bias[0];
+ const float bias = weight * scale_bias[1];
+
+ TypedAxpy<InType, OutType>(
+ block_size, scale, input + fused_block_size * indices[current], out);
+
+ out_vector += bias;
+
+ ++current;
+ }
+ if (normalize_by_lengths && lengths[m]) {
+ // hack: context is not really used
+ math::Scale<OutType, CPUContext>(
+ block_size, 1.f / lengths[m], out, out, nullptr);
+ }
+ out += block_size;
+ }
+ CAFFE_ENFORCE_EQ(
+ current,
+ index_size,
+ "Your input seems to be incorrect: the sum of lengths values should be "
+ "the size of the indices tensor, but it appears not.");
+}
+
+// Proxy back to generic implementation
+#define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION( \
+ IndexType, InType, OutType) \
+ void \
+ Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType##__base( \
+ const TIndex block_size, \
+ const TIndex output_size, \
+ const TIndex index_size, \
+ const TIndex data_size, \
+ const InType* input, \
+ const IndexType* indices, \
+ const int* lengths, \
+ const float* weights, \
+ bool normalize_by_lengths, \
+ OutType* out) { \
+ Fused8BitRowwiseEmbeddingLookupGenericSlow<IndexType, InType, OutType>( \
+ block_size, \
+ output_size, \
+ index_size, \
+ data_size, \
+ input, \
+ indices, \
+ lengths, \
+ weights, \
+ normalize_by_lengths, \
+ out); \
+ } \
+ template <> \
+ void Fused8BitRowwiseEmbeddingLookup( \
+ const TIndex block_size, \
+ const TIndex output_size, \
+ const TIndex index_size, \
+ const TIndex data_size, \
+ const InType* input, \
+ const IndexType* indices, \
+ const int* lengths, \
+ const float* weights, \
+ bool normalize_by_lengths, \
+ OutType* out) { \
+ const int32_t one = 1; \
+ CAFFE_ENFORCE_EQ( \
+ reinterpret_cast<const uint8_t*>(&one)[0], \
+ 1, \
+ "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \
+ AVX2_FMA_DO( \
+ Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType, \
+ block_size, \
+ output_size, \
+ index_size, \
+ data_size, \
+ input, \
+ indices, \
+ lengths, \
+ weights, \
+ normalize_by_lengths, \
+ out); \
+ BASE_DO( \
+ Fused8BitRowwiseEmbeddingLookup_##IndexType##_##InType##_##OutType, \
+ block_size, \
+ output_size, \
+ index_size, \
+ data_size, \
+ input, \
+ indices, \
+ lengths, \
+ weights, \
+ normalize_by_lengths, \
+ out); \
+ }
+
+FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t, uint8_t, float);
+FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t, uint8_t, float);
+
+#undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION
+
+} // namespace caffe2
diff --git a/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h
new file mode 100644
index 0000000000..5251331e3e
--- /dev/null
+++ b/caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h
@@ -0,0 +1,67 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "caffe2/core/common.h"
+
+namespace caffe2 {
+
+/**
+ * Embedding lookup with reduction.
+ *
+ * `input` of size data_size * (block_size + 8B)
+ * `indices` of size index_size
+ * `lengths` of size output_size
+ * `weights` nullptr or array of size index_size
+ * `out` of size output_size * block_size
+ * sum(lengths[i]) == index_size
+ *
+ * Note that block_size should be the number of quantized values per row in the
+ * data, i.e. excluding the scale and bias. The total (fused) block size is
+ * assumed to be this block_size, plus 4 bytes for scale and 4 bytes for bias.
+ *
+ * Behavior is roughly equivalent to pseudocode:
+ *
+ * pos = 0
+ * fused_block_size = block_size + 8B // quantized values and scale and bias
+ * for (i = 0..index_size-1)
+ * for (k = 0..block_size-1)
+ * out[i*block_size + k] = 0
+ * for (j = 0..lengths[i]-1)
+ * for (k = 0..block_size-1)
+ * out[i*block_size + k] += input[indices[pos]*(fused_block_size) + k] *
+ * (weights ? weights[pos] : 1.0)
+ * pos += 1
+ * if (normalize_weights && lengths[i] > 0)
+ * for (k = 0..block_size-1)
+ * out[i*block_size + k] /= lengths[i]
+ *
+ */
+
+template <typename IndexType, typename InType, typename OutType>
+void Fused8BitRowwiseEmbeddingLookup(
+ const TIndex block_size,
+ const TIndex output_size,
+ const TIndex index_size,
+ const TIndex data_size,
+ const InType* input,
+ const IndexType* indices,
+ const int* lengths,
+ const float* weights, // optional, can be null for non-weighted sum
+ bool normalize_by_lengths,
+ OutType* out);
+} // namespace caffe2
diff --git a/caffe2/perfkernels/hp_emblookup_codegen.py b/caffe2/perfkernels/hp_emblookup_codegen.py
index 87c656a8be..28afc95841 100644
--- a/caffe2/perfkernels/hp_emblookup_codegen.py
+++ b/caffe2/perfkernels/hp_emblookup_codegen.py
@@ -20,47 +20,40 @@ from __future__ import unicode_literals
import argparse
import sys
+sizeof = {'float': 4, 'float16': 2, 'uint8_t': 1}
-def unroll(uf, IndexType, InType, OutType, use_weights, isa):
-
- def sizeof(InType):
- size = 0
- if InType == "float":
- size = 4
- elif InType == "float16":
- size = 2
- elif InType == "uint8_t":
- size = 1
- else:
- assert False
-
- return size
+def unroll(uf, IndexType, InType, OutType, use_weights, isa, fused):
def compute(regid, InType, use_weights, isa, prefetch):
code = []
if InType == "float":
- code.append("vop%d = _mm256_fmadd_ps(vwgt, \
- _mm256_loadu_ps(ip + (%d)), vop%d);" % (regid, regid, regid))
-
+ code.append(
+ "vop%d = _mm256_fmadd_ps(vwgt, \
+ _mm256_loadu_ps(ip + (%d)), vop%d);"
+ % (regid, regid, regid)
+ )
elif InType == "float16":
- code.append("vop%d = _mm256_fmadd_ps(vwgt, \
+ code.append(
+ "vop%d = _mm256_fmadd_ps(vwgt, \
_mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (%d)))), \
vop%d);"
- % (regid, regid, regid))
+ % (regid, regid, regid)
+ )
elif InType == "uint8_t":
- code.append("vop%d = _mm256_fmadd_ps(vwgt, \
+ code.append(
+ "vop%d = _mm256_fmadd_ps(vwgt, \
_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (%d))))), \
_mm256_add_ps(vop%d, vbio));"
- % (regid, regid, regid))
+ % (regid, regid, regid)
+ )
else:
assert False
-
- if prefetch == True:
+ if prefetch:
code.append("_mm_prefetch((&ip_next_T0[%d]), _MM_HINT_T0);" % (regid))
else:
- code.append("// skip unecassery prefetch of (&ip_next_T0[%d])" % (regid))
+ code.append("// skip unnecessary prefetch of (&ip_next_T0[%d])" % (regid))
return code
@@ -87,8 +80,16 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa):
code.append("if (weights) {")
code.append("wgt = weights[dataInd];")
code.append("}")
- code.append("bio = wgt * scale_bias[2 * idx + 1];");
- code.append("wgt = wgt * scale_bias[2 * idx];");
+ if fused:
+ code.append(
+ 'const float* scale_bias = reinterpret_cast<'
+ 'const float*>(&input[idx * fused_block_size + block_size]);'
+ )
+ code.append("bio = wgt * scale_bias[1];")
+ code.append("wgt = wgt * scale_bias[0];")
+ else:
+ code.append("bio = wgt * scale_bias[2 * idx + 1];")
+ code.append("wgt = wgt * scale_bias[2 * idx];")
code.append("__m256 vbio = _mm256_set1_ps(bio);")
else:
code.append(OutType + " wgt = 1.f;")
@@ -97,20 +98,25 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa):
code.append("}")
code.append("__m256 vwgt = _mm256_set1_ps(wgt);")
- code.append("const " + InType + " *ip = &input[idx * block_size];")
- code.append("const " + IndexType +
- " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;");
+ code.append("const {} *ip = &input[idx * fused_block_size];".format(InType))
+ code.append(
+ 'const {} next_T0 = (dataInd < index_size - prefdist_T0)'
+ ' ? (dataInd + prefdist_T0) : dataInd;'.format(IndexType)
+ )
code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];")
code.append(
"CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);")
- code.append("const " + InType +
- " *ip_next_T0 = &input[idx_pref_T0 * block_size];")
+
+ code.append(
+ 'const {} *ip_next_T0 = &input[idx_pref_T0'
+ ' * fused_block_size];'.format(InType)
+ )
for i in range(0, uf):
j = 8 * i
cachelinesize = 64
- byteoffset = sizeof(InType) * j
- prefetch = ((byteoffset % cachelinesize) == 0)
+ byteoffset = sizeof[InType] * j
+ prefetch = (byteoffset % cachelinesize) == 0
code.extend(compute(j, InType, use_weights, isa, prefetch))
code.append("}")
@@ -133,25 +139,31 @@ def unroll(uf, IndexType, InType, OutType, use_weights, isa):
return code
-def generic(IndexType, InType, OutType, use_weights, isa):
+def generic(IndexType, InType, OutType, use_weights, isa, fused):
def compute(InType, use_weights, isa):
code = []
if InType == "float":
- code.append("_mm256_storeu_ps(&op[j], \
+ code.append(
+ "_mm256_storeu_ps(&op[j], \
_mm256_fmadd_ps(vwgt,_mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])) \
- );")
+ );"
+ )
elif InType == "float16":
- code.append("_mm256_storeu_ps(&op[j], \
+ code.append(
+ "_mm256_storeu_ps(&op[j], \
_mm256_fmadd_ps(vwgt, \
_mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(&ip[j]))), _mm256_loadu_ps(&op[j])) \
- );")
+ );"
+ )
elif InType == "uint8_t":
- code.append("_mm256_storeu_ps(&op[j], \
+ code.append(
+ "_mm256_storeu_ps(&op[j], \
_mm256_fmadd_ps(vwgt, \
_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<const __m128i*>(&ip[j])))), \
_mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio) ) \
- );")
+ );"
+ )
else:
assert False
@@ -188,9 +200,17 @@ def generic(IndexType, InType, OutType, use_weights, isa):
code.append("if (weights) {")
code.append("wgt = weights[dataInd];")
code.append("}")
- code.append("assert (scale_bias);")
- code.append("bio = wgt * scale_bias[2 * idx + 1];");
- code.append("wgt = wgt * scale_bias[2 * idx];");
+ if fused:
+ code.append(
+ 'const float* scale_bias = reinterpret_cast<'
+ 'const float*>(&input[idx * fused_block_size + block_size]);'
+ )
+ code.append("bio = wgt * scale_bias[1];")
+ code.append("wgt = wgt * scale_bias[0];")
+ else:
+ code.append("assert (scale_bias);")
+ code.append("bio = wgt * scale_bias[2 * idx + 1];")
+ code.append("wgt = wgt * scale_bias[2 * idx];")
code.append("__m256 vbio = _mm256_set1_ps(bio);")
else:
code.append(OutType + " wgt = 1.f;")
@@ -199,14 +219,18 @@ def generic(IndexType, InType, OutType, use_weights, isa):
code.append("}")
code.append("__m256 vwgt = _mm256_set1_ps(wgt);")
- code.append("const " + InType + " *ip = &input[idx * block_size];")
- code.append("const " + IndexType +
- " next_T0 = (dataInd < index_size - prefdist_T0) ? (dataInd + prefdist_T0) : dataInd;");
+ code.append("const {} *ip = &input[idx * fused_block_size];".format(InType))
+ code.append(
+ 'const {} next_T0 = (dataInd < index_size - prefdist_T0)'
+ ' ? (dataInd + prefdist_T0) : dataInd;'.format(IndexType)
+ )
code.append("const " + IndexType + " idx_pref_T0 = indices[next_T0];")
code.append(
"CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);")
- code.append("const " + InType +
- " *ip_next_T0 = &input[idx_pref_T0 * block_size];")
+ code.append(
+ "const {} *ip_next_T0 = &input[idx_pref_T0 * fused_block_size];".
+ format(InType)
+ )
# compute and store main loop
code.append("j = 0;")
@@ -215,7 +239,6 @@ def generic(IndexType, InType, OutType, use_weights, isa):
code.append("}")
# leftover
if InType == "float16":
- #code.append("float16 vtmp1[8] __attribute__((aligned(64)));")
code.append("float16 vtmp1[8] CAFFE2_ALIGNED(64);")
code.append("for(; j < block_size; j++) {")
if InType == "float":
@@ -251,14 +274,17 @@ def generic(IndexType, InType, OutType, use_weights, isa):
return code
-
# start main code
parser = argparse.ArgumentParser()
-parser.add_argument('-f', nargs=1, help="file name")
+parser.add_argument('-f', '--filename', help="file name")
+parser.add_argument('--fused', action='store_true')
opts = parser.parse_args()
-filename = "embedding_lookup_avx2.cc"
-if opts.f:
- filename = (opts.f)[0]
+if opts.filename:
+ filename = opts.filename
+elif opts.fused:
+ filename = "embedding_lookup_fused_8bit_rowwise_avx2.cc"
+else:
+ filename = "embedding_lookup_avx2.cc"
fout = open(filename, 'w')
options = [["int32_t", "float", "float"],
@@ -289,14 +315,14 @@ code.append(
*/
""")
code.append("//// --------------------------")
-code.append("//// ATTENTION: ")
+code.append("//// ATTENTION:")
code.append("//// THIS CODE IS AUTOGENERATED")
-code.append("//// BY %s " % (sys.argv[0]))
-code.append("//// DO NOT MODIFY!!! ")
+code.append("//// BY {}".format(sys.argv[0]))
+code.append("//// DO NOT MODIFY!!!")
code.append("//// --------------------------\n\n")
-code.append("#include \"caffe2/core/types.h\"")
-code.append("#include \"caffe2/core/common.h\"")
+code.append("#include <caffe2/core/types.h>")
+code.append("#include <caffe2/core/common.h>")
code.append("#include <immintrin.h>")
code.append("\n")
@@ -304,9 +330,11 @@ code.append("namespace caffe2 {\n")
for o in options:
[IndexType, InType, OutType] = o
- fn = "void EmbeddingLookup_" + IndexType + \
- "_" + InType + "_" + OutType + "__avx2_fma"
- code.append(fn + "(")
+ prefix = 'Fused8BitRowwise' if opts.fused else ''
+ fn = 'void {}EmbeddingLookup_{}_{}_{}__avx2_fma('.format(
+ prefix, IndexType, InType, OutType
+ )
+ code.append(fn)
code.append("const TIndex block_size,")
code.append("const TIndex output_size,")
code.append("const TIndex index_size,")
@@ -315,29 +343,45 @@ for o in options:
code.append("const " + IndexType + "* indices,")
code.append("const int* lengths,")
code.append("const float* weights,")
- code.append("const float* scale_bias,")
+ if not opts.fused:
+ code.append("const float* scale_bias,")
code.append("bool normalize_by_lengths,")
code.append(OutType + "* out)")
code.append("{")
code.append("const " + IndexType + " prefdist_T0 = 16;")
+ # block_size is the number of elements and fused_block_size is the size of
+ # an entire row, including scale and bias.
+ offset = (8 // sizeof[InType]) if opts.fused else 0
+ code.append(
+ "const {} fused_block_size = block_size + {};".
+ format(IndexType, offset)
+ )
+
#code.append("printf(\"calling " + fn + "\\n\");");
- if InType != "uint8_t":
- code.append("CAFFE_ENFORCE(scale_bias == nullptr, \"scale_bias must be nullptr\");");
- else:
- code.append("CAFFE_ENFORCE(scale_bias != nullptr, \"scale_bias must not be nullptr\");");
+ if not opts.fused:
+ if InType != "uint8_t":
+ code.append(
+ 'CAFFE_ENFORCE(scale_bias == nullptr,'
+ ' "scale_bias must be nullptr");'
+ )
+ else:
+ code.append(
+ 'CAFFE_ENFORCE(scale_bias != nullptr,'
+ ' "scale_bias must not be nullptr");'
+ )
code.append("if (block_size == 128) {")
- code.extend(unroll(16, IndexType, InType, OutType, True, "AVX2"))
+ code += unroll(16, IndexType, InType, OutType, True, "AVX2", opts.fused)
code.append("} else if (block_size == 64) {")
- code.extend(unroll(8, IndexType, InType, OutType, True, "AVX2"))
+ code += unroll(8, IndexType, InType, OutType, True, "AVX2", opts.fused)
code.append("} else if (block_size == 32) {")
- code.extend(unroll(4, IndexType, InType, OutType, True, "AVX2"))
+ code += unroll(4, IndexType, InType, OutType, True, "AVX2", opts.fused)
code.append("} else if (block_size == 16) {")
- code.extend(unroll(2, IndexType, InType, OutType, True, "AVX2"))
+ code += unroll(2, IndexType, InType, OutType, True, "AVX2", opts.fused)
code.append("} else {")
code.append("// generic code")
- code.extend(generic(IndexType, InType, OutType, True, "AVX2"))
+ code += generic(IndexType, InType, OutType, True, "AVX2", opts.fused)
code.append("}")