diff options
author | Sebastian Messmer <messmer@fb.com> | 2019-01-10 16:06:27 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-10 16:22:22 -0800 |
commit | d408324350fa77e72d0be43d7302ac135d25d292 (patch) | |
tree | 6da82456fa49b4f4be8cb0f34ca8db33889b76af /c10/util | |
parent | 6b64052e20c934eed6527c03b544544a4758847c (diff) | |
download | pytorch-d408324350fa77e72d0be43d7302ac135d25d292.tar.gz pytorch-d408324350fa77e72d0be43d7302ac135d25d292.tar.bz2 pytorch-d408324350fa77e72d0be43d7302ac135d25d292.zip |
Move files to/from c10/core and c10/util (#15316)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15316
This starts cleaning up the files in c10 according to the module structure we decided on.
Move to c10/util:
- Half.h, Half-inl.h, Half.cpp, bitcasts.h
Move to c10/core:
- Device.h, Device.cpp
- DeviceType.h, DeviceType.cpp
i-am-not-moving-c2-to-c10
Reviewed By: dzhulgakov
Differential Revision: D13498493
fbshipit-source-id: dfcf1c490474a12ab950c72ca686b8ad86428f63
Diffstat (limited to 'c10/util')
-rw-r--r-- | c10/util/Half-inl.h | 285 | ||||
-rw-r--r-- | c10/util/Half.cpp | 16 | ||||
-rw-r--r-- | c10/util/Half.h | 462 | ||||
-rw-r--r-- | c10/util/bitcasts.h | 45 | ||||
-rw-r--r-- | c10/util/typeid.h | 16 |
5 files changed, 816 insertions, 8 deletions
diff --git a/c10/util/Half-inl.h b/c10/util/Half-inl.h new file mode 100644 index 0000000000..966d55f1f4 --- /dev/null +++ b/c10/util/Half-inl.h @@ -0,0 +1,285 @@ +#pragma once + +#include <cstring> +#include <limits> +#include <c10/macros/Macros.h> + +#ifdef __CUDACC__ +#include <cuda_fp16.h> +#endif + +#ifdef __HIPCC__ +#include <hip/hip_fp16.h> +#endif + +namespace c10 { + +/// Constructors + +inline C10_HOST_DEVICE Half::Half(float value) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + x = __half_as_short(__float2half(value)); +#else + x = detail::fp16_ieee_from_fp32_value(value); +#endif +} + +/// Implicit conversions + +inline C10_HOST_DEVICE Half::operator float() const { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + return __half2float(*reinterpret_cast<const __half*>(&x)); +#else + return detail::fp16_ieee_to_fp32_value(x); +#endif +} + +#if defined(__CUDACC__) || defined(__HIPCC__) +inline C10_HOST_DEVICE Half::Half(const __half& value) { + x = *reinterpret_cast<const unsigned short*>(&value); +} +inline C10_HOST_DEVICE Half::operator __half() const { + return *reinterpret_cast<const __half*>(&x); +} +#endif + +// CUDA intrinsics + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 350) +inline __device__ Half __ldg(const Half* ptr) { + return __ldg(reinterpret_cast<const __half*>(ptr)); +} +#endif + +/// Arithmetic + +inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) { + return static_cast<float>(a) + static_cast<float>(b); +} + +inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) { + return static_cast<float>(a) - static_cast<float>(b); +} + +inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) { + return static_cast<float>(a) * static_cast<float>(b); +} + +inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) { + return static_cast<float>(a) / static_cast<float>(b); +} + +inline C10_HOST_DEVICE Half operator-(const Half& a) { + return -static_cast<float>(a); +} + +inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Half a, float b) { + return static_cast<float>(a) + b; +} +inline C10_HOST_DEVICE float operator-(Half a, float b) { + return static_cast<float>(a) - b; +} +inline C10_HOST_DEVICE float operator*(Half a, float b) { + return static_cast<float>(a) * b; +} +inline C10_HOST_DEVICE float operator/(Half a, float b) { + return static_cast<float>(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Half b) { + return a + static_cast<float>(b); +} +inline C10_HOST_DEVICE float operator-(float a, Half b) { + return a - static_cast<float>(b); +} +inline C10_HOST_DEVICE float operator*(float a, Half b) { + return a * static_cast<float>(b); +} +inline C10_HOST_DEVICE float operator/(float a, Half b) { + return a / static_cast<float>(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) { + return a += static_cast<float>(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) { + return a -= static_cast<float>(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) { + return a *= static_cast<float>(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) { + return a /= static_cast<float>(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Half a, double b) { + return static_cast<double>(a) + b; +} +inline C10_HOST_DEVICE double operator-(Half a, double b) { + return static_cast<double>(a) - b; +} +inline C10_HOST_DEVICE double operator*(Half a, double b) { + return static_cast<double>(a) * b; +} +inline C10_HOST_DEVICE double operator/(Half a, double b) { + return static_cast<double>(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Half b) { + return a + static_cast<double>(b); +} +inline C10_HOST_DEVICE double operator-(double a, Half b) { + return a - static_cast<double>(b); +} +inline C10_HOST_DEVICE double operator*(double a, Half b) { + return a * static_cast<double>(b); +} +inline C10_HOST_DEVICE double operator/(double a, Half b) { + return a / static_cast<double>(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Half operator+(Half a, int b) { + return a + static_cast<Half>(b); +} +inline C10_HOST_DEVICE Half operator-(Half a, int b) { + return a - static_cast<Half>(b); +} +inline C10_HOST_DEVICE Half operator*(Half a, int b) { + return a * static_cast<Half>(b); +} +inline C10_HOST_DEVICE Half operator/(Half a, int b) { + return a / static_cast<Half>(b); +} + +inline C10_HOST_DEVICE Half operator+(int a, Half b) { + return static_cast<Half>(a) + b; +} +inline C10_HOST_DEVICE Half operator-(int a, Half b) { + return static_cast<Half>(a) - b; +} +inline C10_HOST_DEVICE Half operator*(int a, Half b) { + return static_cast<Half>(a) * b; +} +inline C10_HOST_DEVICE Half operator/(int a, Half b) { + return static_cast<Half>(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) { + return a + static_cast<Half>(b); +} +inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) { + return a - static_cast<Half>(b); +} +inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) { + return a * static_cast<Half>(b); +} +inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) { + return a / static_cast<Half>(b); +} + +inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) { + return static_cast<Half>(a) + b; +} +inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) { + return static_cast<Half>(a) - b; +} +inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) { + return static_cast<Half>(a) * b; +} +inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) { + return static_cast<Half>(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Half to float. + +} // namespace c10 + +namespace std { + +template <> +class numeric_limits<c10::Half> { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits<float>::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits<float>::has_denorm_loss; + static constexpr auto round_style = numeric_limits<float>::round_style; + static constexpr bool is_iec559 = true; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 11; + static constexpr int digits10 = 3; + static constexpr int max_digits10 = 5; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = numeric_limits<float>::traps; + static constexpr auto tinyness_before = + numeric_limits<float>::tinyness_before; + static constexpr c10::Half min() { + return c10::Half(0x0400, c10::Half::from_bits); + } + static constexpr c10::Half lowest() { + return c10::Half(0xFBFF, c10::Half::from_bits); + } + static constexpr c10::Half max() { + return c10::Half(0x7BFF, c10::Half::from_bits); + } + static constexpr c10::Half epsilon() { + return c10::Half(0x1400, c10::Half::from_bits); + } + static constexpr c10::Half round_error() { + return c10::Half(0x3800, c10::Half::from_bits); + } + static constexpr c10::Half infinity() { + return c10::Half(0x7C00, c10::Half::from_bits); + } + static constexpr c10::Half quiet_NaN() { + return c10::Half(0x7E00, c10::Half::from_bits); + } + static constexpr c10::Half signaling_NaN() { + return c10::Half(0x7D00, c10::Half::from_bits); + } + static constexpr c10::Half denorm_min() { + return c10::Half(0x0001, c10::Half::from_bits); + } +}; + +} // namespace std diff --git a/c10/util/Half.cpp b/c10/util/Half.cpp new file mode 100644 index 0000000000..76e36ea596 --- /dev/null +++ b/c10/util/Half.cpp @@ -0,0 +1,16 @@ +#include <c10/util/Half.h> + +#include <iostream> + +namespace c10 { + +static_assert( + std::is_standard_layout<Half>::value, + "c10::Half must be standard layout."); + +std::ostream& operator<<(std::ostream& out, const Half& value) { + out << (float)value; + return out; +} + +} // namespace c10 diff --git a/c10/util/Half.h b/c10/util/Half.h new file mode 100644 index 0000000000..9732490f84 --- /dev/null +++ b/c10/util/Half.h @@ -0,0 +1,462 @@ +#pragma once + +/// Defines the Half type (half-precision floating-point) including conversions +/// to standard C types and basic arithmetic operations. Note that arithmetic +/// operations are implemented by converting to floating point and +/// performing the operation in float32, instead of using CUDA half intrinisics. +/// Most uses of this type within ATen are memory bound, including the +/// element-wise kernels, and the half intrinisics aren't efficient on all GPUs. +/// If you are writing a compute bound kernel, you can use the CUDA half +/// intrinsics directly on the Half type from device code. + +#include <c10/util/bitcasts.h> +#include <c10/macros/Macros.h> +#include <c10/util/C++17.h> + +#if defined(__cplusplus) && (__cplusplus >= 201103L) +#include <cmath> +#include <cstdint> +#elif !defined(__OPENCL_VERSION__) +#include <math.h> +#include <stdint.h> +#endif + +#ifdef _MSC_VER +#include <intrin.h> +#endif + +#include <complex> +#include <cstring> +#include <iosfwd> +#include <limits> +#include <sstream> +#include <stdexcept> +#include <string> +#include <utility> + +#ifdef __CUDACC__ +#include <cuda_fp16.h> +#endif + +#ifdef __HIPCC__ +#include <hip/hip_fp16.h> +#endif + +namespace c10 { + +namespace detail { + + /* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format, in bit representation. + * + * @note The implementation doesn't use any floating-point operations. + */ + static inline uint32_t fp16_ieee_to_fp32_bits(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the bits 0-30 of the 32-bit word: + * + * +---+-----+------------+-------------------+ + * | 0 |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 30 27-31 17-26 0-16 + */ + const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF); + /* + * Renorm shift is the number of bits to shift mantissa left to make the half-precision number normalized. + * If the initial number is normalized, some of its high 6 bits (sign == 0 and 5-bit exponent) equals one. + * In this case renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note that if we shift + * denormalized nonsign by renorm_shift, the unit bit of mantissa will shift into exponent, turning the + * biased exponent into 1, and making mantissa normalized (i.e. without leading 1). + */ +#ifdef _MSC_VER + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(nonsign); +#endif + renorm_shift = renorm_shift > 5 ? renorm_shift - 5 : 0; + /* + * Iff half-precision number has exponent of 15, the addition overflows + * it into bit 31, and the subsequent shift turns the high 9 bits + * into 1. Thus inf_nan_mask == 0x7F800000 if the half-precision number + * had exponent of 15 (i.e. was NaN or infinity) 0x00000000 otherwise + */ + const int32_t inf_nan_mask = + ((int32_t)(nonsign + 0x04000000) >> 8) & INT32_C(0x7F800000); + /* + * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31 + * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31 + * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask == + * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h) + * 0x00000000 otherwise + */ + const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31; + /* + * 1. Shift nonsign left by renorm_shift to normalize it (if the input + * was denormal) + * 2. Shift nonsign right by 3 so the exponent (5 bits originally) + * becomes an 8-bit field and 10-bit mantissa shifts into the 10 high + * bits of the 23-bit mantissa of IEEE single-precision number. + * 3. Add 0x70 to the exponent (starting at bit 23) to compensate the + * different in exponent bias (0x7F for single-precision number less 0xF + * for half-precision number). + * 4. Subtract renorm_shift from the exponent (starting at bit 23) to + * account for renormalization. As renorm_shift is less than 0x70, this + * can be combined with step 3. + * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the + * input was NaN or infinity. + * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent + * into zero if the input was zero. + * 7. Combine with the sign of the input number. + */ + return sign | + ((((nonsign << renorm_shift >> 3) + ((0x70 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + } + + /* + * Convert a 16-bit floating-point number in IEEE half-precision format, in bit representation, to + * a 32-bit floating-point number in IEEE single-precision format. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ + static inline float fp16_ieee_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0 - zero bits. + */ + const uint32_t w = (uint32_t) h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; + + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become mantissa and exponent + * of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias between single-precision and half-precision + * formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after conversion to the single-precision number. + * Therefore, if the biased exponent of the half-precision input was 0x1F (max possible value), the biased exponent + * of the single-precision output must be 0xFF (max possible value). We do this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset below) rather than by 0x70 suggested + * by the difference in the exponent bias (see above). + * - Then we multiply the single-precision result of exponent adjustment by 2**(-112) to reverse the effect of + * exponent adjustment by 0xE0 less the necessary exponent adjustment by 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and NaN would retain their value on at least + * partially IEEE754-compliant implementations. + * + * Note that the above operations do not handle denormal inputs (where biased exponent == 0). However, they also do not + * operate on denormal inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0xE0) << 23; + // const float exp_scale = 0x1.0p-112f; + uint32_t scale_bits = (uint32_t) 15 << 23; + float exp_scale_val; + std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val)); + const float exp_scale = exp_scale_val; + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + /* + * Convert denormalized half-precision inputs into single-precision results (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa of the input denormalized half-precision + * number causes a change of the constructud single-precision number by 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + /* + * - Choose either results of conversion of input as a normalized number, or as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); + } + + /* + * Convert a 32-bit floating-point number in IEEE single-precision format to a 16-bit floating-point number in + * IEEE half-precision format, in bit representation. + * + * @note The implementation relies on IEEE-like (no assumption about rounding mode and no operations on denormals) + * floating-point operations and bitcasts between integer and floating-point variables. + */ + static inline uint16_t fp16_ieee_from_fp32_value(float f) { + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + uint32_t scale_to_inf_bits = (uint32_t) 239 << 23; + uint32_t scale_to_zero_bits = (uint32_t) 17 << 23; + float scale_to_inf_val, scale_to_zero_val; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); + } + +} // namespace detail + +struct alignas(2) Half { + unsigned short x; + + struct from_bits_t {}; + static constexpr from_bits_t from_bits = from_bits_t(); + + // HIP wants __host__ __device__ tag, CUDA does not +#ifdef __HIP_PLATFORM_HCC__ + C10_HOST_DEVICE Half() = default; +#else + Half() = default; +#endif + + constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits){}; + inline C10_HOST_DEVICE Half(float value); + inline C10_HOST_DEVICE operator float() const; + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline C10_HOST_DEVICE Half(const __half& value); + inline C10_HOST_DEVICE operator __half() const; +#endif +}; + +// This is just a placeholder for whatever complex representation we +// end up deciding to use for half-precision complex numbers. +struct alignas(4) ComplexHalf { + Half real_; + Half imag_; + ComplexHalf() = default; + Half real() const { + return real_; + } + Half imag() const { + return imag_; + } + inline ComplexHalf(std::complex<float> value) + : real_(value.real()), imag_(value.imag()) {} + inline operator std::complex<float>() const { + return {real_, imag_}; + } +}; + +template <typename T> +struct is_complex_t : public std::false_type {}; + +template <typename T> +struct is_complex_t<std::complex<T>> : public std::true_type {}; + +template <> +struct is_complex_t<ComplexHalf> : public std::true_type {}; + +// Extract double from std::complex<double>; is identity otherwise +// TODO: Write in more idiomatic C++17 +template <typename T> +struct scalar_value_type { + using type = T; +}; +template <typename T> +struct scalar_value_type<std::complex<T>> { + using type = T; +}; +template <> +struct scalar_value_type<ComplexHalf> { + using type = Half; +}; + +// The old implementation of Converter as a function made nvcc's head explode +// when we added std::complex on top of the specializations for CUDA-only types +// like __half, so I rewrote it as a templated class (so, no more overloads, +// just (partial) specialization). + +template <typename To, typename From, typename Enable = void> +struct Converter { + To operator()(From f) { + return static_cast<To>(f); + } +}; + +template <typename To, typename From> +To convert(From from) { + return Converter<To, From>()(from); +} + +template <typename To, typename FromV> +struct Converter< + To, + std::complex<FromV>, + typename std::enable_if< + c10::guts::negation<is_complex_t<To>>::value>::type> { + To operator()(std::complex<FromV> f) { + return static_cast<To>(f.real()); + } +}; + +// In some versions of MSVC, there will be a compiler error when building. +// C4146: unary minus operator applied to unsigned type, result still unsigned +// It can be addressed by disabling the following warning. +#ifdef _MSC_VER +#pragma warning( push ) +#pragma warning( disable : 4146 ) +#endif + +// skip isnan and isinf check for integral types +template <typename To, typename From> +typename std::enable_if<std::is_integral<From>::value, bool>::type overflows( + From f) { + using limit = std::numeric_limits<typename scalar_value_type<To>::type>; + if (!limit::is_signed && std::numeric_limits<From>::is_signed) { + // allow for negative numbers to wrap using two's complement arithmetic. + // For example, with uint8, this allows for `a - b` to be treated as + // `a + 255 * b`. + return f > limit::max() || + (f < 0 && -static_cast<uint64_t>(f) > limit::max()); + } else { + return f < limit::lowest() || f > limit::max(); + } +} + +#ifdef _MSC_VER +#pragma warning( pop ) +#endif + +template <typename To, typename From> +typename std::enable_if<std::is_floating_point<From>::value, bool>::type +overflows(From f) { + using limit = std::numeric_limits<typename scalar_value_type<To>::type>; + if (limit::has_infinity && std::isinf(static_cast<double>(f))) { + return false; + } + if (!limit::has_quiet_NaN && (f != f)) { + return true; + } + return f < limit::lowest() || f > limit::max(); +} + +template <typename To, typename From> +typename std::enable_if<is_complex_t<From>::value, bool>::type overflows( + From f) { + // casts from complex to real are considered to overflow if the + // imaginary component is non-zero + if (!is_complex_t<To>::value && f.imag() != 0) { + return true; + } + // Check for overflow componentwise + // (Technically, the imag overflow check is guaranteed to be false + // when !is_complex_t<To>, but any optimizer worth its salt will be + // able to figure it out.) + return overflows< + typename scalar_value_type<To>::type, + typename From::value_type>(f.real()) || + overflows< + typename scalar_value_type<To>::type, + typename From::value_type>(f.imag()); +} + +template <typename To, typename From> +To checked_convert(From f, const char* name) { + if (overflows<To, From>(f)) { + std::ostringstream oss; + oss << "value cannot be converted to type " << name + << " without overflow: " << f; + throw std::domain_error(oss.str()); + } + return convert<To, From>(f); +} + +C10_API std::ostream& operator<<(std::ostream& out, const Half& value); + +} // namespace c10 + +#include <c10/util/Half-inl.h> diff --git a/c10/util/bitcasts.h b/c10/util/bitcasts.h new file mode 100644 index 0000000000..eb8b502eb5 --- /dev/null +++ b/c10/util/bitcasts.h @@ -0,0 +1,45 @@ +#pragma once + +#if defined(__cplusplus) && (__cplusplus >= 201103L) +#include <cstdint> +#elif !defined(__OPENCL_VERSION__) +#include <stdint.h> +#endif + +namespace c10 { +namespace detail { + +static inline float fp32_from_bits(uint32_t w) { +#if defined(__OPENCL_VERSION__) + return as_float(w); +#elif defined(__CUDA_ARCH__) + return __uint_as_float((unsigned int)w); +#elif defined(__INTEL_COMPILER) + return _castu32_f32(w); +#else + union { + uint32_t as_bits; + float as_value; + } fp32 = {w}; + return fp32.as_value; +#endif +} + +static inline uint32_t fp32_to_bits(float f) { +#if defined(__OPENCL_VERSION__) + return as_uint(f); +#elif defined(__CUDA_ARCH__) + return (uint32_t)__float_as_uint(f); +#elif defined(__INTEL_COMPILER) + return _castf32_u32(f); +#else + union { + float as_value; + uint32_t as_bits; + } fp32 = {f}; + return fp32.as_bits; +#endif +} + +} // namespace detail +} // namespace c10 diff --git a/c10/util/typeid.h b/c10/util/typeid.h index fe37f4219e..448f44ec36 100644 --- a/c10/util/typeid.h +++ b/c10/util/typeid.h @@ -18,7 +18,7 @@ #include <exception> #include "c10/util/Backtrace.h" -#include "c10/Half.h" +#include "c10/util/Half.h" #include "c10/macros/Macros.h" #include "c10/util/C++17.h" #include "c10/util/Exception.h" @@ -430,15 +430,15 @@ class C10_API TypeMeta { // variable template. '-Wpragmas' and '-Wunknown-warning-option' has to be // disabled for compilers that don't know '-Wundefined-var-template' and // would error at our attempt to disable it. -#ifndef _MSC_VER -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wpragmas" -# pragma GCC diagnostic ignored "-Wunknown-warning-option" -# pragma GCC diagnostic ignored "-Wundefined-var-template" +#ifndef _MSC_VER +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wpragmas" +# pragma GCC diagnostic ignored "-Wunknown-warning-option" +# pragma GCC diagnostic ignored "-Wundefined-var-template" #endif return TypeMeta(_typeMetaDataInstance<T>()); -#ifndef _MSC_VER -# pragma GCC diagnostic pop +#ifndef _MSC_VER +# pragma GCC diagnostic pop #endif } |