summaryrefslogtreecommitdiff
path: root/src/caffe/test/test_net_proto.cpp
blob: 38b5b68de95ff3a363ef785428c292ac6d5c4ccb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Copyright 2013 Yangqing Jia

#include <cuda_runtime.h>
#include <fcntl.h>
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <gtest/gtest.h>

#include <cstring>

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/net.hpp"
#include "caffe/filler.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"

#include "caffe/test/test_caffe_main.hpp"

namespace caffe {

template <typename Dtype>
class NetProtoTest : public ::testing::Test {};

typedef ::testing::Types<float, double> Dtypes;
TYPED_TEST_CASE(NetProtoTest, Dtypes);

TYPED_TEST(NetProtoTest, TestSetup) {
  NetParameter net_param;
  ReadProtoFromTextFile("data/lenet.prototxt", &net_param);
  // check if things are right
  EXPECT_EQ(net_param.layers_size(), 10);
  EXPECT_EQ(net_param.input_size(), 0);

  vector<Blob<TypeParam>*> bottom_vec;

  Net<TypeParam> caffe_net(net_param, bottom_vec);
  EXPECT_EQ(caffe_net.layer_names().size(), 10);
  EXPECT_EQ(caffe_net.blob_names().size(), 10);

  /*
  // Print a few statistics to see if things are correct
  for (int i = 0; i < caffe_net.blobs().size(); ++i) {
    LOG(ERROR) << "Blob: " << caffe_net.blob_names()[i];
    LOG(ERROR) << "size: " << caffe_net.blobs()[i]->num() << ", "
        << caffe_net.blobs()[i]->channels() << ", "
        << caffe_net.blobs()[i]->height() << ", "
        << caffe_net.blobs()[i]->width();
  }
  */
  Caffe::set_mode(Caffe::CPU);
  // Run the network without training.
  LOG(ERROR) << "Performing Forward";
  caffe_net.Forward(bottom_vec);
  LOG(ERROR) << "Performing Backward";
  LOG(ERROR) << caffe_net.Backward();

  Caffe::set_mode(Caffe::GPU);
  // Run the network without training.
  LOG(ERROR) << "Performing Forward";
  caffe_net.Forward(bottom_vec);
  LOG(ERROR) << "Performing Backward";
  LOG(ERROR) << caffe_net.Backward();
}

}  // namespace caffe