summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorshai <shai@magisto.com>2016-02-23 10:42:54 +0200
committershai <shai@magisto.com>2016-02-23 10:42:54 +0200
commit29bb23fc92c5b71c0bf8af0b9e580015da9aedda (patch)
tree05f8906d7efb00154a8ec1d21644a83302cba7cd
parent4541f8900588a335f2d9387a5b03460deba68678 (diff)
downloadcaffeonacl-29bb23fc92c5b71c0bf8af0b9e580015da9aedda.tar.gz
caffeonacl-29bb23fc92c5b71c0bf8af0b9e580015da9aedda.tar.bz2
caffeonacl-29bb23fc92c5b71c0bf8af0b9e580015da9aedda.zip
removing all references to Blob.num property (that assumes Blob is 4D). Replacing it with accessing Blob.shape[0] - for Blobs with num_axes() != 4
-rw-r--r--python/caffe/pycaffe.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index 30541107..5020eced 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -98,7 +98,7 @@ def _Net_forward(self, blobs=None, start=None, end=None, **kwargs):
# Set input according to defined shapes and make arrays single and
# C-contiguous as Caffe expects.
for in_, blob in kwargs.iteritems():
- if blob.shape[0] != self.blobs[in_].num:
+ if blob.shape[0] != self.blobs[in_].shape[0]:
raise Exception('Input is not batch sized')
self.blobs[in_].data[...] = blob
@@ -146,7 +146,7 @@ def _Net_backward(self, diffs=None, start=None, end=None, **kwargs):
# Set top diffs according to defined shapes and make arrays single and
# C-contiguous as Caffe expects.
for top, diff in kwargs.iteritems():
- if diff.shape[0] != self.blobs[top].num:
+ if diff.shape[0] != self.blobs[top].shape[0]:
raise Exception('Diff is not batch sized')
self.blobs[top].diff[...] = diff
@@ -257,7 +257,7 @@ def _Net_batch(self, blobs):
batch: {blob name: list of blobs} dict for a single batch.
"""
num = len(blobs.itervalues().next())
- batch_size = self.blobs.itervalues().next().num
+ batch_size = self.blobs.itervalues().next().shape[0]
remainder = num % batch_size
num_batches = num / batch_size