diff options
author | Jeff Donahue <jeff.donahue@gmail.com> | 2015-07-29 17:27:04 -0700 |
---|---|---|
committer | Eric Tzeng <etzeng@eecs.berkeley.edu> | 2015-08-07 13:48:42 -0700 |
commit | f973819240768df207ed8e4d307564b105950333 (patch) | |
tree | f8dc951b6b6fb38cea6f4c2b9df4fe04f8ff65c3 /src | |
parent | 1d5f4e5491f57a267ab9dacf6184a14c0e231159 (diff) | |
download | caffeonacl-f973819240768df207ed8e4d307564b105950333.tar.gz caffeonacl-f973819240768df207ed8e4d307564b105950333.tar.bz2 caffeonacl-f973819240768df207ed8e4d307564b105950333.zip |
add double_data, double_diff to BlobProto for weights/snapshots saved
when using Dtype == double
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/blob.cpp | 49 | ||||
-rw-r--r-- | src/caffe/proto/caffe.proto | 2 |
2 files changed, 44 insertions, 7 deletions
diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 94fdcc35..8450aa14 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -456,10 +456,25 @@ void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) { } // copy data Dtype* data_vec = mutable_cpu_data(); - for (int i = 0; i < count_; ++i) { - data_vec[i] = proto.data(i); + if (proto.double_data_size() > 0) { + CHECK_EQ(count_, proto.double_data_size()); + for (int i = 0; i < count_; ++i) { + data_vec[i] = proto.double_data(i); + } + } else { + CHECK_EQ(count_, proto.data_size()); + for (int i = 0; i < count_; ++i) { + data_vec[i] = proto.data(i); + } } - if (proto.diff_size() > 0) { + if (proto.double_diff_size() > 0) { + CHECK_EQ(count_, proto.double_diff_size()); + Dtype* diff_vec = mutable_cpu_diff(); + for (int i = 0; i < count_; ++i) { + diff_vec[i] = proto.double_diff(i); + } + } else if (proto.diff_size() > 0) { + CHECK_EQ(count_, proto.diff_size()); Dtype* diff_vec = mutable_cpu_diff(); for (int i = 0; i < count_; ++i) { diff_vec[i] = proto.diff(i); @@ -467,20 +482,40 @@ void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) { } } -template <typename Dtype> -void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const { +template <> +void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const { + proto->clear_shape(); + for (int i = 0; i < shape_.size(); ++i) { + proto->mutable_shape()->add_dim(shape_[i]); + } + proto->clear_double_data(); + proto->clear_double_diff(); + const double* data_vec = cpu_data(); + for (int i = 0; i < count_; ++i) { + proto->add_double_data(data_vec[i]); + } + if (write_diff) { + const double* diff_vec = cpu_diff(); + for (int i = 0; i < count_; ++i) { + proto->add_double_diff(diff_vec[i]); + } + } +} + +template <> +void Blob<float>::ToProto(BlobProto* proto, bool write_diff) const { proto->clear_shape(); for (int i = 0; i < shape_.size(); ++i) { proto->mutable_shape()->add_dim(shape_[i]); } proto->clear_data(); proto->clear_diff(); - const Dtype* data_vec = cpu_data(); + const float* data_vec = cpu_data(); for (int i = 0; i < count_; ++i) { proto->add_data(data_vec[i]); } if (write_diff) { - const Dtype* diff_vec = cpu_diff(); + const float* diff_vec = cpu_diff(); for (int i = 0; i < count_; ++i) { proto->add_diff(diff_vec[i]); } diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index adcf4e2f..03daa808 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -11,6 +11,8 @@ message BlobProto { optional BlobShape shape = 7; repeated float data = 5 [packed = true]; repeated float diff = 6 [packed = true]; + repeated double double_data = 8 [packed = true]; + repeated double double_diff = 9 [packed = true]; // 4D dimensions -- deprecated. Use "shape" instead. optional int32 num = 1 [default = 0]; |