summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2017-03-07 20:52:57 -0800
committerGitHub <noreply@github.com>2017-03-07 20:52:57 -0800
commite687a71fac81718d40d4e0e98d29eab34f784b5b (patch)
tree55196ea9e06269ae7a27533ce11208dc5a8237b7 /src
parent85ab6100a122042c7dfd4adaf06f4c0b2e71148d (diff)
parent95a436c601a04af620a0e166393d3ff695905bc4 (diff)
downloadcaffeonacl-e687a71fac81718d40d4e0e98d29eab34f784b5b.tar.gz
caffeonacl-e687a71fac81718d40d4e0e98d29eab34f784b5b.tar.bz2
caffeonacl-e687a71fac81718d40d4e0e98d29eab34f784b5b.zip
Merge pull request #4630 from BlGene/load_hdf5_fix
Made load_hd5 check blob dims by default, instead of reshaping.
Diffstat (limited to 'src')
-rw-r--r--src/caffe/layers/hdf5_data_layer.cpp3
-rw-r--r--src/caffe/test/test_hdf5_output_layer.cpp10
-rw-r--r--src/caffe/test/test_hdf5data_layer.cpp2
-rw-r--r--src/caffe/util/hdf5.cpp34
4 files changed, 37 insertions, 12 deletions
diff --git a/src/caffe/layers/hdf5_data_layer.cpp b/src/caffe/layers/hdf5_data_layer.cpp
index b9a071ce..00716a92 100644
--- a/src/caffe/layers/hdf5_data_layer.cpp
+++ b/src/caffe/layers/hdf5_data_layer.cpp
@@ -39,8 +39,9 @@ void HDF5DataLayer<Dtype>::LoadHDF5FileData(const char* filename) {
for (int i = 0; i < top_size; ++i) {
hdf_blobs_[i] = shared_ptr<Blob<Dtype> >(new Blob<Dtype>());
+ // Allow reshape here, as we are loading data not params
hdf5_load_nd_dataset(file_id, this->layer_param_.top(i).c_str(),
- MIN_DATA_DIM, MAX_DATA_DIM, hdf_blobs_[i].get());
+ MIN_DATA_DIM, MAX_DATA_DIM, hdf_blobs_[i].get(), true);
}
herr_t status = H5Fclose(file_id);
diff --git a/src/caffe/test/test_hdf5_output_layer.cpp b/src/caffe/test/test_hdf5_output_layer.cpp
index 3833ebff..2bc2de1e 100644
--- a/src/caffe/test/test_hdf5_output_layer.cpp
+++ b/src/caffe/test/test_hdf5_output_layer.cpp
@@ -77,10 +77,12 @@ TYPED_TEST(HDF5OutputLayerTest, TestForward) {
H5P_DEFAULT);
ASSERT_GE(file_id, 0)<< "Failed to open HDF5 file" <<
this->input_file_name_;
+ // Allow reshape here as we are loading data not params
+ bool reshape = true;
hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
- this->blob_data_);
+ this->blob_data_, reshape);
hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
- this->blob_label_);
+ this->blob_label_, reshape);
herr_t status = H5Fclose(file_id);
EXPECT_GE(status, 0)<< "Failed to close HDF5 file " <<
this->input_file_name_;
@@ -105,12 +107,12 @@ TYPED_TEST(HDF5OutputLayerTest, TestForward) {
Blob<Dtype>* blob_data = new Blob<Dtype>();
hdf5_load_nd_dataset(file_id, HDF5_DATA_DATASET_NAME, 0, 4,
- blob_data);
+ blob_data, reshape);
this->CheckBlobEqual(*(this->blob_data_), *blob_data);
Blob<Dtype>* blob_label = new Blob<Dtype>();
hdf5_load_nd_dataset(file_id, HDF5_DATA_LABEL_NAME, 0, 4,
- blob_label);
+ blob_label, reshape);
this->CheckBlobEqual(*(this->blob_label_), *blob_label);
status = H5Fclose(file_id);
diff --git a/src/caffe/test/test_hdf5data_layer.cpp b/src/caffe/test/test_hdf5data_layer.cpp
index 68e10286..487f5176 100644
--- a/src/caffe/test/test_hdf5data_layer.cpp
+++ b/src/caffe/test/test_hdf5data_layer.cpp
@@ -70,7 +70,7 @@ TYPED_TEST(HDF5DataLayerTest, TestRead) {
int height = 6;
int width = 5;
- // Test that the layer setup got the correct parameters.
+ // Test that the layer setup gives correct parameters.
HDF5DataLayer<Dtype> layer(param);
layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
EXPECT_EQ(this->blob_top_data_->num(), batch_size);
diff --git a/src/caffe/util/hdf5.cpp b/src/caffe/util/hdf5.cpp
index d255877b..ed737429 100644
--- a/src/caffe/util/hdf5.cpp
+++ b/src/caffe/util/hdf5.cpp
@@ -9,7 +9,7 @@ namespace caffe {
template <typename Dtype>
void hdf5_load_nd_dataset_helper(
hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
- Blob<Dtype>* blob) {
+ Blob<Dtype>* blob, bool reshape) {
// Verify that the dataset exists.
CHECK(H5LTfind_dataset(file_id, dataset_name_))
<< "Failed to find HDF5 dataset " << dataset_name_;
@@ -56,17 +56,38 @@ void hdf5_load_nd_dataset_helper(
LOG(FATAL) << "Datatype class unknown";
}
+
vector<int> blob_dims(dims.size());
for (int i = 0; i < dims.size(); ++i) {
blob_dims[i] = dims[i];
}
- blob->Reshape(blob_dims);
+
+ if (reshape) {
+ blob->Reshape(blob_dims);
+ } else {
+ if (blob_dims != blob->shape()) {
+ // create shape string for error message
+ ostringstream stream;
+ int count = 1;
+ for (int i = 0; i < blob_dims.size(); ++i) {
+ stream << blob_dims[i] << " ";
+ count = count * blob_dims[i];
+ }
+ stream << "(" << count << ")";
+ string source_shape_string = stream.str();
+
+ CHECK(blob_dims == blob->shape()) << "Cannot load blob from hdf5; shape "
+ << "mismatch. Source shape is " << source_shape_string
+ << " target shape is " << blob->shape_string();
+ }
+ }
}
template <>
void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
- int min_dim, int max_dim, Blob<float>* blob) {
- hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+ int min_dim, int max_dim, Blob<float>* blob, bool reshape) {
+ hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
+ reshape);
herr_t status = H5LTread_dataset_float(
file_id, dataset_name_, blob->mutable_cpu_data());
CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
@@ -74,8 +95,9 @@ void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
template <>
void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
- int min_dim, int max_dim, Blob<double>* blob) {
- hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+ int min_dim, int max_dim, Blob<double>* blob, bool reshape) {
+ hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob,
+ reshape);
herr_t status = H5LTread_dataset_double(
file_id, dataset_name_, blob->mutable_cpu_data());
CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;