From 8af33e8ca2f56a3bef23935990b1c9ed65629918 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Wed, 14 May 2014 20:17:09 -0700 Subject: don't squeeze blob arrays for python Preserve the non-batch dimensions of blob arrays, even for singletons. The forward() and backward() helpers take lists of ndarrays instead of a single ndarray per blob, and lists of ndarrays are likewise returned. Note that for output the blob array could actually be returned as a single ndarray instead of a list. --- python/caffe/pycaffe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'python') diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 5275a07f..d965227d 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -77,7 +77,7 @@ def _Net_forward(self, blobs=None, **kwargs): out_blobs.extend([self.blobs[blob].data for blob in blobs]) out_blob_names = self.outputs + blobs for out, out_blob in zip(out_blob_names, out_blobs): - outs[out] = [out_blob[ix, :, :, :].squeeze() + outs[out] = [out_blob[ix, :, :, :] for ix in range(out_blob.shape[0])] return outs @@ -118,7 +118,7 @@ def _Net_backward(self, diffs=None, **kwargs): out_diffs.extend([self.blobs[diff].diff for diff in diffs]) out_diff_names = self.inputs + diffs for out, out_diff in zip(out_diff_names, out_diffs): - outs[out] = [out_diff[ix, :, :, :].squeeze() + outs[out] = [out_diff[ix, :, :, :] for ix in range(out_diff.shape[0])] return outs -- cgit v1.2.3