1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
// Copyright 2014 BVLC and contributors.
// This program converts a set of images to a leveldb by storing them as Datum
// proto buffers.
// Usage:
// convert_imageset ROOTFOLDER/ LISTFILE DB_NAME RANDOM_SHUFFLE[0 or 1] \
// [resize_height] [resize_width]
// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE
// should be a list of files as well as their labels, in the format as
// subfolder1/file1.JPEG 7
// ....
// if RANDOM_SHUFFLE is 1, a random shuffle will be carried out before we
// process the file lines.
#include <glog/logging.h>
#include <leveldb/db.h>
#include <leveldb/write_batch.h>
#include <algorithm>
#include <fstream> // NOLINT(readability/streams)
#include <string>
#include <utility>
#include <vector>
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/io.hpp"
using namespace caffe; // NOLINT(build/namespaces)
using std::pair;
using std::string;
int main(int argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
if (argc < 4 || argc > 7) {
printf("Convert a set of images to the leveldb format used\n"
"as input for Caffe.\n"
"Usage:\n"
" convert_imageset ROOTFOLDER/ LISTFILE DB_NAME"
" RANDOM_SHUFFLE_DATA[0 or 1] [resize_height] [resize_width]\n"
"The ImageNet dataset for the training demo is at\n"
" http://www.image-net.org/download-images\n");
return 1;
}
std::ifstream infile(argv[2]);
std::vector<std::pair<string, int> > lines;
string filename;
int label;
while (infile >> filename >> label) {
lines.push_back(std::make_pair(filename, label));
}
if (argc >= 5 && argv[4][0] == '1') {
// randomly shuffle data
LOG(INFO) << "Shuffling data";
std::random_shuffle(lines.begin(), lines.end());
}
LOG(INFO) << "A total of " << lines.size() << " images.";
int resize_height = 0;
int resize_width = 0;
if (argc >= 6) {
resize_height = atoi(argv[5]);
}
if (argc >= 7) {
resize_width = atoi(argv[6]);
}
leveldb::DB* db;
leveldb::Options options;
options.error_if_exists = true;
options.create_if_missing = true;
options.write_buffer_size = 268435456;
LOG(INFO) << "Opening leveldb " << argv[3];
leveldb::Status status = leveldb::DB::Open(
options, argv[3], &db);
CHECK(status.ok()) << "Failed to open leveldb " << argv[3];
string root_folder(argv[1]);
Datum datum;
int count = 0;
const int kMaxKeyLength = 256;
char key_cstr[kMaxKeyLength];
leveldb::WriteBatch* batch = new leveldb::WriteBatch();
int data_size;
bool data_size_initialized = false;
for (int line_id = 0; line_id < lines.size(); ++line_id) {
if (!ReadImageToDatum(root_folder + lines[line_id].first,
lines[line_id].second, resize_height, resize_width, &datum)) {
continue;
}
if (!data_size_initialized) {
data_size = datum.channels() * datum.height() * datum.width();
data_size_initialized = true;
} else {
const string& data = datum.data();
CHECK_EQ(data.size(), data_size) << "Incorrect data field size "
<< data.size();
}
// sequential
snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,
lines[line_id].first.c_str());
string value;
// get the value
datum.SerializeToString(&value);
batch->Put(string(key_cstr), value);
if (++count % 1000 == 0) {
db->Write(leveldb::WriteOptions(), batch);
LOG(ERROR) << "Processed " << count << " files.";
delete batch;
batch = new leveldb::WriteBatch();
}
}
// write the last batch
if (count % 1000 != 0) {
db->Write(leveldb::WriteOptions(), batch);
LOG(ERROR) << "Processed " << count << " files.";
}
delete batch;
delete db;
return 0;
}
|