/** * Copyright (c) 2016-present, Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "caffe2/core/db.h" #include "caffe2/core/init.h" #include "caffe2/proto/caffe2_pb.h" #include "caffe2/core/logging.h" C10_DEFINE_string(input_db, "", "The input db."); C10_DEFINE_int(splits, 0, "The number of splits."); C10_DEFINE_string(db_type, "", "The db type."); C10_DEFINE_int(batch_size, 1000, "The write batch size."); namespace caffe2 { static int Split(int argc, char** argv) { GlobalInit(&argc, &argv); CAFFE_ENFORCE(FLAGS_input_db.size(), "Must specify --input_db=/path/to/db."); CAFFE_ENFORCE(FLAGS_splits > 0, "Must specify a nonnegative split number."); CAFFE_ENFORCE(FLAGS_db_type.size(), "Must specify --db_type=[a db type]."); unique_ptr in_db( db::CreateDB(FLAGS_db_type, FLAGS_input_db, db::READ)); CAFFE_ENFORCE(in_db != nullptr, "Cannot open input db: ", FLAGS_input_db); unique_ptr cursor(in_db->NewCursor()); // This usually won't happen, but FWIW. CAFFE_ENFORCE( cursor != nullptr, "Cannot obtain cursor for input db: ", FLAGS_input_db); vector> out_dbs; vector> transactions; for (int i = 0; i < FLAGS_splits; ++i) { out_dbs.push_back(unique_ptr(db::CreateDB( FLAGS_db_type, FLAGS_input_db + "_split_" + to_string(i), db::NEW))); CAFFE_ENFORCE(out_dbs.back().get(), "Cannot create output db #", i); transactions.push_back( unique_ptr(out_dbs[i]->NewTransaction())); CAFFE_ENFORCE( transactions.back().get(), "Cannot get transaction for output db #", i); } int count = 0; for (; cursor->Valid(); cursor->Next()) { transactions[count % FLAGS_splits]->Put(cursor->key(), cursor->value()); if (++count % FLAGS_batch_size == 0) { for (int i = 0; i < FLAGS_splits; ++i) { transactions[i]->Commit(); } LOG(INFO) << "Split " << count << " items so far."; } } LOG(INFO) << "A total of " << count << " items processed."; return 0; } } // namespace caffe2 int main(int argc, char** argv) { return caffe2::Split(argc, argv); }