summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2015-09-04 14:36:31 -0300
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2015-09-04 14:36:31 -0300
commit5367a1af5dc8a56a284b7f1c67efce097871955a (patch)
tree3e927ad825979a158a49e8ecd27097a705858aa4
parent9a244f9bcf83d033ddc0fe7355cb6bf4fd4fcb03 (diff)
parent5cc76ad2e38f19a140497ff09c475500da9d76cf (diff)
downloadcaffeonacl-5367a1af5dc8a56a284b7f1c67efce097871955a.tar.gz
caffeonacl-5367a1af5dc8a56a284b7f1c67efce097871955a.tar.bz2
caffeonacl-5367a1af5dc8a56a284b7f1c67efce097871955a.zip
Merge pull request #3024 from danielgordon10/python-solver-fix
[pycaffe] expose all solvers for direct instantiation (although note get_solver)
-rw-r--r--python/caffe/__init__.py2
-rw-r--r--python/caffe/_caffe.cpp9
-rw-r--r--python/caffe/pycaffe.py3
3 files changed, 12 insertions, 2 deletions
diff --git a/python/caffe/__init__.py b/python/caffe/__init__.py
index 6cc44e72..ccda1bca 100644
--- a/python/caffe/__init__.py
+++ b/python/caffe/__init__.py
@@ -1,4 +1,4 @@
-from .pycaffe import Net, SGDSolver
+from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver
from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list
from .proto.caffe_pb2 import TRAIN, TEST
from .classifier import Classifier
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
index cc49f60a..ccd5776a 100644
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -297,6 +297,15 @@ BOOST_PYTHON_MODULE(_caffe) {
bp::class_<AdaGradSolver<Dtype>, bp::bases<Solver<Dtype> >,
shared_ptr<AdaGradSolver<Dtype> >, boost::noncopyable>(
"AdaGradSolver", bp::init<string>());
+ bp::class_<RMSPropSolver<Dtype>, bp::bases<Solver<Dtype> >,
+ shared_ptr<RMSPropSolver<Dtype> >, boost::noncopyable>(
+ "RMSPropSolver", bp::init<string>());
+ bp::class_<AdaDeltaSolver<Dtype>, bp::bases<Solver<Dtype> >,
+ shared_ptr<AdaDeltaSolver<Dtype> >, boost::noncopyable>(
+ "AdaDeltaSolver", bp::init<string>());
+ bp::class_<AdamSolver<Dtype>, bp::bases<Solver<Dtype> >,
+ shared_ptr<AdamSolver<Dtype> >, boost::noncopyable>(
+ "AdamSolver", bp::init<string>());
bp::def("get_solver", &GetSolverFromFile,
bp::return_value_policy<bp::manage_new_object>());
diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py
index 4f980a92..8ea24da4 100644
--- a/python/caffe/pycaffe.py
+++ b/python/caffe/pycaffe.py
@@ -10,7 +10,8 @@ except:
from itertools import zip_longest as izip_longest
import numpy as np
-from ._caffe import Net, SGDSolver
+from ._caffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, \
+ RMSPropSolver, AdaDeltaSolver, AdamSolver
import caffe.io
# We directly update methods from Net here (rather than using composition or