diff options
author | Sergio <sguada@gmail.com> | 2014-10-07 14:14:50 -0700 |
---|---|---|
committer | Sergio <sguada@gmail.com> | 2014-10-15 17:03:07 -0700 |
commit | 3744598ed39d26d6219cd0369dd6923e2c747c99 (patch) | |
tree | 757e28b7a5bc935354617c0243a2bfc4fed2896e /src/caffe/layers/window_data_layer.cpp | |
parent | e9d6e5a0b22a9f4768b8c04c9031ee8adb822ece (diff) | |
download | caffeonacl-3744598ed39d26d6219cd0369dd6923e2c747c99.tar.gz caffeonacl-3744598ed39d26d6219cd0369dd6923e2c747c99.tar.bz2 caffeonacl-3744598ed39d26d6219cd0369dd6923e2c747c99.zip |
Speed up WindowDataLayer and add mean_values
Diffstat (limited to 'src/caffe/layers/window_data_layer.cpp')
-rw-r--r-- | src/caffe/layers/window_data_layer.cpp | 65 |
1 files changed, 46 insertions, 19 deletions
diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index 8e656155..fc0ffc88 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -170,15 +170,30 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, this->prefetch_label_.Reshape(batch_size, 1, 1, 1); // data mean - if (this->transform_param_.has_mean_file()) { + has_mean_file_ = this->transform_param_.has_mean_file(); + has_mean_values_ = this->transform_param_.mean_value_size() > 0; + if (has_mean_file_) { const string& mean_file = this->transform_param_.mean_file(); LOG(INFO) << "Loading mean file from" << mean_file; BlobProto blob_proto; ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); data_mean_.FromProto(blob_proto); - } else { - data_mean_.Reshape(1, channels, crop_size, crop_size); + } + if (has_mean_values_) { + CHECK(has_mean_file_ == false) << + "Cannot specify mean_file and mean_value at the same time"; + for (int c = 0; c < this->transform_param_.mean_value_size(); ++c) { + mean_values_.push_back(this->transform_param_.mean_value(c)); + } + CHECK(mean_values_.size() == 1 || mean_values_.size() == channels) << + "Specify either 1 mean_value or as many as channels: " << channels; + if (channels > 1 && mean_values_.size() == 1) { + // Replicate the mean_value for simplicity + for (int c = 1; c < channels; ++c) { + mean_values_.push_back(mean_values_[0]); + } + } } } @@ -211,10 +226,14 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() { const bool mirror = this->transform_param_.mirror(); const float fg_fraction = this->layer_param_.window_data_param().fg_fraction(); - const Dtype* mean = this->data_mean_.cpu_data(); - const int mean_off = (this->data_mean_.width() - crop_size) / 2; - const int mean_width = this->data_mean_.width(); - const int mean_height = this->data_mean_.height(); + Dtype* mean = NULL; + int mean_off, mean_width, mean_height; + if (this->has_mean_file_) { + mean = this->data_mean_.mutable_cpu_data(); + mean_off = (this->data_mean_.width() - crop_size) / 2; + mean_width = this->data_mean_.width(); + mean_height = this->data_mean_.height(); + } cv::Size cv_crop_size(crop_size, crop_size); const string& crop_mode = this->layer_param_.window_data_param().crop_mode(); @@ -357,18 +376,26 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() { } // copy the warped window into top_data - for (int c = 0; c < channels; ++c) { - for (int h = 0; h < cv_cropped_img.rows; ++h) { - for (int w = 0; w < cv_cropped_img.cols; ++w) { - Dtype pixel = - static_cast<Dtype>(cv_cropped_img.at<cv::Vec3b>(h, w)[c]); - - top_data[((item_id * channels + c) * crop_size + h + pad_h) - * crop_size + w + pad_w] - = (pixel - - mean[(c * mean_height + h + mean_off + pad_h) - * mean_width + w + mean_off + pad_w]) - * scale; + for (int h = 0; h < cv_cropped_img.rows; ++h) { + const uchar* ptr = cv_cropped_img.ptr<uchar>(h); + int img_index = 0; + for (int w = 0; w < cv_cropped_img.cols; ++w) { + for (int c = 0; c < channels; ++c) { + int top_index = ((item_id * channels + c) * crop_size + h + pad_h) + * crop_size + w + pad_w; + // int top_index = (c * height + h) * width + w; + Dtype pixel = static_cast<Dtype>(ptr[img_index++]); + if (this->has_mean_file_) { + int mean_index = (c * mean_height + h + mean_off + pad_h) + * mean_width + w + mean_off + pad_w; + top_data[top_index] = (pixel - mean[mean_index]) * scale; + } else { + if (this->has_mean_values_) { + top_data[top_index] = (pixel - this->mean_values_[c]) * scale; + } else { + top_data[top_index] = pixel * scale; + } + } } } } |