summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--aten/src/ATen/native/EmbeddingBag.cpp223
-rw-r--r--aten/src/ATen/native/cuda/EmbeddingBag.cu1
-rw-r--r--aten/src/ATen/native/native_functions.yaml2
-rw-r--r--caffe2/perfkernels/embedding_lookup.h3
-rw-r--r--test/test_nn.py19
-rw-r--r--tools/autograd/derivatives.yaml2
6 files changed, 211 insertions, 39 deletions
diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp
index 9fa2bdb953..80bbf9c118 100644
--- a/aten/src/ATen/native/EmbeddingBag.cpp
+++ b/aten/src/ATen/native/EmbeddingBag.cpp
@@ -5,6 +5,8 @@
#include <TH/THBlasUtils.h>
+#include <caffe2/perfkernels/embedding_lookup.h>
+
#include <cstring>
#include <iostream>
#include <memory>
@@ -24,22 +26,32 @@ namespace {
namespace at {
namespace native {
-static void make_offset2bag(const Tensor &offsets, const Tensor &indices,
- Tensor &offset2bag) {
+static void make_offset2bag(const Tensor &offsets, const Tensor &indices, Tensor& offset2bag) {
offset2bag.index_add_(
0, offsets, at::ones_like(offsets)); // offset2bag = [1 0 1 0 1]
offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1]
offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2]
}
+namespace {
+
+bool isFastPathIndexSelect(const Tensor& src, Tensor& output) {
+ return src.scalar_type() == kFloat && src.stride(1) == 1 && output.stride(1) == 1;
+}
+
+bool isFastPathIndexSelectScale(const Tensor& src, const Tensor& scale, Tensor& output) {
+ return src.scalar_type() == kFloat && src.stride(1) == 1 && output.stride(1) == 1 && scale.stride(0) == 1;
+}
+
// This function combines index_select (using select_indices as the index) and
// index_add (using add_indices as the index), without creating an intermediary
// tensor to hold the selected embeddings
template<typename T>
-static void index_select_add(const Tensor &select_indices,
+void index_select_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &src,
- Tensor &output) {
+ Tensor &output,
+ const Tensor& /*offsets*/) {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto add_indices_data = add_indices.data<int64_t>();
auto select_indices_data = select_indices.data<int64_t>();
@@ -59,6 +71,57 @@ static void index_select_add(const Tensor &select_indices,
}
}
+template<>
+void index_select_add<float>(const Tensor &select_indices,
+ const Tensor &add_indices,
+ const Tensor &src,
+ Tensor &output,
+ const Tensor& offsets) {
+ int64_t ddim = src.size(1);
+ auto src_data = src.data<float>();
+ auto select_indices_data = select_indices.data<int64_t>();
+ auto output_data = output.data<float>();
+
+ if (isFastPathIndexSelect(src, output)) {
+ auto accessor = offsets.accessor<int64_t, 1>();
+ std::vector<int> lengths;
+
+ int64_t lower = accessor[0];
+ for (size_t i = 1; i < offsets.numel(); ++i) {
+ lengths.push_back(accessor[i] - lower);
+ lower = accessor[i];
+ }
+ lengths.push_back(select_indices.numel() - lower);
+
+ caffe2::EmbeddingLookup(
+ /*block_size=*/ddim,
+ /*output_size=*/lengths.size(),
+ /*index_size=*/select_indices.numel(),
+ /*data_size=*/src.size(0),
+ /*input=*/src_data,
+ /*indices=*/select_indices_data,
+ /*lengths=*/lengths.data(),
+ /*weights=*/nullptr,
+ /*scale_bias=*/nullptr,
+ /*normalize_by_lengths=*/false,
+ /*out=*/output_data
+ );
+ } else {
+ AT_ASSERT(select_indices.numel() == add_indices.numel());
+ auto add_indices_data = add_indices.data<int64_t>();
+ auto src_stride0 = src.stride(0);
+ auto src_stride1 = src.stride(1);
+ auto output_stride0 = output.stride(0);
+ auto output_stride1 = output.stride(1);
+ auto numel = add_indices.numel();
+ for (int64_t i = 0; i < numel; i++) {
+ THBlas_axpy<float>(ddim, 1,
+ src_data + src_stride0 * select_indices_data[i], src_stride1,
+ output_data + output_stride0 * add_indices_data[i], output_stride1);
+ }
+ }
+}
+
// This function fuses the following three fns:
// index_select (using select_indices as the index)
// mul (scaling by per_sample_weights)
@@ -68,7 +131,8 @@ static void index_select_scale_add(const Tensor &select_indices,
const Tensor &add_indices,
const Tensor &scale,
const Tensor &src,
- Tensor &output) {
+ Tensor &output,
+ const Tensor& /*offsets*/) {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto add_indices_data = add_indices.data<int64_t>();
auto select_indices_data = select_indices.data<int64_t>();
@@ -84,7 +148,6 @@ static void index_select_scale_add(const Tensor &select_indices,
auto* scale_data = scale.data<T>();
auto scale_stride = scale.stride(0);
- // XXX: We could make this faster via vectorization
for (int64_t i = 0; i < numel; i++) {
auto* src_base = src_data + src_stride0 * select_indices_data[i];
auto* output_base = output_data + output_stride0 * add_indices_data[i];
@@ -95,6 +158,67 @@ static void index_select_scale_add(const Tensor &select_indices,
}
}
+template<>
+void index_select_scale_add<float>(const Tensor &select_indices,
+ const Tensor &add_indices,
+ const Tensor &scale,
+ const Tensor &src,
+ Tensor &output,
+ const Tensor& offsets) {
+ int64_t ddim = src.size(1);
+ auto* scale_data = scale.data<float>();
+ auto select_indices_data = select_indices.data<int64_t>();
+ auto src_data = src.data<float>();
+ auto output_data = output.data<float>();
+
+ if (isFastPathIndexSelectScale(src, scale, output)) {
+ auto accessor = offsets.accessor<int64_t, 1>();
+ std::vector<int> lengths;
+
+ int64_t lower = accessor[0];
+ for (size_t i = 1; i < offsets.numel(); ++i) {
+ lengths.push_back(accessor[i] - lower);
+ lower = accessor[i];
+ }
+ lengths.push_back(select_indices.numel() - lower);
+
+ caffe2::EmbeddingLookup(
+ /*block_size=*/ddim,
+ /*output_size=*/lengths.size(),
+ /*index_size=*/select_indices.numel(),
+ /*data_size=*/src.size(0),
+ /*input=*/src_data,
+ /*indices=*/select_indices_data,
+ /*lengths=*/lengths.data(),
+ /*weights=*/scale_data,
+ /*scale_bias=*/nullptr,
+ /*normalize_by_lengths=*/false,
+ /*out=*/output_data
+ );
+ } else {
+ AT_ASSERT(select_indices.numel() == add_indices.numel());
+ auto add_indices_data = add_indices.data<int64_t>();
+ auto src_stride0 = src.stride(0);
+ auto src_stride1 = src.stride(1);
+ auto output_stride0 = output.stride(0);
+ auto output_stride1 = output.stride(1);
+ auto scale_stride = scale.stride(0);
+ auto numel = add_indices.numel();
+
+
+ for (int64_t i = 0; i < numel; i++) {
+ auto* src_base = src_data + src_stride0 * select_indices_data[i];
+ auto* output_base = output_data + output_stride0 * add_indices_data[i];
+ auto scale = scale_data[i * scale_stride];
+ for (int64_t j = 0; j < ddim; j++) {
+ output_base[j * output_stride1] += src_base[j * src_stride1] * scale;
+ }
+ }
+ }
+}
+
+} // namespace
+
static void make_bag_size(const Tensor &offsets, const Tensor &indices,
const int64_t mode, Tensor &bag_size) {
if (mode == MODE_MEAN || mode == MODE_MAX) {
@@ -230,27 +354,43 @@ _embedding_bag_cpu(const Tensor &weight, const Tensor &indices,
auto bag_size = at::zeros(offsets.sizes(), indices.options());
make_bag_size(offsets, indices, mode, bag_size);
- // If the last entries are empty, that the last offsets are irrelevant as they
- // won't change anything in the assignment of ID -> bag, but index_add would
- // throw out of bounds error. So to keep it simple we just add one more
- // entry to the end then get rid of it after make_offset2bag.
- auto offset2bag = at::zeros(
- {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
-
- make_offset2bag(offsets, indices, offset2bag);
+ auto output = at::zeros({offsets.size(0), weight.size(1)}, weight.options());
- offset2bag.resize_({indices.sizes()[0]});
+ // To save compute, if we are going to go down the fast path case for the 'sum'
+ // mode, we skip calculating offset2bag, since it is not going to be used.
+ auto fast_path_sum = [&weight, &per_sample_weights, &output]() {
+ if (per_sample_weights.defined()) {
+ return isFastPathIndexSelectScale(weight, per_sample_weights, output);
+ } else {
+ return isFastPathIndexSelect(weight, output);
+ }
+ };
- auto output = at::zeros({offsets.size(0), weight.size(1)}, weight.options());
+ // Use an empty 0-element tensor as a sentinel that we have skipped the
+ // creation of offset2bag because autograd chokes when trying to use an
+ // undefined tensor as an input to a backward op.
+ Tensor offset2bag = at::empty({0}, offsets.options());
+ if (mode == MODE_MEAN || mode == MODE_MAX || !fast_path_sum()) {
+ // If the last entries are empty, that the last offsets are irrelevant as they
+ // won't change anything in the assignment of ID -> bag, but index_add would
+ // throw out of bounds error. So to keep it simple we just add one more
+ // entry to the end then get rid of it after make_offset2bag.
+ offset2bag = at::zeros(
+ {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
+
+ make_offset2bag(offsets, indices, offset2bag);
+
+ offset2bag.resize_({indices.sizes()[0]});
+ }
if (mode == MODE_MEAN || mode == MODE_SUM) {
AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", [&]() {
if (per_sample_weights.defined()) {
AT_ASSERT(mode == MODE_SUM);
index_select_scale_add<scalar_t>(
- indices, offset2bag, per_sample_weights, weight, output);
+ indices, offset2bag, per_sample_weights, weight, output, offsets);
} else {
- index_select_add<scalar_t>(indices, offset2bag, weight, output);
+ index_select_add<scalar_t>(indices, offset2bag, weight, output, offsets);
}
});
auto ret = apply_bag_size(offsets, indices, mode, output, bag_size);
@@ -286,17 +426,29 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices,
auto offsets_arg = TensorArg(offsets, "offsets", 1);
checkScalarType("embedding_bag", offsets_arg, kLong);
checkContiguous("embedding_bag", offsets_arg);
- auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
- checkScalarType("embedding_bag", offset2bag_arg, kLong);
- checkContiguous("embedding_bag", offset2bag_arg);
+
+ Tensor offset2bag_;
+ if (indices.numel() != 0 && offset2bag.numel() == 0) {
+ offset2bag_ = at::zeros(
+ {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
+
+ make_offset2bag(offsets, indices, offset2bag_);
+
+ offset2bag_.resize_({indices.sizes()[0]});
+ } else {
+ auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
+ checkScalarType("embedding_bag", offset2bag_arg, kLong);
+ checkContiguous("embedding_bag", offset2bag_arg);
+ offset2bag_ = offset2bag;
+ }
if (sparse) {
return at::_embedding_bag_sparse_backward(
- grad, indices, offsets, offset2bag, bag_size_, num_weights,
+ grad, indices, offsets, offset2bag_, bag_size_, num_weights,
scale_grad_by_freq, mode, per_sample_weights);
} else {
return at::_embedding_bag_dense_backward(
- grad, indices, offsets, offset2bag, bag_size_, max_indices_, num_weights,
+ grad, indices, offsets, offset2bag_, bag_size_, max_indices_, num_weights,
scale_grad_by_freq, mode, per_sample_weights);
}
}
@@ -440,7 +592,6 @@ Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indi
// contiguous here due to the checks in _embedding_bag_backward above.
// Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
// for more details.
-
auto grad = grad_.contiguous();
auto grad_arg = TensorArg(grad, "grad_", 1);
checkScalarTypes("embedding_bag", grad_arg, {kFloat, kDouble});
@@ -467,6 +618,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
const Tensor& grad,
const Tensor& weight, // NB: embedding table, not per_sample_weights
const Tensor& indices,
+ const Tensor& offsets,
const Tensor& offset2bag,
int64_t mode) {
AT_CHECK(
@@ -487,9 +639,21 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarType("embedding_bag", indices_arg, kLong);
checkContiguous("embedding_bag", indices_arg);
- auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
- checkScalarType("embedding_bag", offset2bag_arg, kLong);
- checkContiguous("embedding_bag", offset2bag_arg);
+
+ Tensor offset2bag_;
+ if (indices.numel() != 0 && offset2bag.numel() == 0) {
+ offset2bag_ = at::zeros(
+ {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0]
+
+ make_offset2bag(offsets, indices, offset2bag_);
+
+ offset2bag_.resize_({indices.sizes()[0]});
+ } else {
+ auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
+ checkScalarType("embedding_bag", offset2bag_arg, kLong);
+ checkContiguous("embedding_bag", offset2bag_arg);
+ offset2bag_ = offset2bag;
+ }
auto grad_data = grad.data<scalar_t>();
auto grad_stride0 = grad.stride(0);
@@ -503,7 +667,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
// The following are contiguous
auto output_data = output.data<scalar_t>();
- auto offset2bag_data = offset2bag.data<int64_t>();
+ auto offset2bag_data = offset2bag_.data<int64_t>();
// XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number.
parallel_for(0, num_samples, 64, [&](int64_t begin, int64_t end) {
@@ -524,12 +688,13 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu(
const Tensor& grad,
const Tensor& weight, // NB: embedding table, not per_sample_weights
const Tensor& indices,
+ const Tensor& offsets,
const Tensor& offset2bag,
int64_t mode) {
return AT_DISPATCH_FLOATING_TYPES(
grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu", [&]() {
return _embedding_bag_per_sample_weights_backward_cpu_template<scalar_t>(
- grad, weight, indices, offset2bag, mode);
+ grad, weight, indices, offsets, offset2bag, mode);
}
);
}
diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu
index 8e0f6e8865..ca9132ad90 100644
--- a/aten/src/ATen/native/cuda/EmbeddingBag.cu
+++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu
@@ -469,6 +469,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda(
const Tensor& grad,
const Tensor& weight, // NB: embedding table, not per_sample_weights
const Tensor& indices,
+ const Tensor& offsets,
const Tensor& offset2bag,
int64_t mode) {
AT_CHECK(
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 6ebd0d53fd..428982c58c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -675,7 +675,7 @@
CPU: _embedding_bag_dense_backward_cpu
CUDA: _embedding_bag_dense_backward_cuda
-- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offset2bag, int mode) -> Tensor
+- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode) -> Tensor
dispatch:
CPU: _embedding_bag_per_sample_weights_backward_cpu
CUDA: _embedding_bag_per_sample_weights_backward_cuda
diff --git a/caffe2/perfkernels/embedding_lookup.h b/caffe2/perfkernels/embedding_lookup.h
index 1d0cd2abfa..37830d69c8 100644
--- a/caffe2/perfkernels/embedding_lookup.h
+++ b/caffe2/perfkernels/embedding_lookup.h
@@ -28,6 +28,9 @@ namespace caffe2 {
* if (normalize_weights && lengths[i] > 0)
* for (k = 0..block_size-1)
* out[i*block_size + k] /= lengths[i]
+ *
+ * TODO: make this API also take "offsets" rather than "lengths" to match the
+ * API for PyTorch's EmbeddingBag
*/
template <
typename IndexType,
diff --git a/test/test_nn.py b/test/test_nn.py
index 2366806802..622d78eb95 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2383,7 +2383,7 @@ class TestNN(NNTestCase):
self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad,
dtype2prec[dtype])
- def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double):
+ def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double, test_backward=True):
# check a known test example
device = torch.device("cuda") if cuda else torch.device("cpu")
es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, dtype)
@@ -2472,7 +2472,7 @@ class TestNN(NNTestCase):
# now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length
N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50)
- kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype)
+ kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype, test_backward=test_backward)
self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs)
for max_norm in (None, 3):
for p in itertools.product([1, 2], repeat=4):
@@ -2563,12 +2563,15 @@ class TestNN(NNTestCase):
F.conv2d(x, torch.randn(1, 16, 1, 1, device="cuda"))
def test_embedding_bag(self):
- self._test_EmbeddingBag(False, 'sum', False)
- self._test_EmbeddingBag(False, 'mean', False)
- self._test_EmbeddingBag(False, 'max', False)
-
- self._test_EmbeddingBag(False, 'sum', True)
- self._test_EmbeddingBag(False, 'mean', True)
+ for dtype in [torch.double, torch.float]:
+ # TODO: figure out why backward on float breaks
+ test_backward = dtype is not torch.float
+ self._test_EmbeddingBag(False, 'sum', False, test_backward=test_backward, dtype=dtype)
+ self._test_EmbeddingBag(False, 'mean', False, test_backward=test_backward, dtype=dtype)
+ self._test_EmbeddingBag(False, 'max', False, test_backward=test_backward, dtype=dtype)
+
+ self._test_EmbeddingBag(False, 'sum', True, test_backward=test_backward, dtype=dtype)
+ self._test_EmbeddingBag(False, 'mean', True, test_backward=test_backward, dtype=dtype)
@staticmethod
def _embedding_bag_reference_impl(input, weight, offsets=None, mode='sum',
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 984555197d..159a75d091 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -955,7 +955,7 @@
indices: non_differentiable
offsets: non_differentiable
weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse, per_sample_weights)
- per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, result1, mode)
+ per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode)
- name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type)
indices: non_differentiable