diff options
author | Jonathan L Long <jonlong@cs.berkeley.edu> | 2015-01-16 13:20:15 -0800 |
---|---|---|
committer | Jonathan L Long <jonlong@cs.berkeley.edu> | 2015-01-19 15:15:04 -0800 |
commit | 7dfe23963c69759c3bdca1f0ee8ff9866c0fb5c2 (patch) | |
tree | 8b1d3821ffebb78823cdb258a8097e7e101e7eb7 /examples | |
parent | 1856bb2d89ecbabe84cbf0a647de7a2ab6c29207 (diff) | |
download | caffeonacl-7dfe23963c69759c3bdca1f0ee8ff9866c0fb5c2.tar.gz caffeonacl-7dfe23963c69759c3bdca1f0ee8ff9866c0fb5c2.tar.bz2 caffeonacl-7dfe23963c69759c3bdca1f0ee8ff9866c0fb5c2.zip |
use db wrappers
Diffstat (limited to 'examples')
-rw-r--r-- | examples/cifar10/convert_cifar_data.cpp | 41 |
1 files changed, 21 insertions, 20 deletions
diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp index 9eecc74c..f4c42e4d 100644 --- a/examples/cifar10/convert_cifar_data.cpp +++ b/examples/cifar10/convert_cifar_data.cpp @@ -9,19 +9,18 @@ #include <fstream> // NOLINT(readability/streams) #include <string> +#include "boost/scoped_ptr.hpp" #include "glog/logging.h" #include "google/protobuf/text_format.h" #include "stdint.h" -#include "caffe/dataset_factory.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/db.hpp" -using std::string; - -using caffe::Dataset; -using caffe::DatasetFactory; using caffe::Datum; -using caffe::shared_ptr; +using boost::scoped_ptr; +using std::string; +namespace db = caffe::db; const int kCIFARSize = 32; const int kCIFARImageNBytes = 3072; @@ -38,10 +37,9 @@ void read_image(std::ifstream* file, int* label, char* buffer) { void convert_dataset(const string& input_folder, const string& output_folder, const string& db_type) { - shared_ptr<Dataset<string, Datum> > train_dataset = - DatasetFactory<string, Datum>(db_type); - CHECK(train_dataset->open(output_folder + "/cifar10_train_" + db_type, - Dataset<string, Datum>::New)); + scoped_ptr<db::DB> train_db(db::GetDB(db_type)); + train_db->Open(output_folder + "/cifar10_train_" + db_type, db::NEW); + scoped_ptr<db::Transaction> txn(train_db->NewTransaction()); // Data buffer int label; char str_buffer[kCIFARImageNBytes]; @@ -64,17 +62,18 @@ void convert_dataset(const string& input_folder, const string& output_folder, datum.set_data(str_buffer, kCIFARImageNBytes); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", fileid * kCIFARBatchSize + itemid); - CHECK(train_dataset->put(string(str_buffer, length), datum)); + string out; + CHECK(datum.SerializeToString(&out)); + txn->Put(string(str_buffer, length), out); } } - CHECK(train_dataset->commit()); - train_dataset->close(); + txn->Commit(); + train_db->Close(); LOG(INFO) << "Writing Testing data"; - shared_ptr<Dataset<string, Datum> > test_dataset = - DatasetFactory<string, Datum>(db_type); - CHECK(test_dataset->open(output_folder + "/cifar10_test_" + db_type, - Dataset<string, Datum>::New)); + scoped_ptr<db::DB> test_db(db::GetDB(db_type)); + test_db->Open(output_folder + "/cifar10_test_" + db_type, db::NEW); + txn.reset(test_db->NewTransaction()); // Open files std::ifstream data_file((input_folder + "/test_batch.bin").c_str(), std::ios::in | std::ios::binary); @@ -84,10 +83,12 @@ void convert_dataset(const string& input_folder, const string& output_folder, datum.set_label(label); datum.set_data(str_buffer, kCIFARImageNBytes); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid); - CHECK(test_dataset->put(string(str_buffer, length), datum)); + string out; + CHECK(datum.SerializeToString(&out)); + txn->Put(string(str_buffer, length), out); } - CHECK(test_dataset->commit()); - test_dataset->close(); + txn->Commit(); + test_db->Close(); } int main(int argc, char** argv) { |