diff options
author | Kai Li <kaili_kloud@163.com> | 2014-08-28 16:55:56 +0800 |
---|---|---|
committer | Kai Li <kaili_kloud@163.com> | 2014-09-03 13:25:21 +0800 |
commit | b794cf9433233ec002dc454fa40286bbd3f4ecfe (patch) | |
tree | fd1cdf0d553436c1f215e9d4813e824a56af602a /src/caffe/layers/window_data_layer.cpp | |
parent | 6833dc0b7acf51a9bd30986df27e4c00fdfac741 (diff) | |
download | caffeonacl-b794cf9433233ec002dc454fa40286bbd3f4ecfe.tar.gz caffeonacl-b794cf9433233ec002dc454fa40286bbd3f4ecfe.tar.bz2 caffeonacl-b794cf9433233ec002dc454fa40286bbd3f4ecfe.zip |
Simplify the WindowDataLayer using the base class
Diffstat (limited to 'src/caffe/layers/window_data_layer.cpp')
-rw-r--r-- | src/caffe/layers/window_data_layer.cpp | 95 |
1 files changed, 31 insertions, 64 deletions
diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index ab12c61a..f4349825 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -1,7 +1,6 @@ #include <stdint.h> #include <algorithm> -#include <fstream> // NOLINT(readability/streams) #include <map> #include <string> #include <utility> @@ -11,11 +10,12 @@ #include "opencv2/highgui/highgui.hpp" #include "opencv2/imgproc/imgproc.hpp" +#include "caffe/common.hpp" +#include "caffe/data_layers.hpp" #include "caffe/layer.hpp" #include "caffe/util/io.hpp" #include "caffe/util/math_functions.hpp" #include "caffe/util/rng.hpp" -#include "caffe/vision_layers.hpp" // caffe.proto > LayerParameter > WindowDataParameter // 'source' field specifies the window_file @@ -29,8 +29,8 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() { // At each iteration, sample N windows where N*p are foreground (object) // windows and N*(1-p) are background (non-object) windows - Dtype* top_data = prefetch_data_.mutable_cpu_data(); - Dtype* top_label = prefetch_label_.mutable_cpu_data(); + Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); + Dtype* top_label = this->prefetch_label_.mutable_cpu_data(); const Dtype scale = this->layer_param_.window_data_param().scale(); const int batch_size = this->layer_param_.window_data_param().batch_size(); const int crop_size = this->layer_param_.window_data_param().crop_size(); @@ -38,17 +38,17 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() { const bool mirror = this->layer_param_.window_data_param().mirror(); const float fg_fraction = this->layer_param_.window_data_param().fg_fraction(); - const Dtype* mean = data_mean_.cpu_data(); - const int mean_off = (data_mean_.width() - crop_size) / 2; - const int mean_width = data_mean_.width(); - const int mean_height = data_mean_.height(); + const Dtype* mean = this->data_mean_.cpu_data(); + const int mean_off = (this->data_mean_.width() - crop_size) / 2; + const int mean_width = this->data_mean_.width(); + const int mean_height = this->data_mean_.height(); cv::Size cv_crop_size(crop_size, crop_size); const string& crop_mode = this->layer_param_.window_data_param().crop_mode(); bool use_square = (crop_mode == "square") ? true : false; // zero out batch - caffe_set(prefetch_data_.count(), Dtype(0), top_data); + caffe_set(this->prefetch_data_.count(), Dtype(0), top_data); const int num_fg = static_cast<int>(static_cast<float>(batch_size) * fg_fraction); @@ -238,12 +238,7 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() { } template <typename Dtype> -WindowDataLayer<Dtype>::~WindowDataLayer<Dtype>() { - JoinPrefetchThread(); -} - -template <typename Dtype> -void WindowDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, +void WindowDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) { // LayerSetUp runs through the window_file and creates two structures // that hold windows: one for foreground (object) windows and one @@ -268,6 +263,16 @@ void WindowDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, << " foreground sampling fraction: " << this->layer_param_.window_data_param().fg_fraction(); + const bool prefetch_needs_rand = + this->layer_param_.window_data_param().mirror() || + this->layer_param_.window_data_param().crop_size(); + if (prefetch_needs_rand) { + const unsigned int prefetch_rng_seed = caffe_rng_rand(); + prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed)); + } else { + prefetch_rng_.reset(); + } + std::ifstream infile(this->layer_param_.window_data_param().source().c_str()); CHECK(infile.good()) << "Failed to open window file " << this->layer_param_.window_data_param().source() << std::endl; @@ -357,14 +362,14 @@ void WindowDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, CHECK_GT(crop_size, 0); const int batch_size = this->layer_param_.window_data_param().batch_size(); (*top)[0]->Reshape(batch_size, channels, crop_size, crop_size); - prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); LOG(INFO) << "output data size: " << (*top)[0]->num() << "," << (*top)[0]->channels() << "," << (*top)[0]->height() << "," << (*top)[0]->width(); // label (*top)[1]->Reshape(batch_size, 1, 1, 1); - prefetch_label_.Reshape(batch_size, 1, 1, 1); + this->prefetch_label_.Reshape(batch_size, 1, 1, 1); // check if we want to have mean if (this->layer_param_.window_data_param().has_mean_file()) { @@ -373,47 +378,27 @@ void WindowDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, LOG(INFO) << "Loading mean file from" << mean_file; BlobProto blob_proto; ReadProtoFromBinaryFileOrDie(mean_file, &blob_proto); - data_mean_.FromProto(blob_proto); - CHECK_EQ(data_mean_.num(), 1); - CHECK_EQ(data_mean_.width(), data_mean_.height()); - CHECK_EQ(data_mean_.channels(), channels); + this->data_mean_.FromProto(blob_proto); + CHECK_EQ(this->data_mean_.num(), 1); + CHECK_EQ(this->data_mean_.width(), this->data_mean_.height()); + CHECK_EQ(this->data_mean_.channels(), channels); } else { // Simply initialize an all-empty mean. - data_mean_.Reshape(1, channels, crop_size, crop_size); + this->data_mean_.Reshape(1, channels, crop_size, crop_size); } // Now, start the prefetch thread. Before calling prefetch, we make two // cpu_data calls so that the prefetch thread does not accidentally make // simultaneous cudaMalloc calls when the main thread is running. In some // GPUs this seems to cause failures if we do not so. - prefetch_data_.mutable_cpu_data(); - prefetch_label_.mutable_cpu_data(); - data_mean_.cpu_data(); + this->prefetch_data_.mutable_cpu_data(); + this->prefetch_label_.mutable_cpu_data(); + this->data_mean_.cpu_data(); DLOG(INFO) << "Initializing prefetch"; - CreatePrefetchThread(); + this->CreatePrefetchThread(); DLOG(INFO) << "Prefetch initialized."; } template <typename Dtype> -void WindowDataLayer<Dtype>::CreatePrefetchThread() { - const bool prefetch_needs_rand = - this->layer_param_.window_data_param().mirror() || - this->layer_param_.window_data_param().crop_size(); - if (prefetch_needs_rand) { - const unsigned int prefetch_rng_seed = caffe_rng_rand(); - prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed)); - } else { - prefetch_rng_.reset(); - } - // Create the thread. - CHECK(StartInternalThread()) << "Thread execution failed."; -} - -template <typename Dtype> -void WindowDataLayer<Dtype>::JoinPrefetchThread() { - CHECK(WaitForInternalThreadToExit()) << "Thread joining failed."; -} - -template <typename Dtype> unsigned int WindowDataLayer<Dtype>::PrefetchRand() { CHECK(prefetch_rng_); caffe::rng_t* prefetch_rng = @@ -421,24 +406,6 @@ unsigned int WindowDataLayer<Dtype>::PrefetchRand() { return (*prefetch_rng)(); } -template <typename Dtype> -void WindowDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, - vector<Blob<Dtype>*>* top) { - // First, join the thread - JoinPrefetchThread(); - // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), - (*top)[0]->mutable_cpu_data()); - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), - (*top)[1]->mutable_cpu_data()); - // Start a new prefetch thread - CreatePrefetchThread(); -} - -#ifdef CPU_ONLY -STUB_GPU_FORWARD(WindowDataLayer, Forward); -#endif - INSTANTIATE_CLASS(WindowDataLayer); } // namespace caffe |