summaryrefslogtreecommitdiff
path: root/c10
diff options
context:
space:
mode:
authorIurii Zdebskyi <iuriiz@fb.com>2019-04-03 07:22:38 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-03 07:28:24 -0700
commit48f70ea0a2a3a408c14642419271999663b65ed9 (patch)
treedb0117e7f20895d6a8d32346ec45dde06a8c3ef5 /c10
parent7349dbb7ce09e0810980cbb2aeb0bbd9aa0757ad (diff)
downloadpytorch-48f70ea0a2a3a408c14642419271999663b65ed9.tar.gz
pytorch-48f70ea0a2a3a408c14642419271999663b65ed9.tar.bz2
pytorch-48f70ea0a2a3a408c14642419271999663b65ed9.zip
Added numpy conversion (#18505)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18505 ghimport-source-id: f3c9b9251e5793f9e192f587194ddfebb45facc1 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18505 [WIP]Added numpy conversion** * #18166 Bool Tensor for CUDA Differential Revision: D14646403 fbshipit-source-id: 79d39d692c778ce1981c1d35b1c33e3d93111041
Diffstat (limited to 'c10')
-rw-r--r--c10/core/ScalarType.h28
1 files changed, 17 insertions, 11 deletions
diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h
index 65ce4e89ca..2d852d1043 100644
--- a/c10/core/ScalarType.h
+++ b/c10/core/ScalarType.h
@@ -25,7 +25,7 @@ _(double,Double,d) /* 7 */ \
_(at::ComplexHalf,ComplexHalf,z) /* 8 */ \
_(std::complex<float>,ComplexFloat,z) /* 9 */ \
_(std::complex<double>,ComplexDouble,z) /* 10 */ \
-_(bool,Bool,i) /* 11 */
+_(bool,Bool,i) /* 11 */
// If you want to support ComplexHalf for real, replace occurrences
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
@@ -193,19 +193,25 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
if (isComplexType(a) || isComplexType(b)) {
AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
}
+
+ // this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX so that's why we have to add
+ // undefined as we are not sure what is the corrent values for the type promotions in complex type cases.
static constexpr ScalarType _promoteTypesLookup
[static_cast<int>(ScalarType::NumOptions)]
[static_cast<int>(ScalarType::NumOptions)] = {
- /* u1 i1 i2 i4 i8 f2 f4 f8 b1 */
- /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, u1 },
- /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, i1 },
- /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, i2 },
- /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, i4 },
- /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, i8 },
- /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, f2 },
- /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, f4 },
- /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, f8 },
- /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, b1 },
+ /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 */
+ /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, u1 },
+ /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, i1 },
+ /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, i2 },
+ /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, ud, ud, ud, i4 },
+ /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, ud, ud, ud, i8 },
+ /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, ud, ud, ud, f2 },
+ /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, ud, ud, ud, f4 },
+ /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud, ud, ud, f8 },
+ /* c2 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
+ /* c4 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
+ /* c8 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
+ /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1 },
};
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}