diff options
author | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2016-11-15 11:19:37 -0800 |
---|---|---|
committer | Evan Shelhamer <shelhamer@imaginarynumber.net> | 2016-11-15 11:28:37 -0800 |
commit | c6ab96596d9eae01c2c403487dc8be8e3edc8fbb (patch) | |
tree | 7f5d3e3376bcfedd3aebc453f6481703fa8cdeae /src | |
parent | eb4ba30e3c4899edc7a9713158d61503fa8ecf90 (diff) | |
download | caffeonacl-c6ab96596d9eae01c2c403487dc8be8e3edc8fbb.tar.gz caffeonacl-c6ab96596d9eae01c2c403487dc8be8e3edc8fbb.tar.bz2 caffeonacl-c6ab96596d9eae01c2c403487dc8be8e3edc8fbb.zip |
sigmoid cross-entropy loss: ignore selected targets by `ignore_label`
sig-ce learns to ignore by zeroing out the loss/diff at targets equal to
the configured `ignore_label`.
n.b. as of now the loss/diff are not properly normalized when there are
ignored targets. sig-ce loss should adopt the same normalization options
as softmax loss.
Diffstat (limited to 'src')
-rw-r--r-- | src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp | 19 | ||||
-rw-r--r-- | src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu | 23 | ||||
-rw-r--r-- | src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp | 28 |
3 files changed, 70 insertions, 0 deletions
diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp index eb77a9c2..21b64c28 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp @@ -14,6 +14,12 @@ void SigmoidCrossEntropyLossLayer<Dtype>::LayerSetUp( sigmoid_top_vec_.clear(); sigmoid_top_vec_.push_back(sigmoid_output_.get()); sigmoid_layer_->SetUp(sigmoid_bottom_vec_, sigmoid_top_vec_); + + has_ignore_label_ = + this->layer_param_.loss_param().has_ignore_label(); + if (has_ignore_label_) { + ignore_label_ = this->layer_param_.loss_param().ignore_label(); + } } template <typename Dtype> @@ -39,6 +45,10 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Forward_cpu( const Dtype* target = bottom[1]->cpu_data(); Dtype loss = 0; for (int i = 0; i < count; ++i) { + const int target_value = static_cast<int>(target[i]); + if (has_ignore_label_ && target_value == ignore_label_) { + continue; + } loss -= input_data[i] * (target[i] - (input_data[i] >= 0)) - log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0))); } @@ -64,6 +74,15 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_cpu( // Scale down gradient const Dtype loss_weight = top[0]->cpu_diff()[0]; caffe_scal(count, loss_weight / num, bottom_diff); + // Zero out gradient of ignored targets. + if (has_ignore_label_) { + for (int i = 0; i < count; ++i) { + const int target_value = static_cast<int>(target[i]); + if (target_value == ignore_label_) { + bottom_diff[i] = 0; + } + } + } } } diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu index 7cb982d2..39eb0506 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu @@ -15,6 +15,17 @@ __global__ void SigmoidCrossEntropyLossForwardGPU(const int nthreads, } template <typename Dtype> +__global__ void SigmoidCrossEntropyLossIgnoreGPU(const int count, + const int ignore_label, const Dtype* target, Dtype* reference) { + CUDA_KERNEL_LOOP(index, count) { + const int target_value = static_cast<int>(target[index]); + if (target_value == ignore_label) { + reference[index] = 0; + } + } +} + +template <typename Dtype> void SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu( const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { // The forward pass computes the sigmoid outputs. @@ -33,6 +44,12 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu( // NOLINT_NEXT_LINE(whitespace/operators) SigmoidCrossEntropyLossForwardGPU<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(count, input_data, target, loss_data); + // Zero out loss of ignored targets. + if (has_ignore_label_) { + // NOLINT_NEXT_LINE(whitespace/operators) + SigmoidCrossEntropyLossIgnoreGPU<Dtype><<<CAFFE_GET_BLOCKS(count), + CAFFE_CUDA_NUM_THREADS>>>(count, ignore_label_, target, loss_data); + } Dtype loss; caffe_gpu_asum(count, loss_data, &loss); top[0]->mutable_cpu_data()[0] = loss / num; @@ -58,6 +75,12 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_gpu( // Scale down gradient const Dtype loss_weight = top[0]->cpu_diff()[0]; caffe_gpu_scal(count, loss_weight / num, bottom_diff); + // Zero out gradient of ignored targets. + if (has_ignore_label_) { + // NOLINT_NEXT_LINE(whitespace/operators) + SigmoidCrossEntropyLossIgnoreGPU<Dtype><<<CAFFE_GET_BLOCKS(count), + CAFFE_CUDA_NUM_THREADS>>>(count, ignore_label_, target, bottom_diff); + } } } diff --git a/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp index 5dfd7656..1bd5f937 100644 --- a/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp +++ b/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp @@ -116,5 +116,33 @@ TYPED_TEST(SigmoidCrossEntropyLossLayerTest, TestGradient) { this->blob_top_vec_, 0); } +TYPED_TEST(SigmoidCrossEntropyLossLayerTest, TestIgnoreGradient) { + typedef typename TypeParam::Dtype Dtype; + FillerParameter data_filler_param; + data_filler_param.set_std(1); + GaussianFiller<Dtype> data_filler(data_filler_param); + data_filler.Fill(this->blob_bottom_data_); + LayerParameter layer_param; + LossParameter* loss_param = layer_param.mutable_loss_param(); + loss_param->set_ignore_label(-1); + Dtype* target = this->blob_bottom_targets_->mutable_cpu_data(); + const int count = this->blob_bottom_targets_->count(); + // Ignore half of targets, then check that diff of this half is zero, + // while the other half is nonzero. + caffe_set(count / 2, Dtype(-1), target); + SigmoidCrossEntropyLossLayer<Dtype> layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + vector<bool> propagate_down(2); + propagate_down[0] = true; + propagate_down[1] = false; + layer.Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_); + const Dtype* diff = this->blob_bottom_data_->cpu_diff(); + for (int i = 0; i < count / 2; ++i) { + EXPECT_FLOAT_EQ(diff[i], 0.); + EXPECT_NE(diff[i + count / 2], 0.); + } +} + } // namespace caffe |