summaryrefslogtreecommitdiff
path: root/tools
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2018-04-25 22:46:42 -0400
committerGitHub <noreply@github.com>2018-04-25 22:46:42 -0400
commit3d907ef78e3040422ed079b11eab9dc20f7b12f5 (patch)
tree12f12b94c22461d0dfb1aae37cf38cddb6358d31 /tools
parentc10da636b5023931d126eab8382f645e21bef026 (diff)
downloadpytorch-3d907ef78e3040422ed079b11eab9dc20f7b12f5.tar.gz
pytorch-3d907ef78e3040422ed079b11eab9dc20f7b12f5.tar.bz2
pytorch-3d907ef78e3040422ed079b11eab9dc20f7b12f5.zip
Consistently check 'out' variants against specified dtype/layout/device parameters. (#6973)
We were previously doing this in the most common cases, but not consistently.
Diffstat (limited to 'tools')
-rw-r--r--tools/autograd/gen_python_functions.py7
-rw-r--r--tools/autograd/templates/python_torch_functions.cpp14
2 files changed, 13 insertions, 8 deletions
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 8a75bdd299..d14b136e34 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -73,10 +73,9 @@ PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate("""\
if (r.isNone(${out_idx})) {
${call_dispatch}
} else {
- if (!r.isNone(${type_idx})) {
- check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.layout(${layout_idx}),
- r.device(${device_idx}), r.isNone(${device_idx}));
- }
+ check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.isNone(${type_idx}),
+ r.layout(${layout_idx}), r.isNone(${layout_idx}),
+ r.device(${device_idx}), r.isNone(${device_idx}));
${call_dispatch_out}
}
""")
diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp
index 35fb89be20..0731f9d3aa 100644
--- a/tools/autograd/templates/python_torch_functions.cpp
+++ b/tools/autograd/templates/python_torch_functions.cpp
@@ -34,11 +34,17 @@ static Tensor set_requires_grad(Tensor self, bool requires_grad) {
return self;
}
-static void check_out_type_matches(Tensor result, ScalarType scalarType, const THPLayout& layout,
+static void check_out_type_matches(Tensor result,
+ ScalarType scalarType, bool scalarType_is_none,
+ const THPLayout& layout, bool layout_is_none,
const Device& device, bool device_is_none) {
- auto result_device_type = torch::getDeviceType(result.type());
- auto device_type = device_is_none ? result_device_type : device.type;
- const auto& type = torch::getType(scalarType, layout, device_type);
+ if (scalarType_is_none && layout_is_none && device_is_none) { // common case
+ return;
+ }
+ auto scalarType_arg = scalarType_is_none ? result.type().scalarType() : scalarType;
+ auto layout_arg = layout_is_none ? *torch::getLayout(result.type().backend()) : layout;
+ auto device_type_arg = device_is_none ? torch::getDeviceType(result.type()) : device.type;
+ const auto& type = torch::getType(scalarType_arg, layout_arg, device_type_arg);
if (result.type() != type) {
AT_ERROR(
"type corresponding to %s does not match type of out parameter (%s)",