summaryrefslogtreecommitdiff
path: root/src/caffe/solver.cpp
diff options
context:
space:
mode:
authorRonghang Hu <huronghang@hotmail.com>2015-09-14 14:57:18 -0700
committerRonghang Hu <huronghang@hotmail.com>2015-09-14 14:57:18 -0700
commite4baef28d6969cba1b656c7e7d525d877826c251 (patch)
tree716c50140d95536134916030f25e8358c4b8b6f6 /src/caffe/solver.cpp
parentf87658a257d81a4a9b065b65ab1e899bfa16656a (diff)
parentab554cb4918cf7bccfada00339b4d1d5ccf3b4af (diff)
downloadcaffeonacl-e4baef28d6969cba1b656c7e7d525d877826c251.tar.gz
caffeonacl-e4baef28d6969cba1b656c7e7d525d877826c251.tar.bz2
caffeonacl-e4baef28d6969cba1b656c7e7d525d877826c251.zip
Merge pull request #3049 from seanbell/check-snapshot-prefix
Check that the snapshot directory is writeable before starting training
Diffstat (limited to 'src/caffe/solver.cpp')
-rw-r--r--src/caffe/solver.cpp19
1 files changed, 19 insertions, 0 deletions
diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp
index 2e59a881..3574ce75 100644
--- a/src/caffe/solver.cpp
+++ b/src/caffe/solver.cpp
@@ -55,6 +55,7 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
<< std::endl << param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
+ CheckSnapshotWritePermissions();
if (Caffe::root_solver() && param_.random_seed() >= 0) {
Caffe::set_random_seed(param_.random_seed());
}
@@ -435,6 +436,24 @@ void Solver<Dtype>::Snapshot() {
}
template <typename Dtype>
+void Solver<Dtype>::CheckSnapshotWritePermissions() {
+ if (Caffe::root_solver() && param_.snapshot()) {
+ CHECK(param_.has_snapshot_prefix())
+ << "In solver params, snapshot is specified but snapshot_prefix is not";
+ string probe_filename = SnapshotFilename(".tempfile");
+ std::ofstream probe_ofs(probe_filename.c_str());
+ if (probe_ofs.good()) {
+ probe_ofs.close();
+ std::remove(probe_filename.c_str());
+ } else {
+ LOG(FATAL) << "Cannot write to snapshot prefix '"
+ << param_.snapshot_prefix() << "'. Make sure "
+ << "that the directory exists and is writeable.";
+ }
+ }
+}
+
+template <typename Dtype>
string Solver<Dtype>::SnapshotFilename(const string extension) {
string filename(param_.snapshot_prefix());
const int kBufferSize = 20;