diff options
-rw-r--r-- | python/caffe/pycaffe.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 5275a07f..d965227d 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -77,7 +77,7 @@ def _Net_forward(self, blobs=None, **kwargs): out_blobs.extend([self.blobs[blob].data for blob in blobs]) out_blob_names = self.outputs + blobs for out, out_blob in zip(out_blob_names, out_blobs): - outs[out] = [out_blob[ix, :, :, :].squeeze() + outs[out] = [out_blob[ix, :, :, :] for ix in range(out_blob.shape[0])] return outs @@ -118,7 +118,7 @@ def _Net_backward(self, diffs=None, **kwargs): out_diffs.extend([self.blobs[diff].diff for diff in diffs]) out_diff_names = self.inputs + diffs for out, out_diff in zip(out_diff_names, out_diffs): - outs[out] = [out_diff[ix, :, :, :].squeeze() + outs[out] = [out_diff[ix, :, :, :] for ix in range(out_diff.shape[0])] return outs |