summaryrefslogtreecommitdiff
path: root/src/caffe
diff options
context:
space:
mode:
authorSimon Layton <slayton@nvidia.com>2015-02-09 22:08:39 -0500
committerEvan Shelhamer <shelhamer@imaginarynumber.net>2015-03-24 14:04:08 -0700
commit600c81d3da6dcc135734af789a9ffcad11515545 (patch)
tree7897df1a37b42edf2b78ef60a9517924e65c4647 /src/caffe
parentf2b52165679b38703bf3a429b9b7cd8979068a0e (diff)
downloadcaffeonacl-600c81d3da6dcc135734af789a9ffcad11515545.tar.gz
caffeonacl-600c81d3da6dcc135734af789a9ffcad11515545.tar.bz2
caffeonacl-600c81d3da6dcc135734af789a9ffcad11515545.zip
switch to cuDNN R2
Diffstat (limited to 'src/caffe')
-rw-r--r--src/caffe/layers/cudnn_conv_layer.cpp10
-rw-r--r--src/caffe/layers/cudnn_conv_layer.cu89
-rw-r--r--src/caffe/layers/cudnn_pooling_layer.cpp4
-rw-r--r--src/caffe/layers/cudnn_pooling_layer.cu20
-rw-r--r--src/caffe/layers/cudnn_relu_layer.cpp4
-rw-r--r--src/caffe/layers/cudnn_relu_layer.cu24
-rw-r--r--src/caffe/layers/cudnn_sigmoid_layer.cpp4
-rw-r--r--src/caffe/layers/cudnn_sigmoid_layer.cu24
-rw-r--r--src/caffe/layers/cudnn_softmax_layer.cpp4
-rw-r--r--src/caffe/layers/cudnn_softmax_layer.cu22
-rw-r--r--src/caffe/layers/cudnn_tanh_layer.cpp4
-rw-r--r--src/caffe/layers/cudnn_tanh_layer.cu24
12 files changed, 174 insertions, 59 deletions
diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp
index 4a69ca20..524caf13 100644
--- a/src/caffe/layers/cudnn_conv_layer.cpp
+++ b/src/caffe/layers/cudnn_conv_layer.cpp
@@ -43,10 +43,10 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
// Create tensor descriptor(s) for data and corresponding convolution(s).
for (int i = 0; i < bottom.size(); i++) {
- cudnnTensor4dDescriptor_t bottom_desc;
+ cudnnTensorDescriptor_t bottom_desc;
cudnn::createTensor4dDesc<Dtype>(&bottom_desc);
bottom_descs_.push_back(bottom_desc);
- cudnnTensor4dDescriptor_t top_desc;
+ cudnnTensorDescriptor_t top_desc;
cudnn::createTensor4dDesc<Dtype>(&top_desc);
top_descs_.push_back(top_desc);
cudnnConvolutionDescriptor_t conv_desc;
@@ -104,12 +104,12 @@ CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
if (!handles_setup_) { return; }
for (int i = 0; i < bottom_descs_.size(); i++) {
- cudnnDestroyTensor4dDescriptor(bottom_descs_[i]);
- cudnnDestroyTensor4dDescriptor(top_descs_[i]);
+ cudnnDestroyTensorDescriptor(bottom_descs_[i]);
+ cudnnDestroyTensorDescriptor(top_descs_[i]);
cudnnDestroyConvolutionDescriptor(conv_descs_[i]);
}
if (this->bias_term_) {
- cudnnDestroyTensor4dDescriptor(bias_desc_);
+ cudnnDestroyTensorDescriptor(bias_desc_);
}
cudnnDestroyFilterDescriptor(filter_desc_);
diff --git a/src/caffe/layers/cudnn_conv_layer.cu b/src/caffe/layers/cudnn_conv_layer.cu
index 071014e1..4a70c69a 100644
--- a/src/caffe/layers/cudnn_conv_layer.cu
+++ b/src/caffe/layers/cudnn_conv_layer.cu
@@ -21,21 +21,59 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
// Forward through cuDNN in parallel over groups.
for (int g = 0; g < this->group_; g++) {
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
+ cudnnConvolutionFwdAlgo_t algo;
+
+ // get the desired convolution algorithm
+ CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[g],
+ bottom_descs_[i],
+ filter_desc_,
+ conv_descs_[i],
+ top_descs_[i],
+ CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+ 0, // memoryLimitInBytes,
+ &algo));
+
+ // get minimum size of the workspace needed for the desired algorithm
+ size_t workspaceSizeInBytes_temp = 0;
+
+ CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[g],
+ bottom_descs_[i],
+ filter_desc_,
+ conv_descs_[i],
+ top_descs_[i],
+ algo,
+ &workspaceSizeInBytes));
+
+ if (workspaceSizeInBytes_temp > workspaceSizeInBytes) {
+ workspaceSizeInBytes = workspaceSizeInBytes_temp;
+ // free the existing workspace and allocate a new (larger) one
+ cudaFree(this->workspace);
+ cudaMalloc(&(this->workspace), workspaceSizeInBytes);
+ }
+
// Filters.
CUDNN_CHECK(cudnnConvolutionForward(handle_[g],
- bottom_descs_[i], bottom_data + bottom_offset_ * g,
- filter_desc_, weight + weight_offset_ * g,
- conv_descs_[i],
- top_descs_[i], top_data + top_offset_ * g,
- CUDNN_RESULT_NO_ACCUMULATE));
+ reinterpret_cast<void *>(&alpha),
+ bottom_descs_[i], bottom_data + bottom_offset_ * g,
+ filter_desc_, weight + weight_offset_ * g,
+ conv_descs_[i],
+ algo, workspace, workspaceSizeInBytes,
+ reinterpret_cast<void *>(&beta),
+ top_descs_[i], top_data + top_offset_ * g));
// Bias.
if (this->bias_term_) {
const Dtype* bias_data = this->blobs_[1]->gpu_data();
- Dtype alpha = 1.;
- CUDNN_CHECK(cudnnAddTensor4d(handle_[g], CUDNN_ADD_SAME_C, &alpha,
- bias_desc_, bias_data + bias_offset_ * g,
- top_descs_[i], top_data + top_offset_ * g));
+ Dtype alpha = 1.0;
+ Dtype beta = 1.0;
+ CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C,
+ reinterpret_cast<void *>(&alpha),
+ bias_desc_, bias_data + bias_offset_ * g,
+ reinterpret_cast<void *>(&beta),
+ top_descs_[i], top_data + top_offset_ * g));
}
}
@@ -67,21 +105,26 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
for (int g = 0; g < this->group_; g++) {
// Gradient w.r.t. bias.
if (this->bias_term_ && this->param_propagate_down_[1]) {
+ Dtype alpha = 1.0;
+ Dtype beta = 1.0;
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0*this->group_ + g],
- top_descs_[i], top_diff + top_offset_ * g,
- bias_desc_, bias_diff + bias_offset_ * g,
- CUDNN_RESULT_ACCUMULATE));
+ reinterpret_cast<void *>(&alpha),
+ top_descs_[i], top_diff + top_offset_ * g,
+ reinterpret_cast<void *>(&beta),
+ bias_desc_, bias_diff + bias_offset_ * g));
}
// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
const Dtype* bottom_data = bottom[i]->gpu_data();
+ Dtype alpha = 1.0;
+ Dtype beta = 1.0;
CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g],
- bottom_descs_[i], bottom_data + bottom_offset_ * g,
- top_descs_[i], top_diff + top_offset_ * g,
- conv_descs_[i],
- filter_desc_, weight_diff + weight_offset_ * g,
- CUDNN_RESULT_ACCUMULATE));
+ reinterpret_cast<void *>(&alpha),
+ bottom_descs_[i], bottom_data + bottom_offset_ * g,
+ top_descs_[i], top_diff + top_offset_ * g,
+ conv_descs_[i], reinterpret_cast<void *>(&beta),
+ filter_desc_, weight_diff + weight_offset_ * g));
}
// Gradient w.r.t. bottom data.
@@ -90,12 +133,14 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
weight = this->blobs_[0]->gpu_data();
}
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g],
- filter_desc_, weight + weight_offset_ * g,
- top_descs_[i], top_diff + top_offset_ * g,
- conv_descs_[i],
- bottom_descs_[i], bottom_diff + bottom_offset_ * g,
- CUDNN_RESULT_NO_ACCUMULATE));
+ reinterpret_cast<void *>(&alpha),
+ filter_desc_, weight + weight_offset_ * g,
+ top_descs_[i], top_diff + top_offset_ * g,
+ conv_descs_[i], reinterpret_cast<void *>(&beta),
+ bottom_descs_[i], bottom_diff + bottom_offset_ * g));
}
}
diff --git a/src/caffe/layers/cudnn_pooling_layer.cpp b/src/caffe/layers/cudnn_pooling_layer.cpp
index dd901956..b447f19b 100644
--- a/src/caffe/layers/cudnn_pooling_layer.cpp
+++ b/src/caffe/layers/cudnn_pooling_layer.cpp
@@ -40,8 +40,8 @@ CuDNNPoolingLayer<Dtype>::~CuDNNPoolingLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
- cudnnDestroyTensor4dDescriptor(bottom_desc_);
- cudnnDestroyTensor4dDescriptor(top_desc_);
+ cudnnDestroyTensorDescriptor(bottom_desc_);
+ cudnnDestroyTensorDescriptor(top_desc_);
cudnnDestroyPoolingDescriptor(pooling_desc_);
cudnnDestroy(handle_);
}
diff --git a/src/caffe/layers/cudnn_pooling_layer.cu b/src/caffe/layers/cudnn_pooling_layer.cu
index 1c113aad..be7c4a8e 100644
--- a/src/caffe/layers/cudnn_pooling_layer.cu
+++ b/src/caffe/layers/cudnn_pooling_layer.cu
@@ -14,8 +14,15 @@ void CuDNNPoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
+
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_,
- bottom_desc_, bottom_data, top_desc_, top_data));
+ reinterpret_cast<void *>(&alpha),
+ bottom_desc_, bottom_data,
+ reinterpret_cast<void *>(&beta),
+ top_desc_, top_data));
}
template <typename Dtype>
@@ -28,9 +35,16 @@ void CuDNNPoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_data = top[0]->gpu_data();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_,
- top_desc_, top_data, top_desc_, top_diff,
- bottom_desc_, bottom_data, bottom_desc_, bottom_diff));
+ reinterpret_cast<void *>(&alpha),
+ top_desc_, top_data, top_desc_, top_diff,
+ bottom_desc_, bottom_data,
+ reinterpret_cast<void *>(&beta),
+ bottom_desc_, bottom_diff));
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNPoolingLayer);
diff --git a/src/caffe/layers/cudnn_relu_layer.cpp b/src/caffe/layers/cudnn_relu_layer.cpp
index 0b8a6bc3..759d8398 100644
--- a/src/caffe/layers/cudnn_relu_layer.cpp
+++ b/src/caffe/layers/cudnn_relu_layer.cpp
@@ -35,8 +35,8 @@ CuDNNReLULayer<Dtype>::~CuDNNReLULayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
- cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
- cudnnDestroyTensor4dDescriptor(this->top_desc_);
+ cudnnDestroyTensorDescriptor(this->bottom_desc_);
+ cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}
diff --git a/src/caffe/layers/cudnn_relu_layer.cu b/src/caffe/layers/cudnn_relu_layer.cu
index 86250870..b9d0870a 100644
--- a/src/caffe/layers/cudnn_relu_layer.cu
+++ b/src/caffe/layers/cudnn_relu_layer.cu
@@ -17,9 +17,16 @@ void CuDNNReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
+
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
CUDNN_CHECK(cudnnActivationForward(this->handle_,
- CUDNN_ACTIVATION_RELU,
- this->bottom_desc_, bottom_data, this->top_desc_, top_data));
+ CUDNN_ACTIVATION_RELU,
+ reinterpret_cast<void *>(&alpha),
+ this->bottom_desc_, bottom_data,
+ reinterpret_cast<void *>(&beta),
+ this->top_desc_, top_data));
}
template <typename Dtype>
@@ -39,10 +46,17 @@ void CuDNNReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
- CUDNN_ACTIVATION_RELU,
- this->top_desc_, top_data, this->top_desc_, top_diff,
- this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff));
+ CUDNN_ACTIVATION_RELU,
+ reinterpret_cast<void *>(&alpha),
+ this->top_desc_, top_data, this->top_desc_, top_diff,
+ this->bottom_desc_, bottom_data,
+ reinterpret_cast<void *>(&beta),
+ this->bottom_desc_, bottom_diff));
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNReLULayer);
diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cpp b/src/caffe/layers/cudnn_sigmoid_layer.cpp
index 67bd9c37..32637873 100644
--- a/src/caffe/layers/cudnn_sigmoid_layer.cpp
+++ b/src/caffe/layers/cudnn_sigmoid_layer.cpp
@@ -35,8 +35,8 @@ CuDNNSigmoidLayer<Dtype>::~CuDNNSigmoidLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
- cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
- cudnnDestroyTensor4dDescriptor(this->top_desc_);
+ cudnnDestroyTensorDescriptor(this->bottom_desc_);
+ cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}
diff --git a/src/caffe/layers/cudnn_sigmoid_layer.cu b/src/caffe/layers/cudnn_sigmoid_layer.cu
index 31b094e2..9bb91501 100644
--- a/src/caffe/layers/cudnn_sigmoid_layer.cu
+++ b/src/caffe/layers/cudnn_sigmoid_layer.cu
@@ -12,9 +12,16 @@ void CuDNNSigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
+
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
CUDNN_CHECK(cudnnActivationForward(this->handle_,
- CUDNN_ACTIVATION_SIGMOID,
- this->bottom_desc_, bottom_data, this->top_desc_, top_data));
+ CUDNN_ACTIVATION_SIGMOID,
+ reinterpret_cast<void *>(&alpha),
+ this->bottom_desc_, bottom_data,
+ reinterpret_cast<void *>(&beta),
+ this->top_desc_, top_data));
}
template <typename Dtype>
@@ -29,10 +36,17 @@ void CuDNNSigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
- CUDNN_ACTIVATION_SIGMOID,
- this->top_desc_, top_data, this->top_desc_, top_diff,
- this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff));
+ CUDNN_ACTIVATION_SIGMOID,
+ reinterpret_cast<void *>(&alpha),
+ this->top_desc_, top_data, this->top_desc_, top_diff,
+ this->bottom_desc_, bottom_data,
+ reinterpret_cast<void *>(&beta),
+ this->bottom_desc_, bottom_diff));
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNSigmoidLayer);
diff --git a/src/caffe/layers/cudnn_softmax_layer.cpp b/src/caffe/layers/cudnn_softmax_layer.cpp
index 211701ca..77a3225a 100644
--- a/src/caffe/layers/cudnn_softmax_layer.cpp
+++ b/src/caffe/layers/cudnn_softmax_layer.cpp
@@ -39,8 +39,8 @@ CuDNNSoftmaxLayer<Dtype>::~CuDNNSoftmaxLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
- cudnnDestroyTensor4dDescriptor(bottom_desc_);
- cudnnDestroyTensor4dDescriptor(top_desc_);
+ cudnnDestroyTensorDescriptor(bottom_desc_);
+ cudnnDestroyTensorDescriptor(top_desc_);
cudnnDestroy(handle_);
}
diff --git a/src/caffe/layers/cudnn_softmax_layer.cu b/src/caffe/layers/cudnn_softmax_layer.cu
index f328afdd..59c304f6 100644
--- a/src/caffe/layers/cudnn_softmax_layer.cu
+++ b/src/caffe/layers/cudnn_softmax_layer.cu
@@ -16,9 +16,16 @@ void CuDNNSoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
+
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
CUDNN_CHECK(cudnnSoftmaxForward(handle_, CUDNN_SOFTMAX_ACCURATE,
- CUDNN_SOFTMAX_MODE_CHANNEL,
- bottom_desc_, bottom_data, top_desc_, top_data));
+ CUDNN_SOFTMAX_MODE_CHANNEL,
+ reinterpret_cast<void *>(&alpha),
+ bottom_desc_, bottom_data,
+ reinterpret_cast<void *>(&beta),
+ top_desc_, top_data));
}
template <typename Dtype>
@@ -29,9 +36,16 @@ void CuDNNSoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
CUDNN_CHECK(cudnnSoftmaxBackward(handle_, CUDNN_SOFTMAX_ACCURATE,
- CUDNN_SOFTMAX_MODE_CHANNEL,
- top_desc_, top_data, top_desc_, top_diff, bottom_desc_, bottom_diff));
+ CUDNN_SOFTMAX_MODE_CHANNEL,
+ reinterpret_cast<void *>(&alpha),
+ top_desc_, top_data, top_desc_, top_diff,
+ reinterpret_cast<void *>(&beta),
+ bottom_desc_, bottom_diff));
}
}
diff --git a/src/caffe/layers/cudnn_tanh_layer.cpp b/src/caffe/layers/cudnn_tanh_layer.cpp
index b1d2b863..376faad3 100644
--- a/src/caffe/layers/cudnn_tanh_layer.cpp
+++ b/src/caffe/layers/cudnn_tanh_layer.cpp
@@ -35,8 +35,8 @@ CuDNNTanHLayer<Dtype>::~CuDNNTanHLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
- cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
- cudnnDestroyTensor4dDescriptor(this->top_desc_);
+ cudnnDestroyTensorDescriptor(this->bottom_desc_);
+ cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}
diff --git a/src/caffe/layers/cudnn_tanh_layer.cu b/src/caffe/layers/cudnn_tanh_layer.cu
index bf9ec7cf..e008b0dc 100644
--- a/src/caffe/layers/cudnn_tanh_layer.cu
+++ b/src/caffe/layers/cudnn_tanh_layer.cu
@@ -12,9 +12,16 @@ void CuDNNTanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
+
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
CUDNN_CHECK(cudnnActivationForward(this->handle_,
- CUDNN_ACTIVATION_TANH,
- this->bottom_desc_, bottom_data, this->top_desc_, top_data));
+ CUDNN_ACTIVATION_TANH,
+ reinterpret_cast<void *>(&alpha),
+ this->bottom_desc_, bottom_data,
+ reinterpret_cast<void *>(&beta),
+ this->top_desc_, top_data));
}
template <typename Dtype>
@@ -29,10 +36,17 @@ void CuDNNTanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+
+ Dtype alpha = 1.0;
+ Dtype beta = 0.0;
+
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
- CUDNN_ACTIVATION_TANH,
- this->top_desc_, top_data, this->top_desc_, top_diff,
- this->bottom_desc_, bottom_data, this->bottom_desc_, bottom_diff));
+ CUDNN_ACTIVATION_TANH,
+ reinterpret_cast<void *>(&alpha),
+ this->top_desc_, top_data, this->top_desc_, top_diff,
+ this->bottom_desc_, bottom_data,
+ reinterpret_cast<void *>(&beta),
+ this->bottom_desc_, bottom_diff));
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNTanHLayer);