summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CONTRIBUTING.md2
-rw-r--r--aten/src/ATen/Declarations.cwrap1
-rw-r--r--aten/src/ATen/native/NativeFunctions.cpp46
-rw-r--r--aten/src/ATen/native/native_functions.yaml4
-rw-r--r--aten/src/THC/generic/THCTensorMathMagma.cu50
-rw-r--r--aten/src/THC/generic/THCTensorMathMagma.h2
-rw-r--r--docs/source/torch.rst1
-rw-r--r--test/test_autograd.py60
-rw-r--r--test/test_cuda.py7
-rw-r--r--test/test_torch.py91
-rw-r--r--tools/autograd/derivatives.yaml19
-rw-r--r--tools/autograd/gen_python_functions.py4
-rw-r--r--tools/autograd/templates/Functions.cpp94
-rw-r--r--tools/autograd/templates/VariableType.cpp9
-rw-r--r--tools/autograd/templates/VariableType.h1
-rw-r--r--tools/jit/templates/aten_dispatch.cpp6
-rw-r--r--torch/_torch_docs.py17
-rw-r--r--torch/csrc/autograd/utils/wrap_outputs.h10
-rw-r--r--torch/csrc/generic/methods/TensorMath.cwrap1
-rw-r--r--torch/functional.py19
20 files changed, 424 insertions, 20 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index a79fb97f9f..3ad384c20c 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -207,7 +207,7 @@ If you are working on the CUDA code, here are some useful CUDA debugging tips:
1. `CUDA_DEBUG=1` will enable CUDA debugging symbols (-g -G). This is particularly
helpful in debugging device code. However, it will slow down the build process,
so use wisely.
-2. `cuda-gdb` and `cuda-memcheck` are your best CUDA debuging friends. Unlike`gdb`,
+2. `cuda-gdb` and `cuda-memcheck` are your best CUDA debugging friends. Unlike`gdb`,
`cuda-gdb` can display actual values in a CUDA tensor (rather than all zeros).
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index dae2fce805..53fee637ca 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -3504,6 +3504,7 @@
- Double
backends:
- CPU
+ - CUDA
variants:
- method
- function
diff --git a/aten/src/ATen/native/NativeFunctions.cpp b/aten/src/ATen/native/NativeFunctions.cpp
index 5ddd7a1b85..70929cd89e 100644
--- a/aten/src/ATen/native/NativeFunctions.cpp
+++ b/aten/src/ATen/native/NativeFunctions.cpp
@@ -270,6 +270,52 @@ Tensor & unsqueeze_(Tensor& self, int64_t dim) {
return self.as_strided_(std::get<0>(g), std::get<1>(g));
}
+// For backward, we save svd.
+// http://www.ics.forth.gr/cvrl/publications/conferences/2000_eccv_SVD_jacobian.pdf
+// But instead of gesvd SVD A = U(A) Sig(A) V(A)^T, which doesn't specify signs
+// of determinants of U and V, we consider det(A) = \prod Sig_(A), where
+// 1. A = U_(A) Sig_(A) V(A)^T
+// 2. Sig_(A) and U_(A) can be different in signs in first row/col from
+// their counterparts so that U_(A) * V_(A) have +1 determinant
+std::tuple<Tensor, Tensor, Tensor, Tensor> _det_with_svd(const Tensor& self) {
+ if (!at::isFloatingType(self.type().scalarType()) ||
+ self.dim() != 2 || self.size(0) != self.size(1)) {
+ std::ostringstream ss;
+ ss << "det(" << self.type() << "{" << self.sizes() << "}): expected a 2D"
+ << "square tensor of floating types";
+ throw std::runtime_error(ss.str());
+ }
+ // check symmetric
+ bool symmetric = self.equal(self.transpose(0, 1));
+
+ auto svd = self.svd(true);
+ auto sigma = std::get<1>(svd);
+ auto u = std::get<0>(svd);
+ auto v = std::get<2>(svd);
+ auto det = sigma.prod();
+ if (!symmetric) {
+ auto qr = self.geqrf();
+ auto a = std::get<0>(qr);
+ auto tau = std::get<1>(qr);
+ // non-zero values in tau represent Householder reflectors, which has -1 det
+ int64_t num_reflectors = tau.nonzero().size(0);
+ auto qr_det = a.diag().prod();
+ if (num_reflectors % 2 == 1) {
+ qr_det = -qr_det;
+ }
+ det = qr_det; // QR is more stable than svd, so use it anyways
+ if ((qr_det < 0).any() ^ (det < 0).any()) { // if different sign
+ u.narrow(1, 0, 1).mul_(-1);
+ sigma.narrow(0, 0, 1).mul_(-1);
+ }
+ }
+ return std::make_tuple(det, u, sigma, v);
+}
+
+Tensor det(const Tensor& self) {
+ return std::get<0>(self._det_with_svd());
+}
+
Tensor stack(TensorList tensors, int64_t dim) {
if (tensors.size() == 0) {
throw std::runtime_error("stack expects a non-empty TensorList");
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 9e5a04ae29..06b6af9b09 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -76,6 +76,10 @@
- func: unsqueeze_(Tensor self, int64_t dim) -> Tensor
+- func: _det_with_svd(Tensor self) -> (Tensor, Tensor, Tensor, Tensor)
+
+- func: det(Tensor self) -> Tensor
+
- func: stack(TensorList tensors, int64_t dim=0) -> Tensor
variants: function
diff --git a/aten/src/THC/generic/THCTensorMathMagma.cu b/aten/src/THC/generic/THCTensorMathMagma.cu
index f746d878b7..1eb39857de 100644
--- a/aten/src/THC/generic/THCTensorMathMagma.cu
+++ b/aten/src/THC/generic/THCTensorMathMagma.cu
@@ -584,6 +584,51 @@ THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, TH
#endif
}
+THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_)
+{
+#ifdef USE_MAGMA
+ THArgCheck(a_->nDimension == 2, 2, "A should be 2 dimensional");
+
+ THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_);
+ int64_t m = a->size[0];
+ int64_t n = a->size[1];
+ int64_t k = (m < n ? m : n);
+
+#ifdef MAGMA_V2
+#if defined(THC_REAL_IS_FLOAT)
+ int64_t nb = magma_get_sgeqrf_nb(m, n);
+#else
+ int64_t nb = magma_get_dgeqrf_nb(m, n);
+#endif
+#else
+#if defined(THC_REAL_IS_FLOAT)
+ int64_t nb = magma_get_sgeqrf_nb(m);
+#else
+ int64_t nb = magma_get_dgeqrf_nb(m);
+#endif
+#endif
+
+ real *rtau_data = th_magma_malloc_pinned<real>(k);
+ real *a_data = THCTensor_(data)(state, a);
+
+ int info;
+#if defined(THC_REAL_IS_FLOAT)
+ magma_sgeqrf2_gpu(m, n, a_data, m, rtau_data, &info);
+#else
+ magma_dgeqrf2_gpu(m, n, a_data, m, rtau_data, &info);
+#endif
+
+ if (info != 0)
+ THError("MAGMA geqrf2 : Argument %d : illegal value.", -info);
+
+ THCTensor_(freeCopyTo)(state, a, ra_);
+ THCTensor_(copyArray1d)(state, rtau_, rtau_data, k);
+ magma_free_pinned(rtau_data);
+#else
+ THError(NoMagma(geqrf));
+#endif
+}
+
THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a_)
{
#ifdef USE_MAGMA
@@ -614,6 +659,11 @@ THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THC
real *work_data = THCTensor_(data)(state, work);
int info;
+ // We need to call two different versions of ?geqrf:
+ // ?geqrf_gpu allows fast computation of Q via ?orqrf_gpu, but doesn't give
+ // R properly. Note that the MAGMA documentation for this method is wrong.
+ // http://icl.cs.utk.edu/magma/forum/viewtopic.php?f=2&t=1015&p=2800&hilit=geqrf_gpu#p2800
+ // ?geqrf2_gpu gives correct R, but doesn't allow computation of Q via ?orqrf_gpu
#if defined(THC_REAL_IS_FLOAT)
magma_sgeqrf2_gpu(m, n, a_data, m, tau_data, &info);
#else
diff --git a/aten/src/THC/generic/THCTensorMathMagma.h b/aten/src/THC/generic/THCTensorMathMagma.h
index 938daea568..2aee308e0c 100644
--- a/aten/src/THC/generic/THCTensorMathMagma.h
+++ b/aten/src/THC/generic/THCTensorMathMagma.h
@@ -15,9 +15,9 @@ THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a);
THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
THC_API void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *a, THCTensor *b, const char *uplo);
+THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_);
THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a);
-
#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
#endif
diff --git a/docs/source/torch.rst b/docs/source/torch.rst
index b67efbe34f..90d324230a 100644
--- a/docs/source/torch.rst
+++ b/docs/source/torch.rst
@@ -188,6 +188,7 @@ BLAS and LAPACK Operations
.. autofunction:: ger
.. autofunction:: gesv
.. autofunction:: inverse
+.. autofunction:: det
.. autofunction:: matmul
.. autofunction:: mm
.. autofunction:: mv
diff --git a/test/test_autograd.py b/test/test_autograd.py
index 5b2c4e6fb1..a8fe4dd911 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -1898,6 +1898,30 @@ def _make_cov(S):
return torch.mm(L, L.t())
+def random_square_matrix_of_rank(l, rank):
+ assert rank <= l
+ A = torch.randn(l, l)
+ u, s, v = A.svd()
+ for i in range(l):
+ if i >= rank:
+ s[i] = 0
+ elif s[i] == 0:
+ s[i] = 1
+ return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
+
+
+def random_symmetric_matrix(l):
+ A = torch.randn(l, l)
+ return A.mm(A.transpose(0, 1))
+
+
+def random_fullrank_matrix_distinct_singular_value(l):
+ A = torch.randn(l, l)
+ u, _, v = A.svd()
+ s = torch.arange(1, l + 1).mul_(1.0 / (l + 1))
+ return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
+
+
class dont_convert(tuple):
pass
@@ -1906,7 +1930,6 @@ L = 20
M = 10
S = 5
-
# (name, size, args...)
method_tests = [
('add', (S, S, S), ((S, S, S),)),
@@ -2166,6 +2189,13 @@ method_tests = [
('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', [0]),
('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', [0]),
('inverse', (S, S), (), '', (), [skipIfNoLapack]),
+ ('det', (S, S), (), '', (), [skipIfNoLapack]),
+ ('det', lambda: random_symmetric_matrix(S), (), 'symmetric', (), [skipIfNoLapack]),
+ ('det', lambda: random_square_matrix_of_rank(S, S - 2), (), 'dim2_null', (), [skipIfNoLapack]),
+ ('det', lambda: random_square_matrix_of_rank(S, 1), (), 'rank1', (), [skipIfNoLapack]),
+ ('det', lambda: random_square_matrix_of_rank(S, 2), (), 'rank2', (), [skipIfNoLapack]),
+ ('det', lambda: random_fullrank_matrix_distinct_singular_value(S), (), 'distinct_postive_s', (), [skipIfNoLapack]),
+ ('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), (), '', (), [skipIfNoLapack]),
('gesv', (S, S), ((S, S),), '', (), [skipIfNoLapack]),
('potrf', _make_cov(S), (True,), '', (), [skipIfNoLapack]),
('eq', (S, S, S), ((S, S, S),)),
@@ -2303,6 +2333,8 @@ def create_input(call_args, requires_grad=True, non_contiguous=False):
return Variable(maybe_non_contig(arg), requires_grad=requires_grad)
elif isinstance(arg, Variable) and non_contiguous:
return Variable(maybe_non_contig(arg.data), requires_grad=arg.requires_grad)
+ elif callable(arg):
+ return map_arg(arg())
else:
return arg
return tuple(map_arg(arg) for arg in call_args)
@@ -2339,6 +2371,19 @@ EXCLUDE_FUNCTIONAL = {
EXCLUDE_GRADCHECK = {
'potrf'
}
+EXCLUDE_GRADGRADCHECK = {
+ 'svd'
+}
+EXCLUDE_GRADGRADCHECK_BY_TEST_NAME = {
+ # Some of the following det ones pass because random matrix has full rank
+ # with high probability. But we can't rely on this. So only test gradgrad on
+ # test_det_distinct_postive_s.
+ 'test_det',
+ 'test_det_symmetric',
+ 'test_det_dim2_null',
+ 'test_det_rank1',
+ 'test_det_rank2'
+}
def exclude_tensor_method(name, test_name):
@@ -2359,6 +2404,7 @@ def exclude_tensor_method(name, test_name):
'resize_as',
'scatter',
'scatter_add',
+ 'det',
}
if test_name in exclude_all_tensor_method_by_test_name:
return True
@@ -2390,9 +2436,11 @@ def gradgradcheck_method_precision_override(test_name):
return override
-def run_grad_and_gradgrad_checks(test_case, test_name, apply_method, output_variable, input_variables):
+def run_grad_and_gradgrad_checks(test_case, name, test_name, apply_method, output_variable,
+ input_variables, run_gradgradcheck=True):
test_case.assertTrue(gradcheck(apply_method, input_variables, eps=1e-6, atol=PRECISION))
-
+ if name in EXCLUDE_GRADGRADCHECK or test_name in EXCLUDE_GRADGRADCHECK_BY_TEST_NAME:
+ return
grad_y = generate_gradoutput(output_variable, non_contiguous=True)
gradgradcheck_precision_override = gradgradcheck_method_precision_override(test_name)
if gradgradcheck_precision_override is not None:
@@ -2400,7 +2448,7 @@ def run_grad_and_gradgrad_checks(test_case, test_name, apply_method, output_vari
rtol = gradgradcheck_precision_override['rtol']
test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y, atol=atol, rtol=rtol))
else:
- test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y,))
+ test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y))
def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
@@ -2413,7 +2461,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
test_case.assertEqual(unpack_variables(output_variable), output_tensor)
if run_grad_checks:
- run_grad_and_gradgrad_checks(test_case, test_name, apply_fn,
+ run_grad_and_gradgrad_checks(test_case, name, test_name, apply_fn,
output_variable, f_args_variable)
self_variable = f_args_variable[0]
@@ -2457,7 +2505,7 @@ for test in method_tests:
# TODO: check that both have changed after adding all inplace ops
if not is_inplace and name not in EXCLUDE_GRADCHECK:
- run_grad_and_gradgrad_checks(self, test_name,
+ run_grad_and_gradgrad_checks(self, name, test_name,
lambda *inputs: getattr(inputs[0], name)(*inputs[1:]),
output_variable, (self_variable,) + args_variable)
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 67f8f90255..784ae22beb 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -316,7 +316,8 @@ tests = [
('qr', small_2d_lapack_fat, lambda t: [], 'fat', float_types),
('qr', large_2d_lapack, lambda t: [], 'big', float_types),
('inverse', new_t(20, 20), lambda t: [], None, float_types),
-
+ ('geqrf', new_t(20, 20), lambda t: [], None, float_types),
+ # TODO: add det to here once Variable and Tensor are the same thing
]
# TODO: random functions, cat, gather, scatter, index*, masked*,
@@ -938,6 +939,10 @@ class TestCuda(TestCase):
def _select_broadcastable_dims(dims_full=None):
return TestTorch._select_broadcastable_dims(dims_full)
+ @unittest.skipIf(not HAS_MAGMA, "no MAGMA library detected")
+ def test_det(self):
+ TestTorch._test_det(self, lambda t: t.cuda())
+
def test_broadcast(self):
TestTorch._test_broadcast(self, lambda t: t.cuda())
diff --git a/test/test_torch.py b/test/test_torch.py
index fb6e7f144d..36c9d7948c 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2471,6 +2471,97 @@ class TestTorch(TestCase):
self.assertFalse(MII.is_contiguous(), 'MII is contiguous')
self.assertEqual(MII, MI, 0, 'inverse value in-place')
+ @staticmethod
+ def _test_det(self, conv_fn):
+ def reference_det(M):
+ # naive row reduction
+ M = M.clone()
+ l = M.size(0)
+ multiplier = 1
+ for i in range(l):
+ if M[i, 0] != 0:
+ if i != 0:
+ M[0], M[i] = M[i], M[0]
+ multiplier = -1
+ break
+ else:
+ return 0
+ for i in range(1, l):
+ row = M[i]
+ for j in range(i):
+ row -= row[j] / M[j, j] * M[j]
+ M[i] = row
+ return M.diag().prod() * multiplier
+
+ # TODO: remove Variable wrapper once Variable and Tensor are the same
+ Variable = torch.autograd.Variable
+
+ eye_det = Variable(conv_fn(torch.eye(5))).det()
+ self.assertEqual(eye_det, eye_det.clone().fill_(1), 1e-8, 'determinant of identity')
+
+ def test(M):
+ M = conv_fn(M)
+ var_M = Variable(M)
+ M_det = var_M.det().data
+
+ self.assertEqual(M_det, M_det.clone().fill_(reference_det(M)), 1e-8, 'determinant')
+ self.assertEqual(M_det, var_M.inverse().det().data.pow_(-1), 1e-8, 'determinant after transpose')
+ self.assertEqual(M_det, var_M.transpose(0, 1).det().data, 1e-8, 'determinant after transpose')
+
+ for x in [0, 2, 4]:
+ for scale in [-2, -0.1, 0, 10]:
+ target = M_det * scale
+ # dim 0
+ M_clone = M.clone()
+ M_clone[:, x] *= scale
+ det = Variable(M_clone).det().data
+ self.assertEqual(target, det, 1e-8, 'determinant after scaling a row')
+ # dim 1
+ M_clone = M.clone()
+ M_clone[x, :] *= scale
+ det = Variable(M_clone).det().data
+ self.assertEqual(target, det, 1e-8, 'determinant after scaling a column')
+
+ for x1, x2 in [(0, 3), (4, 1), (3, 2)]:
+ assert x1 != x2, 'x1 and x2 needs to be different for this test'
+ target = M_det.clone().zero_()
+ # dim 0
+ M_clone = M.clone()
+ M_clone[:, x2] = M_clone[:, x1]
+ det = Variable(M_clone).det().data
+ self.assertEqual(target, det, 1e-8, 'determinant when two rows are same')
+ # dim 1
+ M_clone = M.clone()
+ M_clone[x2, :] = M_clone[x1, :]
+ det = Variable(M_clone).det().data
+ self.assertEqual(target, det, 1e-8, 'determinant when two columns are same')
+
+ for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]:
+ target = -M_det * scale1 * scale2
+ # dim 0
+ M_clone = M.clone()
+ t = M_clone[:, x1] * scale1
+ M_clone[:, x1] += M_clone[:, x2] * scale2
+ M_clone[:, x2] = t
+ det = Variable(M_clone).det().data
+ self.assertEqual(target, det, 1e-8, 'determinant after exchanging rows')
+ # dim 1
+ M_clone = M.clone()
+ t = M_clone[x1, :] * scale1
+ M_clone[x1, :] += M_clone[x2, :] * scale2
+ M_clone[x2, :] = t
+ det = Variable(M_clone).det().data
+ self.assertEqual(target, det, 1e-8, 'determinant after exchanging columns')
+
+ test(torch.randn(5, 5))
+ r = torch.randn(5, 5)
+ test(r.mm(r.transpose(0, 1))) # symmetric
+ test(torch.randn(5, 5, 5)[:, 2, :]) # non-contiguous
+
+ @skipIfNoLapack
+ def test_det(self):
+ self._test_det(self, lambda x: x)
+
@unittest.skip("Not implemented yet")
def test_conv2(self):
x = torch.rand(math.floor(torch.uniform(50, 100)), math.floor(torch.uniform(50, 100)))
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index b38a809a3c..7cb973c603 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -16,16 +16,19 @@
# we are going to left-multiply. When the forward returns multiple
# outputs, 'grad' always refers to the first output; you can refer
# to other outputs using 'grads'
-# - Any of the input arguments, tensor or non-tensor
-# - 'output', representing the result of evaluating the forward
-# expression
+# - Any of the input arguments, tensor or non-tensor, including
+# argument names tha only appear in Declarations.cwrap, e.g. 'output'.
+# - 'result', representing the result of evaluating the forward
+# expression for ATen native function decalarations. If the forward
+# expression outputs a tuple, use 'resultX' instead to access the
+# X-th entry
# - 'grad_input_mask', a std::array<bool, n> (where n is the number
# of differentiable inputs), specifying which inputs actually
# require gradient. (This is only available when multiple
# derivatives are being computed by a single formula.)
#
# If you need a complex expression, e.g., with local variables,
-# write a _backward function in tools/autograd/templates/Function.cpp
+# write a _backward function in tools/autograd/templates/Functions.cpp
# and invoke it from here. By the way, go read
# https://github.com/zdevito/ATen/issues/163; this describes an
# important hazard that occurs when porting backwards from Python to C++
@@ -165,6 +168,9 @@
- name: data_ptr # fallthrough
+- name: _det_with_svd(Tensor self)
+ self: _det_with_svd_backward(grads, self, result0, result1, result2, result3)
+
- name: diag(Tensor self, int64_t diagonal)
self: grad.diag(diagonal)
@@ -443,7 +449,8 @@
# TODO: complicated
# - name: prod(Tensor self, int64_t dim, bool keepdim)
-# - name: prod(Tensor self)
+- name: prod(Tensor self)
+ self: not_implemented("prod")
- name: pstrf(Tensor self, bool upper, Scalar tol)
self: not_implemented("pstrf")
@@ -546,7 +553,7 @@
self: sum_backward(grad, self.sizes(), dim, keepdim)
- name: svd(Tensor self, bool some)
- self: not_implemented("svd")
+ self: svd_backward(grads, self, some, res1, res2, res3)
- name: symeig(Tensor self, bool eigenvectors, bool upper)
self: not_implemented("symeig")
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 0095144de6..1c7686dc2c 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -54,7 +54,9 @@ UNPACK_SELF = "auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;"
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
SUPPORTED_RETURN_TYPES = {
'Tensor', 'std::tuple<Tensor,Tensor>',
- 'std::tuple<Tensor,Tensor,Tensor>', 'std::vector<Tensor>',
+ 'std::tuple<Tensor,Tensor,Tensor>',
+ 'std::tuple<Tensor,Tensor,Tensor,Tensor>',
+ 'std::vector<Tensor>',
'Scalar', 'bool', 'int64_t', 'void*'
}
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 21a0a7ee1d..d256704299 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -1,5 +1,6 @@
#include "Functions.h"
#include <ATen/WrapDimUtils.h>
+#include <iostream>
// define constants like M_PI and C keywords for MSVC
#ifdef _MSC_VER
@@ -502,6 +503,99 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
}
}
+// https://j-towns.github.io/papers/svd-derivative.pdf
+Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
+ bool some, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
+ auto m = self.size(0);
+ auto n = self.size(1);
+ auto k = sigma.size(0);
+
+ Tensor u, v;
+ if (!some) {
+ // ignore the free subspace
+ u = raw_u.narrow(1, 0, k);
+ v = raw_v.narrow(1, 0, k);
+ } else {
+ u = raw_u;
+ v = raw_v;
+ }
+
+ auto gu = grads[0];
+ auto gsigma = grads[1];
+ auto gv = grads[2];
+ auto im = self.type().eye(m);
+ auto in = self.type().eye(n);
+ auto ut = u.t();
+ auto vt = v.t();
+ auto sigma_mat = sigma.diag();
+ auto sigma_mat_inv = sigma.pow(-1).diag();
+ auto sigma_expanded_sq = sigma.pow(2).expand_as(sigma_mat);
+ auto F = (sigma_expanded_sq - sigma_expanded_sq.t()).pow(-1);
+ auto& long_type = sigma.type().toScalarType(at::kLong);
+ auto diag_indices = long_type.arange(0, F.numel(), k + 1);
+ F.view({-1}).index_fill_(0, diag_indices, 0);
+
+ Tensor u_term, sigma_term, v_term;
+
+ if (gu.defined()) {
+ u_term = u.mm(F.mul(ut.mm(gu) - gu.t().mm(u))).mm(sigma_mat);
+ if (m > k) {
+ u_term = u_term + (im - u.mm(ut)).mm(gu).mm(sigma_mat_inv);
+ }
+ u_term = u_term.mm(vt);
+ } else {
+ u_term = self.type().zeros({1}).expand_as(self);
+ }
+
+ if (gsigma.defined()) {
+ sigma_term = u.mm(gsigma.diag()).mm(vt);
+ } else {
+ sigma_term = self.type().zeros({1}).expand_as(self);
+ }
+
+ if (gv.defined()) {
+ auto gvt = gv.t();
+ v_term = sigma_mat.mm(F.mul(vt.mm(gv) - gvt.mm(v))).mm(vt);
+ if (n > k) {
+ v_term = v_term + sigma_mat_inv.mm(gvt.mm(in - v.mm(vt)));
+ }
+ v_term = u.mm(v_term);
+ } else {
+ v_term = self.type().zeros({1}).expand_as(self);
+ }
+
+ return u_term + sigma_term + v_term;
+}
+
+// Formula:
+// d det / d A_ij = \sum_k (\prod_{l neq k} Sigma_l) U_ik V_jk
+// that is, if det != 0
+// d det / d A = U * (Sigma / det) * V^T
+Tensor _det_with_svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
+ const Tensor& det, const Tensor& u, const Tensor& sigma, const Tensor& v) {
+ std::vector<torch::autograd::Variable> svd_grads(grads.begin() + 1, grads.end());
+ auto svd_term = svd_backward(svd_grads, self, true, u, sigma, v);
+
+ auto det_grad = grads[0];
+ auto size = self.size(0);
+ auto null_dim = size - sigma.nonzero().size(0);
+ if (null_dim >= 2) {
+ // \prod_{l neq k} Sigma_l is zero every where
+ return svd_term;
+ }
+ if (null_dim == 1) {
+ // only last sigma is 0
+ // \prod_{l neq k} Sigma_l is zero at all but last dim
+ // at last dim, it is:
+ auto scale = sigma.narrow(0, 0, size - 1).prod();
+ auto last_u = u.narrow(1, size - 1, 1);
+ auto last_v = v.narrow(1, size - 1, 1);
+ return svd_term + last_u.mm(last_v.transpose(0, 1)).mul_(scale.mul_(det_grad));
+ }
+ // no zero singular values
+ return svd_term + u.mm(sigma.pow(-1).mul_(det.mul(det_grad)).diag()).mm(v.transpose(0, 1));
+}
+
}
${autograd_function_definitions}
diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp
index 584cd4ed4c..aafc37dd6b 100644
--- a/tools/autograd/templates/VariableType.cpp
+++ b/tools/autograd/templates/VariableType.cpp
@@ -193,6 +193,15 @@ VariableType::as_variable(std::tuple<Tensor, Tensor, Tensor> tensors) const {
make_variable(std::move(std::get<2>(tensors))));
}
+std::tuple<Variable, Variable, Variable, Variable>
+VariableType::as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor> tensors) const {
+ return std::make_tuple<>(
+ make_variable(std::move(std::get<0>(tensors))),
+ make_variable(std::move(std::get<1>(tensors))),
+ make_variable(std::move(std::get<2>(tensors))),
+ make_variable(std::move(std::get<3>(tensors))));
+}
+
std::vector<Variable> VariableType::as_variable(TensorList tl) const {
std::vector<Variable> variables;
for (auto& t : tl) {
diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h
index 61ada6764e..ce45d6c64a 100644
--- a/tools/autograd/templates/VariableType.h
+++ b/tools/autograd/templates/VariableType.h
@@ -54,6 +54,7 @@ private:
Variable as_variable(Tensor tensor) const;
std::tuple<Variable, Variable> as_variable(std::tuple<Tensor, Tensor> tensor) const;
std::tuple<Variable, Variable, Variable> as_variable(std::tuple<Tensor, Tensor, Tensor> tensor) const;
+ std::tuple<Variable, Variable, Variable, Variable> as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor> tensor) const;
std::vector<Variable> as_variable(TensorList tensor) const;
Variable maybe_wrap(Tensor data, const Variable & self, bool inplace) const;
diff --git a/tools/jit/templates/aten_dispatch.cpp b/tools/jit/templates/aten_dispatch.cpp
index cc4b9a9fb0..dcc90058bb 100644
--- a/tools/jit/templates/aten_dispatch.cpp
+++ b/tools/jit/templates/aten_dispatch.cpp
@@ -90,6 +90,12 @@ void pack_list(list_of_retainable & outputs, std::tuple<Tensor, Tensor, Tensor>
outputs.push_back(toRetainableSteal(std::move(std::get<1>(v))));
outputs.push_back(toRetainableSteal(std::move(std::get<2>(v))));
}
+void pack_list(list_of_retainable & outputs, std::tuple<Tensor, Tensor, Tensor, Tensor> v) {
+ outputs.push_back(toRetainableSteal(std::move(std::get<0>(v))));
+ outputs.push_back(toRetainableSteal(std::move(std::get<1>(v))));
+ outputs.push_back(toRetainableSteal(std::move(std::get<2>(v))));
+ outputs.push_back(toRetainableSteal(std::move(std::get<3>(v))));
+}
// A list of functions taking TensorList arguments (where we can't use
// the number of inputs to choose an overload).
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index f2668fdc19..e734834f0e 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -4276,11 +4276,12 @@ svd(input, some=True, out=None) -> (Tensor, Tensor, Tensor)
`U, S, V = torch.svd(A)` returns the singular value decomposition of a
real matrix `A` of size `(n x m)` such that :math:`A = USV'*`.
-`U` is of shape `n x n`
+`U` is of shape `n x min(n, m)`
-`S` is of shape `n x m`
+`S` is a diagonal square matrix of shape `min(n, m) x min(n, m)`, represented as
+a vector of shape `(min(n, m),)` containing its diagonal entries.
-`V` is of shape `m x m`.
+`V` is of shape `m x min(n, m)`.
:attr:`some` represents the number of singular values to be computed.
If `some=True`, it computes some and `some=False` computes all.
@@ -4288,6 +4289,16 @@ If `some=True`, it computes some and `some=False` computes all.
.. note:: Irrespective of the original strides, the returned matrix `U`
will be transposed, i.e. with strides `(1, n)` instead of `(n, 1)`.
+.. note:: Extra care needs to be taken when backward through `U` and `V`
+ outputs. Such operation is really only stable when :attr:`input` is
+ full rank with all distinct singular values. Otherwise, `NaN` can
+ appear as the gradients are not properly defined. Also, when
+ :attr:`some` = `False`, the gradients on `U[:, min(n, m):]` and
+ `V[:, min(n, m):]` will be ignored as those vectors can be arbitrary
+ bases of the subspaces.
+
+.. note:: Double backward through :meth:`~torch.svd` is not supported currently.
+
Args:
input (Tensor): the input 2D Tensor
some (bool, optional): controls the number of singular values to be computed
diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h
index d50412bff5..0b7071d595 100644
--- a/torch/csrc/autograd/utils/wrap_outputs.h
+++ b/torch/csrc/autograd/utils/wrap_outputs.h
@@ -33,6 +33,16 @@ inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor> tensors) {
return r.release();
}
+inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> tensors) {
+ auto r = THPObjectPtr{PyTuple_New(4)};
+ if (!r) throw python_error();
+ PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors))));
+ PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors))));
+ PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors))));
+ PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors))));
+ return r.release();
+}
+
inline PyObject* wrap(at::TensorList tl) {
auto r = THPObjectPtr{PyTuple_New(tl.size())};
if (!r) throw python_error();
diff --git a/torch/csrc/generic/methods/TensorMath.cwrap b/torch/csrc/generic/methods/TensorMath.cwrap
index f55fecda79..efff877d03 100644
--- a/torch/csrc/generic/methods/TensorMath.cwrap
+++ b/torch/csrc/generic/methods/TensorMath.cwrap
@@ -2555,6 +2555,7 @@ static const char *R = &__R;
- Double
backends:
- CPU
+ - CUDA
variants:
- method
- function
diff --git a/torch/functional.py b/torch/functional.py
index 41f4a90f78..2f91deb827 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -4,7 +4,7 @@ from operator import mul
from functools import reduce
__all__ = [
- 'split', 'chunk', 'stack', 'unbind', 'btriunpack', 'matmul',
+ 'split', 'chunk', 'stack', 'unbind', 'btriunpack', 'matmul', 'det',
]
@@ -245,3 +245,20 @@ def matmul(tensor1, tensor2, out=None):
raise ValueError("both arguments to __matmul__ need to be at least 1D, "
"but they are {}D and {}D".format(dim_tensor1, dim_tensor2))
+
+
+def det(var):
+ """Calculates determinant of a 2D square Variable.
+
+ .. note::
+ Backward through `det` internally uses SVD results. So double backward
+ through `det` will need to backward through :meth:`~Tensor.svd`. This
+ can be unstable in certain cases. Please see :meth:`~torch.svd` for
+ details.
+
+ Arguments:
+ var (Variable): The input 2D square Variable.
+ """
+ if torch.is_tensor(var):
+ raise ValueError("det is currently only supported on Variable")
+ return var.det()