diff options
author | Tobias Domhan <tdomhan@gmail.com> | 2014-03-16 12:50:30 +0100 |
---|---|---|
committer | Tobias Domhan <tdomhan@gmail.com> | 2014-03-16 12:50:30 +0100 |
commit | 585000bbcb3bd189a704d325928554d3811ca84e (patch) | |
tree | f10bf683d4ab75634eddb7d440dab9bc3558c7be /src/caffe/util/io.cpp | |
parent | 4370b3e83e5e923cc68af61ffc820e449917ea64 (diff) | |
download | caffeonacl-585000bbcb3bd189a704d325928554d3811ca84e.tar.gz caffeonacl-585000bbcb3bd189a704d325928554d3811ca84e.tar.bz2 caffeonacl-585000bbcb3bd189a704d325928554d3811ca84e.zip |
support for more than 2 dimensions in hdf5 files
Diffstat (limited to 'src/caffe/util/io.cpp')
-rw-r--r-- | src/caffe/util/io.cpp | 45 |
1 files changed, 33 insertions, 12 deletions
diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp index d6151d5c..c301c55b 100644 --- a/src/caffe/util/io.cpp +++ b/src/caffe/util/io.cpp @@ -12,6 +12,7 @@ #include <algorithm> #include <string> +#include <vector> #include <fstream> // NOLINT(readability/streams) #include "caffe/common.hpp" @@ -100,39 +101,59 @@ bool ReadImageToDatum(const string& filename, const int label, } template <> -void load_2d_dataset<float>(hid_t file_id, const char* dataset_name_, - boost::scoped_ptr<float>* array, hsize_t* dims) { +void hd5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_, + int min_dim, int max_dim, + boost::scoped_ptr<float>* array, std::vector<hsize_t>& out_dims) { herr_t status; int ndims; status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims); - assert(ndims == 2); + CHECK_GE(ndims, min_dim); + CHECK_LE(ndims, max_dim); + + boost::scoped_ptr<hsize_t> dims(new hsize_t[ndims]); H5T_class_t class_; status = H5LTget_dataset_info( - file_id, dataset_name_, dims, &class_, NULL); - assert(class_ == H5T_NATIVE_FLOAT); + file_id, dataset_name_, dims.get(), &class_, NULL); + CHECK_EQ(class_, H5T_FLOAT) << "Epected float data"; + + int array_size = 1; + for (int i=0; i<ndims; ++i) { + out_dims.push_back(dims.get()[i]); + array_size *= dims.get()[i]; + } - array->reset(new float[dims[0] * dims[1]]); + array->reset(new float[array_size]); status = H5LTread_dataset_float( file_id, dataset_name_, array->get()); } template <> -void load_2d_dataset<double>(hid_t file_id, const char* dataset_name_, - boost::scoped_ptr<double>* array, hsize_t* dims) { +void hd5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_, + int min_dim, int max_dim, + boost::scoped_ptr<double>* array, std::vector<hsize_t>& out_dims) { herr_t status; int ndims; status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims); - assert(ndims == 2); + CHECK_GE(ndims, min_dim); + CHECK_LE(ndims, max_dim); + + boost::scoped_ptr<hsize_t> dims(new hsize_t[ndims]); H5T_class_t class_; status = H5LTget_dataset_info( - file_id, dataset_name_, dims, &class_, NULL); - assert(class_ == H5T_NATIVE_DOUBLE); + file_id, dataset_name_, dims.get(), &class_, NULL); + CHECK_EQ(class_, H5T_FLOAT) << "Epected float data"; + + int array_size = 1; + for (int i=0; i<ndims; ++i) { + out_dims.push_back(dims.get()[i]); + array_size *= dims.get()[i]; + } - array->reset(new double[dims[0] * dims[1]]); + array->reset(new double[array_size]); status = H5LTread_dataset_double( file_id, dataset_name_, array->get()); } |