From 8448708ba37c920ae6b126a3b6f0f4353e848b01 Mon Sep 17 00:00:00 2001 From: Aravindh Mahendran Date: Sun, 16 Feb 2014 10:43:34 -0500 Subject: Added tanh activation function layer. --- include/caffe/vision_layers.hpp | 17 ++++++++ src/caffe/layer_factory.cpp | 2 + src/caffe/layers/tanh_layer.cu | 97 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 src/caffe/layers/tanh_layer.cu diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 82e52cd5..47909a21 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -44,6 +44,23 @@ class ReLULayer : public NeuronLayer { const bool propagate_down, vector*>* bottom); }; +template +class TanHLayer : public NeuronLayer { + public: + explicit TanHLayer(const LayerParameter& param) + : NeuronLayer(param) {} + + protected: + virtual void Forward_cpu(const vector*>& bottom, + vector*>* top); + virtual void Forward_gpu(const vector*>& bottom, + vector*>* top); + + virtual Dtype Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); + virtual Dtype Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); +}; template class SigmoidLayer : public NeuronLayer { diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index b62ba383..cb65e8f7 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -47,6 +47,8 @@ Layer* GetLayer(const LayerParameter& param) { return new PoolingLayer(param); } else if (type == "relu") { return new ReLULayer(param); + } else if (type == "tanh") { + return new TanHLayer(param); } else if (type == "sigmoid") { return new SigmoidLayer(param); } else if (type == "softmax") { diff --git a/src/caffe/layers/tanh_layer.cu b/src/caffe/layers/tanh_layer.cu new file mode 100644 index 00000000..22e0831a --- /dev/null +++ b/src/caffe/layers/tanh_layer.cu @@ -0,0 +1,97 @@ +// Copyright 2014 Aravindh Mahendran +// TanH neuron activation function layer. Adapted from ReLU layer code written by Yangqing Jia + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include + +namespace caffe { + +template +void TanHLayer::Forward_cpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = (*top)[0]->mutable_cpu_data(); + Dtype exp2x; + const int count = bottom[0]->count(); + for (int i = 0; i < count; ++i) { + exp2x = exp(2*bottom_data[i]); + top_data[i] = (exp2x - Dtype(1))/(exp2x + Dtype(1)); + } +} + +template +Dtype TanHLayer::Backward_cpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + if (propagate_down) { + const Dtype* bottom_data = (*bottom)[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + const int count = (*bottom)[0]->count(); + Dtype exp2x; + Dtype tanhx; + for (int i = 0; i < count; ++i) { + exp2x = exp(2*bottom_data[i]); + tanhx = (exp2x - Dtype(1))/(exp2x + Dtype(1)); + bottom_diff[i] = top_diff[i] * (1 - tanhx*tanhx); + } + } + return Dtype(0); +} + +template +__global__ void TanHForward(const int n, const Dtype* in, Dtype* out) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + Dtype exp2x = exp(2*in[index]); + out[index] = (exp2x - Dtype(1))/(exp2x + Dtype(1)); + } +} + +template +void TanHLayer::Forward_gpu(const vector*>& bottom, + vector*>* top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = (*top)[0]->mutable_gpu_data(); + const int count = bottom[0]->count(); + TanHForward<<>>( + count, bottom_data, top_data); + CUDA_POST_KERNEL_CHECK; + // << " count: " << count << " bottom_data: " + // << (unsigned long)bottom_data << " top_data: " << (unsigned long)top_data + // << " blocks: " << CAFFE_GET_BLOCKS(count) + // << " threads: " << CAFFE_CUDA_NUM_THREADS; +} + +template +__global__ void TanHBackward(const int n, const Dtype* in_diff, + const Dtype* in_data, Dtype* out_diff) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < n) { + Dtype exp2x = exp(2*in_data[index]); + Dtype tanhx = (exp2x - Dtype(1))/(exp2x + Dtype(1)); + out_diff[index] = in_diff[index] * (1 - tanhx*tanhx); + } +} + +template +Dtype TanHLayer::Backward_gpu(const vector*>& top, + const bool propagate_down, + vector*>* bottom) { + if (propagate_down) { + const Dtype* bottom_data = (*bottom)[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + const int count = (*bottom)[0]->count(); + TanHBackward<<>>( + count, top_diff, bottom_data, bottom_diff); + CUDA_POST_KERNEL_CHECK; + } + return Dtype(0); +} + +INSTANTIATE_CLASS(TanHLayer); + + +} // namespace caffe -- cgit v1.2.3