diff options
author | Elias Ellison <eellison@fb.com> | 2018-11-28 19:14:16 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-11-28 19:16:38 -0800 |
commit | 6d63e9dbfffba9f925ac3af5232390a76aa54dce (patch) | |
tree | 7924d57fc8f42429c12a911e3234c7ce49cbd991 /torch/nn | |
parent | 105fa58748076a19682a4c0c9dee878a9575d7ed (diff) | |
download | pytorch-6d63e9dbfffba9f925ac3af5232390a76aa54dce.tar.gz pytorch-6d63e9dbfffba9f925ac3af5232390a76aa54dce.tar.bz2 pytorch-6d63e9dbfffba9f925ac3af5232390a76aa54dce.zip |
Support Embedding + EmbeddingBag in Script + (Ignore flakey test) (#14509)
Summary:
Resubmitting PR #14415
The tests added for Embedding + EmbeddingBag had random numbers as input, which affected the random number generator & caused the flakey test to break.
Everything but the last two commits have already been accepted
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14509
Differential Revision: D13247917
Pulled By: eellison
fbshipit-source-id: ea6963c47f666c07687787e2fa82020cddc6aa15
Diffstat (limited to 'torch/nn')
-rw-r--r-- | torch/nn/functional.py | 41 | ||||
-rw-r--r-- | torch/nn/modules/sparse.py | 14 |
2 files changed, 40 insertions, 15 deletions
diff --git a/torch/nn/functional.py b/torch/nn/functional.py index e7259f4d55..172a46f5c4 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1328,8 +1328,10 @@ def bilinear(input1, input2, weight, bias=None): return torch.bilinear(input1, input2, weight, bias) -def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2, +@torch._jit_internal.weak_script +def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False): + # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor r"""A simple lookup table that looks up embeddings in a fixed dictionary and size. This module is often used to retrieve word embeddings using indices. @@ -1388,25 +1390,32 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2, [ 0.6262, 0.2438, 0.7471]]]) """ if padding_idx is not None: + padding_idx = torch.jit._unwrap_optional(padding_idx) if padding_idx > 0: assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings' elif padding_idx < 0: assert padding_idx >= -weight.size(0), 'Padding_idx must be within num_embeddings' padding_idx = weight.size(0) + padding_idx - elif padding_idx is None: + else: padding_idx = -1 if max_norm is not None: + max_norm = torch.jit._unwrap_optional(max_norm) # `embedding_renorm_` will call .contiguous() on input anyways, so we # call it here and take advantage of the improved locality in the # `embedding` call below too. input = input.contiguous() - with torch.no_grad(): - torch.embedding_renorm_(weight, input, max_norm, norm_type) + # XXX: equivalent to + # with torch.no_grad(): + # torch.nembedding_renorm_ + # remove once script supports set_grad_enabled + torch.no_grad_embedding_renorm_(weight, input, max_norm, norm_type) return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) +@torch._jit_internal.weak_script def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode='mean', sparse=False): + # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool) -> Tensor r"""Computes sums, means or maxes of 'bags' of embeddings, without instantiating the intermediate embeddings. @@ -1491,26 +1500,27 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, elif input.dim() == 1: if offsets is None: raise ValueError("offsets has to be a 1D Tensor but got None") + offsets = torch.jit._unwrap_optional(offsets) if offsets.dim() != 1: raise ValueError("offsets has to be a 1D Tensor") - if offsets[0].item() != 0: + if int(offsets[0]) != 0: raise ValueError("offsets[0] has to be 0, i.e., the first sequence " "in the mini-batch has to start from position 0. " "However, got {}".format(offsets[0].item())) - if offsets[-1].item() > input.size(0): + if int(offsets[-1]) > input.size(0): raise ValueError("offsets[-1] can not be greater than input's length" " ({}), but got offsets[-1] of {}" .format(input.size(0), offsets[-1].item())) else: raise ValueError("input has to be 1D or 2D Tensor," " but got Tensor of dimension {}".format(input.dim())) - + offsets = torch.jit._unwrap_optional(offsets) # TODO remove when exception control flow logic if mode == 'sum': - mode = 0 + mode_enum = 0 elif mode == 'mean': - mode = 1 + mode_enum = 1 elif mode == 'max': - mode = 2 + mode_enum = 2 if scale_grad_by_freq: raise ValueError("max mode does not support scaling the gradient by the frequency") @@ -1519,18 +1529,23 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2, raise ValueError("max mode does not support sparse weights") else: + mode_enum = -1 # TODO when exception control flow logic raise ValueError("mode has to be one of sum, mean or max") if max_norm is not None: - with torch.no_grad(): - torch.embedding_renorm_(weight, input, max_norm, norm_type) + max_norm = torch.jit._unwrap_optional(max_norm) + # XXX: equivalent to + # with torch.no_grad(): + # torch.nembedding_renorm_ + # remove once script supports set_grad_enabled + torch.no_grad_embedding_renorm_(weight, input, max_norm, norm_type) ret, _, _, _ = torch.embedding_bag( weight, input, offsets, scale_grad_by_freq, - mode, + mode_enum, sparse) return ret diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py index 2ee41035a0..cd3ea4eb21 100644 --- a/torch/nn/modules/sparse.py +++ b/torch/nn/modules/sparse.py @@ -4,8 +4,10 @@ from torch.nn.parameter import Parameter from .module import Module from .. import functional as F from .. import init +from torch._jit_internal import weak_module, weak_script, weak_script_method +@weak_module class Embedding(Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -75,9 +77,11 @@ class Embedding(Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm', + 'norm_type', 'scale_grad_by_freq', 'sparse', '_weight'] def __init__(self, num_embeddings, embedding_dim, padding_idx=None, - max_norm=None, norm_type=2, scale_grad_by_freq=False, + max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False, _weight=None): super(Embedding, self).__init__() self.num_embeddings = num_embeddings @@ -107,6 +111,7 @@ class Embedding(Module): with torch.no_grad(): self.weight[self.padding_idx].fill_(0) + @weak_script_method def forward(self, input): return F.embedding( input, self.weight, self.padding_idx, self.max_norm, @@ -161,6 +166,7 @@ class Embedding(Module): return embedding +@weak_module class EmbeddingBag(Module): r"""Computes sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings. @@ -223,9 +229,11 @@ class EmbeddingBag(Module): tensor([[-0.8861, -5.4350, -0.0523], [ 1.1306, -2.5798, -1.0044]]) """ + __constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type', + 'scale_grad_by_freq', 'mode', 'sparse'] def __init__(self, num_embeddings, embedding_dim, - max_norm=None, norm_type=2, scale_grad_by_freq=False, + max_norm=None, norm_type=2., scale_grad_by_freq=False, mode='mean', sparse=False): super(EmbeddingBag, self).__init__() self.num_embeddings = num_embeddings @@ -242,7 +250,9 @@ class EmbeddingBag(Module): def reset_parameters(self): init.normal_(self.weight) + @weak_script_method def forward(self, input, offsets=None): + # type: (Tensor, Optional[Tensor]) -> Tensor return F.embedding_bag(input, self.weight, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse) |