summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/caffe/optimization/solver.cpp18
-rw-r--r--src/caffe/optimization/solver.hpp2
-rw-r--r--src/caffe/proto/caffe.proto8
-rw-r--r--src/caffe/pyutil/drawnet.py24
4 files changed, 27 insertions, 25 deletions
diff --git a/src/caffe/optimization/solver.cpp b/src/caffe/optimization/solver.cpp
index 73c69c03..0c68330e 100644
--- a/src/caffe/optimization/solver.cpp
+++ b/src/caffe/optimization/solver.cpp
@@ -38,8 +38,8 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
net_->Update();
// Check if we need to do snapshot
- if (param_.snapshot() > 0 && iter_ % param_.snapshot() == 0) {
- Snapshot(false);
+ if (param_.snapshot() && iter_ % param_.snapshot() == 0) {
+ Snapshot();
}
if (param_.display() && iter_ % param_.display() == 0) {
LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
@@ -50,18 +50,14 @@ void Solver<Dtype>::Solve(Net<Dtype>* net, char* resume_file) {
template <typename Dtype>
-void Solver<Dtype>::Snapshot(bool is_final) {
+void Solver<Dtype>::Snapshot() {
NetParameter net_param;
// For intermediate results, we will also dump the gradient values.
- net_->ToProto(&net_param, !is_final);
+ net_->ToProto(&net_param, param_.snapshot_diff());
string filename(param_.snapshot_prefix());
- if (is_final) {
- filename += "_final";
- } else {
- char iter_str_buffer[20];
- sprintf(iter_str_buffer, "_iter_%d", iter_);
- filename += iter_str_buffer;
- }
+ char iter_str_buffer[20];
+ sprintf(iter_str_buffer, "_iter_%d", iter_);
+ filename += iter_str_buffer;
LOG(INFO) << "Snapshotting to " << filename;
WriteProtoToBinaryFile(net_param, filename.c_str());
SolverState state;
diff --git a/src/caffe/optimization/solver.hpp b/src/caffe/optimization/solver.hpp
index a5ea6126..98c872dc 100644
--- a/src/caffe/optimization/solver.hpp
+++ b/src/caffe/optimization/solver.hpp
@@ -27,7 +27,7 @@ class Solver {
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
- void Snapshot(bool is_final = false);
+ void Snapshot();
virtual void SnapshotSolverState(SolverState* state) = 0;
// The Restore function implements how one should restore the solver to a
// previously snapshotted state. You should implement the RestoreSolverState()
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index 4be96963..8eb39b36 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -105,7 +105,9 @@ message SolverParameter {
optional float stepsize = 12; // the stepsize for learning rate policy "step"
optional string snapshot_prefix = 13; // The prefix for the snapshot.
-
+ // whether to snapshot diff in the results or not. Snapshotting diff will help
+ // debugging but the final protocol buffer size will be much larger.
+ optional bool snapshot_diff = 14 [ default = false];
// Adagrad solver parameters
// For Adagrad, we will first run normal sgd using the sgd parameters above
// for adagrad_skip iterations, and then kick in the adagrad algorithm, with
@@ -114,8 +116,8 @@ message SolverParameter {
// the layer parameter specifications, as it will adjust the learning rate
// of individual parameters in a data-dependent way.
// WORK IN PROGRESS: not actually implemented yet.
- optional float adagrad_gamma = 14; // adagrad learning rate multiplier
- optional float adagrad_skip = 15; // the steps to skip before adagrad kicks in
+ optional float adagrad_gamma = 15; // adagrad learning rate multiplier
+ optional float adagrad_skip = 16; // the steps to skip before adagrad kicks in
}
// A message that stores the solver snapshots
diff --git a/src/caffe/pyutil/drawnet.py b/src/caffe/pyutil/drawnet.py
index f958c908..bce3dc4f 100644
--- a/src/caffe/pyutil/drawnet.py
+++ b/src/caffe/pyutil/drawnet.py
@@ -11,14 +11,7 @@ NEURON_LAYER_STYLE = {'shape': 'record', 'fillcolor': '#90EE90',
BLOB_STYLE = {'shape': 'octagon', 'fillcolor': '#F0E68C',
'style': 'filled'}
-def draw_net(caffe_net, ext='png'):
- """Draws a caffe net and returns the image string encoded using the given
- extension.
-
- Input:
- caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer.
- ext: the image extension. Default 'png'.
- """
+def get_pydot_graph(caffe_net):
pydot_graph = pydot.Dot(caffe_net.name, graph_type='digraph')
pydot_nodes = {}
pydot_edges = []
@@ -47,11 +40,22 @@ def draw_net(caffe_net, ext='png'):
for edge in pydot_edges:
pydot_graph.add_edge(
pydot.Edge(pydot_nodes[edge[0]], pydot_nodes[edge[1]]))
- return pydot_graph.create(format=ext)
+ return pydot_graph
+
+def draw_net(caffe_net, ext='png'):
+ """Draws a caffe net and returns the image string encoded using the given
+ extension.
+
+ Input:
+ caffe_net: a caffe.proto.caffe_pb2.NetParameter protocol buffer.
+ ext: the image extension. Default 'png'.
+ """
+ return get_pydot_graph(caffe_net).create(format=ext)
def draw_net_to_file(caffe_net, filename):
"""Draws a caffe net, and saves it to file using the format given as the
- file extension.
+ file extension. Use '.raw' to output raw text that you can manually feed
+ to graphviz to draw graphs.
"""
ext = filename[filename.rfind('.')+1:]
with open(filename, 'w') as fid: