summaryrefslogtreecommitdiff
path: root/src/caffe/util/io.cpp
diff options
context:
space:
mode:
authorTobias Domhan <tdomhan@gmail.com>2014-03-16 12:50:30 +0100
committerTobias Domhan <tdomhan@gmail.com>2014-03-16 12:50:30 +0100
commit585000bbcb3bd189a704d325928554d3811ca84e (patch)
treef10bf683d4ab75634eddb7d440dab9bc3558c7be /src/caffe/util/io.cpp
parent4370b3e83e5e923cc68af61ffc820e449917ea64 (diff)
downloadcaffeonacl-585000bbcb3bd189a704d325928554d3811ca84e.tar.gz
caffeonacl-585000bbcb3bd189a704d325928554d3811ca84e.tar.bz2
caffeonacl-585000bbcb3bd189a704d325928554d3811ca84e.zip
support for more than 2 dimensions in hdf5 files
Diffstat (limited to 'src/caffe/util/io.cpp')
-rw-r--r--src/caffe/util/io.cpp45
1 files changed, 33 insertions, 12 deletions
diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp
index d6151d5c..c301c55b 100644
--- a/src/caffe/util/io.cpp
+++ b/src/caffe/util/io.cpp
@@ -12,6 +12,7 @@
#include <algorithm>
#include <string>
+#include <vector>
#include <fstream> // NOLINT(readability/streams)
#include "caffe/common.hpp"
@@ -100,39 +101,59 @@ bool ReadImageToDatum(const string& filename, const int label,
}
template <>
-void load_2d_dataset<float>(hid_t file_id, const char* dataset_name_,
- boost::scoped_ptr<float>* array, hsize_t* dims) {
+void hd5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
+ int min_dim, int max_dim,
+ boost::scoped_ptr<float>* array, std::vector<hsize_t>& out_dims) {
herr_t status;
int ndims;
status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
- assert(ndims == 2);
+ CHECK_GE(ndims, min_dim);
+ CHECK_LE(ndims, max_dim);
+
+ boost::scoped_ptr<hsize_t> dims(new hsize_t[ndims]);
H5T_class_t class_;
status = H5LTget_dataset_info(
- file_id, dataset_name_, dims, &class_, NULL);
- assert(class_ == H5T_NATIVE_FLOAT);
+ file_id, dataset_name_, dims.get(), &class_, NULL);
+ CHECK_EQ(class_, H5T_FLOAT) << "Epected float data";
+
+ int array_size = 1;
+ for (int i=0; i<ndims; ++i) {
+ out_dims.push_back(dims.get()[i]);
+ array_size *= dims.get()[i];
+ }
- array->reset(new float[dims[0] * dims[1]]);
+ array->reset(new float[array_size]);
status = H5LTread_dataset_float(
file_id, dataset_name_, array->get());
}
template <>
-void load_2d_dataset<double>(hid_t file_id, const char* dataset_name_,
- boost::scoped_ptr<double>* array, hsize_t* dims) {
+void hd5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
+ int min_dim, int max_dim,
+ boost::scoped_ptr<double>* array, std::vector<hsize_t>& out_dims) {
herr_t status;
int ndims;
status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
- assert(ndims == 2);
+ CHECK_GE(ndims, min_dim);
+ CHECK_LE(ndims, max_dim);
+
+ boost::scoped_ptr<hsize_t> dims(new hsize_t[ndims]);
H5T_class_t class_;
status = H5LTget_dataset_info(
- file_id, dataset_name_, dims, &class_, NULL);
- assert(class_ == H5T_NATIVE_DOUBLE);
+ file_id, dataset_name_, dims.get(), &class_, NULL);
+ CHECK_EQ(class_, H5T_FLOAT) << "Epected float data";
+
+ int array_size = 1;
+ for (int i=0; i<ndims; ++i) {
+ out_dims.push_back(dims.get()[i]);
+ array_size *= dims.get()[i];
+ }
- array->reset(new double[dims[0] * dims[1]]);
+ array->reset(new double[array_size]);
status = H5LTread_dataset_double(
file_id, dataset_name_, array->get());
}