diff options
author | David Riazati <davidriazati@fb.com> | 2018-12-18 17:25:51 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-12-18 17:28:28 -0800 |
commit | f3bff2d50050bc4300c6c842163eabfc4a01f571 (patch) | |
tree | 5d7a466683240f46826788a8b263b1cd08ab8089 | |
parent | f3cc9b221829945fa614c25c1abdd8328f7e3125 (diff) | |
download | pytorch-f3bff2d50050bc4300c6c842163eabfc4a01f571.tar.gz pytorch-f3bff2d50050bc4300c6c842163eabfc4a01f571.tar.bz2 pytorch-f3bff2d50050bc4300c6c842163eabfc4a01f571.zip |
Add RNNCell modules to Script standard library (#14695)
Summary:
Adds RNNCell modules to script standard lib
cc apaszke for argument_spec changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14695
Differential Revision: D13467680
Pulled By: driazati
fbshipit-source-id: 13a14da87714325cc4c3d49e5fde8a850d5d757b
-rw-r--r-- | test/test_jit.py | 15 | ||||
-rw-r--r-- | torch/csrc/jit/argument_spec.h | 5 | ||||
-rw-r--r-- | torch/csrc/jit/graph_executor.cpp | 3 | ||||
-rw-r--r-- | torch/nn/modules/rnn.py | 63 |
4 files changed, 64 insertions, 22 deletions
diff --git a/test/test_jit.py b/test/test_jit.py index f8d946a256..760f9029b9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -10787,6 +10787,21 @@ additional_module_tests = [ input_size=(S, S), extra_args=((S, S),) ), + dict( + module_name='RNNCell', + constructor_args=(S, S), + input_size=(S, S), + ), + dict( + module_name='LSTMCell', + constructor_args=(S, S), + input_size=(S, S), + ), + dict( + module_name='GRUCell', + constructor_args=(S, S), + input_size=(S, S), + ), ] diff --git a/torch/csrc/jit/argument_spec.h b/torch/csrc/jit/argument_spec.h index ec5b337fd4..ec3d988ed5 100644 --- a/torch/csrc/jit/argument_spec.h +++ b/torch/csrc/jit/argument_spec.h @@ -73,13 +73,13 @@ struct ArgumentSpec { } void addInput(const IValue& input, size_t& offset, bool with_grad) { - auto & arg = args[offset]; + auto & arg = args.at(offset); // Initialize all fields to 0. This is convenient, because e.g. // requires_grad() can be checked even on tensors AND will make // padding bits all 0s. std::memset(&arg, 0, sizeof(ArgumentInfo)); + if (input.isTensor()) { - JIT_ASSERT(offset < args.size()); at::Tensor t = input.toTensor(); if ((arg.defined_ = t.defined())) { arg.requires_grad_ = with_grad && autograd::Variable(t).requires_grad(); @@ -96,7 +96,6 @@ struct ArgumentSpec { addInput(elem, offset, with_grad); } } else { - JIT_ASSERT(offset < args.size()); // NB: no need to set is_tensor to false, because we memset the struct to 0 above combineHash(arg); offset++; diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index cbdf89366f..9d4054651e 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -298,6 +298,9 @@ struct GraphExecutorImpl { } static size_t countFlatInputs(const TypePtr& ptr) { + if (auto optional_type = ptr->cast<OptionalType>()) { + return countFlatInputs(optional_type->getElementType()); + } if (auto tuple_type = ptr->cast<TupleType>()) { size_t total = 0; for (auto & elem : tuple_type->elements()) { diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 4d5f2a1e1b..89d9b94c58 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -8,8 +8,9 @@ from .module import Module from ..parameter import Parameter from ..utils.rnn import PackedSequence from .. import init +from .. import _VF +from ..._jit_internal import weak_module, weak_script_method -_VF = torch._C._VariableFunctions _rnn_impls = { 'LSTM': _VF.lstm, 'GRU': _VF.gru, @@ -535,6 +536,7 @@ class GRU(RNNBase): class RNNCellBase(Module): + __constants__ = ['input_size', 'hidden_size', 'bias'] def __init__(self, input_size, hidden_size, bias, num_chunks): super(RNNCellBase, self).__init__() @@ -559,13 +561,16 @@ class RNNCellBase(Module): s += ', nonlinearity={nonlinearity}' return s.format(**self.__dict__) + @weak_script_method def check_forward_input(self, input): if input.size(1) != self.input_size: raise RuntimeError( "input has inconsistent input_size: got {}, expected {}".format( input.size(1), self.input_size)) + @weak_script_method def check_forward_hidden(self, input, hx, hidden_label=''): + # type: (Tensor, Tensor, str) if input.size(0) != hx.size(0): raise RuntimeError( "Input batch size {} doesn't match hidden{} batch size {}".format( @@ -582,6 +587,7 @@ class RNNCellBase(Module): init.uniform_(weight, -stdv, stdv) +@weak_module class RNNCell(RNNCellBase): r"""An Elman RNN cell with tanh or ReLU non-linearity. @@ -630,31 +636,41 @@ class RNNCell(RNNCellBase): hx = rnn(input[i], hx) output.append(hx) """ + __constants__ = ['input_size', 'hidden_size', 'bias', 'nonlinearity'] def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh"): super(RNNCell, self).__init__(input_size, hidden_size, bias, num_chunks=1) self.nonlinearity = nonlinearity + @weak_script_method def forward(self, input, hx=None): + # type: (Tensor, Optional[Tensor]) -> Tensor self.check_forward_input(input) if hx is None: - hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) - self.check_forward_hidden(input, hx) + _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + else: + _hx = torch.jit._unwrap_optional(hx) + self.check_forward_hidden(input, _hx, '') if self.nonlinearity == "tanh": - func = _VF.rnn_tanh_cell + ret = _VF.rnn_tanh_cell( + input, _hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) elif self.nonlinearity == "relu": - func = _VF.rnn_relu_cell + ret = _VF.rnn_relu_cell( + input, _hx, + self.weight_ih, self.weight_hh, + self.bias_ih, self.bias_hh, + ) else: + ret = input # TODO: remove when jit supports exception flow raise RuntimeError( "Unknown nonlinearity: {}".format(self.nonlinearity)) - - return func( - input, hx, - self.weight_ih, self.weight_hh, - self.bias_ih, self.bias_hh, - ) + return ret +@weak_module class LSTMCell(RNNCellBase): r"""A long short-term memory (LSTM) cell. @@ -719,20 +735,25 @@ class LSTMCell(RNNCellBase): def __init__(self, input_size, hidden_size, bias=True): super(LSTMCell, self).__init__(input_size, hidden_size, bias, num_chunks=4) + @weak_script_method def forward(self, input, hx=None): + # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor] self.check_forward_input(input) if hx is None: - hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) - hx = (hx, hx) - self.check_forward_hidden(input, hx[0], '[0]') - self.check_forward_hidden(input, hx[1], '[1]') + zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + _hx = (zeros, zeros) + else: + _hx = torch.jit._unwrap_optional(hx) + self.check_forward_hidden(input, _hx[0], '[0]') + self.check_forward_hidden(input, _hx[1], '[1]') return _VF.lstm_cell( - input, hx, + input, _hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, ) +@weak_module class GRUCell(RNNCellBase): r"""A gated recurrent unit (GRU) cell @@ -789,13 +810,17 @@ class GRUCell(RNNCellBase): def __init__(self, input_size, hidden_size, bias=True): super(GRUCell, self).__init__(input_size, hidden_size, bias, num_chunks=3) + @weak_script_method def forward(self, input, hx=None): + # type: (Tensor, Optional[Tensor]) -> Tensor self.check_forward_input(input) if hx is None: - hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False) - self.check_forward_hidden(input, hx) + _hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device) + else: + _hx = torch.jit._unwrap_optional(hx) + self.check_forward_hidden(input, _hx, '') return _VF.gru_cell( - input, hx, + input, _hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh, ) |