diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2016-09-30 16:37:07 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-09-30 16:37:07 -0400 |
commit | 11b38a68956d006b99a4cf78885de3070f60c41b (patch) | |
tree | 4ebdb9ad773714f12aeff57163c692b7111a1efc /torch/_utils.py | |
parent | a1f5fe6a8f47ddb3d79c8492e248762883e80214 (diff) | |
download | pytorch-11b38a68956d006b99a4cf78885de3070f60c41b.tar.gz pytorch-11b38a68956d006b99a4cf78885de3070f60c41b.tar.bz2 pytorch-11b38a68956d006b99a4cf78885de3070f60c41b.zip |
Add more functions to autograd
Diffstat (limited to 'torch/_utils.py')
-rw-r--r-- | torch/_utils.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/torch/_utils.py b/torch/_utils.py index 0ad2ee2924..9adba52e80 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -35,3 +35,20 @@ def _import_dotted_name(name): for component in components[1:]: obj = getattr(obj, component) return obj + + +# Taken from python 3.5 docs +def _accumulate(iterable): + 'Return running totals' + # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 + # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 + it = iter(iterable) + try: + total = next(it) + except StopIteration: + return + yield total + for element in it: + total += element + yield total + |