diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2017-09-29 08:52:35 -0700 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2017-10-19 19:51:10 +0200 |
commit | 98e67448fa78bd1bc6f05920ad03efceecc10066 (patch) | |
tree | 50a690a709583559368145aacce9a460b184ce0c | |
parent | 3a4ca7a2696ac5f8d3a32108648f588bbc2b1eaa (diff) | |
download | pytorch-98e67448fa78bd1bc6f05920ad03efceecc10066.tar.gz pytorch-98e67448fa78bd1bc6f05920ad03efceecc10066.tar.bz2 pytorch-98e67448fa78bd1bc6f05920ad03efceecc10066.zip |
Large Softmax and LogSoftmax refactor
- Cleaned up THNN and THCUNN code and kernels
- Improved THCUNN kernel performance 5x, making it match cuDNN performance
- Added support for computing softmax over arbitrary dims
NOTE: The default dim for 3D inputs is now 1 (used to be 0)
- Both functions now accept inputs with arbitrarily many dimensions
- Autograd functions no longer save the input (it's unnecessary)
- Added cuDNN bindings for softmax, but they are unused as THCUNN
matches or even exceeds cuDNN performance
25 files changed, 774 insertions, 566 deletions
@@ -418,6 +418,7 @@ main_sources = [ "torch/csrc/autograd/generated/python_functions.cpp", "torch/csrc/autograd/functions/batch_normalization.cpp", "torch/csrc/autograd/functions/convolution.cpp", + "torch/csrc/autograd/functions/softmax.cpp", "torch/csrc/autograd/functions/basic_ops.cpp", "torch/csrc/autograd/functions/tensor.cpp", "torch/csrc/autograd/functions/accumulate_grad.cpp", diff --git a/test/common_nn.py b/test/common_nn.py index 7606238094..cae61dcb9d 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -91,6 +91,7 @@ module_tests = [ ), dict( module_name='Softmax', + constructor_args=(1,), input_size=(10, 20), reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20)), ), @@ -101,11 +102,13 @@ module_tests = [ ), dict( module_name='LogSoftmax', + constructor_args=(1,), input_size=(10, 20), reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_(), ), dict( module_name='LogSoftmax', + constructor_args=(1,), input_size=(1, 3, 10, 20), reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(), desc='multiparam', @@ -220,10 +223,12 @@ module_tests = [ ), dict( module_name='Softmin', + constructor_args=(1,), input_size=(10, 20), ), dict( module_name='Softmin', + constructor_args=(1,), input_size=(2, 3, 5, 10), desc='multidim', ), @@ -629,6 +634,7 @@ class ModuleTest(TestBase): super(ModuleTest, self).__init__(*args, **kwargs) self.jacobian_input = kwargs.get('jacobian_input', True) self.should_test_cuda = kwargs.get('test_cuda', True) + self.should_test_pickle = kwargs.get('pickle', True) def __call__(self, test_case): module = self.constructor(*self.constructor_args) @@ -644,13 +650,14 @@ class ModuleTest(TestBase): self.test_noncontig(test_case, module, input) - # TODO: do this with in-memory files as soon as torch.save will support it - with TemporaryFile() as f: - test_case._forward(module, input) - torch.save(module, f) - f.seek(0) - module_copy = torch.load(f) - test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input)) + if self.should_test_pickle: + # TODO: do this with in-memory files as soon as torch.save will support it + with TemporaryFile() as f: + test_case._forward(module, input) + torch.save(module, f) + f.seek(0) + module_copy = torch.load(f) + test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input)) self._do_test(test_case, module, input) diff --git a/test/test_nn.py b/test/test_nn.py index 81867c4119..a7c64163cf 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -162,7 +162,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest): test_case.assertEqual(type(p.data), torch.DoubleTensor) # TODO: Hardshrink is lacking a CUDA implementation - if TEST_CUDA and type(module) != nn.Hardshrink: + if TEST_CUDA and self.should_test_cuda and type(module) != nn.Hardshrink: # to GPU0 input = input.float().cuda() module.float().cuda() @@ -3576,6 +3576,13 @@ def add_test(test): setattr(TestNN, cuda_test_name, lambda self, test=test: test.test_cuda(self)) +def wrap_functional(fn, **kwargs): + class FunctionalModule(nn.Module): + def forward(self, *args): + return fn(*args, **kwargs) + return FunctionalModule + + new_criterion_tests = [ dict( module_name='BCEWithLogitsLoss', @@ -4367,6 +4374,44 @@ new_module_tests = [ input_size=(5, 6, 7), desc='dim' ), + dict( + constructor=wrap_functional(F.softmax, dim=1), + input_size=(2, 3, 4, 5), + fullname='softmax', + pickle=False, + ), + dict( + constructor=wrap_functional(F.softmax, dim=0), + input_size=(2, 3, 4, 5), + fullname='softmax_dim0', + test_cuda=False, + pickle=False, + ), + dict( + constructor=wrap_functional(F.softmax, dim=3), + input_size=(2, 3, 4, 5), + fullname='softmax_dim3', + test_cuda=False, + pickle=False, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=1), + input_size=(2, 3, 4, 5), + fullname='log_softmax', + pickle=False, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=0), + input_size=(2, 3, 4, 5), + fullname='log_softmax_dim0', + pickle=False, + ), + dict( + constructor=wrap_functional(F.log_softmax, dim=3), + input_size=(2, 3, 4, 5), + fullname='log_softmax_dim3', + pickle=False, + ), ] diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 315c5c76da..bbeaea02bc 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -492,10 +492,6 @@ def create_variable_type(top_env, aten_declarations): if arg['dynamic_type'] == 'Tensor'] env['tensorlist_args'] = [arg['name'] for arg in option['arguments'] if arg['dynamic_type'] == 'TensorList'] - if option['return_type'] == 'Scalar': - env['return_value'] = 'Scalar(output)' - else: - env['return_value'] = 'Tensor(std::move(output))' env['type_definition_body'] = emit_body(env, option) diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp index 206c5fc7ca..82ece0359a 100644 --- a/torch/csrc/autograd/functions/init.cpp +++ b/torch/csrc/autograd/functions/init.cpp @@ -1,5 +1,6 @@ #include "batch_normalization.h" #include "convolution.h" +#include "softmax.h" #include "accumulate_grad.h" #include "basic_ops.h" #include "tensor.h" @@ -50,6 +51,19 @@ struct ConvCtor { } }; +template<bool is_log> +struct SoftmaxCtor { + using fn_type = typename std::conditional<is_log, LogSoftmax, Softmax>::type; + fn_type* operator()(PyObject* args) { + int dim; + + TupleParser parser(args, 1); + parser.parse(dim, "dim"); + + return new fn_type(dim); + } +}; + struct DelayedErrorCtor { DelayedError* operator()(PyObject* args) { std::string msg; @@ -252,6 +266,16 @@ bool THPAutograd_initFunctions(PyObject* _unused) addClass<ConvBackward, NoCtor>(module, ConvBackwardClass, "ConvNdBackward", conv_backward_properties); addClass<ConvBackwardBackward, NoCtor>(module, ConvBackwardBackwardClass, "ConvNdBackwardBackward", conv_backward_backward_properties); + static PyTypeObject SoftmaxClass, SoftmaxBackwardClass, SoftmaxBackwardBackwardClass; + addClass<Softmax, SoftmaxCtor<false>>(module, SoftmaxClass, "Softmax"); + addClass<SoftmaxBackward, NoCtor>(module, SoftmaxBackwardClass, "SoftmaxBackward"); + addClass<SoftmaxBackwardBackward, NoCtor>(module, SoftmaxBackwardBackwardClass, "SoftmaxBackwardBackward"); + + static PyTypeObject LogSoftmaxClass, LogSoftmaxBackwardClass, LogSoftmaxBackwardBackwardClass; + addClass<LogSoftmax, SoftmaxCtor<true>>(module, LogSoftmaxClass, "LogSoftmax"); + addClass<LogSoftmaxBackward, NoCtor>(module, LogSoftmaxBackwardClass, "LogSoftmaxBackward"); + addClass<LogSoftmaxBackwardBackward, NoCtor>(module, LogSoftmaxBackwardBackwardClass, "LogSoftmaxBackwardBackward"); + static PyTypeObject AccumulateGradClass; addClass<AccumulateGrad, NoCtor>(module, AccumulateGradClass, "AccumulateGrad", accumulate_grad_properties); diff --git a/torch/csrc/autograd/functions/softmax.cpp b/torch/csrc/autograd/functions/softmax.cpp new file mode 100644 index 0000000000..50a5b45e2b --- /dev/null +++ b/torch/csrc/autograd/functions/softmax.cpp @@ -0,0 +1,107 @@ + +#include "softmax.h" + +#include "torch/csrc/autograd/variable.h" +#include "torch/csrc/autograd/functions/utils.h" +#include "torch/csrc/utils/auto_gpu.h" +#include "torch/csrc/DynamicTypes.h" +#include "torch/csrc/Exceptions.h" + +namespace torch { namespace autograd { + +template<bool is_log> +variable_list SoftmaxBase<is_log>::apply(const variable_list& inputs) { + using BackwardBase = SoftmaxBackwardBase<is_log>; + check_input_variables("SoftmaxBase", inputs, 1); + AutoGPU gpu_guard(inputs[0].data()); + + auto input = inputs[0].data().contiguous(); + auto output = input.type().tensor(input.sizes()); + + if (is_log) { + at::log_softmax_out(output, input, dim); + } else { + at::softmax_out(output, input, dim); + } + + // This gets a bit weird because we need to save the output... + std::shared_ptr<BackwardBase> backward; + auto outputs = wrap_outputs(inputs, as_tensor_list(output), [this, &backward](FunctionFlags f) { + backward = std::make_shared<BackwardBase>(std::move(f), this->dim); + return backward; + }); + if (backward && backward->is_executable) { + backward->saved_output = SavedVariable(outputs[0], backward.get()); + } + return outputs; +}; + +template<bool is_log> +variable_list SoftmaxBackwardBase<is_log>::apply(const variable_list& grad_outputs) { + using BackwardBase = typename std::conditional<is_log, LogSoftmaxBackwardBackward, SoftmaxBackwardBackward>::type; + check_input_variables("SoftmaxBackwardBase", grad_outputs, 1); + AutoGPU gpu_guard(grad_outputs[0]); + + auto output_var = saved_output.unpack(shared_from_this()); + auto& grad_output_var = grad_outputs[0]; + auto& output = output_var.data(); + auto& grad_output = grad_output_var.data(); + auto grad_input = output.type().tensor(output.sizes()); + + auto input = output.type().tensor(); // We don't save the input, because THNN doesn't use it anyway... + if (is_log) { + at::log_softmax_backward_out(grad_input, grad_output, input, dim, output); + } else { + at::softmax_backward_out(grad_input, grad_output, input, dim, output); + } + + variable_list all_inputs {output_var, grad_output_var}; + return wrap_outputs(all_inputs, as_tensor_list(grad_input), [this, &output_var, &grad_output_var](FunctionFlags f) { + auto fn = std::make_shared<BackwardBase>(std::move(f)); + if (fn->is_executable) { + fn->saved_output = SavedVariable(output_var, fn.get()); + fn->saved_grad_output = SavedVariable(grad_output_var, fn.get()); + fn->dim = this->dim; + } + return fn; + }); +} + +// These need to be explicitly instantiated, because they're not in the header. +template struct SoftmaxBase<true>; +template struct SoftmaxBase<false>; +template struct SoftmaxBackwardBase<true>; +template struct SoftmaxBackwardBase<false>; + +variable_list SoftmaxBackwardBackward::apply(const variable_list& grad_grad_inputs) { + check_input_variables("SoftmaxBackwardBackward", grad_grad_inputs, 1); + auto output = saved_output.unpack(shared_from_this()); + auto gO = saved_grad_output.unpack(); + auto& ggI = grad_grad_inputs[0]; + + // Terms for reuse + auto ggI_out_sum = (ggI * output).sum(dim, true); + auto gO_out_sum = (gO * output).sum(dim, true); + + // NOTE: this is 2nd order grad output + auto gO2 = (gO - gO_out_sum) * ggI - gO * ggI_out_sum; + auto ggO = output * (ggI - ggI_out_sum); + + return {Variable(std::move(gO2)), Variable(std::move(ggO))}; +} + +variable_list LogSoftmaxBackwardBackward::apply(const variable_list& grad_grad_inputs) { + check_input_variables("LogSoftmaxBackwardBackward", grad_grad_inputs, 1); + auto output = saved_output.unpack(shared_from_this()); + auto gO = saved_grad_output.unpack(); + auto& ggI = grad_grad_inputs[0]; + + auto output_exp = output.exp(); + // NOTE: this is 2nd order grad output + auto gO2 = (-output_exp) * ggI * gO.sum(dim, true); + auto ggO = ggI - (ggI * output_exp).sum(dim, true); + + return {Variable(std::move(gO2)), Variable(std::move(ggO))}; +} + +}} diff --git a/torch/csrc/autograd/functions/softmax.h b/torch/csrc/autograd/functions/softmax.h new file mode 100644 index 0000000000..e42f4aeba7 --- /dev/null +++ b/torch/csrc/autograd/functions/softmax.h @@ -0,0 +1,67 @@ +#pragma once + +#include <Python.h> +#include <memory> +#include <ATen/ATen.h> + +#include "torch/csrc/autograd/function.h" +#include "torch/csrc/autograd/variable.h" +#include "torch/csrc/autograd/saved_variable.h" + +namespace torch { namespace autograd { + +// Softmax and LogSoftmax are implemented nearly identically until second +// derivative, so the implementation is contained in a common base class. +template<bool is_log> +struct SoftmaxBase : public ForwardFunction<> { + SoftmaxBase(int dim) + : dim(dim) {} + + virtual variable_list apply(const variable_list& inputs) override; + + int dim; +}; + +template<bool is_log> +struct SoftmaxBackwardBase : public Function { + SoftmaxBackwardBase(FunctionFlags f, int dim) + : Function(std::move(f)) + , dim(dim) {} + + virtual variable_list apply(const variable_list& inputs) override; + + SavedVariable saved_output; + int dim; +}; + + +struct Softmax : public SoftmaxBase<false> { + using SoftmaxBase::SoftmaxBase; +}; +struct SoftmaxBackward : public SoftmaxBackwardBase<false> { + using SoftmaxBackwardBase::SoftmaxBackwardBase; +}; +struct SoftmaxBackwardBackward : public Function { + using Function::Function; + virtual variable_list apply(const variable_list& inputs) override; + SavedVariable saved_output; + SavedVariable saved_grad_output; + int dim; +}; + + +struct LogSoftmax : public SoftmaxBase<true> { + using SoftmaxBase::SoftmaxBase; +}; +struct LogSoftmaxBackward : public SoftmaxBackwardBase<true> { + using SoftmaxBackwardBase::SoftmaxBackwardBase; +}; +struct LogSoftmaxBackwardBackward : public Function { + using Function::Function; + virtual variable_list apply(const variable_list& inputs) override; + SavedVariable saved_output; + SavedVariable saved_grad_output; + int dim; +}; + +}} diff --git a/torch/csrc/cudnn/Descriptors.h b/torch/csrc/cudnn/Descriptors.h index 6917e1cbf5..77f6f30e38 100644 --- a/torch/csrc/cudnn/Descriptors.h +++ b/torch/csrc/cudnn/Descriptors.h @@ -4,27 +4,71 @@ #include "Exceptions.h" #include "cudnn-wrapper.h" +#include <ATen/ATen.h> namespace torch { namespace cudnn { +inline int dataSize(cudnnDataType_t dataType) +{ + switch (dataType) { + case CUDNN_DATA_HALF: return 2; + case CUDNN_DATA_FLOAT: return 4; + default: return 8; + } +} + +inline cudnnDataType_t getDataType(const at::Tensor& t) { + auto scalar_type = t.type().scalarType(); + if (scalar_type == at::kFloat) { + return CUDNN_DATA_FLOAT; + } else if (scalar_type == at::kHalf) { + return CUDNN_DATA_HALF; + } else if (scalar_type == at::kDouble) { + return CUDNN_DATA_DOUBLE; + } + throw std::runtime_error("TensorDescriptor only supports double, float and half tensors"); +} + struct TensorDescriptor { cudnnTensorDescriptor_t desc; + TensorDescriptor() : desc(NULL) { CHECK(cudnnCreateTensorDescriptor(&desc)); } - TensorDescriptor(const TensorDescriptor&) = delete; - TensorDescriptor(TensorDescriptor&& ref) - { - desc = ref.desc; - ref.desc = NULL; + /* implicit */ TensorDescriptor(const at::Tensor& t, int pad = 0) : desc(NULL) { + CHECK(cudnnCreateTensorDescriptor(&desc)); + set(t, pad); } + TensorDescriptor(const TensorDescriptor&) = delete; + TensorDescriptor(TensorDescriptor&& ref) { desc = ref.desc; ref.desc = NULL; } ~TensorDescriptor() { cudnnDestroyTensorDescriptor(desc); } void set(cudnnDataType_t dataType, int dim, int* size, int* stride) { CHECK(cudnnSetTensorNdDescriptor(desc, dataType, dim, size, stride)); } + void set(const at::Tensor &t, int pad = 0) { + int dim = t.ndimension(); + if (dim > CUDNN_DIM_MAX || pad > CUDNN_DIM_MAX) +#define _STR(X) #X +#define STR(X) _STR(X) + throw std::runtime_error("cuDNN supports only up to " STR(CUDNN_DIM_MAX) " dimensions"); +#undef _STR +#undef STR + int size[CUDNN_DIM_MAX]; + int stride[CUDNN_DIM_MAX]; + for (int i = 0; i < dim; ++i) { + size[i] = t.size(i); + stride[i] = t.stride(i); + } + for (int i = dim; i < pad; ++i) { + size[i] = 1; + stride[i] = 1; + } + dim = std::max(dim, pad); + set(getDataType(t), dim, size, stride); + } }; struct FilterDescriptor @@ -109,15 +153,6 @@ union Constant } }; -inline int dataSize(cudnnDataType_t dataType) -{ - switch (dataType) { - case CUDNN_DATA_HALF: return 2; - case CUDNN_DATA_FLOAT: return 4; - default: return 8; - } -} - }} // namespace #endif diff --git a/torch/csrc/cudnn/Exceptions.h b/torch/csrc/cudnn/Exceptions.h index 362565e48c..75c3a398d0 100644 --- a/torch/csrc/cudnn/Exceptions.h +++ b/torch/csrc/cudnn/Exceptions.h @@ -37,6 +37,20 @@ void assertSameGPU(cudnnDataType_t dataType, T* ... tensors) { } } +inline int assertSameGPU(const at::Tensor& t) { + return t.get_device(); +} + +template<typename ...T> +int assertSameGPU(const at::Tensor& t, T& ... tensors) { + static_assert(std::is_same<at::Tensor, typename std::common_type<T...>::type>::value, + "all arguments to assertSameGPU have to be at::Tensor&"); + auto t_device = t.get_device(); + if (t_device != assertSameGPU(tensors...)) + throw std::runtime_error("tensors are on different GPUs"); + return t_device; +} + class cudnn_exception : public std::runtime_error { public: cudnnStatus_t status; diff --git a/torch/legacy/nn/LogSoftMax.py b/torch/legacy/nn/LogSoftMax.py index 948e9512ac..66ad2b0573 100644 --- a/torch/legacy/nn/LogSoftMax.py +++ b/torch/legacy/nn/LogSoftMax.py @@ -4,11 +4,20 @@ from .Module import Module class LogSoftMax(Module): + def __init__(self, dim=None): + super(LogSoftMax, self).__init__() + if dim is not None: + self.dim = dim + + def _get_dim(self, input): + return getattr(self, 'dim', 0 if input.dim() == 1 or input.dim() == 3 else 1) + def updateOutput(self, input): self._backend.LogSoftMax_updateOutput( self._backend.library_state, input, - self.output + self.output, + self._get_dim(input) ) return self.output @@ -18,6 +27,7 @@ class LogSoftMax(Module): input, gradOutput, self.gradInput, - self.output + self.output, + self._get_dim(input) ) return self.gradInput diff --git a/torch/legacy/nn/SoftMax.py b/torch/legacy/nn/SoftMax.py index 24d5fa5967..3f91bf80d6 100644 --- a/torch/legacy/nn/SoftMax.py +++ b/torch/legacy/nn/SoftMax.py @@ -4,11 +4,20 @@ from .Module import Module class SoftMax(Module): + def __init__(self, dim=None): + super(SoftMax, self).__init__() + if dim is not None: + self.dim = dim + + def _get_dim(self, input): + return getattr(self, 'dim', 0 if input.dim() == 1 or input.dim() == 3 else 1) + def updateOutput(self, input): self._backend.SoftMax_updateOutput( self._backend.library_state, input, - self.output + self.output, + self._get_dim(input) ) return self.output @@ -18,6 +27,7 @@ class SoftMax(Module): input, gradOutput, self.gradInput, - self.output + self.output, + self._get_dim(input) ) return self.gradInput diff --git a/torch/legacy/nn/SoftMin.py b/torch/legacy/nn/SoftMin.py index 7c1bbbff3f..4c9915c37b 100644 --- a/torch/legacy/nn/SoftMin.py +++ b/torch/legacy/nn/SoftMin.py @@ -5,9 +5,14 @@ from .utils import clear class SoftMin(Module): - def __init__(self): + def __init__(self, dim=None): super(SoftMin, self).__init__() self.mininput = None + if dim is not None: + self.dim = dim + + def _get_dim(self, input): + return getattr(self, 'dim', 0 if input.dim() == 1 or input.dim() == 3 else 1) def updateOutput(self, input): if self.mininput is None: @@ -16,7 +21,8 @@ class SoftMin(Module): self._backend.SoftMax_updateOutput( self._backend.library_state, self.mininput, - self.output + self.output, + self._get_dim(input) ) return self.output @@ -29,7 +35,8 @@ class SoftMin(Module): self.mininput, gradOutput, self.gradInput, - self.output + self.output, + self._get_dim(input) ) self.gradInput.mul_(-1) diff --git a/torch/legacy/nn/SpatialSoftMax.py b/torch/legacy/nn/SpatialSoftMax.py index 526e6d47dc..5c9c0a45d1 100644 --- a/torch/legacy/nn/SpatialSoftMax.py +++ b/torch/legacy/nn/SpatialSoftMax.py @@ -8,7 +8,8 @@ class SpatialSoftMax(Module): self._backend.SoftMax_updateOutput( self._backend.library_state, input, - self.output + self.output, + 0 if input.dim() == 1 or input.dim() == 3 else 1 ) return self.output @@ -18,6 +19,7 @@ class SpatialSoftMax(Module): input, gradOutput, self.gradInput, - self.output + self.output, + 0 if input.dim() == 1 or input.dim() == 3 else 1 ) return self.gradInput diff --git a/torch/lib/THCUNN/LogSoftMax.cu b/torch/lib/THCUNN/LogSoftMax.cu index 98b7670718..91a9fbd7cd 100644 --- a/torch/lib/THCUNN/LogSoftMax.cu +++ b/torch/lib/THCUNN/LogSoftMax.cu @@ -1,88 +1,78 @@ #include "THCUNN.h" #include "THCHalf.h" +#include "THCTensorTypeUtils.cuh" #include "THCHalfAutoNumerics.cuh" #include "SharedMem.cuh" template <typename T, typename AccumT> -__global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input, int classSize, int height, int width) +__global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input, uint32_t outer_size, uint32_t dim_size, uint32_t inner_size) { - int batchIndex = blockIdx.x; - int index = threadIdx.x; - - while (index < height*width) { - int y = index / width; - int x = index % width; - if (y >= height) - break; - - // calculate input starting index in cuda layout (B x H x W x C) - int inputStartIndex = - (height*width*classSize)*batchIndex + - (width*classSize)*y + - (classSize)*x; - - T maxInput = input[inputStartIndex]; - for (int i = 1; i < classSize; i++) { - T value = input[inputStartIndex + i]; - maxInput = THCNumerics<T>::ge(maxInput, value) ? maxInput : value; - } + const uint32_t outer_stride = inner_size * dim_size; + const uint32_t dim_stride = inner_size; + + for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) { + const uint32_t outer_offset = outer_index * outer_stride; + for (uint32_t inner_index = blockIdx.y * blockDim.x + threadIdx.x; inner_index < inner_size; inner_index += blockDim.x * gridDim.y) { + const uint32_t data_offset = outer_offset + inner_index; + + T max_input = input[data_offset]; + for (uint32_t d = 1; d < dim_size; d++) { + const T value = input[data_offset + d * dim_stride]; + max_input = THCNumerics<T>::ge(max_input, value) ? max_input : value; + } - AccumT sum = 0; - for (int i = 0; i < classSize; i++) { - sum += THCNumerics<T>::exp(input[inputStartIndex + i] - maxInput); - } - T logsum = maxInput + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum)); - - for (int i = 0; i < classSize; i++) { - // calculate output index in torch layout (B x C x H x W) - int outputIndex = - (classSize*height*width)*batchIndex + - (height*width)*i + - (width)*y + - x; - output[outputIndex] = input[inputStartIndex + i] - logsum; + AccumT sum = 0; + for (uint32_t d = 0; d < dim_size; d++) + sum += THCNumerics<T>::exp(input[data_offset + d * dim_stride] - max_input); + const T logsum = max_input + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum)); + + for (uint32_t d = 0; d < dim_size; d++) + output[data_offset + d * dim_stride] = input[data_offset + d * dim_stride] - logsum; } - index += blockDim.x; } } template <typename T, typename AccumT> -__global__ void cunn_SpatialLogSoftMax_updateGradInput_kernel(T *gradInput, T *output, T *gradOutput, int classSize, int height, int width) +__global__ void cunn_SpatialLogSoftMax_updateGradInput_kernel(T *gradInput, T *output, T *gradOutput, uint32_t outer_size, uint32_t dim_size, uint32_t inner_size) { - int batchIndex = blockIdx.x; - int index = threadIdx.x; - - while (index < height*width) { - int y = index / width; - int x = index % width; - if (y >= height) - break; - - // calculate output starting index in cuda layout (B x H x W x C) - int outputStartIndex = - (height*width*classSize)*batchIndex + - (width*classSize)*y + - (classSize)*x; - - AccumT sum = 0; - for (int i = 0; i < classSize; i++) { - sum += gradOutput[outputStartIndex + i]; - } + const uint32_t outer_stride = inner_size * dim_size; + const uint32_t dim_stride = inner_size; + + for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) { + const uint32_t outer_offset = outer_index * outer_stride; + for (uint32_t inner_index = blockIdx.y * blockDim.x + threadIdx.x; inner_index < inner_size; inner_index += blockDim.x * gridDim.y) { + const uint32_t data_offset = outer_offset + inner_index; - for (int i = 0; i < classSize; i++) { - // calculate input index in torch layout (B x C x H x W) - int inputIndex = - (classSize*height*width)*batchIndex + - (height*width)*i + - (width)*y + - x; - gradInput[inputIndex] = ScalarConvert<AccumT, T>::to( - gradOutput[outputStartIndex + i] - THCNumerics<T>::exp(output[outputStartIndex + i]) * sum); + AccumT sum = 0; + for (uint32_t d = 0; d < dim_size; d++) { + sum += gradOutput[data_offset + d * dim_stride]; + } + const T real_sum = ScalarConvert<AccumT, T>::to(sum); + + for (uint32_t d = 0; d < dim_size; d++) { + gradInput[data_offset + d * dim_stride] = gradOutput[data_offset + d * dim_stride] - + THCNumerics<T>::exp(output[data_offset + d * dim_stride]) * real_sum; + } } - index += blockDim.x; } } +static void LogSoftMax_getSpatialGridSize( + uint32_t block_size, uint32_t max_active_blocks, + uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, + dim3& grid, dim3& block) { + // First, tile as many blocks as we can over the y axis + uint32_t y_size = (inner_size + block_size - 1) / block_size; + if (y_size > max_active_blocks) + y_size = max_active_blocks; + // Fill the x axis with as many blocks as we can fit + uint32_t x_size = (max_active_blocks + y_size - 1) / y_size; + if (x_size > outer_size) + x_size = outer_size; + grid = dim3(x_size, y_size); + block = dim3(block_size); +} + template <typename T, typename AccumT> struct MaxFloat { diff --git a/torch/lib/THCUNN/generic/LogSoftMax.cu b/torch/lib/THCUNN/generic/LogSoftMax.cu index 2f24697f53..4e35335191 100644 --- a/torch/lib/THCUNN/generic/LogSoftMax.cu +++ b/torch/lib/THCUNN/generic/LogSoftMax.cu @@ -7,101 +7,60 @@ void THNN_(LogSoftMax_updateOutput)( THCState *state, THCTensor *input, - THCTensor *output) + THCTensor *output, + int dim) { THCUNN_assertSameGPU(state, 2, input, output); + THArgCheck(dim >= 0 && dim < input->nDimension, 4, + "dim out of range (got %d, but input has %d dims)", dim, input->nDimension); + THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, input), 4, + "input tensor is too large (unsupported size. file a feature request)"); THCTensor_(resizeAs)(state, output, input); - bool spatial = false; - int batchSize = 1; - int classSize = 0; - int height = 0; - int width = 0; - - int ndims = THCTensor_(nDimension)(state, input); - - if (ndims == 1) - { - classSize = THCTensor_(size)(state, input, 0); - input = THCTensor_(newContiguous)(state, input); - } - else if (ndims == 2) - { - batchSize = THCTensor_(size)(state, input, 0); - classSize = THCTensor_(size)(state, input, 1); - input = THCTensor_(newContiguous)(state, input); - } - else if (ndims == 3) - { - spatial = true; - classSize = THCTensor_(size)(state, input, 0); - height = THCTensor_(size)(state, input, 1); - width = THCTensor_(size)(state, input, 2); - - // create contiguous tensor with cuda layout from tensor with torch layout - THCTensor *tinput = THCTensor_(new)(state); - // C x H x W -> W x H x C - THCTensor_(transpose)(state, tinput, input, 0, 2); - // W x H x C -> H x W x C - THCTensor_(transpose)(state, tinput, tinput, 0, 1); - THCTensor *transposedInput = THCTensor_(newContiguous)(state, tinput); - THCTensor_(free)(state, tinput); - input = transposedInput; - } - else if (ndims == 4) - { - spatial = true; - batchSize = THCTensor_(size)(state, input, 0); - classSize = THCTensor_(size)(state, input, 1); - height = THCTensor_(size)(state, input, 2); - width = THCTensor_(size)(state, input, 3); - - // create contiguous tensor with cuda layout from tensor with torch layout - // B x C x H x W -> B x W x H x C - THCTensor *tinput = THCTensor_(new)(state); - THCTensor_(transpose)(state, tinput, input, 1, 3); - // B x W x H x C -> B x H x W x C - THCTensor_(transpose)(state, tinput, tinput, 1, 2); - THCTensor *transposedInput = THCTensor_(newContiguous)(state, tinput); - THCTensor_(free)(state, tinput); - input = transposedInput; - } - else - { - THError("1D, 2D, 3D or 4D Tensor expected"); - } - - if (!spatial) - { - dim3 grid(batchSize); + uint64_t outer_size = 1; + uint64_t dim_size = input->size[dim]; + uint64_t inner_size = 1; + for (uint64_t i = 0; i < dim; ++i) + outer_size *= input->size[i]; + for (uint64_t i = dim + 1; i < input->nDimension; ++i) + inner_size *= input->size[i]; + + // This kernel spawns a block of 1024 threads per each element in the batch. + // XXX: it assumes that inner_size == 1 + input = THCTensor_(newContiguous)(state, input); + if (inner_size == 1 && dim_size >= 64) { + dim3 grid(outer_size); dim3 block(1024); cunn_LogSoftMax_updateOutput_kernel<2, real, accreal> <<<grid, block, block.x * sizeof(accreal), THCState_getCurrentStream(state)>>>( THCTensor_(data)(state, output), THCTensor_(data)(state, input), - classSize + dim_size ); - } - else - { - dim3 grid(batchSize); - dim3 block(1024); + // This kernel runs in a 2D grid, where each application along y dimension has a fixed + // outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size. + // Reductions over dim are done in a single-threaded manner. + } else { + dim3 grid, block; + uint32_t block_size = 1024; + while (block_size > inner_size) block_size >>= 1; // block_size = floor(log2(inner_size)) + int max_active_blocks; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, + &cunn_SpatialLogSoftMax_updateOutput_kernel<real, accreal>, + block_size, 0); + max_active_blocks *= THCState_getCurrentDeviceProperties(state)->multiProcessorCount; + LogSoftMax_getSpatialGridSize(block_size, max_active_blocks, outer_size, dim_size, inner_size, grid, block); cunn_SpatialLogSoftMax_updateOutput_kernel<real, accreal> <<<grid, block, 0, THCState_getCurrentStream(state)>>>( THCTensor_(data)(state, output), THCTensor_(data)(state, input), - classSize, height, width + outer_size, dim_size, inner_size ); } - - cudaError errcode = cudaGetLastError(); - if (errcode != cudaSuccess) - { - THError(cudaGetErrorString(errcode)); - } + THCudaCheck(cudaGetLastError()); THCTensor_(free)(state, input); } @@ -111,97 +70,32 @@ void THNN_(LogSoftMax_updateGradInput)( THCTensor *input, THCTensor *gradOutput, THCTensor *gradInput, - THCTensor *output) + THCTensor *output, + int dim) { - THCUNN_check_nElement(state, input, gradOutput); + THArgCheck(dim >= 0 && dim < output->nDimension, 6, + "dim out of range (got %d, but input has %d dims)", dim, output->nDimension); + THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, output), 6, + "input tensor is too large (unsupported size. file a feature request)"); + THCUNN_check_nElement(state, output, gradOutput); THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput); THCTensor_(resizeAs)(state, gradInput, output); - bool spatial = false; - int batchSize = 1; - int classSize = 0; - int height = 0; - int width = 0; + uint64_t outer_size = 1; + uint64_t dim_size = output->size[dim]; + uint64_t inner_size = 1; + for (uint64_t i = 0; i < dim; ++i) + outer_size *= output->size[i]; + for (uint64_t i = dim + 1; i < output->nDimension; ++i) + inner_size *= output->size[i]; - int ndims = THCTensor_(nDimension)(state, input); + output = THCTensor_(newContiguous)(state, output); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); - if (ndims == 1) - { - classSize = THCTensor_(size)(state, gradInput, 0); - output = THCTensor_(newContiguous)(state, output); - gradOutput = THCTensor_(newContiguous)(state, gradOutput); - } - else if (ndims == 2) - { - batchSize = THCTensor_(size)(state, gradInput, 0); - classSize = THCTensor_(size)(state, gradInput, 1); - output = THCTensor_(newContiguous)(state, output); - gradOutput = THCTensor_(newContiguous)(state, gradOutput); - } - else if (ndims == 3) - { - spatial = true; - classSize = THCTensor_(size)(state, input, 0); - height = THCTensor_(size)(state, input, 1); - width = THCTensor_(size)(state, input, 2); - - // create contiguous tensor with cuda layout from tensor with torch layout - // C x H x W -> W x H x C - THCTensor_(transpose)(state, output, output, 0, 2); - // W x H x C -> H x W x C - THCTensor_(transpose)(state, output, output, 0, 1); - THCTensor *transposedOutput = THCTensor_(newContiguous)(state, output); - THCTensor_(transpose)(state, output, output, 0, 1); - THCTensor_(transpose)(state, output, output, 0, 2); - output = transposedOutput; - - // create contiguous tensor with cuda layout from tensor with torch layout - // C x H x W -> W x H x C - THCTensor_(transpose)(state, gradOutput, gradOutput, 0, 2); - // W x H x C -> H x W x C - THCTensor_(transpose)(state, gradOutput, gradOutput, 0, 1); - THCTensor *transposedGradOutput = THCTensor_(newContiguous)(state, gradOutput); - THCTensor_(transpose)(state, gradOutput, gradOutput, 0, 1); - THCTensor_(transpose)(state, gradOutput, gradOutput, 0, 2); - gradOutput = transposedGradOutput; - } - else if (ndims == 4) - { - spatial = true; - batchSize = THCTensor_(size)(state, gradInput, 0); - classSize = THCTensor_(size)(state, input, 1); - height = THCTensor_(size)(state, input, 2); - width = THCTensor_(size)(state, input, 3); - - // create contiguous tensor with cuda layout from tensor with torch layout - // B x C x H x W -> B x W x H x C - THCTensor_(transpose)(state, output, output, 1, 3); - // B x W x H x C -> B x H x W x C - THCTensor_(transpose)(state, output, output, 1, 2); - THCTensor *transposedOutput = THCTensor_(newContiguous)(state, output); - THCTensor_(transpose)(state, output, output, 1, 2); - THCTensor_(transpose)(state, output, output, 1, 3); - output = transposedOutput; - - // create contiguous tensor with cuda layout from tensor with torch layout - // B x C x H x W -> B x W x H x C - THCTensor_(transpose)(state, gradOutput, gradOutput, 1, 3); - // B x W x H x C -> B x H x W x C - THCTensor_(transpose)(state, gradOutput, gradOutput, 1, 2); - THCTensor *transposedGradOutput = THCTensor_(newContiguous)(state, gradOutput); - THCTensor_(transpose)(state, gradOutput, gradOutput, 1, 2); - THCTensor_(transpose)(state, gradOutput, gradOutput, 1, 3); - gradOutput = transposedGradOutput; - } - else - { - THError("1D, 2D, 3D or 4D Tensor expected"); - } - - if (!spatial) - { - dim3 grid(batchSize); + // See descriptions of kernels above. + if (inner_size == 1 && dim_size >= 64) { + dim3 grid(outer_size); dim3 block(1024); cunn_LogSoftMax_updateGradInput_kernel<2, real, accreal> @@ -209,20 +103,25 @@ void THNN_(LogSoftMax_updateGradInput)( THCTensor_(data)(state, gradInput), THCTensor_(data)(state, output), THCTensor_(data)(state, gradOutput), - classSize + dim_size ); - } - else - { - dim3 grid(batchSize); - dim3 block(1024); + } else { + dim3 grid, block; + uint32_t block_size = 1024; + while (block_size > inner_size) block_size >>= 1; // block_size = floor(log2(inner_size)) + int max_active_blocks; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, + &cunn_SpatialLogSoftMax_updateGradInput_kernel<real, accreal>, + block_size, 0); + max_active_blocks *= THCState_getCurrentDeviceProperties(state)->multiProcessorCount; + LogSoftMax_getSpatialGridSize(block_size, max_active_blocks, outer_size, dim_size, inner_size, grid, block); cunn_SpatialLogSoftMax_updateGradInput_kernel<real, accreal> <<<grid, block, 0, THCState_getCurrentStream(state)>>>( THCTensor_(data)(state, gradInput), THCTensor_(data)(state, output), THCTensor_(data)(state, gradOutput), - classSize, height, width + outer_size, dim_size, inner_size ); } diff --git a/torch/lib/THCUNN/generic/SoftMax.cu b/torch/lib/THCUNN/generic/SoftMax.cu index b52ca179ea..030f1ee94f 100644 --- a/torch/lib/THCUNN/generic/SoftMax.cu +++ b/torch/lib/THCUNN/generic/SoftMax.cu @@ -7,7 +7,8 @@ void THNN_(SoftMax_updateOutput)( THCState *state, THCTensor *input, - THCTensor *output) + THCTensor *output, + int _dim) { THCUNN_assertSameGPU(state, 2, input, output); @@ -21,12 +22,14 @@ void THNN_(SoftMax_updateOutput)( batchSize = 1; dim = input->size[0]; stride0 = 1; + THArgCheck(_dim == 0, 4, "dim has to be 0 for 1D input"); } else if (input->nDimension == 2) { batchSize = input->size[0]; dim = input->size[1]; stride0 = 1; + THArgCheck(_dim == 1, 4, "dim has to be 1 for 2D input"); } else if (input->nDimension == 3) { @@ -36,6 +39,7 @@ void THNN_(SoftMax_updateOutput)( blocksZ = input->size[2]; stride0 = blocksY * blocksZ; stride1 = blocksZ; + THArgCheck(_dim == 0, 4, "dim has to be 0 for 3D input"); } else if (input->nDimension == 4) { @@ -45,6 +49,7 @@ void THNN_(SoftMax_updateOutput)( blocksZ = input->size[3]; stride0 = blocksY * blocksZ; stride1 = blocksZ; + THArgCheck(_dim == 1, 4, "dim has to be 1 for 4D input"); } else { @@ -79,9 +84,10 @@ void THNN_(SoftMax_updateGradInput)( THCTensor *input, THCTensor *gradOutput, THCTensor *gradInput, - THCTensor *output) + THCTensor *output, + int _dim) { - THCUNN_check_nElement(state, input, gradOutput); + THCUNN_check_nElement(state, output, gradOutput); THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput); output = THCTensor_(newContiguous)(state, output); @@ -96,12 +102,14 @@ void THNN_(SoftMax_updateGradInput)( batchSize = 1; dim = gradInput->size[0]; stride0 = 1; + THArgCheck(_dim == 0, 6, "dim has to be 0 for 1D input"); } else if (gradInput->nDimension == 2) { batchSize = gradInput->size[0]; dim = gradInput->size[1]; stride0 = 1; + THArgCheck(_dim == 1, 6, "dim has to be 0 for 2D input"); } else if (gradInput->nDimension == 3) { @@ -111,6 +119,7 @@ void THNN_(SoftMax_updateGradInput)( blocksZ = gradInput->size[2]; stride0 = blocksY * blocksZ; stride1 = blocksZ; + THArgCheck(_dim == 0, 6, "dim has to be 0 for 3D input"); } else if (gradInput->nDimension == 4) { @@ -120,6 +129,7 @@ void THNN_(SoftMax_updateGradInput)( blocksZ = gradInput->size[3]; stride0 = blocksY * blocksZ; stride1 = blocksZ; + THArgCheck(_dim == 1, 6, "dim has to be 0 for 4D input"); } else { @@ -131,7 +141,7 @@ void THNN_(SoftMax_updateGradInput)( { blocksY *= blocksZ; blocksZ = 1; - if (input->nDimension == 3 || input->nDimension == 4) { + if (output->nDimension == 3 || output->nDimension == 4) { stride0 = blocksY * blocksZ; stride1 = blocksZ; } diff --git a/torch/lib/THCUNN/generic/THCUNN.h b/torch/lib/THCUNN/generic/THCUNN.h index aa1842f50d..bc75e0484b 100644 --- a/torch/lib/THCUNN/generic/THCUNN.h +++ b/torch/lib/THCUNN/generic/THCUNN.h @@ -242,14 +242,16 @@ TH_API void THNN_(LogSigmoid_updateGradInput)( TH_API void THNN_(LogSoftMax_updateOutput)( THCState *state, THCTensor *input, - THCTensor *output); + THCTensor *output, + int dim); TH_API void THNN_(LogSoftMax_updateGradInput)( THCState *state, THCTensor *input, THCTensor *gradOutput, THCTensor *gradInput, - THCTensor *output); + THCTensor *output, + int dim); TH_API void THNN_(LookupTable_accGradParameters)( THCState *state, @@ -1066,14 +1068,16 @@ TH_API void THNN_(SoftMarginCriterion_updateGradInput)( TH_API void THNN_(SoftMax_updateOutput)( THCState *state, THCTensor *input, - THCTensor *output); + THCTensor *output, + int dim); TH_API void THNN_(SoftMax_updateGradInput)( THCState *state, THCTensor *input, THCTensor *gradOutput, THCTensor *gradInput, - THCTensor *output); + THCTensor *output, + int dim); TH_API void THNN_(SoftPlus_updateOutput)( THCState *state, diff --git a/torch/lib/THNN/generic/LogSoftMax.c b/torch/lib/THNN/generic/LogSoftMax.c index a7280422b1..4c4700eb1e 100644 --- a/torch/lib/THNN/generic/LogSoftMax.c +++ b/torch/lib/THNN/generic/LogSoftMax.c @@ -5,64 +5,48 @@ void THNN_(LogSoftMax_updateOutput)( THNNState *state, THTensor *input, - THTensor *output) + THTensor *output, + int dim) { - real *input_data, *output_data; - ptrdiff_t nframe = 0, dim = 0, stride = 0; - ptrdiff_t t, d; + THArgCheck(dim >= 0 && dim < input->nDimension, 4, + "dim out of range (got %d, but input has %d dims)", dim, input->nDimension); - if (input->nDimension == 1) - { - nframe = 1; - dim = input->size[0]; - stride = 1; - } - else if (input->nDimension == 2) - { - nframe = input->size[0]; - dim = input->size[1]; - stride = 1; - } - else if (input->nDimension == 3) - { - nframe = 1; - dim = input->size[0]; - stride = input->size[1]*input->size[2]; - } - else if (input->nDimension == 4) - { - nframe = input->size[0]; - dim = input->size[1]; - stride = input->size[2]*input->size[3]; - } - else - THArgCheck(0, 2, "1D, 2D, 3D or 4D tensor expected"); + uint64_t outer_size = 1; + uint64_t dim_size = input->size[dim]; + uint64_t inner_size = 1; + for (uint64_t i = 0; i < dim; ++i) + outer_size *= input->size[i]; + for (uint64_t i = dim + 1; i < input->nDimension; ++i) + inner_size *= input->size[i]; input = THTensor_(newContiguous)(input); THTensor_(resizeAs)(output, input); - real *input_data0 = THTensor_(data)(input); - real *output_data0 = THTensor_(data)(output); - - accreal logsum; - real maxInput; - #pragma omp parallel for private(t, d, maxInput, logsum, input_data, output_data) - for (t = 0; t < stride*nframe; t++) - { - logsum = 0; - maxInput = -THInf; - input_data = input_data0 + (t/stride)*dim*stride + t % stride; - output_data = output_data0 + (t/stride)*dim*stride + t % stride; - - for (d = 0; d < dim; d++) - maxInput = THMax(maxInput, input_data[d*stride]); + real *input_data_base = THTensor_(data)(input); + real *output_data_base = THTensor_(data)(output); - for (d = 0; d < dim; d++) - logsum += exp(input_data[d*stride] - maxInput); - logsum = maxInput + log(logsum); + uint64_t dim_stride = inner_size; + uint64_t outer_stride = dim_size * dim_stride; - for (d = 0; d < dim; d++) - output_data[d*stride] = input_data[d*stride] - logsum; +#pragma omp parallel for + for (uint64_t i = 0; i < outer_size * inner_size; i++) + { + uint64_t outer_idx = i / inner_size; + uint64_t inner_idx = i % inner_size; + real *input_data = input_data_base + outer_idx * outer_stride + inner_idx; + real *output_data = output_data_base + outer_idx * outer_stride + inner_idx; + + real max_input = -THInf; + for (uint64_t d = 1; d < dim_size; d++) + max_input = THMax(max_input, input_data[d * dim_stride]); + + accreal logsum = 0; + for (uint64_t d = 0; d < dim_size; d++) + logsum += exp(input_data[d * dim_stride] - max_input); + logsum = max_input + log(logsum); + + for (uint64_t d = 0; d < dim_size; d++) + output_data[d * dim_stride] = input_data[d * dim_stride] - logsum; } THTensor_(free)(input); @@ -73,61 +57,47 @@ void THNN_(LogSoftMax_updateGradInput)( THTensor *input, THTensor *gradOutput, THTensor *gradInput, - THTensor *output) + THTensor *output, + int dim) { - THNN_CHECK_SHAPE(input, gradOutput); - real *gradInput_data, *gradOutput_data, *output_data; - ptrdiff_t nframe = 0, dim = 0, stride = 0; - ptrdiff_t t, d; + THNN_CHECK_SHAPE(output, gradOutput); + THArgCheck(dim >= 0 && dim < output->nDimension, 6, + "dim out of range (got %d, but input has %d dims)", dim, output->nDimension); + + uint64_t outer_size = 1; + uint64_t dim_size = output->size[dim]; + uint64_t inner_size = 1; + for (uint64_t i = 0; i < dim; ++i) + outer_size *= output->size[i]; + for (uint64_t i = dim + 1; i < output->nDimension; ++i) + inner_size *= output->size[i]; - if (output->nDimension == 1) - { - nframe = 1; - dim = output->size[0]; - stride = 1; - } - else if (output->nDimension == 2) - { - nframe = output->size[0]; - dim = output->size[1]; - stride = 1; - } - else if (output->nDimension == 3) - { - nframe = 1; - dim = output->size[0]; - stride = output->size[1]*output->size[2]; - } - else if (output->nDimension == 4) - { - nframe = output->size[0]; - dim = output->size[1]; - stride = output->size[2]*output->size[3]; - } - else - THError("1D, 2D, 3D or 4D tensor expected"); - - output = THTensor_(newContiguous)(output); gradOutput = THTensor_(newContiguous)(gradOutput); - + output = THTensor_(newContiguous)(output); THTensor_(resizeAs)(gradInput, output); - real *gradInput_data0 = THTensor_(data)(gradInput); - real *output_data0 = THTensor_(data)(output); - real *gradOutput_data0 = THTensor_(data)(gradOutput); - accreal sum; - #pragma omp parallel for private(t, sum, d, gradInput_data, output_data, gradOutput_data) - for (t = 0; t < stride*nframe; t++) - { - sum = 0; - gradInput_data = gradInput_data0 + (t/stride)*dim*stride + t % stride; - output_data = output_data0 + (t/stride)*dim*stride + t % stride; - gradOutput_data = gradOutput_data0 + (t/stride)*dim*stride + t % stride; - for (d = 0; d < dim; d++) - sum += gradOutput_data[d*stride]; + real *gradInput_data_base = THTensor_(data)(gradInput); + real *output_data_base = THTensor_(data)(output); + real *gradOutput_data_base = THTensor_(data)(gradOutput); + + uint64_t dim_stride = inner_size; + uint64_t outer_stride = dim_size * dim_stride; - for (d = 0; d < dim; d++) - gradInput_data[d*stride] = gradOutput_data[d*stride] - exp(output_data[d*stride])*sum; +#pragma omp parallel for + for (uint64_t i = 0; i < outer_size * inner_size; i++) + { + uint64_t outer_idx = i / inner_size; + uint64_t inner_idx = i % inner_size; + real *gradInput_data = gradInput_data_base + outer_idx * outer_stride + inner_idx; + real *output_data = output_data_base + outer_idx * outer_stride + inner_idx; + real *gradOutput_data = gradOutput_data_base + outer_idx * outer_stride + inner_idx; + + accreal sum = 0; + for (uint64_t d = 0; d < dim_size; d++) + sum += gradOutput_data[d * dim_stride]; + + for (uint64_t d = 0; d < dim_size; d++) + gradInput_data[d * dim_stride] = gradOutput_data[d * dim_stride] - exp(output_data[d * dim_stride]) * sum; } THTensor_(free)(gradOutput); diff --git a/torch/lib/THNN/generic/SoftMax.c b/torch/lib/THNN/generic/SoftMax.c index 303526a222..6a11834b95 100644 --- a/torch/lib/THNN/generic/SoftMax.c +++ b/torch/lib/THNN/generic/SoftMax.c @@ -5,73 +5,50 @@ void THNN_(SoftMax_updateOutput)( THNNState *state, THTensor *input, - THTensor *output) -{ - real *input_data, *output_data; - ptrdiff_t nframe = 0, dim = 0, stride = 0; - ptrdiff_t t; - - if (input->nDimension == 1) - { - nframe = 1; - dim = input->size[0]; - stride = 1; - } - else if (input->nDimension == 2) - { - nframe = input->size[0]; - dim = input->size[1]; - stride = 1; - } - else if (input->nDimension == 3) - { - nframe = 1; - dim = input->size[0]; - stride = input->size[1]*input->size[2]; - } - else if (input->nDimension == 4) - { - nframe = input->size[0]; - dim = input->size[1]; - stride = input->size[2]*input->size[3]; - } - else - { - THArgCheck(0, 2, "1D, 2D, 3D or 4D tensor expected"); - } + THTensor *output, + int dim) { + THArgCheck(dim >= 0 && dim < input->nDimension, 4, + "dim out of range (got %d, but input has %d dims)", dim, input->nDimension); + + uint64_t outer_size = 1; + uint64_t dim_size = input->size[dim]; + uint64_t inner_size = 1; + for (uint64_t i = 0; i < dim; ++i) + outer_size *= input->size[i]; + for (uint64_t i = dim + 1; i < input->nDimension; ++i) + inner_size *= input->size[i]; input = THTensor_(newContiguous)(input); THTensor_(resizeAs)(output, input); - input_data = THTensor_(data)(input); - output_data = THTensor_(data)(output); + real *input_data_base = THTensor_(data)(input); + real *output_data_base = THTensor_(data)(output); -#pragma omp parallel for private(t) - for (t = 0; t < stride*nframe; t++) - { - real *input_ptr = input_data + (t/stride)*dim*stride + t % stride; - real *output_ptr = output_data + (t/stride)*dim*stride + t % stride; + uint64_t dim_stride = inner_size; + uint64_t outer_stride = dim_size * dim_stride; - real inputMax = -THInf; - accreal sum; +#pragma omp parallel for + for (uint64_t i = 0; i < outer_size * inner_size; i++) { + uint64_t outer_idx = i / inner_size; + uint64_t inner_idx = i % inner_size; + real *input_data = input_data_base + outer_idx * outer_stride + inner_idx; + real *output_data = output_data_base + outer_idx * outer_stride + inner_idx; - ptrdiff_t d; - for (d = 0; d < dim; d++) - { - if (input_ptr[d*stride] >= inputMax) inputMax = input_ptr[d*stride]; + real input_max = -THInf; + for (uint64_t d = 0; d < dim_size; d++) { + if (input_data[d * dim_stride] >= input_max) input_max = input_data[d * dim_stride]; } - sum = 0; - for (d = 0; d < dim; d++) - { - real z = exp(input_ptr[d*stride] - inputMax); - output_ptr[d*stride] = z; + accreal sum = 0; + for (uint64_t d = 0; d < dim_size; d++) { + real z = exp(input_data[d * dim_stride] - input_max); + output_data[d * dim_stride] = z; sum += z; } - for (d = 0; d < dim; d++) - { - output_ptr[d*stride] *= 1/sum; + real invsum = 1 / sum; // NOTE: truncate sum to real once + for (uint64_t d = 0; d < dim_size; d++) { + output_data[d * dim_stride] *= invsum; } } @@ -83,64 +60,47 @@ void THNN_(SoftMax_updateGradInput)( THTensor *input, THTensor *gradOutput, THTensor *gradInput, - THTensor *output) + THTensor *output, + int dim) { - THNN_CHECK_SHAPE(input, gradOutput); - real *gradInput_data, *gradOutput_data, *output_data; - ptrdiff_t nframe = 0, dim = 0, stride = 0; - ptrdiff_t t; - - if (output->nDimension == 1) - { - nframe = 1; - dim = output->size[0]; - stride = 1; - } - else if (output->nDimension == 2) - { - nframe = output->size[0]; - dim = output->size[1]; - stride = 1; - } - else if (output->nDimension == 3) - { - nframe = 1; - dim = output->size[0]; - stride = output->size[1]*output->size[2]; - } - else if (output->nDimension == 4) - { - nframe = output->size[0]; - dim = output->size[1]; - stride = output->size[2]*output->size[3]; - } - else - { - THError("1D, 2D, 3D or 4D tensor expected"); - } + THNN_CHECK_SHAPE(output, gradOutput); + THArgCheck(dim >= 0 && dim < output->nDimension, 6, + "dim out of range (got %d, but input has %d dims)", dim, output->nDimension); + + uint64_t outer_size = 1; + uint64_t dim_size = output->size[dim]; + uint64_t inner_size = 1; + for (uint64_t i = 0; i < dim; ++i) + outer_size *= output->size[i]; + for (uint64_t i = dim + 1; i < output->nDimension; ++i) + inner_size *= output->size[i]; gradOutput = THTensor_(newContiguous)(gradOutput); output = THTensor_(newContiguous)(output); - THTensor_(resizeAs)(gradInput, output); - gradInput_data = THTensor_(data)(gradInput); - output_data = THTensor_(data)(output); - gradOutput_data = THTensor_(data)(gradOutput); -#pragma omp parallel for private(t) - for (t = 0; t < stride*nframe; t++) + real *gradInput_data_base = THTensor_(data)(gradInput); + real *output_data_base = THTensor_(data)(output); + real *gradOutput_data_base = THTensor_(data)(gradOutput); + + uint64_t dim_stride = inner_size; + uint64_t outer_stride = dim_size * dim_stride; + +#pragma omp parallel for + for (uint64_t i = 0; i < outer_size * inner_size; i++) { - real *gradInput_ptr = gradInput_data + (t/stride)*dim*stride + t % stride; - real *output_ptr = output_data + (t/stride)*dim*stride + t % stride; - real *gradOutput_ptr = gradOutput_data + (t/stride)*dim*stride + t % stride; + uint64_t outer_idx = i / inner_size; + uint64_t inner_idx = i % inner_size; + real *gradInput_data = gradInput_data_base + outer_idx * outer_stride + inner_idx; + real *output_data = output_data_base + outer_idx * outer_stride + inner_idx; + real *gradOutput_data = gradOutput_data_base + outer_idx * outer_stride + inner_idx; - ptrdiff_t d; accreal sum = 0; - for (d = 0; d < dim; d++) - sum += (accreal)gradOutput_ptr[d*stride] * output_ptr[d*stride]; + for (uint64_t d = 0; d < dim_size; d++) + sum += ((accreal)gradOutput_data[d * dim_stride]) * ((accreal)output_data[d * dim_stride]); - for (d = 0; d < dim; d++) - gradInput_ptr[d*stride] = output_ptr[d*stride] * (gradOutput_ptr[d*stride] - sum); + for (uint64_t d = 0; d < dim_size; d++) + gradInput_data[d * dim_stride] = output_data[d * dim_stride] * (gradOutput_data[d * dim_stride] - sum); } THTensor_(free)(gradOutput); diff --git a/torch/lib/THNN/generic/THNN.h b/torch/lib/THNN/generic/THNN.h index dbbf5f1bca..9da18bc0c4 100644 --- a/torch/lib/THNN/generic/THNN.h +++ b/torch/lib/THNN/generic/THNN.h @@ -224,13 +224,15 @@ TH_API void THNN_(LogSigmoid_updateGradInput)( TH_API void THNN_(LogSoftMax_updateOutput)( THNNState *state, // library's state THTensor *input, // input tensor - THTensor *output); // [OUT] output tensor + THTensor *output, // [OUT] output tensor + int dim); TH_API void THNN_(LogSoftMax_updateGradInput)( THNNState *state, // library's state THTensor *input, // input tensor THTensor *gradOutput, // gradient w.r.t. module's output THTensor *gradInput, // [OUT] gradient w.r.t. input - THTensor *output); // module's output + THTensor *output, // module's output + int dim); TH_API void THNN_(LookupTable_accGradParameters)( THNNState *state, @@ -423,13 +425,15 @@ TH_API void THNN_(SmoothL1Criterion_updateGradInput)( TH_API void THNN_(SoftMax_updateOutput)( THNNState *state, THTensor *input, - THTensor *output); + THTensor *output, + int dim); TH_API void THNN_(SoftMax_updateGradInput)( THNNState *state, THTensor *input, THTensor *gradOutput, THTensor *gradInput, - THTensor *output); + THTensor *output, + int dim); TH_API void THNN_(SoftPlus_updateOutput)( THNNState *state, diff --git a/torch/nn/_functions/thnn/activation.py b/torch/nn/_functions/thnn/activation.py index 4d2ac18656..25c2cac93f 100644 --- a/torch/nn/_functions/thnn/activation.py +++ b/torch/nn/_functions/thnn/activation.py @@ -4,7 +4,6 @@ from torch._thnn import type2backend from torch.autograd.variable import Variable from . import _all_functions -from .auto_double_backwards import softmax_double_backwards class PReLU(Function): diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py index 653eef0eb3..fbd83f40ff 100644 --- a/torch/nn/_functions/thnn/auto.py +++ b/torch/nn/_functions/thnn/auto.py @@ -298,6 +298,8 @@ def _generate_function_classes(scope_dict): 'LookupTableBag', 'PReLU', 'RReLU', + 'SoftMax', + 'LogSoftMax', 'GRUFused', 'LSTMFused', 'unfolded', @@ -312,8 +314,6 @@ def _generate_function_classes(scope_dict): 'SpatialReplicationPadding': 'ReplicationPad2d', 'VolumetricReplicationPadding': 'ReplicationPad3d', 'VolumetricMaxUnpooling': 'MaxUnpool3d', - 'SoftMax': 'Softmax', - 'LogSoftMax': 'LogSoftmax', 'HardTanh': 'Hardtanh', 'HardShrink': 'Hardshrink', 'SoftPlus': 'Softplus', diff --git a/torch/nn/_functions/thnn/auto_double_backwards.py b/torch/nn/_functions/thnn/auto_double_backwards.py index ba900e9d99..e3a96d2e49 100644 --- a/torch/nn/_functions/thnn/auto_double_backwards.py +++ b/torch/nn/_functions/thnn/auto_double_backwards.py @@ -91,21 +91,6 @@ def logsigmoid_double_backwards(ctx, ggI): return gI, ggO, None, None, None, None -def logsoftmax_double_backwards(ctx, ggI): - t = ctx.saved_variables - gO, output = t[1], t[2] - - output_exp = output.exp() - gO_sum = gO.sum(dim=1, keepdim=True) - ggI_output_exp = ggI * output_exp - ggI_output_exp_sum = ggI_output_exp.sum(dim=1, keepdim=True) - - gI = output_exp * gO_sum * ggI_output_exp_sum - ggI_output_exp * gO_sum - ggO = ggI - ggI_output_exp_sum - - return gI, ggO, None, None, None, None - - def reflectionpad1d_double_backwards(ctx, ggI): gI = None ggO = torch.nn._functions.thnn.auto.ReflectionPad1d.apply(ggI, *ctx.additional_args) @@ -141,29 +126,6 @@ def replicationpad3d_double_backwards(ctx, ggI): return gI, ggO, None, None, None, None -def softmax_double_backwards(ctx, ggI): - t = ctx.saved_variables - gO, output = t[1], t[2] - - # terms for reuse - ggI_output = ggI * output - ggI_out_sum = ggI_output.sum(dim=1, keepdim=True) - ggI_out_sum_output = ggI_out_sum * output - gO_out_sum = (gO * output).sum(dim=1, keepdim=True) - - # gI calculation - gI_t0 = ggI_output * (gO - gO_out_sum) - gI_t1 = output * ((ggI_output * gO).sum(dim=1, keepdim=True).sub_(gO_out_sum * ggI_out_sum)) - gI_t2 = ggI_out_sum_output * gO - gI_t3 = ggI_out_sum_output * gO_out_sum - gI = gI_t0 - gI_t1 - gI_t2 + gI_t3 - - # gO calculation - ggO = output * (ggI - ggI_out_sum) - - return gI, ggO, None, None, None, None - - def softplus_double_backwards(ctx, ggI): t = ctx.saved_variables input, gO, output = t[0], t[1], t[2] @@ -313,13 +275,11 @@ double_backwards_fns = { 'Hardtanh': hardtanh_double_backwards, 'LeakyReLU': leakyrelu_double_backwards, 'LogSigmoid': logsigmoid_double_backwards, - 'LogSoftmax': logsoftmax_double_backwards, 'ReflectionPad1d': reflectionpad1d_double_backwards, 'ReflectionPad2d': reflectionpad2d_double_backwards, 'ReplicationPad1d': replicationpad1d_double_backwards, 'ReplicationPad2d': replicationpad2d_double_backwards, 'ReplicationPad3d': replicationpad3d_double_backwards, - 'Softmax': softmax_double_backwards, 'Softplus': softplus_double_backwards, 'Softshrink': softshrink_double_backwards, 'Threshold': threshold_double_backwards, diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 52ec672864..ca6a18a71a 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -17,7 +17,9 @@ from torch.autograd import Variable from .modules.utils import _single, _pair, _triple # Convolutions -ConvNd = torch._C._functions.ConvNd +_ConvNd = torch._C._functions.ConvNd +_Softmax = torch._C._functions.Softmax +_LogSoftmax = torch._C._functions.LogSoftmax def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, @@ -49,9 +51,9 @@ def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, if input is not None and input.dim() != 3: raise ValueError("Expected 3D tensor as input, got {}D tensor instead.".format(input.dim())) - f = ConvNd(_single(stride), _single(padding), _single(dilation), False, - _single(0), groups, torch.backends.cudnn.benchmark, - torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled) + f = _ConvNd(_single(stride), _single(padding), _single(dilation), False, + _single(0), groups, torch.backends.cudnn.benchmark, + torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled) return f(input, weight, bias) @@ -85,9 +87,9 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, if input is not None and input.dim() != 4: raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(input.dim())) - f = ConvNd(_pair(stride), _pair(padding), _pair(dilation), False, - _pair(0), groups, torch.backends.cudnn.benchmark, - torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled) + f = _ConvNd(_pair(stride), _pair(padding), _pair(dilation), False, + _pair(0), groups, torch.backends.cudnn.benchmark, + torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled) return f(input, weight, bias) @@ -121,9 +123,9 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, if input is not None and input.dim() != 5: raise ValueError("Expected 5D tensor as input, got {}D tensor instead.".format(input.dim())) - f = ConvNd(_triple(stride), _triple(padding), _triple(dilation), False, - _triple(0), groups, torch.backends.cudnn.benchmark, - torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled) + f = _ConvNd(_triple(stride), _triple(padding), _triple(dilation), False, + _triple(0), groups, torch.backends.cudnn.benchmark, + torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled) return f(input, weight, bias) @@ -153,10 +155,10 @@ def conv_transpose1d(input, weight, bias=None, stride=1, padding=0, if input is not None and input.dim() != 3: raise ValueError("Expected 3D tensor as input, got {}D tensor instead.".format(input.dim())) - f = ConvNd(_single(stride), _single(padding), _single(dilation), True, - _single(output_padding), - groups, torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, - torch.backends.cudnn.enabled) + f = _ConvNd(_single(stride), _single(padding), _single(dilation), True, + _single(output_padding), + groups, torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, + torch.backends.cudnn.enabled) return f(input, weight, bias) @@ -187,9 +189,9 @@ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, if input is not None and input.dim() != 4: raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(input.dim())) - f = ConvNd(_pair(stride), _pair(padding), _pair(dilation), True, - _pair(output_padding), groups, torch.backends.cudnn.benchmark, - torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled) + f = _ConvNd(_pair(stride), _pair(padding), _pair(dilation), True, + _pair(output_padding), groups, torch.backends.cudnn.benchmark, + torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled) return f(input, weight, bias) @@ -219,9 +221,9 @@ def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, if input is not None and input.dim() != 5: raise ValueError("Expected 5D tensor as input, got {}D tensor instead.".format(input.dim())) - f = ConvNd(_triple(stride), _triple(padding), _triple(dilation), True, - _triple(output_padding), groups, torch.backends.cudnn.benchmark, - torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled) + f = _ConvNd(_triple(stride), _triple(padding), _triple(dilation), True, + _triple(output_padding), groups, torch.backends.cudnn.benchmark, + torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled) return f(input, weight, bias) @@ -599,20 +601,73 @@ def softplus(input, beta=1, threshold=20): return _functions.thnn.auto.Softplus.apply(input, beta, threshold) -def softmin(input): - return softmax(-input) +def _get_softmax_dim(name, ndim, stacklevel): + warnings.warn("Implicit dimension choice for " + name + " has been deprecated. " + "Change the call to include dim=X as an argument.", stacklevel=stacklevel) + if ndim == 0 or ndim == 3: + return 0 + else: + return 1 -def softmax(input): - return _functions.thnn.auto.Softmax.apply(input) +def softmin(input, dim=None, _stacklevel=3): + """Applies a softmin function. + Note that softmin(x) = softmax(-x). See softmax definition for mathematical formula. -def softshrink(input, lambd=0.5): - return _functions.thnn.auto.Softshrink.apply(input, lambd) + Arguments: + input (Variable): input + dim (int): A dimension along which softmin will be computed (so every slice + along dim will sum to 1). + """ + if dim is None: + dim = _get_softmax_dim('softmin', input.dim(), _stacklevel) + return _Softmax(dim)(-input) + + +def softmax(input, dim=None, _stacklevel=3): + """Applies a softmax function. + + Softmax is defined as: + + :math:`softmax(x) = \frac{exp(-x_i)}{\sum_j exp(-x_j)}` + + It is applied to all slices along dim, and will rescale them so that the elements + lie in the range `(0, 1)` and sum to 1. + + Arguments: + input (Variable): input + dim (int): A dimension along which softmax will be computed. + .. note:: + This function doesn't work directly with NLLLoss, + which expects the Log to be computed between the Softmax and itself. + Use log_softmax instead (it's faster and has better numerical properties). -def log_softmax(input): - return _functions.thnn.LogSoftmax.apply(input) + """ + if dim is None: + dim = _get_softmax_dim('softmax', input.dim(), _stacklevel) + return _Softmax(dim)(input) + + +def log_softmax(input, dim=None, _stacklevel=3): + """Applies a softmax followed by a logarithm. + + While mathematically equivalent to log(softmax(x)), doing these two + operations separately is slower, and numerically unstable. This function + uses an alternative formulation to compute the output and gradient correctly. + + Arguments: + input (Variable): input + dim (int): A dimension along which log_softmax will be computed. + """ + if dim is None: + dim = _get_softmax_dim('log_softmax', input.dim(), _stacklevel) + return _LogSoftmax(dim)(input) + + +def softshrink(input, lambd=0.5): + return _functions.thnn.auto.Softshrink.apply(input, lambd) def tanh(input): @@ -931,7 +986,7 @@ def cross_entropy(input, target, weight=None, size_average=True, ignore_index=-1 >>> loss = F.cross_entropy(input, target) >>> loss.backward() """ - return nll_loss(log_softmax(input), target, weight, size_average, ignore_index) + return nll_loss(log_softmax(input, 1), target, weight, size_average, ignore_index) def binary_cross_entropy(input, target, weight=None, size_average=True): diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 399555da27..762e51fa6a 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -601,8 +601,12 @@ class Softmin(Module): :math:`f(x) = exp(-x_i) / sum_j exp(-x_j)` Shape: - - Input: :math:`(N, L)` - - Output: :math:`(N, L)` + - Input: any shape + - Output: same as input + + Arguments: + dim (int): A dimension along which Softmax will be computed (so every slice + along dim will sum to 1). Returns: a Tensor of the same dimension and shape as the input, with @@ -615,9 +619,12 @@ class Softmin(Module): >>> print(input) >>> print(m(input)) """ + def __init__(self, dim=None): + super(Softmin, self).__init__() + self.dim = dim def forward(self, input): - return F.softmin(input) + return F.softmin(input, self.dim, _stacklevel=5) def __repr__(self): return self.__class__.__name__ + ' ()' @@ -632,17 +639,21 @@ class Softmax(Module): :math:`f_i(x) = exp(x_i) / sum_j exp(x_j)` Shape: - - Input: :math:`(N, L)` - - Output: :math:`(N, L)` + - Input: any shape + - Output: same as input Returns: a Tensor of the same dimension and shape as the input with values in the range [0, 1] + Arguments: + dim (int): A dimension along which Softmax will be computed (so every slice + along dim will sum to 1). + .. note:: This module doesn't work directly with NLLLoss, which expects the Log to be computed between the Softmax and itself. - Use Logsoftmax instead (it's faster). + Use Logsoftmax instead (it's faster and has better numerical properties). Examples:: @@ -652,9 +663,17 @@ class Softmax(Module): >>> print(m(input)) """ + def __init__(self, dim=None): + super(Softmax, self).__init__() + self.dim = dim + + def __setstate__(self, state): + self.__dict__.update(state) + if not hasattr(self, 'dim'): + self.dim = None + def forward(self, input): - assert input.dim() == 2, 'Softmax requires a 2D tensor as input' - return F.softmax(input) + return F.softmax(input, self.dim, _stacklevel=5) def __repr__(self): return self.__class__.__name__ + ' ()' @@ -686,7 +705,7 @@ class Softmax2d(Module): def forward(self, input): assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input' - return F.softmax(input) + return F.softmax(input, 1, _stacklevel=5) def __repr__(self): return self.__class__.__name__ + ' ()' @@ -699,8 +718,12 @@ class LogSoftmax(Module): :math:`f_i(x) = log(exp(x_i) / sum_j exp(x_j) )` Shape: - - Input: :math:`(N, L)` - - Output: :math:`(N, L)` + - Input: any shape + - Output: same as input + + Arguments: + dim (int): A dimension along which Softmax will be computed (so every slice + along dim will sum to 1). Returns: a Tensor of the same dimension and shape as the input with @@ -714,8 +737,17 @@ class LogSoftmax(Module): >>> print(m(input)) """ + def __init__(self, dim=None): + super(LogSoftmax, self).__init__() + self.dim = dim + + def __setstate__(self, state): + self.__dict__.update(state) + if not hasattr(self, 'dim'): + self.dim = None + def forward(self, input): - return F.log_softmax(input) + return F.log_softmax(input, self.dim, _stacklevel=5) def __repr__(self): return self.__class__.__name__ + ' ()' |