diff options
author | Ross Girshick <rbg@eecs.berkeley.edu> | 2013-12-05 20:58:03 -0800 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2014-03-19 19:12:54 -0700 |
commit | dee9ce76ffeb8296931c4268caedfb0f00f07925 (patch) | |
tree | 1ece296c7f990a45b4d7e72324255a9b34f84e19 /matlab | |
parent | b68cf5ee131c368be2a7afb1dd736b7955cef334 (diff) | |
download | caffe-dee9ce76ffeb8296931c4268caedfb0f00f07925.tar.gz caffe-dee9ce76ffeb8296931c4268caedfb0f00f07925.tar.bz2 caffe-dee9ce76ffeb8296931c4268caedfb0f00f07925.zip |
return model weights
Diffstat (limited to 'matlab')
-rw-r--r-- | matlab/caffe/matcaffe.cpp | 101 |
1 files changed, 98 insertions, 3 deletions
diff --git a/matlab/caffe/matcaffe.cpp b/matlab/caffe/matcaffe.cpp index ddbacca1..99ef3f4b 100644 --- a/matlab/caffe/matcaffe.cpp +++ b/matlab/caffe/matcaffe.cpp @@ -24,8 +24,8 @@ static shared_ptr<Net<float> > net_; // matlab uses RGB color channel order // images need to have the data mean subtracted // -// Data coming in from matlab needs to be in the order -// [batch_images, channels, height, width] +// Data coming in from matlab needs to be in the order +// [width, height, channels, images] // where width is the fastest dimension. // Here is the rough matlab for putting image data into the correct // format: @@ -87,7 +87,92 @@ static mxArray* do_forward(const mxArray* const bottom) { return mx_out; } -// The caffe::Caffe utility functions. +static mxArray* do_get_weights() { + const vector<shared_ptr<Layer<float> > >& layers = net_->layers(); + const vector<string>& layer_names = net_->layer_names(); + + // Step 1: count the number of layers + int num_layers = 0; + { + string prev_layer_name = ""; + for (unsigned int i = 0; i < layers.size(); ++i) { + vector<shared_ptr<Blob<float> > >& layer_blobs = layers[i]->blobs(); + if (layer_blobs.size() == 0) { + continue; + } + if (layer_names[i] != prev_layer_name) { + prev_layer_name = layer_names[i]; + num_layers++; + } + } + } + + // Step 2: prepare output array of structures + mxArray* mx_layers; + { + const mwSize dims[2] = {num_layers, 1}; + const char* fnames[2] = {"weights", "layer_names"}; + mx_layers = mxCreateStructArray(2, dims, 2, fnames); + } + + // Step 3: copy weights into output + { + string prev_layer_name = ""; + int mx_layer_index = 0; + for (unsigned int i = 0; i < layers.size(); ++i) { + vector<shared_ptr<Blob<float> > >& layer_blobs = layers[i]->blobs(); + if (layer_blobs.size() == 0) { + continue; + } + + mxArray* mx_layer_cells = NULL; + if (layer_names[i] != prev_layer_name) { + prev_layer_name = layer_names[i]; + const mwSize dims[2] = {layer_blobs.size(), 1}; + mx_layer_cells = mxCreateCellArray(2, dims); + mxSetField(mx_layers, mx_layer_index, "weights", mx_layer_cells); + mxSetField(mx_layers, mx_layer_index, "layer_names", + mxCreateString(layer_names[i].c_str())); + mx_layer_index++; + } + + for (unsigned int j = 0; j < layer_blobs.size(); ++j) { + // internally data is stored as (width, height, channels, num) + // where width is the fastest dimension + mwSize dims[4] = {layer_blobs[j]->width(), layer_blobs[j]->height(), + layer_blobs[j]->channels(), layer_blobs[j]->num()}; + mxArray* mx_weights = mxCreateNumericArray(4, dims, mxSINGLE_CLASS, mxREAL); + mxSetCell(mx_layer_cells, j, mx_weights); + float* weights_ptr = reinterpret_cast<float*>(mxGetPr(mx_weights)); + +// mexPrintf("layer: %s (%d) blob: %d %d: (%d, %d, %d) %d\n", +// layer_names[i].c_str(), i, j, layer_blobs[j]->num(), +// layer_blobs[j]->height(), layer_blobs[j]->width(), +// layer_blobs[j]->channels(), layer_blobs[j]->count()); + + switch (Caffe::mode()) { + case Caffe::CPU: + memcpy(weights_ptr, layer_blobs[j]->cpu_data(), + sizeof(float) * layer_blobs[j]->count()); + break; + case Caffe::GPU: + CUDA_CHECK(cudaMemcpy(weights_ptr, layer_blobs[j]->gpu_data(), + sizeof(float) * layer_blobs[j]->count(), cudaMemcpyDeviceToHost)); + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } + } + } + } + + return mx_layers; +} + +static void get_weights(MEX_ARGS) { + plhs[0] = do_get_weights(); +} + static void set_mode_cpu(MEX_ARGS) { Caffe::set_mode(Caffe::CPU); } @@ -139,6 +224,14 @@ static void forward(MEX_ARGS) { plhs[0] = do_forward(prhs[0]); } +static void is_initialized(MEX_ARGS) { + if (!net_) { + plhs[0] = mxCreateDoubleScalar(0); + } else { + plhs[0] = mxCreateDoubleScalar(1); + } +} + /** ----------------------------------------------------------------- ** Available commands. **/ @@ -151,11 +244,13 @@ static handler_registry handlers[] = { // Public API functions { "forward", forward }, { "init", init }, + { "is_initialized", is_initialized }, { "set_mode_cpu", set_mode_cpu }, { "set_mode_gpu", set_mode_gpu }, { "set_phase_train", set_phase_train }, { "set_phase_test", set_phase_test }, { "set_device", set_device }, + { "get_weights", get_weights }, // The end. { "END", NULL }, }; |