summaryrefslogtreecommitdiff
path: root/torch/nn
diff options
context:
space:
mode:
authorElias Ellison <eellison@fb.com>2018-11-28 19:14:16 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-11-28 19:16:38 -0800
commit6d63e9dbfffba9f925ac3af5232390a76aa54dce (patch)
tree7924d57fc8f42429c12a911e3234c7ce49cbd991 /torch/nn
parent105fa58748076a19682a4c0c9dee878a9575d7ed (diff)
downloadpytorch-6d63e9dbfffba9f925ac3af5232390a76aa54dce.tar.gz
pytorch-6d63e9dbfffba9f925ac3af5232390a76aa54dce.tar.bz2
pytorch-6d63e9dbfffba9f925ac3af5232390a76aa54dce.zip
Support Embedding + EmbeddingBag in Script + (Ignore flakey test) (#14509)
Summary: Resubmitting PR #14415 The tests added for Embedding + EmbeddingBag had random numbers as input, which affected the random number generator & caused the flakey test to break. Everything but the last two commits have already been accepted Pull Request resolved: https://github.com/pytorch/pytorch/pull/14509 Differential Revision: D13247917 Pulled By: eellison fbshipit-source-id: ea6963c47f666c07687787e2fa82020cddc6aa15
Diffstat (limited to 'torch/nn')
-rw-r--r--torch/nn/functional.py41
-rw-r--r--torch/nn/modules/sparse.py14
2 files changed, 40 insertions, 15 deletions
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index e7259f4d55..172a46f5c4 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1328,8 +1328,10 @@ def bilinear(input1, input2, weight, bias=None):
return torch.bilinear(input1, input2, weight, bias)
-def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2,
+@torch._jit_internal.weak_script
+def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.,
scale_grad_by_freq=False, sparse=False):
+ # type: (Tensor, Tensor, Optional[int], Optional[float], float, bool, bool) -> Tensor
r"""A simple lookup table that looks up embeddings in a fixed dictionary and size.
This module is often used to retrieve word embeddings using indices.
@@ -1388,25 +1390,32 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2,
[ 0.6262, 0.2438, 0.7471]]])
"""
if padding_idx is not None:
+ padding_idx = torch.jit._unwrap_optional(padding_idx)
if padding_idx > 0:
assert padding_idx < weight.size(0), 'Padding_idx must be within num_embeddings'
elif padding_idx < 0:
assert padding_idx >= -weight.size(0), 'Padding_idx must be within num_embeddings'
padding_idx = weight.size(0) + padding_idx
- elif padding_idx is None:
+ else:
padding_idx = -1
if max_norm is not None:
+ max_norm = torch.jit._unwrap_optional(max_norm)
# `embedding_renorm_` will call .contiguous() on input anyways, so we
# call it here and take advantage of the improved locality in the
# `embedding` call below too.
input = input.contiguous()
- with torch.no_grad():
- torch.embedding_renorm_(weight, input, max_norm, norm_type)
+ # XXX: equivalent to
+ # with torch.no_grad():
+ # torch.nembedding_renorm_
+ # remove once script supports set_grad_enabled
+ torch.no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
+@torch._jit_internal.weak_script
def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
scale_grad_by_freq=False, mode='mean', sparse=False):
+ # type: (Tensor, Tensor, Optional[Tensor], Optional[float], float, bool, str, bool) -> Tensor
r"""Computes sums, means or maxes of 'bags' of embeddings, without instantiating the
intermediate embeddings.
@@ -1491,26 +1500,27 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
elif input.dim() == 1:
if offsets is None:
raise ValueError("offsets has to be a 1D Tensor but got None")
+ offsets = torch.jit._unwrap_optional(offsets)
if offsets.dim() != 1:
raise ValueError("offsets has to be a 1D Tensor")
- if offsets[0].item() != 0:
+ if int(offsets[0]) != 0:
raise ValueError("offsets[0] has to be 0, i.e., the first sequence "
"in the mini-batch has to start from position 0. "
"However, got {}".format(offsets[0].item()))
- if offsets[-1].item() > input.size(0):
+ if int(offsets[-1]) > input.size(0):
raise ValueError("offsets[-1] can not be greater than input's length"
" ({}), but got offsets[-1] of {}"
.format(input.size(0), offsets[-1].item()))
else:
raise ValueError("input has to be 1D or 2D Tensor,"
" but got Tensor of dimension {}".format(input.dim()))
-
+ offsets = torch.jit._unwrap_optional(offsets) # TODO remove when exception control flow logic
if mode == 'sum':
- mode = 0
+ mode_enum = 0
elif mode == 'mean':
- mode = 1
+ mode_enum = 1
elif mode == 'max':
- mode = 2
+ mode_enum = 2
if scale_grad_by_freq:
raise ValueError("max mode does not support scaling the gradient by the frequency")
@@ -1519,18 +1529,23 @@ def embedding_bag(input, weight, offsets=None, max_norm=None, norm_type=2,
raise ValueError("max mode does not support sparse weights")
else:
+ mode_enum = -1 # TODO when exception control flow logic
raise ValueError("mode has to be one of sum, mean or max")
if max_norm is not None:
- with torch.no_grad():
- torch.embedding_renorm_(weight, input, max_norm, norm_type)
+ max_norm = torch.jit._unwrap_optional(max_norm)
+ # XXX: equivalent to
+ # with torch.no_grad():
+ # torch.nembedding_renorm_
+ # remove once script supports set_grad_enabled
+ torch.no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
ret, _, _, _ = torch.embedding_bag(
weight,
input,
offsets,
scale_grad_by_freq,
- mode,
+ mode_enum,
sparse)
return ret
diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py
index 2ee41035a0..cd3ea4eb21 100644
--- a/torch/nn/modules/sparse.py
+++ b/torch/nn/modules/sparse.py
@@ -4,8 +4,10 @@ from torch.nn.parameter import Parameter
from .module import Module
from .. import functional as F
from .. import init
+from torch._jit_internal import weak_module, weak_script, weak_script_method
+@weak_module
class Embedding(Module):
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
@@ -75,9 +77,11 @@ class Embedding(Module):
[ 0.0000, 0.0000, 0.0000],
[-0.1655, 0.9897, 0.0635]]])
"""
+ __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
+ 'norm_type', 'scale_grad_by_freq', 'sparse', '_weight']
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
- max_norm=None, norm_type=2, scale_grad_by_freq=False,
+ max_norm=None, norm_type=2., scale_grad_by_freq=False,
sparse=False, _weight=None):
super(Embedding, self).__init__()
self.num_embeddings = num_embeddings
@@ -107,6 +111,7 @@ class Embedding(Module):
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
+ @weak_script_method
def forward(self, input):
return F.embedding(
input, self.weight, self.padding_idx, self.max_norm,
@@ -161,6 +166,7 @@ class Embedding(Module):
return embedding
+@weak_module
class EmbeddingBag(Module):
r"""Computes sums or means of 'bags' of embeddings, without instantiating the
intermediate embeddings.
@@ -223,9 +229,11 @@ class EmbeddingBag(Module):
tensor([[-0.8861, -5.4350, -0.0523],
[ 1.1306, -2.5798, -1.0044]])
"""
+ __constants__ = ['num_embeddings, embedding_dim', 'max_norm', 'norm_type',
+ 'scale_grad_by_freq', 'mode', 'sparse']
def __init__(self, num_embeddings, embedding_dim,
- max_norm=None, norm_type=2, scale_grad_by_freq=False,
+ max_norm=None, norm_type=2., scale_grad_by_freq=False,
mode='mean', sparse=False):
super(EmbeddingBag, self).__init__()
self.num_embeddings = num_embeddings
@@ -242,7 +250,9 @@ class EmbeddingBag(Module):
def reset_parameters(self):
init.normal_(self.weight)
+ @weak_script_method
def forward(self, input, offsets=None):
+ # type: (Tensor, Optional[Tensor]) -> Tensor
return F.embedding_bag(input, self.weight, offsets,
self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.mode, self.sparse)