diff options
Diffstat (limited to 'src/caffe/blob.cpp')
-rw-r--r-- | src/caffe/blob.cpp | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 4a34e4c5..603e52f7 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -89,6 +89,12 @@ const Dtype* Blob<Dtype>::cpu_data() const { template <typename Dtype> void Blob<Dtype>::set_cpu_data(Dtype* data) { CHECK(data); + // Make sure CPU and GPU sizes remain equal + size_t size = count_ * sizeof(Dtype); + if (data_->size() != size) { + data_.reset(new SyncedMemory(size)); + diff_.reset(new SyncedMemory(size)); + } data_->set_cpu_data(data); } @@ -99,6 +105,18 @@ const Dtype* Blob<Dtype>::gpu_data() const { } template <typename Dtype> +void Blob<Dtype>::set_gpu_data(Dtype* data) { + CHECK(data); + // Make sure CPU and GPU sizes remain equal + size_t size = count_ * sizeof(Dtype); + if (data_->size() != size) { + data_.reset(new SyncedMemory(size)); + diff_.reset(new SyncedMemory(size)); + } + data_->set_gpu_data(data); +} + +template <typename Dtype> const Dtype* Blob<Dtype>::cpu_diff() const { CHECK(diff_); return (const Dtype*)diff_->cpu_data(); |