summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/caffe/solver.hpp2
-rw-r--r--src/caffe/solver.cpp19
2 files changed, 21 insertions, 0 deletions
diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp
index aba3e036..8d52785a 100644
--- a/include/caffe/solver.hpp
+++ b/include/caffe/solver.hpp
@@ -82,6 +82,8 @@ class Solver {
callbacks_.push_back(value);
}
+ void CheckSnapshotWritePermissions();
+
protected:
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 2e59a881..3574ce75 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -55,6 +55,7 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
<< std::endl << param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
+ CheckSnapshotWritePermissions();
if (Caffe::root_solver() && param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
@@ -435,6 +436,24 @@ void Solver<Dtype>::Snapshot() {
}
template <typename Dtype>
+void Solver<Dtype>::CheckSnapshotWritePermissions() {
+ if (Caffe::root_solver() && param_.snapshot()) {
+ CHECK(param_.has_snapshot_prefix())
+ << "In solver params, snapshot is specified but snapshot_prefix is not";
+ string probe_filename = SnapshotFilename(".tempfile");
+ std::ofstream probe_ofs(probe_filename.c_str());
+ if (probe_ofs.good()) {
+ probe_ofs.close();
+ std::remove(probe_filename.c_str());
+ } else {
+ LOG(FATAL) << "Cannot write to snapshot prefix '"
+ << param_.snapshot_prefix() << "'. Make sure "
+ << "that the directory exists and is writeable.";
+ }
+ }
+}
+
+template <typename Dtype>
string Solver<Dtype>::SnapshotFilename(const string extension) {
string filename(param_.snapshot_prefix());
const int kBufferSize = 20;