summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/caffe/blob.hpp2
-rw-r--r--src/caffe/blob.cpp11
2 files changed, 13 insertions, 0 deletions
diff --git a/include/caffe/blob.hpp b/include/caffe/blob.hpp
index dda7b1f8..fea5117e 100644
--- a/include/caffe/blob.hpp
+++ b/include/caffe/blob.hpp
@@ -219,6 +219,7 @@ class Blob {
const Dtype* cpu_data() const;
void set_cpu_data(Dtype* data);
+ const int* gpu_shape() const;
const Dtype* gpu_data() const;
const Dtype* cpu_diff() const;
const Dtype* gpu_diff() const;
@@ -268,6 +269,7 @@ class Blob {
protected:
shared_ptr<SyncedMemory> data_;
shared_ptr<SyncedMemory> diff_;
+ shared_ptr<SyncedMemory> shape_data_;
vector<int> shape_;
int count_;
int capacity_;
diff --git a/src/caffe/blob.cpp b/src/caffe/blob.cpp
index 8450aa14..c86fd5d1 100644
--- a/src/caffe/blob.cpp
+++ b/src/caffe/blob.cpp
@@ -24,11 +24,16 @@ void Blob<Dtype>::Reshape(const vector<int>& shape) {
CHECK_LE(shape.size(), kMaxBlobAxes);
count_ = 1;
shape_.resize(shape.size());
+ if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) {
+ shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int)));
+ }
+ int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data());
for (int i = 0; i < shape.size(); ++i) {
CHECK_GE(shape[i], 0);
CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";
count_ *= shape[i];
shape_[i] = shape[i];
+ shape_data[i] = shape[i];
}
if (count_ > capacity_) {
capacity_ = count_;
@@ -68,6 +73,12 @@ Blob<Dtype>::Blob(const vector<int>& shape)
}
template <typename Dtype>
+const int* Blob<Dtype>::gpu_shape() const {
+ CHECK(shape_data_);
+ return (const int*)shape_data_->gpu_data();
+}
+
+template <typename Dtype>
const Dtype* Blob<Dtype>::cpu_data() const {
CHECK(data_);
return (const Dtype*)data_->cpu_data();