summaryrefslogtreecommitdiff
path: root/c10/util
diff options
context:
space:
mode:
authorSebastian Messmer <messmer@fb.com>2019-01-10 16:06:27 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-10 16:22:22 -0800
commitd408324350fa77e72d0be43d7302ac135d25d292 (patch)
tree6da82456fa49b4f4be8cb0f34ca8db33889b76af /c10/util
parent6b64052e20c934eed6527c03b544544a4758847c (diff)
downloadpytorch-d408324350fa77e72d0be43d7302ac135d25d292.tar.gz
pytorch-d408324350fa77e72d0be43d7302ac135d25d292.tar.bz2
pytorch-d408324350fa77e72d0be43d7302ac135d25d292.zip
Move files to/from c10/core and c10/util (#15316)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15316 This starts cleaning up the files in c10 according to the module structure we decided on. Move to c10/util: - Half.h, Half-inl.h, Half.cpp, bitcasts.h Move to c10/core: - Device.h, Device.cpp - DeviceType.h, DeviceType.cpp i-am-not-moving-c2-to-c10 Reviewed By: dzhulgakov Differential Revision: D13498493 fbshipit-source-id: dfcf1c490474a12ab950c72ca686b8ad86428f63
Diffstat (limited to 'c10/util')
-rw-r--r--c10/util/Half-inl.h285
-rw-r--r--c10/util/Half.cpp16
-rw-r--r--c10/util/Half.h462
-rw-r--r--c10/util/bitcasts.h45
-rw-r--r--c10/util/typeid.h16
5 files changed, 816 insertions, 8 deletions
diff --git a/c10/util/Half-inl.h b/c10/util/Half-inl.h
new file mode 100644
index 0000000000..966d55f1f4
--- /dev/null
+++ b/c10/util/Half-inl.h
@@ -0,0 +1,285 @@
+#pragma once
+
+#include <cstring>
+#include <limits>
+#include <c10/macros/Macros.h>
+
+#ifdef __CUDACC__
+#include <cuda_fp16.h>
+#endif
+
+#ifdef __HIPCC__
+#include <hip/hip_fp16.h>
+#endif
+
+namespace c10 {
+
+/// Constructors
+
+inline C10_HOST_DEVICE Half::Half(float value) {
+#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
+ x = __half_as_short(__float2half(value));
+#else
+ x = detail::fp16_ieee_from_fp32_value(value);
+#endif
+}
+
+/// Implicit conversions
+
+inline C10_HOST_DEVICE Half::operator float() const {
+#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
+ return __half2float(*reinterpret_cast<const __half*>(&x));
+#else
+ return detail::fp16_ieee_to_fp32_value(x);
+#endif
+}
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+inline C10_HOST_DEVICE Half::Half(const __half& value) {
+ x = *reinterpret_cast<const unsigned short*>(&value);
+}
+inline C10_HOST_DEVICE Half::operator __half() const {
+ return *reinterpret_cast<const __half*>(&x);
+}
+#endif
+
+// CUDA intrinsics
+
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350)
+inline __device__ Half __ldg(const Half* ptr) {
+ return __ldg(reinterpret_cast<const __half*>(ptr));
+}
+#endif
+
+/// Arithmetic
+
+inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) {
+ return static_cast<float>(a) + static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) {
+ return static_cast<float>(a) - static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) {
+ return static_cast<float>(a) * static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) {
+ return static_cast<float>(a) / static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Half operator-(const Half& a) {
+ return -static_cast<float>(a);
+}
+
+inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) {
+ a = a + b;
+ return a;
+}
+
+inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) {
+ a = a - b;
+ return a;
+}
+
+inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) {
+ a = a * b;
+ return a;
+}
+
+inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) {
+ a = a / b;
+ return a;
+}
+
+/// Arithmetic with floats
+
+inline C10_HOST_DEVICE float operator+(Half a, float b) {
+ return static_cast<float>(a) + b;
+}
+inline C10_HOST_DEVICE float operator-(Half a, float b) {
+ return static_cast<float>(a) - b;
+}
+inline C10_HOST_DEVICE float operator*(Half a, float b) {
+ return static_cast<float>(a) * b;
+}
+inline C10_HOST_DEVICE float operator/(Half a, float b) {
+ return static_cast<float>(a) / b;
+}
+
+inline C10_HOST_DEVICE float operator+(float a, Half b) {
+ return a + static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float operator-(float a, Half b) {
+ return a - static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float operator*(float a, Half b) {
+ return a * static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float operator/(float a, Half b) {
+ return a / static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) {
+ return a += static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) {
+ return a -= static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) {
+ return a *= static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) {
+ return a /= static_cast<float>(b);
+}
+
+/// Arithmetic with doubles
+
+inline C10_HOST_DEVICE double operator+(Half a, double b) {
+ return static_cast<double>(a) + b;
+}
+inline C10_HOST_DEVICE double operator-(Half a, double b) {
+ return static_cast<double>(a) - b;
+}
+inline C10_HOST_DEVICE double operator*(Half a, double b) {
+ return static_cast<double>(a) * b;
+}
+inline C10_HOST_DEVICE double operator/(Half a, double b) {
+ return static_cast<double>(a) / b;
+}
+
+inline C10_HOST_DEVICE double operator+(double a, Half b) {
+ return a + static_cast<double>(b);
+}
+inline C10_HOST_DEVICE double operator-(double a, Half b) {
+ return a - static_cast<double>(b);
+}
+inline C10_HOST_DEVICE double operator*(double a, Half b) {
+ return a * static_cast<double>(b);
+}
+inline C10_HOST_DEVICE double operator/(double a, Half b) {
+ return a / static_cast<double>(b);
+}
+
+/// Arithmetic with ints
+
+inline C10_HOST_DEVICE Half operator+(Half a, int b) {
+ return a + static_cast<Half>(b);
+}
+inline C10_HOST_DEVICE Half operator-(Half a, int b) {
+ return a - static_cast<Half>(b);
+}
+inline C10_HOST_DEVICE Half operator*(Half a, int b) {
+ return a * static_cast<Half>(b);
+}
+inline C10_HOST_DEVICE Half operator/(Half a, int b) {
+ return a / static_cast<Half>(b);
+}
+
+inline C10_HOST_DEVICE Half operator+(int a, Half b) {
+ return static_cast<Half>(a) + b;
+}
+inline C10_HOST_DEVICE Half operator-(int a, Half b) {
+ return static_cast<Half>(a) - b;
+}
+inline C10_HOST_DEVICE Half operator*(int a, Half b) {
+ return static_cast<Half>(a) * b;
+}
+inline C10_HOST_DEVICE Half operator/(int a, Half b) {
+ return static_cast<Half>(a) / b;
+}
+
+//// Arithmetic with int64_t
+
+inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) {
+ return a + static_cast<Half>(b);
+}
+inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) {
+ return a - static_cast<Half>(b);
+}
+inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) {
+ return a * static_cast<Half>(b);
+}
+inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) {
+ return a / static_cast<Half>(b);
+}
+
+inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) {
+ return static_cast<Half>(a) + b;
+}
+inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) {
+ return static_cast<Half>(a) - b;
+}
+inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) {
+ return static_cast<Half>(a) * b;
+}
+inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) {
+ return static_cast<Half>(a) / b;
+}
+
+/// NOTE: we do not define comparisons directly and instead rely on the implicit
+/// conversion from c10::Half to float.
+
+} // namespace c10
+
+namespace std {
+
+template <>
+class numeric_limits<c10::Half> {
+ public:
+ static constexpr bool is_specialized = true;
+ static constexpr bool is_signed = true;
+ static constexpr bool is_integer = false;
+ static constexpr bool is_exact = false;
+ static constexpr bool has_infinity = true;
+ static constexpr bool has_quiet_NaN = true;
+ static constexpr bool has_signaling_NaN = true;
+ static constexpr auto has_denorm = numeric_limits<float>::has_denorm;
+ static constexpr auto has_denorm_loss =
+ numeric_limits<float>::has_denorm_loss;
+ static constexpr auto round_style = numeric_limits<float>::round_style;
+ static constexpr bool is_iec559 = true;
+ static constexpr bool is_bounded = true;
+ static constexpr bool is_modulo = false;
+ static constexpr int digits = 11;
+ static constexpr int digits10 = 3;
+ static constexpr int max_digits10 = 5;
+ static constexpr int radix = 2;
+ static constexpr int min_exponent = -13;
+ static constexpr int min_exponent10 = -4;
+ static constexpr int max_exponent = 16;
+ static constexpr int max_exponent10 = 4;
+ static constexpr auto traps = numeric_limits<float>::traps;
+ static constexpr auto tinyness_before =
+ numeric_limits<float>::tinyness_before;
+ static constexpr c10::Half min() {
+ return c10::Half(0x0400, c10::Half::from_bits);
+ }
+ static constexpr c10::Half lowest() {
+ return c10::Half(0xFBFF, c10::Half::from_bits);
+ }
+ static constexpr c10::Half max() {
+ return c10::Half(0x7BFF, c10::Half::from_bits);
+ }
+ static constexpr c10::Half epsilon() {
+ return c10::Half(0x1400, c10::Half::from_bits);
+ }
+ static constexpr c10::Half round_error() {
+ return c10::Half(0x3800, c10::Half::from_bits);
+ }
+ static constexpr c10::Half infinity() {
+ return c10::Half(0x7C00, c10::Half::from_bits);
+ }
+ static constexpr c10::Half quiet_NaN() {
+ return c10::Half(0x7E00, c10::Half::from_bits);
+ }
+ static constexpr c10::Half signaling_NaN() {
+ return c10::Half(0x7D00, c10::Half::from_bits);
+ }
+ static constexpr c10::Half denorm_min() {
+ return c10::Half(0x0001, c10::Half::from_bits);
+ }
+};
+
+} // namespace std
diff --git a/c10/util/Half.cpp b/c10/util/Half.cpp
new file mode 100644
index 0000000000..76e36ea596
--- /dev/null
+++ b/c10/util/Half.cpp
@@ -0,0 +1,16 @@
+#include <c10/util/Half.h>
+
+#include <iostream>
+
+namespace c10 {
+
+static_assert(
+ std::is_standard_layout<Half>::value,
+ "c10::Half must be standard layout.");
+
+std::ostream& operator<<(std::ostream& out, const Half& value) {
+ out << (float)value;
+ return out;
+}
+
+} // namespace c10
diff --git a/c10/util/Half.h b/c10/util/Half.h
new file mode 100644
index 0000000000..9732490f84
--- /dev/null
+++ b/c10/util/Half.h
@@ -0,0 +1,462 @@
+#pragma once
+
+/// Defines the Half type (half-precision floating-point) including conversions
+/// to standard C types and basic arithmetic operations. Note that arithmetic
+/// operations are implemented by converting to floating point and
+/// performing the operation in float32, instead of using CUDA half intrinisics.
+/// Most uses of this type within ATen are memory bound, including the
+/// element-wise kernels, and the half intrinisics aren't efficient on all GPUs.
+/// If you are writing a compute bound kernel, you can use the CUDA half
+/// intrinsics directly on the Half type from device code.
+
+#include <c10/util/bitcasts.h>
+#include <c10/macros/Macros.h>
+#include <c10/util/C++17.h>
+
+#if defined(__cplusplus) && (__cplusplus >= 201103L)
+#include <cmath>
+#include <cstdint>
+#elif !defined(__OPENCL_VERSION__)
+#include <math.h>
+#include <stdint.h>
+#endif
+
+#ifdef _MSC_VER
+#include <intrin.h>
+#endif
+
+#include <complex>
+#include <cstring>
+#include <iosfwd>
+#include <limits>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <utility>
+
+#ifdef __CUDACC__
+#include <cuda_fp16.h>
+#endif
+
+#ifdef __HIPCC__
+#include <hip/hip_fp16.h>
+#endif
+
+namespace c10 {
+
+namespace detail {
+
+ /*
+ * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to
+ * a 32-bit floating-point number in IEEE single-precision format, in bit representation.
+ *
+ * @note The implementation doesn't use any floating-point operations.
+ */
+ static inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) {
+ /*
+ * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word:
+ * +---+-----+------------+-------------------+
+ * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
+ * +---+-----+------------+-------------------+
+ * Bits 31 26-30 16-25 0-15
+ *
+ * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits.
+ */
+ const uint32_t w = (uint32_t) h << 16;
+ /*
+ * Extract the sign of the input number into the high bit of the 32-bit word:
+ *
+ * +---+----------------------------------+
+ * | S |0000000 00000000 00000000 00000000|
+ * +---+----------------------------------+
+ * Bits 31 0-31
+ */
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ /*
+ * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word:
+ *
+ * +---+-----+------------+-------------------+
+ * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
+ * +---+-----+------------+-------------------+
+ * Bits 30 27-31 17-26 0-16
+ */
+ const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
+ /*
+ * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized.
+ * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one.
+ * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift
+ * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the
+ * biased exponent into 1, and making mantissa normalized (i.e. without leading 1).
+ */
+#ifdef _MSC_VER
+ unsigned long nonsign_bsr;
+ _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
+ uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
+#else
+ uint32_t renorm_shift = __builtin_clz(nonsign);
+#endif
+ renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0;
+ /*
+ * Iff half-precision number has exponent of 15, the addition overflows
+ * it into bit 31, and the subsequent shift turns the high 9 bits
+ * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number
+ * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise
+ */
+ const int32_t inf_nan_mask =
+ ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000);
+ /*
+ * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
+ * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
+ * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
+ * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
+ * 0x00000000 otherwise
+ */
+ const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
+ /*
+ * 1. Shift nonsign left by renorm_shift to normalize it (if the input
+ * was denormal)
+ * 2. Shift nonsign right by 3 so the exponent (5 bits originally)
+ * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high
+ * bits of the 23-bit mantissa of IEEE single-precision number.
+ * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the
+ * different in exponent bias (0x7F for single-precision number less 0xF
+ * for half-precision number).
+ * 4. Subtract renorm_shift from the exponent (starting at bit 23) to
+ * account for renormalization. As renorm_shift is less than 0x70, this
+ * can be combined with step 3.
+ * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
+ * input was NaN or infinity.
+ * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
+ * into zero if the input was zero.
+ * 7. Combine with the sign of the input number.
+ */
+ return sign |
+ ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) |
+ inf_nan_mask) &
+ ~zero_mask);
+ }
+
+ /*
+ * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to
+ * a 32-bit floating-point number in IEEE single-precision format.
+ *
+ * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals)
+ * floating-point operations and bitcasts between integer and floating-point variables.
+ */
+ static inline float fp16_ieee_to_fp32_value(uint16_t h) {
+ /*
+ * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word:
+ * +---+-----+------------+-------------------+
+ * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000|
+ * +---+-----+------------+-------------------+
+ * Bits 31 26-30 16-25 0-15
+ *
+ * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits.
+ */
+ const uint32_t w = (uint32_t) h << 16;
+ /*
+ * Extract the sign of the input number into the high bit of the 32-bit word:
+ *
+ * +---+----------------------------------+
+ * | S |0000000 00000000 00000000 00000000|
+ * +---+----------------------------------+
+ * Bits 31 0-31
+ */
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ /*
+ * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word:
+ *
+ * +-----+------------+---------------------+
+ * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000|
+ * +-----+------------+---------------------+
+ * Bits 27-31 17-26 0-16
+ */
+ const uint32_t two_w = w + w;
+
+ /*
+ * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent
+ * of a single-precision floating-point number:
+ *
+ * S|Exponent | Mantissa
+ * +-+---+-----+------------+----------------+
+ * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000|
+ * +-+---+-----+------------+----------------+
+ * Bits | 23-31 | 0-22
+ *
+ * Next, there are some adjustments to the exponent:
+ * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision
+ * formats (0x7F - 0xF = 0x70)
+ * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number.
+ * Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent
+ * of the single-precision output must be 0xFF (max possible value). We do this correction in two steps:
+ * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested
+ * by the difference in the exponent bias (see above).
+ * - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of
+ * exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias.
+ * The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least
+ * partially IEEE754-compliant implementations.
+ *
+ * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not
+ * operate on denormal inputs, and do not produce denormal results.
+ */
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+ // const float exp_scale = 0x1.0p-112f;
+ uint32_t scale_bits = (uint32_t) 15 << 23;
+ float exp_scale_val;
+ std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
+ const float exp_scale = exp_scale_val;
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+ /*
+ * Convert denormalized half-precision inputs into single-precision results (always normalized).
+ * Zero inputs are also handled here.
+ *
+ * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits.
+ * First, we shift mantissa into bits 0-9 of the 32-bit word.
+ *
+ * zeros | mantissa
+ * +---------------------------+------------+
+ * |0000 0000 0000 0000 0000 00|MM MMMM MMMM|
+ * +---------------------------+------------+
+ * Bits 10-31 0-9
+ *
+ * Now, remember that denormalized half-precision numbers are represented as:
+ * FP16 = mantissa * 2**(-24).
+ * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input
+ * and with an exponent which would scale the corresponding mantissa bits to 2**(-24).
+ * A normalized single-precision floating-point number is represented as:
+ * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127)
+ * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision
+ * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount.
+ *
+ * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number
+ * is zero, the constructed single-precision number has the value of
+ * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5
+ * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of
+ * the input half-precision number.
+ */
+ const uint32_t magic_mask = UINT32_C(126) << 23;
+ const float magic_bias = 0.5f;
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+ /*
+ * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the
+ * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the
+ * input is either a denormal number, or zero.
+ * - Combine the result of conversion of exponent and mantissa with the sign of the input number.
+ */
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+ const uint32_t result = sign |
+ (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
+ return fp32_from_bits(result);
+ }
+
+ /*
+ * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in
+ * IEEE half-precision format, in bit representation.
+ *
+ * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals)
+ * floating-point operations and bitcasts between integer and floating-point variables.
+ */
+ static inline uint16_t fp16_ieee_from_fp32_value(float f) {
+ // const float scale_to_inf = 0x1.0p+112f;
+ // const float scale_to_zero = 0x1.0p-110f;
+ uint32_t scale_to_inf_bits = (uint32_t) 239 << 23;
+ uint32_t scale_to_zero_bits = (uint32_t) 17 << 23;
+ float scale_to_inf_val, scale_to_zero_val;
+ std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
+ std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
+ const float scale_to_inf = scale_to_inf_val;
+ const float scale_to_zero = scale_to_zero_val;
+
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+ const uint32_t w = fp32_to_bits(f);
+ const uint32_t shl1_w = w + w;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+ if (bias < UINT32_C(0x71000000)) {
+ bias = UINT32_C(0x71000000);
+ }
+
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+ const uint32_t bits = fp32_to_bits(base);
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+ const uint32_t nonsign = exp_bits + mantissa_bits;
+ return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
+ }
+
+} // namespace detail
+
+struct alignas(2) Half {
+ unsigned short x;
+
+ struct from_bits_t {};
+ static constexpr from_bits_t from_bits = from_bits_t();
+
+ // HIP wants __host__ __device__ tag, CUDA does not
+#ifdef __HIP_PLATFORM_HCC__
+ C10_HOST_DEVICE Half() = default;
+#else
+ Half() = default;
+#endif
+
+ constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits){};
+ inline C10_HOST_DEVICE Half(float value);
+ inline C10_HOST_DEVICE operator float() const;
+
+#if defined(__CUDACC__) || defined(__HIPCC__)
+ inline C10_HOST_DEVICE Half(const __half& value);
+ inline C10_HOST_DEVICE operator __half() const;
+#endif
+};
+
+// This is just a placeholder for whatever complex representation we
+// end up deciding to use for half-precision complex numbers.
+struct alignas(4) ComplexHalf {
+ Half real_;
+ Half imag_;
+ ComplexHalf() = default;
+ Half real() const {
+ return real_;
+ }
+ Half imag() const {
+ return imag_;
+ }
+ inline ComplexHalf(std::complex<float> value)
+ : real_(value.real()), imag_(value.imag()) {}
+ inline operator std::complex<float>() const {
+ return {real_, imag_};
+ }
+};
+
+template <typename T>
+struct is_complex_t : public std::false_type {};
+
+template <typename T>
+struct is_complex_t<std::complex<T>> : public std::true_type {};
+
+template <>
+struct is_complex_t<ComplexHalf> : public std::true_type {};
+
+// Extract double from std::complex<double>; is identity otherwise
+// TODO: Write in more idiomatic C++17
+template <typename T>
+struct scalar_value_type {
+ using type = T;
+};
+template <typename T>
+struct scalar_value_type<std::complex<T>> {
+ using type = T;
+};
+template <>
+struct scalar_value_type<ComplexHalf> {
+ using type = Half;
+};
+
+// The old implementation of Converter as a function made nvcc's head explode
+// when we added std::complex on top of the specializations for CUDA-only types
+// like __half, so I rewrote it as a templated class (so, no more overloads,
+// just (partial) specialization).
+
+template <typename To, typename From, typename Enable = void>
+struct Converter {
+ To operator()(From f) {
+ return static_cast<To>(f);
+ }
+};
+
+template <typename To, typename From>
+To convert(From from) {
+ return Converter<To, From>()(from);
+}
+
+template <typename To, typename FromV>
+struct Converter<
+ To,
+ std::complex<FromV>,
+ typename std::enable_if<
+ c10::guts::negation<is_complex_t<To>>::value>::type> {
+ To operator()(std::complex<FromV> f) {
+ return static_cast<To>(f.real());
+ }
+};
+
+// In some versions of MSVC, there will be a compiler error when building.
+// C4146: unary minus operator applied to unsigned type, result still unsigned
+// It can be addressed by disabling the following warning.
+#ifdef _MSC_VER
+#pragma warning( push )
+#pragma warning( disable : 4146 )
+#endif
+
+// skip isnan and isinf check for integral types
+template <typename To, typename From>
+typename std::enable_if<std::is_integral<From>::value, bool>::type overflows(
+ From f) {
+ using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
+ if (!limit::is_signed && std::numeric_limits<From>::is_signed) {
+ // allow for negative numbers to wrap using two's complement arithmetic.
+ // For example, with uint8, this allows for `a - b` to be treated as
+ // `a + 255 * b`.
+ return f > limit::max() ||
+ (f < 0 && -static_cast<uint64_t>(f) > limit::max());
+ } else {
+ return f < limit::lowest() || f > limit::max();
+ }
+}
+
+#ifdef _MSC_VER
+#pragma warning( pop )
+#endif
+
+template <typename To, typename From>
+typename std::enable_if<std::is_floating_point<From>::value, bool>::type
+overflows(From f) {
+ using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
+ if (limit::has_infinity && std::isinf(static_cast<double>(f))) {
+ return false;
+ }
+ if (!limit::has_quiet_NaN && (f != f)) {
+ return true;
+ }
+ return f < limit::lowest() || f > limit::max();
+}
+
+template <typename To, typename From>
+typename std::enable_if<is_complex_t<From>::value, bool>::type overflows(
+ From f) {
+ // casts from complex to real are considered to overflow if the
+ // imaginary component is non-zero
+ if (!is_complex_t<To>::value && f.imag() != 0) {
+ return true;
+ }
+ // Check for overflow componentwise
+ // (Technically, the imag overflow check is guaranteed to be false
+ // when !is_complex_t<To>, but any optimizer worth its salt will be
+ // able to figure it out.)
+ return overflows<
+ typename scalar_value_type<To>::type,
+ typename From::value_type>(f.real()) ||
+ overflows<
+ typename scalar_value_type<To>::type,
+ typename From::value_type>(f.imag());
+}
+
+template <typename To, typename From>
+To checked_convert(From f, const char* name) {
+ if (overflows<To, From>(f)) {
+ std::ostringstream oss;
+ oss << "value cannot be converted to type " << name
+ << " without overflow: " << f;
+ throw std::domain_error(oss.str());
+ }
+ return convert<To, From>(f);
+}
+
+C10_API std::ostream& operator<<(std::ostream& out, const Half& value);
+
+} // namespace c10
+
+#include <c10/util/Half-inl.h>
diff --git a/c10/util/bitcasts.h b/c10/util/bitcasts.h
new file mode 100644
index 0000000000..eb8b502eb5
--- /dev/null
+++ b/c10/util/bitcasts.h
@@ -0,0 +1,45 @@
+#pragma once
+
+#if defined(__cplusplus) && (__cplusplus >= 201103L)
+#include <cstdint>
+#elif !defined(__OPENCL_VERSION__)
+#include <stdint.h>
+#endif
+
+namespace c10 {
+namespace detail {
+
+static inline float fp32_from_bits(uint32_t w) {
+#if defined(__OPENCL_VERSION__)
+ return as_float(w);
+#elif defined(__CUDA_ARCH__)
+ return __uint_as_float((unsigned int)w);
+#elif defined(__INTEL_COMPILER)
+ return _castu32_f32(w);
+#else
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } fp32 = {w};
+ return fp32.as_value;
+#endif
+}
+
+static inline uint32_t fp32_to_bits(float f) {
+#if defined(__OPENCL_VERSION__)
+ return as_uint(f);
+#elif defined(__CUDA_ARCH__)
+ return (uint32_t)__float_as_uint(f);
+#elif defined(__INTEL_COMPILER)
+ return _castf32_u32(f);
+#else
+ union {
+ float as_value;
+ uint32_t as_bits;
+ } fp32 = {f};
+ return fp32.as_bits;
+#endif
+}
+
+} // namespace detail
+} // namespace c10
diff --git a/c10/util/typeid.h b/c10/util/typeid.h
index fe37f4219e..448f44ec36 100644
--- a/c10/util/typeid.h
+++ b/c10/util/typeid.h
@@ -18,7 +18,7 @@
#include <exception>
#include "c10/util/Backtrace.h"
-#include "c10/Half.h"
+#include "c10/util/Half.h"
#include "c10/macros/Macros.h"
#include "c10/util/C++17.h"
#include "c10/util/Exception.h"
@@ -430,15 +430,15 @@ class C10_API TypeMeta {
// variable template. '-Wpragmas' and '-Wunknown-warning-option' has to be
// disabled for compilers that don't know '-Wundefined-var-template' and
// would error at our attempt to disable it.
-#ifndef _MSC_VER
-# pragma GCC diagnostic push
-# pragma GCC diagnostic ignored "-Wpragmas"
-# pragma GCC diagnostic ignored "-Wunknown-warning-option"
-# pragma GCC diagnostic ignored "-Wundefined-var-template"
+#ifndef _MSC_VER
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wpragmas"
+# pragma GCC diagnostic ignored "-Wunknown-warning-option"
+# pragma GCC diagnostic ignored "-Wundefined-var-template"
#endif
return TypeMeta(_typeMetaDataInstance<T>());
-#ifndef _MSC_VER
-# pragma GCC diagnostic pop
+#ifndef _MSC_VER
+# pragma GCC diagnostic pop
#endif
}