summaryrefslogtreecommitdiff
path: root/torch/_utils.py
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-09-30 16:37:07 -0400
committerGitHub <noreply@github.com>2016-09-30 16:37:07 -0400
commit11b38a68956d006b99a4cf78885de3070f60c41b (patch)
tree4ebdb9ad773714f12aeff57163c692b7111a1efc /torch/_utils.py
parenta1f5fe6a8f47ddb3d79c8492e248762883e80214 (diff)
downloadpytorch-11b38a68956d006b99a4cf78885de3070f60c41b.tar.gz
pytorch-11b38a68956d006b99a4cf78885de3070f60c41b.tar.bz2
pytorch-11b38a68956d006b99a4cf78885de3070f60c41b.zip
Add more functions to autograd
Diffstat (limited to 'torch/_utils.py')
-rw-r--r--torch/_utils.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/torch/_utils.py b/torch/_utils.py
index 0ad2ee2924..9adba52e80 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -35,3 +35,20 @@ def _import_dotted_name(name):
for component in components[1:]:
obj = getattr(obj, component)
return obj
+
+
+# Taken from python 3.5 docs
+def _accumulate(iterable):
+ 'Return running totals'
+ # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
+ # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
+ it = iter(iterable)
+ try:
+ total = next(it)
+ except StopIteration:
+ return
+ yield total
+ for element in it:
+ total += element
+ yield total
+