diff options
author | Jeff Donahue <jeff.donahue@gmail.com> | 2015-03-04 21:31:34 -0800 |
---|---|---|
committer | Jeff Donahue <jeff.donahue@gmail.com> | 2015-09-18 17:53:19 -0700 |
commit | 0813f32038bf7477d343ae369981166cfed783b5 (patch) | |
tree | 402bd51a8cd61b9c02d8aed7e1cd465b49acfd40 | |
parent | 4c2ff1693ea509dc4758e73b913f4cbec6c1ac3a (diff) | |
download | caffeonacl-0813f32038bf7477d343ae369981166cfed783b5.tar.gz caffeonacl-0813f32038bf7477d343ae369981166cfed783b5.tar.bz2 caffeonacl-0813f32038bf7477d343ae369981166cfed783b5.zip |
Blob: add SyncedMemory shape accessor for GPU shape access
-rw-r--r-- | include/caffe/blob.hpp | 2 | ||||
-rw-r--r-- | src/caffe/blob.cpp | 11 |
2 files changed, 13 insertions, 0 deletions
diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index dda7b1f8..fea5117e 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -219,6 +219,7 @@ class Blob { const Dtype* cpu_data() const; void set_cpu_data(Dtype* data); + const int* gpu_shape() const; const Dtype* gpu_data() const; const Dtype* cpu_diff() const; const Dtype* gpu_diff() const; @@ -268,6 +269,7 @@ class Blob { protected: shared_ptr<SyncedMemory> data_; shared_ptr<SyncedMemory> diff_; + shared_ptr<SyncedMemory> shape_data_; vector<int> shape_; int count_; int capacity_; diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 8450aa14..c86fd5d1 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -24,11 +24,16 @@ void Blob<Dtype>::Reshape(const vector<int>& shape) { CHECK_LE(shape.size(), kMaxBlobAxes); count_ = 1; shape_.resize(shape.size()); + if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) { + shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int))); + } + int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data()); for (int i = 0; i < shape.size(); ++i) { CHECK_GE(shape[i], 0); CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX"; count_ *= shape[i]; shape_[i] = shape[i]; + shape_data[i] = shape[i]; } if (count_ > capacity_) { capacity_ = count_; @@ -68,6 +73,12 @@ Blob<Dtype>::Blob(const vector<int>& shape) } template <typename Dtype> +const int* Blob<Dtype>::gpu_shape() const { + CHECK(shape_data_); + return (const int*)shape_data_->gpu_data(); +} + +template <typename Dtype> const Dtype* Blob<Dtype>::cpu_data() const { CHECK(data_); return (const Dtype*)data_->cpu_data(); |