summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIgor Fedan <ifedan@fb.com>2019-04-02 13:18:20 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-02 13:27:00 -0700
commit2e97c82470966df6942f364102690460ea58403e (patch)
tree417e4fd9d7f5b6d430ddcc762ebae037015439d6
parent3027e783b1cba83546dab50610a9d244bc2632b2 (diff)
downloadpytorch-2e97c82470966df6942f364102690460ea58403e.tar.gz
pytorch-2e97c82470966df6942f364102690460ea58403e.tar.bz2
pytorch-2e97c82470966df6942f364102690460ea58403e.zip
torch.cross' dim default changed to c10::optional instead of int=-1 (#17582)
Summary: Argument dim=-1 doesn't work for torch.cross. The signature of the torch.cross has been changed to c10::optional<int64_t> dim instead of int64_t. So based on document "If dim is not given, it defaults to the first dimension found with the size 3." and if dim is specified (even negative) it will use the correspondent dim. Fixes #17229 Pull Request resolved: https://github.com/pytorch/pytorch/pull/17582 Differential Revision: D14483063 Pulled By: ifedan fbshipit-source-id: f9699093ec401cb185fd33ca4563c8a46cdcd746
-rw-r--r--aten/src/ATen/Declarations.cwrap9
-rw-r--r--aten/src/ATen/core/Tensor.h2
-rw-r--r--aten/src/ATen/core/TensorMethods.h2
-rw-r--r--aten/src/ATen/core/Type.h2
-rw-r--r--aten/src/ATen/native/Cross.cpp54
-rw-r--r--aten/src/ATen/native/Cross.h13
-rw-r--r--aten/src/ATen/native/LegacyDefinitions.cpp8
-rw-r--r--aten/src/ATen/native/cpu/CrossKernel.cpp78
-rw-r--r--aten/src/ATen/native/cuda/CrossKernel.cu15
-rw-r--r--aten/src/ATen/native/native_functions.yaml4
-rw-r--r--aten/src/TH/generic/THTensorMath.h1
-rw-r--r--aten/src/TH/generic/THTensorMoreMath.cpp49
-rw-r--r--aten/src/THC/generic/THCTensorMathPointwise.cu18
-rw-r--r--aten/src/THC/generic/THCTensorMathPointwise.h2
-rw-r--r--test/test_torch.py29
-rw-r--r--tools/autograd/derivatives.yaml2
16 files changed, 203 insertions, 85 deletions
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 8b8ffd3cd1..b59bacf0e2 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -1846,18 +1846,19 @@
- arg: THTensor* tensor
]]
[[
- name: _th_cross
- cname: cross
+ name: _th_cross_kernel
+ cname: crossKernel
variants:
- function
+ backends:
+ - CUDA
return: argument 0
arguments:
- arg: THTensor* result
output: True
- THTensor* self
- THTensor* other
- - arg: long dim
- default: -1
+ - arg: int64_t dim
]]
[[
name: _th_diag
diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h
index ea3f2b50aa..add127c85f 100644
--- a/aten/src/ATen/core/Tensor.h
+++ b/aten/src/ATen/core/Tensor.h
@@ -663,7 +663,7 @@ class CAFFE2_API Tensor {
Tensor & exponential_(double lambd=1, Generator * generator=nullptr);
Tensor & geometric_(double p, Generator * generator=nullptr);
Tensor diag(int64_t diagonal=0) const;
- Tensor cross(const Tensor & other, int64_t dim=-1) const;
+ Tensor cross(const Tensor & other, c10::optional<int64_t> dim=c10::nullopt) const;
Tensor triu(int64_t diagonal=0) const;
Tensor tril(int64_t diagonal=0) const;
Tensor trace() const;
diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h
index 2a05ce7185..1605afffdc 100644
--- a/aten/src/ATen/core/TensorMethods.h
+++ b/aten/src/ATen/core/TensorMethods.h
@@ -1060,7 +1060,7 @@ inline Tensor & Tensor::geometric_(double p, Generator * generator) {
inline Tensor Tensor::diag(int64_t diagonal) const {
return type().diag(*this, diagonal);
}
-inline Tensor Tensor::cross(const Tensor & other, int64_t dim) const {
+inline Tensor Tensor::cross(const Tensor & other, c10::optional<int64_t> dim) const {
return type().cross(*this, other, dim);
}
inline Tensor Tensor::triu(int64_t diagonal) const {
diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h
index dcdd533187..e50af839ef 100644
--- a/aten/src/ATen/core/Type.h
+++ b/aten/src/ATen/core/Type.h
@@ -541,7 +541,7 @@ struct CAFFE2_API Type {
virtual Tensor & exponential_(Tensor & self, double lambd, Generator * generator) const = 0;
virtual Tensor & geometric_(Tensor & self, double p, Generator * generator) const = 0;
virtual Tensor diag(const Tensor & self, int64_t diagonal) const = 0;
- virtual Tensor cross(const Tensor & self, const Tensor & other, int64_t dim) const = 0;
+ virtual Tensor cross(const Tensor & self, const Tensor & other, c10::optional<int64_t> dim) const = 0;
virtual Tensor triu(const Tensor & self, int64_t diagonal) const = 0;
virtual Tensor tril(const Tensor & self, int64_t diagonal) const = 0;
virtual Tensor trace(const Tensor & self) const = 0;
diff --git a/aten/src/ATen/native/Cross.cpp b/aten/src/ATen/native/Cross.cpp
new file mode 100644
index 0000000000..8788969797
--- /dev/null
+++ b/aten/src/ATen/native/Cross.cpp
@@ -0,0 +1,54 @@
+#include <ATen/ATen.h>
+#include <ATen/Dispatch.h>
+#include <ATen/NativeFunctions.h>
+
+#include <ATen/native/Cross.h>
+
+namespace at { namespace native {
+
+DEFINE_DISPATCH(cross_stub);
+
+Tensor cross(const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension) {
+ Tensor out = at::empty_like(input);
+ native::cross_out(out, input, other, dimension);
+ return out;
+}
+
+Tensor & cross_out(Tensor & out, const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension) {
+ auto device_res = input.type().device_type();
+ AT_CHECK(device_res == kCPU || device_res == kCUDA, "cross only supports CPU and CUDA devices, out got: ", device_res);
+ auto device1 = input.type().device_type();
+ AT_CHECK(device1 == kCPU || device1 == kCUDA, "cross only supports CPU and CUDA devices, input got: ", device1);
+ auto device2 = other.type().device_type();
+ AT_CHECK(device2 == kCPU || device2 == kCUDA, "cross only supports CPU and CUDA devices, other got: ", device2);
+ AT_CHECK(device_res == device1, "out and input must have the same device type. out: ", device_res, " input: ", device1);
+ AT_CHECK(device1 == device2, "input and other must have the same device type. input: ", device1, " other: ", device2);
+ AT_CHECK(!out.is_cuda() || out.get_device() == input.get_device(), "device of out (", input.get_device(), ") must match device of input (", other.get_device(), ")");
+ AT_CHECK(!input.is_cuda() || input.get_device() == other.get_device(), "device of input (", input.get_device(), ") must match device of other (", other.get_device(), ")");
+ AT_CHECK(input.dim() == other.dim(), "inconsistent tensors dimensions input: ", input.dim(), " other: ", other.dim());
+ AT_CHECK(input.sizes() == other.sizes(), "inconsistent tensors sizes input: ", input.sizes(), " other: ", other.sizes());
+
+ int64_t dim = -1;
+ if(!dimension.has_value()) {
+ for(int64_t i = 0; i < input.dim(); i++) {
+ if(input.size(i) == 3) {
+ dim = i;
+ break;
+ }
+ }
+ AT_CHECK(dim >= 0, "no dimension of size 3 in input");
+ } else {
+ dim = maybe_wrap_dim(dimension.value(), input.dim());
+ AT_CHECK(input.size(dim) == 3, "dimension ", dimension.value(), " does not have size 3");
+ }
+
+ if (out.sizes() != input.sizes()) {
+ out.resize_as_(input);
+ }
+
+ cross_stub(device1, out, input, other, dim);
+ return out;
+}
+
+}} // namespace at::native
+
diff --git a/aten/src/ATen/native/Cross.h b/aten/src/ATen/native/Cross.h
new file mode 100644
index 0000000000..35f9886b2b
--- /dev/null
+++ b/aten/src/ATen/native/Cross.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/native/DispatchStub.h>
+
+namespace at { namespace native {
+
+using cross_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const int64_t d);
+
+DECLARE_DISPATCH(cross_fn, cross_stub);
+
+}} // namespace at::native
+
diff --git a/aten/src/ATen/native/LegacyDefinitions.cpp b/aten/src/ATen/native/LegacyDefinitions.cpp
index 79ee5aab5a..6315d39576 100644
--- a/aten/src/ATen/native/LegacyDefinitions.cpp
+++ b/aten/src/ATen/native/LegacyDefinitions.cpp
@@ -272,14 +272,6 @@ Tensor diag(const Tensor & self, int64_t diagonal) {
return at::legacy::th::_th_diag(self, diagonal);
}
-Tensor & cross_out(Tensor & result, const Tensor & self, const Tensor & other, int64_t dim) {
- return at::legacy::th::_th_cross_out(result, self, other, dim);
-}
-
-Tensor cross(const Tensor & self, const Tensor & other, int64_t dim) {
- return at::legacy::th::_th_cross(self, other, dim);
-}
-
Tensor trace(const Tensor & self) {
return at::legacy::th::_th_trace(self);
}
diff --git a/aten/src/ATen/native/cpu/CrossKernel.cpp b/aten/src/ATen/native/cpu/CrossKernel.cpp
new file mode 100644
index 0000000000..9d51fc61f9
--- /dev/null
+++ b/aten/src/ATen/native/cpu/CrossKernel.cpp
@@ -0,0 +1,78 @@
+#include <ATen/native/Cross.h>
+
+#include <numeric>
+#include <iterator>
+#include <algorithm>
+#include <vector>
+
+#include <ATen/Dispatch.h>
+#include <ATen/Parallel.h>
+#include <ATen/cpu/vml.h>
+namespace at { namespace native { namespace {
+
+template<typename scalar_t>
+static void apply_cross(Tensor& result, const Tensor& a, const Tensor& b, const int64_t dim) {
+ int64_t total = a.numel() / 3;
+ int64_t a_stride = a.stride(dim);
+ int64_t b_stride = b.stride(dim);
+ int64_t r_stride = result.stride(dim);
+
+ scalar_t *a_ptr = a.data<scalar_t>();
+ scalar_t *b_ptr = b.data<scalar_t>();
+ scalar_t *r_ptr = result.data<scalar_t>();
+
+ parallel_for(0, total, internal::GRAIN_SIZE, [&](int64_t s, int64_t e) {
+ const int64_t a_dim = a.dim();
+ std::vector<int64_t> position_in_dims(a_dim);
+ int64_t index_in_curr_dim = s;
+ int64_t a_start = 0;
+ int64_t b_start = 0;
+ int64_t r_start = 0;
+ for (int64_t i = 0; i < a.dim(); i++) {
+ if (i == dim) continue;
+ position_in_dims[i] = index_in_curr_dim % a.size(i);
+ a_start += (index_in_curr_dim % a.size(i)) * a.stride(i);
+ b_start += (index_in_curr_dim % b.size(i)) * b.stride(i);
+ r_start += (index_in_curr_dim % result.size(i)) * result.stride(i);
+ index_in_curr_dim = index_in_curr_dim / a.size(i);
+ }
+
+ while (s < e) {
+ r_ptr[r_start+0*r_stride] = a_ptr[a_start+1*a_stride]*b_ptr[b_start+2*b_stride] - a_ptr[a_start+2*a_stride]*b_ptr[b_start+1*b_stride];
+ r_ptr[r_start+1*r_stride] = a_ptr[a_start+2*a_stride]*b_ptr[b_start+0*b_stride] - a_ptr[a_start+0*a_stride]*b_ptr[b_start+2*b_stride];
+ r_ptr[r_start+2*r_stride] = a_ptr[a_start+0*a_stride]*b_ptr[b_start+1*b_stride] - a_ptr[a_start+1*a_stride]*b_ptr[b_start+0*b_stride];
+ s++;
+
+ for (int i = 0; i < a.dim(); i++) {
+ if (i == dim) {
+ continue;
+ }
+ position_in_dims[i]++;
+ a_start += a.stride(i);
+ b_start += b.stride(i);
+ r_start += result.stride(i);
+ if (position_in_dims[i] == a.size(i) && i != a.dim()-1) {
+ a_start -= position_in_dims[i] * a.stride(i);
+ b_start -= position_in_dims[i] * b.stride(i);
+ r_start -= position_in_dims[i] * result.stride(i);
+ position_in_dims[i] = 0;
+ } else {
+ break;
+ }
+ }
+ }
+ });
+}
+
+static void cross_kernel_impl(Tensor& result, const Tensor& a, const Tensor& b, const int64_t dim) {
+ AT_DISPATCH_ALL_TYPES(result.scalar_type(), "cross", [&]() {
+ apply_cross<scalar_t>(result, a, b, dim);
+ });
+}
+
+} // anonymous namespace
+
+REGISTER_DISPATCH(cross_stub, &cross_kernel_impl);
+
+}} // namespace at::native
+
diff --git a/aten/src/ATen/native/cuda/CrossKernel.cu b/aten/src/ATen/native/cuda/CrossKernel.cu
new file mode 100644
index 0000000000..abac08689b
--- /dev/null
+++ b/aten/src/ATen/native/cuda/CrossKernel.cu
@@ -0,0 +1,15 @@
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/LegacyTHFunctions.h>
+#include <ATen/native/Cross.h>
+
+namespace at { namespace native {
+
+void cross_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const int64_t dim) {
+ at::legacy::th::_th_cross_kernel_out(result, x1, x2, dim);
+}
+
+REGISTER_DISPATCH(cross_stub, &cross_kernel_impl);
+
+}}
+
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index dc6153b442..38215975cb 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3513,10 +3513,10 @@
matches_jit_signature: True
variants: method, function
-- func: cross(Tensor self, Tensor other, int dim=-1, *, Tensor(a!) out) -> Tensor(a!)
+- func: cross(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True
-- func: cross(Tensor self, Tensor other, int dim=-1) -> Tensor
+- func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor
matches_jit_signature: True
variants: method, function
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index 27083a0d8d..3ab999b91b 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -80,7 +80,6 @@ TH_API void THTensor_(cumsum)(THTensor *r_, THTensor *t, int dimension);
TH_API void THTensor_(cumprod)(THTensor *r_, THTensor *t, int dimension);
TH_API void THTensor_(sign)(THTensor *r_, THTensor *t);
TH_API accreal THTensor_(trace)(THTensor *t);
-TH_API void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension);
TH_API void THTensor_(cmax)(THTensor *r, THTensor *t, THTensor *src);
TH_API void THTensor_(cmin)(THTensor *r, THTensor *t, THTensor *src);
diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp
index 4661ef234f..8461192bac 100644
--- a/aten/src/TH/generic/THTensorMoreMath.cpp
+++ b/aten/src/TH/generic/THTensorMoreMath.cpp
@@ -390,53 +390,6 @@ accreal THTensor_(trace)(THTensor *t)
return sum;
}
-void THTensor_(cross)(THTensor *r_, THTensor *a, THTensor *b, int dimension)
-{
- int i;
-
- if(THTensor_(nDimensionLegacyNoScalars)(a) != THTensor_(nDimensionLegacyNoScalars)(b))
- THError("inconsistent tensor dimension %dD, %dD",
- THTensor_(nDimensionLegacyNoScalars)(a), THTensor_(nDimensionLegacyNoScalars)(b));
-
- for(i = 0; i < a->dim(); i++)
- {
- if(THTensor_(size)(a, i) != THTensor_(size)(b, i)) {
- THDescBuff ba = THTensor_(sizeDesc)(a);
- THDescBuff bb = THTensor_(sizeDesc)(b);
- THError("inconsistent tensor sizes %s, %s", ba.str, bb.str);
- }
- }
-
- if(dimension < 0)
- {
- for(i = 0; i < THTensor_(nDimensionLegacyNoScalars)(a); i++)
- {
- if(THTensor_sizeLegacyNoScalars(a, i) == 3)
- {
- dimension = i;
- break;
- }
- }
- if(dimension < 0) {
- THDescBuff ba = THTensor_(sizeDesc)(a);
- THError("no dimension of size 3 in a: %s", ba.str);
- }
- }
-
- THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyNoScalars)(a), 3, "dimension %d out of range",
- dimension);
- THArgCheck(THTensor_sizeLegacyNoScalars(a, dimension) == 3, 3, "dimension %d does not have size 3",
- dimension);
-
- THTensor_(resizeAs)(r_, a);
-
- TH_TENSOR_DIM_APPLY3(scalar_t, a, scalar_t, b, scalar_t, r_, dimension,
- TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM,
- r__data[0*r__stride] = a_data[1*a_stride]*b_data[2*b_stride] - a_data[2*a_stride]*b_data[1*b_stride];
- r__data[1*r__stride] = a_data[2*a_stride]*b_data[0*b_stride] - a_data[0*a_stride]*b_data[2*b_stride];
- r__data[2*r__stride] = a_data[0*a_stride]*b_data[1*b_stride] - a_data[1*a_stride]*b_data[0*b_stride];);
-}
-
void THTensor_(cmax)(THTensor *r, THTensor *t, THTensor *src) {
THTensor_(resizeAs)(r, t);
TH_TENSOR_APPLY3(scalar_t, r, scalar_t, t, scalar_t, src,
@@ -1047,7 +1000,7 @@ int THTensor_(equal)(THTensor *ta, THTensor* tb)
}
#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \
- void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value) \
+ void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value) \
{ \
THByteTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL); \
TH_TENSOR_APPLY2(unsigned char, r_, scalar_t, t, \
diff --git a/aten/src/THC/generic/THCTensorMathPointwise.cu b/aten/src/THC/generic/THCTensorMathPointwise.cu
index 0c128a51b5..6a2ee33ed1 100644
--- a/aten/src/THC/generic/THCTensorMathPointwise.cu
+++ b/aten/src/THC/generic/THCTensorMathPointwise.cu
@@ -111,26 +111,10 @@ void THCTensor_(clamp)(THCState *state, THCTensor *self_, THCTensor *src, scalar
THCudaCheck(cudaGetLastError());
}
-void THCTensor_(cross)(THCState *state, THCTensor *self, THCTensor *x, THCTensor *y, int dimension)
+void THCTensor_(crossKernel)(THCState *state, THCTensor *self, THCTensor *x, THCTensor *y, int dimension)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self, x, y));
- int i;
- int nd = x->dim();
- ptrdiff_t nelem = THCTensor_(nElement)(state, x);
- THArgCheck(nd == y->dim(), 1, "tensors must have same number of dimensions");
- for (i = 0; i < nd; i++) {
- THArgCheck(THCTensor_(size)(state, x, i) == THCTensor_(size)(state, y, i), 1, "dimension %i of x and y does not match", i);
- if (dimension < 0 && THCTensor_(size)(state, x, i) == 3) {
- dimension = i;
- }
- }
-
- THArgCheck(dimension >= 0 && dimension < nd, 3, "dimension %d out of range", dimension+1);
- THArgCheck(THCTensor_(size)(state, x, dimension) == 3, 3,
- "dimension %d does not have size 3", dimension+1);
- THCTensor_(resizeAs)(state, self, x);
-
int64_t sx = THCTensor_(stride)(state, x, dimension);
int64_t sy = THCTensor_(stride)(state, y, dimension);
int64_t so = THCTensor_(stride)(state, self, dimension);
diff --git a/aten/src/THC/generic/THCTensorMathPointwise.h b/aten/src/THC/generic/THCTensorMathPointwise.h
index 78559f5eeb..5539e8ed1b 100644
--- a/aten/src/THC/generic/THCTensorMathPointwise.h
+++ b/aten/src/THC/generic/THCTensorMathPointwise.h
@@ -47,7 +47,7 @@ THC_API void THCTensor_(neg)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(abs)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(sign)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(clamp)(THCState *state, THCTensor *self, THCTensor *src, scalar_t min_value, scalar_t max_value);
-THC_API void THCTensor_(cross)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2, int dimension);
+THC_API void THCTensor_(crossKernel)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2, int dimension);
THC_API void THCTensor_(cadd)(THCState *state, THCTensor *self, THCTensor *src1, scalar_t value, THCTensor *src2);
THC_API void THCTensor_(csub)(THCState *state, THCTensor *self, THCTensor *src1, scalar_t value, THCTensor *src2);
diff --git a/test/test_torch.py b/test/test_torch.py
index 5de9045543..7d44307325 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2367,6 +2367,35 @@ class _TestTorchMixin(object):
torch.cross(x, y, out=res2)
self.assertEqual(res1, res2)
+ def test_cross_with_and_without_dim(self):
+ x = torch.rand(100, 3)
+ y = torch.rand(100, 3)
+ res1 = torch.cross(x, y, dim=1)
+ res2 = torch.cross(x, y, dim=-1)
+ res3 = torch.cross(x, y)
+ self.assertEqual(res1, res2)
+ self.assertEqual(res1, res3)
+
+ def test_cross_validation(self):
+ self.assertRaisesRegex(
+ RuntimeError, "inconsistent tensors dimensions",
+ lambda: torch.cross(torch.rand(100, 3), torch.rand(100, 3, 10)))
+ self.assertRaisesRegex(
+ RuntimeError, "inconsistent tensors sizes",
+ lambda: torch.cross(torch.rand(5, 3), torch.rand(3, 5)))
+ self.assertRaisesRegex(
+ RuntimeError, "no dimension of size 3 in input",
+ lambda: torch.cross(torch.rand(5, 4), torch.rand(5, 4)))
+ self.assertRaisesRegex(
+ RuntimeError, "dimension 0 does not have size 3",
+ lambda: torch.cross(torch.rand(5, 4, 3), torch.rand(5, 4, 3), dim=0))
+ self.assertRaisesRegex(
+ RuntimeError, "dimension -1 does not have size 3",
+ lambda: torch.cross(torch.rand(5, 3, 4), torch.rand(5, 3, 4), dim=-1))
+ self.assertRaisesRegex(
+ IndexError, "Dimension out of range",
+ lambda: torch.cross(torch.rand(5, 3, 4), torch.rand(5, 3, 4), dim=-5))
+
def test_zeros(self):
res1 = torch.zeros(100, 100)
res2 = torch.Tensor()
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 425a5b5ac8..bea70b5495 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -234,7 +234,7 @@
- name: cosh(Tensor self)
self: grad * self.sinh()
-- name: cross(Tensor self, Tensor other, int64_t dim)
+- name: cross(Tensor self, Tensor other, int64_t? dim)
self: other.cross(grad, dim)
other: grad.cross(self, dim)