summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJeff Donahue <jeff.donahue@gmail.com>2015-08-20 11:04:05 -0700
committerJeff Donahue <jeff.donahue@gmail.com>2015-08-20 11:04:05 -0700
commitaa2a6f55b9e50b29d607aaee0fae19bd085d6565 (patch)
tree60e2d7214f66fe265586e0766ab588c1cac07db7 /python
parenta7cf704f8ecf48b1ad2ad6b0a7b5ee9e787bbe17 (diff)
parent51b172ce2fcd7f63aa7830389af54d353f53a3bc (diff)
downloadcaffeonacl-aa2a6f55b9e50b29d607aaee0fae19bd085d6565.tar.gz
caffeonacl-aa2a6f55b9e50b29d607aaee0fae19bd085d6565.tar.bz2
caffeonacl-aa2a6f55b9e50b29d607aaee0fae19bd085d6565.zip
Merge pull request #2930 from lukeyeager/pycaffe-layer_type_list
Expose LayerFactory::LayerTypeList in pycaffe
Diffstat (limited to 'python')
-rw-r--r--python/caffe/__init__.py2
-rw-r--r--python/caffe/_caffe.cpp2
-rw-r--r--python/caffe/test/test_layer_type_list.py10
3 files changed, 13 insertions, 1 deletions
diff --git a/python/caffe/__init__.py b/python/caffe/__init__.py
index 1b2da510..6cc44e72 100644
--- a/python/caffe/__init__.py
+++ b/python/caffe/__init__.py
@@ -1,5 +1,5 @@
from .pycaffe import Net, SGDSolver
-from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver
+from ._caffe import set_mode_cpu, set_mode_gpu, set_device, Layer, get_solver, layer_type_list
from .proto.caffe_pb2 import TRAIN, TEST
from .classifier import Classifier
from .detector import Detector
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
index e1ae3ec7..020a5bee 100644
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -200,6 +200,8 @@ BOOST_PYTHON_MODULE(_caffe) {
bp::def("set_mode_gpu", &set_mode_gpu);
bp::def("set_device", &Caffe::SetDevice);
+ bp::def("layer_type_list", &LayerRegistry<Dtype>::LayerTypeList);
+
bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >("Net",
bp::no_init)
.def("__init__", bp::make_constructor(&Net_Init))
diff --git a/python/caffe/test/test_layer_type_list.py b/python/caffe/test/test_layer_type_list.py
new file mode 100644
index 00000000..7edc80df
--- /dev/null
+++ b/python/caffe/test/test_layer_type_list.py
@@ -0,0 +1,10 @@
+import unittest
+
+import caffe
+
+class TestLayerTypeList(unittest.TestCase):
+
+ def test_standard_types(self):
+ for type_name in ['Data', 'Convolution', 'InnerProduct']:
+ self.assertIn(type_name, caffe.layer_type_list(),
+ '%s not in layer_type_list()' % type_name)