diff options
author | Edward Z. Yang <ezyang@mit.edu> | 2018-04-28 07:45:02 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-28 07:45:02 -0400 |
commit | 4caea64d729728e3304e9c1e97081a6bdd463913 (patch) | |
tree | b3063eae2c0e9edcfe28635b6d1b0427ac9a1e0b /aten | |
parent | 4667983f0f35a7336f2556ce9eeb6511f6bf3ac0 (diff) | |
download | pytorch-4caea64d729728e3304e9c1e97081a6bdd463913.tar.gz pytorch-4caea64d729728e3304e9c1e97081a6bdd463913.tar.bz2 pytorch-4caea64d729728e3304e9c1e97081a6bdd463913.zip |
Make all of TH and THC C++. (#6913)
Changelist:
- Move *.c to *.cpp
- Change includes of ".c" to ".cpp"
- A bunch of cmake configuration modifying CMAKE_C_FLAGS changed
to CMAKE_CXX_FLAGS or add_compile_options, because if you do CMAKE_C_FLAGS it only applies when you compile C code
- Explicitly cast void* to T* in a number of places
- Delete extern "C" { ... } blocks; instead, properly apply TH_API to everything that should have it (TH_API handles extern "C")
- Stop using stdatomic.h, instead, use <atomic>. This resulted in a bunch of placement-new/delete to be "totally properly correct"
- Refactor of THLongStorageView to not have static constructor methods (since it no longer has a copy/move constructor)
- Documentation about how the TH C interface (and extern C business) works
- Note that THD master_worker mode is dead
- C++ headers in TH libraries are given .hpp suffix, to make it less likely that you'll confuse them with the C-compatible headers (now suffixed .h)
- New function THCStream_stream and THCStream_device to project out fields of THCStream instead of accessing fields directly
- New function THStorage_(retainIfLive), which is equivalent to a retain but only if the refcount is greater than zero.
- In general, I tried to avoid using hpp headers outside of ATen/TH. However, there were a few places where I gave up and depended on the headers for my own sanity. See Note [TH abstraction violation] for all the sites where this occurred. All other sites were refactored to use functions
- Some extra Werror fixes (char* versus const char*)
Diffstat (limited to 'aten')
198 files changed, 810 insertions, 775 deletions
diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index e37db92cc2..9b1345bfb7 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -132,7 +132,7 @@ ENDIF() IF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5) MESSAGE(STATUS "Found CUDA with FP16 support, compiling with torch.CudaHalfTensor") LIST(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__") - SET(CMAKE_C_FLAGS "-DCUDA_HAS_FP16=1 ${CMAKE_C_FLAGS}") + add_compile_options(-DCUDA_HAS_FP16=1) ELSE(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5) MESSAGE(STATUS "Could not find CUDA with FP16 support, compiling without torch.CudaHalfTensor") ENDIF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5) @@ -162,7 +162,7 @@ IF (APPLE AND CMAKE_COMPILER_IS_GNUCC) IF (APPLE_OPENMP_SUCKS AND GCC_VERSION VERSION_LESS 4.6.2) MESSAGE(STATUS "Warning: Disabling OpenMP (unstable with this version of GCC)") MESSAGE(STATUS " Install GCC >= 4.6.2 or change your OS to enable OpenMP") - SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unknown-pragmas") + add_compile_options(-Wno-unknown-pragmas) SET(WITH_OPENMP OFF CACHE BOOL "OpenMP support if available?" FORCE) ENDIF () ENDIF () @@ -212,18 +212,18 @@ ENDIF() FIND_PACKAGE(ARM) IF (ASIMD_FOUND) MESSAGE(STATUS "asimd/Neon found with compiler flag : -D__NEON__") - SET(CMAKE_C_FLAGS "-D__NEON__ ${CMAKE_C_FLAGS}") + add_compile_options(-D__NEON__) ELSEIF (NEON_FOUND) MESSAGE(STATUS "Neon found with compiler flag : -mfpu=neon -D__NEON__") - SET(CMAKE_C_FLAGS "-mfpu=neon -D__NEON__ ${CMAKE_C_FLAGS}") + add_compile_options(-mfpu=neon -D__NEON__) ENDIF (ASIMD_FOUND) IF (CORTEXA8_FOUND) MESSAGE(STATUS "Cortex-A8 Found with compiler flag : -mcpu=cortex-a8") - SET(CMAKE_C_FLAGS "-mcpu=cortex-a8 -fprefetch-loop-arrays ${CMAKE_C_FLAGS}") + add_compile_options(-mcpu=cortex-a8 -fprefetch-loop-arrays) ENDIF (CORTEXA8_FOUND) IF (CORTEXA9_FOUND) MESSAGE(STATUS "Cortex-A9 Found with compiler flag : -mcpu=cortex-a9") - SET(CMAKE_C_FLAGS "-mcpu=cortex-a9 ${CMAKE_C_FLAGS}") + add_compile_options(-mcpu=cortex-a9) ENDIF (CORTEXA9_FOUND) IF(UNIX) @@ -264,7 +264,7 @@ IF(HAVE_CPUID_H) }" HAVE_GCC_GET_CPUID) ENDIF() IF(HAVE_GCC_GET_CPUID) - SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DHAVE_GCC_GET_CPUID") + add_compile_options(-DHAVE_GCC_GET_CPUID) ENDIF(HAVE_GCC_GET_CPUID) CHECK_C_SOURCE_COMPILES("#include <stdint.h> @@ -282,21 +282,28 @@ CHECK_C_SOURCE_COMPILES("#include <stdint.h> }" NO_GCC_EBX_FPIC_BUG) IF(NOT NO_GCC_EBX_FPIC_BUG) - SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_GCC_GET_CPUID") + add_compile_options(-DUSE_GCC_GET_CPUID) ENDIF(NOT NO_GCC_EBX_FPIC_BUG) FIND_PACKAGE(SSE) # checks SSE, AVX and AVX2 IF(C_SSE2_FOUND) MESSAGE(STATUS "SSE2 Found") - SET(CMAKE_C_FLAGS "${C_SSE2_FLAGS} -DUSE_SSE2 ${CMAKE_C_FLAGS}") + # TODO: Work out correct way to do this. Note that C_SSE2_FLAGS is often + # empty, in which case it expands to " " flag which is bad + SET(CMAKE_C_FLAGS "${C_SSE2_FLAGS} ${CMAKE_C_FLAGS}") + SET(CMAKE_CXX_FLAGS "${C_SSE2_FLAGS} ${CMAKE_CXX_FLAGS}") + add_compile_options(-DUSE_SSE2) ENDIF(C_SSE2_FOUND) IF(C_SSE4_1_FOUND AND C_SSE4_2_FOUND) - SET(CMAKE_C_FLAGS "${C_SSE4_1_FLAGS} -DUSE_SSE4_1 ${C_SSE4_2_FLAGS} -DUSE_SSE4_2 ${CMAKE_C_FLAGS}") + SET(CMAKE_C_FLAGS "${C_SSE4_1_FLAGS} ${C_SSE4_2_FLAGS} ${CMAKE_C_FLAGS}") + SET(CMAKE_CXX_FLAGS "${C_SSE4_1_FLAGS} ${C_SSE4_2_FLAGS} ${CMAKE_CXX_FLAGS}") + add_compile_options(-DUSE_SSE4_1 -DUSE_SSE4_2) ENDIF() IF(C_SSE3_FOUND) MESSAGE(STATUS "SSE3 Found") - SET(CMAKE_C_FLAGS "${C_SSE3_FLAGS} -DUSE_SSE3 ${CMAKE_C_FLAGS}") - SET(CMAKE_CXX_FLAGS "${C_SSE3_FLAGS} -DUSE_SSE3 ${CMAKE_CXX_FLAGS}") + SET(CMAKE_C_FLAGS "${C_SSE3_FLAGS} ${CMAKE_C_FLAGS}") + SET(CMAKE_CXX_FLAGS "${C_SSE3_FLAGS} ${CMAKE_CXX_FLAGS}") + add_compile_options(-DUSE_SSE3) ENDIF(C_SSE3_FOUND) # we don't set -mavx and -mavx2 flags globally, but only for specific files @@ -304,12 +311,11 @@ ENDIF(C_SSE3_FOUND) # add USE_AVX and USE_AVX2 macro defines IF(C_AVX_FOUND) MESSAGE(STATUS "AVX Found") - SET(CMAKE_C_FLAGS "-DUSE_AVX ${CMAKE_C_FLAGS}") + add_compile_options(-DUSE_AVX) ENDIF(C_AVX_FOUND) IF(C_AVX2_FOUND) MESSAGE(STATUS "AVX2 Found") - SET(CMAKE_C_FLAGS "-DUSE_AVX2 ${CMAKE_C_FLAGS}") - SET(CMAKE_CXX_FLAGS "-DUSE_AVX2 ${CMAKE_CXX_FLAGS}") + add_compile_options(-DUSE_AVX2) ENDIF(C_AVX2_FOUND) CHECK_C_SOURCE_RUNS(" diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 13670aa73e..b064816643 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -20,32 +20,31 @@ ENDIF() # so we need to set these commands here rather than in src/TH IF(C_SSE4_1_FOUND AND C_SSE4_2_FOUND) IF(MSVC) - SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_sse.c PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/fp:fast") + SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_sse.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/fp:fast") ELSE(MSVC) - SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_sse.c PROPERTIES COMPILE_FLAGS "-O3 -ffast-math") + SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_sse.cpp PROPERTIES COMPILE_FLAGS "-O3 -ffast-math") ENDIF(MSVC) ENDIF(C_SSE4_1_FOUND AND C_SSE4_2_FOUND) IF(C_AVX_FOUND) IF(MSVC) - SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/fp:fast ${C_AVX_FLAGS}") - SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX.c PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/arch:AVX ${C_AVX_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_avx.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/fp:fast ${CXX_AVX_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/arch:AVX ${CXX_AVX_FLAGS}") ELSE(MSVC) - SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "-O3 -ffast-math ${C_AVX_FLAGS}") - SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX.c PROPERTIES COMPILE_FLAGS "-O3 ${C_AVX_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_avx.cpp PROPERTIES COMPILE_FLAGS "-O3 -ffast-math ${CXX_AVX_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "-O3 ${CXX_AVX_FLAGS}") ENDIF(MSVC) ENDIF(C_AVX_FOUND) IF(C_AVX2_FOUND) IF(MSVC) - SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/arch:AVX2 ${C_AVX2_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/arch:AVX2 ${CXX_AVX2_FLAGS}") ELSE(MSVC) - SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.cpp PROPERTIES COMPILE_FLAGS "-O3 ${C_AVX2_FLAGS}") + SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.cpp PROPERTIES COMPILE_FLAGS "-O3 ${CXX_AVX2_FLAGS}") ENDIF(MSVC) ENDIF(C_AVX2_FOUND) IF(NOT MSVC AND NOT "${CMAKE_C_COMPILER_ID}" MATCHES "Clang") - SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/THAtomic.c PROPERTIES COMPILE_FLAGS "-fno-openmp") - SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/THAllocator.c PROPERTIES COMPILE_FLAGS "-fno-openmp") + SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/THAllocator.cpp PROPERTIES COMPILE_FLAGS "-fno-openmp") ENDIF() FILE(GLOB cpu_kernel_cpp_in RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cpu/*.cpp") @@ -332,7 +331,7 @@ ENDIF(NOT MSVC) IF(NOT C_HAS_THREAD) MESSAGE(STATUS "Warning: __thread is not supported, generating thread-unsafe code") ELSE(NOT C_HAS_THREAD) - SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DTH_HAVE_THREAD") + add_compile_options(-DTH_HAVE_THREAD) ENDIF(NOT C_HAS_THREAD) if(MKLDNN_FOUND) diff --git a/aten/src/ATen/THLongStorageView.h b/aten/src/ATen/THLongStorageView.h index 0209cdf53d..d6c4c8e285 100644 --- a/aten/src/ATen/THLongStorageView.h +++ b/aten/src/ATen/THLongStorageView.h @@ -1,6 +1,7 @@ #pragma once #include "TH/TH.h" +#include "TH/THStorage.hpp" namespace at { @@ -10,34 +11,63 @@ static inline bool is_noelem_tensor_size(ArrayRef<int64_t> size) { return size.size() == 1 && size[0] == 0; } +enum class THLongStorageViewKind { + SIZE, + // noelem_to_empty is to differentiate strides of empty tensors vs scalars. In ATen, both may have strides [1], + // but in TH an empty tensor should have stride [], while a scalar should have stride [1]. + STRIDE_EMPTY_TENSOR, // noelem_to_empty = true + STRIDE_SCALAR, // noelem_to_empty = false + LENGTH, +}; + // make a fake storage out of a size, pointer pair... // used as an argument where THSize and THStride are passed into TH class THLongStorageView { public: + operator THLongStorage*() { + if (storage.size == 0 && zero_dim_to_null) { + return nullptr; + } + return &storage; + } + + /* + // This is done as an enum, and not as these static constructors, as there + // is no move/copy constructor for THLongStorageView + static THLongStorageView makeFromSize(ArrayRef<int64_t> ref) { return THLongStorageView(ref, true, false, false); } - // noelem_to_empty is to differentiate strides of empty tensors vs scalars. In ATen, both may have strides [1], - // but in TH an empty tensor should have stride [], while a scalar should have stride [1]. static THLongStorageView makeFromStride(ArrayRef<int64_t> ref, bool noelem_to_empty) { return THLongStorageView(ref, false, true, noelem_to_empty); } static THLongStorageView makeFromLength(ArrayRef<int64_t> ref) { return THLongStorageView(ref, false, false, false); } - operator THLongStorage*() { - if (storage.size == 0 && zero_dim_to_null) { - return nullptr; - } - return &storage; - } -private: - // zero_dim_to_one converts an empty ArrayRef into [1] - // zero_dim_to_null converts an empty ArrayRef into a null THLongStorage - // noelem_to_empty makes an ArrayRef of [0] into an empty THLongStorage - THLongStorageView(ArrayRef<int64_t> ref, bool zero_dim_to_one, bool zero_dim_to_null, bool noelem_to_empty) - : zero_dim_to_null(zero_dim_to_null) + */ + + THLongStorageView(ArrayRef<int64_t> ref, THLongStorageViewKind kind) + : zero_dim_to_null(false) { + // zero_dim_to_one converts an empty ArrayRef into [1] + // zero_dim_to_null converts an empty ArrayRef into a null THLongStorage + // noelem_to_empty makes an ArrayRef of [0] into an empty THLongStorage + bool zero_dim_to_one = false; + bool noelem_to_empty = false; + switch (kind) { + case THLongStorageViewKind::SIZE: + zero_dim_to_one = true; + break; + case THLongStorageViewKind::STRIDE_EMPTY_TENSOR: + zero_dim_to_null = true; + noelem_to_empty = true; + break; + case THLongStorageViewKind::STRIDE_SCALAR: + zero_dim_to_null = true; + break; + case THLongStorageViewKind::LENGTH: + break; + } if(zero_dim_to_one && ref.size() == 0) { // make storage of size 0 actually a 1-length storage with 1 element // so that our 0-dim tensors get allocated as 1-dim inside TH @@ -57,6 +87,7 @@ private: storage.allocator = nullptr; storage.allocatorContext = nullptr; } +private: int64_t one; THLongStorage storage; bool zero_dim_to_null; diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 266b1505e1..36d8139368 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -233,8 +233,14 @@ CHECKED_CAST = { 'THGenerator*': CodeTemplate( 'check_generator<${Backend}Generator>(${arg_name}, &context->defaultGenerator(backend()))'), - 'THSize*': CodeTemplate('THLongStorageView::makeFromSize(${arg_name})'), - 'THStride*': CodeTemplate('THLongStorageView::makeFromStride(${arg_name}, ${noelem_to_empty})'), + # This is a cast done via direct-construction + 'THSize*': CodeTemplate('THLongStorageView ${result_name}(${arg_name}, THLongStorageViewKind::SIZE);'), + # This is a cast done via direct-construction + 'THStride*': + CodeTemplate( + 'THLongStorageView ${result_name}(${arg_name}, ' + '${noelem_to_empty} ? ' + 'THLongStorageViewKind::STRIDE_EMPTY_TENSOR : THLongStorageViewKind::STRIDE_SCALAR);'), 'real': CodeTemplate('${arg_name}.to${ScalarName}()'), 'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'), 'TensorList': CodeTemplate('tensor_list_checked_cast<${Tensor}, Tensor, ' @@ -242,6 +248,8 @@ CHECKED_CAST = { 'IntList': CodeTemplate('check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos}${,default_init})') } +DIRECT_CONSTRUCTION_CHECKED_CAST = {'THSize*', 'THStride*'} + CHECKED_USE = { 'THTensor*': '{}_->tensor', 'THSTensor*': '{}_->tensor', @@ -271,7 +279,7 @@ ALLOC_WRAP = { CONSTANT_REPLACEMENTS = [ ('AS_REAL', '${AS_REAL}'), ('__storage_size.get\\(\\)', - 'THLongStorageView::makeFromLength(static_cast<int64_t>(storage.size()))'), + 'THLongStorageView(static_cast<int64_t>(storage.size()), THLongStorageViewKind::LENGTH)'), ('__last_dim', 'self.ndimension()-1'), ] @@ -1235,13 +1243,21 @@ def create_derived(backend_type_env, declarations): default_init.append(arg['default_init']) noelem_to_empty = 'is_noelem_tensor_size(size)' if 'size' in seen_names else 'false' - check_cast = CHECKED_CAST[arg['type']].substitute( - env, arg_name=arg['name'], arg_pos=count, - null_okay=null_okay, default_init=default_init, - size=arg.get('size'), - noelem_to_empty=noelem_to_empty) - body.append("auto {}_ = {};".format( - arg['name'], check_cast)) + if arg['type'] in DIRECT_CONSTRUCTION_CHECKED_CAST: + body.append(CHECKED_CAST[arg['type']].substitute( + env, arg_name=arg['name'], arg_pos=count, + null_okay=null_okay, default_init=default_init, + size=arg.get('size'), + noelem_to_empty=noelem_to_empty, + result_name=arg['name'] + '_')) + else: + check_cast = CHECKED_CAST[arg['type']].substitute( + env, arg_name=arg['name'], arg_pos=count, + null_okay=null_okay, default_init=default_init, + size=arg.get('size'), + noelem_to_empty=noelem_to_empty) + body.append("auto {}_ = {};".format( + arg['name'], check_cast)) if drop_argument(arg, option) or replace_with_null(arg): body.append( "(void) {}_; //silence unused warning".format(arg['name'])) diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index edece1c1fe..eec7206f13 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -241,10 +241,12 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations if backend == 'CUDA': env['th_headers'] = [ '#include <THC/THC.h>', + '#include <THC/THCTensor.hpp>', '#include <THCUNN/THCUNN.h>', '#undef THNN_', '#undef THCIndexTensor_', '#include <THCS/THCS.h>', + '#include <THCS/THCSTensor.hpp>', '#undef THCIndexTensor_', ] env['extra_cuda_headers'] = ['#include <ATen/cuda/CUDAHalf.cuh>'] @@ -263,9 +265,11 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations else: env['th_headers'] = [ '#include <TH/TH.h>', + '#include <TH/THTensor.hpp>', '#include <THNN/THNN.h>', '#undef THNN_', '#include <THS/THS.h>', + '#include <THS/THSTensor.hpp>', ] env['extra_cuda_headers'] = [] env['THType'] = scalar_name diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 01d8b82a90..6816cfb700 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -13,7 +13,7 @@ #include <functional> #include "TH/THRandom.h" -#include "TH/THGenerator.h" +#include "TH/THGenerator.hpp" #include "TH/THMath.h" namespace { diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu index 42d239d02b..ee22566555 100644 --- a/aten/src/ATen/native/cuda/Distributions.cu +++ b/aten/src/ATen/native/cuda/Distributions.cu @@ -16,11 +16,9 @@ #include "ATen/native/Distributions.h" -#include <TH/THAtomic.h> - #include <THC/THCGeneral.h> #include <THC/THCTensorRandom.h> -#include <THC/THCGenerator.h> +#include <THC/THCGenerator.hpp> #include <THC/THCApply.cuh> #include <THC/THCNumerics.cuh> @@ -32,7 +30,7 @@ THCGenerator* THCRandom_getGenerator(THCState* state); namespace { std::pair<uint64_t, uint64_t> next_philox_seed(at::Generator* gen, uint64_t increment) { auto gen_ = THCRandom_getGenerator(at::globalContext().thc_state); - uint64_t offset = THAtomicAddLong(&gen_->state.philox_seed_offset, increment); + uint64_t offset = gen_->state.philox_seed_offset.fetch_add(increment); return std::make_pair(gen_->state.initial_seed, offset); } diff --git a/aten/src/README.md b/aten/src/README.md index aa3002cb82..ebc7cc99d6 100644 --- a/aten/src/README.md +++ b/aten/src/README.md @@ -107,3 +107,38 @@ function call, e.g., `kernel = THTensor_(newContiguous2D)(k_)`. to call `THError` before performing any allocations, since in some cases we sketchily throw a C++ exception and try to recover (in particular, the test suite does this.) + +## The C interface + +Historically, the Torch libraries were implemented in C. Since then, we have slowly +started rewriting bits of pieces of Torch in C++ (usually because there is some +C++ feature which would be really helpful for writing something.) However, +Torch has *always been*, and *will always be* a library that provides a C ABI +interface, even if, at some point in the future, its internal implementation +is entirely done in a C++ library that heavily uses C++ idioms. (At the moment, +all of the source files are C++, but they are mostly C code that happens to be +compiled as C++). + +In order to achieve this, the `TH_API` macro (called `THC_API` in `THC`) plays +a crucial role: it declares a function as having C-linkage, which means that the +C++ compiler doesn't mangle its name and a C client can link against it. + +As a developer, here is what you need to know: + +1. If you add a function to the public API of Torch, you *must* mark it with + `TH_API` or `THC_API` (depending if you are in CPU or CUDA land). + This will ensure it is built with C-linkage (and on Windows, it + will also ensure that the symbol is exported from the DLL; otherwise it + won't be visible.) + +2. C++ features should ONLY be used in `.cpp` and `.hpp` files, and not in + `.h` files. If you need to use a C++ type in a header file, you should + define this in a separate, C++ only header `.hpp`, and declare it opaquely + in the `.h`. Search for `mutex` for an example of this principle being applied. + (This convention is OPPOSITE from the prevailing convention in PyTorch and + ATen, where C++ headers are defined in `.h` files.) + +Arguably, the "C-compatible" headers should live in a separate directory, +distinct from the C++ code. We think this might be a good thing to do +eventually, and would make the code structure more clear, but we have not +done it at the moment. diff --git a/aten/src/TH/CMakeLists.txt b/aten/src/TH/CMakeLists.txt index ce1d5f3d56..6026c58dc4 100644 --- a/aten/src/TH/CMakeLists.txt +++ b/aten/src/TH/CMakeLists.txt @@ -2,18 +2,18 @@ set(extra_src) # IF ANY SIMD FOUND IF(C_AVX2_FOUND OR C_AVX_FOUND OR C_SSE4_2_FOUND OR C_SSE4_1_FOUND) - LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/generic/simd/convolve.c) + LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/generic/simd/convolve.cpp) ENDIF(C_AVX2_FOUND OR C_AVX_FOUND OR C_SSE4_2_FOUND OR C_SSE4_1_FOUND) # IF SSE4 FOUND IF(C_SSE4_1_FOUND OR C_SSE4_2_FOUND) - LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/generic/simd/convolve5x5_sse.c) + LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/generic/simd/convolve5x5_sse.cpp) ENDIF(C_SSE4_1_FOUND OR C_SSE4_2_FOUND) # IF AVX FOUND IF(C_AVX_FOUND) - LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/vector/AVX.c) - LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/generic/simd/convolve5x5_avx.c) + LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/vector/AVX.cpp) + LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/generic/simd/convolve5x5_avx.cpp) ENDIF(C_AVX_FOUND) IF(C_AVX2_FOUND) @@ -22,26 +22,28 @@ ENDIF(C_AVX2_FOUND) SET(hdr THGeneral.h THHalf.h THAllocator.h THSize.h THStorage.h THTensor.h THTensorApply.h THBlas.h THMath.h - THLapack.h THLogAdd.h THRandom.h THVector.h THAtomic.h ) + THLapack.h THLogAdd.h THRandom.h THVector.h ) -set(ATen_CPU_SRCS ${ATen_CPU_SRCS} - ${CMAKE_CURRENT_SOURCE_DIR}/THGeneral.c - ${CMAKE_CURRENT_SOURCE_DIR}/THHalf.c - ${CMAKE_CURRENT_SOURCE_DIR}/THAllocator.c - ${CMAKE_CURRENT_SOURCE_DIR}/THSize.c - ${CMAKE_CURRENT_SOURCE_DIR}/THStorage.c +set(ATen_TH_SRCS + ${CMAKE_CURRENT_SOURCE_DIR}/THGeneral.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THHalf.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THAllocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THSize.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THStorage.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THTensor.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/THBlas.c + ${CMAKE_CURRENT_SOURCE_DIR}/THBlas.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THLapack.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/THLogAdd.c + ${CMAKE_CURRENT_SOURCE_DIR}/THLogAdd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THRandom.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/THFile.c - ${CMAKE_CURRENT_SOURCE_DIR}/THDiskFile.c - ${CMAKE_CURRENT_SOURCE_DIR}/THMemoryFile.c - ${CMAKE_CURRENT_SOURCE_DIR}/THAtomic.c + ${CMAKE_CURRENT_SOURCE_DIR}/THFile.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THDiskFile.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/THMemoryFile.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THVector.cpp ${extra_src} -PARENT_SCOPE) + ) +# Remember that PARENT_SCOPE variables are not in the current scope +set(ATen_TH_SRCS ${ATen_TH_SRCS} PARENT_SCOPE) +set(ATen_CPU_SRCS ${ATen_CPU_SRCS} ${ATen_TH_SRCS} PARENT_SCOPE) ###################################################### @@ -83,8 +85,10 @@ INSTALL(FILES THTensorDimApply.h THTensorMacros.h THVector.h - THAtomic.h THHalf.h + THTensor.hpp + THStorage.hpp + THGenerator.hpp DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/TH") INSTALL(FILES @@ -94,26 +98,29 @@ INSTALL(FILES DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/TH/vector") INSTALL(FILES - generic/THBlas.c + generic/THBlas.cpp generic/THBlas.h generic/THLapack.cpp generic/THLapack.h - generic/THStorage.c + generic/THStorage.cpp generic/THStorage.h - generic/THStorageCopy.c + generic/THStorageCopy.cpp generic/THStorageCopy.h generic/THTensor.cpp generic/THTensor.h generic/THTensorConv.cpp generic/THTensorConv.h - generic/THTensorCopy.c + generic/THTensorCopy.cpp generic/THTensorCopy.h - generic/THTensorLapack.c + generic/THTensorLapack.cpp generic/THTensorLapack.h - generic/THTensorMath.c + generic/THTensorMath.cpp generic/THTensorMath.h generic/THTensorRandom.cpp generic/THTensorRandom.h generic/THVectorDispatch.cpp generic/THVector.h + # See Note [TH abstraction violation] + generic/THStorage.hpp + generic/THTensor.hpp DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/TH/generic") diff --git a/aten/src/TH/TH.h b/aten/src/TH/TH.h index 11f208c4b1..08bdde867c 100644 --- a/aten/src/TH/TH.h +++ b/aten/src/TH/TH.h @@ -8,7 +8,6 @@ #include "THLapack.h" #endif -#include "THAtomic.h" #include "THVector.h" #include "THLogAdd.h" #include "THRandom.h" diff --git a/aten/src/TH/THAllocator.c b/aten/src/TH/THAllocator.cpp index 92f3cdaff0..30ae542e79 100644 --- a/aten/src/TH/THAllocator.c +++ b/aten/src/TH/THAllocator.cpp @@ -1,17 +1,9 @@ #include "THAllocator.h" -#include "THAtomic.h" -/* needed for ATOMIC_INT_LOCK_FREE */ -/* cannot go in THAtomic.h because of interactions with OpenMP giving - sorry not implemented errors */ -#if defined(USE_C11_ATOMICS) -#include <stdatomic.h> +#include <atomic> #if ATOMIC_INT_LOCK_FREE == 2 #define TH_ATOMIC_IPC_REFCOUNT 1 #endif -#elif defined(USE_MSC_ATOMICS) || defined(USE_GCC_ATOMICS) -#define TH_ATOMIC_IPC_REFCOUNT 1 -#endif /* stuff for mapped files */ #ifdef _WIN32 @@ -66,14 +58,14 @@ typedef struct { int refcount; } THMapInfo; -char * unknown_filename = "filename not specified"; +const char * unknown_filename = "filename not specified"; #ifdef _WIN32 -char * unknown_eventname = "eventname not specified"; +const char * unknown_eventname = "eventname not specified"; #endif THMapAllocatorContext *THMapAllocatorContext_new(const char *filename, int flags) { - THMapAllocatorContext *ctx = THAlloc(sizeof(THMapAllocatorContext)); + THMapAllocatorContext *ctx = static_cast<THMapAllocatorContext*>(THAlloc(sizeof(THMapAllocatorContext))); if (!(flags & TH_ALLOCATOR_MAPPED_SHARED) && !(flags & TH_ALLOCATOR_MAPPED_SHAREDMEM)) flags &= ~TH_ALLOCATOR_MAPPED_NOCREATE; @@ -81,22 +73,24 @@ THMapAllocatorContext *THMapAllocatorContext_new(const char *filename, int flags THError("TH_ALLOCATOR_MAPPED_EXCLUSIVE flag requires opening the file " "in shared mode"); - if (filename) { - ctx->filename = THAlloc(strlen(filename)+1); - strcpy(ctx->filename, filename); + if (!filename) { + filename = unknown_filename; + } + ctx->filename = static_cast<char*>(THAlloc(strlen(filename)+1)); + strcpy(ctx->filename, filename); #ifdef _WIN32 - char *suffixname = "_event"; + if (filename == unknown_filename) { + size_t namelen = strlen(unknown_eventname)+1; + ctx->eventname = static_cast<char*>(THAlloc(namelen)); + strcpy(ctx->eventname, unknown_eventname); + } else { + const char *suffixname = "_event"; size_t namelen = strlen(filename)+1+strlen(suffixname); - ctx->eventname = THAlloc(namelen); + ctx->eventname = static_cast<char*>(THAlloc(namelen)); strcpy(ctx->eventname, ctx->filename); strcat(ctx->eventname, suffixname); -#endif - } else { - ctx->filename = unknown_filename; -#ifdef _WIN32 - ctx->eventname = unknown_eventname; -#endif } +#endif ctx->flags = flags; ctx->size = 0; @@ -142,12 +136,10 @@ ptrdiff_t THMapAllocatorContext_size(THMapAllocatorContext *ctx) void THMapAllocatorContext_free(THMapAllocatorContext *ctx) { - if (ctx->filename != unknown_filename) { - THFree(ctx->filename); + THFree(ctx->filename); #ifdef _WIN32 - THFree(ctx->eventname); + THFree(ctx->eventname); #endif - } THFree(ctx); } @@ -178,7 +170,7 @@ static void *_map_alloc(void* ctx_, ptrdiff_t size) if (size == 0) return NULL; - THMapAllocatorContext *ctx = ctx_; + THMapAllocatorContext *ctx = static_cast<THMapAllocatorContext*>(ctx_); void *data = NULL; #ifdef _WIN32 @@ -451,7 +443,7 @@ static void THMapAllocator_free(void* ctx_, void* data) { if (data == NULL) return; - THMapAllocatorContext *ctx = ctx_; + THMapAllocatorContext *ctx = static_cast<THMapAllocatorContext *>(ctx_); #ifdef _WIN32 if ((ctx->flags & TH_ALLOCATOR_MAPPED_KEEPFD) || (ctx->flags & TH_ALLOCATOR_MAPPED_SHAREDMEM)) @@ -514,7 +506,7 @@ static void THMapAllocator_free(void* ctx, void* data) { #if (defined(_WIN32) || defined(HAVE_MMAP)) && defined(TH_ATOMIC_IPC_REFCOUNT) static void * THRefcountedMapAllocator_alloc(void *_ctx, ptrdiff_t size) { - THMapAllocatorContext *ctx = _ctx; + THMapAllocatorContext *ctx = static_cast<THMapAllocatorContext *>(_ctx); if (ctx->flags & TH_ALLOCATOR_MAPPED_FROMFD) THError("THRefcountedMapAllocator doesn't support TH_ALLOCATOR_MAPPED_FROMFD flag"); @@ -544,7 +536,7 @@ static void * THRefcountedMapAllocator_alloc(void *_ctx, ptrdiff_t size) { if (ctx->flags & TH_ALLOCATOR_MAPPED_EXCLUSIVE) map_info->refcount = 1; else - THAtomicIncrementRef(&map_info->refcount); + map_info->refcount++; return (void*)data; } @@ -555,11 +547,11 @@ static void *THRefcountedMapAllocator_realloc(void* ctx, void* ptr, ptrdiff_t si } static void THRefcountedMapAllocator_free(void* ctx_, void* data) { - THMapAllocatorContext *ctx = ctx_; + THMapAllocatorContext *ctx = static_cast<THMapAllocatorContext *>(ctx_); #ifdef _WIN32 THMapInfo *info = (THMapInfo*)(((char*)data) - TH_ALLOC_ALIGNMENT); - if (THAtomicDecrementRef(&info->refcount)) { + if (--info->refcount == 0) { SetEvent(ctx->event); } if(UnmapViewOfFile(((char*)data) - TH_ALLOC_ALIGNMENT) == 0) @@ -567,7 +559,7 @@ static void THRefcountedMapAllocator_free(void* ctx_, void* data) { #else /* _WIN32 */ THMapInfo *info = (THMapInfo*)(((char*)data) - TH_ALLOC_ALIGNMENT); - if (THAtomicDecrementRef(&info->refcount)) { + if (--info->refcount == 0) { #ifdef HAVE_SHM_UNLINK if (shm_unlink(ctx->filename) == -1) THError("could not unlink the shared memory file %s", ctx->filename); @@ -585,13 +577,13 @@ static void THRefcountedMapAllocator_free(void* ctx_, void* data) { void THRefcountedMapAllocator_incref(THMapAllocatorContext *ctx, void *data) { THMapInfo *map_info = (THMapInfo*)(((char*)data) - TH_ALLOC_ALIGNMENT); - THAtomicIncrementRef(&map_info->refcount); + ++map_info->refcount; } int THRefcountedMapAllocator_decref(THMapAllocatorContext *ctx, void *data) { THMapInfo *map_info = (THMapInfo*)(((char*)data) - TH_ALLOC_ALIGNMENT); - return THAtomicDecrementRef(&map_info->refcount); + return --map_info->refcount == 0; } #else diff --git a/aten/src/TH/THAtomic.c b/aten/src/TH/THAtomic.c deleted file mode 100644 index 16f0ddb480..0000000000 --- a/aten/src/TH/THAtomic.c +++ /dev/null @@ -1,265 +0,0 @@ -#include "THAtomic.h" - -/* - Note: I thank Leon Bottou for his useful comments. - Ronan. -*/ - -#if defined(USE_C11_ATOMICS) -#include <stdatomic.h> -#endif - -#if defined(USE_MSC_ATOMICS) -#include <intrin.h> -#include <assert.h> -#endif - -#if !defined(USE_MSC_ATOMICS) && !defined(USE_GCC_ATOMICS) && defined(USE_PTHREAD_ATOMICS) -#include <pthread.h> -static pthread_mutex_t ptm = PTHREAD_MUTEX_INITIALIZER; -#endif - -void THAtomicSet(int32_t volatile *a, int32_t newvalue) -{ -#if defined(USE_C11_ATOMICS) - atomic_store(a, newvalue); -#elif defined(USE_MSC_ATOMICS) - assert(sizeof(int) == sizeof(int32_t)); - _InterlockedExchange((int32_t*)a, newvalue); -#elif defined(USE_GCC_ATOMICS) - __sync_lock_test_and_set(a, newvalue); -#else - int32_t oldvalue; - do { - oldvalue = *a; - } while (!THAtomicCompareAndSwap(a, oldvalue, newvalue)); -#endif -} - -int THAtomicGet(int32_t volatile *a) -{ -#if defined(USE_C11_ATOMICS) - return atomic_load(a); -#else - int32_t value; - do { - value = *a; - } while (!THAtomicCompareAndSwap(a, value, value)); - return value; -#endif -} - -int THAtomicAdd(int32_t volatile *a, int32_t value) -{ -#if defined(USE_C11_ATOMICS) - return atomic_fetch_add(a, value); -#elif defined(USE_MSC_ATOMICS) - return _InterlockedExchangeAdd((int32_t*)a, value); -#elif defined(USE_GCC_ATOMICS) - return __sync_fetch_and_add(a, value); -#else - int32_t oldvalue; - do { - oldvalue = *a; - } while (!THAtomicCompareAndSwap(a, oldvalue, (oldvalue + value))); - return oldvalue; -#endif -} - -void THAtomicIncrementRef(int32_t volatile *a) -{ - THAtomicAdd(a, 1); -} - -int THAtomicDecrementRef(int32_t volatile *a) -{ - return (THAtomicAdd(a, -1) == 1); -} - -int THAtomicCompareAndSwap(int32_t volatile *a, int32_t oldvalue, int32_t newvalue) -{ -#if defined(USE_C11_ATOMICS) - return atomic_compare_exchange_strong(a, &oldvalue, newvalue); -#elif defined(USE_MSC_ATOMICS) - return (_InterlockedCompareExchange((int32_t*)a, (int32_t)newvalue, (int32_t)oldvalue) == (int32_t)oldvalue); -#elif defined(USE_GCC_ATOMICS) - return __sync_bool_compare_and_swap(a, oldvalue, newvalue); -#elif defined(USE_PTHREAD_ATOMICS) - int32_t ret = 0; - pthread_mutex_lock(&ptm); - if(*a == oldvalue) { - *a = newvalue; - ret = 1; - } - pthread_mutex_unlock(&ptm); - return ret; -#else -#warning THAtomic is not thread safe - if(*a == oldvalue) { - *a = newvalue; - return 1; - } - else - return 0; -#endif -} - -void THAtomicSetLong(int64_t volatile *a, int64_t newvalue) -{ -#if defined(USE_C11_ATOMICS) - atomic_store(a, newvalue); -#elif defined(USE_MSC_ATOMICS) - _InterlockedExchange64(a, newvalue); -#elif defined(USE_GCC_ATOMICS) - __sync_lock_test_and_set(a, newvalue); -#else - int64_t oldvalue; - do { - oldvalue = *a; - } while (!THAtomicCompareAndSwapLong(a, oldvalue, newvalue)); -#endif -} - -int64_t THAtomicGetLong(int64_t volatile *a) -{ -#if defined(USE_C11_ATOMICS) - return atomic_load(a); -#else - int64_t value; - do { - value = *a; - } while (!THAtomicCompareAndSwapLong(a, value, value)); - return value; -#endif -} - -int64_t THAtomicAddLong(int64_t volatile *a, int64_t value) -{ -#if defined(USE_C11_ATOMICS) - return atomic_fetch_add(a, value); -#elif defined(USE_MSC_ATOMICS) - return _InterlockedExchangeAdd64(a, value); -#elif defined(USE_GCC_ATOMICS) - return __sync_fetch_and_add(a, value); -#else - int64_t oldvalue; - do { - oldvalue = *a; - } while (!THAtomicCompareAndSwapLong(a, oldvalue, (oldvalue + value))); - return oldvalue; -#endif -} - -int64_t THAtomicCompareAndSwapLong(int64_t volatile *a, int64_t oldvalue, int64_t newvalue) -{ -#if defined(USE_C11_ATOMICS) - return atomic_compare_exchange_strong(a, &oldvalue, newvalue); -#elif defined(USE_MSC_ATOMICS) - return (_InterlockedCompareExchange64(a, newvalue, oldvalue) == oldvalue); -#elif defined(USE_GCC_ATOMICS) - return __sync_bool_compare_and_swap(a, oldvalue, newvalue); -#elif defined(USE_PTHREAD_ATOMICS) - int64_t ret = 0; - pthread_mutex_lock(&ptm); - if(*a == oldvalue) { - *a = newvalue; - ret = 1; - } - pthread_mutex_unlock(&ptm); - return ret; -#else -#warning THAtomic is not thread safe - if(*a == oldvalue) { - *a = newvalue; - return 1; - } - else - return 0; -#endif -} - -void THAtomicSetPtrdiff(ptrdiff_t volatile *a, ptrdiff_t newvalue) -{ -#if defined(USE_C11_ATOMICS) - atomic_store(a, newvalue); -#elif defined(USE_MSC_ATOMICS) -#ifdef _WIN64 - _InterlockedExchange64(a, newvalue); -#else - _InterlockedExchange(a, newvalue); -#endif -#elif defined(USE_GCC_ATOMICS) - __sync_lock_test_and_set(a, newvalue); -#else - ptrdiff_t oldvalue; - do { - oldvalue = *a; - } while (!THAtomicCompareAndSwapPtrdiff(a, oldvalue, newvalue)); -#endif -} - -ptrdiff_t THAtomicGetPtrdiff(ptrdiff_t volatile *a) -{ -#if defined(USE_C11_ATOMICS) - return atomic_load(a); -#else - ptrdiff_t value; - do { - value = *a; - } while (!THAtomicCompareAndSwapPtrdiff(a, value, value)); - return value; -#endif -} - -ptrdiff_t THAtomicAddPtrdiff(ptrdiff_t volatile *a, ptrdiff_t value) -{ -#if defined(USE_C11_ATOMICS) - return atomic_fetch_add(a, value); -#elif defined(USE_MSC_ATOMICS) -#ifdef _WIN64 - return _InterlockedExchangeAdd64(a, value); -#else - return _InterlockedExchangeAdd(a, value); -#endif -#elif defined(USE_GCC_ATOMICS) - return __sync_fetch_and_add(a, value); -#else - ptrdiff_t oldvalue; - do { - oldvalue = *a; - } while (!THAtomicCompareAndSwapPtrdiff(a, oldvalue, (oldvalue + value))); - return oldvalue; -#endif -} - -ptrdiff_t THAtomicCompareAndSwapPtrdiff(ptrdiff_t volatile *a, ptrdiff_t oldvalue, ptrdiff_t newvalue) -{ -#if defined(USE_C11_ATOMICS) - return atomic_compare_exchange_strong(a, &oldvalue, newvalue); -#elif defined(USE_MSC_ATOMICS) -#ifdef _WIN64 - return (_InterlockedCompareExchange64(a, newvalue, oldvalue) == oldvalue); -#else - return (_InterlockedCompareExchange(a, newvalue, oldvalue) == oldvalue); -#endif -#elif defined(USE_GCC_ATOMICS) - return __sync_bool_compare_and_swap(a, oldvalue, newvalue); -#elif defined(USE_PTHREAD_ATOMICS) - ptrdiff_t ret = 0; - pthread_mutex_lock(&ptm); - if(*a == oldvalue) { - *a = newvalue; - ret = 1; - } - pthread_mutex_unlock(&ptm); - return ret; -#else -#warning THAtomic is not thread safe - if(*a == oldvalue) { - *a = newvalue; - return 1; - } - else - return 0; -#endif -} diff --git a/aten/src/TH/THAtomic.h b/aten/src/TH/THAtomic.h deleted file mode 100644 index 8431664b23..0000000000 --- a/aten/src/TH/THAtomic.h +++ /dev/null @@ -1,118 +0,0 @@ -#ifndef TH_ATOMIC_INC -#define TH_ATOMIC_INC - -#include "THGeneral.h" - -/****************************************************************************** - * Atomic operations for TH - * Five backends are integrated: - * - C11 atomic operations - * - MSVC intrinsics - * - GCC intrinsics - * - Pthread if none of the above is available - * - Unsafe mode in none of the above is available - ******************************************************************************/ - - -/****************************************************************************** - * all-purpose functions - ******************************************************************************/ - -/* - * *a = newvalue -*/ -TH_API void THAtomicSet(int32_t volatile *a, int32_t newvalue); - -/* - * return *a -*/ -TH_API int32_t THAtomicGet(int32_t volatile *a); - -/* - * *a += value, - * return previous *a -*/ -TH_API int32_t THAtomicAdd(int32_t volatile *a, int32_t value); - -/* - * check if (*a == oldvalue) - * if true: set *a to newvalue, return 1 - * if false: return 0 -*/ -TH_API int32_t THAtomicCompareAndSwap(int32_t volatile *a, int32_t oldvalue, int32_t newvalue); - - -/****************************************************************************** - * refcounting functions - ******************************************************************************/ - -/* - * *a++ -*/ -TH_API void THAtomicIncrementRef(int32_t volatile *a); - -/* - * *a--, - * return 1 if *a == 0 after the operation, 0 otherwise -*/ -TH_API int32_t THAtomicDecrementRef(int32_t volatile *a); - - - -/****************************************************************************** - * functions for long type - ******************************************************************************/ - -/* - * *a = newvalue -*/ -TH_API void THAtomicSetLong(int64_t volatile *a, int64_t newvalue); - -/* - * return *a -*/ -TH_API int64_t THAtomicGetLong(int64_t volatile *a); - -/* - * *a += value, - * return previous *a -*/ -TH_API int64_t THAtomicAddLong(int64_t volatile *a, int64_t value); - -/* - * check if (*a == oldvalue) - * if true: set *a to newvalue, return 1 - * if false: return 0 -*/ -TH_API int64_t THAtomicCompareAndSwapLong(int64_t volatile *a, int64_t oldvalue, int64_t newvalue); - - - -/****************************************************************************** - * functions for ptrdiff_t type - ******************************************************************************/ - -/* - * *a = newvalue -*/ -TH_API void THAtomicSetPtrdiff(ptrdiff_t volatile *a, ptrdiff_t newvalue); - -/* - * return *a -*/ -TH_API ptrdiff_t THAtomicGetPtrdiff(ptrdiff_t volatile *a); - -/* - * *a += value, - * return previous *a -*/ -TH_API ptrdiff_t THAtomicAddPtrdiff(ptrdiff_t volatile *a, ptrdiff_t value); - -/* - * check if (*a == oldvalue) - * if true: set *a to newvalue, return 1 - * if false: return 0 -*/ -TH_API ptrdiff_t THAtomicCompareAndSwapPtrdiff(ptrdiff_t volatile *a, ptrdiff_t oldvalue, ptrdiff_t newvalue); - -#endif diff --git a/aten/src/TH/THBlas.c b/aten/src/TH/THBlas.cpp index 35618b26a1..7523c9e3cd 100644 --- a/aten/src/TH/THBlas.c +++ b/aten/src/TH/THBlas.cpp @@ -1,4 +1,4 @@ #include "THBlas.h" -#include "generic/THBlas.c" +#include "generic/THBlas.cpp" #include "THGenerateAllTypes.h" diff --git a/aten/src/TH/THDiskFile.c b/aten/src/TH/THDiskFile.cpp index ee06b8bdf8..258ad2cbca 100644 --- a/aten/src/TH/THDiskFile.c +++ b/aten/src/TH/THDiskFile.cpp @@ -105,7 +105,7 @@ size_t fread__(void *ptr, size_t size, size_t nitems, FILE *stream) { \ if(sizeof(TYPE) > 1) \ { \ - char *buffer = THAlloc(sizeof(TYPE)*n); \ + char *buffer = static_cast<char*>(THAlloc(sizeof(TYPE)*n)); \ THDiskFile_reverseMemory(buffer, data, sizeof(TYPE), n); \ nwrite = fwrite(buffer, sizeof(TYPE), n, dfself->handle); \ THFree(buffer); \ @@ -396,7 +396,7 @@ static ssize_t THDiskFile_readLong(THFile *self, int64_t *data, ssize_t n) else /* if(dfself->longSize == 8) */ { int big_endian = !THDiskFile_isLittleEndianCPU(); - int32_t *buffer = THAlloc(8*n); + int32_t *buffer = static_cast<int32_t*>(THAlloc(8*n)); nread = fread__(buffer, 8, n, dfself->handle); ssize_t i; for(i = nread; i > 0; i--) @@ -449,14 +449,14 @@ static ssize_t THDiskFile_writeLong(THFile *self, int64_t *data, ssize_t n) } else { - char *buffer = THAlloc(sizeof(int64_t)*n); + char *buffer = static_cast<char*>(THAlloc(sizeof(int64_t)*n)); THDiskFile_reverseMemory(buffer, data, sizeof(int64_t), n); nwrite = fwrite(buffer, sizeof(int64_t), n, dfself->handle); THFree(buffer); } } else if(dfself->longSize == 4) { - int32_t *buffer = THAlloc(4*n); + int32_t *buffer = static_cast<int32_t*>(THAlloc(4*n)); ssize_t i; for(i = 0; i < n; i++) buffer[i] = (int32_t) data[i]; @@ -468,7 +468,7 @@ static ssize_t THDiskFile_writeLong(THFile *self, int64_t *data, ssize_t n) else /* if(dfself->longSize == 8) */ { int big_endian = !THDiskFile_isLittleEndianCPU(); - int32_t *buffer = THAlloc(8*n); + int32_t *buffer = static_cast<int32_t*>(THAlloc(8*n)); ssize_t i; for(i = 0; i < n; i++) { @@ -517,7 +517,7 @@ static ssize_t THDiskFile_readString(THFile *self, const char *format, char **st if(format[1] == 'a') { - char *p = THAlloc(TBRS_BSZ); + char *p = static_cast<char*>(THAlloc(TBRS_BSZ)); ssize_t total = TBRS_BSZ; ssize_t pos = 0; @@ -526,7 +526,7 @@ static ssize_t THDiskFile_readString(THFile *self, const char *format, char **st if(total-pos == 0) /* we need more space! */ { total += TBRS_BSZ; - p = THRealloc(p, total); + p = static_cast<char*>(THRealloc(p, total)); } pos += fread(p+pos, 1, total-pos, dfself->handle); if (pos < total) /* eof? */ @@ -548,7 +548,7 @@ static ssize_t THDiskFile_readString(THFile *self, const char *format, char **st } else { - char *p = THAlloc(TBRS_BSZ); + char *p = static_cast<char*>(THAlloc(TBRS_BSZ)); ssize_t total = TBRS_BSZ; ssize_t pos = 0; ssize_t size; @@ -558,7 +558,7 @@ static ssize_t THDiskFile_readString(THFile *self, const char *format, char **st if(total-pos <= 1) /* we can only write '\0' in there! */ { total += TBRS_BSZ; - p = THRealloc(p, total); + p = static_cast<char*>(THRealloc(p, total)); } if (fgets(p+pos, (int) (total-pos), dfself->handle) == NULL) /* eof? */ { @@ -677,10 +677,10 @@ THFile *THDiskFile_new(const char *name, const char *mode, int isQuiet) THError("cannot open <%s> in mode %c%c", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' ')); } - self = THAlloc(sizeof(THDiskFile)); + self = static_cast<THDiskFile*>(THAlloc(sizeof(THDiskFile))); self->handle = handle; - self->name = THAlloc(strlen(name)+1); + self->name = static_cast<char*>(THAlloc(strlen(name)+1)); strcpy(self->name, name); self->isNativeEncoding = 1; self->longSize = 0; @@ -781,10 +781,10 @@ THFile *THPipeFile_new(const char *name, const char *mode, int isQuiet) THError("cannot open <%s> in mode %c%c. This might be because eg the executable doesn't exist, but it could also be because you are out of memory.", name, (isReadable ? 'r' : ' '), (isWritable ? 'w' : ' ')); } - self = THAlloc(sizeof(THDiskFile)); + self = static_cast<THDiskFile*>(THAlloc(sizeof(THDiskFile))); self->handle = handle; - self->name = THAlloc(strlen(name)+1); + self->name = static_cast<char*>(THAlloc(strlen(name)+1)); strcpy(self->name, name); self->isNativeEncoding = 1; self->longSize = 0; diff --git a/aten/src/TH/THFile.c b/aten/src/TH/THFile.cpp index 649c8543da..e310715adc 100644 --- a/aten/src/TH/THFile.c +++ b/aten/src/TH/THFile.cpp @@ -1,4 +1,5 @@ #include "THFile.h" +#include "THStorage.hpp" #include "THFilePrivate.h" #define IMPLEMENT_THFILE_RW(TYPEC, TYPE) \ diff --git a/aten/src/TH/THGeneral.c b/aten/src/TH/THGeneral.cpp index 00c209aec2..667d7fbf25 100644 --- a/aten/src/TH/THGeneral.c +++ b/aten/src/TH/THGeneral.cpp @@ -1,5 +1,4 @@ #include "THGeneral.h" -#include "THAtomic.h" #ifdef _OPENMP #include <omp.h> @@ -23,9 +22,9 @@ #ifdef TH_BLAS_MKL // this is the C prototype, while mkl_set_num_threads is the fortran prototype -extern void MKL_Set_Num_Threads(int); +TH_EXTERNC void MKL_Set_Num_Threads(int); // this is the C prototype, while mkl_get_max_threads is the fortran prototype -extern int MKL_Get_Max_Threads(void); +TH_EXTERNC int MKL_Get_Max_Threads(void); #endif /* Torch Error Handling */ diff --git a/aten/src/TH/THGenerator.h b/aten/src/TH/THGenerator.hpp index 6fa4dd1749..f1e6914b69 100644 --- a/aten/src/TH/THGenerator.h +++ b/aten/src/TH/THGenerator.hpp @@ -1,5 +1,7 @@ -#ifndef TH_GENERATOR -#define TH_GENERATOR +#pragma once + +// STOP!!! Thinking of including this header directly? Please +// read Note [TH abstraction violation] #include <mutex> @@ -25,5 +27,3 @@ struct THGenerator { std::mutex mutex; /* mutex for using this generator */ THGeneratorState gen_state; }; - -#endif diff --git a/aten/src/TH/THHalf.c b/aten/src/TH/THHalf.cpp index 1c46c59a99..1c46c59a99 100644 --- a/aten/src/TH/THHalf.c +++ b/aten/src/TH/THHalf.cpp diff --git a/aten/src/TH/THLapack.h b/aten/src/TH/THLapack.h index cf2cee504c..614d15f940 100644 --- a/aten/src/TH/THLapack.h +++ b/aten/src/TH/THLapack.h @@ -21,14 +21,7 @@ if (info < 0) { \ THError(fmt, func, info, ##__VA_ARGS__); \ } -#ifdef __cplusplus -extern "C" { -#endif - #include "generic/THLapack.h" #include "THGenerateAllTypes.h" -#ifdef __cplusplus -} -#endif #endif diff --git a/aten/src/TH/THLogAdd.c b/aten/src/TH/THLogAdd.cpp index 4b14f85402..4b14f85402 100644 --- a/aten/src/TH/THLogAdd.c +++ b/aten/src/TH/THLogAdd.cpp diff --git a/aten/src/TH/THMemoryFile.c b/aten/src/TH/THMemoryFile.cpp index ca86f374e9..b0c2b75cfd 100644 --- a/aten/src/TH/THMemoryFile.c +++ b/aten/src/TH/THMemoryFile.cpp @@ -1,5 +1,7 @@ #include "THMemoryFile.h" +#include "THStorage.hpp" #include "THFilePrivate.h" +#include "THDiskFile.h" #include "stdint.h" #ifndef _WIN32 @@ -353,8 +355,6 @@ READ_WRITE_METHODS(double, Double, nByteWritten = snprintf((char*) mfself->storage->data+mfself->position, mfself->storage->size-mfself->position, "%.17g", data[i]), 1) -int THDiskFile_isLittleEndianCPU(void); - static ssize_t THMemoryFile_readLong(THFile *self, int64_t *data, ssize_t n) { THMemoryFile *mfself = (THMemoryFile*)self; @@ -527,7 +527,7 @@ static ssize_t THMemoryFile_writeLong(THFile *self, int64_t *data, ssize_t n) static int8_t* THMemoryFile_cloneString(const int8_t *str, ssize_t size) { - int8_t *cstr = THAlloc(size); + int8_t *cstr = static_cast<int8_t*>(THAlloc(size)); memcpy(cstr, str, size); return cstr; } @@ -665,7 +665,7 @@ THFile *THMemoryFile_newWithStorage(THCharStorage *storage, const char *mode) storage->data[0] = '\0'; } - mfself = THAlloc(sizeof(THMemoryFile)); + mfself = static_cast<THMemoryFile*>(THAlloc(sizeof(THMemoryFile))); mfself->storage = storage; mfself->size = (storage ? storage->size-1 : 0); diff --git a/aten/src/TH/THRandom.cpp b/aten/src/TH/THRandom.cpp index a5c35fdea7..8755f774f1 100644 --- a/aten/src/TH/THRandom.cpp +++ b/aten/src/TH/THRandom.cpp @@ -1,6 +1,6 @@ #include "THGeneral.h" #include "THRandom.h" -#include "THGenerator.h" +#include "THGenerator.hpp" #ifndef _WIN32 #include <fcntl.h> diff --git a/aten/src/TH/THRandom.h b/aten/src/TH/THRandom.h index b7a827f49e..5460d330d1 100644 --- a/aten/src/TH/THRandom.h +++ b/aten/src/TH/THRandom.h @@ -3,13 +3,10 @@ #include "THGeneral.h" -#ifdef __cplusplus -extern "C" { -#endif #define _MERSENNE_STATE_N 624 #define _MERSENNE_STATE_M 397 -/* Struct definition is moved to THGenerator.h, because THRandom.h +/* Struct definition is moved to THGenerator.hpp, because THRandom.h needs to be C-compatible in order to be included in C FFI extensions. */ typedef struct THGenerator THGenerator; typedef struct THGeneratorState THGeneratorState; @@ -82,7 +79,5 @@ TH_API int THRandom_geometric(THGenerator *_generator, double p); /* Returns true with probability $p$ and false with probability $1-p$ (p > 0). */ TH_API int THRandom_bernoulli(THGenerator *_generator, double p); -#ifdef __cplusplus -} -#endif + #endif diff --git a/aten/src/TH/THSize.c b/aten/src/TH/THSize.cpp index 2eb00393a7..2eb00393a7 100644 --- a/aten/src/TH/THSize.c +++ b/aten/src/TH/THSize.cpp diff --git a/aten/src/TH/THStorage.c b/aten/src/TH/THStorage.cpp index 37df9888e0..4206846077 100644 --- a/aten/src/TH/THStorage.c +++ b/aten/src/TH/THStorage.cpp @@ -1,16 +1,15 @@ -#include "THAtomic.h" -#include "THStorage.h" +#include "THStorage.hpp" -#include "generic/THStorage.c" +#include "generic/THStorage.cpp" #include "THGenerateAllTypes.h" -#include "generic/THStorage.c" +#include "generic/THStorage.cpp" #include "THGenerateHalfType.h" -#include "generic/THStorageCopy.c" +#include "generic/THStorageCopy.cpp" #include "THGenerateAllTypes.h" -#include "generic/THStorageCopy.c" +#include "generic/THStorageCopy.cpp" #include "THGenerateHalfType.h" @@ -56,7 +55,7 @@ int THLongStorage_inferSize2(THLongStorage *output, int64_t *sizesA, int64_t dim THArgCheck(dimsB, 1, "Can't expand empty tensor b"); ptrdiff_t ndim = dimsA > dimsB ? dimsA : dimsB; - int64_t *expandedSizes = THAlloc(sizeof(int64_t)*ndim); + int64_t *expandedSizes = static_cast<int64_t*>(THAlloc(sizeof(int64_t)*ndim)); for (int64_t i = ndim - 1; i >= 0; --i) { int64_t offset = ndim - 1 - i; @@ -92,7 +91,7 @@ int THLongStorage_inferSizeN(THLongStorage *output, int n, int64_t **sizes, int6 ndim = dims[ j ] > ndim ? dims[ j ] : ndim; } - int64_t *expandedSizes = THAlloc(sizeof(int64_t)*ndim); + int64_t *expandedSizes = static_cast<int64_t*>(THAlloc(sizeof(int64_t)*ndim)); for (int64_t i = ndim - 1; i >= 0; --i) { expandedSizes[ i ] = 1; @@ -121,8 +120,8 @@ int THLongStorage_inferExpandGeometry(int64_t *tensorSizes, int64_t *tensorStrid char *error_buffer, int buffer_len) { ptrdiff_t ndim = THLongStorage_size(sizes); - int64_t *expandedSizesCalc = THAlloc(sizeof(int64_t)*ndim); - int64_t *expandedStridesCalc = THAlloc(sizeof(int64_t)*ndim); + int64_t *expandedSizesCalc = static_cast<int64_t*>(THAlloc(sizeof(int64_t)*ndim)); + int64_t *expandedStridesCalc = static_cast<int64_t*>(THAlloc(sizeof(int64_t)*ndim)); // create a new geometry for the tensors for (int64_t i = ndim - 1; i >= 0; --i) { diff --git a/aten/src/TH/THStorage.hpp b/aten/src/TH/THStorage.hpp new file mode 100644 index 0000000000..bc4f6bfa6c --- /dev/null +++ b/aten/src/TH/THStorage.hpp @@ -0,0 +1,14 @@ +#pragma once + +// STOP!!! Thinking of including this header directly? Please +// read Note [TH abstraction violation] + +#include "THStorage.h" + +#include <atomic> + +#include "generic/THStorage.hpp" +#include "THGenerateAllTypes.h" + +#include "generic/THStorage.hpp" +#include "THGenerateHalfType.h" diff --git a/aten/src/TH/THTensor.cpp b/aten/src/TH/THTensor.cpp index bd16b9c61e..04e462ac93 100644 --- a/aten/src/TH/THTensor.cpp +++ b/aten/src/TH/THTensor.cpp @@ -1,8 +1,8 @@ #include <cmath> #include <float.h> -#include "THAtomic.h" -#include "THTensor.h" +#include <atomic> +#include "THTensor.hpp" #include "THVector.h" #include "generic/simd/simd.h" @@ -18,20 +18,20 @@ #include "generic/THTensor.cpp" #include "THGenerateHalfType.h" -#include "generic/THTensorCopy.c" +#include "generic/THTensorCopy.cpp" #include "THGenerateAllTypes.h" -#include "generic/THTensorCopy.c" +#include "generic/THTensorCopy.cpp" #include "THGenerateHalfType.h" #include "generic/THTensorRandom.cpp" #include "THGenerateAllTypes.h" -#include "generic/THTensorMath.c" +#include "generic/THTensorMath.cpp" #include "THGenerateAllTypes.h" #include "generic/THTensorConv.cpp" #include "THGenerateAllTypes.h" -#include "generic/THTensorLapack.c" +#include "generic/THTensorLapack.cpp" #include "THGenerateFloatTypes.h" diff --git a/aten/src/TH/THTensor.h b/aten/src/TH/THTensor.h index 924583b2ae..176eee4c36 100644 --- a/aten/src/TH/THTensor.h +++ b/aten/src/TH/THTensor.h @@ -7,9 +7,6 @@ #define THTensor TH_CONCAT_3(TH,Real,Tensor) #define THTensor_(NAME) TH_CONCAT_4(TH,Real,Tensor_,NAME) -#ifdef __cplusplus -extern "C" { -#endif /* basics */ #include "generic/THTensor.h" #include "THGenerateAllTypes.h" @@ -41,7 +38,4 @@ extern "C" { /* lapack support */ #include "generic/THTensorLapack.h" #include "THGenerateFloatTypes.h" -#ifdef __cplusplus -} -#endif #endif diff --git a/aten/src/TH/THTensor.hpp b/aten/src/TH/THTensor.hpp new file mode 100644 index 0000000000..8b51d2e654 --- /dev/null +++ b/aten/src/TH/THTensor.hpp @@ -0,0 +1,15 @@ +#pragma once + +// STOP!!! Thinking of including this header directly? Please +// read Note [TH abstraction violation] + +#include "THTensor.h" +#include "THStorage.hpp" + +#include <atomic> + +#include "generic/THTensor.hpp" +#include "THGenerateAllTypes.h" + +#include "generic/THTensor.hpp" +#include "THGenerateHalfType.h" diff --git a/aten/src/TH/THVector.cpp b/aten/src/TH/THVector.cpp index 51f398cc09..3460d17f4b 100644 --- a/aten/src/TH/THVector.cpp +++ b/aten/src/TH/THVector.cpp @@ -3,16 +3,16 @@ #include "generic/simd/simd.h" #ifdef __NEON__ -#include "vector/NEON.c" +#include "vector/NEON.cpp" #endif #ifdef __PPC64__ -#include "vector/VSX.c" +#include "vector/VSX.cpp" #endif #if defined(USE_SSE2) || defined(USE_SSE3) || defined(USE_SSSE3) \ || defined(USE_SSE4_1) || defined(USE_SSE4_2) -#include "vector/SSE.c" +#include "vector/SSE.cpp" #endif #if defined(USE_AVX) diff --git a/aten/src/TH/THVector.h b/aten/src/TH/THVector.h index 5b609950e4..8054f648e8 100644 --- a/aten/src/TH/THVector.h +++ b/aten/src/TH/THVector.h @@ -6,14 +6,9 @@ #define THVector_(NAME) TH_CONCAT_4(TH,Real,Vector_,NAME) -#ifdef __cplusplus -extern "C" { -#endif /* We are going to use dynamic dispatch, and want only to generate declarations * of the vector functions */ #include "generic/THVector.h" #include "THGenerateAllTypes.h" -#ifdef __cplusplus -} -#endif + #endif // TH_VECTOR_INC diff --git a/aten/src/TH/generic/THBlas.c b/aten/src/TH/generic/THBlas.cpp index b4f7c113ca..d06ae6a9d8 100644 --- a/aten/src/TH/generic/THBlas.c +++ b/aten/src/TH/generic/THBlas.cpp @@ -1,5 +1,5 @@ #ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "generic/THBlas.c" +#define TH_GENERIC_FILE "generic/THBlas.cpp" #else diff --git a/aten/src/TH/generic/THLapack.h b/aten/src/TH/generic/THLapack.h index b464dd2d2b..fe64daed5c 100644 --- a/aten/src/TH/generic/THLapack.h +++ b/aten/src/TH/generic/THLapack.h @@ -22,19 +22,19 @@ TH_API void THLapack_(getri)(int n, real *a, int lda, int *ipiv, real *work, int /* Positive Definite matrices */ /* Cholesky factorization */ -void THLapack_(potrf)(char uplo, int n, real *a, int lda, int *info); +TH_API void THLapack_(potrf)(char uplo, int n, real *a, int lda, int *info); /* Matrix inverse based on Cholesky factorization */ -void THLapack_(potri)(char uplo, int n, real *a, int lda, int *info); +TH_API void THLapack_(potri)(char uplo, int n, real *a, int lda, int *info); /* Solve A*X = B with a symmetric positive definite matrix A using the Cholesky factorization */ -void THLapack_(potrs)(char uplo, int n, int nrhs, real *a, int lda, real *b, int ldb, int *info); +TH_API void THLapack_(potrs)(char uplo, int n, int nrhs, real *a, int lda, real *b, int ldb, int *info); /* Cholesky factorization with complete pivoting. */ -void THLapack_(pstrf)(char uplo, int n, real *a, int lda, int *piv, int *rank, real tol, real *work, int *info); +TH_API void THLapack_(pstrf)(char uplo, int n, real *a, int lda, int *piv, int *rank, real tol, real *work, int *info); /* QR decomposition */ -void THLapack_(geqrf)(int m, int n, real *a, int lda, real *tau, real *work, int lwork, int *info); +TH_API void THLapack_(geqrf)(int m, int n, real *a, int lda, real *tau, real *work, int lwork, int *info); /* Build Q from output of geqrf */ -void THLapack_(orgqr)(int m, int n, int k, real *a, int lda, real *tau, real *work, int lwork, int *info); +TH_API void THLapack_(orgqr)(int m, int n, int k, real *a, int lda, real *tau, real *work, int lwork, int *info); /* Multiply Q with a matrix from output of geqrf */ -void THLapack_(ormqr)(char side, char trans, int m, int n, int k, real *a, int lda, real *tau, real *c, int ldc, real *work, int lwork, int *info); +TH_API void THLapack_(ormqr)(char side, char trans, int m, int n, int k, real *a, int lda, real *tau, real *c, int ldc, real *work, int lwork, int *info); #endif diff --git a/aten/src/TH/generic/THStorage.c b/aten/src/TH/generic/THStorage.cpp index 70c596e630..761cdf38fd 100644 --- a/aten/src/TH/generic/THStorage.c +++ b/aten/src/TH/generic/THStorage.cpp @@ -1,7 +1,9 @@ #ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "generic/THStorage.c" +#define TH_GENERIC_FILE "generic/THStorage.cpp" #else +#include <new> + real* THStorage_(data)(const THStorage *self) { return self->data; @@ -31,10 +33,10 @@ THStorage* THStorage_(newWithAllocator)(ptrdiff_t size, THAllocator *allocator, void *allocatorContext) { - THStorage *storage = THAlloc(sizeof(THStorage)); - storage->data = allocator->malloc(allocatorContext, sizeof(real)*size); + THStorage *storage = static_cast<THStorage*>(THAlloc(sizeof(THStorage))); + storage->data = static_cast<real*>(allocator->malloc(allocatorContext, sizeof(real)*size)); storage->size = size; - storage->refcount = 1; + new (&storage->refcount) std::atomic<int>(1); storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; storage->allocator = allocator; storage->allocatorContext = allocatorContext; @@ -104,7 +106,20 @@ void THStorage_(clearFlag)(THStorage *storage, const char flag) void THStorage_(retain)(THStorage *storage) { if(storage && (storage->flag & TH_STORAGE_REFCOUNTED)) - THAtomicIncrementRef(&storage->refcount); + ++storage->refcount; +} + +int THStorage_(retainIfLive)(THStorage *storage) +{ + // TODO: Check if TH_STORAGE_REFCOUNTED? + int refcount = storage->refcount.load(); + while (refcount > 0) { + if (storage->refcount.compare_exchange_strong(refcount, refcount + 1)) { + return 1; + } + refcount = storage->refcount.load(); + } + return 0; } void THStorage_(free)(THStorage *storage) @@ -112,9 +127,9 @@ void THStorage_(free)(THStorage *storage) if(!storage) return; - if((storage->flag & TH_STORAGE_REFCOUNTED) && (THAtomicGet(&storage->refcount) > 0)) + if((storage->flag & TH_STORAGE_REFCOUNTED) && (storage->refcount.load() > 0)) { - if(THAtomicDecrementRef(&storage->refcount)) + if(--storage->refcount == 0) { if(storage->flag & TH_STORAGE_FREEMEM) { storage->allocator->free(storage->allocatorContext, storage->data); @@ -122,6 +137,7 @@ void THStorage_(free)(THStorage *storage) if(storage->flag & TH_STORAGE_VIEW) { THStorage_(free)(storage->view); } + storage->refcount.~atomic<int>(); THFree(storage); } } @@ -136,7 +152,7 @@ THStorage* THStorage_(newWithData)(real *data, ptrdiff_t size) THStorage* THStorage_(newWithDataAndAllocator)(real* data, ptrdiff_t size, THAllocator* allocator, void* allocatorContext) { - THStorage *storage = THAlloc(sizeof(THStorage)); + THStorage *storage = static_cast<THStorage*>(THAlloc(sizeof(THStorage))); storage->data = data; storage->size = size; storage->refcount = 1; @@ -157,9 +173,9 @@ void THStorage_(resize)(THStorage *storage, ptrdiff_t size) if (size == 0) { storage->data = NULL; } else { - storage->data = storage->allocator->malloc( + storage->data = static_cast<real*>(storage->allocator->malloc( storage->allocatorContext, - sizeof(real)*size); + sizeof(real)*size)); } storage->size = size; if (old_data != NULL) { @@ -173,10 +189,10 @@ void THStorage_(resize)(THStorage *storage, ptrdiff_t size) storage->allocator->free(storage->allocatorContext, old_data); } } else { - storage->data = storage->allocator->realloc( + storage->data = static_cast<real*>(storage->allocator->realloc( storage->allocatorContext, storage->data, - sizeof(real)*size); + sizeof(real)*size)); storage->size = size; } } else { diff --git a/aten/src/TH/generic/THStorage.h b/aten/src/TH/generic/THStorage.h index 3dd214b339..3213205212 100644 --- a/aten/src/TH/generic/THStorage.h +++ b/aten/src/TH/generic/THStorage.h @@ -21,16 +21,8 @@ #define TH_STORAGE_FREEMEM 4 #define TH_STORAGE_VIEW 8 -typedef struct THStorage -{ - real *data; - ptrdiff_t size; - int refcount; - char flag; - THAllocator *allocator; - void *allocatorContext; - struct THStorage *view; -} THStorage; +// Struct definition is moved to THStorage.hpp (so this file stays C compatible) +typedef struct THStorage THStorage; TH_API real* THStorage_(data)(const THStorage*); TH_API ptrdiff_t THStorage_(size)(const THStorage*); @@ -63,6 +55,9 @@ TH_API void THStorage_(clearFlag)(THStorage *storage, const char flag); TH_API void THStorage_(retain)(THStorage *storage); TH_API void THStorage_(swap)(THStorage *storage1, THStorage *storage2); +/* used by StorageSharing */ +TH_API int THStorage_(retainIfLive)(THStorage *storage); + /* might differ with other API (like CUDA) */ TH_API void THStorage_(free)(THStorage *storage); TH_API void THStorage_(resize)(THStorage *storage, ptrdiff_t size); diff --git a/aten/src/TH/generic/THStorage.hpp b/aten/src/TH/generic/THStorage.hpp new file mode 100644 index 0000000000..4d698e976d --- /dev/null +++ b/aten/src/TH/generic/THStorage.hpp @@ -0,0 +1,16 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THStorage.hpp" +#else + +typedef struct THStorage +{ + real *data; + ptrdiff_t size; + std::atomic<int> refcount; + char flag; + THAllocator *allocator; + void *allocatorContext; + struct THStorage *view; +} THStorage; + +#endif diff --git a/aten/src/TH/generic/THStorageCopy.c b/aten/src/TH/generic/THStorageCopy.cpp index ce4b57eaff..30bdd5c7de 100644 --- a/aten/src/TH/generic/THStorageCopy.c +++ b/aten/src/TH/generic/THStorageCopy.cpp @@ -1,5 +1,5 @@ #ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "generic/THStorageCopy.c" +#define TH_GENERIC_FILE "generic/THStorageCopy.cpp" #else void THStorage_(rawCopy)(THStorage *storage, real *src) diff --git a/aten/src/TH/generic/THTensor.cpp b/aten/src/TH/generic/THTensor.cpp index 1c5f3a6c32..10ee760d10 100644 --- a/aten/src/TH/generic/THTensor.cpp +++ b/aten/src/TH/generic/THTensor.cpp @@ -2,6 +2,8 @@ #define TH_GENERIC_FILE "generic/THTensor.cpp" #else +#include <new> + /**** access methods ****/ THStorage *THTensor_(storage)(const THTensor *self) { @@ -697,7 +699,7 @@ ptrdiff_t THTensor_(nElement)(const THTensor *self) void THTensor_(retain)(THTensor *self) { if(self->flag & TH_TENSOR_REFCOUNTED) - THAtomicIncrementRef(&self->refcount); + ++self->refcount; } void THTensor_(free)(THTensor *self) @@ -707,12 +709,13 @@ void THTensor_(free)(THTensor *self) if(self->flag & TH_TENSOR_REFCOUNTED) { - if(THAtomicDecrementRef(&self->refcount)) + if(--self->refcount == 0) { THFree(self->size); THFree(self->stride); if(self->storage) THStorage_(free)(self->storage); + self->refcount.~atomic<int>(); THFree(self); } } @@ -730,7 +733,7 @@ void THTensor_(freeCopyTo)(THTensor *self, THTensor *dst) static void THTensor_(rawInit)(THTensor *self) { - self->refcount = 1; + new (&self->refcount) std::atomic<int>(1); self->storage = THStorage_(new)(); self->storageOffset = 0; self->size = NULL; diff --git a/aten/src/TH/generic/THTensor.h b/aten/src/TH/generic/THTensor.h index 8aafd669c3..6b9bf8dcd6 100644 --- a/aten/src/TH/generic/THTensor.h +++ b/aten/src/TH/generic/THTensor.h @@ -6,21 +6,8 @@ #define TH_TENSOR_REFCOUNTED 1 -typedef struct THTensor -{ - int64_t *size; - int64_t *stride; - int nDimension; - - // Note: storage->size may be greater than the recorded size - // of a tensor - THStorage *storage; - ptrdiff_t storageOffset; - int refcount; - - char flag; - -} THTensor; +// Struct definition moved to THTensor.hpp +typedef struct THTensor THTensor; /**** access methods ****/ diff --git a/aten/src/TH/generic/THTensor.hpp b/aten/src/TH/generic/THTensor.hpp new file mode 100644 index 0000000000..c7bd01d803 --- /dev/null +++ b/aten/src/TH/generic/THTensor.hpp @@ -0,0 +1,21 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/THTensor.hpp" +#else + +typedef struct THTensor +{ + int64_t *size; + int64_t *stride; + int nDimension; + + // Note: storage->size may be greater than the recorded size + // of a tensor + THStorage *storage; + ptrdiff_t storageOffset; + std::atomic<int> refcount; + + char flag; + +} THTensor; + +#endif diff --git a/aten/src/TH/generic/THTensorCopy.c b/aten/src/TH/generic/THTensorCopy.cpp index 675ed8b80c..abdcd5cd91 100644 --- a/aten/src/TH/generic/THTensorCopy.c +++ b/aten/src/TH/generic/THTensorCopy.cpp @@ -1,5 +1,5 @@ #ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "generic/THTensorCopy.c" +#define TH_GENERIC_FILE "generic/THTensorCopy.cpp" #else #ifndef _WIN32 diff --git a/aten/src/TH/generic/THTensorLapack.c b/aten/src/TH/generic/THTensorLapack.cpp index cf866d72c8..66ec1f0280 100644 --- a/aten/src/TH/generic/THTensorLapack.c +++ b/aten/src/TH/generic/THTensorLapack.cpp @@ -1,5 +1,5 @@ #ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "generic/THTensorLapack.c" +#define TH_GENERIC_FILE "generic/THTensorLapack.cpp" #else /* diff --git a/aten/src/TH/generic/THTensorMath.c b/aten/src/TH/generic/THTensorMath.cpp index 841ea1a5b8..3f985e031f 100644 --- a/aten/src/TH/generic/THTensorMath.c +++ b/aten/src/TH/generic/THTensorMath.cpp @@ -1,5 +1,5 @@ #ifndef TH_GENERIC_FILE -#define TH_GENERIC_FILE "generic/THTensorMath.c" +#define TH_GENERIC_FILE "generic/THTensorMath.cpp" #else #ifndef NAN @@ -453,7 +453,7 @@ void THTensor_(take)(THTensor *r_, THTensor *src, THLongTensor *index) // Exceptions must not be thrown across OpenMP parallel sections, so we // record the position of the invalid index and throw the exception after the // loop. - int64_t invalidIdxPos = -1; + std::atomic<int64_t> invalidIdxPos(-1); ptrdiff_t i; #pragma omp parallel for if(nIndices > TH_OMP_OVERHEAD_THRESHOLD) private(i) @@ -467,7 +467,8 @@ void THTensor_(take)(THTensor *r_, THTensor *src, THLongTensor *index) dst_data[i] = src_data[THTensor_(dataOffset)(src, idx)]; } } else { - THAtomicCompareAndSwapLong(&invalidIdxPos, -1, i); + int64_t tmp = -1; + invalidIdxPos.compare_exchange_strong(tmp, i); } } diff --git a/aten/src/TH/generic/THTensorRandom.cpp b/aten/src/TH/generic/THTensorRandom.cpp index 8f3f676227..d4c42be55e 100644 --- a/aten/src/TH/generic/THTensorRandom.cpp +++ b/aten/src/TH/generic/THTensorRandom.cpp @@ -2,7 +2,7 @@ #define TH_GENERIC_FILE "generic/THTensorRandom.cpp" #else -#include "THGenerator.h" +#include "THGenerator.hpp" void THTensor_(random)(THTensor *self, THGenerator *_generator) { diff --git a/aten/src/TH/generic/THVector.h b/aten/src/TH/generic/THVector.h index c8ceab4598..887482c528 100644 --- a/aten/src/TH/generic/THVector.h +++ b/aten/src/TH/generic/THVector.h @@ -2,6 +2,7 @@ #define TH_GENERIC_FILE "generic/THVector.h" #else +// Opaque C++ struct struct THGenerator; TH_API void THVector_(fill)(real *x, const real c, const ptrdiff_t n); diff --git a/aten/src/TH/generic/simd/convolve.c b/aten/src/TH/generic/simd/convolve.cpp index 326be70f73..326be70f73 100644 --- a/aten/src/TH/generic/simd/convolve.c +++ b/aten/src/TH/generic/simd/convolve.cpp diff --git a/aten/src/TH/generic/simd/convolve5x5_avx.c b/aten/src/TH/generic/simd/convolve5x5_avx.cpp index 560474ba53..560474ba53 100644 --- a/aten/src/TH/generic/simd/convolve5x5_avx.c +++ b/aten/src/TH/generic/simd/convolve5x5_avx.cpp diff --git a/aten/src/TH/generic/simd/convolve5x5_sse.c b/aten/src/TH/generic/simd/convolve5x5_sse.cpp index 9de9a4a4c0..9de9a4a4c0 100644 --- a/aten/src/TH/generic/simd/convolve5x5_sse.c +++ b/aten/src/TH/generic/simd/convolve5x5_sse.cpp diff --git a/aten/src/TH/vector/AVX.c b/aten/src/TH/vector/AVX.cpp index b7d5dd1d64..b7d5dd1d64 100644 --- a/aten/src/TH/vector/AVX.c +++ b/aten/src/TH/vector/AVX.cpp diff --git a/aten/src/TH/vector/AVX.h b/aten/src/TH/vector/AVX.h index 71f82fd400..505678ae4a 100644 --- a/aten/src/TH/vector/AVX.h +++ b/aten/src/TH/vector/AVX.h @@ -1,28 +1,23 @@ #ifndef TH_AVX_H #define TH_AVX_H +#include "THGeneral.h" #include <stddef.h> -#ifdef __cplusplus -extern "C" { -#endif -void THDoubleVector_copy_AVX(double *y, const double *x, const ptrdiff_t n); -void THDoubleVector_fill_AVX(double *x, const double c, const ptrdiff_t n); -void THDoubleVector_cdiv_AVX(double *z, const double *x, const double *y, const ptrdiff_t n); -void THDoubleVector_divs_AVX(double *y, const double *x, const double c, const ptrdiff_t n); -void THDoubleVector_cmul_AVX(double *z, const double *x, const double *y, const ptrdiff_t n); -void THDoubleVector_muls_AVX(double *y, const double *x, const double c, const ptrdiff_t n); -void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y, const double c, const ptrdiff_t n); -void THDoubleVector_adds_AVX(double *y, const double *x, const double c, const ptrdiff_t n); -void THFloatVector_copy_AVX(float *y, const float *x, const ptrdiff_t n); -void THFloatVector_fill_AVX(float *x, const float c, const ptrdiff_t n); -void THFloatVector_cdiv_AVX(float *z, const float *x, const float *y, const ptrdiff_t n); -void THFloatVector_divs_AVX(float *y, const float *x, const float c, const ptrdiff_t n); -void THFloatVector_cmul_AVX(float *z, const float *x, const float *y, const ptrdiff_t n); -void THFloatVector_muls_AVX(float *y, const float *x, const float c, const ptrdiff_t n); -void THFloatVector_cadd_AVX(float *z, const float *x, const float *y, const float c, const ptrdiff_t n); -void THFloatVector_adds_AVX(float *y, const float *x, const float c, const ptrdiff_t n); -#ifdef __cplusplus -} -#endif +TH_API void THDoubleVector_copy_AVX(double *y, const double *x, const ptrdiff_t n); +TH_API void THDoubleVector_fill_AVX(double *x, const double c, const ptrdiff_t n); +TH_API void THDoubleVector_cdiv_AVX(double *z, const double *x, const double *y, const ptrdiff_t n); +TH_API void THDoubleVector_divs_AVX(double *y, const double *x, const double c, const ptrdiff_t n); +TH_API void THDoubleVector_cmul_AVX(double *z, const double *x, const double *y, const ptrdiff_t n); +TH_API void THDoubleVector_muls_AVX(double *y, const double *x, const double c, const ptrdiff_t n); +TH_API void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y, const double c, const ptrdiff_t n); +TH_API void THDoubleVector_adds_AVX(double *y, const double *x, const double c, const ptrdiff_t n); +TH_API void THFloatVector_copy_AVX(float *y, const float *x, const ptrdiff_t n); +TH_API void THFloatVector_fill_AVX(float *x, const float c, const ptrdiff_t n); +TH_API void THFloatVector_cdiv_AVX(float *z, const float *x, const float *y, const ptrdiff_t n); +TH_API void THFloatVector_divs_AVX(float *y, const float *x, const float c, const ptrdiff_t n); +TH_API void THFloatVector_cmul_AVX(float *z, const float *x, const float *y, const ptrdiff_t n); +TH_API void THFloatVector_muls_AVX(float *y, const float *x, const float c, const ptrdiff_t n); +TH_API void THFloatVector_cadd_AVX(float *z, const float *x, const float *y, const float c, const ptrdiff_t n); +TH_API void THFloatVector_adds_AVX(float *y, const float *x, const float c, const ptrdiff_t n); #endif diff --git a/aten/src/TH/vector/AVX2.h b/aten/src/TH/vector/AVX2.h index 06b3b61de2..1c281d8c50 100644 --- a/aten/src/TH/vector/AVX2.h +++ b/aten/src/TH/vector/AVX2.h @@ -1,23 +1,19 @@ #ifndef TH_AVX2_H #define TH_AVX2_H +#include "THGeneral.h" + #include <stdint.h> #include <stddef.h> -#ifdef __cplusplus -extern "C" { -#endif struct THGenerator; -void THDoubleVector_cadd_AVX2(double *z, const double *x, const double *y, const double c, const ptrdiff_t n); -void THFloatVector_cadd_AVX2(float *z, const float *x, const float *y, const float c, const ptrdiff_t n); -void THFloatVector_normal_fill_AVX2(float *data, +TH_API void THDoubleVector_cadd_AVX2(double *z, const double *x, const double *y, const double c, const ptrdiff_t n); +TH_API void THFloatVector_cadd_AVX2(float *z, const float *x, const float *y, const float c, const ptrdiff_t n); +TH_API void THFloatVector_normal_fill_AVX2(float *data, const int64_t size, struct THGenerator *generator, const float mean, const float stddev); -void THFloatVector_sigmoid_AVX2(float *y, const float *x, const ptrdiff_t n); -#ifdef __cplusplus -} -#endif +TH_API void THFloatVector_sigmoid_AVX2(float *y, const float *x, const ptrdiff_t n); #endif diff --git a/aten/src/TH/vector/NEON.c b/aten/src/TH/vector/NEON.cpp index 3966acefa7..3966acefa7 100644 --- a/aten/src/TH/vector/NEON.c +++ b/aten/src/TH/vector/NEON.cpp diff --git a/aten/src/TH/vector/SSE.c b/aten/src/TH/vector/SSE.cpp index d026935ab0..d026935ab0 100644 --- a/aten/src/TH/vector/SSE.c +++ b/aten/src/TH/vector/SSE.cpp diff --git a/aten/src/TH/vector/VSX.c b/aten/src/TH/vector/VSX.cpp index f01718c0f4..f01718c0f4 100644 --- a/aten/src/TH/vector/VSX.c +++ b/aten/src/TH/vector/VSX.cpp diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index 42e0cfa2ce..ecd9730f44 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -12,7 +12,7 @@ foreach(THC_TYPE Byte Char Short Int Long Half Float Double) foreach(THC_FILE TensorSort TensorMathCompareT TensorMathPointwise TensorMathCompare TensorMathReduce TensorMasked) if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}${THC_TYPE}.cu") FILE(WRITE "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}${THC_TYPE}.cu" - "#include \"../THC${THC_FILE}.cuh\"\n#include \"../generic/THC${THC_FILE}.cu\"\n#include \"../THCGenerate${THC_TYPE}Type.h\"\n") + "#include \"../THC${THC_FILE}.cuh\"\n#include \"THCTensor.hpp\"\n#include \"THCStream.hpp\"\n#include \"../generic/THC${THC_FILE}.cu\"\n#include \"../THCGenerate${THC_TYPE}Type.h\"\n") endif() LIST(APPEND extra_src "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}${THC_TYPE}.cu") endforeach() @@ -111,19 +111,24 @@ INSTALL(FILES THCThrustAllocator.cuh THCTensorMode.cuh THCTensorTopK.cuh + THCCachingAllocator.h + THCGenerator.hpp + THCTensor.hpp + THCStream.hpp + THCStorage.hpp DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/THC") INSTALL(FILES - generic/THCStorage.c + generic/THCStorage.cpp generic/THCStorage.cu generic/THCStorage.h - generic/THCTensor.c + generic/THCTensor.cpp generic/THCTensor.cu generic/THCTensor.h - generic/THCStorageCopy.c + generic/THCStorageCopy.cpp generic/THCStorageCopy.cu generic/THCStorageCopy.h - generic/THCTensorCopy.c + generic/THCTensorCopy.cpp generic/THCTensorCopy.cu generic/THCTensorCopy.h generic/THCTensorMasked.h @@ -159,4 +164,7 @@ INSTALL(FILES generic/THCTensorMode.cu generic/THCTensorTopK.h generic/THCTensorTopK.cu + # See Note [TH abstraction violation] + generic/THCStorage.hpp + generic/THCTensor.hpp DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/THC/generic") diff --git a/aten/src/THC/THCAllocator.c b/aten/src/THC/THCAllocator.cpp index 554a379d7e..554a379d7e 100644 --- a/aten/src/THC/THCAllocator.c +++ b/aten/src/THC/THCAllocator.cpp diff --git a/aten/src/THC/THCCachingAllocator.cpp b/aten/src/THC/THCCachingAllocator.cpp index 35631df8d8..e39d31b143 100644 --- a/aten/src/THC/THCCachingAllocator.cpp +++ b/aten/src/THC/THCCachingAllocator.cpp @@ -1,4 +1,5 @@ #include "THCCachingAllocator.h" +#include "THCStream.hpp" #include <cuda_runtime_api.h> #include <algorithm> diff --git a/aten/src/THC/THCCachingHostAllocator.cpp b/aten/src/THC/THCCachingHostAllocator.cpp index 867fda62bf..122860a9dc 100644 --- a/aten/src/THC/THCCachingHostAllocator.cpp +++ b/aten/src/THC/THCCachingHostAllocator.cpp @@ -1,4 +1,5 @@ #include "THCCachingHostAllocator.h" +#include "THCStream.hpp" #include <cuda_runtime_api.h> #include <deque> diff --git a/aten/src/THC/THCGeneral.cpp b/aten/src/THC/THCGeneral.cpp index 6939178c07..114b967f7d 100644 --- a/aten/src/THC/THCGeneral.cpp +++ b/aten/src/THC/THCGeneral.cpp @@ -1,4 +1,5 @@ #include "THCGeneral.h" +#include "THCStream.hpp" #include "TH.h" #include "THCAllocator.h" #include "THCCachingHostAllocator.h" @@ -774,8 +775,8 @@ cudaError_t THCudaMemGetInfoCached(THCState *state, size_t* freeBytes, size_t* #undef MIN_GLOBAL_SCRATCH_SPACE_PER_SM_STREAM #undef MIN_GLOBAL_SCRATCH_SPACE_PER_DEVICE -#include "THCStorage.c" -#include "THCAllocator.c" +#include "THCStorage.cpp" +#include "THCAllocator.cpp" /* from THCHalf.h */ diff --git a/aten/src/THC/THCGenerator.h b/aten/src/THC/THCGenerator.hpp index 0eeb5f64c5..ea5d1ba347 100644 --- a/aten/src/THC/THCGenerator.h +++ b/aten/src/THC/THCGenerator.hpp @@ -1,6 +1,9 @@ -#ifndef THC_GENERATOR_INC -#define THC_GENERATOR_INC +#pragma once +// STOP!!! Thinking of including this header directly? Please +// read Note [TH abstraction violation] + +#include <atomic> #include <mutex> typedef struct THCGeneratorState { @@ -8,12 +11,10 @@ typedef struct THCGeneratorState { struct mtgp32_kernel_params *kernel_params; int initf; uint64_t initial_seed; - int64_t philox_seed_offset; + std::atomic<int64_t> philox_seed_offset; } THCGeneratorState; struct THCGenerator { std::mutex mutex; /* mutex for using this generator */ THCGeneratorState state; }; - -#endif diff --git a/aten/src/THC/THCStorage.c b/aten/src/THC/THCStorage.cpp index 669efa823e..254cedd4e2 100644 --- a/aten/src/THC/THCStorage.c +++ b/aten/src/THC/THCStorage.cpp @@ -1,8 +1,9 @@ -#include "THCStorage.h" +#include "THCStorage.hpp" #include "THCGeneral.h" -#include "THAtomic.h" #include "THCHalf.h" -#include "generic/THCStorage.c" +#include <new> + +#include "generic/THCStorage.cpp" #include "THCGenerateAllTypes.h" diff --git a/aten/src/THC/THCStorage.cu b/aten/src/THC/THCStorage.cu index 5555c6f194..fe7635c00c 100644 --- a/aten/src/THC/THCStorage.cu +++ b/aten/src/THC/THCStorage.cu @@ -1,4 +1,4 @@ -#include "THCStorage.h" +#include "THCStorage.hpp" #include "THCThrustAllocator.cuh" #include <thrust/device_ptr.h> diff --git a/aten/src/THC/THCStorage.hpp b/aten/src/THC/THCStorage.hpp new file mode 100644 index 0000000000..1ca064c7ba --- /dev/null +++ b/aten/src/THC/THCStorage.hpp @@ -0,0 +1,11 @@ +#pragma once + +// STOP!!! Thinking of including this header directly? Please +// read Note [TH abstraction violation] + +#include "THCStorage.h" + +#include <atomic> + +#include "generic/THCStorage.hpp" +#include "THCGenerateAllTypes.h" diff --git a/aten/src/THC/THCStorageCopy.cpp b/aten/src/THC/THCStorageCopy.cpp index ee9bf8157e..9e42df5e1b 100644 --- a/aten/src/THC/THCStorageCopy.cpp +++ b/aten/src/THC/THCStorageCopy.cpp @@ -1,6 +1,7 @@ #include "THCStorageCopy.h" +#include "THCTensor.hpp" #include "THCTensorCopy.h" -#include "generic/THCStorageCopy.c" +#include "generic/THCStorageCopy.cpp" #include "THCGenerateAllTypes.h" diff --git a/aten/src/THC/THCStorageCopy.cu b/aten/src/THC/THCStorageCopy.cu index 56641888bf..8d7c869c12 100644 --- a/aten/src/THC/THCStorageCopy.cu +++ b/aten/src/THC/THCStorageCopy.cu @@ -3,6 +3,8 @@ #include "THCHalf.h" #include "THCTensorCopy.h" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include "generic/THCStorageCopy.cu" #include "THCGenerateAllTypes.h" diff --git a/aten/src/THC/THCStream.cpp b/aten/src/THC/THCStream.cpp index 49fe680a37..465597c39b 100644 --- a/aten/src/THC/THCStream.cpp +++ b/aten/src/THC/THCStream.cpp @@ -1,8 +1,7 @@ -#include "THCStream.h" +#include "THCStream.hpp" #include <mutex> #include <cuda_runtime_api.h> -#include "THAtomic.h" #define MAX_DEVICES 256 static THCStream default_streams[MAX_DEVICES]; @@ -32,6 +31,9 @@ THC_API THCStream* THCStream_defaultStream(int device) return &default_streams[device]; } +THC_API cudaStream_t THCStream_stream(THCStream* self) { return self->stream; } +THC_API int THCStream_device(THCStream* self) { return self->device; } + THCStream* THCStream_newWithPriority(int flags, int priority) { THCStream* self = (THCStream*) malloc(sizeof(THCStream)); @@ -46,7 +48,7 @@ void THCStream_free(THCStream* self) if (!self || !self->stream) { return; } - if (THAtomicDecrementRef(&self->refcount)) { + if (--self->refcount == 0) { THCudaCheckWarn(cudaStreamDestroy(self->stream)); free(self); } @@ -55,6 +57,6 @@ void THCStream_free(THCStream* self) void THCStream_retain(THCStream* self) { if (self->stream) { - THAtomicIncrementRef(&self->refcount); + self->refcount++; } } diff --git a/aten/src/THC/THCStream.h b/aten/src/THC/THCStream.h index 6ccb057204..4b29685994 100644 --- a/aten/src/THC/THCStream.h +++ b/aten/src/THC/THCStream.h @@ -4,15 +4,11 @@ #include <cuda_runtime_api.h> #include "THCGeneral.h" -struct THCStream -{ - cudaStream_t stream; - int device; - int refcount; -}; - +struct THCStream; THC_API THCStream* THCStream_new(int flags); +THC_API cudaStream_t THCStream_stream(THCStream* self); +THC_API int THCStream_device(THCStream* self); THC_API THCStream* THCStream_defaultStream(int device); THC_API THCStream* THCStream_newWithPriority(int flags, int priority); THC_API void THCStream_free(THCStream* self); diff --git a/aten/src/THC/THCStream.hpp b/aten/src/THC/THCStream.hpp new file mode 100644 index 0000000000..35bb937841 --- /dev/null +++ b/aten/src/THC/THCStream.hpp @@ -0,0 +1,14 @@ +#pragma once + +// STOP!!! Thinking of including this header directly? Please +// read Note [TH abstraction violation] + +#include <atomic> +#include "THCStream.h" + +struct THCStream +{ + cudaStream_t stream; + int device; + std::atomic<int> refcount; +}; diff --git a/aten/src/THC/THCTensor.cpp b/aten/src/THC/THCTensor.cpp index 3bcf69d72b..d749ccb713 100644 --- a/aten/src/THC/THCTensor.cpp +++ b/aten/src/THC/THCTensor.cpp @@ -1,7 +1,8 @@ #include "THCGeneral.h" -#include "THCTensor.h" +#include "THCTensor.hpp" #include "THCTensorCopy.h" -#include "THAtomic.h" -#include "generic/THCTensor.c" +#include <new> + +#include "generic/THCTensor.cpp" #include "THCGenerateAllTypes.h" diff --git a/aten/src/THC/THCTensor.cu b/aten/src/THC/THCTensor.cu index 1e6fc20732..34de80f096 100644 --- a/aten/src/THC/THCTensor.cu +++ b/aten/src/THC/THCTensor.cu @@ -1,4 +1,5 @@ -#include "THCTensor.h" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include "generic/THCTensor.cu" #include "THCGenerateAllTypes.h" diff --git a/aten/src/THC/THCTensor.hpp b/aten/src/THC/THCTensor.hpp new file mode 100644 index 0000000000..b82279df73 --- /dev/null +++ b/aten/src/THC/THCTensor.hpp @@ -0,0 +1,13 @@ +#pragma once + +// STOP!!! Thinking of including this header directly? Please +// read Note [TH abstraction violation] + +#include "THCTensor.h" +#include "THTensor.hpp" +#include "THCStorage.hpp" + +#include <atomic> + +#include "generic/THCTensor.hpp" +#include "THCGenerateAllTypes.h" diff --git a/aten/src/THC/THCTensorConv.cu b/aten/src/THC/THCTensorConv.cu index 41e6457e61..1963d9d7a7 100644 --- a/aten/src/THC/THCTensorConv.cu +++ b/aten/src/THC/THCTensorConv.cu @@ -2,6 +2,8 @@ #include "THCTensorMath.h" #include "THCTensorCopy.h" #include "THCGeneral.h" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include <stdio.h> /* diff --git a/aten/src/THC/THCTensorCopy.cpp b/aten/src/THC/THCTensorCopy.cpp index 920e78567e..98d25520cf 100644 --- a/aten/src/THC/THCTensorCopy.cpp +++ b/aten/src/THC/THCTensorCopy.cpp @@ -1,5 +1,7 @@ #include "THCTensorCopy.h" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "THCCachingHostAllocator.h" -#include "generic/THCTensorCopy.c" +#include "generic/THCTensorCopy.cpp" #include "THCGenerateAllTypes.h" diff --git a/aten/src/THC/THCTensorIndex.cu b/aten/src/THC/THCTensorIndex.cu index 841c772ae0..a52a3ccdfd 100644 --- a/aten/src/THC/THCTensorIndex.cu +++ b/aten/src/THC/THCTensorIndex.cu @@ -12,6 +12,8 @@ #include "THCAtomics.cuh" #include "THCThrustAllocator.cuh" #include "THCTensorSort.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include <thrust/device_ptr.h> #include <thrust/sort.h> #include <algorithm> // for std::min diff --git a/aten/src/THC/THCTensorMath.cu b/aten/src/THC/THCTensorMath.cu index 1afa10f050..2a56270792 100644 --- a/aten/src/THC/THCTensorMath.cu +++ b/aten/src/THC/THCTensorMath.cu @@ -5,6 +5,8 @@ #include "THCNumerics.cuh" #include "THCTensorMath.cuh" #include "THCThrustAllocator.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include <thrust/copy.h> #include <thrust/count.h> diff --git a/aten/src/THC/THCTensorMathBlas.cu b/aten/src/THC/THCTensorMathBlas.cu index 0804d641a3..5551b0cce8 100644 --- a/aten/src/THC/THCTensorMathBlas.cu +++ b/aten/src/THC/THCTensorMathBlas.cu @@ -3,6 +3,8 @@ #include "THCBlas.h" #include "THCTensorCopy.h" #include "THCNumerics.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include "generic/THCTensorMathBlas.cu" #include "THCGenerateAllTypes.h" diff --git a/aten/src/THC/THCTensorMathMagma.cu b/aten/src/THC/THCTensorMathMagma.cu index cac5d73469..4aa6249794 100644 --- a/aten/src/THC/THCTensorMathMagma.cu +++ b/aten/src/THC/THCTensorMathMagma.cu @@ -2,6 +2,8 @@ #include "THCTensorMath.h" #include "THCTensorCopy.h" #include "THCTensorMathMagma.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include <algorithm> #ifdef USE_MAGMA diff --git a/aten/src/THC/THCTensorMathPairwise.cu b/aten/src/THC/THCTensorMathPairwise.cu index a208f29c2d..19434f31f8 100644 --- a/aten/src/THC/THCTensorMathPairwise.cu +++ b/aten/src/THC/THCTensorMathPairwise.cu @@ -5,6 +5,7 @@ #include "THCApply.cuh" #include "THCNumerics.cuh" #include "THCTensorMathCompareT.cuh" +#include "THCTensor.hpp" template <typename T> struct TensorAddConstantOp { diff --git a/aten/src/THC/THCTensorMathReduce.cu b/aten/src/THC/THCTensorMathReduce.cu index c0be9c8127..afaf1113ee 100644 --- a/aten/src/THC/THCTensorMathReduce.cu +++ b/aten/src/THC/THCTensorMathReduce.cu @@ -1,4 +1,5 @@ #include "THCTensorMathReduce.cuh" +#include "THCTensor.hpp" THC_API int THCudaByteTensor_logicalAndAll(THCState *state, THCudaByteTensor *self) { diff --git a/aten/src/THC/THCTensorMode.cu b/aten/src/THC/THCTensorMode.cu index aa6c6284c3..52a5ce2508 100644 --- a/aten/src/THC/THCTensorMode.cu +++ b/aten/src/THC/THCTensorMode.cu @@ -2,6 +2,8 @@ #include "THCThrustAllocator.cuh" #include "THCTensorTypeUtils.cuh" #include "THCReduceApplyUtils.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include <thrust/device_ptr.h> #include <thrust/sort.h> #include <thrust/inner_product.h> diff --git a/aten/src/THC/THCTensorRandom.cpp b/aten/src/THC/THCTensorRandom.cpp index 404b5ff10e..703871bd54 100644 --- a/aten/src/THC/THCTensorRandom.cpp +++ b/aten/src/THC/THCTensorRandom.cpp @@ -1,5 +1,5 @@ #include "THCTensorRandom.h" -#include "THCGenerator.h" +#include "THCGenerator.hpp" #include <random> #include <curand.h> diff --git a/aten/src/THC/THCTensorRandom.cu b/aten/src/THC/THCTensorRandom.cu index 51d282abe2..432138493d 100644 --- a/aten/src/THC/THCTensorRandom.cu +++ b/aten/src/THC/THCTensorRandom.cu @@ -5,7 +5,7 @@ #include "THCTensorMath.h" #include "THCReduceApplyUtils.cuh" #include "THCTensorRandom.cuh" -#include "THCGenerator.h" +#include "THCGenerator.hpp" #include <thrust/functional.h> #include <curand.h> diff --git a/aten/src/THC/generated/THCTensorMaskedByte.cu b/aten/src/THC/generated/THCTensorMaskedByte.cu index 802f873838..87e96d4dd9 100644 --- a/aten/src/THC/generated/THCTensorMaskedByte.cu +++ b/aten/src/THC/generated/THCTensorMaskedByte.cu @@ -1,3 +1,5 @@ #include "../THCTensorMasked.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMasked.cu" #include "../THCGenerateByteType.h" diff --git a/aten/src/THC/generated/THCTensorMaskedChar.cu b/aten/src/THC/generated/THCTensorMaskedChar.cu index 3fb9fd7cc8..875de0782c 100644 --- a/aten/src/THC/generated/THCTensorMaskedChar.cu +++ b/aten/src/THC/generated/THCTensorMaskedChar.cu @@ -1,3 +1,5 @@ #include "../THCTensorMasked.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMasked.cu" #include "../THCGenerateCharType.h" diff --git a/aten/src/THC/generated/THCTensorMaskedDouble.cu b/aten/src/THC/generated/THCTensorMaskedDouble.cu index 063de42390..e5151fa9d5 100644 --- a/aten/src/THC/generated/THCTensorMaskedDouble.cu +++ b/aten/src/THC/generated/THCTensorMaskedDouble.cu @@ -1,3 +1,5 @@ #include "../THCTensorMasked.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMasked.cu" #include "../THCGenerateDoubleType.h" diff --git a/aten/src/THC/generated/THCTensorMaskedFloat.cu b/aten/src/THC/generated/THCTensorMaskedFloat.cu index 08da574a84..5768db0006 100644 --- a/aten/src/THC/generated/THCTensorMaskedFloat.cu +++ b/aten/src/THC/generated/THCTensorMaskedFloat.cu @@ -1,3 +1,5 @@ #include "../THCTensorMasked.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMasked.cu" #include "../THCGenerateFloatType.h" diff --git a/aten/src/THC/generated/THCTensorMaskedHalf.cu b/aten/src/THC/generated/THCTensorMaskedHalf.cu index caebd6ca65..9aceeb0316 100644 --- a/aten/src/THC/generated/THCTensorMaskedHalf.cu +++ b/aten/src/THC/generated/THCTensorMaskedHalf.cu @@ -1,3 +1,5 @@ #include "../THCTensorMasked.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMasked.cu" #include "../THCGenerateHalfType.h" diff --git a/aten/src/THC/generated/THCTensorMaskedInt.cu b/aten/src/THC/generated/THCTensorMaskedInt.cu index 1b4d1d516b..de9b1bdb42 100644 --- a/aten/src/THC/generated/THCTensorMaskedInt.cu +++ b/aten/src/THC/generated/THCTensorMaskedInt.cu @@ -1,3 +1,5 @@ #include "../THCTensorMasked.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMasked.cu" #include "../THCGenerateIntType.h" diff --git a/aten/src/THC/generated/THCTensorMaskedLong.cu b/aten/src/THC/generated/THCTensorMaskedLong.cu index 5fadbba635..a87b9109b0 100644 --- a/aten/src/THC/generated/THCTensorMaskedLong.cu +++ b/aten/src/THC/generated/THCTensorMaskedLong.cu @@ -1,3 +1,5 @@ #include "../THCTensorMasked.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMasked.cu" #include "../THCGenerateLongType.h" diff --git a/aten/src/THC/generated/THCTensorMaskedShort.cu b/aten/src/THC/generated/THCTensorMaskedShort.cu index e1f68234aa..6da7249afa 100644 --- a/aten/src/THC/generated/THCTensorMaskedShort.cu +++ b/aten/src/THC/generated/THCTensorMaskedShort.cu @@ -1,3 +1,5 @@ #include "../THCTensorMasked.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMasked.cu" #include "../THCGenerateShortType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareByte.cu b/aten/src/THC/generated/THCTensorMathCompareByte.cu index 4312d73eb1..3149e64dea 100644 --- a/aten/src/THC/generated/THCTensorMathCompareByte.cu +++ b/aten/src/THC/generated/THCTensorMathCompareByte.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompare.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompare.cu" #include "../THCGenerateByteType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareChar.cu b/aten/src/THC/generated/THCTensorMathCompareChar.cu index 0356a745bb..dc867f9eff 100644 --- a/aten/src/THC/generated/THCTensorMathCompareChar.cu +++ b/aten/src/THC/generated/THCTensorMathCompareChar.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompare.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompare.cu" #include "../THCGenerateCharType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareDouble.cu b/aten/src/THC/generated/THCTensorMathCompareDouble.cu index 59e406c944..24035dbc0a 100644 --- a/aten/src/THC/generated/THCTensorMathCompareDouble.cu +++ b/aten/src/THC/generated/THCTensorMathCompareDouble.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompare.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompare.cu" #include "../THCGenerateDoubleType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareFloat.cu b/aten/src/THC/generated/THCTensorMathCompareFloat.cu index 2efa6672f1..a2915ca995 100644 --- a/aten/src/THC/generated/THCTensorMathCompareFloat.cu +++ b/aten/src/THC/generated/THCTensorMathCompareFloat.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompare.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompare.cu" #include "../THCGenerateFloatType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareHalf.cu b/aten/src/THC/generated/THCTensorMathCompareHalf.cu index d07e6d7646..1eb849bcd4 100644 --- a/aten/src/THC/generated/THCTensorMathCompareHalf.cu +++ b/aten/src/THC/generated/THCTensorMathCompareHalf.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompare.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompare.cu" #include "../THCGenerateHalfType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareInt.cu b/aten/src/THC/generated/THCTensorMathCompareInt.cu index d1a58f1aa9..39a3d14409 100644 --- a/aten/src/THC/generated/THCTensorMathCompareInt.cu +++ b/aten/src/THC/generated/THCTensorMathCompareInt.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompare.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompare.cu" #include "../THCGenerateIntType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareLong.cu b/aten/src/THC/generated/THCTensorMathCompareLong.cu index ab70999e97..7e3cd25d64 100644 --- a/aten/src/THC/generated/THCTensorMathCompareLong.cu +++ b/aten/src/THC/generated/THCTensorMathCompareLong.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompare.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompare.cu" #include "../THCGenerateLongType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareShort.cu b/aten/src/THC/generated/THCTensorMathCompareShort.cu index e264c0c2ce..8d05507146 100644 --- a/aten/src/THC/generated/THCTensorMathCompareShort.cu +++ b/aten/src/THC/generated/THCTensorMathCompareShort.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompare.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompare.cu" #include "../THCGenerateShortType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareTByte.cu b/aten/src/THC/generated/THCTensorMathCompareTByte.cu index 3069ea4ea8..538eed3069 100644 --- a/aten/src/THC/generated/THCTensorMathCompareTByte.cu +++ b/aten/src/THC/generated/THCTensorMathCompareTByte.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompareT.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompareT.cu" #include "../THCGenerateByteType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareTChar.cu b/aten/src/THC/generated/THCTensorMathCompareTChar.cu index c536fa08e5..350a9fc32a 100644 --- a/aten/src/THC/generated/THCTensorMathCompareTChar.cu +++ b/aten/src/THC/generated/THCTensorMathCompareTChar.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompareT.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompareT.cu" #include "../THCGenerateCharType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareTDouble.cu b/aten/src/THC/generated/THCTensorMathCompareTDouble.cu index 65391600e2..6bf16c542d 100644 --- a/aten/src/THC/generated/THCTensorMathCompareTDouble.cu +++ b/aten/src/THC/generated/THCTensorMathCompareTDouble.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompareT.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompareT.cu" #include "../THCGenerateDoubleType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareTFloat.cu b/aten/src/THC/generated/THCTensorMathCompareTFloat.cu index f85726032a..b41741cc7c 100644 --- a/aten/src/THC/generated/THCTensorMathCompareTFloat.cu +++ b/aten/src/THC/generated/THCTensorMathCompareTFloat.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompareT.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompareT.cu" #include "../THCGenerateFloatType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareTHalf.cu b/aten/src/THC/generated/THCTensorMathCompareTHalf.cu index a3118311c3..3b51ddccfc 100644 --- a/aten/src/THC/generated/THCTensorMathCompareTHalf.cu +++ b/aten/src/THC/generated/THCTensorMathCompareTHalf.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompareT.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompareT.cu" #include "../THCGenerateHalfType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareTInt.cu b/aten/src/THC/generated/THCTensorMathCompareTInt.cu index 3168b2b4fe..977de58b67 100644 --- a/aten/src/THC/generated/THCTensorMathCompareTInt.cu +++ b/aten/src/THC/generated/THCTensorMathCompareTInt.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompareT.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompareT.cu" #include "../THCGenerateIntType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareTLong.cu b/aten/src/THC/generated/THCTensorMathCompareTLong.cu index 4566960126..02f91d9f2f 100644 --- a/aten/src/THC/generated/THCTensorMathCompareTLong.cu +++ b/aten/src/THC/generated/THCTensorMathCompareTLong.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompareT.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompareT.cu" #include "../THCGenerateLongType.h" diff --git a/aten/src/THC/generated/THCTensorMathCompareTShort.cu b/aten/src/THC/generated/THCTensorMathCompareTShort.cu index 46bf67a01d..bcc8b170c9 100644 --- a/aten/src/THC/generated/THCTensorMathCompareTShort.cu +++ b/aten/src/THC/generated/THCTensorMathCompareTShort.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathCompareT.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathCompareT.cu" #include "../THCGenerateShortType.h" diff --git a/aten/src/THC/generated/THCTensorMathPointwiseByte.cu b/aten/src/THC/generated/THCTensorMathPointwiseByte.cu index 7f26e88a0d..388c4771cf 100644 --- a/aten/src/THC/generated/THCTensorMathPointwiseByte.cu +++ b/aten/src/THC/generated/THCTensorMathPointwiseByte.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathPointwise.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathPointwise.cu" #include "../THCGenerateByteType.h" diff --git a/aten/src/THC/generated/THCTensorMathPointwiseChar.cu b/aten/src/THC/generated/THCTensorMathPointwiseChar.cu index d19694807b..38f2e54ba3 100644 --- a/aten/src/THC/generated/THCTensorMathPointwiseChar.cu +++ b/aten/src/THC/generated/THCTensorMathPointwiseChar.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathPointwise.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathPointwise.cu" #include "../THCGenerateCharType.h" diff --git a/aten/src/THC/generated/THCTensorMathPointwiseDouble.cu b/aten/src/THC/generated/THCTensorMathPointwiseDouble.cu index 2e9ad7248f..c01b8c902d 100644 --- a/aten/src/THC/generated/THCTensorMathPointwiseDouble.cu +++ b/aten/src/THC/generated/THCTensorMathPointwiseDouble.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathPointwise.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathPointwise.cu" #include "../THCGenerateDoubleType.h" diff --git a/aten/src/THC/generated/THCTensorMathPointwiseFloat.cu b/aten/src/THC/generated/THCTensorMathPointwiseFloat.cu index 061bd7034b..fa0a289746 100644 --- a/aten/src/THC/generated/THCTensorMathPointwiseFloat.cu +++ b/aten/src/THC/generated/THCTensorMathPointwiseFloat.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathPointwise.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathPointwise.cu" #include "../THCGenerateFloatType.h" diff --git a/aten/src/THC/generated/THCTensorMathPointwiseHalf.cu b/aten/src/THC/generated/THCTensorMathPointwiseHalf.cu index 42bef21339..ed15732267 100644 --- a/aten/src/THC/generated/THCTensorMathPointwiseHalf.cu +++ b/aten/src/THC/generated/THCTensorMathPointwiseHalf.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathPointwise.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathPointwise.cu" #include "../THCGenerateHalfType.h" diff --git a/aten/src/THC/generated/THCTensorMathPointwiseInt.cu b/aten/src/THC/generated/THCTensorMathPointwiseInt.cu index daa9cae00a..aec442b671 100644 --- a/aten/src/THC/generated/THCTensorMathPointwiseInt.cu +++ b/aten/src/THC/generated/THCTensorMathPointwiseInt.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathPointwise.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathPointwise.cu" #include "../THCGenerateIntType.h" diff --git a/aten/src/THC/generated/THCTensorMathPointwiseLong.cu b/aten/src/THC/generated/THCTensorMathPointwiseLong.cu index d5e38a7c12..6d6aefce0d 100644 --- a/aten/src/THC/generated/THCTensorMathPointwiseLong.cu +++ b/aten/src/THC/generated/THCTensorMathPointwiseLong.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathPointwise.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathPointwise.cu" #include "../THCGenerateLongType.h" diff --git a/aten/src/THC/generated/THCTensorMathPointwiseShort.cu b/aten/src/THC/generated/THCTensorMathPointwiseShort.cu index 6867ce2908..c087a38a14 100644 --- a/aten/src/THC/generated/THCTensorMathPointwiseShort.cu +++ b/aten/src/THC/generated/THCTensorMathPointwiseShort.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathPointwise.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathPointwise.cu" #include "../THCGenerateShortType.h" diff --git a/aten/src/THC/generated/THCTensorMathReduceByte.cu b/aten/src/THC/generated/THCTensorMathReduceByte.cu index 3806f4e334..b97c36725e 100644 --- a/aten/src/THC/generated/THCTensorMathReduceByte.cu +++ b/aten/src/THC/generated/THCTensorMathReduceByte.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathReduce.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathReduce.cu" #include "../THCGenerateByteType.h" diff --git a/aten/src/THC/generated/THCTensorMathReduceChar.cu b/aten/src/THC/generated/THCTensorMathReduceChar.cu index 5afe07666b..976da32ab1 100644 --- a/aten/src/THC/generated/THCTensorMathReduceChar.cu +++ b/aten/src/THC/generated/THCTensorMathReduceChar.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathReduce.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathReduce.cu" #include "../THCGenerateCharType.h" diff --git a/aten/src/THC/generated/THCTensorMathReduceDouble.cu b/aten/src/THC/generated/THCTensorMathReduceDouble.cu index e1bb7c4e67..b6d4eeb423 100644 --- a/aten/src/THC/generated/THCTensorMathReduceDouble.cu +++ b/aten/src/THC/generated/THCTensorMathReduceDouble.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathReduce.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathReduce.cu" #include "../THCGenerateDoubleType.h" diff --git a/aten/src/THC/generated/THCTensorMathReduceFloat.cu b/aten/src/THC/generated/THCTensorMathReduceFloat.cu index d0fdd5d49a..f341f41a8a 100644 --- a/aten/src/THC/generated/THCTensorMathReduceFloat.cu +++ b/aten/src/THC/generated/THCTensorMathReduceFloat.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathReduce.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathReduce.cu" #include "../THCGenerateFloatType.h" diff --git a/aten/src/THC/generated/THCTensorMathReduceHalf.cu b/aten/src/THC/generated/THCTensorMathReduceHalf.cu index f4d9d99784..77bab66664 100644 --- a/aten/src/THC/generated/THCTensorMathReduceHalf.cu +++ b/aten/src/THC/generated/THCTensorMathReduceHalf.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathReduce.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathReduce.cu" #include "../THCGenerateHalfType.h" diff --git a/aten/src/THC/generated/THCTensorMathReduceInt.cu b/aten/src/THC/generated/THCTensorMathReduceInt.cu index 98dd6a4031..18f6c2284f 100644 --- a/aten/src/THC/generated/THCTensorMathReduceInt.cu +++ b/aten/src/THC/generated/THCTensorMathReduceInt.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathReduce.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathReduce.cu" #include "../THCGenerateIntType.h" diff --git a/aten/src/THC/generated/THCTensorMathReduceLong.cu b/aten/src/THC/generated/THCTensorMathReduceLong.cu index 6c47b5d90f..87a0ef7056 100644 --- a/aten/src/THC/generated/THCTensorMathReduceLong.cu +++ b/aten/src/THC/generated/THCTensorMathReduceLong.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathReduce.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathReduce.cu" #include "../THCGenerateLongType.h" diff --git a/aten/src/THC/generated/THCTensorMathReduceShort.cu b/aten/src/THC/generated/THCTensorMathReduceShort.cu index de2117a2ee..49ce082beb 100644 --- a/aten/src/THC/generated/THCTensorMathReduceShort.cu +++ b/aten/src/THC/generated/THCTensorMathReduceShort.cu @@ -1,3 +1,5 @@ #include "../THCTensorMathReduce.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorMathReduce.cu" #include "../THCGenerateShortType.h" diff --git a/aten/src/THC/generated/THCTensorSortByte.cu b/aten/src/THC/generated/THCTensorSortByte.cu index 6103c4850b..cacebcb7f1 100644 --- a/aten/src/THC/generated/THCTensorSortByte.cu +++ b/aten/src/THC/generated/THCTensorSortByte.cu @@ -1,3 +1,5 @@ #include "../THCTensorSort.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorSort.cu" #include "../THCGenerateByteType.h" diff --git a/aten/src/THC/generated/THCTensorSortChar.cu b/aten/src/THC/generated/THCTensorSortChar.cu index bf10336fcc..774c406f6c 100644 --- a/aten/src/THC/generated/THCTensorSortChar.cu +++ b/aten/src/THC/generated/THCTensorSortChar.cu @@ -1,3 +1,5 @@ #include "../THCTensorSort.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorSort.cu" #include "../THCGenerateCharType.h" diff --git a/aten/src/THC/generated/THCTensorSortDouble.cu b/aten/src/THC/generated/THCTensorSortDouble.cu index 577af85804..b06144fff2 100644 --- a/aten/src/THC/generated/THCTensorSortDouble.cu +++ b/aten/src/THC/generated/THCTensorSortDouble.cu @@ -1,3 +1,5 @@ #include "../THCTensorSort.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorSort.cu" #include "../THCGenerateDoubleType.h" diff --git a/aten/src/THC/generated/THCTensorSortFloat.cu b/aten/src/THC/generated/THCTensorSortFloat.cu index dd84b46fd8..f9c07bcbdd 100644 --- a/aten/src/THC/generated/THCTensorSortFloat.cu +++ b/aten/src/THC/generated/THCTensorSortFloat.cu @@ -1,3 +1,5 @@ #include "../THCTensorSort.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorSort.cu" #include "../THCGenerateFloatType.h" diff --git a/aten/src/THC/generated/THCTensorSortHalf.cu b/aten/src/THC/generated/THCTensorSortHalf.cu index e2025f2428..85e381dca7 100644 --- a/aten/src/THC/generated/THCTensorSortHalf.cu +++ b/aten/src/THC/generated/THCTensorSortHalf.cu @@ -1,3 +1,5 @@ #include "../THCTensorSort.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorSort.cu" #include "../THCGenerateHalfType.h" diff --git a/aten/src/THC/generated/THCTensorSortInt.cu b/aten/src/THC/generated/THCTensorSortInt.cu index af7a153a82..fe5e8f448b 100644 --- a/aten/src/THC/generated/THCTensorSortInt.cu +++ b/aten/src/THC/generated/THCTensorSortInt.cu @@ -1,3 +1,5 @@ #include "../THCTensorSort.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorSort.cu" #include "../THCGenerateIntType.h" diff --git a/aten/src/THC/generated/THCTensorSortLong.cu b/aten/src/THC/generated/THCTensorSortLong.cu index c65ca268d7..ae7d94f1fd 100644 --- a/aten/src/THC/generated/THCTensorSortLong.cu +++ b/aten/src/THC/generated/THCTensorSortLong.cu @@ -1,3 +1,5 @@ #include "../THCTensorSort.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorSort.cu" #include "../THCGenerateLongType.h" diff --git a/aten/src/THC/generated/THCTensorSortShort.cu b/aten/src/THC/generated/THCTensorSortShort.cu index 03e1a9a725..cbcbad2b18 100644 --- a/aten/src/THC/generated/THCTensorSortShort.cu +++ b/aten/src/THC/generated/THCTensorSortShort.cu @@ -1,3 +1,5 @@ #include "../THCTensorSort.cuh" +#include "THCTensor.hpp" +#include "THCStream.hpp" #include "../generic/THCTensorSort.cu" #include "../THCGenerateShortType.h" diff --git a/aten/src/THC/generic/THCStorage.c b/aten/src/THC/generic/THCStorage.cpp index b0e64d1e1d..b371b82352 100644 --- a/aten/src/THC/generic/THCStorage.c +++ b/aten/src/THC/generic/THCStorage.cpp @@ -1,5 +1,5 @@ #ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "generic/THCStorage.c" +#define THC_GENERIC_FILE "generic/THCStorage.cpp" #else real* THCStorage_(data)(THCState *state, const THCStorage *self) @@ -61,7 +61,7 @@ THCStorage* THCStorage_(newWithAllocator)(THCState *state, ptrdiff_t size, THCStorage *storage = (THCStorage*)THAlloc(sizeof(THCStorage)); memset(storage, 0, sizeof(THCStorage)); - storage->refcount = 1; + new (&storage->refcount) std::atomic<int>(1); storage->flag = TH_STORAGE_REFCOUNTED | TH_STORAGE_RESIZABLE | TH_STORAGE_FREEMEM; storage->allocator = allocator; storage->allocatorContext = allocatorContext; @@ -169,7 +169,20 @@ void THCStorage_(clearFlag)(THCState *state, THCStorage *storage, const char fla void THCStorage_(retain)(THCState *state, THCStorage *self) { if(self && (self->flag & TH_STORAGE_REFCOUNTED)) - THAtomicIncrementRef(&self->refcount); + self->refcount++; +} + +int THCStorage_(retainIfLive)(THCState *state, THCStorage *storage) +{ + // TODO: Check if THC_STORAGE_REFCOUNTED? + int refcount = storage->refcount.load(); + while (refcount > 0) { + if (storage->refcount.compare_exchange_strong(refcount, refcount + 1)) { + return 1; + } + refcount = storage->refcount.load(); + } + return 0; } void THCStorage_(free)(THCState *state, THCStorage *self) @@ -177,7 +190,7 @@ void THCStorage_(free)(THCState *state, THCStorage *self) if(!(self->flag & TH_STORAGE_REFCOUNTED)) return; - if (THAtomicDecrementRef(&self->refcount)) + if (--self->refcount == 0) { if(self->flag & TH_STORAGE_FREEMEM) { THCudaCheck( @@ -186,6 +199,7 @@ void THCStorage_(free)(THCState *state, THCStorage *self) if(self->flag & TH_STORAGE_VIEW) { THCStorage_(free)(state, self->view); } + self->refcount.~atomic<int>(); THFree(self); } } diff --git a/aten/src/THC/generic/THCStorage.h b/aten/src/THC/generic/THCStorage.h index e768ec6ff1..f3a936c781 100644 --- a/aten/src/THC/generic/THCStorage.h +++ b/aten/src/THC/generic/THCStorage.h @@ -6,18 +6,7 @@ #define TH_STORAGE_RESIZABLE 2 #define TH_STORAGE_FREEMEM 4 -typedef struct THCStorage -{ - real *data; - ptrdiff_t size; - int refcount; - char flag; - THCDeviceAllocator *allocator; - void *allocatorContext; - struct THCStorage *view; - int device; -} THCStorage; - +typedef struct THCStorage THCStorage; THC_API real* THCStorage_(data)(THCState *state, const THCStorage*); THC_API ptrdiff_t THCStorage_(size)(THCState *state, const THCStorage*); @@ -51,6 +40,9 @@ THC_API void THCStorage_(setFlag)(THCState *state, THCStorage *storage, const ch THC_API void THCStorage_(clearFlag)(THCState *state, THCStorage *storage, const char flag); THC_API void THCStorage_(retain)(THCState *state, THCStorage *storage); +/* used by StorageSharing */ +THC_API int THCStorage_(retainIfLive)(THCState *state, THCStorage *storage); + THC_API void THCStorage_(free)(THCState *state, THCStorage *storage); THC_API void THCStorage_(resize)(THCState *state, THCStorage *storage, ptrdiff_t size); THC_API void THCStorage_(fill)(THCState *state, THCStorage *storage, real value); diff --git a/aten/src/THC/generic/THCStorage.hpp b/aten/src/THC/generic/THCStorage.hpp new file mode 100644 index 0000000000..871f7690ef --- /dev/null +++ b/aten/src/THC/generic/THCStorage.hpp @@ -0,0 +1,17 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/THCStorage.hpp" +#else + +typedef struct THCStorage +{ + real *data; + ptrdiff_t size; + std::atomic<int> refcount; + char flag; + THCDeviceAllocator *allocator; + void *allocatorContext; + struct THCStorage *view; + int device; +} THCStorage; + +#endif diff --git a/aten/src/THC/generic/THCStorageCopy.c b/aten/src/THC/generic/THCStorageCopy.cpp index ac2e0b05f7..9352d1b997 100644 --- a/aten/src/THC/generic/THCStorageCopy.c +++ b/aten/src/THC/generic/THCStorageCopy.cpp @@ -1,5 +1,5 @@ #ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "generic/THCStorageCopy.c" +#define THC_GENERIC_FILE "generic/THCStorageCopy.cpp" #else void THCStorage_(copyCPU)(THCState *state, THCStorage *self, struct THStorage *src) diff --git a/aten/src/THC/generic/THCTensor.c b/aten/src/THC/generic/THCTensor.cpp index f94d8023e6..3a127fffaa 100644 --- a/aten/src/THC/generic/THCTensor.c +++ b/aten/src/THC/generic/THCTensor.cpp @@ -1,5 +1,5 @@ #ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "generic/THCTensor.c" +#define THC_GENERIC_FILE "generic/THCTensor.cpp" #else /**** access methods ****/ @@ -695,7 +695,7 @@ ptrdiff_t THCTensor_(nElement)(THCState *state, const THCTensor *self) void THCTensor_(retain)(THCState *state, THCTensor *self) { if(self->flag & TH_TENSOR_REFCOUNTED) - THAtomicIncrementRef(&self->refcount); + self->refcount++; } void THCTensor_(free)(THCState *state, THCTensor *self) @@ -705,12 +705,13 @@ void THCTensor_(free)(THCState *state, THCTensor *self) if(self->flag & TH_TENSOR_REFCOUNTED) { - if(THAtomicDecrementRef(&self->refcount)) + if(--self->refcount == 0) { THFree(self->size); THFree(self->stride); if(self->storage) THCStorage_(free)(state, self->storage); + self->refcount.~atomic<int>(); THFree(self); } } @@ -728,7 +729,7 @@ void THCTensor_(freeCopyTo)(THCState *state, THCTensor *self, THCTensor *dst) static void THCTensor_(rawInit)(THCState *state, THCTensor *self) { - self->refcount = 1; + new (&self->refcount) std::atomic<int>(1); self->storage = THCStorage_(new)(state); self->storageOffset = 0; self->size = NULL; diff --git a/aten/src/THC/generic/THCTensor.h b/aten/src/THC/generic/THCTensor.h index d6f5bbaaf0..e2ccf50eed 100644 --- a/aten/src/THC/generic/THCTensor.h +++ b/aten/src/THC/generic/THCTensor.h @@ -4,19 +4,7 @@ #define TH_TENSOR_REFCOUNTED 1 -typedef struct THCTensor -{ - int64_t *size; - int64_t *stride; - int nDimension; - - THCStorage *storage; - ptrdiff_t storageOffset; - int refcount; - - char flag; - -} THCTensor; +typedef struct THCTensor THCTensor; /**** access methods ****/ diff --git a/aten/src/THC/generic/THCTensor.hpp b/aten/src/THC/generic/THCTensor.hpp new file mode 100644 index 0000000000..ebffb56f47 --- /dev/null +++ b/aten/src/THC/generic/THCTensor.hpp @@ -0,0 +1,19 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/THCTensor.hpp" +#else + +typedef struct THCTensor +{ + int64_t *size; + int64_t *stride; + int nDimension; + + THCStorage *storage; + ptrdiff_t storageOffset; + std::atomic<int> refcount; + + char flag; + +} THCTensor; + +#endif diff --git a/aten/src/THC/generic/THCTensorCopy.c b/aten/src/THC/generic/THCTensorCopy.cpp index f28f8b81ce..21662935d2 100644 --- a/aten/src/THC/generic/THCTensorCopy.c +++ b/aten/src/THC/generic/THCTensorCopy.cpp @@ -1,5 +1,5 @@ #ifndef THC_GENERIC_FILE -#define THC_GENERIC_FILE "generic/THCTensorCopy.c" +#define THC_GENERIC_FILE "generic/THCTensorCopy.cpp" #else /* specific methods */ diff --git a/aten/src/THCS/THCSTensor.cpp b/aten/src/THCS/THCSTensor.cpp index b6cae0ef84..83de99b339 100644 --- a/aten/src/THCS/THCSTensor.cpp +++ b/aten/src/THCS/THCSTensor.cpp @@ -1,4 +1,4 @@ -#include "THCSTensor.h" +#include "THCSTensor.hpp" #include "generic/THCSTensor.cpp" #include "THCSGenerateAllTypes.h" diff --git a/aten/src/THCS/THCSTensor.cu b/aten/src/THCS/THCSTensor.cu index da098e60e3..1493f18a7c 100644 --- a/aten/src/THCS/THCSTensor.cu +++ b/aten/src/THCS/THCSTensor.cu @@ -1,4 +1,5 @@ -#include "THCSTensor.h" +#include "THCSTensor.hpp" +#include "THCTensor.hpp" #include "THCApply.cuh" #include "THCTensorSort.cuh" #include "THCTensorMathPointwise.cuh" diff --git a/aten/src/THCS/THCSTensor.hpp b/aten/src/THCS/THCSTensor.hpp new file mode 100644 index 0000000000..1c6b536a4d --- /dev/null +++ b/aten/src/THCS/THCSTensor.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "THCSTensor.h" +#include "THCTensor.hpp" +#include "THTensor.hpp" + +#include "generic/THCSTensor.hpp" +#include "THCSGenerateAllTypes.h" diff --git a/aten/src/THCS/generic/THCSTensor.cpp b/aten/src/THCS/generic/THCSTensor.cpp index b8875ed800..c6e112cbca 100644 --- a/aten/src/THCS/generic/THCSTensor.cpp +++ b/aten/src/THCS/generic/THCSTensor.cpp @@ -74,7 +74,7 @@ static void THCSTensor_(rawInit)(THCState *state, THCSTensor *self) self->coalesced = 0; self->nnz = 0; // self->flag = TH_TENSOR_REFCOUNTED; - self->refcount = 1; + new (&self->refcount) std::atomic<int>(1); } THCSTensor* THCSTensor_(rawResize)(THCState *state, THCSTensor *self, int nDimI, int nDimV, int64_t *size) { @@ -381,18 +381,19 @@ void THCSTensor_(free)(THCState *state, THCSTensor *self) { if(!self) return; - if(THAtomicDecrementRef(&self->refcount)) + if(--self->refcount == 0) { THFree(self->size); THCIndexTensor_(free)(state, self->indices); THCTensor_(free)(state, self->values); + self->refcount.~atomic<int>(); THFree(self); } } void THCSTensor_(retain)(THCState *state, THCSTensor *self) { - THAtomicIncrementRef(&self->refcount); + self->refcount++; } int THCSTensor_(checkGPU)(THCState *state, unsigned int nSparseTensors, unsigned int nTensors, ...) diff --git a/aten/src/THCS/generic/THCSTensor.cu b/aten/src/THCS/generic/THCSTensor.cu index 368c7a7fad..a0f68ad426 100644 --- a/aten/src/THCS/generic/THCSTensor.cu +++ b/aten/src/THCS/generic/THCSTensor.cu @@ -3,6 +3,7 @@ #else #include "THCThrustAllocator.cuh" +#include "THCTensor.hpp" #include <thrust/device_ptr.h> #include <thrust/device_vector.h> #include <thrust/gather.h> diff --git a/aten/src/THCS/generic/THCSTensor.h b/aten/src/THCS/generic/THCSTensor.h index fa4871fd42..c9172a978d 100644 --- a/aten/src/THCS/generic/THCSTensor.h +++ b/aten/src/THCS/generic/THCSTensor.h @@ -2,22 +2,7 @@ #define THCS_GENERIC_FILE "generic/THCSTensor.h" #else -typedef struct THCSTensor -{ // Stored in COO format, indices + values - int64_t *size; - ptrdiff_t nnz; - int nDimensionI; // dimension of indices - int nDimensionV; // dimension of values - - // 2-D tensor of nDim x nnz of indices. May have nnz dim bigger than nnz - // as buffer, so we keep track of both - THCIndexTensor *indices; - THCTensor *values; - // Some math operations can only be performed on ordered sparse tensors - int coalesced; - int refcount; - -} THCSTensor; +typedef struct THCSTensor THCSTensor; /**** access methods ****/ TH_API int THCSTensor_(nDimension)(THCState *state, const THCSTensor *self); diff --git a/aten/src/THCS/generic/THCSTensor.hpp b/aten/src/THCS/generic/THCSTensor.hpp new file mode 100644 index 0000000000..a004b84ae4 --- /dev/null +++ b/aten/src/THCS/generic/THCSTensor.hpp @@ -0,0 +1,24 @@ +#ifndef THCS_GENERIC_FILE +#define THCS_GENERIC_FILE "generic/THCSTensor.hpp" +#else + +#include <atomic> + +typedef struct THCSTensor +{ // Stored in COO format, indices + values + int64_t *size; + ptrdiff_t nnz; + int nDimensionI; // dimension of indices + int nDimensionV; // dimension of values + + // 2-D tensor of nDim x nnz of indices. May have nnz dim bigger than nnz + // as buffer, so we keep track of both + THCIndexTensor *indices; + THCTensor *values; + // Some math operations can only be performed on ordered sparse tensors + int coalesced; + std::atomic<int> refcount; + +} THCSTensor; + +#endif diff --git a/aten/src/THCUNN/BatchNormalization.cu b/aten/src/THCUNN/BatchNormalization.cu index 865323a16a..03531b3e84 100644 --- a/aten/src/THCUNN/BatchNormalization.cu +++ b/aten/src/THCUNN/BatchNormalization.cu @@ -2,6 +2,7 @@ #include "common.h" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCTensor.hpp" #include "THCDeviceTensor.cuh" #include "THCDeviceTensorUtils.cuh" diff --git a/aten/src/THCUNN/Col2Im.cu b/aten/src/THCUNN/Col2Im.cu index 17cd18488e..d7fd995de4 100644 --- a/aten/src/THCUNN/Col2Im.cu +++ b/aten/src/THCUNN/Col2Im.cu @@ -1,6 +1,8 @@ #include "THCUNN.h" #include "common.h" #include "im2col.h" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" diff --git a/aten/src/THCUNN/Im2Col.cu b/aten/src/THCUNN/Im2Col.cu index 91fa6ed12b..95bdcd4e8b 100644 --- a/aten/src/THCUNN/Im2Col.cu +++ b/aten/src/THCUNN/Im2Col.cu @@ -4,6 +4,8 @@ #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include "generic/Im2Col.cu" #include "THCGenerateFloatTypes.h" diff --git a/aten/src/THCUNN/IndexLinear.cu b/aten/src/THCUNN/IndexLinear.cu index 2729f92772..2422af9730 100644 --- a/aten/src/THCUNN/IndexLinear.cu +++ b/aten/src/THCUNN/IndexLinear.cu @@ -2,6 +2,8 @@ #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" #include "THCAtomics.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #define divup(a, b) ((a) + (b) - 1) / (b) const int THREADS_PER_BLOCK = 256; diff --git a/aten/src/THCUNN/LogSoftMax.cu b/aten/src/THCUNN/LogSoftMax.cu index 23338d9463..8eaaa21c6f 100644 --- a/aten/src/THCUNN/LogSoftMax.cu +++ b/aten/src/THCUNN/LogSoftMax.cu @@ -1,5 +1,6 @@ #include "THCUNN.h" #include "THCHalf.h" +#include "THCTensor.hpp" #include "SoftMaxCommon.cuh" diff --git a/aten/src/THCUNN/LookupTableBag.cu b/aten/src/THCUNN/LookupTableBag.cu index bf3aa32bfc..c2ba9f5208 100644 --- a/aten/src/THCUNN/LookupTableBag.cu +++ b/aten/src/THCUNN/LookupTableBag.cu @@ -1,5 +1,6 @@ #include "THCUNN.h" #include "common.h" +#include "THCTensor.hpp" #include "THCThrustAllocator.cuh" #include <thrust/device_ptr.h> diff --git a/aten/src/THCUNN/MultiLabelMarginCriterion.cu b/aten/src/THCUNN/MultiLabelMarginCriterion.cu index a8dc15e211..13b432c15c 100644 --- a/aten/src/THCUNN/MultiLabelMarginCriterion.cu +++ b/aten/src/THCUNN/MultiLabelMarginCriterion.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "THCReduceApplyUtils.cuh" #include "THCHalf.h" diff --git a/aten/src/THCUNN/MultiMarginCriterion.cu b/aten/src/THCUNN/MultiMarginCriterion.cu index 89e07ac6f7..c2fa213462 100644 --- a/aten/src/THCUNN/MultiMarginCriterion.cu +++ b/aten/src/THCUNN/MultiMarginCriterion.cu @@ -2,6 +2,8 @@ #include "common.h" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #define MULTIMARGIN_THREADS 128 diff --git a/aten/src/THCUNN/PReLU.cu b/aten/src/THCUNN/PReLU.cu index 395e4a1ec5..cdc6b2b71a 100644 --- a/aten/src/THCUNN/PReLU.cu +++ b/aten/src/THCUNN/PReLU.cu @@ -2,6 +2,7 @@ #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" #include <THC/THCApply.cuh> +#include "THCTensor.hpp" #include "common.h" diff --git a/aten/src/THCUNN/SoftMax.cu b/aten/src/THCUNN/SoftMax.cu index 1e58f7d38e..615ef29e68 100644 --- a/aten/src/THCUNN/SoftMax.cu +++ b/aten/src/THCUNN/SoftMax.cu @@ -1,5 +1,6 @@ #include "THCUNN.h" #include "THCHalf.h" +#include "THCTensor.hpp" #include "SoftMaxCommon.cuh" diff --git a/aten/src/THCUNN/SparseLinear.cu b/aten/src/THCUNN/SparseLinear.cu index 9110bbcaca..cd9b659085 100644 --- a/aten/src/THCUNN/SparseLinear.cu +++ b/aten/src/THCUNN/SparseLinear.cu @@ -1,6 +1,7 @@ #include "THCUNN.h" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCTensor.hpp" #include <cusparse.h> diff --git a/aten/src/THCUNN/SpatialAdaptiveAveragePooling.cu b/aten/src/THCUNN/SpatialAdaptiveAveragePooling.cu index fe3d6e2b18..2c671dad5a 100644 --- a/aten/src/THCUNN/SpatialAdaptiveAveragePooling.cu +++ b/aten/src/THCUNN/SpatialAdaptiveAveragePooling.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" #include "THCAtomics.cuh" diff --git a/aten/src/THCUNN/SpatialAdaptiveMaxPooling.cu b/aten/src/THCUNN/SpatialAdaptiveMaxPooling.cu index 01de3a8186..b49e86f8a7 100644 --- a/aten/src/THCUNN/SpatialAdaptiveMaxPooling.cu +++ b/aten/src/THCUNN/SpatialAdaptiveMaxPooling.cu @@ -2,6 +2,7 @@ #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" #include "THCAtomics.cuh" +#include "THCTensor.hpp" #define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit diff --git a/aten/src/THCUNN/SpatialAveragePooling.cu b/aten/src/THCUNN/SpatialAveragePooling.cu index 5f77e06567..ce9941a623 100644 --- a/aten/src/THCUNN/SpatialAveragePooling.cu +++ b/aten/src/THCUNN/SpatialAveragePooling.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" #include "common.h" diff --git a/aten/src/THCUNN/SpatialConvolutionLocal.cu b/aten/src/THCUNN/SpatialConvolutionLocal.cu index e5b1f98748..17801d52b1 100644 --- a/aten/src/THCUNN/SpatialConvolutionLocal.cu +++ b/aten/src/THCUNN/SpatialConvolutionLocal.cu @@ -4,6 +4,8 @@ #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include "generic/SpatialConvolutionLocal.cu" #include "THCGenerateFloatTypes.h" diff --git a/aten/src/THCUNN/SpatialConvolutionMM.cu b/aten/src/THCUNN/SpatialConvolutionMM.cu index 2a88047a16..4a59acb297 100644 --- a/aten/src/THCUNN/SpatialConvolutionMM.cu +++ b/aten/src/THCUNN/SpatialConvolutionMM.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "im2col.h" diff --git a/aten/src/THCUNN/SpatialCrossMapLRN.cu b/aten/src/THCUNN/SpatialCrossMapLRN.cu index cd37320b18..cd6f081b13 100644 --- a/aten/src/THCUNN/SpatialCrossMapLRN.cu +++ b/aten/src/THCUNN/SpatialCrossMapLRN.cu @@ -1,6 +1,8 @@ #include "THCUNN.h" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include "common.h" template <typename Dtype, typename Acctype> diff --git a/aten/src/THCUNN/SpatialDepthwiseConvolution.cu b/aten/src/THCUNN/SpatialDepthwiseConvolution.cu index 84c283d192..a8c9e71a0d 100644 --- a/aten/src/THCUNN/SpatialDepthwiseConvolution.cu +++ b/aten/src/THCUNN/SpatialDepthwiseConvolution.cu @@ -2,6 +2,7 @@ // port from Caffe #include "THCUNN.h" +#include "THCTensor.hpp" #include "THCDeviceTensor.cuh" #include "THCDeviceTensorUtils.cuh" #include "THCNumerics.cuh" diff --git a/aten/src/THCUNN/SpatialDilatedConvolution.cu b/aten/src/THCUNN/SpatialDilatedConvolution.cu index a4a8e382cc..b8e96024fd 100644 --- a/aten/src/THCUNN/SpatialDilatedConvolution.cu +++ b/aten/src/THCUNN/SpatialDilatedConvolution.cu @@ -4,6 +4,8 @@ #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include "generic/SpatialDilatedConvolution.cu" #include "THCGenerateFloatTypes.h" diff --git a/aten/src/THCUNN/SpatialDilatedMaxPooling.cu b/aten/src/THCUNN/SpatialDilatedMaxPooling.cu index 167076c233..e97b4ba496 100644 --- a/aten/src/THCUNN/SpatialDilatedMaxPooling.cu +++ b/aten/src/THCUNN/SpatialDilatedMaxPooling.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" #include "common.h" diff --git a/aten/src/THCUNN/SpatialFullDilatedConvolution.cu b/aten/src/THCUNN/SpatialFullDilatedConvolution.cu index 77d9811be2..61e1fe5910 100644 --- a/aten/src/THCUNN/SpatialFullDilatedConvolution.cu +++ b/aten/src/THCUNN/SpatialFullDilatedConvolution.cu @@ -1,5 +1,6 @@ #include "THCUNN.h" #include "im2col.h" +#include "THCTensor.hpp" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" diff --git a/aten/src/THCUNN/SpatialMaxUnpooling.cu b/aten/src/THCUNN/SpatialMaxUnpooling.cu index 8990907179..56488fdfee 100644 --- a/aten/src/THCUNN/SpatialMaxUnpooling.cu +++ b/aten/src/THCUNN/SpatialMaxUnpooling.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" template <typename Dtype> diff --git a/aten/src/THCUNN/SpatialReflectionPadding.cu b/aten/src/THCUNN/SpatialReflectionPadding.cu index 3fd27516cb..96472eed08 100644 --- a/aten/src/THCUNN/SpatialReflectionPadding.cu +++ b/aten/src/THCUNN/SpatialReflectionPadding.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "THCDeviceTensor.cuh" #include "THCDeviceTensorUtils.cuh" diff --git a/aten/src/THCUNN/SpatialReplicationPadding.cu b/aten/src/THCUNN/SpatialReplicationPadding.cu index 3d2dfe2bdc..f63c2090d5 100644 --- a/aten/src/THCUNN/SpatialReplicationPadding.cu +++ b/aten/src/THCUNN/SpatialReplicationPadding.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "THCDeviceTensor.cuh" #include "THCDeviceTensorUtils.cuh" diff --git a/aten/src/THCUNN/SpatialSubSampling.cu b/aten/src/THCUNN/SpatialSubSampling.cu index 914590758a..bb04846622 100644 --- a/aten/src/THCUNN/SpatialSubSampling.cu +++ b/aten/src/THCUNN/SpatialSubSampling.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" #include "THCAtomics.cuh" diff --git a/aten/src/THCUNN/SpatialUpSamplingBilinear.cu b/aten/src/THCUNN/SpatialUpSamplingBilinear.cu index 11f37b4654..07daa0e9fe 100644 --- a/aten/src/THCUNN/SpatialUpSamplingBilinear.cu +++ b/aten/src/THCUNN/SpatialUpSamplingBilinear.cu @@ -1,6 +1,7 @@ // Adapted from interp.cpp from Caffe util by Pauline Luc // Originally developed by George Papandreou #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "linear_upsampling.h" #include "THCDeviceTensor.cuh" diff --git a/aten/src/THCUNN/SpatialUpSamplingNearest.cu b/aten/src/THCUNN/SpatialUpSamplingNearest.cu index 057f0b1b0f..0f39cc8322 100644 --- a/aten/src/THCUNN/SpatialUpSamplingNearest.cu +++ b/aten/src/THCUNN/SpatialUpSamplingNearest.cu @@ -1,5 +1,6 @@ #include "THCUNN.h" #include "common.h" +#include "THCTensor.hpp" #include <thrust/transform.h> #include <thrust/reduce.h> diff --git a/aten/src/THCUNN/TemporalConvolution.cu b/aten/src/THCUNN/TemporalConvolution.cu index f4e9c697cf..af12169d7a 100644 --- a/aten/src/THCUNN/TemporalConvolution.cu +++ b/aten/src/THCUNN/TemporalConvolution.cu @@ -2,6 +2,7 @@ #include "common.h" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCTensor.hpp" #include "generic/TemporalConvolution.cu" #include "THCGenerateFloatTypes.h" diff --git a/aten/src/THCUNN/TemporalMaxPooling.cu b/aten/src/THCUNN/TemporalMaxPooling.cu index 384e409ff7..2508f83517 100644 --- a/aten/src/THCUNN/TemporalMaxPooling.cu +++ b/aten/src/THCUNN/TemporalMaxPooling.cu @@ -3,6 +3,8 @@ #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" #include "THCAtomics.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #define TEMPORAL_MAX_POOLING_THREADS 1024 diff --git a/aten/src/THCUNN/TemporalReflectionPadding.cu b/aten/src/THCUNN/TemporalReflectionPadding.cu index ccfa002c26..4dd4da84c0 100644 --- a/aten/src/THCUNN/TemporalReflectionPadding.cu +++ b/aten/src/THCUNN/TemporalReflectionPadding.cu @@ -5,6 +5,8 @@ #include "THCDeviceUtils.cuh" #include "THCReduceApplyUtils.cuh" #include <THC/THCApply.cuh> +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" diff --git a/aten/src/THCUNN/TemporalReplicationPadding.cu b/aten/src/THCUNN/TemporalReplicationPadding.cu index 8eed759d2d..2c812bda8d 100644 --- a/aten/src/THCUNN/TemporalReplicationPadding.cu +++ b/aten/src/THCUNN/TemporalReplicationPadding.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "THCDeviceTensor.cuh" #include "THCDeviceTensorUtils.cuh" diff --git a/aten/src/THCUNN/TemporalRowConvolution.cu b/aten/src/THCUNN/TemporalRowConvolution.cu index dc3b18c348..745fef8075 100644 --- a/aten/src/THCUNN/TemporalRowConvolution.cu +++ b/aten/src/THCUNN/TemporalRowConvolution.cu @@ -4,6 +4,8 @@ #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" +#include "THCTensor.hpp" +#include "THCStorage.hpp" #include "generic/TemporalRowConvolution.cu" diff --git a/aten/src/THCUNN/TemporalUpSamplingLinear.cu b/aten/src/THCUNN/TemporalUpSamplingLinear.cu index 98e4f28339..89b0c37b1f 100644 --- a/aten/src/THCUNN/TemporalUpSamplingLinear.cu +++ b/aten/src/THCUNN/TemporalUpSamplingLinear.cu @@ -1,6 +1,7 @@ // Adapted from interp.cpp from Caffe util by Pauline Luc // Originally developed by George Papandreou #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "linear_upsampling.h" #include "THCDeviceTensor.cuh" diff --git a/aten/src/THCUNN/TemporalUpSamplingNearest.cu b/aten/src/THCUNN/TemporalUpSamplingNearest.cu index f5dd5d9647..c69492cc0e 100644 --- a/aten/src/THCUNN/TemporalUpSamplingNearest.cu +++ b/aten/src/THCUNN/TemporalUpSamplingNearest.cu @@ -1,5 +1,6 @@ #include "THCUNN.h" #include "common.h" +#include "THCTensor.hpp" #include <thrust/transform.h> #include <thrust/reduce.h> diff --git a/aten/src/THCUNN/VolumetricAdaptiveAveragePooling.cu b/aten/src/THCUNN/VolumetricAdaptiveAveragePooling.cu index a909184f81..84e2c7f706 100644 --- a/aten/src/THCUNN/VolumetricAdaptiveAveragePooling.cu +++ b/aten/src/THCUNN/VolumetricAdaptiveAveragePooling.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" #include "THCAtomics.cuh" diff --git a/aten/src/THCUNN/VolumetricAdaptiveMaxPooling.cu b/aten/src/THCUNN/VolumetricAdaptiveMaxPooling.cu index 3ac9032dd0..0f0575c6ea 100644 --- a/aten/src/THCUNN/VolumetricAdaptiveMaxPooling.cu +++ b/aten/src/THCUNN/VolumetricAdaptiveMaxPooling.cu @@ -2,6 +2,7 @@ #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" #include "THCAtomics.cuh" +#include "THCTensor.hpp" #define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit diff --git a/aten/src/THCUNN/VolumetricAveragePooling.cu b/aten/src/THCUNN/VolumetricAveragePooling.cu index cdc3d04ed4..610127c177 100644 --- a/aten/src/THCUNN/VolumetricAveragePooling.cu +++ b/aten/src/THCUNN/VolumetricAveragePooling.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "THCDeviceTensor.cuh" #include "THCDeviceTensorUtils.cuh" diff --git a/aten/src/THCUNN/VolumetricConvolution.cu b/aten/src/THCUNN/VolumetricConvolution.cu index 78f45f93f2..b45f7510b1 100644 --- a/aten/src/THCUNN/VolumetricConvolution.cu +++ b/aten/src/THCUNN/VolumetricConvolution.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "THCHalf.h" #include "THCHalfAutoNumerics.cuh" diff --git a/aten/src/THCUNN/VolumetricDilatedConvolution.cu b/aten/src/THCUNN/VolumetricDilatedConvolution.cu index d82e02dc2b..8a32c70b67 100644 --- a/aten/src/THCUNN/VolumetricDilatedConvolution.cu +++ b/aten/src/THCUNN/VolumetricDilatedConvolution.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "vol2col.h" #include "THCHalf.h" diff --git a/aten/src/THCUNN/VolumetricDilatedMaxPooling.cu b/aten/src/THCUNN/VolumetricDilatedMaxPooling.cu index 2b07349821..8ded1856ba 100644 --- a/aten/src/THCUNN/VolumetricDilatedMaxPooling.cu +++ b/aten/src/THCUNN/VolumetricDilatedMaxPooling.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "THCDeviceTensor.cuh" #include "THCDeviceTensorUtils.cuh" diff --git a/aten/src/THCUNN/VolumetricFullDilatedConvolution.cu b/aten/src/THCUNN/VolumetricFullDilatedConvolution.cu index 47173f2463..c5c7196bac 100644 --- a/aten/src/THCUNN/VolumetricFullDilatedConvolution.cu +++ b/aten/src/THCUNN/VolumetricFullDilatedConvolution.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "vol2col.h" #include "THCHalf.h" diff --git a/aten/src/THCUNN/VolumetricMaxUnpooling.cu b/aten/src/THCUNN/VolumetricMaxUnpooling.cu index 83f2aeb92d..eac3b2d17a 100644 --- a/aten/src/THCUNN/VolumetricMaxUnpooling.cu +++ b/aten/src/THCUNN/VolumetricMaxUnpooling.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "THCDeviceTensor.cuh" #include "THCDeviceTensorUtils.cuh" diff --git a/aten/src/THCUNN/VolumetricReplicationPadding.cu b/aten/src/THCUNN/VolumetricReplicationPadding.cu index bac505a3d7..27ea3ecad3 100644 --- a/aten/src/THCUNN/VolumetricReplicationPadding.cu +++ b/aten/src/THCUNN/VolumetricReplicationPadding.cu @@ -1,4 +1,5 @@ #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "THCDeviceTensor.cuh" #include "THCDeviceTensorUtils.cuh" diff --git a/aten/src/THCUNN/VolumetricUpSamplingNearest.cu b/aten/src/THCUNN/VolumetricUpSamplingNearest.cu index 9139d0e99a..4aea7edae9 100644 --- a/aten/src/THCUNN/VolumetricUpSamplingNearest.cu +++ b/aten/src/THCUNN/VolumetricUpSamplingNearest.cu @@ -1,5 +1,6 @@ #include "THCUNN.h" #include "common.h" +#include "THCTensor.hpp" #include <thrust/transform.h> #include <thrust/reduce.h> diff --git a/aten/src/THCUNN/VolumetricUpSamplingTrilinear.cu b/aten/src/THCUNN/VolumetricUpSamplingTrilinear.cu index 03d506ac56..0f353b91ac 100644 --- a/aten/src/THCUNN/VolumetricUpSamplingTrilinear.cu +++ b/aten/src/THCUNN/VolumetricUpSamplingTrilinear.cu @@ -1,6 +1,7 @@ // Adapted from interp.cpp from Caffe util by Pauline Luc // Originally developed by George Papandreou #include "THCUNN.h" +#include "THCTensor.hpp" #include "common.h" #include "linear_upsampling.h" #include "THCDeviceTensor.cuh" diff --git a/aten/src/THNN/generic/ClassNLLCriterion.c b/aten/src/THNN/generic/ClassNLLCriterion.c index e55ef09cd3..eae712f53d 100644 --- a/aten/src/THNN/generic/ClassNLLCriterion.c +++ b/aten/src/THNN/generic/ClassNLLCriterion.c @@ -34,7 +34,7 @@ void THNN_(ClassNLLCriterion_updateOutput)( int batch_size = THTensor_(size)(input, 0); THTensor_(resize1d)(output, batch_size); - int invalid_target = -1; // We cannot throw an exception inside omp parallel + std::atomic<int> invalid_target(-1); // We cannot throw an exception inside omp parallel int i; #pragma omp parallel for private(i) for (i = 0; i < batch_size; i++) { @@ -48,12 +48,13 @@ void THNN_(ClassNLLCriterion_updateOutput)( real cur_weight = weights ? THTensor_fastGet1d(weights, cur_target) : 1.0f; THTensor_fastSet1d(output, i, -THTensor_fastGet2d(input, i, cur_target) * cur_weight); } else { - THAtomicCompareAndSwap(&invalid_target, -1, cur_target); + int tmp = -1; + invalid_target.compare_exchange_strong(tmp, cur_target); } } - if (invalid_target >= 0) { - THError("Target %d out of bounds", invalid_target); + if (invalid_target.load() >= 0) { + THError("Target %d out of bounds", invalid_target.load()); } return; diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp index ad276f9e5f..4cb146d475 100644 --- a/aten/src/THNN/init.cpp +++ b/aten/src/THNN/init.cpp @@ -1,6 +1,8 @@ #include "TH.h" #include "THNN.h" +#include "THTensor.hpp" + #define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME) #define nn_(NAME) TH_CONCAT_3(nn_, Real, NAME) diff --git a/aten/src/THS/THSTensor.cpp b/aten/src/THS/THSTensor.cpp index a862579e49..433b5e7f16 100644 --- a/aten/src/THS/THSTensor.cpp +++ b/aten/src/THS/THSTensor.cpp @@ -1,4 +1,6 @@ -#include "THSTensor.h" +#include "THSTensor.hpp" + +#include <new> #include "generic/THSTensor.cpp" #include "THSGenerateAllTypes.h" diff --git a/aten/src/THS/THSTensor.hpp b/aten/src/THS/THSTensor.hpp new file mode 100644 index 0000000000..69e6efaeb9 --- /dev/null +++ b/aten/src/THS/THSTensor.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "THSTensor.h" +#include "THTensor.hpp" + +#include <atomic> + +#include "generic/THSTensor.hpp" +#include "THSGenerateAllTypes.h" diff --git a/aten/src/THS/generic/THSTensor.cpp b/aten/src/THS/generic/THSTensor.cpp index 36a6f83c77..043ce013c6 100644 --- a/aten/src/THS/generic/THSTensor.cpp +++ b/aten/src/THS/generic/THSTensor.cpp @@ -65,7 +65,7 @@ THTensor *THSTensor_(newValues)(const THSTensor *self) { /*** Helper methods ***/ static void THSTensor_(rawInit)(THSTensor *self) { - self->refcount = 1; + new (&self->refcount) std::atomic<int>(1); self->size = NULL; self->indices = THLongTensor_new(); self->values = THTensor_(new)(); @@ -547,18 +547,19 @@ void THSTensor_(free)(THSTensor *self) { if(!self) return; - if(THAtomicDecrementRef(&self->refcount)) + if(--self->refcount == 0) { THFree(self->size); THLongTensor_free(self->indices); THTensor_(free)(self->values); + self->refcount.~atomic<int>(); THFree(self); } } void THSTensor_(retain)(THSTensor *self) { - THAtomicIncrementRef(&self->refcount); + self->refcount++; } #endif diff --git a/aten/src/THS/generic/THSTensor.h b/aten/src/THS/generic/THSTensor.h index bab3ba049c..4c05c801de 100644 --- a/aten/src/THS/generic/THSTensor.h +++ b/aten/src/THS/generic/THSTensor.h @@ -2,24 +2,8 @@ #define THS_GENERIC_FILE "generic/THSTensor.h" #else -typedef struct THSTensor -{ // Stored in COO format, indices + values - int64_t *size; - ptrdiff_t nnz; - int nDimensionI; // dimension of indices - int nDimensionV; // dimension of values - - // 2-D tensor of nDim x nnz of indices. May have nnz dim bigger than nnz - // as buffer, so we keep track of both - THLongTensor *indices; - THTensor *values; - // A sparse tensor is 'coalesced' if every index occurs at most once in - // the indices tensor, and the indices are in sorted order. - // Most math operations can only be performed on ordered sparse tensors - int coalesced; - int refcount; - -} THSTensor; +// Moved to THSTensor.hpp +typedef struct THSTensor THSTensor; /**** access methods ****/ TH_API int THSTensor_(nDimension)(const THSTensor *self); diff --git a/aten/src/THS/generic/THSTensor.hpp b/aten/src/THS/generic/THSTensor.hpp new file mode 100644 index 0000000000..2676a2e9f3 --- /dev/null +++ b/aten/src/THS/generic/THSTensor.hpp @@ -0,0 +1,24 @@ +#ifndef THS_GENERIC_FILE +#define THS_GENERIC_FILE "generic/THSTensor.hpp" +#else + +typedef struct THSTensor +{ // Stored in COO format, indices + values + int64_t *size; + ptrdiff_t nnz; + int nDimensionI; // dimension of indices + int nDimensionV; // dimension of values + + // 2-D tensor of nDim x nnz of indices. May have nnz dim bigger than nnz + // as buffer, so we keep track of both + THLongTensor *indices; + THTensor *values; + // A sparse tensor is 'coalesced' if every index occurs at most once in + // the indices tensor, and the indices are in sorted order. + // Most math operations can only be performed on ordered sparse tensors + int coalesced; + std::atomic<int> refcount; + +} THSTensor; + +#endif |