summaryrefslogtreecommitdiff
path: root/matlab
diff options
context:
space:
mode:
authorRonghang Hu <huronghang@hotmail.com>2015-05-29 07:50:23 +0800
committerRonghang Hu <huronghang@hotmail.com>2015-05-29 16:06:35 +0800
commitd07e5f796907a2bc048bdab3cdb4ace05fa60d7a (patch)
tree82011dcce9bd4f289afe6d26f66fa6aa004ffd3a /matlab
parent18adbb8d1a1be91598aa23bad6550eed954e32a9 (diff)
downloadcaffeonacl-d07e5f796907a2bc048bdab3cdb4ace05fa60d7a.tar.gz
caffeonacl-d07e5f796907a2bc048bdab3cdb4ace05fa60d7a.tar.bz2
caffeonacl-d07e5f796907a2bc048bdab3cdb4ace05fa60d7a.zip
More tests for Blob, Layer, copy_from and step, fix some typos
More testes are added into test_net.m and test_solver.m
Diffstat (limited to 'matlab')
-rw-r--r--matlab/+caffe/+test/test_net.m24
-rw-r--r--matlab/+caffe/+test/test_solver.m2
-rw-r--r--matlab/+caffe/Net.m2
-rw-r--r--matlab/+caffe/Solver.m2
-rw-r--r--matlab/+caffe/io.m2
-rw-r--r--matlab/+caffe/run_tests.m3
6 files changed, 31 insertions, 4 deletions
diff --git a/matlab/+caffe/+test/test_net.m b/matlab/+caffe/+test/test_net.m
index 5d9ba000..3dabe84d 100644
--- a/matlab/+caffe/+test/test_net.m
+++ b/matlab/+caffe/+test/test_net.m
@@ -48,6 +48,24 @@ classdef test_net < matlab.unittest.TestCase
end
end
methods (Test)
+ function self = test_blob(self)
+ self.net.blobs('data').set_data(10 * ones(self.net.blobs('data').shape));
+ self.verifyEqual(self.net.blobs('data').get_data(), ...
+ 10 * ones(self.net.blobs('data').shape, 'single'));
+ self.net.blobs('data').set_diff(-2 * ones(self.net.blobs('data').shape));
+ self.verifyEqual(self.net.blobs('data').get_diff(), ...
+ -2 * ones(self.net.blobs('data').shape, 'single'));
+ original_shape = self.net.blobs('data').shape;
+ self.net.blobs('data').reshape([6 5 4 3 2 1]);
+ self.verifyEqual(self.net.blobs('data').shape, [6 5 4 3 2 1]);
+ self.net.blobs('data').reshape(original_shape);
+ self.net.reshape();
+ end
+ function self = test_layer(self)
+ self.verifyEqual(self.net.params('conv', 1).shape, [2 2 2 11]);
+ self.verifyEqual(self.net.layers('conv').params(2).shape, 11);
+ self.verifyEqual(self.net.layers('conv').type(), 'Convolution');
+ end
function test_forward_backward(self)
self.net.forward_prefilled();
self.net.backward_prefilled();
@@ -60,13 +78,17 @@ classdef test_net < matlab.unittest.TestCase
weights_file = tempname();
self.net.save(weights_file);
model_file2 = caffe.test.test_net.simple_net_file(self.num_output);
- net2 = caffe.Net(model_file2, weights_file, 'train');
+ net2 = caffe.Net(model_file2, 'train');
+ net2.copy_from(weights_file);
+ net3 = caffe.Net(model_file2, weights_file, 'train');
delete(model_file2);
delete(weights_file);
for l = 1:length(self.net.layer_vec)
for i = 1:length(self.net.layer_vec(l).params)
self.verifyEqual(self.net.layer_vec(l).params(i).get_data(), ...
net2.layer_vec(l).params(i).get_data());
+ self.verifyEqual(self.net.layer_vec(l).params(i).get_data(), ...
+ net3.layer_vec(l).params(i).get_data());
end
end
end
diff --git a/matlab/+caffe/+test/test_solver.m b/matlab/+caffe/+test/test_solver.m
index 682dad48..739258b0 100644
--- a/matlab/+caffe/+test/test_solver.m
+++ b/matlab/+caffe/+test/test_solver.m
@@ -36,6 +36,8 @@ classdef test_solver < matlab.unittest.TestCase
methods (Test)
function test_solve(self)
self.verifyEqual(self.solver.iter(), 0)
+ self.solver.step(30);
+ self.verifyEqual(self.solver.iter(), 30)
self.solver.solve()
self.verifyEqual(self.solver.iter(), 100)
end
diff --git a/matlab/+caffe/Net.m b/matlab/+caffe/Net.m
index a6761060..e6295bba 100644
--- a/matlab/+caffe/Net.m
+++ b/matlab/+caffe/Net.m
@@ -111,7 +111,7 @@ classdef Net < handle
self.blobs(self.outputs{n}).set_diff(output_diff{n});
end
self.backward_prefilled();
- % retrieve diff from input_blobs
+ % retrieve diff from input blobs
res = cell(length(self.inputs), 1);
for n = 1:length(self.inputs)
res{n} = self.blobs(self.inputs{n}).get_diff();
diff --git a/matlab/+caffe/Solver.m b/matlab/+caffe/Solver.m
index daaa8022..f8bdc4e2 100644
--- a/matlab/+caffe/Solver.m
+++ b/matlab/+caffe/Solver.m
@@ -41,7 +41,7 @@ classdef Solver < handle
end
function restore(self, snapshot_filename)
CHECK(ischar(snapshot_filename), 'snapshot_filename must be a string');
- CHECK_FILE_EXIST(snapshot_filename)
+ CHECK_FILE_EXIST(snapshot_filename);
caffe_('solver_restore', self.hSolver_self, snapshot_filename);
end
function solve(self)
diff --git a/matlab/+caffe/io.m b/matlab/+caffe/io.m
index 7a30bfb5..c9e07aee 100644
--- a/matlab/+caffe/io.m
+++ b/matlab/+caffe/io.m
@@ -17,7 +17,7 @@ classdef io
function mean_data = read_mean(mean_proto_file)
% mean_data = read_mean(mean_proto_file)
% read image mean data from binaryproto file
- CHECK(ischar(mean_proto_file), 'im_file must be a string');
+ CHECK(ischar(mean_proto_file), 'mean_proto_file must be a string');
CHECK_FILE_EXIST(mean_proto_file);
mean_data = caffe_('read_mean', mean_proto_file);
end
diff --git a/matlab/+caffe/run_tests.m b/matlab/+caffe/run_tests.m
index 8773c9f6..93896855 100644
--- a/matlab/+caffe/run_tests.m
+++ b/matlab/+caffe/run_tests.m
@@ -2,6 +2,9 @@ function results = run_tests()
% results = run_tests()
% run all tests in this caffe matlab wrapper package
+% use CPU for testing
+caffe.set_mode_cpu();
+
% reset caffe before testing
caffe.reset_all();