summaryrefslogtreecommitdiff
path: root/matlab
diff options
context:
space:
mode:
authorRoss Girshick <rbg@eecs.berkeley.edu>2013-12-05 20:58:03 -0800
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2014-03-19 19:12:54 -0700
commitdee9ce76ffeb8296931c4268caedfb0f00f07925 (patch)
tree1ece296c7f990a45b4d7e72324255a9b34f84e19 /matlab
parentb68cf5ee131c368be2a7afb1dd736b7955cef334 (diff)
downloadcaffe-dee9ce76ffeb8296931c4268caedfb0f00f07925.tar.gz
caffe-dee9ce76ffeb8296931c4268caedfb0f00f07925.tar.bz2
caffe-dee9ce76ffeb8296931c4268caedfb0f00f07925.zip
return model weights
Diffstat (limited to 'matlab')
-rw-r--r--matlab/caffe/matcaffe.cpp101
1 files changed, 98 insertions, 3 deletions
diff --git a/matlab/caffe/matcaffe.cpp b/matlab/caffe/matcaffe.cpp
index ddbacca1..99ef3f4b 100644
--- a/matlab/caffe/matcaffe.cpp
+++ b/matlab/caffe/matcaffe.cpp
@@ -24,8 +24,8 @@ static shared_ptr<Net<float> > net_;
// matlab uses RGB color channel order
// images need to have the data mean subtracted
//
-// Data coming in from matlab needs to be in the order
-// [batch_images, channels, height, width]
+// Data coming in from matlab needs to be in the order
+// [width, height, channels, images]
// where width is the fastest dimension.
// Here is the rough matlab for putting image data into the correct
// format:
@@ -87,7 +87,92 @@ static mxArray* do_forward(const mxArray* const bottom) {
return mx_out;
}
-// The caffe::Caffe utility functions.
+static mxArray* do_get_weights() {
+ const vector<shared_ptr<Layer<float> > >& layers = net_->layers();
+ const vector<string>& layer_names = net_->layer_names();
+
+ // Step 1: count the number of layers
+ int num_layers = 0;
+ {
+ string prev_layer_name = "";
+ for (unsigned int i = 0; i < layers.size(); ++i) {
+ vector<shared_ptr<Blob<float> > >& layer_blobs = layers[i]->blobs();
+ if (layer_blobs.size() == 0) {
+ continue;
+ }
+ if (layer_names[i] != prev_layer_name) {
+ prev_layer_name = layer_names[i];
+ num_layers++;
+ }
+ }
+ }
+
+ // Step 2: prepare output array of structures
+ mxArray* mx_layers;
+ {
+ const mwSize dims[2] = {num_layers, 1};
+ const char* fnames[2] = {"weights", "layer_names"};
+ mx_layers = mxCreateStructArray(2, dims, 2, fnames);
+ }
+
+ // Step 3: copy weights into output
+ {
+ string prev_layer_name = "";
+ int mx_layer_index = 0;
+ for (unsigned int i = 0; i < layers.size(); ++i) {
+ vector<shared_ptr<Blob<float> > >& layer_blobs = layers[i]->blobs();
+ if (layer_blobs.size() == 0) {
+ continue;
+ }
+
+ mxArray* mx_layer_cells = NULL;
+ if (layer_names[i] != prev_layer_name) {
+ prev_layer_name = layer_names[i];
+ const mwSize dims[2] = {layer_blobs.size(), 1};
+ mx_layer_cells = mxCreateCellArray(2, dims);
+ mxSetField(mx_layers, mx_layer_index, "weights", mx_layer_cells);
+ mxSetField(mx_layers, mx_layer_index, "layer_names",
+ mxCreateString(layer_names[i].c_str()));
+ mx_layer_index++;
+ }
+
+ for (unsigned int j = 0; j < layer_blobs.size(); ++j) {
+ // internally data is stored as (width, height, channels, num)
+ // where width is the fastest dimension
+ mwSize dims[4] = {layer_blobs[j]->width(), layer_blobs[j]->height(),
+ layer_blobs[j]->channels(), layer_blobs[j]->num()};
+ mxArray* mx_weights = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL);
+ mxSetCell(mx_layer_cells, j, mx_weights);
+ float* weights_ptr = reinterpret_cast<float*>(mxGetPr(mx_weights));
+
+// mexPrintf("layer: %s (%d) blob: %d %d: (%d, %d, %d) %d\n",
+// layer_names[i].c_str(), i, j, layer_blobs[j]->num(),
+// layer_blobs[j]->height(), layer_blobs[j]->width(),
+// layer_blobs[j]->channels(), layer_blobs[j]->count());
+
+ switch (Caffe::mode()) {
+ case Caffe::CPU:
+ memcpy(weights_ptr, layer_blobs[j]->cpu_data(),
+ sizeof(float) * layer_blobs[j]->count());
+ break;
+ case Caffe::GPU:
+ CUDA_CHECK(cudaMemcpy(weights_ptr, layer_blobs[j]->gpu_data(),
+ sizeof(float) * layer_blobs[j]->count(), cudaMemcpyDeviceToHost));
+ break;
+ default:
+ LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+ }
+ }
+ }
+ }
+
+ return mx_layers;
+}
+
+static void get_weights(MEX_ARGS) {
+ plhs[0] = do_get_weights();
+}
+
static void set_mode_cpu(MEX_ARGS) {
Caffe::set_mode(Caffe::CPU);
}
@@ -139,6 +224,14 @@ static void forward(MEX_ARGS) {
plhs[0] = do_forward(prhs[0]);
}
+static void is_initialized(MEX_ARGS) {
+ if (!net_) {
+ plhs[0] = mxCreateDoubleScalar(0);
+ } else {
+ plhs[0] = mxCreateDoubleScalar(1);
+ }
+}
+
/** -----------------------------------------------------------------
** Available commands.
**/
@@ -151,11 +244,13 @@ static handler_registry handlers[] = {
// Public API functions
{ "forward", forward },
{ "init", init },
+ { "is_initialized", is_initialized },
{ "set_mode_cpu", set_mode_cpu },
{ "set_mode_gpu", set_mode_gpu },
{ "set_phase_train", set_phase_train },
{ "set_phase_test", set_phase_test },
{ "set_device", set_device },
+ { "get_weights", get_weights },
// The end.
{ "END", NULL },
};