From 0813f32038bf7477d343ae369981166cfed783b5 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Wed, 4 Mar 2015 21:31:34 -0800 Subject: Blob: add SyncedMemory shape accessor for GPU shape access --- include/caffe/blob.hpp | 2 ++ src/caffe/blob.cpp | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp index dda7b1f8..fea5117e 100644 --- a/include/caffe/blob.hpp +++ b/include/caffe/blob.hpp @@ -219,6 +219,7 @@ class Blob { const Dtype* cpu_data() const; void set_cpu_data(Dtype* data); + const int* gpu_shape() const; const Dtype* gpu_data() const; const Dtype* cpu_diff() const; const Dtype* gpu_diff() const; @@ -268,6 +269,7 @@ class Blob { protected: shared_ptr data_; shared_ptr diff_; + shared_ptr shape_data_; vector shape_; int count_; int capacity_; diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp index 8450aa14..c86fd5d1 100644 --- a/src/caffe/blob.cpp +++ b/src/caffe/blob.cpp @@ -24,11 +24,16 @@ void Blob::Reshape(const vector& shape) { CHECK_LE(shape.size(), kMaxBlobAxes); count_ = 1; shape_.resize(shape.size()); + if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) { + shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int))); + } + int* shape_data = static_cast(shape_data_->mutable_cpu_data()); for (int i = 0; i < shape.size(); ++i) { CHECK_GE(shape[i], 0); CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX"; count_ *= shape[i]; shape_[i] = shape[i]; + shape_data[i] = shape[i]; } if (count_ > capacity_) { capacity_ = count_; @@ -67,6 +72,12 @@ Blob::Blob(const vector& shape) Reshape(shape); } +template +const int* Blob::gpu_shape() const { + CHECK(shape_data_); + return (const int*)shape_data_->gpu_data(); +} + template const Dtype* Blob::cpu_data() const { CHECK(data_); -- cgit v1.2.3