summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 20:17:09 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-14 20:17:16 -0700
commit8af33e8ca2f56a3bef23935990b1c9ed65629918 (patch)
tree8282a8958174b0a5f9867b3b015d02f43136b619 /python
parentaf0b857d5481cf86ff60c59e3eab71d1175ce8c6 (diff)
downloadcaffeonacl-8af33e8ca2f56a3bef23935990b1c9ed65629918.tar.gz
caffeonacl-8af33e8ca2f56a3bef23935990b1c9ed65629918.tar.bz2
caffeonacl-8af33e8ca2f56a3bef23935990b1c9ed65629918.zip
don't squeeze blob arrays for python
Preserve the non-batch dimensions of blob arrays, even for singletons. The forward() and backward() helpers take lists of ndarrays instead of a single ndarray per blob, and lists of ndarrays are likewise returned. Note that for output the blob array could actually be returned as a single ndarray instead of a list.
Diffstat (limited to 'python')
-rw-r--r--python/caffe/pycaffe.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index 5275a07f..d965227d 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -77,7 +77,7 @@ def _Net_forward(self, blobs=None, **kwargs):
out_blobs.extend([self.blobs[blob].data for blob in blobs])
out_blob_names = self.outputs + blobs
for out, out_blob in zip(out_blob_names, out_blobs):
- outs[out] = [out_blob[ix, :, :, :].squeeze()
+ outs[out] = [out_blob[ix, :, :, :]
for ix in range(out_blob.shape[0])]
return outs
@@ -118,7 +118,7 @@ def _Net_backward(self, diffs=None, **kwargs):
out_diffs.extend([self.blobs[diff].diff for diff in diffs])
out_diff_names = self.inputs + diffs
for out, out_diff in zip(out_diff_names, out_diffs):
- outs[out] = [out_diff[ix, :, :, :].squeeze()
+ outs[out] = [out_diff[ix, :, :, :]
for ix in range(out_diff.shape[0])]
return outs