diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2015-09-04 14:36:31 -0300 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2015-09-04 14:36:31 -0300 |
commit | 5367a1af5dc8a56a284b7f1c67efce097871955a (patch) | |
tree | 3e927ad825979a158a49e8ecd27097a705858aa4 /python | |
parent | 9a244f9bcf83d033ddc0fe7355cb6bf4fd4fcb03 (diff) | |
parent | 5cc76ad2e38f19a140497ff09c475500da9d76cf (diff) | |
download | caffeonacl-5367a1af5dc8a56a284b7f1c67efce097871955a.tar.gz caffeonacl-5367a1af5dc8a56a284b7f1c67efce097871955a.tar.bz2 caffeonacl-5367a1af5dc8a56a284b7f1c67efce097871955a.zip |
Merge pull request #3024 from danielgordon10/python-solver-fix
[pycaffe] expose all solvers for direct instantiation (although note get_solver)
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/__init__.py | 2 | ||||
-rw-r--r-- | python/caffe/_caffe.cpp | 9 | ||||
-rw-r--r-- | python/caffe/pycaffe.py | 3 |
3 files changed, 12 insertions, 2 deletions
diff --git a/python/caffe/__init__.py b/python/caffe/__init__.py index 6cc44e72..ccda1bca 100644 --- a/python/caffe/__init__.py +++ b/python/caffe/__init__.py @@ -1,4 +1,4 @@ -from .pycaffe import Net, SGDSolver +from .pycaffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, RMSPropSolver, AdaDeltaSolver, AdamSolver from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list from .proto.caffe_pb2 import TRAIN, TEST from .classifier import Classifier diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index cc49f60a..ccd5776a 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -297,6 +297,15 @@ BOOST_PYTHON_MODULE(_caffe) { bp::class_<AdaGradSolver<Dtype>, bp::bases<Solver<Dtype> >, shared_ptr<AdaGradSolver<Dtype> >, boost::noncopyable>( "AdaGradSolver", bp::init<string>()); + bp::class_<RMSPropSolver<Dtype>, bp::bases<Solver<Dtype> >, + shared_ptr<RMSPropSolver<Dtype> >, boost::noncopyable>( + "RMSPropSolver", bp::init<string>()); + bp::class_<AdaDeltaSolver<Dtype>, bp::bases<Solver<Dtype> >, + shared_ptr<AdaDeltaSolver<Dtype> >, boost::noncopyable>( + "AdaDeltaSolver", bp::init<string>()); + bp::class_<AdamSolver<Dtype>, bp::bases<Solver<Dtype> >, + shared_ptr<AdamSolver<Dtype> >, boost::noncopyable>( + "AdamSolver", bp::init<string>()); bp::def("get_solver", &GetSolverFromFile, bp::return_value_policy<bp::manage_new_object>()); diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 4f980a92..8ea24da4 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -10,7 +10,8 @@ except: from itertools import zip_longest as izip_longest import numpy as np -from ._caffe import Net, SGDSolver +from ._caffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, \ + RMSPropSolver, AdaDeltaSolver, AdamSolver import caffe.io # We directly update methods from Net here (rather than using composition or |