summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJon Long <jonlong@cs.berkeley.edu>2015-08-26 18:12:17 -0700
committerJon Long <jonlong@cs.berkeley.edu>2015-08-26 18:12:17 -0700
commitf572eefc8a415c01c6da24ee0d8e0f75b3f00d44 (patch)
treed42d18e347dd2b9c6dccd772f79ce25cab2077b3 /python
parentb7e4bfe2522cf7fa81694710a780c1e64d6116b2 (diff)
parent60c0d58baab7be6c770d81f4c5a7cc1fce0ef7af (diff)
downloadcaffeonacl-f572eefc8a415c01c6da24ee0d8e0f75b3f00d44.tar.gz
caffeonacl-f572eefc8a415c01c6da24ee0d8e0f75b3f00d44.tar.bz2
caffeonacl-f572eefc8a415c01c6da24ee0d8e0f75b3f00d44.zip
Merge pull request #2944 from philkr/python_layer_param
Give the python layer parameter/weight blobs.
Diffstat (limited to 'python')
-rw-r--r--python/caffe/_caffe.cpp18
-rw-r--r--python/caffe/test/test_python_layer.py54
2 files changed, 71 insertions, 1 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
index 020a5bee..cc49f60a 100644
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -190,6 +190,21 @@ bp::object Blob_Reshape(bp::tuple args, bp::dict kwargs) {
return bp::object();
}
+bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
+ if (bp::len(kwargs) > 0) {
+ throw std::runtime_error("BlobVec.add_blob takes no kwargs");
+ }
+ typedef vector<shared_ptr<Blob<Dtype> > > BlobVec;
+ BlobVec* self = bp::extract<BlobVec*>(args[0]);
+ vector<int> shape(bp::len(args) - 1);
+ for (int i = 1; i < bp::len(args); ++i) {
+ shape[i - 1] = bp::extract<int>(args[i]);
+ }
+ self->push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
+ // We need to explicitly return None to use bp::raw_function.
+ return bp::object();
+}
+
BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
BOOST_PYTHON_MODULE(_caffe) {
@@ -288,7 +303,8 @@ BOOST_PYTHON_MODULE(_caffe) {
// vector wrappers for all the vector types we use
bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
- .def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>());
+ .def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>())
+ .def("add_blob", bp::raw_function(&BlobVec_add_blob));
bp::class_<vector<Blob<Dtype>*> >("RawBlobVec")
.def(bp::vector_indexing_suite<vector<Blob<Dtype>*>, true>());
bp::class_<vector<shared_ptr<Layer<Dtype> > > >("LayerVec")
diff --git a/python/caffe/test/test_python_layer.py b/python/caffe/test/test_python_layer.py
index a1e11bc2..8ed86655 100644
--- a/python/caffe/test/test_python_layer.py
+++ b/python/caffe/test/test_python_layer.py
@@ -28,6 +28,21 @@ class ExceptionLayer(caffe.Layer):
def setup(self, bottom, top):
raise RuntimeError
+class ParameterLayer(caffe.Layer):
+ """A layer that just multiplies by ten"""
+
+ def setup(self, bottom, top):
+ self.blobs.add_blob(1)
+ self.blobs[0].data[0] = 0
+
+ def reshape(self, bottom, top):
+ top[0].reshape(*bottom[0].data.shape)
+
+ def forward(self, bottom, top):
+ pass
+
+ def backward(self, top, propagate_down, bottom):
+ self.blobs[0].diff[0] = 1
def python_net_file():
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
@@ -52,6 +67,16 @@ def exception_net_file():
return f.name
+def parameter_net_file():
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
+ f.write("""name: 'pythonnet' force_backward: true
+ input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
+ layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
+ python_param { module: 'test_python_layer' layer: 'ParameterLayer' } }
+ """)
+ return f.name
+
+
class TestPythonLayer(unittest.TestCase):
def setUp(self):
net_file = python_net_file()
@@ -84,3 +109,32 @@ class TestPythonLayer(unittest.TestCase):
net_file = exception_net_file()
self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
os.remove(net_file)
+
+ def test_parameter(self):
+ net_file = parameter_net_file()
+ net = caffe.Net(net_file, caffe.TRAIN)
+ # Test forward and backward
+ net.forward()
+ net.backward()
+ layer = net.layers[list(net._layer_names).index('layer')]
+ self.assertEqual(layer.blobs[0].data[0], 0)
+ self.assertEqual(layer.blobs[0].diff[0], 1)
+ layer.blobs[0].data[0] += layer.blobs[0].diff[0]
+ self.assertEqual(layer.blobs[0].data[0], 1)
+
+ # Test saving and loading
+ h, caffemodel_file = tempfile.mkstemp()
+ net.save(caffemodel_file)
+ layer.blobs[0].data[0] = -1
+ self.assertEqual(layer.blobs[0].data[0], -1)
+ net.copy_from(caffemodel_file)
+ self.assertEqual(layer.blobs[0].data[0], 1)
+ os.remove(caffemodel_file)
+
+ # Test weight sharing
+ net2 = caffe.Net(net_file, caffe.TRAIN)
+ net2.share_with(net)
+ layer = net.layers[list(net2._layer_names).index('layer')]
+ self.assertEqual(layer.blobs[0].data[0], 1)
+
+ os.remove(net_file)