diff options
-rw-r--r-- | test/test_torch.py | 10 | ||||
-rw-r--r-- | tools/autograd/templates/python_torch_functions.cpp | 40 | ||||
-rw-r--r-- | tools/autograd/templates/python_torch_functions_dispatch.h | 32 |
3 files changed, 61 insertions, 21 deletions
diff --git a/test/test_torch.py b/test/test_torch.py index c099208bed..e97d5cceb6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1161,18 +1161,28 @@ class TestTorch(TestCase): res2[i] = max(min_val, min(max_val, res2[i])) self.assertEqual(res1, res2) + out = m1.clone() + torch.clamp(m1, min=min_val, max=max_val, out=out) + self.assertEqual(out, res1) + res1 = torch.clamp(m1, min=min_val) res2 = m1.clone() for i in iter_indices(res2): res2[i] = max(min_val, res2[i]) self.assertEqual(res1, res2) + torch.clamp(m1, min=min_val, out=out) + self.assertEqual(out, res1) + res1 = torch.clamp(m1, max=max_val) res2 = m1.clone() for i in iter_indices(res2): res2[i] = min(max_val, res2[i]) self.assertEqual(res1, res2) + torch.clamp(m1, max=max_val, out=out) + self.assertEqual(out, res1) + def test_pow(self): # [res] torch.pow([res,] x) diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 2dc9e0ee26..e2009bac1f 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -43,40 +43,38 @@ static void check_out_type_matches(Tensor result, const THPDtype &dtype, const T } } -static Tensor dispatch_clamp(const Tensor & self, Scalar min, Scalar max) { - AutoNoGIL no_gil; - AutoGPU auto_gpu(self); - return self.clamp(min, max); -} -static Tensor dispatch_clamp_min(const Tensor & self, Scalar min) { - AutoNoGIL no_gil; - AutoGPU auto_gpu(self); - return self.clamp_min(min); -} -static Tensor dispatch_clamp_max(const Tensor & self, Scalar max) { - AutoNoGIL no_gil; - AutoGPU auto_gpu(self); - return self.clamp_max(max); -} - // The Python clamp() syntax has to be mapped to one of three C++ functions static PyObject * THPVariable_clamp(PyObject* module, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS static PythonArgParser parser({ - "clamp(Tensor input, Scalar min=None, Scalar max=None)", + "clamp(Tensor input, Scalar min=None, Scalar max=None, *, Tensor out=None)", }); - ParsedArgs<3> parsed_args; + + ParsedArgs<4> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (!r.isNone(1) && !r.isNone(2)) { - return THPVariable_Wrap(dispatch_clamp(r.tensor(0), r.scalar(1), r.scalar(2))); + if (!r.isNone(3)) { + return wrap(dispatch_clamp(r.tensor(0), r.scalar(1), r.scalar(2), r.tensor(3))); + } else { + return wrap(dispatch_clamp(r.tensor(0), r.scalar(1), r.scalar(2))); + } } else if (!r.isNone(1)) { - return THPVariable_Wrap(dispatch_clamp_min(r.tensor(0), r.scalar(1))); + if (!r.isNone(3)) { + return wrap(dispatch_clamp_min(r.tensor(0), r.scalar(1), r.tensor(3))); + } else { + return wrap(dispatch_clamp_min(r.tensor(0), r.scalar(1))); + } } else if (!r.isNone(2)) { - return THPVariable_Wrap(dispatch_clamp_max(r.tensor(0), r.scalar(2))); + if (!r.isNone(3)) { + return wrap(dispatch_clamp_max(r.tensor(0), r.scalar(2), r.tensor(3))); + } else { + return wrap(dispatch_clamp_max(r.tensor(0), r.scalar(2))); + } } else { throw std::runtime_error("At least one of 'min' or 'max' must not be None"); } + Py_RETURN_NONE; END_HANDLE_TH_ERRORS } diff --git a/tools/autograd/templates/python_torch_functions_dispatch.h b/tools/autograd/templates/python_torch_functions_dispatch.h index dfae2585cf..45c4f79868 100644 --- a/tools/autograd/templates/python_torch_functions_dispatch.h +++ b/tools/autograd/templates/python_torch_functions_dispatch.h @@ -32,6 +32,38 @@ static void maybe_initialize_cuda(const at::Type &type) { } } +// manual dispatch code for clamp +inline Tensor dispatch_clamp(const Tensor & self, Scalar min, Scalar max) { + AutoNoGIL no_gil; + AutoGPU auto_gpu(self); + return self.clamp(min, max); +} +inline Tensor dispatch_clamp_min(const Tensor & self, Scalar min) { + AutoNoGIL no_gil; + AutoGPU auto_gpu(self); + return self.clamp_min(min); +} +inline Tensor dispatch_clamp_max(const Tensor & self, Scalar max) { + AutoNoGIL no_gil; + AutoGPU auto_gpu(self); + return self.clamp_max(max); +} +inline Tensor & dispatch_clamp(const Tensor & self, Scalar min, Scalar max, Tensor result) { + AutoNoGIL no_gil; + AutoGPU auto_gpu(result); + return at::clamp_out(result, self, min, max); +} +inline Tensor & dispatch_clamp_min(const Tensor & self, Scalar min, Tensor result) { + AutoNoGIL no_gil; + AutoGPU auto_gpu(result); + return at::clamp_min_out(result, self, min); +} +inline Tensor & dispatch_clamp_max(const Tensor & self, Scalar max, Tensor result) { + AutoNoGIL no_gil; + AutoGPU auto_gpu(result); + return at::clamp_max_out(result, self, max); +} + ${py_method_dispatch} }} // namespace torch::autograd |