diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-05-18 18:25:18 -0700 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-05-19 23:56:16 -0700 |
commit | bf4d7262bd0694a39245d6e6ff23d9ea25f5a5df (patch) | |
tree | 6c01b74777871b299e10b37782342cc76ded870b | |
parent | 6b85fd006d87c3af538ce679eaaf0ba6b866765e (diff) | |
download | caffe-bf4d7262bd0694a39245d6e6ff23d9ea25f5a5df.tar.gz caffe-bf4d7262bd0694a39245d6e6ff23d9ea25f5a5df.tar.bz2 caffe-bf4d7262bd0694a39245d6e6ff23d9ea25f5a5df.zip |
fix padding for the last batch
-rw-r--r-- | python/caffe/pycaffe.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 72ae5fbb..9caa21b8 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -325,14 +325,15 @@ def _Net_batch(self, blobs): # Yield full batches. for b in range(num_batches): - for i in [b * batch_size]: - yield {name: blobs[name][i:i + batch_size] for name in blobs} + i = b * batch_size + yield {name: blobs[name][i:i + batch_size] for name in blobs} # Yield last padded batch, if any. if remainder > 0: padded_batch = {} for name in blobs: - padding = np.zeros((remainder,) + blobs[name].shape[1:]) + padding = np.zeros((batch_size - remainder,) + + blobs[name].shape[1:]) padded_batch[name] = np.concatenate([blobs[name][-remainder:], padding]) yield padded_batch |