summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-26 21:50:39 -0700
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-05-26 21:50:39 -0700
commit8c609da73666bba762a50ba39e23ca82724fbad4 (patch)
tree98b8f2a9d1043dcc129fd4e936b6ff07fb61ac03 /python
parent017fbd45355ce373db34085f77ab92a9c48af844 (diff)
downloadcaffeonacl-8c609da73666bba762a50ba39e23ca82724fbad4.tar.gz
caffeonacl-8c609da73666bba762a50ba39e23ca82724fbad4.tar.bz2
caffeonacl-8c609da73666bba762a50ba39e23ca82724fbad4.zip
caffe.Net preprocessing members belong to object, not class
Diffstat (limited to 'python')
-rw-r--r--python/caffe/pycaffe.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index 8bed7046..5c1512cd 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -14,11 +14,6 @@ import caffe.io
# inheritance) so that nets created by caffe (e.g., by SGDSolver) will
# automatically have the improved interface.
-# Input preprocessing
-Net.mean = {} # input mean (ndarray, input dimensional or broadcastable)
-Net.input_scale = {} # for a model that expects data = input * input_scale
-Net.channel_swap = {} # for RGB -> BGR and the like
-
@property
def _Net_blobs(self):
@@ -187,10 +182,12 @@ def _Net_set_mean(self, input_, mean_f, mode='elementwise'):
Take
input_: which input to assign this mean.
- mean_f: path to mean .npy
+ mean_f: path to mean .npy with ndarray (input dimensional or broadcastable)
mode: elementwise = use the whole mean (and check dimensions)
channel = channel constant (e.g. mean pixel instead of mean image)
"""
+ if not hasattr(self, 'mean'):
+ self.mean = {}
if input_ not in self.inputs:
raise Exception('Input not in {}'.format(self.inputs))
in_shape = self.blobs[input_].data.shape
@@ -218,6 +215,8 @@ def _Net_set_input_scale(self, input_, scale):
input_: which input to assign this scale factor
scale: scale coefficient
"""
+ if not hasattr(self, 'input_scale'):
+ self.input_scale = {}
if input_ not in self.inputs:
raise Exception('Input not in {}'.format(self.inputs))
self.input_scale[input_] = scale
@@ -233,6 +232,8 @@ def _Net_set_channel_swap(self, input_, order):
order: the order to take the channels.
(2,1,0) maps RGB to BGR for example.
"""
+ if not hasattr(self, 'channel_swap'):
+ self.channel_swap = {}
if input_ not in self.inputs:
raise Exception('Input not in {}'.format(self.inputs))
self.channel_swap[input_] = order