summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorsh1r0 <sh1r0@users.noreply.github.com>2015-10-09 00:31:05 +0800
committersh1r0 <sh1r0@users.noreply.github.com>2015-10-09 00:37:52 +0800
commitc65ba61bdf273604c3edcd24ba7a80cc3835441a (patch)
tree4b3b31758634129400407fc92a64e282434766cb /python
parent04c7c368095a9b3244ac9cf8afde9272482a9b32 (diff)
downloadcaffeonacl-c65ba61bdf273604c3edcd24ba7a80cc3835441a.tar.gz
caffeonacl-c65ba61bdf273604c3edcd24ba7a80cc3835441a.tar.bz2
caffeonacl-c65ba61bdf273604c3edcd24ba7a80cc3835441a.zip
Remove the 4D constraint of blobproto IO in python
Diffstat (limited to 'python')
-rw-r--r--python/caffe/io.py12
1 files changed, 4 insertions, 8 deletions
diff --git a/python/caffe/io.py b/python/caffe/io.py
index 0cad7211..40b7ac1e 100644
--- a/python/caffe/io.py
+++ b/python/caffe/io.py
@@ -21,22 +21,18 @@ def blobproto_to_array(blob, return_diff=False):
unless return_diff is True, in which case we will return the diff.
"""
if return_diff:
- return np.array(blob.diff).reshape(
- blob.num, blob.channels, blob.height, blob.width)
+ return np.array(blob.diff).reshape(*blob.shape.dim)
else:
- return np.array(blob.data).reshape(
- blob.num, blob.channels, blob.height, blob.width)
+ return np.array(blob.data).reshape(*blob.shape.dim)
def array_to_blobproto(arr, diff=None):
- """Converts a 4-dimensional array to blob proto. If diff is given, also
+ """Converts a N-dimensional array to blob proto. If diff is given, also
convert the diff. You need to make sure that arr and diff have the same
shape, and this function does not do sanity check.
"""
- if arr.ndim != 4:
- raise ValueError('Incorrect array shape.')
blob = caffe_pb2.BlobProto()
- blob.num, blob.channels, blob.height, blob.width = arr.shape
+ blob.shape.dim.extend(arr.shape)
blob.data.extend(arr.astype(float).flat)
if diff is not None:
blob.diff.extend(diff.astype(float).flat)