summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeff Donahue <jeff.donahue@gmail.com>2015-02-15 13:26:36 -0800
committerJeff Donahue <jeff.donahue@gmail.com>2015-03-03 15:55:14 -0800
commitabec30252ced89d9e2550ca47fca569f563479f6 (patch)
tree5f72270d28a4587cb5761e041bc77d622e9a08d8
parentb86891635dbb24f70d5634a679150070caf776e4 (diff)
downloadcaffeonacl-abec30252ced89d9e2550ca47fca569f563479f6.tar.gz
caffeonacl-abec30252ced89d9e2550ca47fca569f563479f6.tar.bz2
caffeonacl-abec30252ced89d9e2550ca47fca569f563479f6.zip
SoftmaxLayer: generalized Blob axes
-rw-r--r--include/caffe/common_layers.hpp3
-rw-r--r--src/caffe/layers/softmax_layer.cpp62
-rw-r--r--src/caffe/layers/softmax_layer.cu35
-rw-r--r--src/caffe/proto/caffe.proto5
4 files changed, 54 insertions, 51 deletions
diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp
index 4e47e55d..b1ac3a93 100644
--- a/include/caffe/common_layers.hpp
+++ b/include/caffe/common_layers.hpp
@@ -353,6 +353,9 @@ class SoftmaxLayer : public Layer<Dtype> {
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ int outer_num_;
+ int inner_num_;
+ int softmax_axis_;
/// sum_multiplier is used to carry out sum using BLAS
Blob<Dtype> sum_multiplier_;
/// scale is an intermediate Blob to hold temporary results.
diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp
index 25142fde..04712c9e 100644
--- a/src/caffe/layers/softmax_layer.cpp
+++ b/src/caffe/layers/softmax_layer.cpp
@@ -10,14 +10,18 @@ namespace caffe {
template <typename Dtype>
void SoftmaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
- top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
- bottom[0]->height(), bottom[0]->width());
- sum_multiplier_.Reshape(1, bottom[0]->channels(), 1, 1);
+ softmax_axis_ =
+ bottom[0]->CanonicalAxisIndex(this->layer_param_.softmax_param().axis());
+ top[0]->ReshapeLike(*bottom[0]);
+ vector<int> mult_dims(1, bottom[0]->shape(softmax_axis_));
+ sum_multiplier_.Reshape(mult_dims);
Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
- for (int i = 0; i < sum_multiplier_.count(); ++i) {
- multiplier_data[i] = 1.;
- }
- scale_.Reshape(bottom[0]->num(), 1, bottom[0]->height(), bottom[0]->width());
+ caffe_set(sum_multiplier_.count(), Dtype(1), multiplier_data);
+ outer_num_ = bottom[0]->count(0, softmax_axis_);
+ inner_num_ = bottom[0]->count(softmax_axis_ + 1);
+ vector<int> scale_dims = bottom[0]->shape();
+ scale_dims[softmax_axis_] = 1;
+ scale_.Reshape(scale_dims);
}
template <typename Dtype>
@@ -26,34 +30,32 @@ void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
Dtype* scale_data = scale_.mutable_cpu_data();
- int num = bottom[0]->num();
- int channels = bottom[0]->channels();
- int dim = bottom[0]->count() / bottom[0]->num();
- int spatial_dim = bottom[0]->height() * bottom[0]->width();
+ int channels = bottom[0]->shape(softmax_axis_);
+ int dim = bottom[0]->count() / outer_num_;
caffe_copy(bottom[0]->count(), bottom_data, top_data);
// We need to subtract the max to avoid numerical issues, compute the exp,
// and then normalize.
- for (int i = 0; i < num; ++i) {
+ for (int i = 0; i < outer_num_; ++i) {
// initialize scale_data to the first plane
- caffe_copy(spatial_dim, bottom_data + i * dim, scale_data);
+ caffe_copy(inner_num_, bottom_data + i * dim, scale_data);
for (int j = 0; j < channels; j++) {
- for (int k = 0; k < spatial_dim; k++) {
+ for (int k = 0; k < inner_num_; k++) {
scale_data[k] = std::max(scale_data[k],
- bottom_data[i * dim + j * spatial_dim + k]);
+ bottom_data[i * dim + j * inner_num_ + k]);
}
}
// subtraction
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim,
- 1, -1., sum_multiplier_.cpu_data(), scale_data, 1., top_data + i * dim);
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, inner_num_,
+ 1, -1., sum_multiplier_.cpu_data(), scale_data, 1., top_data);
// exponentiation
- caffe_exp<Dtype>(dim, top_data + i * dim, top_data + i * dim);
+ caffe_exp<Dtype>(dim, top_data, top_data);
// sum after exp
- caffe_cpu_gemv<Dtype>(CblasTrans, channels, spatial_dim, 1.,
- top_data + i * dim, sum_multiplier_.cpu_data(), 0., scale_data);
+ caffe_cpu_gemv<Dtype>(CblasTrans, channels, inner_num_, 1.,
+ top_data, sum_multiplier_.cpu_data(), 0., scale_data);
// division
for (int j = 0; j < channels; j++) {
- caffe_div(spatial_dim, top_data + top[0]->offset(i, j), scale_data,
- top_data + top[0]->offset(i, j));
+ caffe_div(inner_num_, top_data, scale_data, top_data);
+ top_data += inner_num_;
}
}
}
@@ -66,20 +68,18 @@ void SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_data = top[0]->cpu_data();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
Dtype* scale_data = scale_.mutable_cpu_data();
- int num = top[0]->num();
- int channels = top[0]->channels();
- int dim = top[0]->count() / top[0]->num();
- int spatial_dim = top[0]->height() * top[0]->width();
+ int channels = top[0]->shape(softmax_axis_);
+ int dim = top[0]->count() / outer_num_;
caffe_copy(top[0]->count(), top_diff, bottom_diff);
- for (int i = 0; i < num; ++i) {
+ for (int i = 0; i < outer_num_; ++i) {
// compute dot(top_diff, top_data) and subtract them from the bottom diff
- for (int k = 0; k < spatial_dim; ++k) {
+ for (int k = 0; k < inner_num_; ++k) {
scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,
- bottom_diff + i * dim + k, spatial_dim,
- top_data + i * dim + k, spatial_dim);
+ bottom_diff + i * dim + k, inner_num_,
+ top_data + i * dim + k, inner_num_);
}
// subtraction
- caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, spatial_dim, 1,
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, inner_num_, 1,
-1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);
}
// elementwise multiplication
diff --git a/src/caffe/layers/softmax_layer.cu b/src/caffe/layers/softmax_layer.cu
index 6b8871a0..1f9c3a41 100644
--- a/src/caffe/layers/softmax_layer.cu
+++ b/src/caffe/layers/softmax_layer.cu
@@ -90,36 +90,33 @@ void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
Dtype* top_data = top[0]->mutable_gpu_data();
Dtype* scale_data = scale_.mutable_gpu_data();
int count = bottom[0]->count();
- int num = bottom[0]->num();
- int channels = bottom[0]->channels();
- int spatial_dim = bottom[0]->height() * bottom[0]->width();
+ int channels = top[0]->shape(softmax_axis_);
caffe_copy(count, bottom_data, top_data);
// We need to subtract the max to avoid numerical issues, compute the exp,
// and then normalize.
// compute max
// NOLINT_NEXT_LINE(whitespace/operators)
- kernel_channel_max<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
- CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data,
+ kernel_channel_max<Dtype><<<CAFFE_GET_BLOCKS(outer_num_ * inner_num_),
+ CAFFE_CUDA_NUM_THREADS>>>(outer_num_, channels, inner_num_, top_data,
scale_data);
// subtract
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_subtract<Dtype><<<CAFFE_GET_BLOCKS(count),
- CAFFE_CUDA_NUM_THREADS>>>(count, num, channels, spatial_dim,
+ CAFFE_CUDA_NUM_THREADS>>>(count, outer_num_, channels, inner_num_,
scale_data, top_data);
// exponentiate
// NOLINT_NEXT_LINE(whitespace/operators)
- kernel_exp<Dtype><<<CAFFE_GET_BLOCKS(num * channels * spatial_dim),
- CAFFE_CUDA_NUM_THREADS>>>(num * channels * spatial_dim, top_data,
- top_data);
+ kernel_exp<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
+ count, top_data, top_data);
// sum after exp
// NOLINT_NEXT_LINE(whitespace/operators)
- kernel_channel_sum<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
- CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data,
+ kernel_channel_sum<Dtype><<<CAFFE_GET_BLOCKS(outer_num_ * inner_num_),
+ CAFFE_CUDA_NUM_THREADS>>>(outer_num_, channels, inner_num_, top_data,
scale_data);
// divide
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_div<Dtype><<<CAFFE_GET_BLOCKS(count),
- CAFFE_CUDA_NUM_THREADS>>>(count, num, channels, spatial_dim,
+ CAFFE_CUDA_NUM_THREADS>>>(count, outer_num_, channels, inner_num_,
scale_data, top_data);
}
@@ -131,18 +128,16 @@ void SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
Dtype* scale_data = scale_.mutable_gpu_data();
int count = top[0]->count();
- int num = top[0]->num();
- int channels = top[0]->channels();
- int spatial_dim = top[0]->height() * top[0]->width();
- caffe_copy(top[0]->count(), top_diff, bottom_diff);
+ int channels = top[0]->shape(softmax_axis_);
+ caffe_copy(count, top_diff, bottom_diff);
// Compute inner1d(top_diff, top_data) and subtract them from the bottom diff.
// NOLINT_NEXT_LINE(whitespace/operators)
- kernel_channel_dot<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
- CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_diff, top_data,
- scale_data);
+ kernel_channel_dot<Dtype><<<CAFFE_GET_BLOCKS(outer_num_ * inner_num_),
+ CAFFE_CUDA_NUM_THREADS>>>(outer_num_, channels, inner_num_,
+ top_diff, top_data, scale_data);
// NOLINT_NEXT_LINE(whitespace/operators)
kernel_channel_subtract<Dtype><<<CAFFE_GET_BLOCKS(count),
- CAFFE_CUDA_NUM_THREADS>>>(count, num, channels, spatial_dim,
+ CAFFE_CUDA_NUM_THREADS>>>(count, outer_num_, channels, inner_num_,
scale_data, bottom_diff);
// elementwise multiplication
caffe_gpu_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto
index 7783a783..8fcb8def 100644
--- a/src/caffe/proto/caffe.proto
+++ b/src/caffe/proto/caffe.proto
@@ -692,6 +692,11 @@ message SoftmaxParameter {
CUDNN = 2;
}
optional Engine engine = 1 [default = DEFAULT];
+
+ // The axis along which to perform the softmax -- may be negative to index
+ // from the end (e.g., -1 for the last axis).
+ // Any other axes will be evaluated as independent softmaxes.
+ optional int32 axis = 2 [default = 1];
}
// Message that stores parameters used by TanHLayer