summaryrefslogtreecommitdiff
path: root/src/caffe
diff options
context:
space:
mode:
authorJeff Donahue <jeff.donahue@gmail.com>2014-07-11 01:55:17 -0700
committerJeff Donahue <jeff.donahue@gmail.com>2014-08-13 13:22:04 -0700
commit512a626fc71c69ed4460024b31c5fe8dff1e668c (patch)
treef3d11beb593a4e64e779a99b82538ceee7fae21a /src/caffe
parent7a3ed9b8edf43895770b63cb4d9f5cacf0dba047 (diff)
downloadcaffeonacl-512a626fc71c69ed4460024b31c5fe8dff1e668c.tar.gz
caffeonacl-512a626fc71c69ed4460024b31c5fe8dff1e668c.tar.bz2
caffeonacl-512a626fc71c69ed4460024b31c5fe8dff1e668c.zip
Generalize loss by allowing any top blob to be used as a loss in which
its elements are summed with a scalar coefficient. Forward for layers no longer returns a loss; instead all loss layers must have top blobs. Existing loss layers are given a top blob automatically by Net::Init, with an associated top_loss_weight of 1 (set in LossLayer::FurtherSetUp). Due to the increased amount of common SetUp logic, the SetUp interface is modified such that all subclasses should normally override FurtherSetUp only, which is called by SetUp.
Diffstat (limited to 'src/caffe')
-rw-r--r--src/caffe/layers/accuracy_layer.cpp7
-rw-r--r--src/caffe/layers/argmax_layer.cpp6
-rw-r--r--src/caffe/layers/bnll_layer.cpp3
-rw-r--r--src/caffe/layers/bnll_layer.cu3
-rw-r--r--src/caffe/layers/concat_layer.cpp10
-rw-r--r--src/caffe/layers/concat_layer.cu3
-rw-r--r--src/caffe/layers/conv_layer.cpp6
-rw-r--r--src/caffe/layers/conv_layer.cu3
-rw-r--r--src/caffe/layers/data_layer.cpp6
-rw-r--r--src/caffe/layers/data_layer.cu3
-rw-r--r--src/caffe/layers/dropout_layer.cpp7
-rw-r--r--src/caffe/layers/dropout_layer.cu3
-rw-r--r--src/caffe/layers/dummy_data_layer.cpp7
-rw-r--r--src/caffe/layers/eltwise_layer.cpp6
-rw-r--r--src/caffe/layers/eltwise_layer.cu3
-rw-r--r--src/caffe/layers/euclidean_loss_layer.cpp13
-rw-r--r--src/caffe/layers/euclidean_loss_layer.cu10
-rw-r--r--src/caffe/layers/flatten_layer.cpp6
-rw-r--r--src/caffe/layers/flatten_layer.cu3
-rw-r--r--src/caffe/layers/hdf5_data_layer.cpp6
-rw-r--r--src/caffe/layers/hdf5_data_layer.cu3
-rw-r--r--src/caffe/layers/hdf5_output_layer.cpp3
-rw-r--r--src/caffe/layers/hdf5_output_layer.cu3
-rw-r--r--src/caffe/layers/hinge_loss_layer.cpp14
-rw-r--r--src/caffe/layers/im2col_layer.cpp6
-rw-r--r--src/caffe/layers/im2col_layer.cu3
-rw-r--r--src/caffe/layers/image_data_layer.cpp8
-rw-r--r--src/caffe/layers/image_data_layer.cu3
-rw-r--r--src/caffe/layers/infogain_loss_layer.cpp16
-rw-r--r--src/caffe/layers/inner_product_layer.cpp6
-rw-r--r--src/caffe/layers/inner_product_layer.cu3
-rw-r--r--src/caffe/layers/loss_layer.cpp11
-rw-r--r--src/caffe/layers/lrn_layer.cpp19
-rw-r--r--src/caffe/layers/lrn_layer.cu12
-rw-r--r--src/caffe/layers/memory_data_layer.cpp6
-rw-r--r--src/caffe/layers/multinomial_logistic_loss_layer.cpp13
-rw-r--r--src/caffe/layers/mvn_layer.cpp7
-rw-r--r--src/caffe/layers/mvn_layer.cu4
-rw-r--r--src/caffe/layers/neuron_layer.cpp6
-rw-r--r--src/caffe/layers/pooling_layer.cpp15
-rw-r--r--src/caffe/layers/pooling_layer.cu3
-rw-r--r--src/caffe/layers/power_layer.cpp8
-rw-r--r--src/caffe/layers/power_layer.cu4
-rw-r--r--src/caffe/layers/relu_layer.cpp3
-rw-r--r--src/caffe/layers/relu_layer.cu3
-rw-r--r--src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp13
-rw-r--r--src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu10
-rw-r--r--src/caffe/layers/sigmoid_layer.cpp3
-rw-r--r--src/caffe/layers/sigmoid_layer.cu3
-rw-r--r--src/caffe/layers/slice_layer.cpp6
-rw-r--r--src/caffe/layers/slice_layer.cu3
-rw-r--r--src/caffe/layers/softmax_layer.cpp6
-rw-r--r--src/caffe/layers/softmax_layer.cu3
-rw-r--r--src/caffe/layers/softmax_loss_layer.cpp28
-rw-r--r--src/caffe/layers/softmax_loss_layer.cu5
-rw-r--r--src/caffe/layers/split_layer.cpp6
-rw-r--r--src/caffe/layers/split_layer.cu3
-rw-r--r--src/caffe/layers/tanh_layer.cpp3
-rw-r--r--src/caffe/layers/tanh_layer.cu3
-rw-r--r--src/caffe/layers/threshold_layer.cpp7
-rw-r--r--src/caffe/layers/threshold_layer.cu4
-rw-r--r--src/caffe/layers/window_data_layer.cpp8
-rw-r--r--src/caffe/layers/window_data_layer.cu3
-rw-r--r--src/caffe/net.cpp32
-rw-r--r--src/caffe/test/test_euclidean_loss_layer.cpp42
-rw-r--r--src/caffe/test/test_hinge_loss_layer.cpp21
-rw-r--r--src/caffe/test/test_infogain_loss_layer.cpp10
-rw-r--r--src/caffe/test/test_multinomial_logistic_loss_layer.cpp10
-rw-r--r--src/caffe/test/test_net.cpp143
-rw-r--r--src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp16
-rw-r--r--src/caffe/test/test_softmax_with_loss_layer.cpp12
-rw-r--r--src/caffe/test/test_split_layer.cpp1
72 files changed, 308 insertions, 392 deletions
diff --git a/src/caffe/layers/accuracy_layer.cpp b/src/caffe/layers/accuracy_layer.cpp
index 76889d8b..062e9271 100644
--- a/src/caffe/layers/accuracy_layer.cpp
+++ b/src/caffe/layers/accuracy_layer.cpp
@@ -11,9 +11,8 @@
namespace caffe {
template <typename Dtype>
-void AccuracyLayer<Dtype>::SetUp(
+void AccuracyLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
top_k_ = this->layer_param_.accuracy_param().top_k();
CHECK_EQ(bottom[0]->num(), bottom[1]->num())
<< "The data and label should have the same number.";
@@ -26,7 +25,7 @@ void AccuracyLayer<Dtype>::SetUp(
}
template <typename Dtype>
-Dtype AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Dtype accuracy = 0;
const Dtype* bottom_data = bottom[0]->cpu_data();
@@ -56,9 +55,7 @@ Dtype AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
// LOG(INFO) << "Accuracy: " << accuracy;
(*top)[0]->mutable_cpu_data()[0] = accuracy / num;
-
// Accuracy layer should not be used as a loss function.
- return Dtype(0);
}
INSTANTIATE_CLASS(AccuracyLayer);
diff --git a/src/caffe/layers/argmax_layer.cpp b/src/caffe/layers/argmax_layer.cpp
index b2ef91ea..4b67f24c 100644
--- a/src/caffe/layers/argmax_layer.cpp
+++ b/src/caffe/layers/argmax_layer.cpp
@@ -9,9 +9,8 @@
namespace caffe {
template <typename Dtype>
-void ArgMaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void ArgMaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
out_max_val_ = this->layer_param_.argmax_param().out_max_val();
top_k_ = this->layer_param_.argmax_param().top_k();
CHECK_GE(top_k_, 1) << " top k must not be less than 1.";
@@ -27,7 +26,7 @@ void ArgMaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -51,7 +50,6 @@ Dtype ArgMaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
}
}
}
- return Dtype(0);
}
INSTANTIATE_CLASS(ArgMaxLayer);
diff --git a/src/caffe/layers/bnll_layer.cpp b/src/caffe/layers/bnll_layer.cpp
index 4cb85203..ef98326a 100644
--- a/src/caffe/layers/bnll_layer.cpp
+++ b/src/caffe/layers/bnll_layer.cpp
@@ -9,7 +9,7 @@ namespace caffe {
const float kBNLL_THRESHOLD = 50.;
template <typename Dtype>
-Dtype BNLLLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void BNLLLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -19,7 +19,6 @@ Dtype BNLLLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
bottom_data[i] + log(1. + exp(-bottom_data[i])) :
log(1. + exp(bottom_data[i]));
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/bnll_layer.cu b/src/caffe/layers/bnll_layer.cu
index 9895a061..b940133b 100644
--- a/src/caffe/layers/bnll_layer.cu
+++ b/src/caffe/layers/bnll_layer.cu
@@ -18,7 +18,7 @@ __global__ void BNLLForward(const int n, const Dtype* in, Dtype* out) {
}
template <typename Dtype>
-Dtype BNLLLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void BNLLLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -27,7 +27,6 @@ Dtype BNLLLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
BNLLForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data);
CUDA_POST_KERNEL_CHECK;
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/concat_layer.cpp b/src/caffe/layers/concat_layer.cpp
index b76d4b2c..73d28b17 100644
--- a/src/caffe/layers/concat_layer.cpp
+++ b/src/caffe/layers/concat_layer.cpp
@@ -7,9 +7,8 @@
namespace caffe {
template <typename Dtype>
-void ConcatLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void ConcatLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
concat_dim_ = this->layer_param_.concat_param().concat_dim();
CHECK_GE(concat_dim_, 0) <<
"concat_dim should be >= 0";
@@ -39,7 +38,7 @@ void ConcatLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype ConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void ConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Dtype* top_data = (*top)[0]->mutable_cpu_data();
if (concat_dim_== 0) {
@@ -61,9 +60,8 @@ Dtype ConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
top_data+(*top)[0]->offset(n, offset_channel));
}
offset_channel += bottom[i]->channels();
- } // concat_dim_ is guaranteed to be 0 or 1 by SetUp.
+ } // concat_dim_ is guaranteed to be 0 or 1 by LayerSetUp.
}
- return Dtype(0.);
}
template <typename Dtype>
@@ -95,7 +93,7 @@ void ConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
}
offset_channel += blob->channels();
}
- } // concat_dim_ is guaranteed to be 0 or 1 by SetUp.
+ } // concat_dim_ is guaranteed to be 0 or 1 by LayerSetUp.
}
#ifdef CPU_ONLY
diff --git a/src/caffe/layers/concat_layer.cu b/src/caffe/layers/concat_layer.cu
index aea8b77e..99c55da2 100644
--- a/src/caffe/layers/concat_layer.cu
+++ b/src/caffe/layers/concat_layer.cu
@@ -7,7 +7,7 @@
namespace caffe {
template <typename Dtype>
-Dtype ConcatLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void ConcatLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Dtype* top_data = (*top)[0]->mutable_gpu_data();
if (concat_dim_ == 0) {
@@ -34,7 +34,6 @@ Dtype ConcatLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
LOG(FATAL) << "concat_dim along dim" << concat_dim_ <<
" not implemented yet";
}
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp
index df3e31ba..1a1248f3 100644
--- a/src/caffe/layers/conv_layer.cpp
+++ b/src/caffe/layers/conv_layer.cpp
@@ -9,9 +9,8 @@
namespace caffe {
template <typename Dtype>
-void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
ConvolutionParameter conv_param = this->layer_param_.convolution_param();
CHECK(!conv_param.has_kernel_size() !=
!(conv_param.has_kernel_h() && conv_param.has_kernel_w()))
@@ -117,7 +116,7 @@ void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
template <typename Dtype>
-Dtype ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->cpu_data();
@@ -147,7 +146,6 @@ Dtype ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
}
}
}
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/conv_layer.cu b/src/caffe/layers/conv_layer.cu
index 04ae1393..f7f393ba 100644
--- a/src/caffe/layers/conv_layer.cu
+++ b/src/caffe/layers/conv_layer.cu
@@ -9,7 +9,7 @@
namespace caffe {
template <typename Dtype>
-Dtype ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->gpu_data();
@@ -39,7 +39,6 @@ Dtype ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
}
}
}
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp
index 8f17c45e..c2b0c73a 100644
--- a/src/caffe/layers/data_layer.cpp
+++ b/src/caffe/layers/data_layer.cpp
@@ -160,9 +160,8 @@ DataLayer<Dtype>::~DataLayer<Dtype>() {
}
template <typename Dtype>
-void DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
if (top->size() == 1) {
output_labels_ = false;
} else {
@@ -332,7 +331,7 @@ unsigned int DataLayer<Dtype>::PrefetchRand() {
}
template <typename Dtype>
-Dtype DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, join the thread
JoinPrefetchThread();
@@ -345,7 +344,6 @@ Dtype DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
}
// Start a new prefetch thread
CreatePrefetchThread();
- return Dtype(0.);
}
#ifdef CPU_ONLY
diff --git a/src/caffe/layers/data_layer.cu b/src/caffe/layers/data_layer.cu
index 2ae1a640..467b146f 100644
--- a/src/caffe/layers/data_layer.cu
+++ b/src/caffe/layers/data_layer.cu
@@ -12,7 +12,7 @@
namespace caffe {
template <typename Dtype>
-Dtype DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, join the thread
JoinPrefetchThread();
@@ -25,7 +25,6 @@ Dtype DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
}
// Start a new prefetch thread
CreatePrefetchThread();
- return Dtype(0.);
}
INSTANTIATE_CLASS(DataLayer);
diff --git a/src/caffe/layers/dropout_layer.cpp b/src/caffe/layers/dropout_layer.cpp
index 0621b56e..52537d1a 100644
--- a/src/caffe/layers/dropout_layer.cpp
+++ b/src/caffe/layers/dropout_layer.cpp
@@ -11,9 +11,9 @@
namespace caffe {
template <typename Dtype>
-void DropoutLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void DropoutLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- NeuronLayer<Dtype>::SetUp(bottom, top);
+ NeuronLayer<Dtype>::LayerSetUp(bottom, top);
// Set up the cache for random number generation
rand_vec_.Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
@@ -25,7 +25,7 @@ void DropoutLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -40,7 +40,6 @@ Dtype DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
} else {
caffe_copy(bottom[0]->count(), bottom_data, top_data);
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/dropout_layer.cu b/src/caffe/layers/dropout_layer.cu
index 9bcd687b..9756c862 100644
--- a/src/caffe/layers/dropout_layer.cu
+++ b/src/caffe/layers/dropout_layer.cu
@@ -21,7 +21,7 @@ __global__ void DropoutForward(const int n, const Dtype* in,
}
template <typename Dtype>
-Dtype DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -38,7 +38,6 @@ Dtype DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
} else {
caffe_copy(count, bottom_data, top_data);
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/dummy_data_layer.cpp b/src/caffe/layers/dummy_data_layer.cpp
index 98b437ee..883f2528 100644
--- a/src/caffe/layers/dummy_data_layer.cpp
+++ b/src/caffe/layers/dummy_data_layer.cpp
@@ -7,7 +7,7 @@
namespace caffe {
template <typename Dtype>
-void DummyDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void DummyDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const int num_top = top->size();
const DummyDataParameter& param = this->layer_param_.dummy_data_param();
@@ -32,7 +32,7 @@ void DummyDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
// If refill_[i] is false, Forward does nothing for Blob i. We use this to
// avoid wastefully refilling "constant" Blobs in every forward pass.
// We first fill refill_ in with the INVERSE of its final values.
- // The first time we run Forward from the SetUp method, we'll fill only the
+ // The first time we run Forward from the LayerSetUp method, we'll fill only
// Blobs for which refill_ is normally false. These Blobs will never be
// filled again.
refill_.clear();
@@ -82,7 +82,7 @@ void DummyDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype DummyDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void DummyDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
for (int i = 0; i < top->size(); ++i) {
const int filler_id = (fillers_.size() > 1) ? i : 0;
@@ -90,7 +90,6 @@ Dtype DummyDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
fillers_[filler_id]->Fill((*top)[i]);
}
}
- return Dtype(0.);
}
INSTANTIATE_CLASS(DummyDataLayer);
diff --git a/src/caffe/layers/eltwise_layer.cpp b/src/caffe/layers/eltwise_layer.cpp
index 8085b464..ec6a46ff 100644
--- a/src/caffe/layers/eltwise_layer.cpp
+++ b/src/caffe/layers/eltwise_layer.cpp
@@ -7,9 +7,8 @@
namespace caffe {
template <typename Dtype>
-void EltwiseLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void EltwiseLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
CHECK(this->layer_param().eltwise_param().coeff_size() == 0
|| this->layer_param().eltwise_param().coeff_size() == bottom.size()) <<
"Eltwise Layer takes one coefficient per bottom blob.";
@@ -39,7 +38,7 @@ void EltwiseLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype EltwiseLayer<Dtype>::Forward_cpu(
+void EltwiseLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
const int count = (*top)[0]->count();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -60,7 +59,6 @@ Dtype EltwiseLayer<Dtype>::Forward_cpu(
default:
LOG(FATAL) << "Unknown elementwise operation.";
}
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/eltwise_layer.cu b/src/caffe/layers/eltwise_layer.cu
index eec8857c..4b38949d 100644
--- a/src/caffe/layers/eltwise_layer.cu
+++ b/src/caffe/layers/eltwise_layer.cu
@@ -7,7 +7,7 @@
namespace caffe {
template <typename Dtype>
-Dtype EltwiseLayer<Dtype>::Forward_gpu(
+void EltwiseLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
const int count = (*top)[0]->count();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -29,7 +29,6 @@ Dtype EltwiseLayer<Dtype>::Forward_gpu(
default:
LOG(FATAL) << "Unknown elementwise operation.";
}
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/euclidean_loss_layer.cpp b/src/caffe/layers/euclidean_loss_layer.cpp
index 17180d40..be83601f 100644
--- a/src/caffe/layers/euclidean_loss_layer.cpp
+++ b/src/caffe/layers/euclidean_loss_layer.cpp
@@ -8,8 +8,9 @@
namespace caffe {
template <typename Dtype>
-void EuclideanLossLayer<Dtype>::FurtherSetUp(
+void EuclideanLossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+ LossLayer<Dtype>::LayerSetUp(bottom, top);
CHECK_EQ(bottom[0]->channels(), bottom[1]->channels());
CHECK_EQ(bottom[0]->height(), bottom[1]->height());
CHECK_EQ(bottom[0]->width(), bottom[1]->width());
@@ -18,7 +19,7 @@ void EuclideanLossLayer<Dtype>::FurtherSetUp(
}
template <typename Dtype>
-Dtype EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
int count = bottom[0]->count();
caffe_sub(
@@ -28,10 +29,7 @@ Dtype EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
diff_.mutable_cpu_data());
Dtype dot = caffe_cpu_dot(count, diff_.cpu_data(), diff_.cpu_data());
Dtype loss = dot / bottom[0]->num() / Dtype(2);
- if (top->size() == 1) {
- (*top)[0]->mutable_cpu_data()[0] = loss;
- }
- return loss;
+ (*top)[0]->mutable_cpu_data()[0] = loss;
}
template <typename Dtype>
@@ -40,9 +38,10 @@ void EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
+ const Dtype alpha = sign * top[0]->cpu_diff()[0] / (*bottom)[i]->num();
caffe_cpu_axpby(
(*bottom)[i]->count(), // count
- sign / (*bottom)[i]->num(), // alpha
+ alpha, // alpha
diff_.cpu_data(), // a
Dtype(0), // beta
(*bottom)[i]->mutable_cpu_diff()); // b
diff --git a/src/caffe/layers/euclidean_loss_layer.cu b/src/caffe/layers/euclidean_loss_layer.cu
index f4dfd0b5..70b1b9ee 100644
--- a/src/caffe/layers/euclidean_loss_layer.cu
+++ b/src/caffe/layers/euclidean_loss_layer.cu
@@ -8,7 +8,7 @@
namespace caffe {
template <typename Dtype>
-Dtype EuclideanLossLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void EuclideanLossLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
int count = bottom[0]->count();
caffe_gpu_sub(
@@ -19,10 +19,7 @@ Dtype EuclideanLossLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
Dtype dot;
caffe_gpu_dot(count, diff_.gpu_data(), diff_.gpu_data(), &dot);
Dtype loss = dot / bottom[0]->num() / Dtype(2);
- if (top->size() == 1) {
- (*top)[0]->mutable_cpu_data()[0] = loss;
- }
- return loss;
+ (*top)[0]->mutable_cpu_data()[0] = loss;
}
template <typename Dtype>
@@ -31,9 +28,10 @@ void EuclideanLossLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
+ const Dtype alpha = sign * top[0]->cpu_diff()[0] / (*bottom)[i]->num();
caffe_gpu_axpby(
(*bottom)[i]->count(), // count
- sign / (*bottom)[i]->num(), // alpha
+ alpha, // alpha
diff_.gpu_data(), // a
Dtype(0), // beta
(*bottom)[i]->mutable_gpu_diff()); // b
diff --git a/src/caffe/layers/flatten_layer.cpp b/src/caffe/layers/flatten_layer.cpp
index 81a506a8..8c1fc74e 100644
--- a/src/caffe/layers/flatten_layer.cpp
+++ b/src/caffe/layers/flatten_layer.cpp
@@ -7,9 +7,8 @@
namespace caffe {
template <typename Dtype>
-void FlattenLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void FlattenLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
int channels_out = bottom[0]->channels() * bottom[0]->height()
* bottom[0]->width();
(*top)[0]->Reshape(bottom[0]->num(), channels_out, 1, 1);
@@ -19,10 +18,9 @@ void FlattenLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype FlattenLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void FlattenLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
(*top)[0]->ShareData(*bottom[0]);
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/flatten_layer.cu b/src/caffe/layers/flatten_layer.cu
index 7233afb3..ff23f523 100644
--- a/src/caffe/layers/flatten_layer.cu
+++ b/src/caffe/layers/flatten_layer.cu
@@ -7,10 +7,9 @@
namespace caffe {
template <typename Dtype>
-Dtype FlattenLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void FlattenLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
(*top)[0]->ShareData(*bottom[0]);
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/hdf5_data_layer.cpp b/src/caffe/layers/hdf5_data_layer.cpp
index 938d8435..1f2a8358 100644
--- a/src/caffe/layers/hdf5_data_layer.cpp
+++ b/src/caffe/layers/hdf5_data_layer.cpp
@@ -50,9 +50,8 @@ void HDF5DataLayer<Dtype>::LoadHDF5FileData(const char* filename) {
}
template <typename Dtype>
-void HDF5DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void HDF5DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
// Read the source to parse the filenames.
const string& source = this->layer_param_.hdf5_data_param().source();
LOG(INFO) << "Loading filename from " << source;
@@ -85,7 +84,7 @@ void HDF5DataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype HDF5DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void HDF5DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const int batch_size = this->layer_param_.hdf5_data_param().batch_size();
const int data_count = (*top)[0]->count() / (*top)[0]->num();
@@ -109,7 +108,6 @@ Dtype HDF5DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
&label_blob_.cpu_data()[current_row_ * label_data_count],
&(*top)[1]->mutable_cpu_data()[i * label_data_count]);
}
- return Dtype(0.);
}
#ifdef CPU_ONLY
diff --git a/src/caffe/layers/hdf5_data_layer.cu b/src/caffe/layers/hdf5_data_layer.cu
index 1f682d57..79cc536e 100644
--- a/src/caffe/layers/hdf5_data_layer.cu
+++ b/src/caffe/layers/hdf5_data_layer.cu
@@ -17,7 +17,7 @@ TODO:
namespace caffe {
template <typename Dtype>
-Dtype HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const int batch_size = this->layer_param_.hdf5_data_param().batch_size();
const int data_count = (*top)[0]->count() / (*top)[0]->num();
@@ -44,7 +44,6 @@ Dtype HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
&label_blob_.cpu_data()[current_row_ * label_data_count],
&(*top)[1]->mutable_gpu_data()[i * label_data_count]);
}
- return Dtype(0.);
}
INSTANTIATE_CLASS(HDF5DataLayer);
diff --git a/src/caffe/layers/hdf5_output_layer.cpp b/src/caffe/layers/hdf5_output_layer.cpp
index 0d7590b1..3cdbbb31 100644
--- a/src/caffe/layers/hdf5_output_layer.cpp
+++ b/src/caffe/layers/hdf5_output_layer.cpp
@@ -39,7 +39,7 @@ void HDF5OutputLayer<Dtype>::SaveBlobs() {
}
template <typename Dtype>
-Dtype HDF5OutputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void HDF5OutputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_GE(bottom.size(), 2);
CHECK_EQ(bottom[0]->num(), bottom[1]->num());
@@ -57,7 +57,6 @@ Dtype HDF5OutputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
&label_blob_.mutable_cpu_data()[i * label_datum_dim]);
}
SaveBlobs();
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/hdf5_output_layer.cu b/src/caffe/layers/hdf5_output_layer.cu
index d2f20b3f..0813c02a 100644
--- a/src/caffe/layers/hdf5_output_layer.cu
+++ b/src/caffe/layers/hdf5_output_layer.cu
@@ -12,7 +12,7 @@
namespace caffe {
template <typename Dtype>
-Dtype HDF5OutputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void HDF5OutputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_GE(bottom.size(), 2);
CHECK_EQ(bottom[0]->num(), bottom[1]->num());
@@ -30,7 +30,6 @@ Dtype HDF5OutputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
&label_blob_.mutable_cpu_data()[i * label_datum_dim]);
}
SaveBlobs();
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/hinge_loss_layer.cpp b/src/caffe/layers/hinge_loss_layer.cpp
index bc3a593c..8022aae2 100644
--- a/src/caffe/layers/hinge_loss_layer.cpp
+++ b/src/caffe/layers/hinge_loss_layer.cpp
@@ -11,7 +11,7 @@
namespace caffe {
template <typename Dtype>
-Dtype HingeLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void HingeLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
@@ -30,11 +30,14 @@ Dtype HingeLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
Dtype(0), 1 + bottom_diff[i * dim + j]);
}
}
+ Dtype* loss = (*top)[0]->mutable_cpu_data();
switch (this->layer_param_.hinge_loss_param().norm()) {
case HingeLossParameter_Norm_L1:
- return caffe_cpu_asum(count, bottom_diff) / num;
+ loss[0] = caffe_cpu_asum(count, bottom_diff) / num;
+ break;
case HingeLossParameter_Norm_L2:
- return caffe_cpu_dot(count, bottom_diff, bottom_diff) / num;
+ loss[0] = caffe_cpu_dot(count, bottom_diff, bottom_diff) / num;
+ break;
default:
LOG(FATAL) << "Unknown Norm";
}
@@ -58,13 +61,14 @@ void HingeLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
bottom_diff[i * dim + static_cast<int>(label[i])] *= -1;
}
+ const Dtype loss_weight = top[0]->cpu_diff()[0];
switch (this->layer_param_.hinge_loss_param().norm()) {
case HingeLossParameter_Norm_L1:
caffe_cpu_sign(count, bottom_diff, bottom_diff);
- caffe_scal(count, Dtype(1. / num), bottom_diff);
+ caffe_scal(count, loss_weight / num, bottom_diff);
break;
case HingeLossParameter_Norm_L2:
- caffe_scal(count, Dtype(2. / num), bottom_diff);
+ caffe_scal(count, loss_weight * 2 / num, bottom_diff);
break;
default:
LOG(FATAL) << "Unknown Norm";
diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp
index 2dd74762..02f33f1c 100644
--- a/src/caffe/layers/im2col_layer.cpp
+++ b/src/caffe/layers/im2col_layer.cpp
@@ -8,9 +8,8 @@
namespace caffe {
template <typename Dtype>
-void Im2colLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void Im2colLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
ConvolutionParameter conv_param = this->layer_param_.convolution_param();
CHECK(!conv_param.has_kernel_size() !=
!(conv_param.has_kernel_h() && conv_param.has_kernel_w()))
@@ -56,7 +55,7 @@ void Im2colLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -65,7 +64,6 @@ Dtype Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
width_, kernel_h_, kernel_w_, pad_h_, pad_w_,
stride_h_, stride_w_, top_data + (*top)[0]->offset(n));
}
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu
index 6b4c7010..8df061d8 100644
--- a/src/caffe/layers/im2col_layer.cu
+++ b/src/caffe/layers/im2col_layer.cu
@@ -8,7 +8,7 @@
namespace caffe {
template <typename Dtype>
-Dtype Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -17,7 +17,6 @@ Dtype Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
width_, kernel_h_, kernel_w_, pad_h_, pad_w_,
stride_h_, stride_w_, top_data + (*top)[0]->offset(n));
}
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/image_data_layer.cpp b/src/caffe/layers/image_data_layer.cpp
index a0f03a82..c72bf9c0 100644
--- a/src/caffe/layers/image_data_layer.cpp
+++ b/src/caffe/layers/image_data_layer.cpp
@@ -123,10 +123,9 @@ ImageDataLayer<Dtype>::~ImageDataLayer<Dtype>() {
}
template <typename Dtype>
-void ImageDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void ImageDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
- const int new_height = this->layer_param_.image_data_param().new_height();
+ const int new_height = this->layer_param_.image_data_param().new_height();
const int new_width = this->layer_param_.image_data_param().new_width();
CHECK((new_height == 0 && new_width == 0) ||
(new_height > 0 && new_width > 0)) << "Current implementation requires "
@@ -252,7 +251,7 @@ unsigned int ImageDataLayer<Dtype>::PrefetchRand() {
}
template <typename Dtype>
-Dtype ImageDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void ImageDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, join the thread
JoinPrefetchThread();
@@ -263,7 +262,6 @@ Dtype ImageDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
(*top)[1]->mutable_cpu_data());
// Start a new prefetch thread
CreatePrefetchThread();
- return Dtype(0.);
}
#ifdef CPU_ONLY
diff --git a/src/caffe/layers/image_data_layer.cu b/src/caffe/layers/image_data_layer.cu
index f61409cc..30a22100 100644
--- a/src/caffe/layers/image_data_layer.cu
+++ b/src/caffe/layers/image_data_layer.cu
@@ -9,7 +9,7 @@
namespace caffe {
template <typename Dtype>
-Dtype ImageDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void ImageDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, join the thread
JoinPrefetchThread();
@@ -20,7 +20,6 @@ Dtype ImageDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
(*top)[1]->mutable_gpu_data());
// Start a new prefetch thread
CreatePrefetchThread();
- return Dtype(0.);
}
INSTANTIATE_CLASS(ImageDataLayer);
diff --git a/src/caffe/layers/infogain_loss_layer.cpp b/src/caffe/layers/infogain_loss_layer.cpp
index fa01116e..91dd8924 100644
--- a/src/caffe/layers/infogain_loss_layer.cpp
+++ b/src/caffe/layers/infogain_loss_layer.cpp
@@ -11,8 +11,9 @@
namespace caffe {
template <typename Dtype>
-void InfogainLossLayer<Dtype>::FurtherSetUp(
+void InfogainLossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+ LossLayer<Dtype>::LayerSetUp(bottom, top);
CHECK_EQ(bottom[1]->channels(), 1);
CHECK_EQ(bottom[1]->height(), 1);
CHECK_EQ(bottom[1]->width(), 1);
@@ -38,7 +39,7 @@ void InfogainLossLayer<Dtype>::FurtherSetUp(
template <typename Dtype>
-Dtype InfogainLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void InfogainLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* bottom_label = bottom[1]->cpu_data();
@@ -58,11 +59,7 @@ Dtype InfogainLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
loss -= infogain_mat[label * dim + j] * log(prob);
}
}
- loss /= num;
- if (top->size() == 1) {
- (*top)[0]->mutable_cpu_data()[0] = loss;
- }
- return loss;
+ (*top)[0]->mutable_cpu_data()[0] = loss / num;
}
template <typename Dtype>
@@ -89,11 +86,12 @@ void InfogainLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
int num = (*bottom)[0]->num();
int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
+ const Dtype scale = - top[0]->cpu_diff()[0] / num;
for (int i = 0; i < num; ++i) {
- int label = static_cast<int>(bottom_label[i]);
+ const int label = static_cast<int>(bottom_label[i]);
for (int j = 0; j < dim; ++j) {
Dtype prob = std::max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD));
- bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num;
+ bottom_diff[i * dim + j] = scale * infogain_mat[label * dim + j] / prob;
}
}
}
diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp
index a9e0f353..3ba0e1f2 100644
--- a/src/caffe/layers/inner_product_layer.cpp
+++ b/src/caffe/layers/inner_product_layer.cpp
@@ -10,9 +10,8 @@
namespace caffe {
template <typename Dtype>
-void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
const int num_output = this->layer_param_.inner_product_param().num_output();
bias_term_ = this->layer_param_.inner_product_param().bias_term();
// Figure out the dimensions
@@ -52,7 +51,7 @@ void InnerProductLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -64,7 +63,6 @@ Dtype InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
bias_multiplier_.cpu_data(),
this->blobs_[1]->cpu_data(), (Dtype)1., top_data);
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu
index e0210720..3a0d4388 100644
--- a/src/caffe/layers/inner_product_layer.cu
+++ b/src/caffe/layers/inner_product_layer.cu
@@ -10,7 +10,7 @@
namespace caffe {
template <typename Dtype>
-Dtype InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -22,7 +22,6 @@ Dtype InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
bias_multiplier_.gpu_data(),
this->blobs_[1]->gpu_data(), (Dtype)1., top_data);
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/loss_layer.cpp b/src/caffe/layers/loss_layer.cpp
index 48665221..89d8c91e 100644
--- a/src/caffe/layers/loss_layer.cpp
+++ b/src/caffe/layers/loss_layer.cpp
@@ -11,16 +11,15 @@
namespace caffe {
template <typename Dtype>
-void LossLayer<Dtype>::SetUp(
+void LossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
CHECK_EQ(bottom[0]->num(), bottom[1]->num())
<< "The data and label should have the same number.";
- if (top->size() == 1) {
- // Layers should copy the loss in the top blob
- (*top)[0]->Reshape(1, 1, 1, 1);
+ (*top)[0]->Reshape(1, 1, 1, 1);
+ // LossLayers have a non-zero (1) loss by default.
+ if (this->layer_param_.loss_weight_size() == 0) {
+ this->layer_param_.add_loss_weight(Dtype(1));
}
- FurtherSetUp(bottom, top);
}
INSTANTIATE_CLASS(LossLayer);
diff --git a/src/caffe/layers/lrn_layer.cpp b/src/caffe/layers/lrn_layer.cpp
index e77f6857..c76ca95d 100644
--- a/src/caffe/layers/lrn_layer.cpp
+++ b/src/caffe/layers/lrn_layer.cpp
@@ -7,9 +7,8 @@
namespace caffe {
template <typename Dtype>
-void LRNLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void LRNLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
num_ = bottom[0]->num();
channels_ = bottom[0]->channels();
height_ = bottom[0]->height();
@@ -96,21 +95,22 @@ void LRNLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype LRNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void LRNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
switch (this->layer_param_.lrn_param().norm_region()) {
case LRNParameter_NormRegion_ACROSS_CHANNELS:
- return CrossChannelForward_cpu(bottom, top);
+ CrossChannelForward_cpu(bottom, top);
+ break;
case LRNParameter_NormRegion_WITHIN_CHANNEL:
- return WithinChannelForward(bottom, top);
+ WithinChannelForward(bottom, top);
+ break;
default:
LOG(FATAL) << "Unknown normalization region.";
- return Dtype(0);
}
}
template <typename Dtype>
-Dtype LRNLayer<Dtype>::CrossChannelForward_cpu(
+void LRNLayer<Dtype>::CrossChannelForward_cpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -154,19 +154,16 @@ Dtype LRNLayer<Dtype>::CrossChannelForward_cpu(
// In the end, compute output
caffe_powx<Dtype>(scale_.count(), scale_data, -beta_, top_data);
caffe_mul<Dtype>(scale_.count(), top_data, bottom_data, top_data);
-
- return Dtype(0.);
}
template <typename Dtype>
-Dtype LRNLayer<Dtype>::WithinChannelForward(
+void LRNLayer<Dtype>::WithinChannelForward(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
split_layer_->Forward(bottom, &split_top_vec_);
square_layer_->Forward(square_bottom_vec_, &square_top_vec_);
pool_layer_->Forward(square_top_vec_, &pool_top_vec_);
power_layer_->Forward(pool_top_vec_, &power_top_vec_);
product_layer_->Forward(product_bottom_vec_, top);
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/lrn_layer.cu b/src/caffe/layers/lrn_layer.cu
index eee12e66..d6cb23bf 100644
--- a/src/caffe/layers/lrn_layer.cu
+++ b/src/caffe/layers/lrn_layer.cu
@@ -54,16 +54,17 @@ __global__ void LRNFillScale(const int nthreads, const Dtype* in,
template <typename Dtype>
-Dtype LRNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void LRNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
switch (this->layer_param_.lrn_param().norm_region()) {
case LRNParameter_NormRegion_ACROSS_CHANNELS:
- return CrossChannelForward_gpu(bottom, top);
+ CrossChannelForward_gpu(bottom, top);
+ break;
case LRNParameter_NormRegion_WITHIN_CHANNEL:
- return WithinChannelForward(bottom, top);
+ WithinChannelForward(bottom, top);
+ break;
default:
LOG(FATAL) << "Unknown normalization region.";
- return Dtype(0);
}
}
@@ -77,7 +78,7 @@ __global__ void LRNComputeOutput(const int nthreads, const Dtype* in,
}
template <typename Dtype>
-Dtype LRNLayer<Dtype>::CrossChannelForward_gpu(
+void LRNLayer<Dtype>::CrossChannelForward_gpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
// First, compute scale
const Dtype* bottom_data = bottom[0]->gpu_data();
@@ -96,7 +97,6 @@ Dtype LRNLayer<Dtype>::CrossChannelForward_gpu(
LRNComputeOutput<<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS>>>(
n_threads, bottom_data, scale_data, -beta_, top_data);
CUDA_POST_KERNEL_CHECK;
- return Dtype(0.);
}
diff --git a/src/caffe/layers/memory_data_layer.cpp b/src/caffe/layers/memory_data_layer.cpp
index d1717fd4..fda92976 100644
--- a/src/caffe/layers/memory_data_layer.cpp
+++ b/src/caffe/layers/memory_data_layer.cpp
@@ -6,9 +6,8 @@
namespace caffe {
template <typename Dtype>
-void MemoryDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void MemoryDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
batch_size_ = this->layer_param_.memory_data_param().batch_size();
datum_channels_ = this->layer_param_.memory_data_param().channels();
datum_height_ = this->layer_param_.memory_data_param().height();
@@ -34,13 +33,12 @@ void MemoryDataLayer<Dtype>::Reset(Dtype* data, Dtype* labels, int n) {
}
template <typename Dtype>
-Dtype MemoryDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void MemoryDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK(data_) << "MemoryDataLayer needs to be initalized by calling Reset";
(*top)[0]->set_cpu_data(data_ + pos_ * datum_size_);
(*top)[1]->set_cpu_data(labels_ + pos_);
pos_ = (pos_ + batch_size_) % n_;
- return Dtype(0.);
}
INSTANTIATE_CLASS(MemoryDataLayer);
diff --git a/src/caffe/layers/multinomial_logistic_loss_layer.cpp b/src/caffe/layers/multinomial_logistic_loss_layer.cpp
index a9c7de65..cf96bfe7 100644
--- a/src/caffe/layers/multinomial_logistic_loss_layer.cpp
+++ b/src/caffe/layers/multinomial_logistic_loss_layer.cpp
@@ -11,15 +11,16 @@
namespace caffe {
template <typename Dtype>
-void MultinomialLogisticLossLayer<Dtype>::FurtherSetUp(
+void MultinomialLogisticLossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+ LossLayer<Dtype>::LayerSetUp(bottom, top);
CHECK_EQ(bottom[1]->channels(), 1);
CHECK_EQ(bottom[1]->height(), 1);
CHECK_EQ(bottom[1]->width(), 1);
}
template <typename Dtype>
-Dtype MultinomialLogisticLossLayer<Dtype>::Forward_cpu(
+void MultinomialLogisticLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* bottom_label = bottom[1]->cpu_data();
@@ -32,10 +33,7 @@ Dtype MultinomialLogisticLossLayer<Dtype>::Forward_cpu(
bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD));
loss -= log(prob);
}
- if (top->size() == 1) {
- (*top)[0]->mutable_cpu_data()[0] = loss / num;
- }
- return loss / num;
+ (*top)[0]->mutable_cpu_data()[0] = loss / num;
}
template <typename Dtype>
@@ -53,11 +51,12 @@ void MultinomialLogisticLossLayer<Dtype>::Backward_cpu(
int num = (*bottom)[0]->num();
int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
caffe_set((*bottom)[0]->count(), Dtype(0), bottom_diff);
+ const Dtype scale = - top[0]->cpu_diff()[0] / num;
for (int i = 0; i < num; ++i) {
int label = static_cast<int>(bottom_label[i]);
Dtype prob = std::max(
bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD));
- bottom_diff[i * dim + label] = -1. / prob / num;
+ bottom_diff[i * dim + label] = scale / prob;
}
}
}
diff --git a/src/caffe/layers/mvn_layer.cpp b/src/caffe/layers/mvn_layer.cpp
index 30235b3f..4d90702f 100644
--- a/src/caffe/layers/mvn_layer.cpp
+++ b/src/caffe/layers/mvn_layer.cpp
@@ -8,9 +8,8 @@
namespace caffe {
template <typename Dtype>
-void MVNLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void MVNLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
(*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
mean_.Reshape(bottom[0]->num(), bottom[0]->channels(),
@@ -26,7 +25,7 @@ void MVNLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype MVNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void MVNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -85,8 +84,6 @@ Dtype MVNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
caffe_add(temp_.count(), bottom_data, temp_.cpu_data(), top_data);
}
-
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/mvn_layer.cu b/src/caffe/layers/mvn_layer.cu
index dd823984..2c02dfe1 100644
--- a/src/caffe/layers/mvn_layer.cu
+++ b/src/caffe/layers/mvn_layer.cu
@@ -8,7 +8,7 @@
namespace caffe {
template <typename Dtype>
-Dtype MVNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void MVNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -68,8 +68,6 @@ Dtype MVNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
caffe_gpu_add(temp_.count(), bottom_data, temp_.gpu_data(), top_data);
}
-
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/neuron_layer.cpp b/src/caffe/layers/neuron_layer.cpp
index 3343b26c..eff7948a 100644
--- a/src/caffe/layers/neuron_layer.cpp
+++ b/src/caffe/layers/neuron_layer.cpp
@@ -6,14 +6,12 @@
namespace caffe {
template <typename Dtype>
-void NeuronLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void NeuronLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
// NeuronLayer allows in-place computations. If the computation is not
// in-place, we will need to initialize the top blob.
if ((*top)[0] != bottom[0]) {
- (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
- bottom[0]->height(), bottom[0]->width());
+ (*top)[0]->ReshapeLike(*bottom[0]);
}
}
diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp
index 30657b6c..9e77fa28 100644
--- a/src/caffe/layers/pooling_layer.cpp
+++ b/src/caffe/layers/pooling_layer.cpp
@@ -14,18 +14,8 @@ using std::min;
using std::max;
template <typename Dtype>
-void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void PoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- // Set the max number of top blobs before calling base Layer::SetUp.
- // If doing MAX pooling, we can optionally output an extra top Blob
- // for the mask. Otherwise, we only have one top Blob.
- if (this->layer_param_.pooling_param().pool() ==
- PoolingParameter_PoolMethod_MAX) {
- max_top_blobs_ = 2;
- } else {
- max_top_blobs_ = 1;
- }
- Layer<Dtype>::SetUp(bottom, top);
PoolingParameter pool_param = this->layer_param_.pooling_param();
CHECK(!pool_param.has_kernel_size() !=
!(pool_param.has_kernel_h() && pool_param.has_kernel_w()))
@@ -111,7 +101,7 @@ void PoolingLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
// TODO(Yangqing): Is there a faster way to do pooling in the channel-first
// case?
template <typename Dtype>
-Dtype PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -210,7 +200,6 @@ Dtype PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
default:
LOG(FATAL) << "Unknown pooling method.";
}
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/pooling_layer.cu b/src/caffe/layers/pooling_layer.cu
index 58f1997c..e64128b8 100644
--- a/src/caffe/layers/pooling_layer.cu
+++ b/src/caffe/layers/pooling_layer.cu
@@ -151,7 +151,7 @@ __global__ void StoPoolForwardTest(const int nthreads,
template <typename Dtype>
-Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -206,7 +206,6 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
LOG(FATAL) << "Unknown pooling method.";
}
CUDA_POST_KERNEL_CHECK;
- return Dtype(0.);
}
diff --git a/src/caffe/layers/power_layer.cpp b/src/caffe/layers/power_layer.cpp
index 8b5d8d16..a332c4d2 100644
--- a/src/caffe/layers/power_layer.cpp
+++ b/src/caffe/layers/power_layer.cpp
@@ -8,9 +8,9 @@
namespace caffe {
template <typename Dtype>
-void PowerLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void PowerLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- NeuronLayer<Dtype>::SetUp(bottom, top);
+ NeuronLayer<Dtype>::LayerSetUp(bottom, top);
power_ = this->layer_param_.power_param().power();
scale_ = this->layer_param_.power_param().scale();
shift_ = this->layer_param_.power_param().shift();
@@ -19,7 +19,7 @@ void PowerLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
// Compute y = (shift + scale * x)^power
template <typename Dtype>
-Dtype PowerLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void PowerLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Dtype* top_data = (*top)[0]->mutable_cpu_data();
const int count = bottom[0]->count();
@@ -27,7 +27,6 @@ Dtype PowerLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
if (diff_scale_ == Dtype(0)) {
Dtype value = (power_ == 0) ? Dtype(1) : pow(shift_, power_);
caffe_set(count, value, top_data);
- return Dtype(0);
}
const Dtype* bottom_data = bottom[0]->cpu_data();
caffe_copy(count, bottom_data, top_data);
@@ -40,7 +39,6 @@ Dtype PowerLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
if (power_ != Dtype(1)) {
caffe_powx(count, top_data, power_, top_data);
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/power_layer.cu b/src/caffe/layers/power_layer.cu
index 0950b78b..eaf63c1f 100644
--- a/src/caffe/layers/power_layer.cu
+++ b/src/caffe/layers/power_layer.cu
@@ -8,7 +8,7 @@
namespace caffe {
template <typename Dtype>
-Dtype PowerLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void PowerLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
@@ -16,7 +16,6 @@ Dtype PowerLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
if (diff_scale_ == Dtype(0)) {
Dtype value = (power_ == 0) ? Dtype(1) : pow(shift_, power_);
caffe_gpu_set(count, value, top_data);
- return Dtype(0);
}
const Dtype* bottom_data = bottom[0]->gpu_data();
caffe_copy(count, bottom_data, top_data);
@@ -29,7 +28,6 @@ Dtype PowerLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
if (power_ != Dtype(1)) {
caffe_gpu_powx(count, top_data, power_, top_data);
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/relu_layer.cpp b/src/caffe/layers/relu_layer.cpp
index fca10a5a..b50352f8 100644
--- a/src/caffe/layers/relu_layer.cpp
+++ b/src/caffe/layers/relu_layer.cpp
@@ -7,7 +7,7 @@
namespace caffe {
template <typename Dtype>
-Dtype ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -17,7 +17,6 @@ Dtype ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
top_data[i] = std::max(bottom_data[i], Dtype(0))
+ negative_slope * std::min(bottom_data[i], Dtype(0));
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/relu_layer.cu b/src/caffe/layers/relu_layer.cu
index a74428bf..def2bbcd 100644
--- a/src/caffe/layers/relu_layer.cu
+++ b/src/caffe/layers/relu_layer.cu
@@ -15,7 +15,7 @@ __global__ void ReLUForward(const int n, const Dtype* in, Dtype* out,
}
template <typename Dtype>
-Dtype ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -30,7 +30,6 @@ Dtype ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
// << " top_data: " << (unsigned long)top_data
// << " blocks: " << CAFFE_GET_BLOCKS(count)
// << " threads: " << CAFFE_CUDA_NUM_THREADS;
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
index 24ab6a85..6e440a82 100644
--- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
+++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
@@ -9,8 +9,9 @@
namespace caffe {
template <typename Dtype>
-void SigmoidCrossEntropyLossLayer<Dtype>::FurtherSetUp(
+void SigmoidCrossEntropyLossLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+ LossLayer<Dtype>::LayerSetUp(bottom, top);
CHECK_EQ(bottom[0]->count(), bottom[1]->count()) <<
"SIGMOID_CROSS_ENTROPY_LOSS layer inputs must have the same count.";
sigmoid_bottom_vec_.clear();
@@ -21,7 +22,7 @@ void SigmoidCrossEntropyLossLayer<Dtype>::FurtherSetUp(
}
template <typename Dtype>
-Dtype SigmoidCrossEntropyLossLayer<Dtype>::Forward_cpu(
+void SigmoidCrossEntropyLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
// The forward pass computes the sigmoid outputs.
sigmoid_bottom_vec_[0] = bottom[0];
@@ -37,10 +38,7 @@ Dtype SigmoidCrossEntropyLossLayer<Dtype>::Forward_cpu(
loss -= input_data[i] * (target[i] - (input_data[i] >= 0)) -
log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0)));
}
- if (top->size() == 1) {
- (*top)[0]->mutable_cpu_data()[0] = loss / num;
- }
- return loss / num;
+ (*top)[0]->mutable_cpu_data()[0] = loss / num;
}
template <typename Dtype>
@@ -60,7 +58,8 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_cpu(
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
caffe_sub(count, sigmoid_output_data, target, bottom_diff);
// Scale down gradient
- caffe_scal(count, Dtype(1) / num, bottom_diff);
+ const Dtype loss_weight = top[0]->cpu_diff()[0];
+ caffe_scal(count, loss_weight / num, bottom_diff);
}
}
diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu
index 0e4dab76..8d0fdc6f 100644
--- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu
+++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu
@@ -9,7 +9,7 @@
namespace caffe {
template <typename Dtype>
-Dtype SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
+void SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
// The forward pass computes the sigmoid outputs.
sigmoid_bottom_vec_[0] = bottom[0];
@@ -25,10 +25,7 @@ Dtype SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
loss -= input_data[i] * (target[i] - (input_data[i] >= 0)) -
log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0)));
}
- if (top->size() == 1) {
- (*top)[0]->mutable_cpu_data()[0] = loss / num;
- }
- return loss / num;
+ (*top)[0]->mutable_cpu_data()[0] = loss / num;
}
template <typename Dtype>
@@ -49,7 +46,8 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_gpu(
caffe_copy(count, sigmoid_output_data, bottom_diff);
caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff);
// Scale down gradient
- caffe_gpu_scal(count, Dtype(1) / num, bottom_diff);
+ const Dtype loss_weight = top[0]->cpu_diff()[0];
+ caffe_gpu_scal(count, loss_weight / num, bottom_diff);
}
}
diff --git a/src/caffe/layers/sigmoid_layer.cpp b/src/caffe/layers/sigmoid_layer.cpp
index 0f8b582d..d7bba7fb 100644
--- a/src/caffe/layers/sigmoid_layer.cpp
+++ b/src/caffe/layers/sigmoid_layer.cpp
@@ -13,7 +13,7 @@ inline Dtype sigmoid(Dtype x) {
}
template <typename Dtype>
-Dtype SigmoidLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void SigmoidLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -21,7 +21,6 @@ Dtype SigmoidLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
for (int i = 0; i < count; ++i) {
top_data[i] = sigmoid(bottom_data[i]);
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/sigmoid_layer.cu b/src/caffe/layers/sigmoid_layer.cu
index 039796e1..e1ebb1f6 100644
--- a/src/caffe/layers/sigmoid_layer.cu
+++ b/src/caffe/layers/sigmoid_layer.cu
@@ -15,7 +15,7 @@ __global__ void SigmoidForward(const int n, const Dtype* in, Dtype* out) {
}
template <typename Dtype>
-Dtype SigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void SigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -29,7 +29,6 @@ Dtype SigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
// << " top_data: " << (unsigned long)top_data
// << " blocks: " << CAFFE_GET_BLOCKS(count)
// << " threads: " << CAFFE_CUDA_NUM_THREADS;
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/slice_layer.cpp b/src/caffe/layers/slice_layer.cpp
index e182837c..9fa12752 100644
--- a/src/caffe/layers/slice_layer.cpp
+++ b/src/caffe/layers/slice_layer.cpp
@@ -8,9 +8,8 @@
namespace caffe {
template <typename Dtype>
-void SliceLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void SliceLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
const SliceParameter& slice_param = this->layer_param_.slice_param();
slice_dim_ = slice_param.slice_dim();
CHECK_GE(slice_dim_, 0);
@@ -73,7 +72,7 @@ void SliceLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype SliceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void SliceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->mutable_cpu_data();
if (slice_dim_ == 0) {
@@ -98,7 +97,6 @@ Dtype SliceLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
offset_channel += blob->channels();
}
} // slice_dim_ is guaranteed to be 0 or 1 by SetUp.
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/slice_layer.cu b/src/caffe/layers/slice_layer.cu
index 8e01131e..f64e5754 100644
--- a/src/caffe/layers/slice_layer.cu
+++ b/src/caffe/layers/slice_layer.cu
@@ -7,7 +7,7 @@
namespace caffe {
template <typename Dtype>
-Dtype SliceLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void SliceLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->mutable_gpu_data();
if (slice_dim_ == 0) {
@@ -32,7 +32,6 @@ Dtype SliceLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
offset_channel += blob->channels();
}
} // slice_dim_ is guaranteed to be 0 or 1 by SetUp.
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp
index 61990ed9..fa2ba17a 100644
--- a/src/caffe/layers/softmax_layer.cpp
+++ b/src/caffe/layers/softmax_layer.cpp
@@ -9,9 +9,8 @@
namespace caffe {
template <typename Dtype>
-void SoftmaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void SoftmaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
(*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
sum_multiplier_.Reshape(1, bottom[0]->channels(),
@@ -24,7 +23,7 @@ void SoftmaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -52,7 +51,6 @@ Dtype SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
for (int i = 0; i < num; ++i) {
caffe_scal<Dtype>(dim, Dtype(1.) / scale_data[i], top_data + i * dim);
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/softmax_layer.cu b/src/caffe/layers/softmax_layer.cu
index 65b0e229..6b853099 100644
--- a/src/caffe/layers/softmax_layer.cu
+++ b/src/caffe/layers/softmax_layer.cu
@@ -39,7 +39,7 @@ __global__ void kernel_exp(const int num, const Dtype* data, Dtype* out) {
}
template <typename Dtype>
-Dtype SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -68,7 +68,6 @@ Dtype SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
kernel_softmax_div<Dtype><<<CAFFE_GET_BLOCKS(num * dim),
CAFFE_CUDA_NUM_THREADS>>>(
num, dim, scale_data, top_data);
- return Dtype(0);
}
// TODO(Yangqing): implement the GPU version of softmax.
diff --git a/src/caffe/layers/softmax_loss_layer.cpp b/src/caffe/layers/softmax_loss_layer.cpp
index 98cf14c4..0fa83ccd 100644
--- a/src/caffe/layers/softmax_loss_layer.cpp
+++ b/src/caffe/layers/softmax_loss_layer.cpp
@@ -9,26 +9,22 @@
namespace caffe {
template <typename Dtype>
-void SoftmaxWithLossLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
- vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
+void SoftmaxWithLossLayer<Dtype>::LayerSetUp(
+ const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+ LossLayer<Dtype>::LayerSetUp(bottom, top);
softmax_bottom_vec_.clear();
softmax_bottom_vec_.push_back(bottom[0]);
+ softmax_top_vec_.clear();
softmax_top_vec_.push_back(&prob_);
softmax_layer_->SetUp(softmax_bottom_vec_, &softmax_top_vec_);
- if (top->size() >= 1) {
- // softmax loss (averaged across batch)
- (*top)[0]->Reshape(1, 1, 1, 1);
- }
- if (top->size() == 2) {
+ if (top->size() >= 2) {
// softmax output
- (*top)[1]->Reshape(bottom[0]->num(), bottom[0]->channels(),
- bottom[0]->height(), bottom[0]->width());
+ (*top)[1]->ReshapeLike(*bottom[0]);
}
}
template <typename Dtype>
-Dtype SoftmaxWithLossLayer<Dtype>::Forward_cpu(
+void SoftmaxWithLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
// The forward pass computes the softmax prob values.
softmax_bottom_vec_[0] = bottom[0];
@@ -42,13 +38,10 @@ Dtype SoftmaxWithLossLayer<Dtype>::Forward_cpu(
loss += -log(std::max(prob_data[i * dim + static_cast<int>(label[i])],
Dtype(FLT_MIN)));
}
- if (top->size() >= 1) {
- (*top)[0]->mutable_cpu_data()[0] = loss / num;
- }
+ (*top)[0]->mutable_cpu_data()[0] = loss / num;
if (top->size() == 2) {
(*top)[1]->ShareData(prob_);
}
- return loss / num;
}
template <typename Dtype>
@@ -69,8 +62,9 @@ void SoftmaxWithLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
for (int i = 0; i < num; ++i) {
bottom_diff[i * dim + static_cast<int>(label[i])] -= 1;
}
- // Scale down gradient
- caffe_scal(prob_.count(), Dtype(1) / num, bottom_diff);
+ // Scale gradient
+ const Dtype loss_weight = top[0]->cpu_diff()[0];
+ caffe_scal(prob_.count(), loss_weight / num, bottom_diff);
}
}
diff --git a/src/caffe/layers/softmax_loss_layer.cu b/src/caffe/layers/softmax_loss_layer.cu
index 32f3e670..9ef8dd23 100644
--- a/src/caffe/layers/softmax_loss_layer.cu
+++ b/src/caffe/layers/softmax_loss_layer.cu
@@ -9,10 +9,9 @@
namespace caffe {
template <typename Dtype>
-Dtype SoftmaxWithLossLayer<Dtype>::Forward_gpu(
+void SoftmaxWithLossLayer<Dtype>::Forward_gpu(
const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
- // The forward pass computes the softmax prob values.
- return Forward_cpu(bottom, top);
+ Forward_cpu(bottom, top);
}
template <typename Dtype>
diff --git a/src/caffe/layers/split_layer.cpp b/src/caffe/layers/split_layer.cpp
index 2786d3f7..c223d475 100644
--- a/src/caffe/layers/split_layer.cpp
+++ b/src/caffe/layers/split_layer.cpp
@@ -7,9 +7,8 @@
namespace caffe {
template <typename Dtype>
-void SplitLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void SplitLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
count_ = bottom[0]->count();
for (int i = 0; i < top->size(); ++i) {
// Allow the 0th top blob to be 'in-place', but no others.
@@ -25,12 +24,11 @@ void SplitLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
}
template <typename Dtype>
-Dtype SplitLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void SplitLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
for (int i = 0; i < top->size(); ++i) {
(*top)[i]->ShareData(*bottom[0]);
}
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/split_layer.cu b/src/caffe/layers/split_layer.cu
index 1cf15a79..2d2b3c2b 100644
--- a/src/caffe/layers/split_layer.cu
+++ b/src/caffe/layers/split_layer.cu
@@ -7,12 +7,11 @@
namespace caffe {
template <typename Dtype>
-Dtype SplitLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void SplitLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
for (int i = 0; i < top->size(); ++i) {
(*top)[i]->ShareData(*bottom[0]);
}
- return Dtype(0.);
}
template <typename Dtype>
diff --git a/src/caffe/layers/tanh_layer.cpp b/src/caffe/layers/tanh_layer.cpp
index 0c8be3fa..8dae0054 100644
--- a/src/caffe/layers/tanh_layer.cpp
+++ b/src/caffe/layers/tanh_layer.cpp
@@ -10,7 +10,7 @@
namespace caffe {
template <typename Dtype>
-Dtype TanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void TanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -20,7 +20,6 @@ Dtype TanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
exp2x = exp(2 * bottom_data[i]);
top_data[i] = (exp2x - Dtype(1)) / (exp2x + Dtype(1));
}
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/tanh_layer.cu b/src/caffe/layers/tanh_layer.cu
index b3daad1e..bdb7a949 100644
--- a/src/caffe/layers/tanh_layer.cu
+++ b/src/caffe/layers/tanh_layer.cu
@@ -18,7 +18,7 @@ __global__ void TanHForward(const int n, const Dtype* in, Dtype* out) {
}
template <typename Dtype>
-Dtype TanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void TanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -27,7 +27,6 @@ Dtype TanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
TanHForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data);
CUDA_POST_KERNEL_CHECK;
- return Dtype(0);
}
template <typename Dtype>
diff --git a/src/caffe/layers/threshold_layer.cpp b/src/caffe/layers/threshold_layer.cpp
index c9323560..180ea6a3 100644
--- a/src/caffe/layers/threshold_layer.cpp
+++ b/src/caffe/layers/threshold_layer.cpp
@@ -7,14 +7,14 @@
namespace caffe {
template <typename Dtype>
-void ThresholdLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void ThresholdLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- NeuronLayer<Dtype>::SetUp(bottom, top);
+ NeuronLayer<Dtype>::LayerSetUp(bottom, top);
threshold_ = this->layer_param_.threshold_param().threshold();
}
template <typename Dtype>
-Dtype ThresholdLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void ThresholdLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
@@ -22,7 +22,6 @@ Dtype ThresholdLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
for (int i = 0; i < count; ++i) {
top_data[i] = (bottom_data[i] > threshold_) ? Dtype(1) : Dtype(0);
}
- return Dtype(0);
}
#ifdef CPU_ONLY
diff --git a/src/caffe/layers/threshold_layer.cu b/src/caffe/layers/threshold_layer.cu
index 398d56e8..93430815 100644
--- a/src/caffe/layers/threshold_layer.cu
+++ b/src/caffe/layers/threshold_layer.cu
@@ -15,7 +15,7 @@ __global__ void ThresholdForward(const int n, const Dtype threshold,
}
template <typename Dtype>
-Dtype ThresholdLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void ThresholdLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
@@ -24,8 +24,6 @@ Dtype ThresholdLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
ThresholdForward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, threshold_, bottom_data, top_data);
CUDA_POST_KERNEL_CHECK;
-
- return Dtype(0);
}
diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp
index ddff5555..8138234e 100644
--- a/src/caffe/layers/window_data_layer.cpp
+++ b/src/caffe/layers/window_data_layer.cpp
@@ -246,10 +246,9 @@ WindowDataLayer<Dtype>::~WindowDataLayer<Dtype>() {
}
template <typename Dtype>
-void WindowDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+void WindowDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
- Layer<Dtype>::SetUp(bottom, top);
- // SetUp runs through the window_file and creates two structures
+ // LayerSetUp runs through the window_file and creates two structures
// that hold windows: one for foreground (object) windows and one
// for background (non-object) windows. We use an overlap threshold
// to decide which is which.
@@ -426,7 +425,7 @@ unsigned int WindowDataLayer<Dtype>::PrefetchRand() {
}
template <typename Dtype>
-Dtype WindowDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+void WindowDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, join the thread
JoinPrefetchThread();
@@ -437,7 +436,6 @@ Dtype WindowDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
(*top)[1]->mutable_cpu_data());
// Start a new prefetch thread
CreatePrefetchThread();
- return Dtype(0.);
}
#ifdef CPU_ONLY
diff --git a/src/caffe/layers/window_data_layer.cu b/src/caffe/layers/window_data_layer.cu
index 6e8fa8b3..475ec265 100644
--- a/src/caffe/layers/window_data_layer.cu
+++ b/src/caffe/layers/window_data_layer.cu
@@ -18,7 +18,7 @@
namespace caffe {
template <typename Dtype>
-Dtype WindowDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+void WindowDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, join the thread
JoinPrefetchThread();
@@ -29,7 +29,6 @@ Dtype WindowDataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
(*top)[1]->mutable_gpu_data());
// Start a new prefetch thread
CreatePrefetchThread();
- return Dtype(0.);
}
INSTANTIATE_CLASS(WindowDataLayer);
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index db6b4ffe..68a80261 100644
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
@@ -1,3 +1,4 @@
+#include <algorithm>
#include <map>
#include <set>
#include <string>
@@ -73,11 +74,26 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
// If a blob needs backward, this layer should provide it.
need_backward |= blob_need_backward_[blob_id];
}
- for (int top_id = 0; top_id < layer_param.top_size(); ++top_id) {
+ int num_top = layer_param.top_size();
+ for (int top_id = 0; top_id < num_top; ++top_id) {
AppendTop(param, layer_id, top_id, &available_blobs, &blob_name_to_idx);
}
+ // If the layer specifies that AutoTopBlobs() -> true and the LayerParameter
+ // specified fewer than the required number (as specified by
+ // ExactNumTopBlobs() or MinTopBlobs()), allocate them here.
+ Layer<Dtype>* layer = layers_[layer_id].get();
+ if (layer->AutoTopBlobs()) {
+ const int needed_num_top =
+ std::max(layer->MinTopBlobs(), layer->ExactNumTopBlobs());
+ for (; num_top < needed_num_top; ++num_top) {
+ // Add "anonymous" top blobs -- do not modify available_blobs or
+ // blob_name_to_idx as we don't want these blobs to be usable as input
+ // to other layers.
+ AppendTop(param, layer_id, num_top, NULL, NULL);
+ }
+ }
// After this layer is connected, set it up.
- // LOG(INFO) << "Setting up " << layer_names_[layer_id];
+ LOG(INFO) << "Setting up " << layer_names_[layer_id];
layers_[layer_id]->SetUp(bottom_vecs_[layer_id], &top_vecs_[layer_id]);
for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
LOG(INFO) << "Top shape: " << top_vecs_[layer_id][top_id]->num() << " "
@@ -272,15 +288,17 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
shared_ptr<LayerParameter> layer_param((layer_id >= 0) ?
(new LayerParameter(param.layers(layer_id))) : NULL);
const string& blob_name = layer_param ?
- layer_param->top(top_id) : param.input(top_id);
+ (layer_param->top_size() > top_id ?
+ layer_param->top(top_id) : "(automatic)") : param.input(top_id);
// Check if we are doing in-place computation
- if (layer_param && layer_param->bottom_size() > top_id &&
+ if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id &&
blob_name == layer_param->bottom(top_id)) {
// In-place computation
LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)";
top_vecs_[layer_id].push_back(blobs_[(*blob_name_to_idx)[blob_name]].get());
top_id_vecs_[layer_id].push_back((*blob_name_to_idx)[blob_name]);
- } else if (blob_name_to_idx->find(blob_name) != blob_name_to_idx->end()) {
+ } else if (blob_name_to_idx &&
+ blob_name_to_idx->find(blob_name) != blob_name_to_idx->end()) {
// If we are not doing in-place computation but have duplicated blobs,
// raise an error.
LOG(FATAL) << "Duplicate blobs produced by multiple sources.";
@@ -296,7 +314,7 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
blobs_.push_back(blob_pointer);
blob_names_.push_back(blob_name);
blob_need_backward_.push_back(false);
- (*blob_name_to_idx)[blob_name] = blob_id;
+ if (blob_name_to_idx) { (*blob_name_to_idx)[blob_name] = blob_id; }
if (layer_id == -1) {
// Set the (explicitly specified) dimensions of the input blob.
blob_pointer->Reshape(param.input_dim(top_id * 4),
@@ -311,7 +329,7 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
}
memory_used_ += blob_pointer->count();
}
- available_blobs->insert(blob_name);
+ if (available_blobs) { available_blobs->insert(blob_name); }
}
// Helper for Net::Init: add a new bottom blob to the net.
diff --git a/src/caffe/test/test_euclidean_loss_layer.cpp b/src/caffe/test/test_euclidean_loss_layer.cpp
index 511d38cc..d7d2de7e 100644
--- a/src/caffe/test/test_euclidean_loss_layer.cpp
+++ b/src/caffe/test/test_euclidean_loss_layer.cpp
@@ -22,7 +22,8 @@ class EuclideanLossLayerTest : public MultiDeviceTest<TypeParam> {
protected:
EuclideanLossLayerTest()
: blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
- blob_bottom_label_(new Blob<Dtype>(10, 5, 1, 1)) {
+ blob_bottom_label_(new Blob<Dtype>(10, 5, 1, 1)),
+ blob_top_loss_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
GaussianFiller<Dtype> filler(filler_param);
@@ -30,28 +31,61 @@ class EuclideanLossLayerTest : public MultiDeviceTest<TypeParam> {
blob_bottom_vec_.push_back(blob_bottom_data_);
filler.Fill(this->blob_bottom_label_);
blob_bottom_vec_.push_back(blob_bottom_label_);
+ blob_top_vec_.push_back(blob_top_loss_);
}
virtual ~EuclideanLossLayerTest() {
delete blob_bottom_data_;
delete blob_bottom_label_;
+ delete blob_top_loss_;
}
+
+ void TestForward() {
+ // Get the loss without a specified objective weight -- should be
+ // equivalent to explicitly specifiying a weight of 1.
+ LayerParameter layer_param;
+ EuclideanLossLayer<Dtype> layer_weight_1(layer_param);
+ layer_weight_1.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+ const Dtype loss_weight_1 =
+ layer_weight_1.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
+
+ // Get the loss again with a different objective weight; check that it is
+ // scaled appropriately.
+ const Dtype kLossWeight = 3.7;
+ layer_param.add_loss_weight(kLossWeight);
+ EuclideanLossLayer<Dtype> layer_weight_2(layer_param);
+ layer_weight_2.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+ const Dtype loss_weight_2 =
+ layer_weight_2.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
+ const Dtype kErrorMargin = 1e-5;
+ EXPECT_NEAR(loss_weight_1 * kLossWeight, loss_weight_2, kErrorMargin);
+ // Make sure the loss is non-trivial.
+ const Dtype kNonTrivialAbsThresh = 1e-1;
+ EXPECT_GE(fabs(loss_weight_1), kNonTrivialAbsThresh);
+ }
+
Blob<Dtype>* const blob_bottom_data_;
Blob<Dtype>* const blob_bottom_label_;
+ Blob<Dtype>* const blob_top_loss_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};
TYPED_TEST_CASE(EuclideanLossLayerTest, TestDtypesAndDevices);
+TYPED_TEST(EuclideanLossLayerTest, TestForward) {
+ this->TestForward();
+}
+
TYPED_TEST(EuclideanLossLayerTest, TestGradient) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
+ const Dtype kLossWeight = 3.7;
+ layer_param.add_loss_weight(kLossWeight);
EuclideanLossLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
GradientChecker<Dtype> checker(1e-2, 1e-2, 1701);
- checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_),
- &(this->blob_top_vec_), -1, -1, -1);
+ checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+ &(this->blob_top_vec_));
}
-
} // namespace caffe
diff --git a/src/caffe/test/test_hinge_loss_layer.cpp b/src/caffe/test/test_hinge_loss_layer.cpp
index 8f6f6f78..3c11b9ac 100644
--- a/src/caffe/test/test_hinge_loss_layer.cpp
+++ b/src/caffe/test/test_hinge_loss_layer.cpp
@@ -22,7 +22,8 @@ class HingeLossLayerTest : public MultiDeviceTest<TypeParam> {
protected:
HingeLossLayerTest()
: blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
- blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)) {
+ blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)),
+ blob_top_loss_(new Blob<Dtype>()) {
// fill the values
Caffe::set_random_seed(1701);
FillerParameter filler_param;
@@ -34,13 +35,16 @@ class HingeLossLayerTest : public MultiDeviceTest<TypeParam> {
blob_bottom_label_->mutable_cpu_data()[i] = caffe_rng_rand() % 5;
}
blob_bottom_vec_.push_back(blob_bottom_label_);
+ blob_top_vec_.push_back(blob_top_loss_);
}
virtual ~HingeLossLayerTest() {
delete blob_bottom_data_;
delete blob_bottom_label_;
+ delete blob_top_loss_;
}
Blob<Dtype>* const blob_bottom_data_;
Blob<Dtype>* const blob_bottom_label_;
+ Blob<Dtype>* const blob_top_loss_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};
@@ -52,10 +56,9 @@ TYPED_TEST(HingeLossLayerTest, TestGradientL1) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
HingeLossLayer<Dtype> layer(layer_param);
- layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
- GradientChecker<Dtype> checker(1e-2, 1e-3, 1701, 1, 0.01);
- checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_),
- &(this->blob_top_vec_), 0, -1, -1);
+ GradientChecker<Dtype> checker(1e-2, 2e-3, 1701, 1, 0.01);
+ checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+ &(this->blob_top_vec_), 0);
}
TYPED_TEST(HingeLossLayerTest, TestGradientL2) {
@@ -65,11 +68,9 @@ TYPED_TEST(HingeLossLayerTest, TestGradientL2) {
HingeLossParameter* hinge_loss_param = layer_param.mutable_hinge_loss_param();
hinge_loss_param->set_norm(HingeLossParameter_Norm_L2);
HingeLossLayer<Dtype> layer(layer_param);
- layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
- GradientChecker<Dtype> checker(1e-2, 2e-3, 1701);
- checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_),
- &(this->blob_top_vec_), 0, -1, -1);
+ GradientChecker<Dtype> checker(1e-2, 1e-2, 1701);
+ checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+ &(this->blob_top_vec_), 0);
}
-
} // namespace caffe
diff --git a/src/caffe/test/test_infogain_loss_layer.cpp b/src/caffe/test/test_infogain_loss_layer.cpp
index 162d0e6c..de2f901a 100644
--- a/src/caffe/test/test_infogain_loss_layer.cpp
+++ b/src/caffe/test/test_infogain_loss_layer.cpp
@@ -23,7 +23,8 @@ class InfogainLossLayerTest : public MultiDeviceTest<TypeParam> {
InfogainLossLayerTest()
: blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)),
- blob_bottom_infogain_(new Blob<Dtype>(1, 1, 5, 5)) {
+ blob_bottom_infogain_(new Blob<Dtype>(1, 1, 5, 5)),
+ blob_top_loss_(new Blob<Dtype>()) {
Caffe::set_random_seed(1701);
FillerParameter filler_param;
PositiveUnitballFiller<Dtype> filler(filler_param);
@@ -38,15 +39,18 @@ class InfogainLossLayerTest : public MultiDeviceTest<TypeParam> {
UniformFiller<Dtype> infogain_filler(filler_param);
infogain_filler.Fill(this->blob_bottom_infogain_);
blob_bottom_vec_.push_back(blob_bottom_infogain_);
+ blob_top_vec_.push_back(blob_top_loss_);
}
virtual ~InfogainLossLayerTest() {
delete blob_bottom_data_;
delete blob_bottom_label_;
delete blob_bottom_infogain_;
+ delete blob_top_loss_;
}
Blob<Dtype>* const blob_bottom_data_;
Blob<Dtype>* const blob_bottom_label_;
Blob<Dtype>* const blob_bottom_infogain_;
+ Blob<Dtype>* const blob_top_loss_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};
@@ -59,8 +63,8 @@ TYPED_TEST(InfogainLossLayerTest, TestGradient) {
LayerParameter layer_param;
InfogainLossLayer<Dtype> layer(layer_param);
GradientChecker<Dtype> checker(1e-4, 2e-2, 1701, 1, 0.01);
- checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_),
- &(this->blob_top_vec_), 0, -1, -1);
+ checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+ &(this->blob_top_vec_), 0);
}
} // namespace caffe
diff --git a/src/caffe/test/test_multinomial_logistic_loss_layer.cpp b/src/caffe/test/test_multinomial_logistic_loss_layer.cpp
index 3d1037ba..1fc4c42f 100644
--- a/src/caffe/test/test_multinomial_logistic_loss_layer.cpp
+++ b/src/caffe/test/test_multinomial_logistic_loss_layer.cpp
@@ -20,7 +20,8 @@ class MultinomialLogisticLossLayerTest : public ::testing::Test {
protected:
MultinomialLogisticLossLayerTest()
: blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
- blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)) {
+ blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)),
+ blob_top_loss_(new Blob<Dtype>()) {
Caffe::set_random_seed(1701);
// fill the values
FillerParameter filler_param;
@@ -31,13 +32,16 @@ class MultinomialLogisticLossLayerTest : public ::testing::Test {
blob_bottom_label_->mutable_cpu_data()[i] = caffe_rng_rand() % 5;
}
blob_bottom_vec_.push_back(blob_bottom_label_);
+ blob_top_vec_.push_back(blob_top_loss_);
}
virtual ~MultinomialLogisticLossLayerTest() {
delete blob_bottom_data_;
delete blob_bottom_label_;
+ delete blob_top_loss_;
}
Blob<Dtype>* const blob_bottom_data_;
Blob<Dtype>* const blob_bottom_label_;
+ Blob<Dtype>* const blob_top_loss_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};
@@ -51,8 +55,8 @@ TYPED_TEST(MultinomialLogisticLossLayerTest, TestGradientCPU) {
MultinomialLogisticLossLayer<TypeParam> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
GradientChecker<TypeParam> checker(1e-2, 2*1e-2, 1701, 0, 0.05);
- checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_),
- &(this->blob_top_vec_), 0, -1, -1);
+ checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+ &(this->blob_top_vec_), 0);
}
} // namespace caffe
diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp
index acd3bcdd..497f11d4 100644
--- a/src/caffe/test/test_net.cpp
+++ b/src/caffe/test/test_net.cpp
@@ -144,7 +144,7 @@ class NetTest : public MultiDeviceTest<TypeParam> {
virtual void InitTrickyNet(Dtype* loss_weight = NULL) {
ostringstream loss_weight_stream;
if (loss_weight) {
- loss_weight_stream << " top_loss_weight: " << *loss_weight << " ";
+ loss_weight_stream << " loss_weight: " << *loss_weight << " ";
}
const string& proto =
"name: 'TrickyTestNetwork' "
@@ -220,14 +220,16 @@ class NetTest : public MultiDeviceTest<TypeParam> {
InitNetFromProtoString(proto);
}
- virtual void InitUnsharedWeightsNet(Dtype* loss_weight,
+ // loss_weight is the loss weight for the EUCLIDEAN_LOSS layer output.
+ // midnet_loss_weight is the loss weight for the first INNER_PRODUCT layer
+ // output. Should both default to 0.0 if unspecified (i.e., if NULL is
+ // passed to this function).
+ virtual void InitUnsharedWeightsNet(const Dtype* loss_weight = NULL,
+ const Dtype* midnet_loss_weight = NULL,
const bool force_backward = false, const bool bias_term = false,
const Dtype blobs_lr_w1 = 1, const Dtype blobs_lr_b1 = 2,
const Dtype blobs_lr_w2 = 1, const Dtype blobs_lr_b2 = 2) {
ostringstream proto;
- if (loss_weight) {
- loss_weight_stream << " top_loss_weight: " << *loss_weight << " ";
- }
proto << "name: 'UnsharedWeightsNetwork' ";
if (force_backward) {
proto << "force_backward: true ";
@@ -270,7 +272,11 @@ class NetTest : public MultiDeviceTest<TypeParam> {
}
proto <<
" bottom: 'data' "
- " top: 'innerproduct1' "
+ " top: 'innerproduct1' ";
+ if (midnet_loss_weight) {
+ proto << " loss_weight: " << *midnet_loss_weight << " ";
+ }
+ proto <<
"} "
"layers: { "
" name: 'innerproduct2' "
@@ -300,7 +306,7 @@ class NetTest : public MultiDeviceTest<TypeParam> {
" name: 'loss' "
" type: EUCLIDEAN_LOSS ";
if (loss_weight) {
- proto << " top_loss_weight: " << *loss_weight << " ";
+ proto << " loss_weight: " << *loss_weight << " ";
}
proto <<
" bottom: 'innerproduct1' "
@@ -544,7 +550,6 @@ TYPED_TEST(NetTest, TestBottomNeedBackward) {
}
TYPED_TEST(NetTest, TestBottomNeedBackwardForce) {
- typedef typename TypeParam::Dtype Dtype;
const bool force_backward = true;
this->InitTinyNet(force_backward);
const vector<vector<bool> >& bottom_need_backward =
@@ -559,7 +564,6 @@ TYPED_TEST(NetTest, TestBottomNeedBackwardForce) {
}
TYPED_TEST(NetTest, TestBottomNeedBackwardEuclideanForce) {
- typedef typename TypeParam::Dtype Dtype;
const bool force_backward = true;
this->InitTinyNetEuclidean(force_backward);
const vector<vector<bool> >& bottom_need_backward =
@@ -591,106 +595,45 @@ TYPED_TEST(NetTest, TestBottomNeedBackwardTricky) {
EXPECT_EQ(true, bottom_need_backward[3][1]);
}
-TYPED_TEST(NetTest, TestLossWeightCPU) {
- Caffe::set_mode(Caffe::CPU);
- // First, compute the loss and gradients with no top_loss_weight specified.
- // In this case, the loss weight for the EUCLIDEAN_LOSS layer should default
- // to 1.
- vector<Blob<TypeParam>*> bottom;
- Caffe::set_random_seed(this->seed_);
- const bool kForceBackward = true;
- this->InitUnsharedWeightsNet(NULL, kForceBackward);
- const TypeParam loss = this->net_->ForwardBackward(bottom);
- const bool kCopyDiff = true;
- const bool kReshape = true;
- const vector<shared_ptr<Blob<TypeParam> > >& net_blobs = this->net_->blobs();
- vector<shared_ptr<Blob<TypeParam> > > blob_grads(net_blobs.size());
- for (int i = 0; i < net_blobs.size(); ++i) {
- blob_grads[i].reset(new Blob<TypeParam>());
- blob_grads[i]->CopyFrom(*net_blobs[i], kCopyDiff, kReshape);
- }
- const vector<shared_ptr<Blob<TypeParam> > >& net_params =
- this->net_->params();
- vector<shared_ptr<Blob<TypeParam> > > param_grads(net_params.size());
- for (int i = 0; i < net_params.size(); ++i) {
- param_grads[i].reset(new Blob<TypeParam>());
- param_grads[i]->CopyFrom(*net_params[i], kCopyDiff, kReshape);
- }
- // Check that the loss is non-trivial, otherwise the test doesn't prove much.
- const TypeParam kMinLossAbsValue = 1e-2;
- ASSERT_GE(fabs(loss), kMinLossAbsValue);
- const TypeParam kErrorMargin = 1e-5;
- const int kNumLossWeights = 6;
- TypeParam kLossWeights[kNumLossWeights] = {2, 0, 1, -1, -2.5, 3.7};
- for (int i = 0; i < kNumLossWeights; ++i) {
- Caffe::set_random_seed(this->seed_);
- this->InitUnsharedWeightsNet(&kLossWeights[i], kForceBackward);
- const TypeParam weighted_loss = this->net_->ForwardBackward(bottom);
- const TypeParam error_margin = kErrorMargin * fabs(kLossWeights[i]);
- EXPECT_NEAR(loss * kLossWeights[i], weighted_loss, error_margin)
- << "loss weight = " << kLossWeights[i];
- const vector<shared_ptr<Blob<TypeParam> > >& weighted_blobs =
- this->net_->blobs();
- ASSERT_EQ(blob_grads.size(), weighted_blobs.size());
- for (int j = 0; j < blob_grads.size(); ++j) {
- ASSERT_EQ(blob_grads[j]->count(), weighted_blobs[j]->count());
- for (int k = 0; k < blob_grads[j]->count(); ++k) {
- EXPECT_NEAR(blob_grads[j]->cpu_diff()[k] * kLossWeights[i],
- weighted_blobs[j]->cpu_diff()[k], error_margin);
- }
- }
- const vector<shared_ptr<Blob<TypeParam> > >& weighted_params =
- this->net_->params();
- ASSERT_EQ(param_grads.size(), weighted_params.size());
- for (int j = 0; j < param_grads.size(); ++j) {
- ASSERT_EQ(param_grads[j]->count(), weighted_params[j]->count());
- for (int k = 0; k < param_grads[j]->count(); ++k) {
- EXPECT_NEAR(param_grads[j]->cpu_diff()[k] * kLossWeights[i],
- weighted_params[j]->cpu_diff()[k], error_margin);
- }
- }
- }
-}
-
-TYPED_TEST(NetTest, TestLossWeightGPU) {
- Caffe::set_mode(Caffe::GPU);
- // First, compute the loss and gradients with no top_loss_weight specified.
+TYPED_TEST(NetTest, TestLossWeight) {
+ typedef typename TypeParam::Dtype Dtype;
+ // First, compute the loss and gradients with no loss_weight specified.
// In this case, the loss weight for the EUCLIDEAN_LOSS layer should default
// to 1.
- vector<Blob<TypeParam>*> bottom;
+ vector<Blob<Dtype>*> bottom;
Caffe::set_random_seed(this->seed_);
const bool kForceBackward = true;
- this->InitUnsharedWeightsNet(NULL, kForceBackward);
- const TypeParam loss = this->net_->ForwardBackward(bottom);
+ this->InitUnsharedWeightsNet(NULL, NULL, kForceBackward);
+ const Dtype loss = this->net_->ForwardBackward(bottom);
const bool kCopyDiff = true;
const bool kReshape = true;
- const vector<shared_ptr<Blob<TypeParam> > >& net_blobs = this->net_->blobs();
- vector<shared_ptr<Blob<TypeParam> > > blob_grads(net_blobs.size());
+ const vector<shared_ptr<Blob<Dtype> > >& net_blobs = this->net_->blobs();
+ vector<shared_ptr<Blob<Dtype> > > blob_grads(net_blobs.size());
for (int i = 0; i < net_blobs.size(); ++i) {
- blob_grads[i].reset(new Blob<TypeParam>());
+ blob_grads[i].reset(new Blob<Dtype>());
blob_grads[i]->CopyFrom(*net_blobs[i], kCopyDiff, kReshape);
}
- const vector<shared_ptr<Blob<TypeParam> > >& net_params =
+ const vector<shared_ptr<Blob<Dtype> > >& net_params =
this->net_->params();
- vector<shared_ptr<Blob<TypeParam> > > param_grads(net_params.size());
+ vector<shared_ptr<Blob<Dtype> > > param_grads(net_params.size());
for (int i = 0; i < net_params.size(); ++i) {
- param_grads[i].reset(new Blob<TypeParam>());
+ param_grads[i].reset(new Blob<Dtype>());
param_grads[i]->CopyFrom(*net_params[i], kCopyDiff, kReshape);
}
// Check that the loss is non-trivial, otherwise the test doesn't prove much.
- const TypeParam kMinLossAbsValue = 1e-2;
+ const Dtype kMinLossAbsValue = 1e-2;
ASSERT_GE(fabs(loss), kMinLossAbsValue);
const Dtype kErrorMargin = 1e-4;
const int kNumLossWeights = 6;
- TypeParam kLossWeights[kNumLossWeights] = {2, 0, 1, -1, -2.5, 3.7};
+ Dtype kLossWeights[kNumLossWeights] = {2, 0, 1, -1, -2.5, 3.7};
for (int i = 0; i < kNumLossWeights; ++i) {
Caffe::set_random_seed(this->seed_);
- this->InitUnsharedWeightsNet(&kLossWeights[i], kForceBackward);
- const TypeParam weighted_loss = this->net_->ForwardBackward(bottom);
- const TypeParam error_margin = kErrorMargin * fabs(kLossWeights[i]);
+ this->InitUnsharedWeightsNet(&kLossWeights[i], NULL, kForceBackward);
+ const Dtype weighted_loss = this->net_->ForwardBackward(bottom);
+ const Dtype error_margin = kErrorMargin * fabs(kLossWeights[i]);
EXPECT_NEAR(loss * kLossWeights[i], weighted_loss, error_margin)
<< "loss weight = " << kLossWeights[i];
- const vector<shared_ptr<Blob<TypeParam> > >& weighted_blobs =
+ const vector<shared_ptr<Blob<Dtype> > >& weighted_blobs =
this->net_->blobs();
ASSERT_EQ(blob_grads.size(), weighted_blobs.size());
for (int j = 0; j < blob_grads.size(); ++j) {
@@ -700,7 +643,7 @@ TYPED_TEST(NetTest, TestLossWeightGPU) {
weighted_blobs[j]->cpu_diff()[k], error_margin);
}
}
- const vector<shared_ptr<Blob<TypeParam> > >& weighted_params =
+ const vector<shared_ptr<Blob<Dtype> > >& weighted_params =
this->net_->params();
ASSERT_EQ(param_grads.size(), weighted_params.size());
for (int j = 0; j < param_grads.size(); ++j) {
@@ -859,15 +802,15 @@ TYPED_TEST(NetTest, TestSharedWeightsUpdate) {
TYPED_TEST(NetTest, TestParamPropagateDown) {
typedef typename TypeParam::Dtype Dtype;
vector<Blob<Dtype>*> bottom;
- const bool kBiasTerm = true;
- const bool kForceBackward = false;
- const Dtype* kLossWeight = NULL;
+ const bool kBiasTerm = true, kForceBackward = false;
+ const Dtype* kLossWeight1 = NULL;
+ const Dtype* kLossWeight2 = NULL;
// Run the net with all params learned; check that gradients are non-zero.
Caffe::set_random_seed(this->seed_);
Dtype blobs_lr_w1 = 1, blobs_lr_w2 = 1, blobs_lr_b1 = 2, blobs_lr_b2 = 2;
- this->InitUnsharedWeightsNet(kLossWeight, kForceBackward, kBiasTerm
- blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
+ this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward,
+ kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
this->net_->Forward(bottom);
this->net_->Backward();
const vector<shared_ptr<Blob<Dtype> > >& params = this->net_->params();
@@ -886,8 +829,8 @@ TYPED_TEST(NetTest, TestParamPropagateDown) {
// gradients.
Caffe::set_random_seed(this->seed_);
blobs_lr_w1 *= 2, blobs_lr_w2 *= 2, blobs_lr_b1 *= 2, blobs_lr_b2 *= 2;
- this->InitUnsharedWeightsNet(kLossWeight, kForceBackward, kBiasTerm
- blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
+ this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward,
+ kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
this->net_->Forward(bottom);
this->net_->Backward();
const vector<shared_ptr<Blob<Dtype> > >& params2 = this->net_->params();
@@ -902,8 +845,8 @@ TYPED_TEST(NetTest, TestParamPropagateDown) {
// gradients for those.
Caffe::set_random_seed(this->seed_);
blobs_lr_w1 = 1, blobs_lr_w2 = 0, blobs_lr_b1 = 0, blobs_lr_b2 = 1;
- this->InitUnsharedWeightsNet(kLossWeight, kForceBackward, kBiasTerm
- blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
+ this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward,
+ kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
this->net_->Forward(bottom);
this->net_->Backward();
const vector<shared_ptr<Blob<Dtype> > >& params3 = this->net_->params();
@@ -921,8 +864,8 @@ TYPED_TEST(NetTest, TestParamPropagateDown) {
// Change the opposite subset of the learning rates to zero.
Caffe::set_random_seed(this->seed_);
blobs_lr_w1 = 0, blobs_lr_w2 = 1, blobs_lr_b1 = 1, blobs_lr_b2 = 0;
- this->InitUnsharedWeightsNet(kLossWeight, kForceBackward, kBiasTerm
- blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
+ this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward,
+ kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
this->net_->Forward(bottom);
this->net_->Backward();
const vector<shared_ptr<Blob<Dtype> > >& params4 = this->net_->params();
diff --git a/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp
index f5716c9e..47ccdea1 100644
--- a/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp
+++ b/src/caffe/test/test_sigmoid_cross_entropy_loss_layer.cpp
@@ -22,7 +22,8 @@ class SigmoidCrossEntropyLossLayerTest : public MultiDeviceTest<TypeParam> {
protected:
SigmoidCrossEntropyLossLayerTest()
: blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
- blob_bottom_targets_(new Blob<Dtype>(10, 5, 1, 1)) {
+ blob_bottom_targets_(new Blob<Dtype>(10, 5, 1, 1)),
+ blob_top_loss_(new Blob<Dtype>()) {
// Fill the data vector
FillerParameter data_filler_param;
data_filler_param.set_std(1);
@@ -36,10 +37,12 @@ class SigmoidCrossEntropyLossLayerTest : public MultiDeviceTest<TypeParam> {
UniformFiller<Dtype> targets_filler(targets_filler_param);
targets_filler.Fill(blob_bottom_targets_);
blob_bottom_vec_.push_back(blob_bottom_targets_);
+ blob_top_vec_.push_back(blob_top_loss_);
}
virtual ~SigmoidCrossEntropyLossLayerTest() {
delete blob_bottom_data_;
delete blob_bottom_targets_;
+ delete blob_top_loss_;
}
Dtype SigmoidCrossEntropyLossReference(const int count, const int num,
@@ -60,6 +63,8 @@ class SigmoidCrossEntropyLossLayerTest : public MultiDeviceTest<TypeParam> {
void TestForward() {
LayerParameter layer_param;
+ const Dtype kLossWeight = 3.7;
+ layer_param.add_loss_weight(kLossWeight);
FillerParameter data_filler_param;
data_filler_param.set_std(1);
GaussianFiller<Dtype> data_filler(data_filler_param);
@@ -82,7 +87,7 @@ class SigmoidCrossEntropyLossLayerTest : public MultiDeviceTest<TypeParam> {
const Dtype* blob_bottom_data = this->blob_bottom_data_->cpu_data();
const Dtype* blob_bottom_targets =
this->blob_bottom_targets_->cpu_data();
- Dtype reference_loss = this->SigmoidCrossEntropyLossReference(
+ Dtype reference_loss = kLossWeight * SigmoidCrossEntropyLossReference(
count, num, blob_bottom_data, blob_bottom_targets);
EXPECT_NEAR(reference_loss, layer_loss, eps) << "debug: trial #" << i;
}
@@ -90,6 +95,7 @@ class SigmoidCrossEntropyLossLayerTest : public MultiDeviceTest<TypeParam> {
Blob<Dtype>* const blob_bottom_data_;
Blob<Dtype>* const blob_bottom_targets_;
+ Blob<Dtype>* const blob_top_loss_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};
@@ -103,11 +109,13 @@ TYPED_TEST(SigmoidCrossEntropyLossLayerTest, TestSigmoidCrossEntropyLoss) {
TYPED_TEST(SigmoidCrossEntropyLossLayerTest, TestGradient) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
+ const Dtype kLossWeight = 3.7;
+ layer_param.add_loss_weight(kLossWeight);
SigmoidCrossEntropyLossLayer<Dtype> layer(layer_param);
layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
GradientChecker<Dtype> checker(1e-2, 1e-2, 1701);
- checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_),
- &(this->blob_top_vec_), 0, -1, -1);
+ checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+ &(this->blob_top_vec_), 0);
}
diff --git a/src/caffe/test/test_softmax_with_loss_layer.cpp b/src/caffe/test/test_softmax_with_loss_layer.cpp
index bd39bd44..0f0adbba 100644
--- a/src/caffe/test/test_softmax_with_loss_layer.cpp
+++ b/src/caffe/test/test_softmax_with_loss_layer.cpp
@@ -22,7 +22,8 @@ class SoftmaxWithLossLayerTest : public MultiDeviceTest<TypeParam> {
protected:
SoftmaxWithLossLayerTest()
: blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
- blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)) {
+ blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)),
+ blob_top_loss_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
filler_param.set_std(10);
@@ -33,13 +34,16 @@ class SoftmaxWithLossLayerTest : public MultiDeviceTest<TypeParam> {
blob_bottom_label_->mutable_cpu_data()[i] = caffe_rng_rand() % 5;
}
blob_bottom_vec_.push_back(blob_bottom_label_);
+ blob_top_vec_.push_back(blob_top_loss_);
}
virtual ~SoftmaxWithLossLayerTest() {
delete blob_bottom_data_;
delete blob_bottom_label_;
+ delete blob_top_loss_;
}
Blob<Dtype>* const blob_bottom_data_;
Blob<Dtype>* const blob_bottom_label_;
+ Blob<Dtype>* const blob_top_loss_;
vector<Blob<Dtype>*> blob_bottom_vec_;
vector<Blob<Dtype>*> blob_top_vec_;
};
@@ -50,11 +54,11 @@ TYPED_TEST_CASE(SoftmaxWithLossLayerTest, TestDtypesAndDevices);
TYPED_TEST(SoftmaxWithLossLayerTest, TestGradient) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter layer_param;
+ layer_param.add_loss_weight(3);
SoftmaxWithLossLayer<Dtype> layer(layer_param);
- layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
GradientChecker<Dtype> checker(1e-2, 1e-2, 1701);
- checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_),
- &(this->blob_top_vec_), 0, -1, -1);
+ checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+ &(this->blob_top_vec_), 0);
}
} // namespace caffe
diff --git a/src/caffe/test/test_split_layer.cpp b/src/caffe/test/test_split_layer.cpp
index 711669ba..bf634f58 100644
--- a/src/caffe/test/test_split_layer.cpp
+++ b/src/caffe/test/test_split_layer.cpp
@@ -8,6 +8,7 @@
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
+#include "caffe/proto/caffe.pb.h"
#include "caffe/util/insert_splits.hpp"
#include "caffe/vision_layers.hpp"