summaryrefslogtreecommitdiff
path: root/src/caffe/blob.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/caffe/blob.cpp')
-rw-r--r--src/caffe/blob.cpp18
1 files changed, 18 insertions, 0 deletions
diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp
index 4a34e4c5..603e52f7 100644
--- a/src/caffe/blob.cpp
+++ b/src/caffe/blob.cpp
@@ -89,6 +89,12 @@ const Dtype* Blob<Dtype>::cpu_data() const {
template <typename Dtype>
void Blob<Dtype>::set_cpu_data(Dtype* data) {
CHECK(data);
+ // Make sure CPU and GPU sizes remain equal
+ size_t size = count_ * sizeof(Dtype);
+ if (data_->size() != size) {
+ data_.reset(new SyncedMemory(size));
+ diff_.reset(new SyncedMemory(size));
+ }
data_->set_cpu_data(data);
}
@@ -99,6 +105,18 @@ const Dtype* Blob<Dtype>::gpu_data() const {
}
template <typename Dtype>
+void Blob<Dtype>::set_gpu_data(Dtype* data) {
+ CHECK(data);
+ // Make sure CPU and GPU sizes remain equal
+ size_t size = count_ * sizeof(Dtype);
+ if (data_->size() != size) {
+ data_.reset(new SyncedMemory(size));
+ diff_.reset(new SyncedMemory(size));
+ }
+ data_->set_gpu_data(data);
+}
+
+template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_diff() const {
CHECK(diff_);
return (const Dtype*)diff_->cpu_data();