summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-18 18:25:18 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-19 23:56:16 -0700
commitbf4d7262bd0694a39245d6e6ff23d9ea25f5a5df (patch)
tree6c01b74777871b299e10b37782342cc76ded870b
parent6b85fd006d87c3af538ce679eaaf0ba6b866765e (diff)
downloadcaffe-bf4d7262bd0694a39245d6e6ff23d9ea25f5a5df.tar.gz
caffe-bf4d7262bd0694a39245d6e6ff23d9ea25f5a5df.tar.bz2
caffe-bf4d7262bd0694a39245d6e6ff23d9ea25f5a5df.zip
fix padding for the last batch
-rw-r--r--python/caffe/pycaffe.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index 72ae5fbb..9caa21b8 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -325,14 +325,15 @@ def _Net_batch(self, blobs):
# Yield full batches.
for b in range(num_batches):
- for i in [b * batch_size]:
- yield {name: blobs[name][i:i + batch_size] for name in blobs}
+ i = b * batch_size
+ yield {name: blobs[name][i:i + batch_size] for name in blobs}
# Yield last padded batch, if any.
if remainder > 0:
padded_batch = {}
for name in blobs:
- padding = np.zeros((remainder,) + blobs[name].shape[1:])
+ padding = np.zeros((batch_size - remainder,)
+ + blobs[name].shape[1:])
padded_batch[name] = np.concatenate([blobs[name][-remainder:],
padding])
yield padded_batch