summaryrefslogtreecommitdiff
path: root/matlab
diff options
context:
space:
mode:
authorEvan Shelhamer <shelhamer@imaginarynumber.net>2015-01-25 22:06:23 -0800
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2015-02-17 11:35:51 -0800
commitbb8b5d60eb47dff3977377a2a4db4fe9f2fbee62 (patch)
treec826be29ce06b6e0619f22193fb1117904be90b8 /matlab
parentd2934aee4cee0805992a9c1690d0eedeed98ea51 (diff)
downloadcaffeonacl-bb8b5d60eb47dff3977377a2a4db4fe9f2fbee62.tar.gz
caffeonacl-bb8b5d60eb47dff3977377a2a4db4fe9f2fbee62.tar.bz2
caffeonacl-bb8b5d60eb47dff3977377a2a4db4fe9f2fbee62.zip
[matcaffe] give phase to Net
Diffstat (limited to 'matlab')
-rw-r--r--matlab/caffe/matcaffe.cpp25
-rw-r--r--matlab/caffe/matcaffe_init.m7
2 files changed, 15 insertions, 17 deletions
diff --git a/matlab/caffe/matcaffe.cpp b/matlab/caffe/matcaffe.cpp
index fd8397e7..996d3d21 100644
--- a/matlab/caffe/matcaffe.cpp
+++ b/matlab/caffe/matcaffe.cpp
@@ -254,14 +254,6 @@ static void set_mode_gpu(MEX_ARGS) {
Caffe::set_mode(Caffe::GPU);
}
-static void set_phase_train(MEX_ARGS) {
- Caffe::set_phase(Caffe::TRAIN);
-}
-
-static void set_phase_test(MEX_ARGS) {
- Caffe::set_phase(Caffe::TEST);
-}
-
static void set_device(MEX_ARGS) {
if (nrhs != 1) {
ostringstream error_msg;
@@ -278,7 +270,7 @@ static void get_init_key(MEX_ARGS) {
}
static void init(MEX_ARGS) {
- if (nrhs != 2) {
+ if (nrhs != 3) {
ostringstream error_msg;
error_msg << "Expected 2 arguments, got " << nrhs;
mex_error(error_msg.str());
@@ -286,12 +278,23 @@ static void init(MEX_ARGS) {
char* param_file = mxArrayToString(prhs[0]);
char* model_file = mxArrayToString(prhs[1]);
+ char* phase_name = mxArrayToString(prhs[2]);
+
+ Phase phase;
+ if (strcmp(phase_name, "train") == 0) {
+ phase = TRAIN;
+ } else if (strcmp(phase_name, "test") == 0) {
+ phase = TEST;
+ } else {
+ mex_error("Unknown phase.");
+ }
- net_.reset(new Net<float>(string(param_file)));
+ net_.reset(new Net<float>(string(param_file), phase));
net_->CopyTrainedLayersFrom(string(model_file));
mxFree(param_file);
mxFree(model_file);
+ mxFree(phase_name);
init_key = random(); // NOLINT(caffe/random_fn)
@@ -377,8 +380,6 @@ static handler_registry handlers[] = {
{ "is_initialized", is_initialized },
{ "set_mode_cpu", set_mode_cpu },
{ "set_mode_gpu", set_mode_gpu },
- { "set_phase_train", set_phase_train },
- { "set_phase_test", set_phase_test },
{ "set_device", set_device },
{ "get_weights", get_weights },
{ "get_init_key", get_init_key },
diff --git a/matlab/caffe/matcaffe_init.m b/matlab/caffe/matcaffe_init.m
index 7cc69357..5d0a0a70 100644
--- a/matlab/caffe/matcaffe_init.m
+++ b/matlab/caffe/matcaffe_init.m
@@ -25,7 +25,8 @@ if caffe('is_initialized') == 0
% NOTE: you'll have to get network definition
error('You need the network prototxt definition');
end
- caffe('init', model_def_file, model_file)
+ % load network in TEST phase
+ caffe('init', model_def_file, model_file, 'test')
end
fprintf('Done with init\n');
@@ -38,7 +39,3 @@ else
caffe('set_mode_cpu');
end
fprintf('Done with set_mode\n');
-
-% put into test mode
-caffe('set_phase_test');
-fprintf('Done with set_phase_test\n');