diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2015-01-25 22:06:23 -0800 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2015-02-17 11:35:51 -0800 |
commit | bb8b5d60eb47dff3977377a2a4db4fe9f2fbee62 (patch) | |
tree | c826be29ce06b6e0619f22193fb1117904be90b8 /matlab | |
parent | d2934aee4cee0805992a9c1690d0eedeed98ea51 (diff) | |
download | caffeonacl-bb8b5d60eb47dff3977377a2a4db4fe9f2fbee62.tar.gz caffeonacl-bb8b5d60eb47dff3977377a2a4db4fe9f2fbee62.tar.bz2 caffeonacl-bb8b5d60eb47dff3977377a2a4db4fe9f2fbee62.zip |
[matcaffe] give phase to Net
Diffstat (limited to 'matlab')
-rw-r--r-- | matlab/caffe/matcaffe.cpp | 25 | ||||
-rw-r--r-- | matlab/caffe/matcaffe_init.m | 7 |
2 files changed, 15 insertions, 17 deletions
diff --git a/matlab/caffe/matcaffe.cpp b/matlab/caffe/matcaffe.cpp index fd8397e7..996d3d21 100644 --- a/matlab/caffe/matcaffe.cpp +++ b/matlab/caffe/matcaffe.cpp @@ -254,14 +254,6 @@ static void set_mode_gpu(MEX_ARGS) { Caffe::set_mode(Caffe::GPU); } -static void set_phase_train(MEX_ARGS) { - Caffe::set_phase(Caffe::TRAIN); -} - -static void set_phase_test(MEX_ARGS) { - Caffe::set_phase(Caffe::TEST); -} - static void set_device(MEX_ARGS) { if (nrhs != 1) { ostringstream error_msg; @@ -278,7 +270,7 @@ static void get_init_key(MEX_ARGS) { } static void init(MEX_ARGS) { - if (nrhs != 2) { + if (nrhs != 3) { ostringstream error_msg; error_msg << "Expected 2 arguments, got " << nrhs; mex_error(error_msg.str()); @@ -286,12 +278,23 @@ static void init(MEX_ARGS) { char* param_file = mxArrayToString(prhs[0]); char* model_file = mxArrayToString(prhs[1]); + char* phase_name = mxArrayToString(prhs[2]); + + Phase phase; + if (strcmp(phase_name, "train") == 0) { + phase = TRAIN; + } else if (strcmp(phase_name, "test") == 0) { + phase = TEST; + } else { + mex_error("Unknown phase."); + } - net_.reset(new Net<float>(string(param_file))); + net_.reset(new Net<float>(string(param_file), phase)); net_->CopyTrainedLayersFrom(string(model_file)); mxFree(param_file); mxFree(model_file); + mxFree(phase_name); init_key = random(); // NOLINT(caffe/random_fn) @@ -377,8 +380,6 @@ static handler_registry handlers[] = { { "is_initialized", is_initialized }, { "set_mode_cpu", set_mode_cpu }, { "set_mode_gpu", set_mode_gpu }, - { "set_phase_train", set_phase_train }, - { "set_phase_test", set_phase_test }, { "set_device", set_device }, { "get_weights", get_weights }, { "get_init_key", get_init_key }, diff --git a/matlab/caffe/matcaffe_init.m b/matlab/caffe/matcaffe_init.m index 7cc69357..5d0a0a70 100644 --- a/matlab/caffe/matcaffe_init.m +++ b/matlab/caffe/matcaffe_init.m @@ -25,7 +25,8 @@ if caffe('is_initialized') == 0 % NOTE: you'll have to get network definition error('You need the network prototxt definition'); end - caffe('init', model_def_file, model_file) + % load network in TEST phase + caffe('init', model_def_file, model_file, 'test') end fprintf('Done with init\n'); @@ -38,7 +39,3 @@ else caffe('set_mode_cpu'); end fprintf('Done with set_mode\n'); - -% put into test mode -caffe('set_phase_test'); -fprintf('Done with set_phase_test\n'); |