diff options
author | Ronghang Hu <huronghang@hotmail.com> | 2015-05-29 07:50:23 +0800 |
---|---|---|
committer | Ronghang Hu <huronghang@hotmail.com> | 2015-05-29 16:06:35 +0800 |
commit | d07e5f796907a2bc048bdab3cdb4ace05fa60d7a (patch) | |
tree | 82011dcce9bd4f289afe6d26f66fa6aa004ffd3a /matlab | |
parent | 18adbb8d1a1be91598aa23bad6550eed954e32a9 (diff) | |
download | caffeonacl-d07e5f796907a2bc048bdab3cdb4ace05fa60d7a.tar.gz caffeonacl-d07e5f796907a2bc048bdab3cdb4ace05fa60d7a.tar.bz2 caffeonacl-d07e5f796907a2bc048bdab3cdb4ace05fa60d7a.zip |
More tests for Blob, Layer, copy_from and step, fix some typos
More testes are added into test_net.m and test_solver.m
Diffstat (limited to 'matlab')
-rw-r--r-- | matlab/+caffe/+test/test_net.m | 24 | ||||
-rw-r--r-- | matlab/+caffe/+test/test_solver.m | 2 | ||||
-rw-r--r-- | matlab/+caffe/Net.m | 2 | ||||
-rw-r--r-- | matlab/+caffe/Solver.m | 2 | ||||
-rw-r--r-- | matlab/+caffe/io.m | 2 | ||||
-rw-r--r-- | matlab/+caffe/run_tests.m | 3 |
6 files changed, 31 insertions, 4 deletions
diff --git a/matlab/+caffe/+test/test_net.m b/matlab/+caffe/+test/test_net.m index 5d9ba000..3dabe84d 100644 --- a/matlab/+caffe/+test/test_net.m +++ b/matlab/+caffe/+test/test_net.m @@ -48,6 +48,24 @@ classdef test_net < matlab.unittest.TestCase end end methods (Test) + function self = test_blob(self) + self.net.blobs('data').set_data(10 * ones(self.net.blobs('data').shape)); + self.verifyEqual(self.net.blobs('data').get_data(), ... + 10 * ones(self.net.blobs('data').shape, 'single')); + self.net.blobs('data').set_diff(-2 * ones(self.net.blobs('data').shape)); + self.verifyEqual(self.net.blobs('data').get_diff(), ... + -2 * ones(self.net.blobs('data').shape, 'single')); + original_shape = self.net.blobs('data').shape; + self.net.blobs('data').reshape([6 5 4 3 2 1]); + self.verifyEqual(self.net.blobs('data').shape, [6 5 4 3 2 1]); + self.net.blobs('data').reshape(original_shape); + self.net.reshape(); + end + function self = test_layer(self) + self.verifyEqual(self.net.params('conv', 1).shape, [2 2 2 11]); + self.verifyEqual(self.net.layers('conv').params(2).shape, 11); + self.verifyEqual(self.net.layers('conv').type(), 'Convolution'); + end function test_forward_backward(self) self.net.forward_prefilled(); self.net.backward_prefilled(); @@ -60,13 +78,17 @@ classdef test_net < matlab.unittest.TestCase weights_file = tempname(); self.net.save(weights_file); model_file2 = caffe.test.test_net.simple_net_file(self.num_output); - net2 = caffe.Net(model_file2, weights_file, 'train'); + net2 = caffe.Net(model_file2, 'train'); + net2.copy_from(weights_file); + net3 = caffe.Net(model_file2, weights_file, 'train'); delete(model_file2); delete(weights_file); for l = 1:length(self.net.layer_vec) for i = 1:length(self.net.layer_vec(l).params) self.verifyEqual(self.net.layer_vec(l).params(i).get_data(), ... net2.layer_vec(l).params(i).get_data()); + self.verifyEqual(self.net.layer_vec(l).params(i).get_data(), ... + net3.layer_vec(l).params(i).get_data()); end end end diff --git a/matlab/+caffe/+test/test_solver.m b/matlab/+caffe/+test/test_solver.m index 682dad48..739258b0 100644 --- a/matlab/+caffe/+test/test_solver.m +++ b/matlab/+caffe/+test/test_solver.m @@ -36,6 +36,8 @@ classdef test_solver < matlab.unittest.TestCase methods (Test) function test_solve(self) self.verifyEqual(self.solver.iter(), 0) + self.solver.step(30); + self.verifyEqual(self.solver.iter(), 30) self.solver.solve() self.verifyEqual(self.solver.iter(), 100) end diff --git a/matlab/+caffe/Net.m b/matlab/+caffe/Net.m index a6761060..e6295bba 100644 --- a/matlab/+caffe/Net.m +++ b/matlab/+caffe/Net.m @@ -111,7 +111,7 @@ classdef Net < handle self.blobs(self.outputs{n}).set_diff(output_diff{n}); end self.backward_prefilled(); - % retrieve diff from input_blobs + % retrieve diff from input blobs res = cell(length(self.inputs), 1); for n = 1:length(self.inputs) res{n} = self.blobs(self.inputs{n}).get_diff(); diff --git a/matlab/+caffe/Solver.m b/matlab/+caffe/Solver.m index daaa8022..f8bdc4e2 100644 --- a/matlab/+caffe/Solver.m +++ b/matlab/+caffe/Solver.m @@ -41,7 +41,7 @@ classdef Solver < handle end function restore(self, snapshot_filename) CHECK(ischar(snapshot_filename), 'snapshot_filename must be a string'); - CHECK_FILE_EXIST(snapshot_filename) + CHECK_FILE_EXIST(snapshot_filename); caffe_('solver_restore', self.hSolver_self, snapshot_filename); end function solve(self) diff --git a/matlab/+caffe/io.m b/matlab/+caffe/io.m index 7a30bfb5..c9e07aee 100644 --- a/matlab/+caffe/io.m +++ b/matlab/+caffe/io.m @@ -17,7 +17,7 @@ classdef io function mean_data = read_mean(mean_proto_file) % mean_data = read_mean(mean_proto_file) % read image mean data from binaryproto file - CHECK(ischar(mean_proto_file), 'im_file must be a string'); + CHECK(ischar(mean_proto_file), 'mean_proto_file must be a string'); CHECK_FILE_EXIST(mean_proto_file); mean_data = caffe_('read_mean', mean_proto_file); end diff --git a/matlab/+caffe/run_tests.m b/matlab/+caffe/run_tests.m index 8773c9f6..93896855 100644 --- a/matlab/+caffe/run_tests.m +++ b/matlab/+caffe/run_tests.m @@ -2,6 +2,9 @@ function results = run_tests() % results = run_tests() % run all tests in this caffe matlab wrapper package +% use CPU for testing +caffe.set_mode_cpu(); + % reset caffe before testing caffe.reset_all(); |