diff options
author | Rok Mandeljc <rok.mandeljc@fe.uni-lj.si> | 2015-06-29 15:48:43 +0200 |
---|---|---|
committer | Rok Mandeljc <rok.mandeljc@gmail.com> | 2016-09-19 00:31:56 +0200 |
commit | 2f55f42cff9147e69b1f5dff9232058d7b654eba (patch) | |
tree | 2ad309bdf0bbd2a7d4d553e6799daf1063a92a0f | |
parent | 25422de79f58e214e55834524bfe696f8651889f (diff) | |
download | caffeonacl-2f55f42cff9147e69b1f5dff9232058d7b654eba.tar.gz caffeonacl-2f55f42cff9147e69b1f5dff9232058d7b654eba.tar.bz2 caffeonacl-2f55f42cff9147e69b1f5dff9232058d7b654eba.zip |
matcaffe: allow destruction of individual networks and solvers
-rw-r--r-- | matlab/+caffe/Net.m | 3 | ||||
-rw-r--r-- | matlab/+caffe/Solver.m | 3 | ||||
-rw-r--r-- | matlab/+caffe/private/caffe_.cpp | 24 |
3 files changed, 30 insertions, 0 deletions
diff --git a/matlab/+caffe/Net.m b/matlab/+caffe/Net.m index e6295bba..349e060e 100644 --- a/matlab/+caffe/Net.m +++ b/matlab/+caffe/Net.m @@ -68,6 +68,9 @@ classdef Net < handle self.layer_names = self.attributes.layer_names; self.blob_names = self.attributes.blob_names; end + function delete (self) + caffe_('delete_net', self.hNet_self); + end function layer = layers(self, layer_name) CHECK(ischar(layer_name), 'layer_name must be a string'); layer = self.layer_vec(self.name2layer_index(layer_name)); diff --git a/matlab/+caffe/Solver.m b/matlab/+caffe/Solver.m index f8bdc4e2..2d3c98b2 100644 --- a/matlab/+caffe/Solver.m +++ b/matlab/+caffe/Solver.m @@ -36,6 +36,9 @@ classdef Solver < handle self.test_nets(n) = caffe.Net(self.attributes.hNet_test_nets(n)); end end + function delete (self) + caffe_('delete_solver', self.hSolver_self); + end function iter = iter(self) iter = caffe_('solver_get_iter', self.hSolver_self); end diff --git a/matlab/+caffe/private/caffe_.cpp b/matlab/+caffe/private/caffe_.cpp index 1b1b2bff..bc04f417 100644 --- a/matlab/+caffe/private/caffe_.cpp +++ b/matlab/+caffe/private/caffe_.cpp @@ -197,6 +197,17 @@ static void get_solver(MEX_ARGS) { mxFree(solver_file); } +// Usage: caffe_('delete_solver', hSolver) +static void delete_solver(MEX_ARGS) { + mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]), + "Usage: caffe_('delete_solver', hSolver)"); + Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]); + solvers_.erase(std::remove_if(solvers_.begin(), solvers_.end(), + [solver] (const shared_ptr< Solver<float> > &solverPtr) { + return solverPtr.get() == solver; + }), solvers_.end()); +} + // Usage: caffe_('solver_get_attr', hSolver) static void solver_get_attr(MEX_ARGS) { mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]), @@ -271,6 +282,17 @@ static void get_net(MEX_ARGS) { mxFree(phase_name); } +// Usage: caffe_('delete_solver', hSolver) +static void delete_net(MEX_ARGS) { + mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]), + "Usage: caffe_('delete_solver', hNet)"); + Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]); + nets_.erase(std::remove_if(nets_.begin(), nets_.end(), + [net] (const shared_ptr< Net<float> > &netPtr) { + return netPtr.get() == net; + }), nets_.end()); +} + // Usage: caffe_('net_get_attr', hNet) static void net_get_attr(MEX_ARGS) { mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]), @@ -522,12 +544,14 @@ struct handler_registry { static handler_registry handlers[] = { // Public API functions { "get_solver", get_solver }, + { "delete_solver", delete_solver }, { "solver_get_attr", solver_get_attr }, { "solver_get_iter", solver_get_iter }, { "solver_restore", solver_restore }, { "solver_solve", solver_solve }, { "solver_step", solver_step }, { "get_net", get_net }, + { "delete_net", delete_net }, { "net_get_attr", net_get_attr }, { "net_forward", net_forward }, { "net_backward", net_backward }, |