summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--c10/core/ScalarType.h28
-rw-r--r--test/common_utils.py4
-rw-r--r--test/test_torch.py22
-rw-r--r--tools/autograd/templates/python_variable_methods.cpp11
-rw-r--r--torch/csrc/utils/tensor_numpy.cpp2
5 files changed, 55 insertions, 12 deletions
diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h
index 65ce4e89ca..2d852d1043 100644
--- a/c10/core/ScalarType.h
+++ b/c10/core/ScalarType.h
@@ -25,7 +25,7 @@ _(double,Double,d) /* 7 */ \
_(at::ComplexHalf,ComplexHalf,z) /* 8 */ \
_(std::complex<float>,ComplexFloat,z) /* 9 */ \
_(std::complex<double>,ComplexDouble,z) /* 10 */ \
-_(bool,Bool,i) /* 11 */
+_(bool,Bool,i) /* 11 */
// If you want to support ComplexHalf for real, replace occurrences
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
@@ -193,19 +193,25 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
if (isComplexType(a) || isComplexType(b)) {
AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
}
+
+ // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX so that's why we have to add
+ // undefined as we are not sure what is the corrent values for the type promotions in complex type cases.
static constexpr ScalarType _promoteTypesLookup
[static_cast<int>(ScalarType::NumOptions)]
[static_cast<int>(ScalarType::NumOptions)] = {
- /* u1 i1 i2 i4 i8 f2 f4 f8 b1 */
- /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, u1 },
- /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, i1 },
- /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, i2 },
- /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, i4 },
- /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, i8 },
- /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, f2 },
- /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, f4 },
- /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, f8 },
- /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, b1 },
+ /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 */
+ /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, u1 },
+ /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, i1 },
+ /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, i2 },
+ /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, ud, ud, ud, i4 },
+ /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, ud, ud, ud, i8 },
+ /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, ud, ud, ud, f2 },
+ /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, ud, ud, ud, f4 },
+ /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud, ud, ud, f8 },
+ /* c2 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
+ /* c4 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
+ /* c8 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
+ /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1 },
};
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}
diff --git a/test/common_utils.py b/test/common_utils.py
index cca6663ca3..6fb0d00387 100644
--- a/test/common_utils.py
+++ b/test/common_utils.py
@@ -414,6 +414,10 @@ class TestCase(expecttest.TestCase):
self.assertEqual(x.item(), y, prec, message, allow_inf)
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
self.assertEqual(x, y.item(), prec, message, allow_inf)
+ elif isinstance(x, torch.Tensor) and isinstance(y, numpy.bool_):
+ self.assertEqual(x.item(), y, prec, message, allow_inf)
+ elif isinstance(y, torch.Tensor) and isinstance(x, numpy.bool_):
+ self.assertEqual(x, y.item(), prec, message, allow_inf)
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
def assertTensorsEqual(a, b):
super(TestCase, self).assertEqual(a.size(), b.size(), message)
diff --git a/test/test_torch.py b/test/test_torch.py
index d2e8f386d2..97169ad675 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -10001,6 +10001,23 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
y[0][1] = 3
self.assertTrue(x[0][1] == 3)
+ @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
+ def test_to_numpy_bool(self):
+ x = torch.tensor([True, False], dtype=torch.bool)
+ self.assertEqual(x.dtype, torch.bool)
+
+ y = x.numpy()
+ self.assertEqual(y.dtype, np.bool)
+ for i in range(len(x)):
+ self.assertEqual(x[i], y[i])
+
+ x = torch.tensor([True], dtype=torch.bool)
+ self.assertEqual(x.dtype, torch.bool)
+
+ y = x.numpy()
+ self.assertEqual(y.dtype, np.bool)
+ self.assertEqual(x[0], y[0])
+
def test_dlpack_conversion(self):
x = torch.randn(1, 2, 3, 4).type('torch.FloatTensor')
z = from_dlpack(to_dlpack(x))
@@ -10024,6 +10041,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
np.int8,
np.uint8,
np.longlong,
+ np.bool,
]
for dtype in dtypes:
array = np.array([1, 2, 3, 4], dtype=dtype)
@@ -10075,6 +10093,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
np.int16,
np.int8,
np.uint8,
+ np.bool,
]
incorrect_byteorder = '>' if sys.byteorder == 'little' else '<'
@@ -10120,7 +10139,8 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
np.int64,
np.int32,
np.int16,
- np.uint8
+ np.uint8,
+ np.bool,
]
for dtype in dtypes:
self.assertEqual(dtype(42), torch.tensor(dtype(42)).item())
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
index 9b2111f23c..0f043a654e 100644
--- a/tools/autograd/templates/python_variable_methods.cpp
+++ b/tools/autograd/templates/python_variable_methods.cpp
@@ -219,6 +219,15 @@ static int64_t dispatch_to_CLong(const Tensor & self) {
return self.item<int64_t>();
}
+static bool dispatch_to_Bool(const Tensor & self) {
+ AutoNoGIL no_gil;
+ OptionalDeviceGuard device_guard(device_of(self));
+ if (self.numel() != 1) {
+ throw ValueError("only one element tensors can be converted to Python scalars");
+ }
+ return self.item<bool>();
+}
+
static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
@@ -439,6 +448,8 @@ static PyObject * THPVariable_item(PyObject* self, PyObject* args)
return wrap(dispatch_to_CDouble(self_));
} else if (self_.is_complex()) {
return wrap(dispatch_to_CComplexDouble(self_));
+ } else if (self_.scalar_type() == ScalarType::Bool) {
+ return wrap(dispatch_to_Bool(self_));
} else {
return wrap(dispatch_to_CLong(self_));
}
diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp
index fa0cb54bc7..cf417420fc 100644
--- a/torch/csrc/utils/tensor_numpy.cpp
+++ b/torch/csrc/utils/tensor_numpy.cpp
@@ -156,6 +156,7 @@ static int aten_to_dtype(const ScalarType scalar_type) {
case kShort: return NPY_INT16;
case kChar: return NPY_INT8;
case kByte: return NPY_UINT8;
+ case kBool: return NPY_BOOL;
default:
throw ValueError("Got unsupported ScalarType ", toString(scalar_type));
}
@@ -170,6 +171,7 @@ ScalarType numpy_dtype_to_aten(int dtype) {
case NPY_INT16: return kShort;
case NPY_INT8: return kChar;
case NPY_UINT8: return kByte;
+ case NPY_BOOL: return kBool;
default:
// Workaround: MSVC does not support two switch cases that have the same value
if (dtype == NPY_LONGLONG || dtype == NPY_INT64) {