From 8c609da73666bba762a50ba39e23ca82724fbad4 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Mon, 26 May 2014 21:50:39 -0700 Subject: caffe.Net preprocessing members belong to object, not class --- python/caffe/pycaffe.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'python') diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 8bed7046..5c1512cd 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -14,11 +14,6 @@ import caffe.io # inheritance) so that nets created by caffe (e.g., by SGDSolver) will # automatically have the improved interface. -# Input preprocessing -Net.mean = {} # input mean (ndarray, input dimensional or broadcastable) -Net.input_scale = {} # for a model that expects data = input * input_scale -Net.channel_swap = {} # for RGB -> BGR and the like - @property def _Net_blobs(self): @@ -187,10 +182,12 @@ def _Net_set_mean(self, input_, mean_f, mode='elementwise'): Take input_: which input to assign this mean. - mean_f: path to mean .npy + mean_f: path to mean .npy with ndarray (input dimensional or broadcastable) mode: elementwise = use the whole mean (and check dimensions) channel = channel constant (e.g. mean pixel instead of mean image) """ + if not hasattr(self, 'mean'): + self.mean = {} if input_ not in self.inputs: raise Exception('Input not in {}'.format(self.inputs)) in_shape = self.blobs[input_].data.shape @@ -218,6 +215,8 @@ def _Net_set_input_scale(self, input_, scale): input_: which input to assign this scale factor scale: scale coefficient """ + if not hasattr(self, 'input_scale'): + self.input_scale = {} if input_ not in self.inputs: raise Exception('Input not in {}'.format(self.inputs)) self.input_scale[input_] = scale @@ -233,6 +232,8 @@ def _Net_set_channel_swap(self, input_, order): order: the order to take the channels. (2,1,0) maps RGB to BGR for example. """ + if not hasattr(self, 'channel_swap'): + self.channel_swap = {} if input_ not in self.inputs: raise Exception('Input not in {}'.format(self.inputs)) self.channel_swap[input_] = order -- cgit v1.2.3