diff options
author | philkr <philkr@users.noreply.github.com> | 2016-05-23 20:09:45 -0700 |
---|---|---|
committer | philkr <philkr@users.noreply.github.com> | 2016-06-02 15:27:17 -0700 |
commit | 742c93f31be4c874aa5fd0103f25f8a2f8d4d63d (patch) | |
tree | ccd63897f8aa26712fab2e98f6b98d5be81e2053 /python | |
parent | 923e7e8b6337f610115ae28859408bc392d13136 (diff) | |
download | caffeonacl-742c93f31be4c874aa5fd0103f25f8a2f8d4d63d.tar.gz caffeonacl-742c93f31be4c874aa5fd0103f25f8a2f8d4d63d.tar.bz2 caffeonacl-742c93f31be4c874aa5fd0103f25f8a2f8d4d63d.zip |
Exposing load_hdf5 and save_hdf5 to python
Diffstat (limited to 'python')
-rw-r--r-- | python/caffe/_caffe.cpp | 12 | ||||
-rw-r--r-- | python/caffe/test/test_net.py | 14 |
2 files changed, 25 insertions, 1 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 32b5d921..48a0c8f2 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -114,6 +114,14 @@ void Net_Save(const Net<Dtype>& net, string filename) { WriteProtoToBinaryFile(net_param, filename.c_str()); } +void Net_SaveHDF5(const Net<Dtype>& net, string filename) { + net.ToHDF5(filename); +} + +void Net_LoadHDF5(Net<Dtype>* net, string filename) { + net->CopyTrainedLayersFromHDF5(filename.c_str()); +} + void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj, bp::object labels_obj) { // check that this network has an input MemoryDataLayer @@ -267,7 +275,9 @@ BOOST_PYTHON_MODULE(_caffe) { bp::return_value_policy<bp::copy_const_reference>())) .def("_set_input_arrays", &Net_SetInputArrays, bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >()) - .def("save", &Net_Save); + .def("save", &Net_Save) + .def("save_hdf5", &Net_SaveHDF5) + .def("load_hdf5", &Net_LoadHDF5); BP_REGISTER_SHARED_PTR_TO_PYTHON(Net<Dtype>); bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>( diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py index aad828aa..4cacfcd0 100644 --- a/python/caffe/test/test_net.py +++ b/python/caffe/test/test_net.py @@ -79,3 +79,17 @@ class TestNet(unittest.TestCase): for i in range(len(self.net.params[name])): self.assertEqual(abs(self.net.params[name][i].data - net2.params[name][i].data).sum(), 0) + + def test_save_hdf5(self): + f = tempfile.NamedTemporaryFile(mode='w+', delete=False) + f.close() + self.net.save_hdf5(f.name) + net_file = simple_net_file(self.num_output) + net2 = caffe.Net(net_file, caffe.TRAIN) + net2.load_hdf5(f.name) + os.remove(net_file) + os.remove(f.name) + for name in self.net.params: + for i in range(len(self.net.params[name])): + self.assertEqual(abs(self.net.params[name][i].data + - net2.params[name][i].data).sum(), 0) |