diff options
author | gchanan <gregchanan@gmail.com> | 2018-04-16 23:52:59 -0400 |
---|---|---|
committer | Edward Z. Yang <ezyang@mit.edu> | 2018-04-16 23:52:59 -0400 |
commit | 5ed3f3347a5684a2b6208546e34c9a13771e77ab (patch) | |
tree | d3db48be704db088d4b33a026d3b98c464127b58 /torch | |
parent | dd91d57c3f5647fb4ac63ac4325a42224f9a3028 (diff) | |
download | pytorch-5ed3f3347a5684a2b6208546e34c9a13771e77ab.tar.gz pytorch-5ed3f3347a5684a2b6208546e34c9a13771e77ab.tar.bz2 pytorch-5ed3f3347a5684a2b6208546e34c9a13771e77ab.zip |
Add dtypes (with reasonable defaults) to sum, prod, cumsum, cumprod. (#6573)
* Add dtypes (with reasonable defaults) to sum, prod, cumsum, cumprod.
This adds optional dtypes to torch.sum, torch.prod, torch.cumsum, torch.cumprod.
By default, the dtype is torch.float64 for integral types, and the dtype of the input for floating point types.
* Don't use optional<ScalarType>, because the jit can't handle it yet.
Instead, we manually build the overloads. This is fairly painful because of default arguments, but should be easy to pull out once the jit can handle optional<ScalarType>.
* Fix keepdim with out parameters.
* Fix _cudnn_rnn_flatten_weight.
* If dtype is provided to an out function, make sure it matches the dtype of the result.
* Fix typo.
Diffstat (limited to 'torch')
-rw-r--r-- | torch/csrc/utils/python_arg_parser.cpp | 1 | ||||
-rw-r--r-- | torch/csrc/utils/python_arg_parser.h | 6 |
2 files changed, 7 insertions, 0 deletions
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index b1ffea24f9..1291294aec 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -24,6 +24,7 @@ static std::unordered_map<std::string, ParameterType> type_map = { {"Storage", ParameterType::STORAGE}, {"PyObject*", ParameterType::PYOBJECT}, {"ScalarType", ParameterType::SCALARTYPE}, + {"optional<ScalarType>", ParameterType::SCALARTYPE}, {"Layout", ParameterType::LAYOUT}, {"Device", ParameterType::DEVICE}, {"String", ParameterType::STRING}, diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 032ea187c9..9cd055b978 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -95,6 +95,7 @@ struct PythonArgs { inline std::unique_ptr<at::Storage> storage(int i); inline at::ScalarType scalartype(int i); inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype); + inline at::optional<at::ScalarType> scalartypeOptional(int i); inline const THPLayout& layout(int i); inline const THPLayout& layoutWithDefault(int i, const THPLayout& default_layout); inline Device device(int i); @@ -272,6 +273,11 @@ inline at::ScalarType PythonArgs::scalartype(int i) { return reinterpret_cast<THPDtype*>(args[i])->scalar_type; } +inline at::optional<at::ScalarType> PythonArgs::scalartypeOptional(int i) { + if (!args[i]) return at::nullopt; + return scalartype(i); +} + inline const THPLayout& PythonArgs::layout(int i) { if (!args[i]) return *signature.params[i].default_layout; return *reinterpret_cast<THPLayout*>(args[i]); |