summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorJonathan L Long <jonlong@cs.berkeley.edu>2015-01-16 13:20:15 -0800
committerJonathan L Long <jonlong@cs.berkeley.edu>2015-01-19 15:15:04 -0800
commit7dfe23963c69759c3bdca1f0ee8ff9866c0fb5c2 (patch)
tree8b1d3821ffebb78823cdb258a8097e7e101e7eb7 /examples
parent1856bb2d89ecbabe84cbf0a647de7a2ab6c29207 (diff)
downloadcaffeonacl-7dfe23963c69759c3bdca1f0ee8ff9866c0fb5c2.tar.gz
caffeonacl-7dfe23963c69759c3bdca1f0ee8ff9866c0fb5c2.tar.bz2
caffeonacl-7dfe23963c69759c3bdca1f0ee8ff9866c0fb5c2.zip
use db wrappers
Diffstat (limited to 'examples')
-rw-r--r--examples/cifar10/convert_cifar_data.cpp41
1 files changed, 21 insertions, 20 deletions
diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp
index 9eecc74c..f4c42e4d 100644
--- a/examples/cifar10/convert_cifar_data.cpp
+++ b/examples/cifar10/convert_cifar_data.cpp
@@ -9,19 +9,18 @@
#include <fstream> // NOLINT(readability/streams)
#include <string>
+#include "boost/scoped_ptr.hpp"
#include "glog/logging.h"
#include "google/protobuf/text_format.h"
#include "stdint.h"
-#include "caffe/dataset_factory.hpp"
#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/db.hpp"
-using std::string;
-
-using caffe::Dataset;
-using caffe::DatasetFactory;
using caffe::Datum;
-using caffe::shared_ptr;
+using boost::scoped_ptr;
+using std::string;
+namespace db = caffe::db;
const int kCIFARSize = 32;
const int kCIFARImageNBytes = 3072;
@@ -38,10 +37,9 @@ void read_image(std::ifstream* file, int* label, char* buffer) {
void convert_dataset(const string& input_folder, const string& output_folder,
const string& db_type) {
- shared_ptr<Dataset<string, Datum> > train_dataset =
- DatasetFactory<string, Datum>(db_type);
- CHECK(train_dataset->open(output_folder + "/cifar10_train_" + db_type,
- Dataset<string, Datum>::New));
+ scoped_ptr<db::DB> train_db(db::GetDB(db_type));
+ train_db->Open(output_folder + "/cifar10_train_" + db_type, db::NEW);
+ scoped_ptr<db::Transaction> txn(train_db->NewTransaction());
// Data buffer
int label;
char str_buffer[kCIFARImageNBytes];
@@ -64,17 +62,18 @@ void convert_dataset(const string& input_folder, const string& output_folder,
datum.set_data(str_buffer, kCIFARImageNBytes);
int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d",
fileid * kCIFARBatchSize + itemid);
- CHECK(train_dataset->put(string(str_buffer, length), datum));
+ string out;
+ CHECK(datum.SerializeToString(&out));
+ txn->Put(string(str_buffer, length), out);
}
}
- CHECK(train_dataset->commit());
- train_dataset->close();
+ txn->Commit();
+ train_db->Close();
LOG(INFO) << "Writing Testing data";
- shared_ptr<Dataset<string, Datum> > test_dataset =
- DatasetFactory<string, Datum>(db_type);
- CHECK(test_dataset->open(output_folder + "/cifar10_test_" + db_type,
- Dataset<string, Datum>::New));
+ scoped_ptr<db::DB> test_db(db::GetDB(db_type));
+ test_db->Open(output_folder + "/cifar10_test_" + db_type, db::NEW);
+ txn.reset(test_db->NewTransaction());
// Open files
std::ifstream data_file((input_folder + "/test_batch.bin").c_str(),
std::ios::in | std::ios::binary);
@@ -84,10 +83,12 @@ void convert_dataset(const string& input_folder, const string& output_folder,
datum.set_label(label);
datum.set_data(str_buffer, kCIFARImageNBytes);
int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid);
- CHECK(test_dataset->put(string(str_buffer, length), datum));
+ string out;
+ CHECK(datum.SerializeToString(&out));
+ txn->Put(string(str_buffer, length), out);
}
- CHECK(test_dataset->commit());
- test_dataset->close();
+ txn->Commit();
+ test_db->Close();
}
int main(int argc, char** argv) {