diff options
author | Richard Zou <zou3519@gmail.com> | 2018-08-29 12:14:46 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-08-29 12:25:51 -0700 |
commit | 4e446b85fb5e0b5db0951cc068e423d9caf5beef (patch) | |
tree | b0172f872d0ae63acbe670bc253666560909c4e8 /torch/autograd/profiler.py | |
parent | 396dec0e3740fad00461bc0ebcdfae09708693c6 (diff) | |
download | pytorch-4e446b85fb5e0b5db0951cc068e423d9caf5beef.tar.gz pytorch-4e446b85fb5e0b5db0951cc068e423d9caf5beef.tar.bz2 pytorch-4e446b85fb5e0b5db0951cc068e423d9caf5beef.zip |
Make profiler.build_table() O(n) rather than O(n^2) (#10969)
Summary:
Fixes #10851
Speeds up profiling results dramatically.
For the following script:
```
import torch
import time
ITER = 2000
x = torch.randn(1, 1, requires_grad=True)
with torch.autograd.profiler.profile() as prof:
y = x
for i in range(ITER):
y = 3 * y - 2 * y
y.backward()
start = time.time()
print("Done running. Preparing prof")
x = str(prof)
print("Done preparing prof results")
end = time.time()
print("Elapsed: {}".format(end - start))
```
I get 7s before / 0.13s after these changes.
cc apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10969
Differential Revision: D9556129
Pulled By: zou3519
fbshipit-source-id: 26b421686f8a42cdaace6382567d403e6385dc12
Diffstat (limited to 'torch/autograd/profiler.py')
-rw-r--r-- | torch/autograd/profiler.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 75e309ac0f..c1be47ad49 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -554,11 +554,11 @@ def build_table(events, sort_by=None, header=None): header_sep = '-' * max_name_length + (' ' + '-' * col_width) * 5 # Have to use a list because nonlocal is Py3 only... - result = [''] + result = [] def append(s): - result[0] += s - result[0] += '\n' + result.append(s) + result.append('\n') # Yes, newline after the end as well # Actual printing if header is not None: @@ -572,4 +572,4 @@ def build_table(events, sort_by=None, header=None): append(row_format.format(evt.key, evt.cpu_time_str, evt.cuda_time_str, evt.count, evt.cpu_time_total_str, evt.cuda_time_total_str)) - return result[0] + return ''.join(result) |