diff options
-rw-r--r-- | examples/cifar10/convert_cifar_data.cpp | 4 | ||||
-rw-r--r-- | include/caffe/database.hpp | 4 | ||||
-rw-r--r-- | include/caffe/leveldb_database.hpp | 4 | ||||
-rw-r--r-- | include/caffe/lmdb_database.hpp | 4 | ||||
-rw-r--r-- | src/caffe/leveldb_database.cpp | 10 | ||||
-rw-r--r-- | src/caffe/lmdb_database.cpp | 22 | ||||
-rw-r--r-- | src/caffe/test/test_data_layer.cpp | 2 | ||||
-rw-r--r-- | src/caffe/test/test_database.cpp | 80 | ||||
-rw-r--r-- | tools/convert_imageset.cpp | 2 | ||||
-rw-r--r-- | tools/extract_features.cpp | 2 |
10 files changed, 70 insertions, 64 deletions
diff --git a/examples/cifar10/convert_cifar_data.cpp b/examples/cifar10/convert_cifar_data.cpp index b29e4121..af845ead 100644 --- a/examples/cifar10/convert_cifar_data.cpp +++ b/examples/cifar10/convert_cifar_data.cpp @@ -66,7 +66,7 @@ void convert_dataset(const string& input_folder, const string& output_folder, int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", fileid * kCIFARBatchSize + itemid); Database::buffer_t key(str_buffer, str_buffer + length); - CHECK(train_database->put(&key, &value)); + CHECK(train_database->put(key, value)); } } CHECK(train_database->commit()); @@ -89,7 +89,7 @@ void convert_dataset(const string& input_folder, const string& output_folder, reinterpret_cast<unsigned char*>(value.data())); int length = snprintf(str_buffer, kCIFARImageNBytes, "%05d", itemid); Database::buffer_t key(str_buffer, str_buffer + length); - CHECK(test_database->put(&key, &value)); + CHECK(test_database->put(key, value)); } CHECK(test_database->commit()); test_database->close(); diff --git a/include/caffe/database.hpp b/include/caffe/database.hpp index 148b1ed7..3f3970d7 100644 --- a/include/caffe/database.hpp +++ b/include/caffe/database.hpp @@ -27,8 +27,8 @@ class Database { }; virtual bool open(const string& filename, Mode mode) = 0; - virtual bool put(buffer_t* key, buffer_t* value) = 0; - virtual bool get(buffer_t* key, buffer_t* value) = 0; + virtual bool put(const buffer_t& key, const buffer_t& value) = 0; + virtual bool get(const buffer_t& key, buffer_t* value) = 0; virtual bool commit() = 0; virtual void close() = 0; diff --git a/include/caffe/leveldb_database.hpp b/include/caffe/leveldb_database.hpp index 64bfa7ce..e2558ff4 100644 --- a/include/caffe/leveldb_database.hpp +++ b/include/caffe/leveldb_database.hpp @@ -15,8 +15,8 @@ namespace caffe { class LeveldbDatabase : public Database { public: bool open(const string& filename, Mode mode); - bool put(buffer_t* key, buffer_t* value); - bool get(buffer_t* key, buffer_t* value); + bool put(const buffer_t& key, const buffer_t& value); + bool get(const buffer_t& key, buffer_t* value); bool commit(); void close(); diff --git a/include/caffe/lmdb_database.hpp b/include/caffe/lmdb_database.hpp index 69e3ce0f..4a0f3183 100644 --- a/include/caffe/lmdb_database.hpp +++ b/include/caffe/lmdb_database.hpp @@ -19,8 +19,8 @@ class LmdbDatabase : public Database { txn_(NULL) { } bool open(const string& filename, Mode mode); - bool put(buffer_t* key, buffer_t* value); - bool get(buffer_t* key, buffer_t* value); + bool put(const buffer_t& key, const buffer_t& value); + bool get(const buffer_t& key, buffer_t* value); bool commit(); void close(); diff --git a/src/caffe/leveldb_database.cpp b/src/caffe/leveldb_database.cpp index d7506edf..c09112f6 100644 --- a/src/caffe/leveldb_database.cpp +++ b/src/caffe/leveldb_database.cpp @@ -51,7 +51,7 @@ bool LeveldbDatabase::open(const string& filename, Mode mode) { return true; } -bool LeveldbDatabase::put(buffer_t* key, buffer_t* value) { +bool LeveldbDatabase::put(const buffer_t& key, const buffer_t& value) { LOG(INFO) << "LevelDB: Put"; if (read_only_) { @@ -61,18 +61,18 @@ bool LeveldbDatabase::put(buffer_t* key, buffer_t* value) { CHECK_NOTNULL(batch_.get()); - leveldb::Slice key_slice(key->data(), key->size()); - leveldb::Slice value_slice(value->data(), value->size()); + leveldb::Slice key_slice(key.data(), key.size()); + leveldb::Slice value_slice(value.data(), value.size()); batch_->Put(key_slice, value_slice); return true; } -bool LeveldbDatabase::get(buffer_t* key, buffer_t* value) { +bool LeveldbDatabase::get(const buffer_t& key, buffer_t* value) { LOG(INFO) << "LevelDB: Get"; - leveldb::Slice key_slice(key->data(), key->size()); + leveldb::Slice key_slice(key.data(), key.size()); string value_string; leveldb::Status status = diff --git a/src/caffe/lmdb_database.cpp b/src/caffe/lmdb_database.cpp index d71513a5..2cb699b4 100644 --- a/src/caffe/lmdb_database.cpp +++ b/src/caffe/lmdb_database.cpp @@ -77,14 +77,18 @@ bool LmdbDatabase::open(const string& filename, Mode mode) { return true; } -bool LmdbDatabase::put(buffer_t* key, buffer_t* value) { +bool LmdbDatabase::put(const buffer_t& key, const buffer_t& value) { LOG(INFO) << "LMDB: Put"; + // MDB_val::mv_size is not const, so we need to make a local copy. + buffer_t local_key = key; + buffer_t local_value = value; + MDB_val mdbkey, mdbdata; - mdbdata.mv_size = value->size(); - mdbdata.mv_data = value->data(); - mdbkey.mv_size = key->size(); - mdbkey.mv_data = key->data(); + mdbdata.mv_size = local_value.size(); + mdbdata.mv_data = local_value.data(); + mdbkey.mv_size = local_key.size(); + mdbkey.mv_data = local_key.data(); CHECK_NOTNULL(txn_); CHECK_NE(0, dbi_); @@ -98,12 +102,14 @@ bool LmdbDatabase::put(buffer_t* key, buffer_t* value) { return true; } -bool LmdbDatabase::get(buffer_t* key, buffer_t* value) { +bool LmdbDatabase::get(const buffer_t& key, buffer_t* value) { LOG(INFO) << "LMDB: Get"; + buffer_t local_key = key; + MDB_val mdbkey, mdbdata; - mdbkey.mv_data = key->data(); - mdbkey.mv_size = key->size(); + mdbkey.mv_data = local_key.data(); + mdbkey.mv_size = local_key.size(); int retval; MDB_txn* get_txn; diff --git a/src/caffe/test/test_data_layer.cpp b/src/caffe/test/test_data_layer.cpp index cc9ad204..8ae8f2b8 100644 --- a/src/caffe/test/test_data_layer.cpp +++ b/src/caffe/test/test_data_layer.cpp @@ -59,7 +59,7 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> { Database::buffer_t value(datum.ByteSize()); datum.SerializeWithCachedSizesToArray( reinterpret_cast<unsigned char*>(value.data())); - CHECK(database->put(&key, &value)); + CHECK(database->put(key, value)); } CHECK(database->commit()); database->close(); diff --git a/src/caffe/test/test_database.cpp b/src/caffe/test/test_database.cpp index f6586501..70e1d962 100644 --- a/src/caffe/test/test_database.cpp +++ b/src/caffe/test/test_database.cpp @@ -126,7 +126,7 @@ TYPED_TEST(DatabaseTest, TestIteratorsLevelDB) { string value = ss.str(); Database::buffer_t key_buf(key.data(), key.data() + key.size()); Database::buffer_t val_buf(value.data(), value.data() + value.size()); - EXPECT_TRUE(database->put(&key_buf, &val_buf)); + EXPECT_TRUE(database->put(key_buf, val_buf)); } EXPECT_TRUE(database->commit()); @@ -151,8 +151,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLevelDB) { Database::buffer_t key2 = this->TestKey(); Database::buffer_t value2 = this->TestValue(); - EXPECT_TRUE(database->put(&key1, &value1)); - EXPECT_TRUE(database->put(&key2, &value2)); + EXPECT_TRUE(database->put(key1, value1)); + EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); Database::const_iterator iter1 = database->begin(); @@ -190,8 +190,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLevelDB) { Database::buffer_t key2 = this->TestKey(); Database::buffer_t value2 = this->TestValue(); - EXPECT_TRUE(database->put(&key1, &value1)); - EXPECT_TRUE(database->put(&key2, &value2)); + EXPECT_TRUE(database->put(key1, value1)); + EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); Database::const_iterator iter1 = database->begin(); @@ -229,7 +229,7 @@ TYPED_TEST(DatabaseTest, TestNewPutLevelDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -254,13 +254,13 @@ TYPED_TEST(DatabaseTest, TestNewGetLevelDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); @@ -275,11 +275,11 @@ TYPED_TEST(DatabaseTest, TestNewGetNoCommitLevelDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } @@ -291,7 +291,7 @@ TYPED_TEST(DatabaseTest, TestReadWritePutLevelDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -316,13 +316,13 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLevelDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); @@ -337,11 +337,11 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLevelDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } TYPED_TEST(DatabaseTest, TestReadOnlyPutLevelDBFails) { @@ -355,7 +355,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyPutLevelDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_FALSE(database->put(&key, &val)); + EXPECT_FALSE(database->put(key, val)); } TYPED_TEST(DatabaseTest, TestReadOnlyCommitLevelDBFails) { @@ -377,7 +377,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -387,7 +387,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLevelDBPasses) { Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); } @@ -400,7 +400,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); database->close(); @@ -408,7 +408,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLevelDBFails) { Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } TYPED_TEST(DatabaseTest, TestNewDoesntExistLMDBPasses) { @@ -473,7 +473,7 @@ TYPED_TEST(DatabaseTest, TestIteratorsLMDB) { string value = ss.str(); Database::buffer_t key_buf(key.data(), key.data() + key.size()); Database::buffer_t val_buf(value.data(), value.data() + value.size()); - EXPECT_TRUE(database->put(&key_buf, &val_buf)); + EXPECT_TRUE(database->put(key_buf, val_buf)); } EXPECT_TRUE(database->commit()); @@ -498,8 +498,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPreIncrementLMDB) { Database::buffer_t key2 = this->TestKey(); Database::buffer_t value2 = this->TestValue(); - EXPECT_TRUE(database->put(&key1, &value1)); - EXPECT_TRUE(database->put(&key2, &value2)); + EXPECT_TRUE(database->put(key1, value1)); + EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); Database::const_iterator iter1 = database->begin(); @@ -537,8 +537,8 @@ TYPED_TEST(DatabaseTest, TestIteratorsPostIncrementLMDB) { Database::buffer_t key2 = this->TestKey(); Database::buffer_t value2 = this->TestValue(); - EXPECT_TRUE(database->put(&key1, &value1)); - EXPECT_TRUE(database->put(&key2, &value2)); + EXPECT_TRUE(database->put(key1, value1)); + EXPECT_TRUE(database->put(key2, value2)); EXPECT_TRUE(database->commit()); Database::const_iterator iter1 = database->begin(); @@ -576,7 +576,7 @@ TYPED_TEST(DatabaseTest, TestNewPutLMDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -601,13 +601,13 @@ TYPED_TEST(DatabaseTest, TestNewGetLMDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); @@ -622,11 +622,11 @@ TYPED_TEST(DatabaseTest, TestNewGetNoCommitLMDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) { @@ -637,7 +637,7 @@ TYPED_TEST(DatabaseTest, TestReadWritePutLMDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -662,13 +662,13 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetLMDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); @@ -683,11 +683,11 @@ TYPED_TEST(DatabaseTest, TestReadWriteGetNoCommitLMDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } TYPED_TEST(DatabaseTest, TestReadOnlyPutLMDBFails) { @@ -701,7 +701,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyPutLMDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_FALSE(database->put(&key, &val)); + EXPECT_FALSE(database->put(key, val)); } TYPED_TEST(DatabaseTest, TestReadOnlyCommitLMDBFails) { @@ -723,7 +723,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); EXPECT_TRUE(database->commit()); @@ -733,7 +733,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetLMDBPasses) { Database::buffer_t new_val; - EXPECT_TRUE(database->get(&key, &new_val)); + EXPECT_TRUE(database->get(key, &new_val)); EXPECT_TRUE(this->BufferEq(val, new_val)); } @@ -746,7 +746,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) { Database::buffer_t key = this->TestKey(); Database::buffer_t val = this->TestValue(); - EXPECT_TRUE(database->put(&key, &val)); + EXPECT_TRUE(database->put(key, val)); database->close(); @@ -754,7 +754,7 @@ TYPED_TEST(DatabaseTest, TestReadOnlyGetNoCommitLMDBFails) { Database::buffer_t new_val; - EXPECT_FALSE(database->get(&key, &new_val)); + EXPECT_FALSE(database->get(key, &new_val)); } } // namespace caffe diff --git a/tools/convert_imageset.cpp b/tools/convert_imageset.cpp index 3345c9c7..5ad2c0b4 100644 --- a/tools/convert_imageset.cpp +++ b/tools/convert_imageset.cpp @@ -116,7 +116,7 @@ int main(int argc, char** argv) { Database::buffer_t keystr(key_cstr, key_cstr + length); // Put in db - CHECK(database->put(&keystr, &value)); + CHECK(database->put(keystr, value)); if (++count % 1000 == 0) { // Commit txn diff --git a/tools/extract_features.cpp b/tools/extract_features.cpp index 1560ef60..0c7660d5 100644 --- a/tools/extract_features.cpp +++ b/tools/extract_features.cpp @@ -161,7 +161,7 @@ int feature_extraction_pipeline(int argc, char** argv) { int length = snprintf(key_str, kMaxKeyStrLength, "%d", image_indices[i]); Database::buffer_t key(key_str, key_str + length); - CHECK(feature_dbs.at(i)->put(&key, &value)); + CHECK(feature_dbs.at(i)->put(key, value)); ++image_indices[i]; if (image_indices[i] % 1000 == 0) { CHECK(feature_dbs.at(i)->commit()); |