diff options
-rw-r--r-- | include/caffe/solver.hpp | 2 | ||||
-rw-r--r-- | src/caffe/solver.cpp | 19 |
2 files changed, 21 insertions, 0 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index aba3e036..8d52785a 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -82,6 +82,8 @@ class Solver { callbacks_.push_back(value); } + void CheckSnapshotWritePermissions(); + protected: // Make and apply the update value for the current iteration. virtual void ApplyUpdate() = 0; diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 2e59a881..3574ce75 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -55,6 +55,7 @@ void Solver<Dtype>::Init(const SolverParameter& param) { << std::endl << param.DebugString(); param_ = param; CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; + CheckSnapshotWritePermissions(); if (Caffe::root_solver() && param_.random_seed() >= 0) { Caffe::set_random_seed(param_.random_seed()); } @@ -435,6 +436,24 @@ void Solver<Dtype>::Snapshot() { } template <typename Dtype> +void Solver<Dtype>::CheckSnapshotWritePermissions() { + if (Caffe::root_solver() && param_.snapshot()) { + CHECK(param_.has_snapshot_prefix()) + << "In solver params, snapshot is specified but snapshot_prefix is not"; + string probe_filename = SnapshotFilename(".tempfile"); + std::ofstream probe_ofs(probe_filename.c_str()); + if (probe_ofs.good()) { + probe_ofs.close(); + std::remove(probe_filename.c_str()); + } else { + LOG(FATAL) << "Cannot write to snapshot prefix '" + << param_.snapshot_prefix() << "'. Make sure " + << "that the directory exists and is writeable."; + } + } +} + +template <typename Dtype> string Solver<Dtype>::SnapshotFilename(const string extension) { string filename(param_.snapshot_prefix()); const int kBufferSize = 20; |