diff options
author | sh1r0 <sh1r0@users.noreply.github.com> | 2015-10-09 00:31:05 +0800 |
---|---|---|
committer | sh1r0 <sh1r0@users.noreply.github.com> | 2015-10-09 00:37:52 +0800 |
commit | c65ba61bdf273604c3edcd24ba7a80cc3835441a (patch) | |
tree | 4b3b31758634129400407fc92a64e282434766cb /python | |
parent | 04c7c368095a9b3244ac9cf8afde9272482a9b32 (diff) | |
download | caffeonacl-c65ba61bdf273604c3edcd24ba7a80cc3835441a.tar.gz caffeonacl-c65ba61bdf273604c3edcd24ba7a80cc3835441a.tar.bz2 caffeonacl-c65ba61bdf273604c3edcd24ba7a80cc3835441a.zip |
Remove the 4D constraint of blobproto IO in python
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/io.py | 12 |
1 files changed, 4 insertions, 8 deletions
diff --git a/python/caffe/io.py b/python/caffe/io.py index 0cad7211..40b7ac1e 100644 --- a/python/caffe/io.py +++ b/python/caffe/io.py @@ -21,22 +21,18 @@ def blobproto_to_array(blob, return_diff=False): unless return_diff is True, in which case we will return the diff. """ if return_diff: - return np.array(blob.diff).reshape( - blob.num, blob.channels, blob.height, blob.width) + return np.array(blob.diff).reshape(*blob.shape.dim) else: - return np.array(blob.data).reshape( - blob.num, blob.channels, blob.height, blob.width) + return np.array(blob.data).reshape(*blob.shape.dim) def array_to_blobproto(arr, diff=None): - """Converts a 4-dimensional array to blob proto. If diff is given, also + """Converts a N-dimensional array to blob proto. If diff is given, also convert the diff. You need to make sure that arr and diff have the same shape, and this function does not do sanity check. """ - if arr.ndim != 4: - raise ValueError('Incorrect array shape.') blob = caffe_pb2.BlobProto() - blob.num, blob.channels, blob.height, blob.width = arr.shape + blob.shape.dim.extend(arr.shape) blob.data.extend(arr.astype(float).flat) if diff is not None: blob.diff.extend(diff.astype(float).flat) |