summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorphilkr <philkr@users.noreply.github.com>2016-05-23 20:09:45 -0700
committerphilkr <philkr@users.noreply.github.com>2016-06-02 15:27:17 -0700
commit742c93f31be4c874aa5fd0103f25f8a2f8d4d63d (patch)
treeccd63897f8aa26712fab2e98f6b98d5be81e2053 /python
parent923e7e8b6337f610115ae28859408bc392d13136 (diff)
downloadcaffeonacl-742c93f31be4c874aa5fd0103f25f8a2f8d4d63d.tar.gz
caffeonacl-742c93f31be4c874aa5fd0103f25f8a2f8d4d63d.tar.bz2
caffeonacl-742c93f31be4c874aa5fd0103f25f8a2f8d4d63d.zip
Exposing load_hdf5 and save_hdf5 to python
Diffstat (limited to 'python')
-rw-r--r--python/caffe/_caffe.cpp12
-rw-r--r--python/caffe/test/test_net.py14
2 files changed, 25 insertions, 1 deletions
diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp
index 32b5d921..48a0c8f2 100644
--- a/python/caffe/_caffe.cpp
+++ b/python/caffe/_caffe.cpp
@@ -114,6 +114,14 @@ void Net_Save(const Net<Dtype>& net, string filename) {
WriteProtoToBinaryFile(net_param, filename.c_str());
}
+void Net_SaveHDF5(const Net<Dtype>& net, string filename) {
+ net.ToHDF5(filename);
+}
+
+void Net_LoadHDF5(Net<Dtype>* net, string filename) {
+ net->CopyTrainedLayersFromHDF5(filename.c_str());
+}
+
void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
bp::object labels_obj) {
// check that this network has an input MemoryDataLayer
@@ -267,7 +275,9 @@ BOOST_PYTHON_MODULE(_caffe) {
bp::return_value_policy<bp::copy_const_reference>()))
.def("_set_input_arrays", &Net_SetInputArrays,
bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
- .def("save", &Net_Save);
+ .def("save", &Net_Save)
+ .def("save_hdf5", &Net_SaveHDF5)
+ .def("load_hdf5", &Net_LoadHDF5);
BP_REGISTER_SHARED_PTR_TO_PYTHON(Net<Dtype>);
bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
diff --git a/python/caffe/test/test_net.py b/python/caffe/test/test_net.py
index aad828aa..4cacfcd0 100644
--- a/python/caffe/test/test_net.py
+++ b/python/caffe/test/test_net.py
@@ -79,3 +79,17 @@ class TestNet(unittest.TestCase):
for i in range(len(self.net.params[name])):
self.assertEqual(abs(self.net.params[name][i].data
- net2.params[name][i].data).sum(), 0)
+
+ def test_save_hdf5(self):
+ f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
+ f.close()
+ self.net.save_hdf5(f.name)
+ net_file = simple_net_file(self.num_output)
+ net2 = caffe.Net(net_file, caffe.TRAIN)
+ net2.load_hdf5(f.name)
+ os.remove(net_file)
+ os.remove(f.name)
+ for name in self.net.params:
+ for i in range(len(self.net.params[name])):
+ self.assertEqual(abs(self.net.params[name][i].data
+ - net2.params[name][i].data).sum(), 0)