summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJeff Donahue <jeff.donahue@gmail.com>2015-07-29 17:27:04 -0700
committerEric Tzeng <etzeng@eecs.berkeley.edu>2015-08-07 13:48:42 -0700
commitf973819240768df207ed8e4d307564b105950333 (patch)
treef8dc951b6b6fb38cea6f4c2b9df4fe04f8ff65c3 /src
parent1d5f4e5491f57a267ab9dacf6184a14c0e231159 (diff)
downloadcaffeonacl-f973819240768df207ed8e4d307564b105950333.tar.gz
caffeonacl-f973819240768df207ed8e4d307564b105950333.tar.bz2
caffeonacl-f973819240768df207ed8e4d307564b105950333.zip
add double_data, double_diff to BlobProto for weights/snapshots saved
when using Dtype == double
Diffstat (limited to 'src')
-rw-r--r--src/caffe/blob.cpp49
-rw-r--r--src/caffe/proto/caffe.proto2
2 files changed, 44 insertions, 7 deletions
diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp
index 94fdcc35..8450aa14 100644
--- a/src/caffe/blob.cpp
+++ b/src/caffe/blob.cpp
@@ -456,10 +456,25 @@ void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {
}
// copy data
Dtype* data_vec = mutable_cpu_data();
- for (int i = 0; i < count_; ++i) {
- data_vec[i] = proto.data(i);
+ if (proto.double_data_size() > 0) {
+ CHECK_EQ(count_, proto.double_data_size());
+ for (int i = 0; i < count_; ++i) {
+ data_vec[i] = proto.double_data(i);
+ }
+ } else {
+ CHECK_EQ(count_, proto.data_size());
+ for (int i = 0; i < count_; ++i) {
+ data_vec[i] = proto.data(i);
+ }
}
- if (proto.diff_size() > 0) {
+ if (proto.double_diff_size() > 0) {
+ CHECK_EQ(count_, proto.double_diff_size());
+ Dtype* diff_vec = mutable_cpu_diff();
+ for (int i = 0; i < count_; ++i) {
+ diff_vec[i] = proto.double_diff(i);
+ }
+ } else if (proto.diff_size() > 0) {
+ CHECK_EQ(count_, proto.diff_size());
Dtype* diff_vec = mutable_cpu_diff();
for (int i = 0; i < count_; ++i) {
diff_vec[i] = proto.diff(i);
@@ -467,20 +482,40 @@ void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {
}
}
-template <typename Dtype>
-void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {
+template <>
+void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const {
+ proto->clear_shape();
+ for (int i = 0; i < shape_.size(); ++i) {
+ proto->mutable_shape()->add_dim(shape_[i]);
+ }
+ proto->clear_double_data();
+ proto->clear_double_diff();
+ const double* data_vec = cpu_data();
+ for (int i = 0; i < count_; ++i) {
+ proto->add_double_data(data_vec[i]);
+ }
+ if (write_diff) {
+ const double* diff_vec = cpu_diff();
+ for (int i = 0; i < count_; ++i) {
+ proto->add_double_diff(diff_vec[i]);
+ }
+ }
+}
+
+template <>
+void Blob<float>::ToProto(BlobProto* proto, bool write_diff) const {
proto->clear_shape();
for (int i = 0; i < shape_.size(); ++i) {
proto->mutable_shape()->add_dim(shape_[i]);
}
proto->clear_data();
proto->clear_diff();
- const Dtype* data_vec = cpu_data();
+ const float* data_vec = cpu_data();
for (int i = 0; i < count_; ++i) {
proto->add_data(data_vec[i]);
}
if (write_diff) {
- const Dtype* diff_vec = cpu_diff();
+ const float* diff_vec = cpu_diff();
for (int i = 0; i < count_; ++i) {
proto->add_diff(diff_vec[i]);
}
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index adcf4e2f..03daa808 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -11,6 +11,8 @@ message BlobProto {
optional BlobShape shape = 7;
repeated float data = 5 [packed = true];
repeated float diff = 6 [packed = true];
+ repeated double double_data = 8 [packed = true];
+ repeated double double_diff = 9 [packed = true];
// 4D dimensions -- deprecated. Use "shape" instead.
optional int32 num = 1 [default = 0];