diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/blob.cpp | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 8450aa14..c86fd5d1 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -24,11 +24,16 @@ void Blob<Dtype>::Reshape(const vector<int>& shape) { CHECK_LE(shape.size(), kMaxBlobAxes); count_ = 1; shape_.resize(shape.size()); + if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) { + shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int))); + } + int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data()); for (int i = 0; i < shape.size(); ++i) { CHECK_GE(shape[i], 0); CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX"; count_ *= shape[i]; shape_[i] = shape[i]; + shape_data[i] = shape[i]; } if (count_ > capacity_) { capacity_ = count_; @@ -68,6 +73,12 @@ Blob<Dtype>::Blob(const vector<int>& shape) } template <typename Dtype> +const int* Blob<Dtype>::gpu_shape() const { + CHECK(shape_data_); + return (const int*)shape_data_->gpu_data(); +} + +template <typename Dtype> const Dtype* Blob<Dtype>::cpu_data() const { CHECK(data_); return (const Dtype*)data_->cpu_data(); |