diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2017-03-07 20:52:57 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-03-07 20:52:57 -0800 |
commit | e687a71fac81718d40d4e0e98d29eab34f784b5b (patch) | |
tree | 55196ea9e06269ae7a27533ce11208dc5a8237b7 /src | |
parent | 85ab6100a122042c7dfd4adaf06f4c0b2e71148d (diff) | |
parent | 95a436c601a04af620a0e166393d3ff695905bc4 (diff) | |
download | caffeonacl-e687a71fac81718d40d4e0e98d29eab34f784b5b.tar.gz caffeonacl-e687a71fac81718d40d4e0e98d29eab34f784b5b.tar.bz2 caffeonacl-e687a71fac81718d40d4e0e98d29eab34f784b5b.zip |
Merge pull request #4630 from BlGene/load_hdf5_fix
Made load_hd5 check blob dims by default, instead of reshaping.
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/layers/hdf5_data_layer.cpp | 3 | ||||
-rw-r--r-- | src/caffe/test/test_hdf5_output_layer.cpp | 10 | ||||
-rw-r--r-- | src/caffe/test/test_hdf5data_layer.cpp | 2 | ||||
-rw-r--r-- | src/caffe/util/hdf5.cpp | 34 |
4 files changed, 37 insertions, 12 deletions
diff --git a/src/caffe/layers/hdf5_data_layer.cpp b/src/caffe/layers/hdf5_data_layer.cpp index b9a071ce..00716a92 100644 --- a/src/caffe/layers/hdf5_data_layer.cpp +++ b/src/caffe/layers/hdf5_data_layer.cpp @@ -39,8 +39,9 @@ void HDF5DataLayer<Dtype>::LoadHDF5FileData(const char* filename) { for (int i = 0; i < top_size; ++i) { hdf_blobs_[i] = shared_ptr<Blob<Dtype> >(new Blob<Dtype>()); + // Allow reshape here, as we are loading data not params hdf5_load_nd_dataset(file_id, this->layer_param_.top(i).c_str(), - MIN_DATA_DIM, MAX_DATA_DIM, hdf_blobs_[i].get()); + MIN_DATA_DIM, MAX_DATA_DIM, hdf_blobs_[i].get(), true); } herr_t status = H5Fclose(file_id); diff --git a/src/caffe/test/test_hdf5_output_layer.cpp b/src/caffe/test/test_hdf5_output_layer.cpp index 3833ebff..2bc2de1e 100644 --- a/src/caffe/test/test_hdf5_output_layer.cpp +++ b/src/caffe/test/test_hdf5_output_layer.cpp @@ -77,10 +77,12 @@ TYPED_TEST(HDF5OutputLayerTest, TestForward) { H5P_DEFAULT); ASSERT_GE(file_id, 0)<< "Failed to open HDF5 file" << this->input_file_name_; + // Allow reshape here as we are loading data not params + bool reshape = true; hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4, - this->blob_data_); + this->blob_data_, reshape); hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4, - this->blob_label_); + this->blob_label_, reshape); herr_t status = H5Fclose(file_id); EXPECT_GE(status, 0)<< "Failed to close HDF5 file " << this->input_file_name_; @@ -105,12 +107,12 @@ TYPED_TEST(HDF5OutputLayerTest, TestForward) { Blob<Dtype>* blob_data = new Blob<Dtype>(); hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4, - blob_data); + blob_data, reshape); this->CheckBlobEqual(*(this->blob_data_), *blob_data); Blob<Dtype>* blob_label = new Blob<Dtype>(); hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4, - blob_label); + blob_label, reshape); this->CheckBlobEqual(*(this->blob_label_), *blob_label); status = H5Fclose(file_id); diff --git a/src/caffe/test/test_hdf5data_layer.cpp b/src/caffe/test/test_hdf5data_layer.cpp index 68e10286..487f5176 100644 --- a/src/caffe/test/test_hdf5data_layer.cpp +++ b/src/caffe/test/test_hdf5data_layer.cpp @@ -70,7 +70,7 @@ TYPED_TEST(HDF5DataLayerTest, TestRead) { int height = 6; int width = 5; - // Test that the layer setup got the correct parameters. + // Test that the layer setup gives correct parameters. HDF5DataLayer<Dtype> layer(param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); EXPECT_EQ(this->blob_top_data_->num(), batch_size); diff --git a/src/caffe/util/hdf5.cpp b/src/caffe/util/hdf5.cpp index d255877b..ed737429 100644 --- a/src/caffe/util/hdf5.cpp +++ b/src/caffe/util/hdf5.cpp @@ -9,7 +9,7 @@ namespace caffe { template <typename Dtype> void hdf5_load_nd_dataset_helper( hid_t file_id, const char* dataset_name_, int min_dim, int max_dim, - Blob<Dtype>* blob) { + Blob<Dtype>* blob, bool reshape) { // Verify that the dataset exists. CHECK(H5LTfind_dataset(file_id, dataset_name_)) << "Failed to find HDF5 dataset " << dataset_name_; @@ -56,17 +56,38 @@ void hdf5_load_nd_dataset_helper( LOG(FATAL) << "Datatype class unknown"; } + vector<int> blob_dims(dims.size()); for (int i = 0; i < dims.size(); ++i) { blob_dims[i] = dims[i]; } - blob->Reshape(blob_dims); + + if (reshape) { + blob->Reshape(blob_dims); + } else { + if (blob_dims != blob->shape()) { + // create shape string for error message + ostringstream stream; + int count = 1; + for (int i = 0; i < blob_dims.size(); ++i) { + stream << blob_dims[i] << " "; + count = count * blob_dims[i]; + } + stream << "(" << count << ")"; + string source_shape_string = stream.str(); + + CHECK(blob_dims == blob->shape()) << "Cannot load blob from hdf5; shape " + << "mismatch. Source shape is " << source_shape_string + << " target shape is " << blob->shape_string(); + } + } } template <> void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_, - int min_dim, int max_dim, Blob<float>* blob) { - hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob); + int min_dim, int max_dim, Blob<float>* blob, bool reshape) { + hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob, + reshape); herr_t status = H5LTread_dataset_float( file_id, dataset_name_, blob->mutable_cpu_data()); CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_; @@ -74,8 +95,9 @@ void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_, template <> void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_, - int min_dim, int max_dim, Blob<double>* blob) { - hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob); + int min_dim, int max_dim, Blob<double>* blob, bool reshape) { + hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob, + reshape); herr_t status = H5LTread_dataset_double( file_id, dataset_name_, blob->mutable_cpu_data()); CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_; |