summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeff Donahue <jeff.donahue@gmail.com>2015-09-24 13:11:05 -0700
committerJeff Donahue <jeff.donahue@gmail.com>2015-09-24 13:11:05 -0700
commit349ff65adcab47612dead5142308c2b920797182 (patch)
treefe2bb5736f58f43fee452443c645047ccca588c8
parent37dc63cf3609134eec7950faaee0467071994e96 (diff)
parentebc9963fea7b72f397c446a10a9aeab576979566 (diff)
downloadcaffeonacl-349ff65adcab47612dead5142308c2b920797182.tar.gz
caffeonacl-349ff65adcab47612dead5142308c2b920797182.tar.bz2
caffeonacl-349ff65adcab47612dead5142308c2b920797182.zip
Merge pull request #2978 from lukeyeager/h5t_integer
Allow H5T_INTEGER in HDF5 files
-rw-r--r--src/caffe/test/test_data/generate_sample_data.py14
-rw-r--r--src/caffe/test/test_data/sample_data_2_gzip.h5bin15446 -> 15446 bytes
-rw-r--r--src/caffe/util/hdf5.cpp29
3 files changed, 36 insertions, 7 deletions
diff --git a/src/caffe/test/test_data/generate_sample_data.py b/src/caffe/test/test_data/generate_sample_data.py
index 3703b418..8349dbbc 100644
--- a/src/caffe/test/test_data/generate_sample_data.py
+++ b/src/caffe/test/test_data/generate_sample_data.py
@@ -36,23 +36,25 @@ with h5py.File(script_dir + '/sample_data.h5', 'w') as f:
f['label'] = label
f['label2'] = label2
-with h5py.File(script_dir + '/sample_data_2_gzip.h5', 'w') as f:
+with h5py.File(script_dir + '/sample_data_uint8_gzip.h5', 'w') as f:
f.create_dataset(
'data', data=data + total_size,
compression='gzip', compression_opts=1
)
f.create_dataset(
'label', data=label,
- compression='gzip', compression_opts=1
+ compression='gzip', compression_opts=1,
+ dtype='uint8',
)
f.create_dataset(
'label2', data=label2,
- compression='gzip', compression_opts=1
+ compression='gzip', compression_opts=1,
+ dtype='uint8',
)
with open(script_dir + '/sample_data_list.txt', 'w') as f:
- f.write(script_dir + '/sample_data.h5\n')
- f.write(script_dir + '/sample_data_2_gzip.h5\n')
+ f.write('src/caffe/test/test_data/sample_data.h5\n')
+ f.write('src/caffe/test/test_data/sample_uint8_gzip.h5\n')
# Generate GradientBasedSolver solver_data.h5
@@ -76,4 +78,4 @@ with h5py.File(script_dir + '/solver_data.h5', 'w') as f:
f['targets'] = targets
with open(script_dir + '/solver_data_list.txt', 'w') as f:
- f.write(script_dir + '/solver_data.h5\n')
+ f.write('src/caffe/test/test_data/solver_data.h5\n')
diff --git a/src/caffe/test/test_data/sample_data_2_gzip.h5 b/src/caffe/test/test_data/sample_data_2_gzip.h5
index a138e036..0cb9ef92 100644
--- a/src/caffe/test/test_data/sample_data_2_gzip.h5
+++ b/src/caffe/test/test_data/sample_data_2_gzip.h5
Binary files differ
diff --git a/src/caffe/util/hdf5.cpp b/src/caffe/util/hdf5.cpp
index d0d05f70..7730e76a 100644
--- a/src/caffe/util/hdf5.cpp
+++ b/src/caffe/util/hdf5.cpp
@@ -27,7 +27,34 @@ void hdf5_load_nd_dataset_helper(
status = H5LTget_dataset_info(
file_id, dataset_name_, dims.data(), &class_, NULL);
CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
- CHECK_EQ(class_, H5T_FLOAT) << "Expected float or double data";
+ switch (class_) {
+ case H5T_FLOAT:
+ LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_FLOAT";
+ break;
+ case H5T_INTEGER:
+ LOG_FIRST_N(INFO, 1) << "Datatype class: H5T_INTEGER";
+ break;
+ case H5T_TIME:
+ LOG(FATAL) << "Unsupported datatype class: H5T_TIME";
+ case H5T_STRING:
+ LOG(FATAL) << "Unsupported datatype class: H5T_STRING";
+ case H5T_BITFIELD:
+ LOG(FATAL) << "Unsupported datatype class: H5T_BITFIELD";
+ case H5T_OPAQUE:
+ LOG(FATAL) << "Unsupported datatype class: H5T_OPAQUE";
+ case H5T_COMPOUND:
+ LOG(FATAL) << "Unsupported datatype class: H5T_COMPOUND";
+ case H5T_REFERENCE:
+ LOG(FATAL) << "Unsupported datatype class: H5T_REFERENCE";
+ case H5T_ENUM:
+ LOG(FATAL) << "Unsupported datatype class: H5T_ENUM";
+ case H5T_VLEN:
+ LOG(FATAL) << "Unsupported datatype class: H5T_VLEN";
+ case H5T_ARRAY:
+ LOG(FATAL) << "Unsupported datatype class: H5T_ARRAY";
+ default:
+ LOG(FATAL) << "Datatype class unknown";
+ }
vector<int> blob_dims(dims.size());
for (int i = 0; i < dims.size(); ++i) {