diff options
author | Jon Long <jonlong@cs.berkeley.edu> | 2016-03-02 16:23:14 -0800 |
---|---|---|
committer | Jon Long <jonlong@cs.berkeley.edu> | 2016-03-02 16:23:14 -0800 |
commit | 559758d0c5c5906633174d392b89c0a7a88dc9f9 (patch) | |
tree | 08d2bde1469bf8c0f9c31aaacc7aeff935a04e9e /python | |
parent | 37d1f915f966954401a49243503076fcc172a027 (diff) | |
parent | 666da79ad2f4d72c804ddadc7b10157e4d04bdd0 (diff) | |
download | caffeonacl-559758d0c5c5906633174d392b89c0a7a88dc9f9.tar.gz caffeonacl-559758d0c5c5906633174d392b89c0a7a88dc9f9.tar.bz2 caffeonacl-559758d0c5c5906633174d392b89c0a7a88dc9f9.zip |
Merge pull request #3716 from ttdt/master
Use six library to ensure pycaffe.py python3 compliance
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/pycaffe.py | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 5020eced..c5c0b824 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -14,6 +14,8 @@ from ._caffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, \ RMSPropSolver, AdaDeltaSolver, AdamSolver import caffe.io +import six + # We directly update methods from Net here (rather than using composition or # inheritance) so that nets created by caffe (e.g., by SGDSolver) will # automatically have the improved interface. @@ -97,7 +99,7 @@ def _Net_forward(self, blobs=None, start=None, end=None, **kwargs): raise Exception('Input blob arguments do not match net inputs.') # Set input according to defined shapes and make arrays single and # C-contiguous as Caffe expects. - for in_, blob in kwargs.iteritems(): + for in_, blob in six.iteritems(kwargs): if blob.shape[0] != self.blobs[in_].shape[0]: raise Exception('Input is not batch sized') self.blobs[in_].data[...] = blob @@ -145,7 +147,7 @@ def _Net_backward(self, diffs=None, start=None, end=None, **kwargs): raise Exception('Top diff arguments do not match net outputs.') # Set top diffs according to defined shapes and make arrays single and # C-contiguous as Caffe expects. - for top, diff in kwargs.iteritems(): + for top, diff in six.iteritems(kwargs): if diff.shape[0] != self.blobs[top].shape[0]: raise Exception('Diff is not batch sized') self.blobs[top].diff[...] = diff @@ -174,13 +176,13 @@ def _Net_forward_all(self, blobs=None, **kwargs): all_outs = {out: [] for out in set(self.outputs + (blobs or []))} for batch in self._batch(kwargs): outs = self.forward(blobs=blobs, **batch) - for out, out_blob in outs.iteritems(): + for out, out_blob in six.iteritems(outs): all_outs[out].extend(out_blob.copy()) # Package in ndarray. for out in all_outs: all_outs[out] = np.asarray(all_outs[out]) # Discard padding. - pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next()) + pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs))) if pad: for out in all_outs: all_outs[out] = all_outs[out][:-pad] @@ -215,16 +217,16 @@ def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs): for fb, bb in izip_longest(forward_batches, backward_batches, fillvalue={}): batch_blobs = self.forward(blobs=blobs, **fb) batch_diffs = self.backward(diffs=diffs, **bb) - for out, out_blobs in batch_blobs.iteritems(): + for out, out_blobs in six.iteritems(batch_blobs): all_outs[out].extend(out_blobs.copy()) - for diff, out_diffs in batch_diffs.iteritems(): + for diff, out_diffs in six.iteritems(batch_diffs): all_diffs[diff].extend(out_diffs.copy()) # Package in ndarray. for out, diff in zip(all_outs, all_diffs): all_outs[out] = np.asarray(all_outs[out]) all_diffs[diff] = np.asarray(all_diffs[diff]) # Discard padding at the end and package in ndarray. - pad = len(all_outs.itervalues().next()) - len(kwargs.itervalues().next()) + pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs))) if pad: for out, diff in zip(all_outs, all_diffs): all_outs[out] = all_outs[out][:-pad] @@ -256,10 +258,10 @@ def _Net_batch(self, blobs): ------ batch: {blob name: list of blobs} dict for a single batch. """ - num = len(blobs.itervalues().next()) - batch_size = self.blobs.itervalues().next().shape[0] + num = len(six.next(six.itervalues(blobs))) + batch_size = six.next(six.itervalues(self.blobs)).shape[0] remainder = num % batch_size - num_batches = num / batch_size + num_batches = num // batch_size # Yield full batches. for b in range(num_batches): |