summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRok Mandeljc <rok.mandeljc@fe.uni-lj.si>2015-06-29 15:48:43 +0200
committerRok Mandeljc <rok.mandeljc@gmail.com>2016-09-19 00:31:56 +0200
commit2f55f42cff9147e69b1f5dff9232058d7b654eba (patch)
tree2ad309bdf0bbd2a7d4d553e6799daf1063a92a0f
parent25422de79f58e214e55834524bfe696f8651889f (diff)
downloadcaffeonacl-2f55f42cff9147e69b1f5dff9232058d7b654eba.tar.gz
caffeonacl-2f55f42cff9147e69b1f5dff9232058d7b654eba.tar.bz2
caffeonacl-2f55f42cff9147e69b1f5dff9232058d7b654eba.zip
matcaffe: allow destruction of individual networks and solvers
-rw-r--r--matlab/+caffe/Net.m3
-rw-r--r--matlab/+caffe/Solver.m3
-rw-r--r--matlab/+caffe/private/caffe_.cpp24
3 files changed, 30 insertions, 0 deletions
diff --git a/matlab/+caffe/Net.m b/matlab/+caffe/Net.m
index e6295bba..349e060e 100644
--- a/matlab/+caffe/Net.m
+++ b/matlab/+caffe/Net.m
@@ -68,6 +68,9 @@ classdef Net < handle
self.layer_names = self.attributes.layer_names;
self.blob_names = self.attributes.blob_names;
end
+ function delete (self)
+ caffe_('delete_net', self.hNet_self);
+ end
function layer = layers(self, layer_name)
CHECK(ischar(layer_name), 'layer_name must be a string');
layer = self.layer_vec(self.name2layer_index(layer_name));
diff --git a/matlab/+caffe/Solver.m b/matlab/+caffe/Solver.m
index f8bdc4e2..2d3c98b2 100644
--- a/matlab/+caffe/Solver.m
+++ b/matlab/+caffe/Solver.m
@@ -36,6 +36,9 @@ classdef Solver < handle
self.test_nets(n) = caffe.Net(self.attributes.hNet_test_nets(n));
end
end
+ function delete (self)
+ caffe_('delete_solver', self.hSolver_self);
+ end
function iter = iter(self)
iter = caffe_('solver_get_iter', self.hSolver_self);
end
diff --git a/matlab/+caffe/private/caffe_.cpp b/matlab/+caffe/private/caffe_.cpp
index 1b1b2bff..bc04f417 100644
--- a/matlab/+caffe/private/caffe_.cpp
+++ b/matlab/+caffe/private/caffe_.cpp
@@ -197,6 +197,17 @@ static void get_solver(MEX_ARGS) {
mxFree(solver_file);
}
+// Usage: caffe_('delete_solver', hSolver)
+static void delete_solver(MEX_ARGS) {
+ mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
+ "Usage: caffe_('delete_solver', hSolver)");
+ Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
+ solvers_.erase(std::remove_if(solvers_.begin(), solvers_.end(),
+ [solver] (const shared_ptr< Solver<float> > &solverPtr) {
+ return solverPtr.get() == solver;
+ }), solvers_.end());
+}
+
// Usage: caffe_('solver_get_attr', hSolver)
static void solver_get_attr(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@@ -271,6 +282,17 @@ static void get_net(MEX_ARGS) {
mxFree(phase_name);
}
+// Usage: caffe_('delete_solver', hSolver)
+static void delete_net(MEX_ARGS) {
+ mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
+ "Usage: caffe_('delete_solver', hNet)");
+ Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
+ nets_.erase(std::remove_if(nets_.begin(), nets_.end(),
+ [net] (const shared_ptr< Net<float> > &netPtr) {
+ return netPtr.get() == net;
+ }), nets_.end());
+}
+
// Usage: caffe_('net_get_attr', hNet)
static void net_get_attr(MEX_ARGS) {
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@@ -522,12 +544,14 @@ struct handler_registry {
static handler_registry handlers[] = {
// Public API functions
{ "get_solver", get_solver },
+ { "delete_solver", delete_solver },
{ "solver_get_attr", solver_get_attr },
{ "solver_get_iter", solver_get_iter },
{ "solver_restore", solver_restore },
{ "solver_solve", solver_solve },
{ "solver_step", solver_step },
{ "get_net", get_net },
+ { "delete_net", delete_net },
{ "net_get_attr", net_get_attr },
{ "net_forward", net_forward },
{ "net_backward", net_backward },