diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/optimization/solver.cpp | 18 | ||||
-rw-r--r-- | src/caffe/optimization/solver.hpp | 2 | ||||
-rw-r--r-- | src/caffe/proto/caffe.proto | 8 | ||||
-rw-r--r-- | src/caffe/pyutil/drawnet.py | 24 |
4 files changed, 27 insertions, 25 deletions
diff --git a/src/caffe/optimization/solver.cpp b/src/caffe/optimization/solver.cpp index 73c69c03..0c68330e 100644 --- a/src/caffe/optimization/solver.cpp +++ b/src/caffe/optimization/solver.cpp @@ -38,8 +38,8 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) { net_->Update(); // Check if we need to do snapshot - if (param_.snapshot() > 0 && iter_ % param_.snapshot() == 0) { - Snapshot(false); + if (param_.snapshot() && iter_ % param_.snapshot() == 0) { + Snapshot(); } if (param_.display() && iter_ % param_.display() == 0) { LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss; @@ -50,18 +50,14 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) { template <typename Dtype> -void Solver<Dtype>::Snapshot(bool is_final) { +void Solver<Dtype>::Snapshot() { NetParameter net_param; // For intermediate results, we will also dump the gradient values. - net_->ToProto(&net_param, !is_final); + net_->ToProto(&net_param, param_.snapshot_diff()); string filename(param_.snapshot_prefix()); - if (is_final) { - filename += "_final"; - } else { - char iter_str_buffer[20]; - sprintf(iter_str_buffer, "_iter_%d", iter_); - filename += iter_str_buffer; - } + char iter_str_buffer[20]; + sprintf(iter_str_buffer, "_iter_%d", iter_); + filename += iter_str_buffer; LOG(INFO) << "Snapshotting to " << filename; WriteProtoToBinaryFile(net_param, filename.c_str()); SolverState state; diff --git a/src/caffe/optimization/solver.hpp b/src/caffe/optimization/solver.hpp index a5ea6126..98c872dc 100644 --- a/src/caffe/optimization/solver.hpp +++ b/src/caffe/optimization/solver.hpp @@ -27,7 +27,7 @@ class Solver { // that stores the learned net. You should implement the SnapshotSolverState() // function that produces a SolverState protocol buffer that needs to be // written to disk together with the learned net. - void Snapshot(bool is_final = false); + void Snapshot(); virtual void SnapshotSolverState(SolverState* state) = 0; // The Restore function implements how one should restore the solver to a // previously snapshotted state. You should implement the RestoreSolverState() diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 4be96963..8eb39b36 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -105,7 +105,9 @@ message SolverParameter { optional float stepsize = 12; // the stepsize for learning rate policy "step" optional string snapshot_prefix = 13; // The prefix for the snapshot. - + // whether to snapshot diff in the results or not. Snapshotting diff will help + // debugging but the final protocol buffer size will be much larger. + optional bool snapshot_diff = 14 [ default = false]; // Adagrad solver parameters // For Adagrad, we will first run normal sgd using the sgd parameters above // for adagrad_skip iterations, and then kick in the adagrad algorithm, with @@ -114,8 +116,8 @@ message SolverParameter { // the layer parameter specifications, as it will adjust the learning rate // of individual parameters in a data-dependent way. // WORK IN PROGRESS: not actually implemented yet. - optional float adagrad_gamma = 14; // adagrad learning rate multiplier - optional float adagrad_skip = 15; // the steps to skip before adagrad kicks in + optional float adagrad_gamma = 15; // adagrad learning rate multiplier + optional float adagrad_skip = 16; // the steps to skip before adagrad kicks in } // A message that stores the solver snapshots diff --git a/src/caffe/pyutil/drawnet.py b/src/caffe/pyutil/drawnet.py index f958c908..bce3dc4f 100644 --- a/src/caffe/pyutil/drawnet.py +++ b/src/caffe/pyutil/drawnet.py @@ -11,14 +11,7 @@ NEURON_LAYER_STYLE = {'shape': 'record', 'fillcolor': '#90EE90', BLOB_STYLE = {'shape': 'octagon', 'fillcolor': '#F0E68C', 'style': 'filled'} -def draw_net(caffe_net, ext='png'): - """Draws a caffe net and returns the image string encoded using the given - extension. - - Input: - caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer. - ext: the image extension. Default 'png'. - """ +def get_pydot_graph(caffe_net): pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph') pydot_nodes = {} pydot_edges = [] @@ -47,11 +40,22 @@ def draw_net(caffe_net, ext='png'): for edge in pydot_edges: pydot_graph.add_edge( pydot.Edge(pydot_nodes[edge[0]], pydot_nodes[edge[1]])) - return pydot_graph.create(format=ext) + return pydot_graph + +def draw_net(caffe_net, ext='png'): + """Draws a caffe net and returns the image string encoded using the given + extension. + + Input: + caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer. + ext: the image extension. Default 'png'. + """ + return get_pydot_graph(caffe_net).create(format=ext) def draw_net_to_file(caffe_net, filename): """Draws a caffe net, and saves it to file using the format given as the - file extension. + file extension. Use '.raw' to output raw text that you can manually feed + to graphviz to draw graphs. """ ext = filename[filename.rfind('.')+1:] with open(filename, 'w') as fid: |