summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Tzeng <etzeng@eecs.berkeley.edu>2015-07-22 16:17:01 -0700
committerEric Tzeng <etzeng@eecs.berkeley.edu>2015-08-07 14:56:38 -0700
commitada055bbf155882534907a7fb98a35e4f7bba392 (patch)
treec91d108a3cab4709c1bf75ece74248a41a5ad3b1
parent1e740e1e8568063ffe69162a30a97961c5c62af0 (diff)
downloadcaffeonacl-ada055bbf155882534907a7fb98a35e4f7bba392.tar.gz
caffeonacl-ada055bbf155882534907a7fb98a35e4f7bba392.tar.bz2
caffeonacl-ada055bbf155882534907a7fb98a35e4f7bba392.zip
Snapshot model weights/solver state to HDF5 files.
Summary of changes: - HDF5 helper functions were moved into a separate file util/hdf5.cpp - hdf5_save_nd_dataset now saves n-d blobs, can save diffs instead of data - Minor fix for memory leak in HDF5 functions (delete instead of delete[]) - Extra methods have been added to both Net/Solver enabling snapshotting and restoring from HDF5 files - snapshot_format was added to SolverParameters, with possible values HDF5 or BINARYPROTO (default HDF5) - kMaxBlobAxes was reduced to 32 to match the limitations of HDF5
-rw-r--r--include/caffe/blob.hpp2
-rw-r--r--include/caffe/net.hpp4
-rw-r--r--include/caffe/solver.hpp21
-rw-r--r--include/caffe/util/hdf5.hpp39
-rw-r--r--include/caffe/util/io.hpp18
-rw-r--r--src/caffe/layers/hdf5_data_layer.cpp2
-rw-r--r--src/caffe/layers/hdf5_output_layer.cpp2
-rw-r--r--src/caffe/layers/hdf5_output_layer.cu1
-rw-r--r--src/caffe/net.cpp122
-rw-r--r--src/caffe/proto/caffe.proto7
-rw-r--r--src/caffe/solver.cpp164
-rw-r--r--src/caffe/test/test_hdf5_output_layer.cpp1
-rw-r--r--src/caffe/util/hdf5.cpp160
-rw-r--r--src/caffe/util/io.cpp74
14 files changed, 482 insertions, 135 deletions
diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp
index 472cc184..9b813e73 100644
--- a/include/caffe/blob.hpp
+++ b/include/caffe/blob.hpp
@@ -10,7 +10,7 @@
#include "caffe/syncedmem.hpp"
#include "caffe/util/math_functions.hpp"
-const int kMaxBlobAxes = INT_MAX;
+const int kMaxBlobAxes = 32;
namespace caffe {
diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp
index 5665df1e..dfd2e556 100644
--- a/include/caffe/net.hpp
+++ b/include/caffe/net.hpp
@@ -98,8 +98,12 @@ class Net {
*/
void CopyTrainedLayersFrom(const NetParameter& param);
void CopyTrainedLayersFrom(const string trained_filename);
+ void CopyTrainedLayersFromBinaryProto(const string trained_filename);
+ void CopyTrainedLayersFromHDF5(const string trained_filename);
/// @brief Writes the net to a proto.
void ToProto(NetParameter* param, bool write_diff = false) const;
+ /// @brief Writes the net to an HDF5 file.
+ void ToHDF5(const string& filename, bool write_diff = false) const;
/// @brief returns the network name.
inline const string& name() const { return name_; }
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index c2ced487..703434b5 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -27,9 +27,9 @@ class Solver {
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
- // The Restore function implements how one should restore the solver to a
- // previously snapshotted state. You should implement the RestoreSolverState()
- // function that restores the state from a SolverState protocol buffer.
+ // The Restore method simply dispatches to one of the
+ // RestoreSolverStateFrom___ protected methods. You should implement these
+ // methods to restore the state from the appropriate snapshot type.
void Restore(const char* resume_file);
virtual ~Solver() {}
inline shared_ptr<Net<Dtype> > net() { return net_; }
@@ -46,11 +46,15 @@ class Solver {
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
+ string SnapshotFilename(const string extension);
+ string SnapshotToBinaryProto();
+ string SnapshotToHDF5();
// The test routine
void TestAll();
void Test(const int test_net_id = 0);
- virtual void SnapshotSolverState(SolverState* state) = 0;
- virtual void RestoreSolverState(const SolverState& state) = 0;
+ virtual void SnapshotSolverState(const string& model_filename) = 0;
+ virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
+ virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
SolverParameter param_;
@@ -85,8 +89,11 @@ class SGDSolver : public Solver<Dtype> {
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
virtual void ClipGradients();
- virtual void SnapshotSolverState(SolverState * state);
- virtual void RestoreSolverState(const SolverState& state);
+ virtual void SnapshotSolverState(const string& model_filename);
+ virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
+ virtual void SnapshotSolverStateToHDF5(const string& model_filename);
+ virtual void RestoreSolverStateFromHDF5(const string& state_file);
+ virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
// history maintains the historical momentum data.
// update maintains update related data and is not needed in snapshots.
// temp maintains other information that might be needed in computation
diff --git a/include/caffe/util/hdf5.hpp b/include/caffe/util/hdf5.hpp
new file mode 100644
index 00000000..ce568c5e
--- /dev/null
+++ b/include/caffe/util/hdf5.hpp
@@ -0,0 +1,39 @@
+#ifndef CAFFE_UTIL_HDF5_H_
+#define CAFFE_UTIL_HDF5_H_
+
+#include <string>
+
+#include "hdf5.h"
+#include "hdf5_hl.h"
+
+#include "caffe/blob.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void hdf5_load_nd_dataset_helper(
+ hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
+ Blob<Dtype>* blob);
+
+template <typename Dtype>
+void hdf5_load_nd_dataset(
+ hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
+ Blob<Dtype>* blob);
+
+template <typename Dtype>
+void hdf5_save_nd_dataset(
+ const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob,
+ bool write_diff = false);
+
+int hdf5_load_int(hid_t loc_id, const string& dataset_name);
+void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i);
+string hdf5_load_string(hid_t loc_id, const string& dataset_name);
+void hdf5_save_string(hid_t loc_id, const string& dataset_name,
+ const string& s);
+
+int hdf5_get_num_links(hid_t loc_id);
+string hdf5_get_name_by_idx(hid_t loc_id, int idx);
+
+} // namespace caffe
+
+#endif // CAFFE_UTIL_HDF5_H_
diff --git a/include/caffe/util/io.hpp b/include/caffe/util/io.hpp
index 3a62c3c9..c0938ad0 100644
--- a/include/caffe/util/io.hpp
+++ b/include/caffe/util/io.hpp
@@ -5,15 +5,11 @@
#include <string>
#include "google/protobuf/message.h"
-#include "hdf5.h"
-#include "hdf5_hl.h"
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
-#define HDF5_NUM_DIMS 4
-
namespace caffe {
using ::google::protobuf::Message;
@@ -140,20 +136,6 @@ cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);
void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);
-template <typename Dtype>
-void hdf5_load_nd_dataset_helper(
- hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
- Blob<Dtype>* blob);
-
-template <typename Dtype>
-void hdf5_load_nd_dataset(
- hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
- Blob<Dtype>* blob);
-
-template <typename Dtype>
-void hdf5_save_nd_dataset(
- const hid_t file_id, const string& dataset_name, const Blob<Dtype>& blob);
-
} // namespace caffe
#endif // CAFFE_UTIL_IO_H_
diff --git a/src/caffe/layers/hdf5_data_layer.cpp b/src/caffe/layers/hdf5_data_layer.cpp
index 8a782f7e..8ced5103 100644
--- a/src/caffe/layers/hdf5_data_layer.cpp
+++ b/src/caffe/layers/hdf5_data_layer.cpp
@@ -16,7 +16,7 @@ TODO:
#include "caffe/data_layers.hpp"
#include "caffe/layer.hpp"
-#include "caffe/util/io.hpp"
+#include "caffe/util/hdf5.hpp"
namespace caffe {
diff --git a/src/caffe/layers/hdf5_output_layer.cpp b/src/caffe/layers/hdf5_output_layer.cpp
index f63375c3..56788c21 100644
--- a/src/caffe/layers/hdf5_output_layer.cpp
+++ b/src/caffe/layers/hdf5_output_layer.cpp
@@ -6,7 +6,7 @@
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
-#include "caffe/util/io.hpp"
+#include "caffe/util/hdf5.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
diff --git a/src/caffe/layers/hdf5_output_layer.cu b/src/caffe/layers/hdf5_output_layer.cu
index ae497c34..eb6d0e47 100644
--- a/src/caffe/layers/hdf5_output_layer.cu
+++ b/src/caffe/layers/hdf5_output_layer.cu
@@ -6,7 +6,6 @@
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
-#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index a18ee638..0812b367 100644
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
@@ -5,12 +5,14 @@
#include <utility>
#include <vector>
+#include "hdf5.h"
+
#include "caffe/common.hpp"
#include "caffe/layer.hpp"
#include "caffe/net.hpp"
#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/hdf5.hpp"
#include "caffe/util/insert_splits.hpp"
-#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/upgrade_proto.hpp"
@@ -747,12 +749,73 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
template <typename Dtype>
void Net<Dtype>::CopyTrainedLayersFrom(const string trained_filename) {
+ if (trained_filename.size() >= 3 &&
+ trained_filename.compare(trained_filename.size() - 3, 3, ".h5") == 0) {
+ CopyTrainedLayersFromHDF5(trained_filename);
+ } else {
+ CopyTrainedLayersFromBinaryProto(trained_filename);
+ }
+}
+
+template <typename Dtype>
+void Net<Dtype>::CopyTrainedLayersFromBinaryProto(
+ const string trained_filename) {
NetParameter param;
ReadNetParamsFromBinaryFileOrDie(trained_filename, &param);
CopyTrainedLayersFrom(param);
}
template <typename Dtype>
+void Net<Dtype>::CopyTrainedLayersFromHDF5(const string trained_filename) {
+ hid_t file_hid = H5Fopen(trained_filename.c_str(), H5F_ACC_RDONLY,
+ H5P_DEFAULT);
+ CHECK_GE(file_hid, 0) << "Couldn't open " << trained_filename;
+ hid_t data_hid = H5Gopen2(file_hid, "data", H5P_DEFAULT);
+ CHECK_GE(data_hid, 0) << "Error reading weights from " << trained_filename;
+ int num_layers = hdf5_get_num_links(data_hid);
+ for (int i = 0; i < num_layers; ++i) {
+ string source_layer_name = hdf5_get_name_by_idx(data_hid, i);
+ if (!layer_names_index_.count(source_layer_name)) {
+ DLOG(INFO) << "Ignoring source layer " << source_layer_name;
+ continue;
+ }
+ int target_layer_id = layer_names_index_[source_layer_name];
+ DLOG(INFO) << "Copying source layer " << source_layer_name;
+ vector<shared_ptr<Blob<Dtype> > >& target_blobs =
+ layers_[target_layer_id]->blobs();
+ hid_t layer_hid = H5Gopen2(data_hid, source_layer_name.c_str(),
+ H5P_DEFAULT);
+ CHECK_GE(layer_hid, 0)
+ << "Error reading weights from " << trained_filename;
+ // Check that source layer doesn't have more params than target layer
+ int num_source_params = hdf5_get_num_links(layer_hid);
+ CHECK_LE(num_source_params, target_blobs.size())
+ << "Incompatible number of blobs for layer " << source_layer_name;
+ for (int j = 0; j < target_blobs.size(); ++j) {
+ ostringstream oss;
+ oss << j;
+ string dataset_name = oss.str();
+ int target_net_param_id = param_id_vecs_[target_layer_id][j];
+ if (!H5Lexists(layer_hid, dataset_name.c_str(), H5P_DEFAULT)) {
+ // Target param doesn't exist in source weights...
+ if (param_owners_[target_net_param_id] != -1) {
+ // ...but it's weight-shared in target, so that's fine.
+ continue;
+ } else {
+ LOG(FATAL) << "Incompatible number of blobs for layer "
+ << source_layer_name;
+ }
+ }
+ hdf5_load_nd_dataset(layer_hid, dataset_name.c_str(), 0, kMaxBlobAxes,
+ target_blobs[j].get());
+ }
+ H5Gclose(layer_hid);
+ }
+ H5Gclose(data_hid);
+ H5Fclose(file_hid);
+}
+
+template <typename Dtype>
void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
param->Clear();
param->set_name(name_);
@@ -774,6 +837,63 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
}
template <typename Dtype>
+void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {
+ hid_t file_hid = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
+ H5P_DEFAULT);
+ CHECK_GE(file_hid, 0)
+ << "Couldn't open " << filename << " to save weights.";
+ hid_t data_hid = H5Gcreate2(file_hid, "data", H5P_DEFAULT, H5P_DEFAULT,
+ H5P_DEFAULT);
+ CHECK_GE(data_hid, 0) << "Error saving weights to " << filename << ".";
+ hid_t diff_hid = -1;
+ if (write_diff) {
+ diff_hid = H5Gcreate2(file_hid, "diff", H5P_DEFAULT, H5P_DEFAULT,
+ H5P_DEFAULT);
+ CHECK_GE(diff_hid, 0) << "Error saving weights to " << filename << ".";
+ }
+ for (int layer_id = 0; layer_id < layers_.size(); ++layer_id) {
+ const LayerParameter& layer_param = layers_[layer_id]->layer_param();
+ string layer_name = layer_param.name();
+ hid_t layer_data_hid = H5Gcreate2(data_hid, layer_name.c_str(),
+ H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
+ CHECK_GE(layer_data_hid, 0)
+ << "Error saving weights to " << filename << ".";
+ hid_t layer_diff_hid = -1;
+ if (write_diff) {
+ layer_diff_hid = H5Gcreate2(diff_hid, layer_name.c_str(),
+ H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
+ CHECK_GE(layer_diff_hid, 0)
+ << "Error saving weights to " << filename << ".";
+ }
+ int num_params = layers_[layer_id]->blobs().size();
+ for (int param_id = 0; param_id < num_params; ++param_id) {
+ ostringstream dataset_name;
+ dataset_name << param_id;
+ const int net_param_id = param_id_vecs_[layer_id][param_id];
+ if (param_owners_[net_param_id] == -1) {
+ // Only save params that own themselves
+ hdf5_save_nd_dataset<Dtype>(layer_data_hid, dataset_name.str(),
+ *params_[net_param_id]);
+ }
+ if (write_diff) {
+ // Write diffs regardless of weight-sharing
+ hdf5_save_nd_dataset<Dtype>(layer_diff_hid, dataset_name.str(),
+ *params_[net_param_id], true);
+ }
+ }
+ H5Gclose(layer_data_hid);
+ if (write_diff) {
+ H5Gclose(layer_diff_hid);
+ }
+ }
+ H5Gclose(data_hid);
+ if (write_diff) {
+ H5Gclose(diff_hid);
+ }
+ H5Fclose(file_hid);
+}
+
+template <typename Dtype>
void Net<Dtype>::Update() {
// First, accumulate the diffs of any shared parameters into their owner's
// diff. (Assumes that the learning rate, weight decay, etc. have already been
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index 03daa808..96e975be 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
-// SolverParameter next available ID: 37 (last added: iter_size)
+// SolverParameter next available ID: 38 (last added: snapshot_format)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
@@ -175,6 +175,11 @@ message SolverParameter {
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
optional bool snapshot_diff = 16 [default = false];
+ enum SnapshotFormat {
+ HDF5 = 0;
+ BINARYPROTO = 1;
+ }
+ optional SnapshotFormat snapshot_format = 37 [default = HDF5];
// the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.
enum SolverMode {
CPU = 0;
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index aabe0ede..75271138 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -4,9 +4,13 @@
#include <string>
#include <vector>
+#include "hdf5.h"
+#include "hdf5_hl.h"
+
#include "caffe/net.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"
+#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/upgrade_proto.hpp"
@@ -348,42 +352,58 @@ void Solver<Dtype>::Test(const int test_net_id) {
template <typename Dtype>
void Solver<Dtype>::Snapshot() {
- NetParameter net_param;
- // For intermediate results, we will also dump the gradient values.
- net_->ToProto(&net_param, param_.snapshot_diff());
+ string model_filename;
+ switch (param_.snapshot_format()) {
+ case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
+ model_filename = SnapshotToBinaryProto();
+ break;
+ case caffe::SolverParameter_SnapshotFormat_HDF5:
+ model_filename = SnapshotToHDF5();
+ break;
+ default:
+ LOG(FATAL) << "Unsupported snapshot format.";
+ }
+
+ SnapshotSolverState(model_filename);
+}
+
+template <typename Dtype>
+string Solver<Dtype>::SnapshotFilename(const string extension) {
string filename(param_.snapshot_prefix());
- string model_filename, snapshot_filename;
const int kBufferSize = 20;
char iter_str_buffer[kBufferSize];
snprintf(iter_str_buffer, kBufferSize, "_iter_%d", iter_);
- filename += iter_str_buffer;
- model_filename = filename + ".caffemodel";
- LOG(INFO) << "Snapshotting to " << model_filename;
- WriteProtoToBinaryFile(net_param, model_filename.c_str());
- SolverState state;
- SnapshotSolverState(&state);
- state.set_iter(iter_);
- state.set_learned_net(model_filename);
- state.set_current_step(current_step_);
- snapshot_filename = filename + ".solverstate";
- LOG(INFO) << "Snapshotting solver state to " << snapshot_filename;
- WriteProtoToBinaryFile(state, snapshot_filename.c_str());
+ return filename + iter_str_buffer + extension;
}
template <typename Dtype>
-void Solver<Dtype>::Restore(const char* state_file) {
- SolverState state;
+string Solver<Dtype>::SnapshotToBinaryProto() {
+ string model_filename = SnapshotFilename(".caffemodel");
+ LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
NetParameter net_param;
- ReadProtoFromBinaryFile(state_file, &state);
- if (state.has_learned_net()) {
- ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
- net_->CopyTrainedLayersFrom(net_param);
- }
- iter_ = state.iter();
- current_step_ = state.current_step();
- RestoreSolverState(state);
+ net_->ToProto(&net_param, param_.snapshot_diff());
+ WriteProtoToBinaryFile(net_param, model_filename);
+ return model_filename;
+}
+
+template <typename Dtype>
+string Solver<Dtype>::SnapshotToHDF5() {
+ string model_filename = SnapshotFilename(".caffemodel.h5");
+ LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
+ net_->ToHDF5(model_filename, param_.snapshot_diff());
+ return model_filename;
}
+template <typename Dtype>
+void Solver<Dtype>::Restore(const char* state_file) {
+ string state_filename(state_file);
+ if (state_filename.size() >= 3 &&
+ state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
+ RestoreSolverStateFromHDF5(state_filename);
+ } else {
+ RestoreSolverStateFromBinaryProto(state_filename);
+ }
+}
// Return the current learning rate. The currently implemented learning rate
// policies are as follows:
@@ -618,17 +638,76 @@ void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
}
template <typename Dtype>
-void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) {
- state->clear_history();
+void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
+ switch (this->param_.snapshot_format()) {
+ case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
+ SnapshotSolverStateToBinaryProto(model_filename);
+ break;
+ case caffe::SolverParameter_SnapshotFormat_HDF5:
+ SnapshotSolverStateToHDF5(model_filename);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported snapshot format.";
+ }
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
+ const string& model_filename) {
+ SolverState state;
+ state.set_iter(this->iter_);
+ state.set_learned_net(model_filename);
+ state.set_current_step(this->current_step_);
+ state.clear_history();
for (int i = 0; i < history_.size(); ++i) {
// Add history
- BlobProto* history_blob = state->add_history();
+ BlobProto* history_blob = state.add_history();
history_[i]->ToProto(history_blob);
}
+ string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
+ LOG(INFO)
+ << "Snapshotting solver state to binary proto file" << snapshot_filename;
+ WriteProtoToBinaryFile(state, snapshot_filename.c_str());
+}
+
+template <typename Dtype>
+void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
+ const string& model_filename) {
+ string snapshot_filename =
+ Solver<Dtype>::SnapshotFilename(".solverstate.h5");
+ LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
+ hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC,
+ H5P_DEFAULT, H5P_DEFAULT);
+ CHECK_GE(file_hid, 0)
+ << "Couldn't open " << snapshot_filename << " to save solver state.";
+ hdf5_save_int(file_hid, "iter", this->iter_);
+ hdf5_save_string(file_hid, "learned_net", model_filename);
+ hdf5_save_int(file_hid, "current_step", this->current_step_);
+ hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT,
+ H5P_DEFAULT);
+ CHECK_GE(history_hid, 0)
+ << "Error saving solver state to " << snapshot_filename << ".";
+ for (int i = 0; i < history_.size(); ++i) {
+ ostringstream oss;
+ oss << i;
+ hdf5_save_nd_dataset<Dtype>(history_hid, oss.str(), *history_[i]);
+ }
+ H5Gclose(history_hid);
+ H5Fclose(file_hid);
}
template <typename Dtype>
-void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
+void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
+ const string& state_file) {
+ SolverState state;
+ ReadProtoFromBinaryFile(state_file, &state);
+ this->iter_ = state.iter();
+ if (state.has_learned_net()) {
+ NetParameter net_param;
+ ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
+ this->net_->CopyTrainedLayersFrom(net_param);
+ }
+ this->current_step_ = state.current_step();
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
LOG(INFO) << "SGDSolver: restoring history";
@@ -638,6 +717,31 @@ void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) {
}
template <typename Dtype>
+void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
+ hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
+ CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
+ this->iter_ = hdf5_load_int(file_hid, "iter");
+ if (H5LTfind_dataset(file_hid, "learned_net")) {
+ string learned_net = hdf5_load_string(file_hid, "learned_net");
+ this->net_->CopyTrainedLayersFrom(learned_net);
+ }
+ this->current_step_ = hdf5_load_int(file_hid, "current_step");
+ hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
+ CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
+ int state_history_size = hdf5_get_num_links(history_hid);
+ CHECK_EQ(state_history_size, history_.size())
+ << "Incorrect length of history blobs.";
+ for (int i = 0; i < history_.size(); ++i) {
+ ostringstream oss;
+ oss << i;
+ hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
+ kMaxBlobAxes, history_[i].get());
+ }
+ H5Gclose(history_hid);
+ H5Fclose(file_hid);
+}
+
+template <typename Dtype>
void NesterovSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
const vector<float>& net_params_lr = this->net_->params_lr();
diff --git a/src/caffe/test/test_hdf5_output_layer.cpp b/src/caffe/test/test_hdf5_output_layer.cpp
index a23034f2..b56277b5 100644
--- a/src/caffe/test/test_hdf5_output_layer.cpp
+++ b/src/caffe/test/test_hdf5_output_layer.cpp
@@ -6,6 +6,7 @@
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/vision_layers.hpp"
diff --git a/src/caffe/util/hdf5.cpp b/src/caffe/util/hdf5.cpp
new file mode 100644
index 00000000..d0d05f70
--- /dev/null
+++ b/src/caffe/util/hdf5.cpp
@@ -0,0 +1,160 @@
+#include "caffe/util/hdf5.hpp"
+
+#include <string>
+#include <vector>
+
+namespace caffe {
+
+// Verifies format of data stored in HDF5 file and reshapes blob accordingly.
+template <typename Dtype>
+void hdf5_load_nd_dataset_helper(
+ hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
+ Blob<Dtype>* blob) {
+ // Verify that the dataset exists.
+ CHECK(H5LTfind_dataset(file_id, dataset_name_))
+ << "Failed to find HDF5 dataset " << dataset_name_;
+ // Verify that the number of dimensions is in the accepted range.
+ herr_t status;
+ int ndims;
+ status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
+ CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_;
+ CHECK_GE(ndims, min_dim);
+ CHECK_LE(ndims, max_dim);
+
+ // Verify that the data format is what we expect: float or double.
+ std::vector<hsize_t> dims(ndims);
+ H5T_class_t class_;
+ status = H5LTget_dataset_info(
+ file_id, dataset_name_, dims.data(), &class_, NULL);
+ CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
+ CHECK_EQ(class_, H5T_FLOAT) << "Expected float or double data";
+
+ vector<int> blob_dims(dims.size());
+ for (int i = 0; i < dims.size(); ++i) {
+ blob_dims[i] = dims[i];
+ }
+ blob->Reshape(blob_dims);
+}
+
+template <>
+void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
+ int min_dim, int max_dim, Blob<float>* blob) {
+ hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+ herr_t status = H5LTread_dataset_float(
+ file_id, dataset_name_, blob->mutable_cpu_data());
+ CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
+}
+
+template <>
+void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
+ int min_dim, int max_dim, Blob<double>* blob) {
+ hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
+ herr_t status = H5LTread_dataset_double(
+ file_id, dataset_name_, blob->mutable_cpu_data());
+ CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;
+}
+
+template <>
+void hdf5_save_nd_dataset<float>(
+ const hid_t file_id, const string& dataset_name, const Blob<float>& blob,
+ bool write_diff) {
+ int num_axes = blob.num_axes();
+ hsize_t *dims = new hsize_t[num_axes];
+ for (int i = 0; i < num_axes; ++i) {
+ dims[i] = blob.shape(i);
+ }
+ const float* data;
+ if (write_diff) {
+ data = blob.cpu_diff();
+ } else {
+ data = blob.cpu_data();
+ }
+ herr_t status = H5LTmake_dataset_float(
+ file_id, dataset_name.c_str(), num_axes, dims, data);
+ CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
+ delete[] dims;
+}
+
+template <>
+void hdf5_save_nd_dataset<double>(
+ hid_t file_id, const string& dataset_name, const Blob<double>& blob,
+ bool write_diff) {
+ int num_axes = blob.num_axes();
+ hsize_t *dims = new hsize_t[num_axes];
+ for (int i = 0; i < num_axes; ++i) {
+ dims[i] = blob.shape(i);
+ }
+ const double* data;
+ if (write_diff) {
+ data = blob.cpu_diff();
+ } else {
+ data = blob.cpu_data();
+ }
+ herr_t status = H5LTmake_dataset_double(
+ file_id, dataset_name.c_str(), num_axes, dims, data);
+ CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
+ delete[] dims;
+}
+
+string hdf5_load_string(hid_t loc_id, const string& dataset_name) {
+ // Get size of dataset
+ size_t size;
+ H5T_class_t class_;
+ herr_t status = \
+ H5LTget_dataset_info(loc_id, dataset_name.c_str(), NULL, &class_, &size);
+ CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name;
+ char *buf = new char[size];
+ status = H5LTread_dataset_string(loc_id, dataset_name.c_str(), buf);
+ CHECK_GE(status, 0)
+ << "Failed to load int dataset with name " << dataset_name;
+ string val(buf);
+ delete[] buf;
+ return val;
+}
+
+void hdf5_save_string(hid_t loc_id, const string& dataset_name,
+ const string& s) {
+ herr_t status = \
+ H5LTmake_dataset_string(loc_id, dataset_name.c_str(), s.c_str());
+ CHECK_GE(status, 0)
+ << "Failed to save string dataset with name " << dataset_name;
+}
+
+int hdf5_load_int(hid_t loc_id, const string& dataset_name) {
+ int val;
+ herr_t status = H5LTread_dataset_int(loc_id, dataset_name.c_str(), &val);
+ CHECK_GE(status, 0)
+ << "Failed to load int dataset with name " << dataset_name;
+ return val;
+}
+
+void hdf5_save_int(hid_t loc_id, const string& dataset_name, int i) {
+ hsize_t one = 1;
+ herr_t status = \
+ H5LTmake_dataset_int(loc_id, dataset_name.c_str(), 1, &one, &i);
+ CHECK_GE(status, 0)
+ << "Failed to save int dataset with name " << dataset_name;
+}
+
+int hdf5_get_num_links(hid_t loc_id) {
+ H5G_info_t info;
+ herr_t status = H5Gget_info(loc_id, &info);
+ CHECK_GE(status, 0) << "Error while counting HDF5 links.";
+ return info.nlinks;
+}
+
+string hdf5_get_name_by_idx(hid_t loc_id, int idx) {
+ ssize_t str_size = H5Lget_name_by_idx(
+ loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, NULL, 0, H5P_DEFAULT);
+ CHECK_GE(str_size, 0) << "Error retrieving HDF5 dataset at index " << idx;
+ char *c_str = new char[str_size+1];
+ ssize_t status = H5Lget_name_by_idx(
+ loc_id, ".", H5_INDEX_NAME, H5_ITER_NATIVE, idx, c_str, str_size+1,
+ H5P_DEFAULT);
+ CHECK_GE(status, 0) << "Error retrieving HDF5 dataset at index " << idx;
+ string result(c_str);
+ delete[] c_str;
+ return result;
+}
+
+} // namespace caffe
diff --git a/src/caffe/util/io.cpp b/src/caffe/util/io.cpp
index 77ef7f25..6f033142 100644
--- a/src/caffe/util/io.cpp
+++ b/src/caffe/util/io.cpp
@@ -228,79 +228,5 @@ void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) {
datum->set_data(buffer);
}
-// Verifies format of data stored in HDF5 file and reshapes blob accordingly.
-template <typename Dtype>
-void hdf5_load_nd_dataset_helper(
- hid_t file_id, const char* dataset_name_, int min_dim, int max_dim,
- Blob<Dtype>* blob) {
- // Verify that the dataset exists.
- CHECK(H5LTfind_dataset(file_id, dataset_name_))
- << "Failed to find HDF5 dataset " << dataset_name_;
- // Verify that the number of dimensions is in the accepted range.
- herr_t status;
- int ndims;
- status = H5LTget_dataset_ndims(file_id, dataset_name_, &ndims);
- CHECK_GE(status, 0) << "Failed to get dataset ndims for " << dataset_name_;
- CHECK_GE(ndims, min_dim);
- CHECK_LE(ndims, max_dim);
-
- // Verify that the data format is what we expect: float or double.
- std::vector<hsize_t> dims(ndims);
- H5T_class_t class_;
- status = H5LTget_dataset_info(
- file_id, dataset_name_, dims.data(), &class_, NULL);
- CHECK_GE(status, 0) << "Failed to get dataset info for " << dataset_name_;
- CHECK_EQ(class_, H5T_FLOAT) << "Expected float or double data";
-
- vector<int> blob_dims(dims.size());
- for (int i = 0; i < dims.size(); ++i) {
- blob_dims[i] = dims[i];
- }
- blob->Reshape(blob_dims);
-}
-
-template <>
-void hdf5_load_nd_dataset<float>(hid_t file_id, const char* dataset_name_,
- int min_dim, int max_dim, Blob<float>* blob) {
- hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
- herr_t status = H5LTread_dataset_float(
- file_id, dataset_name_, blob->mutable_cpu_data());
- CHECK_GE(status, 0) << "Failed to read float dataset " << dataset_name_;
-}
-
-template <>
-void hdf5_load_nd_dataset<double>(hid_t file_id, const char* dataset_name_,
- int min_dim, int max_dim, Blob<double>* blob) {
- hdf5_load_nd_dataset_helper(file_id, dataset_name_, min_dim, max_dim, blob);
- herr_t status = H5LTread_dataset_double(
- file_id, dataset_name_, blob->mutable_cpu_data());
- CHECK_GE(status, 0) << "Failed to read double dataset " << dataset_name_;
-}
-
-template <>
-void hdf5_save_nd_dataset<float>(
- const hid_t file_id, const string& dataset_name, const Blob<float>& blob) {
- hsize_t dims[HDF5_NUM_DIMS];
- dims[0] = blob.num();
- dims[1] = blob.channels();
- dims[2] = blob.height();
- dims[3] = blob.width();
- herr_t status = H5LTmake_dataset_float(
- file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data());
- CHECK_GE(status, 0) << "Failed to make float dataset " << dataset_name;
-}
-
-template <>
-void hdf5_save_nd_dataset<double>(
- const hid_t file_id, const string& dataset_name, const Blob<double>& blob) {
- hsize_t dims[HDF5_NUM_DIMS];
- dims[0] = blob.num();
- dims[1] = blob.channels();
- dims[2] = blob.height();
- dims[3] = blob.width();
- herr_t status = H5LTmake_dataset_double(
- file_id, dataset_name.c_str(), HDF5_NUM_DIMS, dims, blob.cpu_data());
- CHECK_GE(status, 0) << "Failed to make double dataset " << dataset_name;
-}
} // namespace caffe