summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2018-04-16 23:52:59 -0400
committerEdward Z. Yang <ezyang@mit.edu>2018-04-16 23:52:59 -0400
commit5ed3f3347a5684a2b6208546e34c9a13771e77ab (patch)
treed3db48be704db088d4b33a026d3b98c464127b58 /torch
parentdd91d57c3f5647fb4ac63ac4325a42224f9a3028 (diff)
downloadpytorch-5ed3f3347a5684a2b6208546e34c9a13771e77ab.tar.gz
pytorch-5ed3f3347a5684a2b6208546e34c9a13771e77ab.tar.bz2
pytorch-5ed3f3347a5684a2b6208546e34c9a13771e77ab.zip
Add dtypes (with reasonable defaults) to sum, prod, cumsum, cumprod. (#6573)
* Add dtypes (with reasonable defaults) to sum, prod, cumsum, cumprod. This adds optional dtypes to torch.sum, torch.prod, torch.cumsum, torch.cumprod. By default, the dtype is torch.float64 for integral types, and the dtype of the input for floating point types. * Don't use optional<ScalarType>, because the jit can't handle it yet. Instead, we manually build the overloads. This is fairly painful because of default arguments, but should be easy to pull out once the jit can handle optional<ScalarType>. * Fix keepdim with out parameters. * Fix _cudnn_rnn_flatten_weight. * If dtype is provided to an out function, make sure it matches the dtype of the result. * Fix typo.
Diffstat (limited to 'torch')
-rw-r--r--torch/csrc/utils/python_arg_parser.cpp1
-rw-r--r--torch/csrc/utils/python_arg_parser.h6
2 files changed, 7 insertions, 0 deletions
diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp
index b1ffea24f9..1291294aec 100644
--- a/torch/csrc/utils/python_arg_parser.cpp
+++ b/torch/csrc/utils/python_arg_parser.cpp
@@ -24,6 +24,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
{"Storage", ParameterType::STORAGE},
{"PyObject*", ParameterType::PYOBJECT},
{"ScalarType", ParameterType::SCALARTYPE},
+ {"optional<ScalarType>", ParameterType::SCALARTYPE},
{"Layout", ParameterType::LAYOUT},
{"Device", ParameterType::DEVICE},
{"String", ParameterType::STRING},
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 032ea187c9..9cd055b978 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -95,6 +95,7 @@ struct PythonArgs {
inline std::unique_ptr<at::Storage> storage(int i);
inline at::ScalarType scalartype(int i);
inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
+ inline at::optional<at::ScalarType> scalartypeOptional(int i);
inline const THPLayout& layout(int i);
inline const THPLayout& layoutWithDefault(int i, const THPLayout& default_layout);
inline Device device(int i);
@@ -272,6 +273,11 @@ inline at::ScalarType PythonArgs::scalartype(int i) {
return reinterpret_cast<THPDtype*>(args[i])->scalar_type;
}
+inline at::optional<at::ScalarType> PythonArgs::scalartypeOptional(int i) {
+ if (!args[i]) return at::nullopt;
+ return scalartype(i);
+}
+
inline const THPLayout& PythonArgs::layout(int i) {
if (!args[i]) return *signature.params[i].default_layout;
return *reinterpret_cast<THPLayout*>(args[i]);