diff options
author | Zachary DeVito <zdevito@gmail.com> | 2017-11-16 13:58:09 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-11-16 13:58:09 -0800 |
commit | cc7f09a3727c5095ccbadc16869522ada5e724e0 (patch) | |
tree | 98c4fa5000c3aa5f1b7a9cfc79314ec7001182d5 /torch | |
parent | 8f5c0f9678b2484dee1f8374d80f1058b8cab609 (diff) | |
download | pytorch-cc7f09a3727c5095ccbadc16869522ada5e724e0.tar.gz pytorch-cc7f09a3727c5095ccbadc16869522ada5e724e0.tar.bz2 pytorch-cc7f09a3727c5095ccbadc16869522ada5e724e0.zip |
Add cudaEvent support to the profiler (#3734)
* Add cudaEvent support to the profiler
This adds the ability to record cuda timings using cudaEventRecord
in the profiler. Since it doesn't require nvprof it is easier
to run than the nvprof path.
This also records a thread id for each event, which will make
tracing results easier to understand
* Add flow arrows from cpu to cuda event
* Fix no cuda build
* Review comments
* Move CUDA checks to one place
Diffstat (limited to 'torch')
-rw-r--r-- | torch/autograd/profiler.py | 148 | ||||
-rw-r--r-- | torch/csrc/autograd/init.cpp | 56 | ||||
-rw-r--r-- | torch/csrc/autograd/profiler.cpp | 28 | ||||
-rw-r--r-- | torch/csrc/autograd/profiler.h | 127 | ||||
-rw-r--r-- | torch/csrc/cuda/cuda_check.h | 41 | ||||
-rw-r--r-- | torch/csrc/cudnn/Conv.cpp | 5 | ||||
-rw-r--r-- | torch/csrc/jit/fusion_compiler.cpp | 57 |
7 files changed, 294 insertions, 168 deletions
diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 5afb1375a9..b75fa373a5 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -53,16 +53,52 @@ class EventList(list): import json with open(path, 'w') as f: chrome_events = [] + next_id = 0 for evt in self: chrome_events.append(dict( name=evt.name, ph='X', - ts=evt.start / 1000, - dur=evt.cpu_time_total / 1000, - tid='Autograd functions', - pid='Autograd functions', + ts=evt.cpu_interval.start, + dur=evt.cpu_interval.elapsed_us(), + tid=evt.thread, + pid='CPU functions', args={}, )) + for name, cuda_interval in evt.kernels: + # 's' and 'f' draw Flow arrows from + # the CPU launch to the GPU kernel + chrome_events.append(dict( + name=evt.name, + ph='s', + # +1 microsecond so the arrow is drawn inside cpu block + ts=evt.cpu_interval.start + 1, + tid=evt.thread, + pid='CPU functions', + id=next_id, + cat='cpu_to_cuda', + args={}, + )) + chrome_events.append(dict( + name=name, + ph='f', + ts=cuda_interval.start, + tid=evt.thread, + pid='CUDA functions', + id=next_id, + cat='cpu_to_cuda', + args={}, + )) + chrome_events.append(dict( + name=evt.name, + ph='X', + ts=cuda_interval.start, + dur=cuda_interval.elapsed_us(), + tid=evt.thread, + pid='CUDA functions', + args={}, + )) + next_id += 1 + json.dump(chrome_events, f) def key_averages(self): @@ -97,6 +133,10 @@ class profile(object): enabled (bool, optional): Setting this to False makes this context manager a no-op. Default: ``True``. + use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API. + Adds approximately 4us of overhead to each tensor operation. + Default: ``False`` + .. warning: This context managers should not be called recursively, i.e. at most one instance should be enabled at any given time. @@ -121,8 +161,9 @@ class profile(object): N5torch8autograd5CloneE 4.088us 0.000us """ - def __init__(self, enabled=True): + def __init__(self, enabled=True, use_cuda=False): self.enabled = enabled + self.use_cuda = use_cuda self.function_events = None if not self.enabled: return @@ -134,7 +175,9 @@ class profile(object): if self.entered: raise RuntimeError("autograd profiler traces are not reentrant") self.entered = True - torch.autograd._enable_profiler(False) + profiler_kind = torch.autograd.ProfilerState.CUDA if self.use_cuda \ + else torch.autograd.ProfilerState.CPU + torch.autograd._enable_profiler(profiler_kind) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -218,7 +261,7 @@ class emit_nvtx(object): raise RuntimeError("NVTX annotation context manager is not reentrant") self.entered = True torch.cuda.synchronize() - torch.autograd._enable_profiler(True) + torch.autograd._enable_profiler(torch.autograd.ProfilerState.NVTX) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -241,9 +284,9 @@ def load_nvprof(path): ################################################################################ # FunctionEvent -def format_time(time_ns): +def format_time(time_us): """Defines how to format time in FunctionEvent""" - return '{:.3f}us'.format(time_ns / 1000) + return '{:.3f}us'.format(time_us) def attr_formatter(name): @@ -269,32 +312,44 @@ class FormattedTimesMixin(object): return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count +class Interval(object): + def __init__(self, start, end): + self.start = start + self.end = end + + def elapsed_us(self): + return self.end - self.start + + # TODO: record TID too class FunctionEvent(FormattedTimesMixin): """Profiling information about a single function.""" - def __init__(self, id, name, start, end): + def __init__(self, id, name, thread, cpu_start, cpu_end): self.id = id self.name = name - self.start = start - self.end = end + self.cpu_interval = Interval(cpu_start, cpu_end) + self.thread = thread self.kernels = [] self.count = 1 + def append_kernel(self, name, start, end): + self.kernels.append((name, Interval(start, end))) + @property def cuda_time_total(self): - return sum(kinfo[1] for kinfo in self.kernels) + return sum(kinfo[1].elasped_us() for kinfo in self.kernels) @property def cpu_time_total(self): - return self.end - self.start + return self.cpu_interval.elapsed_us() @property def key(self): return self.name def __repr__(self): - return '<FunctionEvent id={} cpu_time={} cuda_time={} name={}>'.format( - self.id, self.cpu_time_str, self.cuda_time_str, self.name) + return '<FunctionEvent id={} cpu_time={} cuda_time={} name={} thread={}>'.format( + self.id, self.cpu_time_str, self.cuda_time_str, self.name, self.thread) class FunctionEventAvg(FormattedTimesMixin): @@ -339,37 +394,39 @@ class StringTable(defaultdict): ################################################################################ # CPU checkpoints -Record = namedtuple('Record', ['name', 'timestamp', 'kind']) - - def parse_cpu_trace(thread_records): next_id = 0 - start_time = None + start_record = None functions = [] - function_stack = [] + record_stack = [] string_table = StringTable() - for r in itertools.chain(*thread_records): - record = Record(*r) - if record.name == '__start_profile': - start_time = record.timestamp - if record.kind == 'mark': + # '__start_profile' is not guarenteed to be first, so we must find it here + for record in itertools.chain(*thread_records): + if record.name() == '__start_profile': + start_record = record + break + assert start_record is not None + + for record in itertools.chain(*thread_records): + if record.kind() == 'mark': continue - elif record.kind == 'push': - function_stack.append(FunctionEvent( - id=next_id, name=string_table[record.name], start=record.timestamp, end=record.timestamp)) + elif record.kind() == 'push': + record_stack.append((next_id, record)) next_id += 1 - elif record.kind == 'pop': - function_stack[-1].end = record.timestamp - functions.append(function_stack.pop()) - - # Normalize times - if start_time is None: - raise RuntimeError('Malformed profile: no start marker') - for event in functions: - event.start -= start_time - event.end -= start_time - - functions.sort(key=lambda evt: evt.start) + elif record.kind() == 'pop': + function_id, start = record_stack.pop() + fe = FunctionEvent( + id=function_id, + name=string_table[start.name()], + thread=start.thread_id(), + cpu_start=start_record.cpu_elapsed_us(start), + cpu_end=start_record.cpu_elapsed_us(record)) + if start_record.has_cuda(): + fe.append_kernel(start.name(), + start_record.cuda_elapsed_us(start), + start_record.cuda_elapsed_us(record)) + functions.append(fe) + functions.sort(key=lambda evt: evt.cpu_interval.start) return functions @@ -414,8 +471,9 @@ def parse_nvprof_trace(path): unique.see(row['marker_id']) evt = FunctionEvent(id=row['marker_id'], name=strings[row['name']], - start=row['start_time'], - end=row['end_time']) + cpu_start=row['start_time'], + cpu_end=row['end_time'], + thread=0) # TODO: find in sqlite database functions.append(evt) functions_map[evt.id] = evt @@ -439,7 +497,9 @@ def parse_nvprof_trace(path): unique.see(row['marker_id'], row['runtime_id']) assert row['cbid'] == 13 # 13 == Launch evt = functions_map[row['marker_id']] - evt.kernels.append((row['kernel_name'], row['kernel_end'] - row['kernel_start'])) + evt.append_kernel(row['kernel_name'], + row['kernel_start'], + row['kernel_end']) functions.sort(key=lambda evt: evt.start) return functions diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 32954d5f83..868852fa1b 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -10,44 +10,6 @@ #define ENSURE_UNREACHABLE __builtin_unreachable(); #endif -namespace pybind11 { namespace detail { - -template <> struct type_caster<torch::autograd::profiler::EventKind> { -public: - PYBIND11_TYPE_CASTER(torch::autograd::profiler::EventKind, _("torch::autograd::profiler::EventKind")); - - bool load(handle src, bool) { - try { - auto str = py::cast<std::string>(src); - if (str == "push") { - value = torch::autograd::profiler::EventKind::PushRange; - } else if (str == "pop") { - value = torch::autograd::profiler::EventKind::PopRange; - } else if (str == "mark") { - value = torch::autograd::profiler::EventKind::Mark; - } else { - return false; - } - } catch (std::exception& e) { - return false; - } - return true; - } - static handle cast(torch::autograd::profiler::EventKind src, return_value_policy /* policy */, handle /* parent */) { - switch (src) { - case torch::autograd::profiler::EventKind::PushRange: - return py::cast("push").release(); - case torch::autograd::profiler::EventKind::PopRange: - return py::cast("pop").release(); - case torch::autograd::profiler::EventKind::Mark: - return py::cast("mark").release(); - } - ENSURE_UNREACHABLE - } -}; - -}} // namespace pybind11::detail - PyObject * THPAutograd_initExtension(PyObject *_unused) { THPUtils_assert_PyImport("torch.autograd", autograd_module); @@ -68,17 +30,31 @@ PyObject * THPAutograd_initExtension(PyObject *_unused) "StochasticFunction class in torch.autograd module"); auto m = py::handle(autograd_module).cast<py::module>(); + + py::class_<torch::autograd::profiler::Event>(m,"ProfilerEvent") + .def("kind",&torch::autograd::profiler::Event::kind) + .def("name",&torch::autograd::profiler::Event::name) + .def("thread_id",&torch::autograd::profiler::Event::thread_id) + .def("cpu_elapsed_us",&torch::autograd::profiler::Event::cpu_elapsed_us) + .def("cuda_elapsed_us",&torch::autograd::profiler::Event::cuda_elapsed_us) + .def("has_cuda",&torch::autograd::profiler::Event::has_cuda); + py::enum_<torch::autograd::profiler::ProfilerState>(m,"ProfilerState") + .value("Disabled", torch::autograd::profiler::ProfilerState::Disabled) + .value("CPU", torch::autograd::profiler::ProfilerState::CPU) + .value("CUDA", torch::autograd::profiler::ProfilerState::CUDA) + .value("NVTX", torch::autograd::profiler::ProfilerState::NVTX); + m.def("_enable_profiler", torch::autograd::profiler::enableProfiler); m.def("_disable_profiler", torch::autograd::profiler::disableProfiler); m.def("_push_range", [](const char *name) { using namespace torch::autograd::profiler; - if (!profiling) return; + if (state == ProfilerState::Disabled) return; pushRange(name); }); m.def("_pop_range", []() { using namespace torch::autograd::profiler; - if (!profiling) return; + if (state == ProfilerState::Disabled) return; popRange(); }); diff --git a/torch/csrc/autograd/profiler.cpp b/torch/csrc/autograd/profiler.cpp index 4b6f542991..ce6612c452 100644 --- a/torch/csrc/autograd/profiler.cpp +++ b/torch/csrc/autograd/profiler.cpp @@ -3,38 +3,38 @@ namespace torch { namespace autograd { namespace profiler { -bool profiling = false; -bool using_cuda; +ProfilerState state = ProfilerState::Disabled; +uint32_t next_thread_id = 0; std::mutex all_event_lists_mutex; std::list<std::shared_ptr<RangeEventList>> all_event_lists; thread_local std::shared_ptr<RangeEventList> event_list; +thread_local int32_t thread_id; void RecordFunction::pushFunctionRange(Function* fn) { pushRange(fn->name()); } -void enableProfiler(bool use_cuda) { +void enableProfiler(ProfilerState new_state) { + TORCH_ASSERT(new_state != ProfilerState::Disabled); #ifndef WITH_CUDA - if (use_cuda) - throw std::runtime_error("Can't use CUDA profiler - PyTorch was compiled without CUDA"); + if (new_state == ProfilerState::NVTX) + throw std::runtime_error("Can't use NVTX profiler - PyTorch was compiled without CUDA"); #endif - if (profiling) { - if (use_cuda != using_cuda) - throw std::runtime_error("can't change use_cuda flag while profiler is running"); - return; + if (state != ProfilerState::Disabled && new_state != state) { + throw std::runtime_error("can't change kind of profiling (e.g. NVTX to CPU) while profiler is running"); } - profiling = true; - using_cuda = use_cuda; + state = new_state; mark("__start_profile"); } thread_event_lists disableProfiler() { - if (!profiling) { + if (state == ProfilerState::Disabled) { throw std::runtime_error("can't disable profiler when it's not running"); } + ProfilerState old_state = state; mark("__stop_profile"); - profiling = false; - if (using_cuda) { + state = ProfilerState::Disabled; + if (old_state == ProfilerState::NVTX) { return thread_event_lists(); } else { thread_event_lists result; diff --git a/torch/csrc/autograd/profiler.h b/torch/csrc/autograd/profiler.h index 2f2be14537..0b43b231f9 100644 --- a/torch/csrc/autograd/profiler.h +++ b/torch/csrc/autograd/profiler.h @@ -11,8 +11,14 @@ #include <cstdint> #include <string> #include <list> +#include <sstream> #include <forward_list> #include <tuple> +#include "ATen/ATen.h" +#include "torch/csrc/cuda/cuda_check.h" +#ifdef WITH_CUDA +#include <cuda_runtime.h> +#endif namespace torch { namespace autograd { @@ -24,16 +30,83 @@ constexpr inline std::size_t ceilToMultiple(std::size_t a, std::size_t b) { return ((a + b - 1) / b) * b; } +inline uint64_t getTime() { + using namespace std::chrono; + using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type; + return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count(); +} + enum class EventKind { Mark, PushRange, PopRange }; -// NOTE: we don't need a flag saying if an event is a kernel, because it's -// used only for the CPU-side perf recording. -using Event = std::tuple<std::string, uint64_t, EventKind>; // (name, time, kind) +struct Event { + Event(EventKind kind, std::string name, uint32_t thread_id, bool record_cuda) + : kind_(kind) + , name_(std::move(name)) + , thread_id_(thread_id) + , cpu_ns_(getTime()) { +#ifdef WITH_CUDA + if(record_cuda) { + TORCH_CUDA_CHECK(cudaEventCreate(&event)); + TORCH_CUDA_CHECK(cudaEventRecord(event, at::globalContext().getCurrentCUDAStream())); + } +#endif + } + std::string kind() const { + switch(kind_) { + case EventKind::Mark: return "mark"; + case EventKind::PushRange: return "push"; + case EventKind::PopRange: return "pop"; + } + throw std::runtime_error("unknown EventKind"); + } + const std::string & name() const { + return name_; + } + uint32_t thread_id() const { + return thread_id_; + } + double cpu_elapsed_us(const Event & e) { + return (e.cpu_ns_ - cpu_ns_)/(1000.0); + } + double cuda_elapsed_us(const Event & e) { +#ifdef WITH_CUDA + if(!e.has_cuda() || !has_cuda()) { + throw std::logic_error("Events were not recorded for CUDA"); + } + TORCH_CUDA_CHECK(cudaEventSynchronize(event)); + TORCH_CUDA_CHECK(cudaEventSynchronize(e.event)); + float ms; + TORCH_CUDA_CHECK(cudaEventElapsedTime(&ms, event, e.event)); + return ms*1000.0; +#else + throw std::logic_error("CUDA not enabled"); +#endif + } + bool has_cuda() const { +#ifdef WITH_CUDA + return event != nullptr; +#else + return false; +#endif + } + +private: + EventKind kind_; + std::string name_; + uint32_t thread_id_; + uint64_t cpu_ns_; +#ifdef WITH_CUDA + cudaEvent_t event = nullptr; +#endif +}; +// a linked-list of fixed sized vectors, to avoid +// a std::vector resize from taking a large amount of time inside +// a profiling event struct RangeEventList { constexpr static std::size_t MB = 1024 * 1024; constexpr static std::size_t event_block_size = 16 * MB; @@ -70,81 +143,85 @@ struct RangeEventList { std::forward_list<block_type> blocks; }; -extern bool profiling; -extern bool using_cuda; +enum class ProfilerState { + Disabled, + CPU, // CPU-only profiling + CUDA, // CPU + CUDA events + NVTX, // only emit NVTX markers +}; + +extern ProfilerState state; +extern uint32_t next_thread_id; extern std::mutex all_event_lists_mutex; extern std::list<std::shared_ptr<RangeEventList>> all_event_lists; + extern thread_local std::shared_ptr<RangeEventList> event_list; +extern thread_local int32_t thread_id; inline RangeEventList& getEventList() { if (!event_list) { std::lock_guard<std::mutex> guard(all_event_lists_mutex); event_list = std::make_shared<RangeEventList>(); + thread_id = next_thread_id++; all_event_lists.emplace_front(event_list); } return *event_list; } -inline uint64_t getTime() { - using namespace std::chrono; - using clock = std::conditional<high_resolution_clock::is_steady, high_resolution_clock, steady_clock>::type; - return duration_cast<nanoseconds>(clock::now().time_since_epoch()).count(); -} - inline void mark(std::string name) { - if (using_cuda) { + if (state == ProfilerState::NVTX) { #ifdef WITH_CUDA nvtxMarkA(name.c_str()); #else - throw std::logic_error("mark called with use_cuda=True, but compiled without CUDA"); + throw std::logic_error("mark called with NVTX tracing, but compiled without CUDA"); #endif } else { - getEventList().record(std::move(name), getTime(), EventKind::Mark); + getEventList().record(EventKind::Mark, std::move(name), thread_id, state == ProfilerState::CUDA); } } inline void pushRange(std::string name) { - if (using_cuda) { + if (state == ProfilerState::NVTX) { #ifdef WITH_CUDA nvtxRangePushA(name.c_str()); #else - throw std::logic_error("pushRange called with use_cuda=True, but compiled without CUDA"); + throw std::logic_error("pushRange called with NVTX tracing, but compiled without CUDA"); #endif } else { - getEventList().record(std::move(name), getTime(), EventKind::PushRange); + getEventList().record(EventKind::PushRange, std::move(name), thread_id, state == ProfilerState::CUDA); } } inline void popRange() { - if (using_cuda) { + if (state == ProfilerState::NVTX) { #ifdef WITH_CUDA nvtxRangePop(); #else - throw std::logic_error("popRange called with use_cuda=True, but compiled without CUDA"); + throw std::logic_error("popRange called with NVTX tracing, but compiled without CUDA"); #endif } else { - getEventList().record(std::string(), getTime(), EventKind::PopRange); + getEventList().record(EventKind::PopRange, std::string(), thread_id, state == ProfilerState::CUDA); } } struct RecordFunction { explicit RecordFunction(Function *fn) { - if (!profiling) return; + if (state == ProfilerState::Disabled) return; pushFunctionRange(fn); } explicit RecordFunction(std::string name) { - if (!profiling) return; + if (state == ProfilerState::Disabled) return; pushRange(std::move(name)); } explicit RecordFunction(const char *name) { - if (!profiling) return; + if (state == ProfilerState::Disabled) return; pushRange(name); } ~RecordFunction() { - if (!profiling) return; + if (state == ProfilerState::Disabled) return; popRange(); } @@ -155,7 +232,7 @@ struct RecordFunction { using thread_event_lists = std::vector<std::vector<Event>>; // NOTE: changing profiler modes is **NOT THREAD SAFE**. You should ensure that // there no autograd functions are being executed when these function are used. -void enableProfiler(bool use_cuda); +void enableProfiler(ProfilerState state); thread_event_lists disableProfiler(); } // namespace profiler diff --git a/torch/csrc/cuda/cuda_check.h b/torch/csrc/cuda/cuda_check.h new file mode 100644 index 0000000000..816b1883dc --- /dev/null +++ b/torch/csrc/cuda/cuda_check.h @@ -0,0 +1,41 @@ +#pragma once + +#ifdef WITH_CUDA +#include <cuda.h> +#include <cuda_runtime.h> +#include <nvrtc.h> + +namespace torch { +// We're using three CUDA APIs, so define a few helpers for error handling +static inline void nvrtcCheck(nvrtcResult result,const char * file, int line) { + if(result != NVRTC_SUCCESS) { + std::stringstream ss; + ss << file << ":" << line << ": " << nvrtcGetErrorString(result); + throw std::runtime_error(ss.str()); + } +} +#define TORCH_NVRTC_CHECK(result) ::torch::nvrtcCheck(result,__FILE__,__LINE__); + +static inline void cuCheck(CUresult result, const char * file, int line) { + if(result != CUDA_SUCCESS) { + const char * str; + cuGetErrorString(result, &str); + std::stringstream ss; + ss << file << ":" << line << ": " << str; + throw std::runtime_error(ss.str()); + } +} +#define TORCH_CU_CHECK(result) ::torch::cuCheck(result,__FILE__,__LINE__); + +static inline void cudaCheck(cudaError_t result, const char * file, int line) { + if(result != cudaSuccess) { + std::stringstream ss; + ss << file << ":" << line << ": " << cudaGetErrorString(result); + throw std::runtime_error(ss.str()); + } +} +#define TORCH_CUDA_CHECK(result) ::torch::cudaCheck(result,__FILE__,__LINE__); + +} + +#endif diff --git a/torch/csrc/cudnn/Conv.cpp b/torch/csrc/cudnn/Conv.cpp index 3979a8b2f8..542dfb7992 100644 --- a/torch/csrc/cudnn/Conv.cpp +++ b/torch/csrc/cudnn/Conv.cpp @@ -1,5 +1,6 @@ #include "Conv.h" +#include "torch/csrc/cuda/cuda_check.h" #include "THC/THC.h" #include "Exceptions.h" #include "Types.h" @@ -98,7 +99,7 @@ BenchmarkCache<cudnnConvolutionBwdFilterAlgo_t> bwd_filter_algos; struct Workspace { Workspace(THCState* state, size_t size) : state(state), size(size), data(NULL) { - CUDA_CHECK(THCudaMalloc(state, &data, size)); + TORCH_CUDA_CHECK(THCudaMalloc(state, &data, size)); } Workspace(const Workspace&) = delete; Workspace(Workspace&&) = default; @@ -453,7 +454,7 @@ void findAlgorithm( cache.insert(conv.params, *algo); THCDeviceAllocator* allocator = THCCachingAllocator_get(); - CUDA_CHECK(allocator->emptyCache(allocator->state)); + TORCH_CUDA_CHECK(allocator->emptyCache(allocator->state)); } template<typename algo_t> diff --git a/torch/csrc/jit/fusion_compiler.cpp b/torch/csrc/jit/fusion_compiler.cpp index cbdebf6ff5..af6d9d7547 100644 --- a/torch/csrc/jit/fusion_compiler.cpp +++ b/torch/csrc/jit/fusion_compiler.cpp @@ -3,6 +3,7 @@ #include "torch/csrc/jit/code_template.h" #include "torch/csrc/jit/resource_guard.h" #include "torch/csrc/utils/disallow_copy.h" +#include "torch/csrc/cuda/cuda_check.h" #include "ATen/ATen.h" #include <nvrtc.h> #include <cuda.h> @@ -104,36 +105,6 @@ std::ostream& operator<<(std::ostream & out, const TensorDesc & d) { return out; } -// We're using three CUDA APIs, so define a few helpers for error handling -static void nvrtcCheck(nvrtcResult result,const char * file, int line) { - if(result != NVRTC_SUCCESS) { - std::stringstream ss; - ss << file << ":" << line << ": " << nvrtcGetErrorString(result); - throw std::runtime_error(ss.str()); - } -} -#define JIT_NVRTC_CHECK(result) nvrtcCheck(result,__FILE__,__LINE__); - -static void cuCheck(CUresult result, const char * file, int line) { - if(result != CUDA_SUCCESS) { - const char * str; - cuGetErrorString(result, &str); - std::stringstream ss; - ss << file << ":" << line << ": " << str; - throw std::runtime_error(ss.str()); - } -} -#define JIT_CU_CHECK(result) cuCheck(result,__FILE__,__LINE__); - -static void cudaCheck(cudaError_t result, const char * file, int line) { - if(result != cudaSuccess) { - std::stringstream ss; - ss << file << ":" << line << ": " << cudaGetErrorString(result); - throw std::runtime_error(ss.str()); - } -} -#define JIT_CUDA_CHECK(result) cudaCheck(result,__FILE__,__LINE__); - //////////////////////////////////////////////////////////////////////////////// // Code generation @@ -328,14 +299,14 @@ CompiledFusionFunction::CompiledFusionFunction(const std::string & name, Annotat : name(name) , input_desc(agraph.input_desc) , output_desc(agraph.output_desc) { - JIT_CUDA_CHECK(cudaGetDevice(&device)); - JIT_CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); + TORCH_CUDA_CHECK(cudaGetDevice(&device)); + TORCH_CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); std::stringstream cu; concat_desc = codegen::emitCompilationUnit(cu, name, agraph); compilation_unit = cu.str(); nvrtcProgram program; - JIT_NVRTC_CHECK(nvrtcCreateProgram(&program, compilation_unit.c_str(), NULL, 0, nullptr, nullptr)); + TORCH_NVRTC_CHECK(nvrtcCreateProgram(&program, compilation_unit.c_str(), NULL, 0, nullptr, nullptr)); std::string compute = "--gpu-architecture=compute_" + std::to_string(prop.major) + std::to_string(prop.minor); std::vector<const char *> args = {"--std=c++11", compute.c_str()}; @@ -349,25 +320,25 @@ CompiledFusionFunction::CompiledFusionFunction(const std::string & name, Annotat throw std::runtime_error(cu.str()); } ResourceGuard holdProgram([&] { - JIT_NVRTC_CHECK(nvrtcDestroyProgram(&program)); + TORCH_NVRTC_CHECK(nvrtcDestroyProgram(&program)); }); - JIT_NVRTC_CHECK(result); + TORCH_NVRTC_CHECK(result); size_t ptx_size; - JIT_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size)); + TORCH_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size)); ptx.resize(ptx_size); - JIT_NVRTC_CHECK(nvrtcGetPTX(program, ptx.data())); + TORCH_NVRTC_CHECK(nvrtcGetPTX(program, ptx.data())); - JIT_CU_CHECK(cuModuleLoadData(&module, ptx.data())); - JIT_CU_CHECK(cuModuleGetFunction(&function, module, name.c_str())); + TORCH_CU_CHECK(cuModuleLoadData(&module, ptx.data())); + TORCH_CU_CHECK(cuModuleGetFunction(&function, module, name.c_str())); - JIT_CU_CHECK(cuOccupancyMaxActiveBlocksPerMultiprocessor( + TORCH_CU_CHECK(cuOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocks, function, 128, 0)); maxBlocks *= prop.multiProcessorCount; } CompiledFusionFunction::~CompiledFusionFunction() { - JIT_CU_CHECK(cuModuleUnload(module)); + TORCH_CU_CHECK(cuModuleUnload(module)); } namespace { @@ -500,7 +471,7 @@ void CompiledFusionFunction::launch(uint32_t numel, void ** arguments) { // cudaFree(0) accomplishes this. cudaFree(0); - JIT_CU_CHECK(cuLaunchKernel( + TORCH_CU_CHECK(cuLaunchKernel( function, numBlocks, 1, 1, blockSize, 1, 1, @@ -513,7 +484,7 @@ std::shared_ptr<CompiledFusionFunction> FusionCompiler::getOrCompile(AnnotatedGr std::stringstream key; key << *agraph.graph << "\n"; int device; - JIT_CUDA_CHECK(cudaGetDevice(&device)); + TORCH_CUDA_CHECK(cudaGetDevice(&device)); key << "Device " << device << "\n"; for(auto & i : agraph.input_desc) key << i << "\n"; |