diff options
author | gchanan <gregchanan@gmail.com> | 2018-04-25 22:46:42 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-25 22:46:42 -0400 |
commit | 3d907ef78e3040422ed079b11eab9dc20f7b12f5 (patch) | |
tree | 12f12b94c22461d0dfb1aae37cf38cddb6358d31 /tools | |
parent | c10da636b5023931d126eab8382f645e21bef026 (diff) | |
download | pytorch-3d907ef78e3040422ed079b11eab9dc20f7b12f5.tar.gz pytorch-3d907ef78e3040422ed079b11eab9dc20f7b12f5.tar.bz2 pytorch-3d907ef78e3040422ed079b11eab9dc20f7b12f5.zip |
Consistently check 'out' variants against specified dtype/layout/device parameters. (#6973)
We were previously doing this in the most common cases, but not consistently.
Diffstat (limited to 'tools')
-rw-r--r-- | tools/autograd/gen_python_functions.py | 7 | ||||
-rw-r--r-- | tools/autograd/templates/python_torch_functions.cpp | 14 |
2 files changed, 13 insertions, 8 deletions
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 8a75bdd299..d14b136e34 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -73,10 +73,9 @@ PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate("""\ if (r.isNone(${out_idx})) { ${call_dispatch} } else { - if (!r.isNone(${type_idx})) { - check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.layout(${layout_idx}), - r.device(${device_idx}), r.isNone(${device_idx})); - } + check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.isNone(${type_idx}), + r.layout(${layout_idx}), r.isNone(${layout_idx}), + r.device(${device_idx}), r.isNone(${device_idx})); ${call_dispatch_out} } """) diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 35fb89be20..0731f9d3aa 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -34,11 +34,17 @@ static Tensor set_requires_grad(Tensor self, bool requires_grad) { return self; } -static void check_out_type_matches(Tensor result, ScalarType scalarType, const THPLayout& layout, +static void check_out_type_matches(Tensor result, + ScalarType scalarType, bool scalarType_is_none, + const THPLayout& layout, bool layout_is_none, const Device& device, bool device_is_none) { - auto result_device_type = torch::getDeviceType(result.type()); - auto device_type = device_is_none ? result_device_type : device.type; - const auto& type = torch::getType(scalarType, layout, device_type); + if (scalarType_is_none && layout_is_none && device_is_none) { // common case + return; + } + auto scalarType_arg = scalarType_is_none ? result.type().scalarType() : scalarType; + auto layout_arg = layout_is_none ? *torch::getLayout(result.type().backend()) : layout; + auto device_type_arg = device_is_none ? torch::getDeviceType(result.type()) : device.type; + const auto& type = torch::getType(scalarType_arg, layout_arg, device_type_arg); if (result.type() != type) { AT_ERROR( "type corresponding to %s does not match type of out parameter (%s)", |