diff options
-rw-r--r-- | aten/src/ATen/core/ATenCoreTest.h | 2 | ||||
-rw-r--r-- | aten/src/ATen/core/ATenGeneral.h | 2 | ||||
-rw-r--r-- | aten/src/ATen/core/Half-inl.h | 100 | ||||
-rw-r--r-- | aten/src/ATen/core/Half.h | 14 | ||||
-rw-r--r-- | aten/src/ATen/core/IdWrapper.h | 2 | ||||
-rw-r--r-- | aten/src/ATen/core/Macros.h | 35 | ||||
-rw-r--r-- | aten/src/ATen/core/OptionsGuard.cpp | 2 | ||||
-rw-r--r-- | aten/src/ATen/core/OptionsGuard.h | 8 | ||||
-rw-r--r-- | aten/src/ATen/core/SmallVector.h | 2 | ||||
-rw-r--r-- | aten/src/ATen/core/TensorAccessor.h | 91 | ||||
-rw-r--r-- | aten/src/ATen/core/TensorTypeId.h | 2 | ||||
-rw-r--r-- | aten/src/ATen/core/TensorTypeIdRegistration.h | 2 | ||||
-rw-r--r-- | aten/src/ATen/core/UniqueVoidPtr.h | 2 | ||||
-rw-r--r-- | aten/src/ATen/core/aten_interned_strings.h | 2 | ||||
-rw-r--r-- | aten/src/ATen/core/interned_strings.h | 4 | ||||
-rw-r--r-- | aten/src/ATen/core/typeid.h | 2 | ||||
-rw-r--r-- | aten/src/ATen/cuda/Array.h | 18 | ||||
-rw-r--r-- | aten/src/ATen/cuda/detail/OffsetCalculator.cuh | 4 | ||||
-rw-r--r-- | aten/src/ATen/native/cuda/Reduce.cuh | 32 | ||||
-rw-r--r-- | c10/macros/Macros.h | 35 |
20 files changed, 198 insertions, 163 deletions
diff --git a/aten/src/ATen/core/ATenCoreTest.h b/aten/src/ATen/core/ATenCoreTest.h index 93f894ea66..0a45902317 100644 --- a/aten/src/ATen/core/ATenCoreTest.h +++ b/aten/src/ATen/core/ATenCoreTest.h @@ -1,6 +1,6 @@ #pragma once -#include <ATen/core/Macros.h> +#include <c10/macros/Macros.h> namespace at { diff --git a/aten/src/ATen/core/ATenGeneral.h b/aten/src/ATen/core/ATenGeneral.h index cb946c93c9..618f987f4a 100644 --- a/aten/src/ATen/core/ATenGeneral.h +++ b/aten/src/ATen/core/ATenGeneral.h @@ -1,3 +1,3 @@ #pragma once -#include "ATen/core/Macros.h" +#include "c10/macros/Macros.h" diff --git a/aten/src/ATen/core/Half-inl.h b/aten/src/ATen/core/Half-inl.h index e63243f563..18582a7524 100644 --- a/aten/src/ATen/core/Half-inl.h +++ b/aten/src/ATen/core/Half-inl.h @@ -2,7 +2,7 @@ #include <cstring> #include <limits> -#include <ATen/core/Macros.h> +#include <c10/macros/Macros.h> #ifdef __CUDACC__ #include <cuda_fp16.h> @@ -16,7 +16,7 @@ namespace at { /// Constructors -inline AT_HOST_DEVICE Half::Half(float value) { +inline C10_HOST_DEVICE Half::Half(float value) { #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) x = __half_as_short(__float2half(value)); #else @@ -26,7 +26,7 @@ inline AT_HOST_DEVICE Half::Half(float value) { /// Implicit conversions -inline AT_HOST_DEVICE Half::operator float() const { +inline C10_HOST_DEVICE Half::operator float() const { #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) return __half2float(*reinterpret_cast<const __half*>(&x)); #else @@ -35,10 +35,10 @@ inline AT_HOST_DEVICE Half::operator float() const { } #if defined(__CUDACC__) || defined(__HIPCC__) -inline AT_HOST_DEVICE Half::Half(const __half& value) { +inline C10_HOST_DEVICE Half::Half(const __half& value) { x = *reinterpret_cast<const unsigned short*>(&value); } -inline AT_HOST_DEVICE Half::operator __half() const { +inline C10_HOST_DEVICE Half::operator __half() const { return *reinterpret_cast<const __half*>(&x); } #endif @@ -53,168 +53,168 @@ inline __device__ Half __ldg(const Half* ptr) { /// Arithmetic -inline AT_HOST_DEVICE Half operator+(const Half& a, const Half& b) { +inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) { return static_cast<float>(a) + static_cast<float>(b); } -inline AT_HOST_DEVICE Half operator-(const Half& a, const Half& b) { +inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) { return static_cast<float>(a) - static_cast<float>(b); } -inline AT_HOST_DEVICE Half operator*(const Half& a, const Half& b) { +inline C10_HOST_DEVICE Half operator*(const Half& a, const Half& b) { return static_cast<float>(a) * static_cast<float>(b); } -inline AT_HOST_DEVICE Half operator/(const Half& a, const Half& b) { +inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) { return static_cast<float>(a) / static_cast<float>(b); } -inline AT_HOST_DEVICE Half operator-(const Half& a) { +inline C10_HOST_DEVICE Half operator-(const Half& a) { return -static_cast<float>(a); } -inline AT_HOST_DEVICE Half& operator+=(Half& a, const Half& b) { +inline C10_HOST_DEVICE Half& operator+=(Half& a, const Half& b) { a = a + b; return a; } -inline AT_HOST_DEVICE Half& operator-=(Half& a, const Half& b) { +inline C10_HOST_DEVICE Half& operator-=(Half& a, const Half& b) { a = a - b; return a; } -inline AT_HOST_DEVICE Half& operator*=(Half& a, const Half& b) { +inline C10_HOST_DEVICE Half& operator*=(Half& a, const Half& b) { a = a * b; return a; } -inline AT_HOST_DEVICE Half& operator/=(Half& a, const Half& b) { +inline C10_HOST_DEVICE Half& operator/=(Half& a, const Half& b) { a = a / b; return a; } /// Arithmetic with floats -inline AT_HOST_DEVICE float operator+(Half a, float b) { +inline C10_HOST_DEVICE float operator+(Half a, float b) { return static_cast<float>(a) + b; } -inline AT_HOST_DEVICE float operator-(Half a, float b) { +inline C10_HOST_DEVICE float operator-(Half a, float b) { return static_cast<float>(a) - b; } -inline AT_HOST_DEVICE float operator*(Half a, float b) { +inline C10_HOST_DEVICE float operator*(Half a, float b) { return static_cast<float>(a) * b; } -inline AT_HOST_DEVICE float operator/(Half a, float b) { +inline C10_HOST_DEVICE float operator/(Half a, float b) { return static_cast<float>(a) / b; } -inline AT_HOST_DEVICE float operator+(float a, Half b) { +inline C10_HOST_DEVICE float operator+(float a, Half b) { return a + static_cast<float>(b); } -inline AT_HOST_DEVICE float operator-(float a, Half b) { +inline C10_HOST_DEVICE float operator-(float a, Half b) { return a - static_cast<float>(b); } -inline AT_HOST_DEVICE float operator*(float a, Half b) { +inline C10_HOST_DEVICE float operator*(float a, Half b) { return a * static_cast<float>(b); } -inline AT_HOST_DEVICE float operator/(float a, Half b) { +inline C10_HOST_DEVICE float operator/(float a, Half b) { return a / static_cast<float>(b); } -inline AT_HOST_DEVICE float& operator+=(float& a, const Half& b) { +inline C10_HOST_DEVICE float& operator+=(float& a, const Half& b) { return a += static_cast<float>(b); } -inline AT_HOST_DEVICE float& operator-=(float& a, const Half& b) { +inline C10_HOST_DEVICE float& operator-=(float& a, const Half& b) { return a -= static_cast<float>(b); } -inline AT_HOST_DEVICE float& operator*=(float& a, const Half& b) { +inline C10_HOST_DEVICE float& operator*=(float& a, const Half& b) { return a *= static_cast<float>(b); } -inline AT_HOST_DEVICE float& operator/=(float& a, const Half& b) { +inline C10_HOST_DEVICE float& operator/=(float& a, const Half& b) { return a /= static_cast<float>(b); } /// Arithmetic with doubles -inline AT_HOST_DEVICE double operator+(Half a, double b) { +inline C10_HOST_DEVICE double operator+(Half a, double b) { return static_cast<double>(a) + b; } -inline AT_HOST_DEVICE double operator-(Half a, double b) { +inline C10_HOST_DEVICE double operator-(Half a, double b) { return static_cast<double>(a) - b; } -inline AT_HOST_DEVICE double operator*(Half a, double b) { +inline C10_HOST_DEVICE double operator*(Half a, double b) { return static_cast<double>(a) * b; } -inline AT_HOST_DEVICE double operator/(Half a, double b) { +inline C10_HOST_DEVICE double operator/(Half a, double b) { return static_cast<double>(a) / b; } -inline AT_HOST_DEVICE double operator+(double a, Half b) { +inline C10_HOST_DEVICE double operator+(double a, Half b) { return a + static_cast<double>(b); } -inline AT_HOST_DEVICE double operator-(double a, Half b) { +inline C10_HOST_DEVICE double operator-(double a, Half b) { return a - static_cast<double>(b); } -inline AT_HOST_DEVICE double operator*(double a, Half b) { +inline C10_HOST_DEVICE double operator*(double a, Half b) { return a * static_cast<double>(b); } -inline AT_HOST_DEVICE double operator/(double a, Half b) { +inline C10_HOST_DEVICE double operator/(double a, Half b) { return a / static_cast<double>(b); } /// Arithmetic with ints -inline AT_HOST_DEVICE Half operator+(Half a, int b) { +inline C10_HOST_DEVICE Half operator+(Half a, int b) { return a + static_cast<Half>(b); } -inline AT_HOST_DEVICE Half operator-(Half a, int b) { +inline C10_HOST_DEVICE Half operator-(Half a, int b) { return a - static_cast<Half>(b); } -inline AT_HOST_DEVICE Half operator*(Half a, int b) { +inline C10_HOST_DEVICE Half operator*(Half a, int b) { return a * static_cast<Half>(b); } -inline AT_HOST_DEVICE Half operator/(Half a, int b) { +inline C10_HOST_DEVICE Half operator/(Half a, int b) { return a / static_cast<Half>(b); } -inline AT_HOST_DEVICE Half operator+(int a, Half b) { +inline C10_HOST_DEVICE Half operator+(int a, Half b) { return static_cast<Half>(a) + b; } -inline AT_HOST_DEVICE Half operator-(int a, Half b) { +inline C10_HOST_DEVICE Half operator-(int a, Half b) { return static_cast<Half>(a) - b; } -inline AT_HOST_DEVICE Half operator*(int a, Half b) { +inline C10_HOST_DEVICE Half operator*(int a, Half b) { return static_cast<Half>(a) * b; } -inline AT_HOST_DEVICE Half operator/(int a, Half b) { +inline C10_HOST_DEVICE Half operator/(int a, Half b) { return static_cast<Half>(a) / b; } //// Arithmetic with int64_t -inline AT_HOST_DEVICE Half operator+(Half a, int64_t b) { +inline C10_HOST_DEVICE Half operator+(Half a, int64_t b) { return a + static_cast<Half>(b); } -inline AT_HOST_DEVICE Half operator-(Half a, int64_t b) { +inline C10_HOST_DEVICE Half operator-(Half a, int64_t b) { return a - static_cast<Half>(b); } -inline AT_HOST_DEVICE Half operator*(Half a, int64_t b) { +inline C10_HOST_DEVICE Half operator*(Half a, int64_t b) { return a * static_cast<Half>(b); } -inline AT_HOST_DEVICE Half operator/(Half a, int64_t b) { +inline C10_HOST_DEVICE Half operator/(Half a, int64_t b) { return a / static_cast<Half>(b); } -inline AT_HOST_DEVICE Half operator+(int64_t a, Half b) { +inline C10_HOST_DEVICE Half operator+(int64_t a, Half b) { return static_cast<Half>(a) + b; } -inline AT_HOST_DEVICE Half operator-(int64_t a, Half b) { +inline C10_HOST_DEVICE Half operator-(int64_t a, Half b) { return static_cast<Half>(a) - b; } -inline AT_HOST_DEVICE Half operator*(int64_t a, Half b) { +inline C10_HOST_DEVICE Half operator*(int64_t a, Half b) { return static_cast<Half>(a) * b; } -inline AT_HOST_DEVICE Half operator/(int64_t a, Half b) { +inline C10_HOST_DEVICE Half operator/(int64_t a, Half b) { return static_cast<Half>(a) / b; } diff --git a/aten/src/ATen/core/Half.h b/aten/src/ATen/core/Half.h index d5835feed3..1b702cac49 100644 --- a/aten/src/ATen/core/Half.h +++ b/aten/src/ATen/core/Half.h @@ -9,7 +9,7 @@ /// If you are writing a compute bound kernel, you can use the CUDA half /// intrinsics directly on the Half type from device code. -#include <ATen/core/Macros.h> +#include <c10/macros/Macros.h> #include <c10/util/C++17.h> #include <cmath> @@ -47,18 +47,18 @@ struct alignas(2) Half { // HIP wants __host__ __device__ tag, CUDA does not #ifdef __HIP_PLATFORM_HCC__ - AT_HOST_DEVICE Half() = default; + C10_HOST_DEVICE Half() = default; #else Half() = default; #endif - constexpr AT_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits){}; - inline AT_HOST_DEVICE Half(float value); - inline AT_HOST_DEVICE operator float() const; + constexpr C10_HOST_DEVICE Half(unsigned short bits, from_bits_t) : x(bits){}; + inline C10_HOST_DEVICE Half(float value); + inline C10_HOST_DEVICE operator float() const; #if defined(__CUDACC__) || defined(__HIPCC__) - inline AT_HOST_DEVICE Half(const __half& value); - inline AT_HOST_DEVICE operator __half() const; + inline C10_HOST_DEVICE Half(const __half& value); + inline C10_HOST_DEVICE operator __half() const; #endif }; diff --git a/aten/src/ATen/core/IdWrapper.h b/aten/src/ATen/core/IdWrapper.h index 6ca0b934a6..c82db97d7a 100644 --- a/aten/src/ATen/core/IdWrapper.h +++ b/aten/src/ATen/core/IdWrapper.h @@ -1,6 +1,6 @@ #pragma once -#include <ATen/core/Macros.h> +#include <c10/macros/Macros.h> #include <functional> #include <utility> diff --git a/aten/src/ATen/core/Macros.h b/aten/src/ATen/core/Macros.h index 56840ce7ce..7340f5bb91 100644 --- a/aten/src/ATen/core/Macros.h +++ b/aten/src/ATen/core/Macros.h @@ -1,37 +1,2 @@ #pragma once - -#include <sstream> -#include <string> - #include "c10/macros/Macros.h" - -#if defined(__CUDACC__) || defined(__HIPCC__) -// Designates functions callable from the host (CPU) and the device (GPU) -#define AT_HOST_DEVICE __host__ __device__ -#define AT_DEVICE __device__ -#define AT_HOST __host__ -#else -#define AT_HOST_DEVICE -#define AT_HOST -#define AT_DEVICE -#endif - -#ifdef __HIP_PLATFORM_HCC__ -#define HIP_HOST_DEVICE __host__ __device__ -#else -#define HIP_HOST_DEVICE -#endif - -#if defined(__ANDROID__) -#define AT_ANDROID 1 -#define AT_MOBILE 1 -#elif (defined(__APPLE__) && \ - (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) -#define AT_IOS 1 -#define AT_MOBILE 1 -#elif (defined(__APPLE__) && TARGET_OS_MAC) -#define AT_IOS 1 -#define AT_MOBILE 0 -#else -#define AT_MOBILE 0 -#endif // ANDROID / IOS / MACOS diff --git a/aten/src/ATen/core/OptionsGuard.cpp b/aten/src/ATen/core/OptionsGuard.cpp index 606b4623c4..65d72e1864 100644 --- a/aten/src/ATen/core/OptionsGuard.cpp +++ b/aten/src/ATen/core/OptionsGuard.cpp @@ -7,7 +7,7 @@ namespace at { // In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting, // thread_local is not supported. In that case, we don't provide // an OptionsGuard; and force you to pass around options manually. -#if !AT_MOBILE && !defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) +#if !C10_MOBILE && !defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) DefaultTensorOptions& mutateDefaultTensorOptions() { static thread_local c10::optional<DefaultTensorOptions> options; diff --git a/aten/src/ATen/core/OptionsGuard.h b/aten/src/ATen/core/OptionsGuard.h index f6f714f4a4..641d64cd27 100644 --- a/aten/src/ATen/core/OptionsGuard.h +++ b/aten/src/ATen/core/OptionsGuard.h @@ -1,15 +1,15 @@ #pragma once -#include <ATen/core/TensorOptions.h> #include <ATen/core/DefaultTensorOptions.h> -#include <ATen/core/Macros.h> +#include <ATen/core/TensorOptions.h> +#include <c10/macros/Macros.h> namespace at { /// Returns the current default options. CAFFE2_API const DefaultTensorOptions& getDefaultTensorOptions(); -#if !AT_MOBILE && !defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) +#if !C10_MOBILE && !defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) /// Get a mutable reference to the current thread local default options. CAFFE2_API DefaultTensorOptions& mutateDefaultTensorOptions(); @@ -38,7 +38,7 @@ struct OptionsGuard { DefaultTensorOptions original_; }; -#else // AT_MOBILE +#else // C10_MOBILE template<typename T = void> struct OptionsGuard { diff --git a/aten/src/ATen/core/SmallVector.h b/aten/src/ATen/core/SmallVector.h index 21b5268c45..787b13fb70 100644 --- a/aten/src/ATen/core/SmallVector.h +++ b/aten/src/ATen/core/SmallVector.h @@ -21,7 +21,7 @@ #pragma once #include <ATen/core/AlignOf.h> -#include <ATen/core/Macros.h> +#include <c10/macros/Macros.h> #include <algorithm> #include <cassert> diff --git a/aten/src/ATen/core/TensorAccessor.h b/aten/src/ATen/core/TensorAccessor.h index c9f23cbac9..442be6331e 100644 --- a/aten/src/ATen/core/TensorAccessor.h +++ b/aten/src/ATen/core/TensorAccessor.h @@ -1,8 +1,8 @@ #pragma once -#include <cstddef> +#include <c10/macros/Macros.h> #include <stdint.h> -#include <ATen/core/Macros.h> +#include <cstddef> namespace at { @@ -31,18 +31,30 @@ class TensorAccessorBase { public: typedef typename PtrTraits<T>::PtrType PtrType; - AT_HOST_DEVICE TensorAccessorBase(PtrType data_, const int64_t * sizes_, const int64_t * strides_) - : data_(data_), sizes_(sizes_), strides_(strides_) {} - AT_HOST IntList sizes() const { + C10_HOST_DEVICE TensorAccessorBase( + PtrType data_, + const int64_t* sizes_, + const int64_t* strides_) + : data_(data_), sizes_(sizes_), strides_(strides_) {} + C10_HOST IntList sizes() const { return IntList(sizes_,N); } - AT_HOST IntList strides() const { + C10_HOST IntList strides() const { return IntList(strides_,N); } - AT_HOST_DEVICE int64_t stride(int64_t i) const { return strides_[i]; } - AT_HOST_DEVICE int64_t size(int64_t i) const { return sizes_[i]; } - AT_HOST_DEVICE T *data() { return data_; } - AT_HOST_DEVICE const T *data() const { return data_; } + C10_HOST_DEVICE int64_t stride(int64_t i) const { + return strides_[i]; + } + C10_HOST_DEVICE int64_t size(int64_t i) const { + return sizes_[i]; + } + C10_HOST_DEVICE T* data() { + return data_; + } + C10_HOST_DEVICE const T* data() const { + return data_; + } + protected: PtrType data_; const int64_t* sizes_; @@ -58,14 +70,17 @@ class TensorAccessor : public TensorAccessorBase<T,N,PtrTraits> { public: typedef typename PtrTraits<T>::PtrType PtrType; - AT_HOST_DEVICE TensorAccessor(PtrType data_, const int64_t * sizes_, const int64_t * strides_) - : TensorAccessorBase<T,N>(data_,sizes_,strides_) {} + C10_HOST_DEVICE TensorAccessor( + PtrType data_, + const int64_t* sizes_, + const int64_t* strides_) + : TensorAccessorBase<T, N>(data_, sizes_, strides_) {} - AT_HOST_DEVICE TensorAccessor<T,N-1> operator[](int64_t i) { + C10_HOST_DEVICE TensorAccessor<T, N - 1> operator[](int64_t i) { return TensorAccessor<T,N-1>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); } - AT_HOST_DEVICE const TensorAccessor<T,N-1> operator[](int64_t i) const { + C10_HOST_DEVICE const TensorAccessor<T, N - 1> operator[](int64_t i) const { return TensorAccessor<T,N-1>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); } }; @@ -75,9 +90,12 @@ class TensorAccessor<T,1,PtrTraits> : public TensorAccessorBase<T,1,PtrTraits> { public: typedef typename PtrTraits<T>::PtrType PtrType; - AT_HOST_DEVICE TensorAccessor(PtrType data_, const int64_t * sizes_, const int64_t * strides_) - : TensorAccessorBase<T,1,PtrTraits>(data_,sizes_,strides_) {} - AT_HOST_DEVICE T & operator[](int64_t i) { + C10_HOST_DEVICE TensorAccessor( + PtrType data_, + const int64_t* sizes_, + const int64_t* strides_) + : TensorAccessorBase<T, 1, PtrTraits>(data_, sizes_, strides_) {} + C10_HOST_DEVICE T& operator[](int64_t i) { return this->data_[this->strides_[0]*i]; } }; @@ -95,14 +113,21 @@ template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPt class PackedTensorAccessorBase { public: typedef typename PtrTraits<T>::PtrType PtrType; - AT_HOST PackedTensorAccessorBase(PtrType data_, const int64_t * sizes_, const int64_t * strides_) - : data_(data_) - { + C10_HOST PackedTensorAccessorBase( + PtrType data_, + const int64_t* sizes_, + const int64_t* strides_) + : data_(data_) { std::copy(sizes_, sizes_ + N, std::begin(this->sizes_)); std::copy(strides_, strides_ + N, std::begin(this->strides_)); } - AT_HOST_DEVICE int64_t stride(int64_t i) const { return strides_[i]; } - AT_HOST_DEVICE int64_t size(int64_t i) const { return sizes_[i]; } + C10_HOST_DEVICE int64_t stride(int64_t i) const { + return strides_[i]; + } + C10_HOST_DEVICE int64_t size(int64_t i) const { + return sizes_[i]; + } + protected: PtrType data_; int64_t sizes_[N]; @@ -114,16 +139,19 @@ class PackedTensorAccessor : public PackedTensorAccessorBase<T,N,PtrTraits> { public: typedef typename PtrTraits<T>::PtrType PtrType; - AT_HOST PackedTensorAccessor(PtrType data_, const int64_t * sizes_, const int64_t * strides_) - : PackedTensorAccessorBase<T,N,PtrTraits>(data_, sizes_, strides_) {}; + C10_HOST PackedTensorAccessor( + PtrType data_, + const int64_t* sizes_, + const int64_t* strides_) + : PackedTensorAccessorBase<T, N, PtrTraits>(data_, sizes_, strides_){}; - AT_DEVICE TensorAccessor<T,N-1> operator[](int64_t i) { + C10_DEVICE TensorAccessor<T, N - 1> operator[](int64_t i) { int64_t* new_sizes = this->sizes_+1; int64_t* new_strides = this->strides_+1; return TensorAccessor<T,N-1>(this->data_ + this->strides_[0]*i, new_sizes, new_strides); } - AT_DEVICE const TensorAccessor<T,N-1> operator[](int64_t i) const { + C10_DEVICE const TensorAccessor<T, N - 1> operator[](int64_t i) const { int64_t* new_sizes = this->sizes_+1; int64_t* new_strides = this->strides_+1; return TensorAccessor<T,N-1>(this->data_ + this->strides_[0]*i, new_sizes, new_strides); @@ -134,13 +162,16 @@ template<typename T, template <typename U> class PtrTraits> class PackedTensorAccessor<T,1,PtrTraits> : public PackedTensorAccessorBase<T,1,PtrTraits> { public: typedef typename PtrTraits<T>::PtrType PtrType; - AT_HOST PackedTensorAccessor(PtrType data_, const int64_t * sizes_, const int64_t * strides_) - : PackedTensorAccessorBase<T,1,PtrTraits>(data_, sizes_, strides_) {}; + C10_HOST PackedTensorAccessor( + PtrType data_, + const int64_t* sizes_, + const int64_t* strides_) + : PackedTensorAccessorBase<T, 1, PtrTraits>(data_, sizes_, strides_){}; - AT_DEVICE T & operator[](int64_t i) { + C10_DEVICE T& operator[](int64_t i) { return this->data_[this->strides_[0]*i]; } - AT_DEVICE const T& operator[](int64_t i) const { + C10_DEVICE const T& operator[](int64_t i) const { return this->data_[this->strides_[0]*i]; } }; diff --git a/aten/src/ATen/core/TensorTypeId.h b/aten/src/ATen/core/TensorTypeId.h index 4092165b7d..b5382512e9 100644 --- a/aten/src/ATen/core/TensorTypeId.h +++ b/aten/src/ATen/core/TensorTypeId.h @@ -3,7 +3,7 @@ #include <iostream> #include <string> #include "ATen/core/IdWrapper.h" -#include "ATen/core/Macros.h" +#include "c10/macros/Macros.h" namespace at { diff --git a/aten/src/ATen/core/TensorTypeIdRegistration.h b/aten/src/ATen/core/TensorTypeIdRegistration.h index a42051f455..a4dd44d628 100644 --- a/aten/src/ATen/core/TensorTypeIdRegistration.h +++ b/aten/src/ATen/core/TensorTypeIdRegistration.h @@ -8,8 +8,8 @@ * Both must be in the same namespace. */ -#include "ATen/core/Macros.h" #include "ATen/core/TensorTypeId.h" +#include "c10/macros/Macros.h" #include <atomic> #include <mutex> diff --git a/aten/src/ATen/core/UniqueVoidPtr.h b/aten/src/ATen/core/UniqueVoidPtr.h index 59bced8c03..b4acde8957 100644 --- a/aten/src/ATen/core/UniqueVoidPtr.h +++ b/aten/src/ATen/core/UniqueVoidPtr.h @@ -1,7 +1,7 @@ #pragma once #include <memory> -#include <ATen/core/Macros.h> +#include <c10/macros/Macros.h> namespace at { diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 93512a222b..ff1b7b20c3 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -8,7 +8,7 @@ // To explicitly use interned strings as symbols in your code, you must add // them to this list. -#if !AT_MOBILE +#if !C10_MOBILE #define FORALL_ATEN_BASE_SYMBOLS(_) \ _(aten, RoiPooling2d_backward) \ _(aten, RoiPooling2d_forward) \ diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index bca00cd9c1..031024add7 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -6,11 +6,11 @@ #include <algorithm> #include <ATen/core/aten_interned_strings.h> -#include <ATen/core/Macros.h> +#include <c10/macros/Macros.h> namespace c10 { -#if !AT_MOBILE +#if !C10_MOBILE #define FORALL_NS_SYMBOLS(_) \ _(namespaces, prim) \ _(namespaces, aten) \ diff --git a/aten/src/ATen/core/typeid.h b/aten/src/ATen/core/typeid.h index 211ce19243..21f89389d1 100644 --- a/aten/src/ATen/core/typeid.h +++ b/aten/src/ATen/core/typeid.h @@ -20,7 +20,7 @@ #include "ATen/core/Backtrace.h" #include "ATen/core/Half.h" #include "ATen/core/IdWrapper.h" -#include "ATen/core/Macros.h" +#include "c10/macros/Macros.h" #include "c10/util/C++17.h" #include "c10/util/Exception.h" #include "caffe2/core/macros.h" diff --git a/aten/src/ATen/cuda/Array.h b/aten/src/ATen/cuda/Array.h index 05364aec69..4988210b13 100644 --- a/aten/src/ATen/cuda/Array.h +++ b/aten/src/ATen/cuda/Array.h @@ -2,7 +2,7 @@ // A fixed-size array type usable from CUDA kernels. -#include <ATen/core/Macros.h> +#include <c10/macros/Macros.h> namespace at { namespace cuda { @@ -10,15 +10,19 @@ template <typename T, int size> struct Array { T data[size]; - AT_HOST_DEVICE T operator[](int i) const { return data[i]; } - AT_HOST_DEVICE T& operator[](int i) { return data[i]; } + C10_HOST_DEVICE T operator[](int i) const { + return data[i]; + } + C10_HOST_DEVICE T& operator[](int i) { + return data[i]; + } - HIP_HOST_DEVICE Array() = default; - HIP_HOST_DEVICE Array(const Array&) = default; - HIP_HOST_DEVICE Array& operator=(const Array&) = default; + C10_HIP_HOST_DEVICE Array() = default; + C10_HIP_HOST_DEVICE Array(const Array&) = default; + C10_HIP_HOST_DEVICE Array& operator=(const Array&) = default; // Fill the array with x. - AT_HOST_DEVICE Array(T x) { + C10_HOST_DEVICE Array(T x) { for (int i = 0; i < size; i++) { data[i] = x; } diff --git a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh index 86ac849023..207bbb70bb 100644 --- a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh +++ b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh @@ -2,7 +2,7 @@ #include <array> #include <cstdint> -#include <ATen/core/Macros.h> +#include <c10/macros/Macros.h> #include <ATen/cuda/Array.h> #include <THC/THCIntegerDivider.cuh> @@ -29,7 +29,7 @@ struct OffsetCalculator { } } - AT_HOST_DEVICE offset_type get(uint32_t linear_idx) const { + C10_HOST_DEVICE offset_type get(uint32_t linear_idx) const { offset_type offsets; #pragma unroll for (int arg = 0; arg < NARGS; arg++) { diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 4ffb7a004a..c911a8aff7 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -60,25 +60,25 @@ struct ReduceConfig { return dim3(div_up(num_outputs, step_output), ctas_per_output); } - AT_HOST_DEVICE bool should_warp_reduce() const { + C10_HOST_DEVICE bool should_warp_reduce() const { return input_mult[LANE] != 0; } - AT_HOST_DEVICE bool should_block_reduce() const { + C10_HOST_DEVICE bool should_block_reduce() const { return input_mult[WARP] != 0; } - AT_HOST_DEVICE bool should_global_reduce() const { + C10_HOST_DEVICE bool should_global_reduce() const { return input_mult[CTA] != 0; } - AT_DEVICE bool should_store(int output_idx) const { + C10_DEVICE bool should_store(int output_idx) const { return output_idx < num_outputs && (!should_warp_reduce() || threadIdx.x == 0) && (!should_block_reduce() || threadIdx.y == 0); } - AT_HOST_DEVICE int input_idx() const { + C10_HOST_DEVICE int input_idx() const { int lane = threadIdx.x; int warp = threadIdx.y; int cta2 = blockIdx.y; @@ -87,7 +87,7 @@ struct ReduceConfig { cta2 * input_mult[CTA]); } - AT_HOST_DEVICE int output_idx() const { + C10_HOST_DEVICE int output_idx() const { int lane = threadIdx.x; int warp = threadIdx.y; int cta1 = blockIdx.x; @@ -96,11 +96,11 @@ struct ReduceConfig { cta1 * step_output); } - AT_DEVICE int shared_memory_offset(int offset) const { + C10_DEVICE int shared_memory_offset(int offset) const { return threadIdx.x + (threadIdx.y + offset) * blockDim.x; } - AT_DEVICE int staging_memory_offset(int cta2) const { + C10_DEVICE int staging_memory_offset(int cta2) const { int offset = cta2 + blockIdx.x * gridDim.y; if (!should_warp_reduce()) { offset = threadIdx.x + offset * blockDim.x; @@ -230,7 +230,7 @@ struct ReduceOp { , semaphores(semaphores) { } - AT_DEVICE void run() const { + C10_DEVICE void run() const { int output_idx = config.output_idx(); int input_idx = config.input_idx(); auto base_offsets = output_calc.get(output_idx); @@ -259,7 +259,7 @@ struct ReduceOp { } } - AT_DEVICE Array<scalar_t, vt0> load_inputs(const scalar_t* data, int offset) const { + C10_DEVICE Array<scalar_t, vt0> load_inputs(const scalar_t* data, int offset) const { int end = config.num_inputs; int stride = input_calc.strides_[0][0] / sizeof(scalar_t); if (input_calc.dims == 1) { @@ -273,7 +273,7 @@ struct ReduceOp { } } - AT_DEVICE arg_t thread_reduce_once(const scalar_t* data, int offset) const { + C10_DEVICE arg_t thread_reduce_once(const scalar_t* data, int offset) const { auto values = load_inputs(data, offset); arg_t value; @@ -284,7 +284,7 @@ struct ReduceOp { return value; } - AT_DEVICE arg_t thread_reduce(const scalar_t* data) const { + C10_DEVICE arg_t thread_reduce(const scalar_t* data) const { arg_t value = ident; int idx = config.input_idx(); while (idx < config.num_inputs) { @@ -295,7 +295,7 @@ struct ReduceOp { return value; } - AT_DEVICE arg_t warp_reduce(arg_t value) const { + C10_DEVICE arg_t warp_reduce(arg_t value) const { for (int offset = 1; offset < warpSize; offset <<= 1) { arg_t other = WARP_SHFL_DOWN(value, offset); value = op(value, other); @@ -303,7 +303,7 @@ struct ReduceOp { return value; } - AT_DEVICE arg_t block_reduce(arg_t value) const { + C10_DEVICE arg_t block_reduce(arg_t value) const { extern __shared__ char shared_memory[]; arg_t* shared = (arg_t*)shared_memory; shared[config.shared_memory_offset(0)] = value; @@ -319,7 +319,7 @@ struct ReduceOp { return value; } - AT_DEVICE bool mark_block_finished() const { + C10_DEVICE bool mark_block_finished() const { extern __shared__ int is_last_block_done_shared[]; __syncthreads(); @@ -335,7 +335,7 @@ struct ReduceOp { return is_last_block_done; } - AT_DEVICE arg_t global_reduce(arg_t value, scalar_t* out) const { + C10_DEVICE arg_t global_reduce(arg_t value, scalar_t* out) const { arg_t* reduce_buffer = (arg_t*)buffer; bool should_store = config.should_store(config.output_idx()); diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index ff00f62d54..2dcf0062da 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -88,4 +88,39 @@ namespace at {using namespace c10;} #define C10_UNLIKELY(expr) (expr) #endif +#include <sstream> +#include <string> + +#if defined(__CUDACC__) || defined(__HIPCC__) +// Designates functions callable from the host (CPU) and the device (GPU) +#define C10_HOST_DEVICE __host__ __device__ +#define C10_DEVICE __device__ +#define C10_HOST __host__ +#else +#define C10_HOST_DEVICE +#define C10_HOST +#define C10_DEVICE +#endif + +#ifdef __HIP_PLATFORM_HCC__ +#define C10_HIP_HOST_DEVICE __host__ __device__ +#else +#define C10_HIP_HOST_DEVICE +#endif + +#if defined(__ANDROID__) +#define C10_ANDROID 1 +#define C10_MOBILE 1 +#elif ( \ + defined(__APPLE__) && \ + (TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) +#define C10_IOS 1 +#define C10_MOBILE 1 +#elif (defined(__APPLE__) && TARGET_OS_MAC) +#define C10_IOS 1 +#define C10_MOBILE 0 +#else +#define C10_MOBILE 0 +#endif // ANDROID / IOS / MACOS + #endif // C10_MACROS_MACROS_H_ |