summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorYangqing Jia <jiayq84@gmail.com>2015-07-17 11:24:46 -0700
committerYangqing Jia <jiayq84@gmail.com>2015-07-18 07:23:09 -0700
commit47c70a43b4821fe0848345a3f6abf2c197669f30 (patch)
tree75c9379969402bebd57efcda1f5c1c10d54b689c /caffe2
parentc5166e578c90db43c2d479446673efbd3e99b01e (diff)
downloadpytorch-47c70a43b4821fe0848345a3f6abf2c197669f30.tar.gz
pytorch-47c70a43b4821fe0848345a3f6abf2c197669f30.tar.bz2
pytorch-47c70a43b4821fe0848345a3f6abf2c197669f30.zip
(1) minidb bugfix
(2) blob serialization comments (3) cudnn: putting it under a separate device name so we can explicitly choose cudnn instead of having CUDA device prioritizing it. (4) note that mint is not available with ipython due to zeromq conflict (5) db_throughput utility (6) added gprofiler
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/binaries/BREW38
-rw-r--r--caffe2/binaries/db_throughput.cc48
-rw-r--r--caffe2/core/blob.h14
-rw-r--r--caffe2/core/blob_serialization.cc6
-rw-r--r--caffe2/core/minidb.cc15
-rw-r--r--caffe2/core/operator.cc8
-rw-r--r--caffe2/proto/caffe2.proto3
7 files changed, 102 insertions, 30 deletions
diff --git a/caffe2/binaries/BREW b/caffe2/binaries/BREW
index 46680ead3a..834755c08a 100644
--- a/caffe2/binaries/BREW
+++ b/caffe2/binaries/BREW
@@ -25,43 +25,56 @@ cc_binary(
)
cc_binary(
- name = "make_cifar_db",
+ name = "convert_encoded_to_raw_leveldb",
srcs = [
- "make_cifar_db.cc",
+ "convert_encoded_to_raw_leveldb.cc",
],
deps = [
- "//caffe2/db:db",
+ "//caffe2/core:core",
"//caffe2/proto:caffe2_proto",
+ "//third_party/leveldb:leveldb",
"//third_party/gflags:gflags",
"//third_party/glog:glog",
+ "//third_party/opencv:opencv_core",
+ "//third_party/opencv:opencv_highgui",
+ "//third_party/opencv:opencv_imgproc",
],
)
cc_binary(
- name = "make_image_db",
+ name = "db_throughput",
srcs = [
- "make_image_db.cc",
+ "db_throughput.cc",
+ ],
+ deps = [
+ "//caffe2/db:db",
+ "//third_party/gflags:gflags",
+ "//third_party/glog:glog",
+ "//third_party/google:profiler",
+ ],
+)
+
+cc_binary(
+ name = "make_cifar_db",
+ srcs = [
+ "make_cifar_db.cc",
],
deps = [
"//caffe2/db:db",
"//caffe2/proto:caffe2_proto",
"//third_party/gflags:gflags",
"//third_party/glog:glog",
- "//third_party/opencv:opencv_core",
- "//third_party/opencv:opencv_highgui",
- "//third_party/opencv:opencv_imgproc",
],
)
cc_binary(
- name = "convert_encoded_to_raw_leveldb",
+ name = "make_image_db",
srcs = [
- "convert_encoded_to_raw_leveldb.cc",
+ "make_image_db.cc",
],
deps = [
- "//caffe2/core:core",
+ "//caffe2/db:db",
"//caffe2/proto:caffe2_proto",
- "//third_party/leveldb:leveldb",
"//third_party/gflags:gflags",
"//third_party/glog:glog",
"//third_party/opencv:opencv_core",
@@ -70,7 +83,6 @@ cc_binary(
],
)
-
cc_binary(
name = "make_mnist_db",
srcs = [
diff --git a/caffe2/binaries/db_throughput.cc b/caffe2/binaries/db_throughput.cc
new file mode 100644
index 0000000000..7c9ae1a5bb
--- /dev/null
+++ b/caffe2/binaries/db_throughput.cc
@@ -0,0 +1,48 @@
+#include <ctime>
+#include <cstdio>
+
+#include "caffe2/core/db.h"
+#include "caffe2/proto/caffe2.pb.h"
+#include "gflags/gflags.h"
+#include "glog/logging.h"
+#include "google/profiler.h"
+
+DEFINE_string(input_db, "", "The input db.");
+DEFINE_string(input_db_type, "", "The input db type.");
+DEFINE_string(profile_file, "db_throughput_profile", "The profile output.");
+DEFINE_int32(report_interval, 1000, "The report interval.");
+DEFINE_int32(repeat, 10, "The number to repeat the throughput test.");
+
+using caffe2::db::Cursor;
+using caffe2::db::DB;
+using caffe2::string;
+
+int main(int argc, char** argv) {
+ google::InitGoogleLogging(argv[0]);
+ gflags::SetUsageMessage(
+ "This script reports the throughput .");
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ std::unique_ptr<DB> in_db(caffe2::db::CreateDB(
+ FLAGS_input_db_type, FLAGS_input_db, caffe2::db::READ));
+ std::unique_ptr<Cursor> cursor(in_db->NewCursor());
+
+ ProfilerStart(FLAGS_profile_file.c_str());
+ for (int iter_id = 0; iter_id < FLAGS_repeat; ++iter_id) {
+ clock_t start = clock();
+ for (int i = 0; i < FLAGS_report_interval; ++i) {
+ volatile string key = cursor->key();
+ volatile string value = cursor->value();
+ cursor->Next();
+ if (!cursor->Valid()) {
+ cursor->SeekToFirst();
+ }
+ }
+ clock_t elapsed = clock() - start;
+ double elapsed_seconds = static_cast<double>(elapsed) / CLOCKS_PER_SEC;
+ printf("Iteration %03d, took %4.5f seconds, throughput %f items/sec.\n",
+ iter_id, elapsed_seconds, FLAGS_report_interval / elapsed_seconds);
+ }
+ ProfilerStop();
+ return 0;
+}
diff --git a/caffe2/core/blob.h b/caffe2/core/blob.h
index c14a273a13..5a5c8cd5ee 100644
--- a/caffe2/core/blob.h
+++ b/caffe2/core/blob.h
@@ -66,7 +66,7 @@ class Blob {
// Serializes the current blob, if possible. This serialization uses
// registration so we don't need to deal with multiple platform problems.
- string Serialize(const string& name) const;
+ inline string Serialize(const string& name) const;
private:
internal::TypeId id_;
@@ -76,12 +76,15 @@ class Blob {
DISABLE_COPY_AND_ASSIGN(Blob);
};
-// BlobSerializer is a class that serializes a blob to a string.
+// BlobSerializerBase is a class that serializes a blob to a string. This class
+// exists purely for the purpose of registering type-specific serialization
+// code.
class BlobSerializerBase {
public:
virtual string Serialize(const Blob& blob, const string& name) = 0;
};
+// THe Blob serialization registry and serializer creator functions.
DECLARE_TYPED_REGISTRY(BlobSerializerRegistry, internal::TypeId,
BlobSerializerBase);
#define REGISTER_BLOB_SERIALIZER(name, id, ...) \
@@ -91,6 +94,13 @@ inline BlobSerializerBase* CreateSerializer(internal::TypeId id) {
return BlobSerializerRegistry()->Create(id);
}
+// The blob serialization member function implementation.
+inline string Blob::Serialize(const string& name) const {
+ std::unique_ptr<BlobSerializerBase> serializer(CreateSerializer(id_));
+ return serializer->Serialize(*this, name);
+}
+
+
template <typename dtype, class Context>
class Tensor {
public:
diff --git a/caffe2/core/blob_serialization.cc b/caffe2/core/blob_serialization.cc
index d44e3c25b4..18e9fc3b2b 100644
--- a/caffe2/core/blob_serialization.cc
+++ b/caffe2/core/blob_serialization.cc
@@ -5,12 +5,6 @@ namespace caffe2 {
DEFINE_TYPED_REGISTRY(BlobSerializerRegistry, internal::TypeId,
BlobSerializerBase);
-string Blob::Serialize(const string& name) const {
- std::unique_ptr<BlobSerializerBase> serializer(CreateSerializer(id_));
- return serializer->Serialize(*this, name);
-}
-
-
namespace {
REGISTER_BLOB_SERIALIZER(float_cpu,
(internal::GetTypeId<Tensor<float, CPUContext> >()),
diff --git a/caffe2/core/minidb.cc b/caffe2/core/minidb.cc
index 3577fc92d3..744c1b7806 100644
--- a/caffe2/core/minidb.cc
+++ b/caffe2/core/minidb.cc
@@ -10,7 +10,10 @@ namespace db {
class MiniDBCursor : public Cursor {
public:
explicit MiniDBCursor(FILE* f, std::mutex* mutex)
- : file_(f), lock_(*mutex) {}
+ : file_(f), lock_(*mutex), valid_(true) {
+ // We call Next() to read in the first entry.
+ Next();
+ }
~MiniDBCursor() {}
void SeekToFirst() override {
@@ -22,31 +25,37 @@ class MiniDBCursor : public Cursor {
}
void Next() override {
+ // First, read in the key and value length.
if (fread(&key_len_, sizeof(int), 1, file_) == 0) {
// Reaching EOF.
+ LOG(INFO) << "EOF reached, setting valid to false";
valid_ = false;
return;
}
CHECK_EQ(fread(&value_len_, sizeof(int), 1, file_), 1);
CHECK_GT(key_len_, 0);
CHECK_GT(value_len_, 0);
+ // Resize if the key and value len is larger than the current one.
if (key_len_ > key_.size()) {
key_.resize(key_len_);
}
if (value_len_ > value_.size()) {
value_.resize(value_len_);
}
+ // Actually read in the contents.
CHECK_EQ(fread(key_.data(), sizeof(char), key_len_, file_), key_len_);
CHECK_EQ(fread(value_.data(), sizeof(char), value_len_, file_), value_len_);
+ // Note(Yangqing): as we read the file, the cursor naturally moves to the
+ // beginning of the next entry.
}
string key() override {
- CHECK(valid_) << "Invalid position!";
+ CHECK(valid_) << "Cursor is at invalid location!";
return string(key_.data(), key_len_);
}
string value() override {
- CHECK(valid_) << "Invalid position!";
+ CHECK(valid_) << "Cursor is at invalid location!";
return string(value_.data(), value_len_);
}
diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc
index ed59d1ddf6..3993eb974e 100644
--- a/caffe2/core/operator.cc
+++ b/caffe2/core/operator.cc
@@ -98,12 +98,10 @@ OperatorBase* CreateOperator(const OperatorDef& operator_def, Workspace* ws) {
return CPUOperatorRegistry()->Create(key, operator_def, ws);
case CUDA:
VLOG(1) << "Creating CUDA operator " << key;
- // In Cuda, if we have cudnn, we will prefer to use cudnn first.
- if (CUDNNOperatorRegistry()->Has(key)) {
- VLOG(1) << "Using CuDNN implementation.";
- return CUDNNOperatorRegistry()->Create(key, operator_def, ws);
- }
return CUDAOperatorRegistry()->Create(key, operator_def, ws);
+ case CUDNN:
+ VLOG(1) << "Using CuDNN implementation.";
+ return CUDNNOperatorRegistry()->Create(key, operator_def, ws);
}
// Just to suppress some compiler error
return nullptr;
diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto
index c34b58da68..7313389eab 100644
--- a/caffe2/proto/caffe2.proto
+++ b/caffe2/proto/caffe2.proto
@@ -36,8 +36,9 @@ message Argument {
}
enum DeviceType {
- CPU = 0; // In default, we will use CPU.
+ CPU = 0; // In default, we will use CPU.
CUDA = 1; // CUDA, with custom kernels.
+ CUDNN = 2; // CUDA, with CUDNN implementations.
}
message DeviceOption {