diff options
-rw-r--r-- | aten/src/ATen/native/EmbeddingBag.cpp | 223 | ||||
-rw-r--r-- | aten/src/ATen/native/cuda/EmbeddingBag.cu | 1 | ||||
-rw-r--r-- | aten/src/ATen/native/native_functions.yaml | 2 | ||||
-rw-r--r-- | caffe2/perfkernels/embedding_lookup.h | 3 | ||||
-rw-r--r-- | test/test_nn.py | 19 | ||||
-rw-r--r-- | tools/autograd/derivatives.yaml | 2 |
6 files changed, 211 insertions, 39 deletions
diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index 9fa2bdb953..80bbf9c118 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -5,6 +5,8 @@ #include <TH/THBlasUtils.h> +#include <caffe2/perfkernels/embedding_lookup.h> + #include <cstring> #include <iostream> #include <memory> @@ -24,22 +26,32 @@ namespace { namespace at { namespace native { -static void make_offset2bag(const Tensor &offsets, const Tensor &indices, - Tensor &offset2bag) { +static void make_offset2bag(const Tensor &offsets, const Tensor &indices, Tensor& offset2bag) { offset2bag.index_add_( 0, offsets, at::ones_like(offsets)); // offset2bag = [1 0 1 0 1] offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1] offset2bag = offset2bag.cumsum(0); // offset2bag = [0 0 1 1 2] } +namespace { + +bool isFastPathIndexSelect(const Tensor& src, Tensor& output) { + return src.scalar_type() == kFloat && src.stride(1) == 1 && output.stride(1) == 1; +} + +bool isFastPathIndexSelectScale(const Tensor& src, const Tensor& scale, Tensor& output) { + return src.scalar_type() == kFloat && src.stride(1) == 1 && output.stride(1) == 1 && scale.stride(0) == 1; +} + // This function combines index_select (using select_indices as the index) and // index_add (using add_indices as the index), without creating an intermediary // tensor to hold the selected embeddings template<typename T> -static void index_select_add(const Tensor &select_indices, +void index_select_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &src, - Tensor &output) { + Tensor &output, + const Tensor& /*offsets*/) { AT_ASSERT(select_indices.numel() == add_indices.numel()); auto add_indices_data = add_indices.data<int64_t>(); auto select_indices_data = select_indices.data<int64_t>(); @@ -59,6 +71,57 @@ static void index_select_add(const Tensor &select_indices, } } +template<> +void index_select_add<float>(const Tensor &select_indices, + const Tensor &add_indices, + const Tensor &src, + Tensor &output, + const Tensor& offsets) { + int64_t ddim = src.size(1); + auto src_data = src.data<float>(); + auto select_indices_data = select_indices.data<int64_t>(); + auto output_data = output.data<float>(); + + if (isFastPathIndexSelect(src, output)) { + auto accessor = offsets.accessor<int64_t, 1>(); + std::vector<int> lengths; + + int64_t lower = accessor[0]; + for (size_t i = 1; i < offsets.numel(); ++i) { + lengths.push_back(accessor[i] - lower); + lower = accessor[i]; + } + lengths.push_back(select_indices.numel() - lower); + + caffe2::EmbeddingLookup( + /*block_size=*/ddim, + /*output_size=*/lengths.size(), + /*index_size=*/select_indices.numel(), + /*data_size=*/src.size(0), + /*input=*/src_data, + /*indices=*/select_indices_data, + /*lengths=*/lengths.data(), + /*weights=*/nullptr, + /*scale_bias=*/nullptr, + /*normalize_by_lengths=*/false, + /*out=*/output_data + ); + } else { + AT_ASSERT(select_indices.numel() == add_indices.numel()); + auto add_indices_data = add_indices.data<int64_t>(); + auto src_stride0 = src.stride(0); + auto src_stride1 = src.stride(1); + auto output_stride0 = output.stride(0); + auto output_stride1 = output.stride(1); + auto numel = add_indices.numel(); + for (int64_t i = 0; i < numel; i++) { + THBlas_axpy<float>(ddim, 1, + src_data + src_stride0 * select_indices_data[i], src_stride1, + output_data + output_stride0 * add_indices_data[i], output_stride1); + } + } +} + // This function fuses the following three fns: // index_select (using select_indices as the index) // mul (scaling by per_sample_weights) @@ -68,7 +131,8 @@ static void index_select_scale_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &scale, const Tensor &src, - Tensor &output) { + Tensor &output, + const Tensor& /*offsets*/) { AT_ASSERT(select_indices.numel() == add_indices.numel()); auto add_indices_data = add_indices.data<int64_t>(); auto select_indices_data = select_indices.data<int64_t>(); @@ -84,7 +148,6 @@ static void index_select_scale_add(const Tensor &select_indices, auto* scale_data = scale.data<T>(); auto scale_stride = scale.stride(0); - // XXX: We could make this faster via vectorization for (int64_t i = 0; i < numel; i++) { auto* src_base = src_data + src_stride0 * select_indices_data[i]; auto* output_base = output_data + output_stride0 * add_indices_data[i]; @@ -95,6 +158,67 @@ static void index_select_scale_add(const Tensor &select_indices, } } +template<> +void index_select_scale_add<float>(const Tensor &select_indices, + const Tensor &add_indices, + const Tensor &scale, + const Tensor &src, + Tensor &output, + const Tensor& offsets) { + int64_t ddim = src.size(1); + auto* scale_data = scale.data<float>(); + auto select_indices_data = select_indices.data<int64_t>(); + auto src_data = src.data<float>(); + auto output_data = output.data<float>(); + + if (isFastPathIndexSelectScale(src, scale, output)) { + auto accessor = offsets.accessor<int64_t, 1>(); + std::vector<int> lengths; + + int64_t lower = accessor[0]; + for (size_t i = 1; i < offsets.numel(); ++i) { + lengths.push_back(accessor[i] - lower); + lower = accessor[i]; + } + lengths.push_back(select_indices.numel() - lower); + + caffe2::EmbeddingLookup( + /*block_size=*/ddim, + /*output_size=*/lengths.size(), + /*index_size=*/select_indices.numel(), + /*data_size=*/src.size(0), + /*input=*/src_data, + /*indices=*/select_indices_data, + /*lengths=*/lengths.data(), + /*weights=*/scale_data, + /*scale_bias=*/nullptr, + /*normalize_by_lengths=*/false, + /*out=*/output_data + ); + } else { + AT_ASSERT(select_indices.numel() == add_indices.numel()); + auto add_indices_data = add_indices.data<int64_t>(); + auto src_stride0 = src.stride(0); + auto src_stride1 = src.stride(1); + auto output_stride0 = output.stride(0); + auto output_stride1 = output.stride(1); + auto scale_stride = scale.stride(0); + auto numel = add_indices.numel(); + + + for (int64_t i = 0; i < numel; i++) { + auto* src_base = src_data + src_stride0 * select_indices_data[i]; + auto* output_base = output_data + output_stride0 * add_indices_data[i]; + auto scale = scale_data[i * scale_stride]; + for (int64_t j = 0; j < ddim; j++) { + output_base[j * output_stride1] += src_base[j * src_stride1] * scale; + } + } + } +} + +} // namespace + static void make_bag_size(const Tensor &offsets, const Tensor &indices, const int64_t mode, Tensor &bag_size) { if (mode == MODE_MEAN || mode == MODE_MAX) { @@ -230,27 +354,43 @@ _embedding_bag_cpu(const Tensor &weight, const Tensor &indices, auto bag_size = at::zeros(offsets.sizes(), indices.options()); make_bag_size(offsets, indices, mode, bag_size); - // If the last entries are empty, that the last offsets are irrelevant as they - // won't change anything in the assignment of ID -> bag, but index_add would - // throw out of bounds error. So to keep it simple we just add one more - // entry to the end then get rid of it after make_offset2bag. - auto offset2bag = at::zeros( - {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] - - make_offset2bag(offsets, indices, offset2bag); + auto output = at::zeros({offsets.size(0), weight.size(1)}, weight.options()); - offset2bag.resize_({indices.sizes()[0]}); + // To save compute, if we are going to go down the fast path case for the 'sum' + // mode, we skip calculating offset2bag, since it is not going to be used. + auto fast_path_sum = [&weight, &per_sample_weights, &output]() { + if (per_sample_weights.defined()) { + return isFastPathIndexSelectScale(weight, per_sample_weights, output); + } else { + return isFastPathIndexSelect(weight, output); + } + }; - auto output = at::zeros({offsets.size(0), weight.size(1)}, weight.options()); + // Use an empty 0-element tensor as a sentinel that we have skipped the + // creation of offset2bag because autograd chokes when trying to use an + // undefined tensor as an input to a backward op. + Tensor offset2bag = at::empty({0}, offsets.options()); + if (mode == MODE_MEAN || mode == MODE_MAX || !fast_path_sum()) { + // If the last entries are empty, that the last offsets are irrelevant as they + // won't change anything in the assignment of ID -> bag, but index_add would + // throw out of bounds error. So to keep it simple we just add one more + // entry to the end then get rid of it after make_offset2bag. + offset2bag = at::zeros( + {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] + + make_offset2bag(offsets, indices, offset2bag); + + offset2bag.resize_({indices.sizes()[0]}); + } if (mode == MODE_MEAN || mode == MODE_SUM) { AT_DISPATCH_FLOATING_TYPES(weight.scalar_type(), "embedding_bag_cpu", [&]() { if (per_sample_weights.defined()) { AT_ASSERT(mode == MODE_SUM); index_select_scale_add<scalar_t>( - indices, offset2bag, per_sample_weights, weight, output); + indices, offset2bag, per_sample_weights, weight, output, offsets); } else { - index_select_add<scalar_t>(indices, offset2bag, weight, output); + index_select_add<scalar_t>(indices, offset2bag, weight, output, offsets); } }); auto ret = apply_bag_size(offsets, indices, mode, output, bag_size); @@ -286,17 +426,29 @@ Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices, auto offsets_arg = TensorArg(offsets, "offsets", 1); checkScalarType("embedding_bag", offsets_arg, kLong); checkContiguous("embedding_bag", offsets_arg); - auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); - checkScalarType("embedding_bag", offset2bag_arg, kLong); - checkContiguous("embedding_bag", offset2bag_arg); + + Tensor offset2bag_; + if (indices.numel() != 0 && offset2bag.numel() == 0) { + offset2bag_ = at::zeros( + {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] + + make_offset2bag(offsets, indices, offset2bag_); + + offset2bag_.resize_({indices.sizes()[0]}); + } else { + auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); + checkScalarType("embedding_bag", offset2bag_arg, kLong); + checkContiguous("embedding_bag", offset2bag_arg); + offset2bag_ = offset2bag; + } if (sparse) { return at::_embedding_bag_sparse_backward( - grad, indices, offsets, offset2bag, bag_size_, num_weights, + grad, indices, offsets, offset2bag_, bag_size_, num_weights, scale_grad_by_freq, mode, per_sample_weights); } else { return at::_embedding_bag_dense_backward( - grad, indices, offsets, offset2bag, bag_size_, max_indices_, num_weights, + grad, indices, offsets, offset2bag_, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode, per_sample_weights); } } @@ -440,7 +592,6 @@ Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indi // contiguous here due to the checks in _embedding_bag_backward above. // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml // for more details. - auto grad = grad_.contiguous(); auto grad_arg = TensorArg(grad, "grad_", 1); checkScalarTypes("embedding_bag", grad_arg, {kFloat, kDouble}); @@ -467,6 +618,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template( const Tensor& grad, const Tensor& weight, // NB: embedding table, not per_sample_weights const Tensor& indices, + const Tensor& offsets, const Tensor& offset2bag, int64_t mode) { AT_CHECK( @@ -487,9 +639,21 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template( auto indices_arg = TensorArg(indices, "indices", 1); checkScalarType("embedding_bag", indices_arg, kLong); checkContiguous("embedding_bag", indices_arg); - auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); - checkScalarType("embedding_bag", offset2bag_arg, kLong); - checkContiguous("embedding_bag", offset2bag_arg); + + Tensor offset2bag_; + if (indices.numel() != 0 && offset2bag.numel() == 0) { + offset2bag_ = at::zeros( + {indices.sizes()[0] + 1}, indices.options()); // offset2bag = [0 0 0 0 0] + + make_offset2bag(offsets, indices, offset2bag_); + + offset2bag_.resize_({indices.sizes()[0]}); + } else { + auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); + checkScalarType("embedding_bag", offset2bag_arg, kLong); + checkContiguous("embedding_bag", offset2bag_arg); + offset2bag_ = offset2bag; + } auto grad_data = grad.data<scalar_t>(); auto grad_stride0 = grad.stride(0); @@ -503,7 +667,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu_template( // The following are contiguous auto output_data = output.data<scalar_t>(); - auto offset2bag_data = offset2bag.data<int64_t>(); + auto offset2bag_data = offset2bag_.data<int64_t>(); // XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number. parallel_for(0, num_samples, 64, [&](int64_t begin, int64_t end) { @@ -524,12 +688,13 @@ Tensor _embedding_bag_per_sample_weights_backward_cpu( const Tensor& grad, const Tensor& weight, // NB: embedding table, not per_sample_weights const Tensor& indices, + const Tensor& offsets, const Tensor& offset2bag, int64_t mode) { return AT_DISPATCH_FLOATING_TYPES( grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu", [&]() { return _embedding_bag_per_sample_weights_backward_cpu_template<scalar_t>( - grad, weight, indices, offset2bag, mode); + grad, weight, indices, offsets, offset2bag, mode); } ); } diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 8e0f6e8865..ca9132ad90 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -469,6 +469,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda( const Tensor& grad, const Tensor& weight, // NB: embedding table, not per_sample_weights const Tensor& indices, + const Tensor& offsets, const Tensor& offset2bag, int64_t mode) { AT_CHECK( diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 6ebd0d53fd..428982c58c 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -675,7 +675,7 @@ CPU: _embedding_bag_dense_backward_cpu CUDA: _embedding_bag_dense_backward_cuda -- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offset2bag, int mode) -> Tensor +- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode) -> Tensor dispatch: CPU: _embedding_bag_per_sample_weights_backward_cpu CUDA: _embedding_bag_per_sample_weights_backward_cuda diff --git a/caffe2/perfkernels/embedding_lookup.h b/caffe2/perfkernels/embedding_lookup.h index 1d0cd2abfa..37830d69c8 100644 --- a/caffe2/perfkernels/embedding_lookup.h +++ b/caffe2/perfkernels/embedding_lookup.h @@ -28,6 +28,9 @@ namespace caffe2 { * if (normalize_weights && lengths[i] > 0) * for (k = 0..block_size-1) * out[i*block_size + k] /= lengths[i] + * + * TODO: make this API also take "offsets" rather than "lengths" to match the + * API for PyTorch's EmbeddingBag */ template < typename IndexType, diff --git a/test/test_nn.py b/test/test_nn.py index 2366806802..622d78eb95 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2383,7 +2383,7 @@ class TestNN(NNTestCase): self.assertEqual(per_sample_weights.grad, per_sample_weights_reference.grad, dtype2prec[dtype]) - def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double): + def _test_EmbeddingBag(self, cuda, mode, sparse, dtype=torch.double, test_backward=True): # check a known test example device = torch.device("cuda") if cuda else torch.device("cpu") es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, dtype) @@ -2472,7 +2472,7 @@ class TestNN(NNTestCase): # now compare EmbeddingBag vs Embedding + Sum/Mean, for constant bag length N, D, B, L = random.randint(1, 100), random.randint(1, 100), random.randint(1, 50), random.randint(1, 50) - kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype) + kwargs = dict(mode=mode, sparse=sparse, device=device, dtype=dtype, test_backward=test_backward) self._test_EmbeddingBag_vs_Embedding(N, D, B, L, **kwargs) for max_norm in (None, 3): for p in itertools.product([1, 2], repeat=4): @@ -2563,12 +2563,15 @@ class TestNN(NNTestCase): F.conv2d(x, torch.randn(1, 16, 1, 1, device="cuda")) def test_embedding_bag(self): - self._test_EmbeddingBag(False, 'sum', False) - self._test_EmbeddingBag(False, 'mean', False) - self._test_EmbeddingBag(False, 'max', False) - - self._test_EmbeddingBag(False, 'sum', True) - self._test_EmbeddingBag(False, 'mean', True) + for dtype in [torch.double, torch.float]: + # TODO: figure out why backward on float breaks + test_backward = dtype is not torch.float + self._test_EmbeddingBag(False, 'sum', False, test_backward=test_backward, dtype=dtype) + self._test_EmbeddingBag(False, 'mean', False, test_backward=test_backward, dtype=dtype) + self._test_EmbeddingBag(False, 'max', False, test_backward=test_backward, dtype=dtype) + + self._test_EmbeddingBag(False, 'sum', True, test_backward=test_backward, dtype=dtype) + self._test_EmbeddingBag(False, 'mean', True, test_backward=test_backward, dtype=dtype) @staticmethod def _embedding_bag_reference_impl(input, weight, offsets=None, mode='sum', diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 984555197d..159a75d091 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -955,7 +955,7 @@ indices: non_differentiable offsets: non_differentiable weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse, per_sample_weights) - per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, result1, mode) + per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode) - name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type) indices: non_differentiable |