summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Riazati <davidriazati@fb.com>2018-12-18 17:25:51 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-18 17:28:28 -0800
commitf3bff2d50050bc4300c6c842163eabfc4a01f571 (patch)
tree5d7a466683240f46826788a8b263b1cd08ab8089
parentf3cc9b221829945fa614c25c1abdd8328f7e3125 (diff)
downloadpytorch-f3bff2d50050bc4300c6c842163eabfc4a01f571.tar.gz
pytorch-f3bff2d50050bc4300c6c842163eabfc4a01f571.tar.bz2
pytorch-f3bff2d50050bc4300c6c842163eabfc4a01f571.zip
Add RNNCell modules to Script standard library (#14695)
Summary: Adds RNNCell modules to script standard lib cc apaszke for argument_spec changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/14695 Differential Revision: D13467680 Pulled By: driazati fbshipit-source-id: 13a14da87714325cc4c3d49e5fde8a850d5d757b
-rw-r--r--test/test_jit.py15
-rw-r--r--torch/csrc/jit/argument_spec.h5
-rw-r--r--torch/csrc/jit/graph_executor.cpp3
-rw-r--r--torch/nn/modules/rnn.py63
4 files changed, 64 insertions, 22 deletions
diff --git a/test/test_jit.py b/test/test_jit.py
index f8d946a256..760f9029b9 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -10787,6 +10787,21 @@ additional_module_tests = [
input_size=(S, S),
extra_args=((S, S),)
),
+ dict(
+ module_name='RNNCell',
+ constructor_args=(S, S),
+ input_size=(S, S),
+ ),
+ dict(
+ module_name='LSTMCell',
+ constructor_args=(S, S),
+ input_size=(S, S),
+ ),
+ dict(
+ module_name='GRUCell',
+ constructor_args=(S, S),
+ input_size=(S, S),
+ ),
]
diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h
index ec5b337fd4..ec3d988ed5 100644
--- a/torch/csrc/jit/argument_spec.h
+++ b/torch/csrc/jit/argument_spec.h
@@ -73,13 +73,13 @@ struct ArgumentSpec {
}
void addInput(const IValue& input, size_t& offset, bool with_grad) {
- auto & arg = args[offset];
+ auto & arg = args.at(offset);
// Initialize all fields to 0. This is convenient, because e.g.
// requires_grad() can be checked even on tensors AND will make
// padding bits all 0s.
std::memset(&arg, 0, sizeof(ArgumentInfo));
+
if (input.isTensor()) {
- JIT_ASSERT(offset < args.size());
at::Tensor t = input.toTensor();
if ((arg.defined_ = t.defined())) {
arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad();
@@ -96,7 +96,6 @@ struct ArgumentSpec {
addInput(elem, offset, with_grad);
}
} else {
- JIT_ASSERT(offset < args.size());
// NB: no need to set is_tensor to false, because we memset the struct to 0 above
combineHash(arg);
offset++;
diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp
index cbdf89366f..9d4054651e 100644
--- a/torch/csrc/jit/graph_executor.cpp
+++ b/torch/csrc/jit/graph_executor.cpp
@@ -298,6 +298,9 @@ struct GraphExecutorImpl {
}
static size_t countFlatInputs(const TypePtr& ptr) {
+ if (auto optional_type = ptr->cast<OptionalType>()) {
+ return countFlatInputs(optional_type->getElementType());
+ }
if (auto tuple_type = ptr->cast<TupleType>()) {
size_t total = 0;
for (auto & elem : tuple_type->elements()) {
diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py
index 4d5f2a1e1b..89d9b94c58 100644
--- a/torch/nn/modules/rnn.py
+++ b/torch/nn/modules/rnn.py
@@ -8,8 +8,9 @@ from .module import Module
from ..parameter import Parameter
from ..utils.rnn import PackedSequence
from .. import init
+from .. import _VF
+from ..._jit_internal import weak_module, weak_script_method
-_VF = torch._C._VariableFunctions
_rnn_impls = {
'LSTM': _VF.lstm,
'GRU': _VF.gru,
@@ -535,6 +536,7 @@ class GRU(RNNBase):
class RNNCellBase(Module):
+ __constants__ = ['input_size', 'hidden_size', 'bias']
def __init__(self, input_size, hidden_size, bias, num_chunks):
super(RNNCellBase, self).__init__()
@@ -559,13 +561,16 @@ class RNNCellBase(Module):
s += ', nonlinearity={nonlinearity}'
return s.format(**self.__dict__)
+ @weak_script_method
def check_forward_input(self, input):
if input.size(1) != self.input_size:
raise RuntimeError(
"input has inconsistent input_size: got {}, expected {}".format(
input.size(1), self.input_size))
+ @weak_script_method
def check_forward_hidden(self, input, hx, hidden_label=''):
+ # type: (Tensor, Tensor, str)
if input.size(0) != hx.size(0):
raise RuntimeError(
"Input batch size {} doesn't match hidden{} batch size {}".format(
@@ -582,6 +587,7 @@ class RNNCellBase(Module):
init.uniform_(weight, -stdv, stdv)
+@weak_module
class RNNCell(RNNCellBase):
r"""An Elman RNN cell with tanh or ReLU non-linearity.
@@ -630,31 +636,41 @@ class RNNCell(RNNCellBase):
hx = rnn(input[i], hx)
output.append(hx)
"""
+ __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity']
def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"):
super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1)
self.nonlinearity = nonlinearity
+ @weak_script_method
def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tensor]) -> Tensor
self.check_forward_input(input)
if hx is None:
- hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
- self.check_forward_hidden(input, hx)
+ _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ else:
+ _hx = torch.jit._unwrap_optional(hx)
+ self.check_forward_hidden(input, _hx, '')
if self.nonlinearity == "tanh":
- func = _VF.rnn_tanh_cell
+ ret = _VF.rnn_tanh_cell(
+ input, _hx,
+ self.weight_ih, self.weight_hh,
+ self.bias_ih, self.bias_hh,
+ )
elif self.nonlinearity == "relu":
- func = _VF.rnn_relu_cell
+ ret = _VF.rnn_relu_cell(
+ input, _hx,
+ self.weight_ih, self.weight_hh,
+ self.bias_ih, self.bias_hh,
+ )
else:
+ ret = input # TODO: remove when jit supports exception flow
raise RuntimeError(
"Unknown nonlinearity: {}".format(self.nonlinearity))
-
- return func(
- input, hx,
- self.weight_ih, self.weight_hh,
- self.bias_ih, self.bias_hh,
- )
+ return ret
+@weak_module
class LSTMCell(RNNCellBase):
r"""A long short-term memory (LSTM) cell.
@@ -719,20 +735,25 @@ class LSTMCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True):
super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4)
+ @weak_script_method
def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
self.check_forward_input(input)
if hx is None:
- hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
- hx = (hx, hx)
- self.check_forward_hidden(input, hx[0], '[0]')
- self.check_forward_hidden(input, hx[1], '[1]')
+ zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ _hx = (zeros, zeros)
+ else:
+ _hx = torch.jit._unwrap_optional(hx)
+ self.check_forward_hidden(input, _hx[0], '[0]')
+ self.check_forward_hidden(input, _hx[1], '[1]')
return _VF.lstm_cell(
- input, hx,
+ input, _hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
)
+@weak_module
class GRUCell(RNNCellBase):
r"""A gated recurrent unit (GRU) cell
@@ -789,13 +810,17 @@ class GRUCell(RNNCellBase):
def __init__(self, input_size, hidden_size, bias=True):
super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3)
+ @weak_script_method
def forward(self, input, hx=None):
+ # type: (Tensor, Optional[Tensor]) -> Tensor
self.check_forward_input(input)
if hx is None:
- hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
- self.check_forward_hidden(input, hx)
+ _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
+ else:
+ _hx = torch.jit._unwrap_optional(hx)
+ self.check_forward_hidden(input, _hx, '')
return _VF.gru_cell(
- input, hx,
+ input, _hx,
self.weight_ih, self.weight_hh,
self.bias_ih, self.bias_hh,
)