summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--test/test_torch.py10
-rw-r--r--tools/autograd/templates/python_torch_functions.cpp40
-rw-r--r--tools/autograd/templates/python_torch_functions_dispatch.h32
3 files changed, 61 insertions, 21 deletions
diff --git a/test/test_torch.py b/test/test_torch.py
index c099208bed..e97d5cceb6 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -1161,18 +1161,28 @@ class TestTorch(TestCase):
res2[i] = max(min_val, min(max_val, res2[i]))
self.assertEqual(res1, res2)
+ out = m1.clone()
+ torch.clamp(m1, min=min_val, max=max_val, out=out)
+ self.assertEqual(out, res1)
+
res1 = torch.clamp(m1, min=min_val)
res2 = m1.clone()
for i in iter_indices(res2):
res2[i] = max(min_val, res2[i])
self.assertEqual(res1, res2)
+ torch.clamp(m1, min=min_val, out=out)
+ self.assertEqual(out, res1)
+
res1 = torch.clamp(m1, max=max_val)
res2 = m1.clone()
for i in iter_indices(res2):
res2[i] = min(max_val, res2[i])
self.assertEqual(res1, res2)
+ torch.clamp(m1, max=max_val, out=out)
+ self.assertEqual(out, res1)
+
def test_pow(self):
# [res] torch.pow([res,] x)
diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp
index 2dc9e0ee26..e2009bac1f 100644
--- a/tools/autograd/templates/python_torch_functions.cpp
+++ b/tools/autograd/templates/python_torch_functions.cpp
@@ -43,40 +43,38 @@ static void check_out_type_matches(Tensor result, const THPDtype &dtype, const T
}
}
-static Tensor dispatch_clamp(const Tensor & self, Scalar min, Scalar max) {
- AutoNoGIL no_gil;
- AutoGPU auto_gpu(self);
- return self.clamp(min, max);
-}
-static Tensor dispatch_clamp_min(const Tensor & self, Scalar min) {
- AutoNoGIL no_gil;
- AutoGPU auto_gpu(self);
- return self.clamp_min(min);
-}
-static Tensor dispatch_clamp_max(const Tensor & self, Scalar max) {
- AutoNoGIL no_gil;
- AutoGPU auto_gpu(self);
- return self.clamp_max(max);
-}
-
// The Python clamp() syntax has to be mapped to one of three C++ functions
static PyObject * THPVariable_clamp(PyObject* module, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
- "clamp(Tensor input, Scalar min=None, Scalar max=None)",
+ "clamp(Tensor input, Scalar min=None, Scalar max=None, *, Tensor out=None)",
});
- ParsedArgs<3> parsed_args;
+
+ ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (!r.isNone(1) && !r.isNone(2)) {
- return THPVariable_Wrap(dispatch_clamp(r.tensor(0), r.scalar(1), r.scalar(2)));
+ if (!r.isNone(3)) {
+ return wrap(dispatch_clamp(r.tensor(0), r.scalar(1), r.scalar(2), r.tensor(3)));
+ } else {
+ return wrap(dispatch_clamp(r.tensor(0), r.scalar(1), r.scalar(2)));
+ }
} else if (!r.isNone(1)) {
- return THPVariable_Wrap(dispatch_clamp_min(r.tensor(0), r.scalar(1)));
+ if (!r.isNone(3)) {
+ return wrap(dispatch_clamp_min(r.tensor(0), r.scalar(1), r.tensor(3)));
+ } else {
+ return wrap(dispatch_clamp_min(r.tensor(0), r.scalar(1)));
+ }
} else if (!r.isNone(2)) {
- return THPVariable_Wrap(dispatch_clamp_max(r.tensor(0), r.scalar(2)));
+ if (!r.isNone(3)) {
+ return wrap(dispatch_clamp_max(r.tensor(0), r.scalar(2), r.tensor(3)));
+ } else {
+ return wrap(dispatch_clamp_max(r.tensor(0), r.scalar(2)));
+ }
} else {
throw std::runtime_error("At least one of 'min' or 'max' must not be None");
}
+ Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
diff --git a/tools/autograd/templates/python_torch_functions_dispatch.h b/tools/autograd/templates/python_torch_functions_dispatch.h
index dfae2585cf..45c4f79868 100644
--- a/tools/autograd/templates/python_torch_functions_dispatch.h
+++ b/tools/autograd/templates/python_torch_functions_dispatch.h
@@ -32,6 +32,38 @@ static void maybe_initialize_cuda(const at::Type &type) {
}
}
+// manual dispatch code for clamp
+inline Tensor dispatch_clamp(const Tensor & self, Scalar min, Scalar max) {
+ AutoNoGIL no_gil;
+ AutoGPU auto_gpu(self);
+ return self.clamp(min, max);
+}
+inline Tensor dispatch_clamp_min(const Tensor & self, Scalar min) {
+ AutoNoGIL no_gil;
+ AutoGPU auto_gpu(self);
+ return self.clamp_min(min);
+}
+inline Tensor dispatch_clamp_max(const Tensor & self, Scalar max) {
+ AutoNoGIL no_gil;
+ AutoGPU auto_gpu(self);
+ return self.clamp_max(max);
+}
+inline Tensor & dispatch_clamp(const Tensor & self, Scalar min, Scalar max, Tensor result) {
+ AutoNoGIL no_gil;
+ AutoGPU auto_gpu(result);
+ return at::clamp_out(result, self, min, max);
+}
+inline Tensor & dispatch_clamp_min(const Tensor & self, Scalar min, Tensor result) {
+ AutoNoGIL no_gil;
+ AutoGPU auto_gpu(result);
+ return at::clamp_min_out(result, self, min);
+}
+inline Tensor & dispatch_clamp_max(const Tensor & self, Scalar max, Tensor result) {
+ AutoNoGIL no_gil;
+ AutoGPU auto_gpu(result);
+ return at::clamp_max_out(result, self, max);
+}
+
${py_method_dispatch}
}} // namespace torch::autograd