summaryrefslogtreecommitdiff
path: root/tools/upgrade_net_proto_text.cpp
blob: 9200431bc278562de4ad898ac55ef8459f1fb8bf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
// This is a script to upgrade "V0" network prototxts to the new format.
// Usage:
//    upgrade_net_proto_text v0_net_proto_file_in net_proto_file_out

#include <cstring>
#include <fstream>  // NOLINT(readability/streams)
#include <iostream>  // NOLINT(readability/streams)
#include <string>

#include "caffe/caffe.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"

using std::ofstream;

using namespace caffe;  // NOLINT(build/namespaces)

int main(int argc, char** argv) {
  ::google::InitGoogleLogging(argv[0]);
  if (argc != 3) {
    LOG(ERROR) << "Usage: "
        << "upgrade_net_proto_text v0_net_proto_file_in net_proto_file_out";
    return 1;
  }

  NetParameter net_param;
  string input_filename(argv[1]);
  if (!ReadProtoFromTextFile(input_filename, &net_param)) {
    LOG(ERROR) << "Failed to parse input text file as NetParameter: "
               << input_filename;
    return 2;
  }
  bool need_upgrade = NetNeedsUpgrade(net_param);
  bool need_data_upgrade = NetNeedsDataUpgrade(net_param);
  bool success = true;
  if (need_upgrade) {
    success = UpgradeNetAsNeeded(input_filename, &net_param);
    if (!success) {
      LOG(ERROR) << "Encountered error(s) while upgrading prototxt; "
                 << "see details above.";
    }
  } else {
    LOG(ERROR) << "File already in latest proto format: " << input_filename;
  }

  if (need_data_upgrade) {
    UpgradeNetDataTransformation(&net_param);
  }

  // Save new format prototxt.
  WriteProtoToTextFile(net_param, argv[2]);

  LOG(ERROR) << "Wrote upgraded NetParameter text proto to " << argv[2];
  return !success;
}