diff options
author | Jon Long <jonlong@cs.berkeley.edu> | 2015-08-26 18:12:17 -0700 |
---|---|---|
committer | Jon Long <jonlong@cs.berkeley.edu> | 2015-08-26 18:12:17 -0700 |
commit | f572eefc8a415c01c6da24ee0d8e0f75b3f00d44 (patch) | |
tree | d42d18e347dd2b9c6dccd772f79ce25cab2077b3 /python | |
parent | b7e4bfe2522cf7fa81694710a780c1e64d6116b2 (diff) | |
parent | 60c0d58baab7be6c770d81f4c5a7cc1fce0ef7af (diff) | |
download | caffeonacl-f572eefc8a415c01c6da24ee0d8e0f75b3f00d44.tar.gz caffeonacl-f572eefc8a415c01c6da24ee0d8e0f75b3f00d44.tar.bz2 caffeonacl-f572eefc8a415c01c6da24ee0d8e0f75b3f00d44.zip |
Merge pull request #2944 from philkr/python_layer_param
Give the python layer parameter/weight blobs.
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/_caffe.cpp | 18 | ||||
-rw-r--r-- | python/caffe/test/test_python_layer.py | 54 |
2 files changed, 71 insertions, 1 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 020a5bee..cc49f60a 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -190,6 +190,21 @@ bp::object Blob_Reshape(bp::tuple args, bp::dict kwargs) { return bp::object(); } +bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) { + if (bp::len(kwargs) > 0) { + throw std::runtime_error("BlobVec.add_blob takes no kwargs"); + } + typedef vector<shared_ptr<Blob<Dtype> > > BlobVec; + BlobVec* self = bp::extract<BlobVec*>(args[0]); + vector<int> shape(bp::len(args) - 1); + for (int i = 1; i < bp::len(args); ++i) { + shape[i - 1] = bp::extract<int>(args[i]); + } + self->push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape))); + // We need to explicitly return None to use bp::raw_function. + return bp::object(); +} + BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1); BOOST_PYTHON_MODULE(_caffe) { @@ -288,7 +303,8 @@ BOOST_PYTHON_MODULE(_caffe) { // vector wrappers for all the vector types we use bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec") - .def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>()); + .def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>()) + .def("add_blob", bp::raw_function(&BlobVec_add_blob)); bp::class_<vector<Blob<Dtype>*> >("RawBlobVec") .def(bp::vector_indexing_suite<vector<Blob<Dtype>*>, true>()); bp::class_<vector<shared_ptr<Layer<Dtype> > > >("LayerVec") diff --git a/python/caffe/test/test_python_layer.py b/python/caffe/test/test_python_layer.py index a1e11bc2..8ed86655 100644 --- a/python/caffe/test/test_python_layer.py +++ b/python/caffe/test/test_python_layer.py @@ -28,6 +28,21 @@ class ExceptionLayer(caffe.Layer): def setup(self, bottom, top): raise RuntimeError +class ParameterLayer(caffe.Layer): + """A layer that just multiplies by ten""" + + def setup(self, bottom, top): + self.blobs.add_blob(1) + self.blobs[0].data[0] = 0 + + def reshape(self, bottom, top): + top[0].reshape(*bottom[0].data.shape) + + def forward(self, bottom, top): + pass + + def backward(self, top, propagate_down, bottom): + self.blobs[0].diff[0] = 1 def python_net_file(): with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f: @@ -52,6 +67,16 @@ def exception_net_file(): return f.name +def parameter_net_file(): + with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f: + f.write("""name: 'pythonnet' force_backward: true + input: 'data' input_shape { dim: 10 dim: 9 dim: 8 } + layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top' + python_param { module: 'test_python_layer' layer: 'ParameterLayer' } } + """) + return f.name + + class TestPythonLayer(unittest.TestCase): def setUp(self): net_file = python_net_file() @@ -84,3 +109,32 @@ class TestPythonLayer(unittest.TestCase): net_file = exception_net_file() self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST) os.remove(net_file) + + def test_parameter(self): + net_file = parameter_net_file() + net = caffe.Net(net_file, caffe.TRAIN) + # Test forward and backward + net.forward() + net.backward() + layer = net.layers[list(net._layer_names).index('layer')] + self.assertEqual(layer.blobs[0].data[0], 0) + self.assertEqual(layer.blobs[0].diff[0], 1) + layer.blobs[0].data[0] += layer.blobs[0].diff[0] + self.assertEqual(layer.blobs[0].data[0], 1) + + # Test saving and loading + h, caffemodel_file = tempfile.mkstemp() + net.save(caffemodel_file) + layer.blobs[0].data[0] = -1 + self.assertEqual(layer.blobs[0].data[0], -1) + net.copy_from(caffemodel_file) + self.assertEqual(layer.blobs[0].data[0], 1) + os.remove(caffemodel_file) + + # Test weight sharing + net2 = caffe.Net(net_file, caffe.TRAIN) + net2.share_with(net) + layer = net.layers[list(net2._layer_names).index('layer')] + self.assertEqual(layer.blobs[0].data[0], 1) + + os.remove(net_file) |