summaryrefslogtreecommitdiff
path: root/src/caffe/layers/window_data_layer.cpp
diff options
context:
space:
mode:
authorSergio <sguada@gmail.com>2014-10-07 14:14:50 -0700
committerSergio <sguada@gmail.com>2014-10-15 17:03:07 -0700
commit3744598ed39d26d6219cd0369dd6923e2c747c99 (patch)
tree757e28b7a5bc935354617c0243a2bfc4fed2896e /src/caffe/layers/window_data_layer.cpp
parente9d6e5a0b22a9f4768b8c04c9031ee8adb822ece (diff)
downloadcaffeonacl-3744598ed39d26d6219cd0369dd6923e2c747c99.tar.gz
caffeonacl-3744598ed39d26d6219cd0369dd6923e2c747c99.tar.bz2
caffeonacl-3744598ed39d26d6219cd0369dd6923e2c747c99.zip
Speed up WindowDataLayer and add mean_values
Diffstat (limited to 'src/caffe/layers/window_data_layer.cpp')
-rw-r--r--src/caffe/layers/window_data_layer.cpp65
1 files changed, 46 insertions, 19 deletions
diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp
index 8e656155..fc0ffc88 100644
--- a/src/caffe/layers/window_data_layer.cpp
+++ b/src/caffe/layers/window_data_layer.cpp
@@ -170,15 +170,30 @@ void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
this->prefetch_label_.Reshape(batch_size, 1, 1, 1);
// data mean
- if (this->transform_param_.has_mean_file()) {
+ has_mean_file_ = this->transform_param_.has_mean_file();
+ has_mean_values_ = this->transform_param_.mean_value_size() > 0;
+ if (has_mean_file_) {
const string& mean_file =
this->transform_param_.mean_file();
LOG(INFO) << "Loading mean file from" << mean_file;
BlobProto blob_proto;
ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
data_mean_.FromProto(blob_proto);
- } else {
- data_mean_.Reshape(1, channels, crop_size, crop_size);
+ }
+ if (has_mean_values_) {
+ CHECK(has_mean_file_ == false) <<
+ "Cannot specify mean_file and mean_value at the same time";
+ for (int c = 0; c < this->transform_param_.mean_value_size(); ++c) {
+ mean_values_.push_back(this->transform_param_.mean_value(c));
+ }
+ CHECK(mean_values_.size() == 1 || mean_values_.size() == channels) <<
+ "Specify either 1 mean_value or as many as channels: " << channels;
+ if (channels > 1 && mean_values_.size() == 1) {
+ // Replicate the mean_value for simplicity
+ for (int c = 1; c < channels; ++c) {
+ mean_values_.push_back(mean_values_[0]);
+ }
+ }
}
}
@@ -211,10 +226,14 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
const bool mirror = this->transform_param_.mirror();
const float fg_fraction =
this->layer_param_.window_data_param().fg_fraction();
- const Dtype* mean = this->data_mean_.cpu_data();
- const int mean_off = (this->data_mean_.width() - crop_size) / 2;
- const int mean_width = this->data_mean_.width();
- const int mean_height = this->data_mean_.height();
+ Dtype* mean = NULL;
+ int mean_off, mean_width, mean_height;
+ if (this->has_mean_file_) {
+ mean = this->data_mean_.mutable_cpu_data();
+ mean_off = (this->data_mean_.width() - crop_size) / 2;
+ mean_width = this->data_mean_.width();
+ mean_height = this->data_mean_.height();
+ }
cv::Size cv_crop_size(crop_size, crop_size);
const string& crop_mode = this->layer_param_.window_data_param().crop_mode();
@@ -357,18 +376,26 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
}
// copy the warped window into top_data
- for (int c = 0; c < channels; ++c) {
- for (int h = 0; h < cv_cropped_img.rows; ++h) {
- for (int w = 0; w < cv_cropped_img.cols; ++w) {
- Dtype pixel =
- static_cast<Dtype>(cv_cropped_img.at<cv::Vec3b>(h, w)[c]);
-
- top_data[((item_id * channels + c) * crop_size + h + pad_h)
- * crop_size + w + pad_w]
- = (pixel
- - mean[(c * mean_height + h + mean_off + pad_h)
- * mean_width + w + mean_off + pad_w])
- * scale;
+ for (int h = 0; h < cv_cropped_img.rows; ++h) {
+ const uchar* ptr = cv_cropped_img.ptr<uchar>(h);
+ int img_index = 0;
+ for (int w = 0; w < cv_cropped_img.cols; ++w) {
+ for (int c = 0; c < channels; ++c) {
+ int top_index = ((item_id * channels + c) * crop_size + h + pad_h)
+ * crop_size + w + pad_w;
+ // int top_index = (c * height + h) * width + w;
+ Dtype pixel = static_cast<Dtype>(ptr[img_index++]);
+ if (this->has_mean_file_) {
+ int mean_index = (c * mean_height + h + mean_off + pad_h)
+ * mean_width + w + mean_off + pad_w;
+ top_data[top_index] = (pixel - mean[mean_index]) * scale;
+ } else {
+ if (this->has_mean_values_) {
+ top_data[top_index] = (pixel - this->mean_values_[c]) * scale;
+ } else {
+ top_data[top_index] = pixel * scale;
+ }
+ }
}
}
}