summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJonathan L Long <jonlong@cs.berkeley.edu>2014-04-25 14:29:50 -0700
committerJonathan L Long <jonlong@cs.berkeley.edu>2014-05-02 13:25:51 -0700
commite1072a66d467b743df75435e1a28a1e34a1a4f25 (patch)
tree2f007357adb7a689b108805eeb553106a2441362 /python
parent634a382bce186e63671d37b4d8e40939a74b6373 (diff)
downloadcaffe-e1072a66d467b743df75435e1a28a1e34a1a4f25.tar.gz
caffe-e1072a66d467b743df75435e1a28a1e34a1a4f25.tar.bz2
caffe-e1072a66d467b743df75435e1a28a1e34a1a4f25.zip
pycaffe: store a shared_ptr<CaffeNet> in SGDSolver
Doing this, rather than constructing the CaffeNet wrapper every time, will allow the wrapper to hold references that last at least as long as SGDSolver (which will be necessary to ensure that data used by MemoryDataLayer doesn't get freed).
Diffstat (limited to 'python')
-rw-r--r--python/caffe/_caffe.cpp6
1 files changed, 5 insertions, 1 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
index 1a44debf..853ddbe1 100644
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -301,9 +301,12 @@ class CaffeSGDSolver {
// exception if param_file can't be opened
CheckFile(param_file);
solver_.reset(new SGDSolver<float>(param_file));
+ // we need to explicitly store the net wrapper, rather than constructing
+ // it on the fly, so that it can hold references to Python objects
+ net_.reset(new CaffeNet(solver_->net()));
}
- CaffeNet net() { return CaffeNet(solver_->net()); }
+ shared_ptr<CaffeNet> net() { return net_; }
void Solve() { return solver_->Solve(); }
void SolveResume(const string& resume_file) {
CheckFile(resume_file);
@@ -311,6 +314,7 @@ class CaffeSGDSolver {
}
protected:
+ shared_ptr<CaffeNet> net_;
shared_ptr<SGDSolver<float> > solver_;
};