diff options
-rw-r--r-- | c10/core/ScalarType.h | 28 | ||||
-rw-r--r-- | test/common_utils.py | 4 | ||||
-rw-r--r-- | test/test_torch.py | 22 | ||||
-rw-r--r-- | tools/autograd/templates/python_variable_methods.cpp | 11 | ||||
-rw-r--r-- | torch/csrc/utils/tensor_numpy.cpp | 2 |
5 files changed, 55 insertions, 12 deletions
diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 65ce4e89ca..2d852d1043 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -25,7 +25,7 @@ _(double,Double,d) /* 7 */ \ _(at::ComplexHalf,ComplexHalf,z) /* 8 */ \ _(std::complex<float>,ComplexFloat,z) /* 9 */ \ _(std::complex<double>,ComplexDouble,z) /* 10 */ \ -_(bool,Bool,i) /* 11 */ +_(bool,Bool,i) /* 11 */ // If you want to support ComplexHalf for real, replace occurrences // of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But @@ -193,19 +193,25 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { if (isComplexType(a) || isComplexType(b)) { AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be"); } + + // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX so that's why we have to add + // undefined as we are not sure what is the corrent values for the type promotions in complex type cases. static constexpr ScalarType _promoteTypesLookup [static_cast<int>(ScalarType::NumOptions)] [static_cast<int>(ScalarType::NumOptions)] = { - /* u1 i1 i2 i4 i8 f2 f4 f8 b1 */ - /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, u1 }, - /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, i1 }, - /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, i2 }, - /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, i4 }, - /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, i8 }, - /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, f2 }, - /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, f4 }, - /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, f8 }, - /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, b1 }, + /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 */ + /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, u1 }, + /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, i1 }, + /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, i2 }, + /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, ud, ud, ud, i4 }, + /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, ud, ud, ud, i8 }, + /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, ud, ud, ud, f2 }, + /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, ud, ud, ud, f4 }, + /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud, ud, ud, f8 }, + /* c2 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud }, + /* c4 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud }, + /* c8 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud }, + /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1 }, }; return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)]; } diff --git a/test/common_utils.py b/test/common_utils.py index cca6663ca3..6fb0d00387 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -414,6 +414,10 @@ class TestCase(expecttest.TestCase): self.assertEqual(x.item(), y, prec, message, allow_inf) elif isinstance(y, torch.Tensor) and isinstance(x, Number): self.assertEqual(x, y.item(), prec, message, allow_inf) + elif isinstance(x, torch.Tensor) and isinstance(y, numpy.bool_): + self.assertEqual(x.item(), y, prec, message, allow_inf) + elif isinstance(y, torch.Tensor) and isinstance(x, numpy.bool_): + self.assertEqual(x, y.item(), prec, message, allow_inf) elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor): def assertTensorsEqual(a, b): super(TestCase, self).assertEqual(a.size(), b.size(), message) diff --git a/test/test_torch.py b/test/test_torch.py index d2e8f386d2..97169ad675 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10001,6 +10001,23 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], y[0][1] = 3 self.assertTrue(x[0][1] == 3) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_to_numpy_bool(self): + x = torch.tensor([True, False], dtype=torch.bool) + self.assertEqual(x.dtype, torch.bool) + + y = x.numpy() + self.assertEqual(y.dtype, np.bool) + for i in range(len(x)): + self.assertEqual(x[i], y[i]) + + x = torch.tensor([True], dtype=torch.bool) + self.assertEqual(x.dtype, torch.bool) + + y = x.numpy() + self.assertEqual(y.dtype, np.bool) + self.assertEqual(x[0], y[0]) + def test_dlpack_conversion(self): x = torch.randn(1, 2, 3, 4).type('torch.FloatTensor') z = from_dlpack(to_dlpack(x)) @@ -10024,6 +10041,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], np.int8, np.uint8, np.longlong, + np.bool, ] for dtype in dtypes: array = np.array([1, 2, 3, 4], dtype=dtype) @@ -10075,6 +10093,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], np.int16, np.int8, np.uint8, + np.bool, ] incorrect_byteorder = '>' if sys.byteorder == 'little' else '<' @@ -10120,7 +10139,8 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], np.int64, np.int32, np.int16, - np.uint8 + np.uint8, + np.bool, ] for dtype in dtypes: self.assertEqual(dtype(42), torch.tensor(dtype(42)).item()) diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 9b2111f23c..0f043a654e 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -219,6 +219,15 @@ static int64_t dispatch_to_CLong(const Tensor & self) { return self.item<int64_t>(); } +static bool dispatch_to_Bool(const Tensor & self) { + AutoNoGIL no_gil; + OptionalDeviceGuard device_guard(device_of(self)); + if (self.numel() != 1) { + throw ValueError("only one element tensors can be converted to Python scalars"); + } + return self.item<bool>(); +} + static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW); @@ -439,6 +448,8 @@ static PyObject * THPVariable_item(PyObject* self, PyObject* args) return wrap(dispatch_to_CDouble(self_)); } else if (self_.is_complex()) { return wrap(dispatch_to_CComplexDouble(self_)); + } else if (self_.scalar_type() == ScalarType::Bool) { + return wrap(dispatch_to_Bool(self_)); } else { return wrap(dispatch_to_CLong(self_)); } diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index fa0cb54bc7..cf417420fc 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -156,6 +156,7 @@ static int aten_to_dtype(const ScalarType scalar_type) { case kShort: return NPY_INT16; case kChar: return NPY_INT8; case kByte: return NPY_UINT8; + case kBool: return NPY_BOOL; default: throw ValueError("Got unsupported ScalarType ", toString(scalar_type)); } @@ -170,6 +171,7 @@ ScalarType numpy_dtype_to_aten(int dtype) { case NPY_INT16: return kShort; case NPY_INT8: return kChar; case NPY_UINT8: return kByte; + case NPY_BOOL: return kBool; default: // Workaround: MSVC does not support two switch cases that have the same value if (dtype == NPY_LONGLONG || dtype == NPY_INT64) { |