summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2017-09-29 08:52:35 -0700
committerAdam Paszke <adam.paszke@gmail.com>2017-10-19 19:51:10 +0200
commit98e67448fa78bd1bc6f05920ad03efceecc10066 (patch)
tree50a690a709583559368145aacce9a460b184ce0c
parent3a4ca7a2696ac5f8d3a32108648f588bbc2b1eaa (diff)
downloadpytorch-98e67448fa78bd1bc6f05920ad03efceecc10066.tar.gz
pytorch-98e67448fa78bd1bc6f05920ad03efceecc10066.tar.bz2
pytorch-98e67448fa78bd1bc6f05920ad03efceecc10066.zip
Large Softmax and LogSoftmax refactor
- Cleaned up THNN and THCUNN code and kernels - Improved THCUNN kernel performance 5x, making it match cuDNN performance - Added support for computing softmax over arbitrary dims NOTE: The default dim for 3D inputs is now 1 (used to be 0) - Both functions now accept inputs with arbitrarily many dimensions - Autograd functions no longer save the input (it's unnecessary) - Added cuDNN bindings for softmax, but they are unused as THCUNN matches or even exceeds cuDNN performance
-rw-r--r--setup.py1
-rw-r--r--test/common_nn.py21
-rw-r--r--test/test_nn.py47
-rw-r--r--tools/autograd/gen_variable_type.py4
-rw-r--r--torch/csrc/autograd/functions/init.cpp24
-rw-r--r--torch/csrc/autograd/functions/softmax.cpp107
-rw-r--r--torch/csrc/autograd/functions/softmax.h67
-rw-r--r--torch/csrc/cudnn/Descriptors.h63
-rw-r--r--torch/csrc/cudnn/Exceptions.h14
-rw-r--r--torch/legacy/nn/LogSoftMax.py14
-rw-r--r--torch/legacy/nn/SoftMax.py14
-rw-r--r--torch/legacy/nn/SoftMin.py13
-rw-r--r--torch/legacy/nn/SpatialSoftMax.py6
-rw-r--r--torch/lib/THCUNN/LogSoftMax.cu122
-rw-r--r--torch/lib/THCUNN/generic/LogSoftMax.cu233
-rw-r--r--torch/lib/THCUNN/generic/SoftMax.cu18
-rw-r--r--torch/lib/THCUNN/generic/THCUNN.h12
-rw-r--r--torch/lib/THNN/generic/LogSoftMax.c168
-rw-r--r--torch/lib/THNN/generic/SoftMax.c166
-rw-r--r--torch/lib/THNN/generic/THNN.h12
-rw-r--r--torch/nn/_functions/thnn/activation.py1
-rw-r--r--torch/nn/_functions/thnn/auto.py4
-rw-r--r--torch/nn/_functions/thnn/auto_double_backwards.py40
-rw-r--r--torch/nn/functional.py113
-rw-r--r--torch/nn/modules/activation.py56
25 files changed, 774 insertions, 566 deletions
diff --git a/setup.py b/setup.py
index c43531f717..5df7355739 100644
--- a/setup.py
+++ b/setup.py
@@ -418,6 +418,7 @@ main_sources = [
"torch/csrc/autograd/generated/python_functions.cpp",
"torch/csrc/autograd/functions/batch_normalization.cpp",
"torch/csrc/autograd/functions/convolution.cpp",
+ "torch/csrc/autograd/functions/softmax.cpp",
"torch/csrc/autograd/functions/basic_ops.cpp",
"torch/csrc/autograd/functions/tensor.cpp",
"torch/csrc/autograd/functions/accumulate_grad.cpp",
diff --git a/test/common_nn.py b/test/common_nn.py
index 7606238094..cae61dcb9d 100644
--- a/test/common_nn.py
+++ b/test/common_nn.py
@@ -91,6 +91,7 @@ module_tests = [
),
dict(
module_name='Softmax',
+ constructor_args=(1,),
input_size=(10, 20),
reference_fn=lambda i, _: torch.exp(i).div(torch.exp(i).sum(1, True).expand(10, 20)),
),
@@ -101,11 +102,13 @@ module_tests = [
),
dict(
module_name='LogSoftmax',
+ constructor_args=(1,),
input_size=(10, 20),
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, True).expand(10, 20)).log_(),
),
dict(
module_name='LogSoftmax',
+ constructor_args=(1,),
input_size=(1, 3, 10, 20),
reference_fn=lambda i, _: torch.exp(i).div_(torch.exp(i).sum(1, False)).log_(),
desc='multiparam',
@@ -220,10 +223,12 @@ module_tests = [
),
dict(
module_name='Softmin',
+ constructor_args=(1,),
input_size=(10, 20),
),
dict(
module_name='Softmin',
+ constructor_args=(1,),
input_size=(2, 3, 5, 10),
desc='multidim',
),
@@ -629,6 +634,7 @@ class ModuleTest(TestBase):
super(ModuleTest, self).__init__(*args, **kwargs)
self.jacobian_input = kwargs.get('jacobian_input', True)
self.should_test_cuda = kwargs.get('test_cuda', True)
+ self.should_test_pickle = kwargs.get('pickle', True)
def __call__(self, test_case):
module = self.constructor(*self.constructor_args)
@@ -644,13 +650,14 @@ class ModuleTest(TestBase):
self.test_noncontig(test_case, module, input)
- # TODO: do this with in-memory files as soon as torch.save will support it
- with TemporaryFile() as f:
- test_case._forward(module, input)
- torch.save(module, f)
- f.seek(0)
- module_copy = torch.load(f)
- test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
+ if self.should_test_pickle:
+ # TODO: do this with in-memory files as soon as torch.save will support it
+ with TemporaryFile() as f:
+ test_case._forward(module, input)
+ torch.save(module, f)
+ f.seek(0)
+ module_copy = torch.load(f)
+ test_case.assertEqual(test_case._forward(module, input), test_case._forward(module_copy, input))
self._do_test(test_case, module, input)
diff --git a/test/test_nn.py b/test/test_nn.py
index 81867c4119..a7c64163cf 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -162,7 +162,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
test_case.assertEqual(type(p.data), torch.DoubleTensor)
# TODO: Hardshrink is lacking a CUDA implementation
- if TEST_CUDA and type(module) != nn.Hardshrink:
+ if TEST_CUDA and self.should_test_cuda and type(module) != nn.Hardshrink:
# to GPU0
input = input.float().cuda()
module.float().cuda()
@@ -3576,6 +3576,13 @@ def add_test(test):
setattr(TestNN, cuda_test_name, lambda self, test=test: test.test_cuda(self))
+def wrap_functional(fn, **kwargs):
+ class FunctionalModule(nn.Module):
+ def forward(self, *args):
+ return fn(*args, **kwargs)
+ return FunctionalModule
+
+
new_criterion_tests = [
dict(
module_name='BCEWithLogitsLoss',
@@ -4367,6 +4374,44 @@ new_module_tests = [
input_size=(5, 6, 7),
desc='dim'
),
+ dict(
+ constructor=wrap_functional(F.softmax, dim=1),
+ input_size=(2, 3, 4, 5),
+ fullname='softmax',
+ pickle=False,
+ ),
+ dict(
+ constructor=wrap_functional(F.softmax, dim=0),
+ input_size=(2, 3, 4, 5),
+ fullname='softmax_dim0',
+ test_cuda=False,
+ pickle=False,
+ ),
+ dict(
+ constructor=wrap_functional(F.softmax, dim=3),
+ input_size=(2, 3, 4, 5),
+ fullname='softmax_dim3',
+ test_cuda=False,
+ pickle=False,
+ ),
+ dict(
+ constructor=wrap_functional(F.log_softmax, dim=1),
+ input_size=(2, 3, 4, 5),
+ fullname='log_softmax',
+ pickle=False,
+ ),
+ dict(
+ constructor=wrap_functional(F.log_softmax, dim=0),
+ input_size=(2, 3, 4, 5),
+ fullname='log_softmax_dim0',
+ pickle=False,
+ ),
+ dict(
+ constructor=wrap_functional(F.log_softmax, dim=3),
+ input_size=(2, 3, 4, 5),
+ fullname='log_softmax_dim3',
+ pickle=False,
+ ),
]
diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py
index 315c5c76da..bbeaea02bc 100644
--- a/tools/autograd/gen_variable_type.py
+++ b/tools/autograd/gen_variable_type.py
@@ -492,10 +492,6 @@ def create_variable_type(top_env, aten_declarations):
if arg['dynamic_type'] == 'Tensor']
env['tensorlist_args'] = [arg['name'] for arg in option['arguments']
if arg['dynamic_type'] == 'TensorList']
- if option['return_type'] == 'Scalar':
- env['return_value'] = 'Scalar(output)'
- else:
- env['return_value'] = 'Tensor(std::move(output))'
env['type_definition_body'] = emit_body(env, option)
diff --git a/torch/csrc/autograd/functions/init.cpp b/torch/csrc/autograd/functions/init.cpp
index 206c5fc7ca..82ece0359a 100644
--- a/torch/csrc/autograd/functions/init.cpp
+++ b/torch/csrc/autograd/functions/init.cpp
@@ -1,5 +1,6 @@
#include "batch_normalization.h"
#include "convolution.h"
+#include "softmax.h"
#include "accumulate_grad.h"
#include "basic_ops.h"
#include "tensor.h"
@@ -50,6 +51,19 @@ struct ConvCtor {
}
};
+template<bool is_log>
+struct SoftmaxCtor {
+ using fn_type = typename std::conditional<is_log, LogSoftmax, Softmax>::type;
+ fn_type* operator()(PyObject* args) {
+ int dim;
+
+ TupleParser parser(args, 1);
+ parser.parse(dim, "dim");
+
+ return new fn_type(dim);
+ }
+};
+
struct DelayedErrorCtor {
DelayedError* operator()(PyObject* args) {
std::string msg;
@@ -252,6 +266,16 @@ bool THPAutograd_initFunctions(PyObject* _unused)
addClass<ConvBackward, NoCtor>(module, ConvBackwardClass, "ConvNdBackward", conv_backward_properties);
addClass<ConvBackwardBackward, NoCtor>(module, ConvBackwardBackwardClass, "ConvNdBackwardBackward", conv_backward_backward_properties);
+ static PyTypeObject SoftmaxClass, SoftmaxBackwardClass, SoftmaxBackwardBackwardClass;
+ addClass<Softmax, SoftmaxCtor<false>>(module, SoftmaxClass, "Softmax");
+ addClass<SoftmaxBackward, NoCtor>(module, SoftmaxBackwardClass, "SoftmaxBackward");
+ addClass<SoftmaxBackwardBackward, NoCtor>(module, SoftmaxBackwardBackwardClass, "SoftmaxBackwardBackward");
+
+ static PyTypeObject LogSoftmaxClass, LogSoftmaxBackwardClass, LogSoftmaxBackwardBackwardClass;
+ addClass<LogSoftmax, SoftmaxCtor<true>>(module, LogSoftmaxClass, "LogSoftmax");
+ addClass<LogSoftmaxBackward, NoCtor>(module, LogSoftmaxBackwardClass, "LogSoftmaxBackward");
+ addClass<LogSoftmaxBackwardBackward, NoCtor>(module, LogSoftmaxBackwardBackwardClass, "LogSoftmaxBackwardBackward");
+
static PyTypeObject AccumulateGradClass;
addClass<AccumulateGrad, NoCtor>(module, AccumulateGradClass, "AccumulateGrad", accumulate_grad_properties);
diff --git a/torch/csrc/autograd/functions/softmax.cpp b/torch/csrc/autograd/functions/softmax.cpp
new file mode 100644
index 0000000000..50a5b45e2b
--- /dev/null
+++ b/torch/csrc/autograd/functions/softmax.cpp
@@ -0,0 +1,107 @@
+
+#include "softmax.h"
+
+#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/autograd/functions/utils.h"
+#include "torch/csrc/utils/auto_gpu.h"
+#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+
+namespace torch { namespace autograd {
+
+template<bool is_log>
+variable_list SoftmaxBase<is_log>::apply(const variable_list& inputs) {
+ using BackwardBase = SoftmaxBackwardBase<is_log>;
+ check_input_variables("SoftmaxBase", inputs, 1);
+ AutoGPU gpu_guard(inputs[0].data());
+
+ auto input = inputs[0].data().contiguous();
+ auto output = input.type().tensor(input.sizes());
+
+ if (is_log) {
+ at::log_softmax_out(output, input, dim);
+ } else {
+ at::softmax_out(output, input, dim);
+ }
+
+ // This gets a bit weird because we need to save the output...
+ std::shared_ptr<BackwardBase> backward;
+ auto outputs = wrap_outputs(inputs, as_tensor_list(output), [this, &backward](FunctionFlags f) {
+ backward = std::make_shared<BackwardBase>(std::move(f), this->dim);
+ return backward;
+ });
+ if (backward && backward->is_executable) {
+ backward->saved_output = SavedVariable(outputs[0], backward.get());
+ }
+ return outputs;
+};
+
+template<bool is_log>
+variable_list SoftmaxBackwardBase<is_log>::apply(const variable_list& grad_outputs) {
+ using BackwardBase = typename std::conditional<is_log, LogSoftmaxBackwardBackward, SoftmaxBackwardBackward>::type;
+ check_input_variables("SoftmaxBackwardBase", grad_outputs, 1);
+ AutoGPU gpu_guard(grad_outputs[0]);
+
+ auto output_var = saved_output.unpack(shared_from_this());
+ auto& grad_output_var = grad_outputs[0];
+ auto& output = output_var.data();
+ auto& grad_output = grad_output_var.data();
+ auto grad_input = output.type().tensor(output.sizes());
+
+ auto input = output.type().tensor(); // We don't save the input, because THNN doesn't use it anyway...
+ if (is_log) {
+ at::log_softmax_backward_out(grad_input, grad_output, input, dim, output);
+ } else {
+ at::softmax_backward_out(grad_input, grad_output, input, dim, output);
+ }
+
+ variable_list all_inputs {output_var, grad_output_var};
+ return wrap_outputs(all_inputs, as_tensor_list(grad_input), [this, &output_var, &grad_output_var](FunctionFlags f) {
+ auto fn = std::make_shared<BackwardBase>(std::move(f));
+ if (fn->is_executable) {
+ fn->saved_output = SavedVariable(output_var, fn.get());
+ fn->saved_grad_output = SavedVariable(grad_output_var, fn.get());
+ fn->dim = this->dim;
+ }
+ return fn;
+ });
+}
+
+// These need to be explicitly instantiated, because they're not in the header.
+template struct SoftmaxBase<true>;
+template struct SoftmaxBase<false>;
+template struct SoftmaxBackwardBase<true>;
+template struct SoftmaxBackwardBase<false>;
+
+variable_list SoftmaxBackwardBackward::apply(const variable_list& grad_grad_inputs) {
+ check_input_variables("SoftmaxBackwardBackward", grad_grad_inputs, 1);
+ auto output = saved_output.unpack(shared_from_this());
+ auto gO = saved_grad_output.unpack();
+ auto& ggI = grad_grad_inputs[0];
+
+ // Terms for reuse
+ auto ggI_out_sum = (ggI * output).sum(dim, true);
+ auto gO_out_sum = (gO * output).sum(dim, true);
+
+ // NOTE: this is 2nd order grad output
+ auto gO2 = (gO - gO_out_sum) * ggI - gO * ggI_out_sum;
+ auto ggO = output * (ggI - ggI_out_sum);
+
+ return {Variable(std::move(gO2)), Variable(std::move(ggO))};
+}
+
+variable_list LogSoftmaxBackwardBackward::apply(const variable_list& grad_grad_inputs) {
+ check_input_variables("LogSoftmaxBackwardBackward", grad_grad_inputs, 1);
+ auto output = saved_output.unpack(shared_from_this());
+ auto gO = saved_grad_output.unpack();
+ auto& ggI = grad_grad_inputs[0];
+
+ auto output_exp = output.exp();
+ // NOTE: this is 2nd order grad output
+ auto gO2 = (-output_exp) * ggI * gO.sum(dim, true);
+ auto ggO = ggI - (ggI * output_exp).sum(dim, true);
+
+ return {Variable(std::move(gO2)), Variable(std::move(ggO))};
+}
+
+}}
diff --git a/torch/csrc/autograd/functions/softmax.h b/torch/csrc/autograd/functions/softmax.h
new file mode 100644
index 0000000000..e42f4aeba7
--- /dev/null
+++ b/torch/csrc/autograd/functions/softmax.h
@@ -0,0 +1,67 @@
+#pragma once
+
+#include <Python.h>
+#include <memory>
+#include <ATen/ATen.h>
+
+#include "torch/csrc/autograd/function.h"
+#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/autograd/saved_variable.h"
+
+namespace torch { namespace autograd {
+
+// Softmax and LogSoftmax are implemented nearly identically until second
+// derivative, so the implementation is contained in a common base class.
+template<bool is_log>
+struct SoftmaxBase : public ForwardFunction<> {
+ SoftmaxBase(int dim)
+ : dim(dim) {}
+
+ virtual variable_list apply(const variable_list& inputs) override;
+
+ int dim;
+};
+
+template<bool is_log>
+struct SoftmaxBackwardBase : public Function {
+ SoftmaxBackwardBase(FunctionFlags f, int dim)
+ : Function(std::move(f))
+ , dim(dim) {}
+
+ virtual variable_list apply(const variable_list& inputs) override;
+
+ SavedVariable saved_output;
+ int dim;
+};
+
+
+struct Softmax : public SoftmaxBase<false> {
+ using SoftmaxBase::SoftmaxBase;
+};
+struct SoftmaxBackward : public SoftmaxBackwardBase<false> {
+ using SoftmaxBackwardBase::SoftmaxBackwardBase;
+};
+struct SoftmaxBackwardBackward : public Function {
+ using Function::Function;
+ virtual variable_list apply(const variable_list& inputs) override;
+ SavedVariable saved_output;
+ SavedVariable saved_grad_output;
+ int dim;
+};
+
+
+struct LogSoftmax : public SoftmaxBase<true> {
+ using SoftmaxBase::SoftmaxBase;
+};
+struct LogSoftmaxBackward : public SoftmaxBackwardBase<true> {
+ using SoftmaxBackwardBase::SoftmaxBackwardBase;
+};
+struct LogSoftmaxBackwardBackward : public Function {
+ using Function::Function;
+ virtual variable_list apply(const variable_list& inputs) override;
+ SavedVariable saved_output;
+ SavedVariable saved_grad_output;
+ int dim;
+};
+
+}}
diff --git a/torch/csrc/cudnn/Descriptors.h b/torch/csrc/cudnn/Descriptors.h
index 6917e1cbf5..77f6f30e38 100644
--- a/torch/csrc/cudnn/Descriptors.h
+++ b/torch/csrc/cudnn/Descriptors.h
@@ -4,27 +4,71 @@
#include "Exceptions.h"
#include "cudnn-wrapper.h"
+#include <ATen/ATen.h>
namespace torch { namespace cudnn {
+inline int dataSize(cudnnDataType_t dataType)
+{
+ switch (dataType) {
+ case CUDNN_DATA_HALF: return 2;
+ case CUDNN_DATA_FLOAT: return 4;
+ default: return 8;
+ }
+}
+
+inline cudnnDataType_t getDataType(const at::Tensor& t) {
+ auto scalar_type = t.type().scalarType();
+ if (scalar_type == at::kFloat) {
+ return CUDNN_DATA_FLOAT;
+ } else if (scalar_type == at::kHalf) {
+ return CUDNN_DATA_HALF;
+ } else if (scalar_type == at::kDouble) {
+ return CUDNN_DATA_DOUBLE;
+ }
+ throw std::runtime_error("TensorDescriptor only supports double, float and half tensors");
+}
+
struct TensorDescriptor
{
cudnnTensorDescriptor_t desc;
+
TensorDescriptor() : desc(NULL) {
CHECK(cudnnCreateTensorDescriptor(&desc));
}
- TensorDescriptor(const TensorDescriptor&) = delete;
- TensorDescriptor(TensorDescriptor&& ref)
- {
- desc = ref.desc;
- ref.desc = NULL;
+ /* implicit */ TensorDescriptor(const at::Tensor& t, int pad = 0) : desc(NULL) {
+ CHECK(cudnnCreateTensorDescriptor(&desc));
+ set(t, pad);
}
+ TensorDescriptor(const TensorDescriptor&) = delete;
+ TensorDescriptor(TensorDescriptor&& ref) { desc = ref.desc; ref.desc = NULL; }
~TensorDescriptor() {
cudnnDestroyTensorDescriptor(desc);
}
void set(cudnnDataType_t dataType, int dim, int* size, int* stride) {
CHECK(cudnnSetTensorNdDescriptor(desc, dataType, dim, size, stride));
}
+ void set(const at::Tensor &t, int pad = 0) {
+ int dim = t.ndimension();
+ if (dim > CUDNN_DIM_MAX || pad > CUDNN_DIM_MAX)
+#define _STR(X) #X
+#define STR(X) _STR(X)
+ throw std::runtime_error("cuDNN supports only up to " STR(CUDNN_DIM_MAX) " dimensions");
+#undef _STR
+#undef STR
+ int size[CUDNN_DIM_MAX];
+ int stride[CUDNN_DIM_MAX];
+ for (int i = 0; i < dim; ++i) {
+ size[i] = t.size(i);
+ stride[i] = t.stride(i);
+ }
+ for (int i = dim; i < pad; ++i) {
+ size[i] = 1;
+ stride[i] = 1;
+ }
+ dim = std::max(dim, pad);
+ set(getDataType(t), dim, size, stride);
+ }
};
struct FilterDescriptor
@@ -109,15 +153,6 @@ union Constant
}
};
-inline int dataSize(cudnnDataType_t dataType)
-{
- switch (dataType) {
- case CUDNN_DATA_HALF: return 2;
- case CUDNN_DATA_FLOAT: return 4;
- default: return 8;
- }
-}
-
}} // namespace
#endif
diff --git a/torch/csrc/cudnn/Exceptions.h b/torch/csrc/cudnn/Exceptions.h
index 362565e48c..75c3a398d0 100644
--- a/torch/csrc/cudnn/Exceptions.h
+++ b/torch/csrc/cudnn/Exceptions.h
@@ -37,6 +37,20 @@ void assertSameGPU(cudnnDataType_t dataType, T* ... tensors) {
}
}
+inline int assertSameGPU(const at::Tensor& t) {
+ return t.get_device();
+}
+
+template<typename ...T>
+int assertSameGPU(const at::Tensor& t, T& ... tensors) {
+ static_assert(std::is_same<at::Tensor, typename std::common_type<T...>::type>::value,
+ "all arguments to assertSameGPU have to be at::Tensor&");
+ auto t_device = t.get_device();
+ if (t_device != assertSameGPU(tensors...))
+ throw std::runtime_error("tensors are on different GPUs");
+ return t_device;
+}
+
class cudnn_exception : public std::runtime_error {
public:
cudnnStatus_t status;
diff --git a/torch/legacy/nn/LogSoftMax.py b/torch/legacy/nn/LogSoftMax.py
index 948e9512ac..66ad2b0573 100644
--- a/torch/legacy/nn/LogSoftMax.py
+++ b/torch/legacy/nn/LogSoftMax.py
@@ -4,11 +4,20 @@ from .Module import Module
class LogSoftMax(Module):
+ def __init__(self, dim=None):
+ super(LogSoftMax, self).__init__()
+ if dim is not None:
+ self.dim = dim
+
+ def _get_dim(self, input):
+ return getattr(self, 'dim', 0 if input.dim() == 1 or input.dim() == 3 else 1)
+
def updateOutput(self, input):
self._backend.LogSoftMax_updateOutput(
self._backend.library_state,
input,
- self.output
+ self.output,
+ self._get_dim(input)
)
return self.output
@@ -18,6 +27,7 @@ class LogSoftMax(Module):
input,
gradOutput,
self.gradInput,
- self.output
+ self.output,
+ self._get_dim(input)
)
return self.gradInput
diff --git a/torch/legacy/nn/SoftMax.py b/torch/legacy/nn/SoftMax.py
index 24d5fa5967..3f91bf80d6 100644
--- a/torch/legacy/nn/SoftMax.py
+++ b/torch/legacy/nn/SoftMax.py
@@ -4,11 +4,20 @@ from .Module import Module
class SoftMax(Module):
+ def __init__(self, dim=None):
+ super(SoftMax, self).__init__()
+ if dim is not None:
+ self.dim = dim
+
+ def _get_dim(self, input):
+ return getattr(self, 'dim', 0 if input.dim() == 1 or input.dim() == 3 else 1)
+
def updateOutput(self, input):
self._backend.SoftMax_updateOutput(
self._backend.library_state,
input,
- self.output
+ self.output,
+ self._get_dim(input)
)
return self.output
@@ -18,6 +27,7 @@ class SoftMax(Module):
input,
gradOutput,
self.gradInput,
- self.output
+ self.output,
+ self._get_dim(input)
)
return self.gradInput
diff --git a/torch/legacy/nn/SoftMin.py b/torch/legacy/nn/SoftMin.py
index 7c1bbbff3f..4c9915c37b 100644
--- a/torch/legacy/nn/SoftMin.py
+++ b/torch/legacy/nn/SoftMin.py
@@ -5,9 +5,14 @@ from .utils import clear
class SoftMin(Module):
- def __init__(self):
+ def __init__(self, dim=None):
super(SoftMin, self).__init__()
self.mininput = None
+ if dim is not None:
+ self.dim = dim
+
+ def _get_dim(self, input):
+ return getattr(self, 'dim', 0 if input.dim() == 1 or input.dim() == 3 else 1)
def updateOutput(self, input):
if self.mininput is None:
@@ -16,7 +21,8 @@ class SoftMin(Module):
self._backend.SoftMax_updateOutput(
self._backend.library_state,
self.mininput,
- self.output
+ self.output,
+ self._get_dim(input)
)
return self.output
@@ -29,7 +35,8 @@ class SoftMin(Module):
self.mininput,
gradOutput,
self.gradInput,
- self.output
+ self.output,
+ self._get_dim(input)
)
self.gradInput.mul_(-1)
diff --git a/torch/legacy/nn/SpatialSoftMax.py b/torch/legacy/nn/SpatialSoftMax.py
index 526e6d47dc..5c9c0a45d1 100644
--- a/torch/legacy/nn/SpatialSoftMax.py
+++ b/torch/legacy/nn/SpatialSoftMax.py
@@ -8,7 +8,8 @@ class SpatialSoftMax(Module):
self._backend.SoftMax_updateOutput(
self._backend.library_state,
input,
- self.output
+ self.output,
+ 0 if input.dim() == 1 or input.dim() == 3 else 1
)
return self.output
@@ -18,6 +19,7 @@ class SpatialSoftMax(Module):
input,
gradOutput,
self.gradInput,
- self.output
+ self.output,
+ 0 if input.dim() == 1 or input.dim() == 3 else 1
)
return self.gradInput
diff --git a/torch/lib/THCUNN/LogSoftMax.cu b/torch/lib/THCUNN/LogSoftMax.cu
index 98b7670718..91a9fbd7cd 100644
--- a/torch/lib/THCUNN/LogSoftMax.cu
+++ b/torch/lib/THCUNN/LogSoftMax.cu
@@ -1,88 +1,78 @@
#include "THCUNN.h"
#include "THCHalf.h"
+#include "THCTensorTypeUtils.cuh"
#include "THCHalfAutoNumerics.cuh"
#include "SharedMem.cuh"
template <typename T, typename AccumT>
-__global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input, int classSize, int height, int width)
+__global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input, uint32_t outer_size, uint32_t dim_size, uint32_t inner_size)
{
- int batchIndex = blockIdx.x;
- int index = threadIdx.x;
-
- while (index < height*width) {
- int y = index / width;
- int x = index % width;
- if (y >= height)
- break;
-
- // calculate input starting index in cuda layout (B x H x W x C)
- int inputStartIndex =
- (height*width*classSize)*batchIndex +
- (width*classSize)*y +
- (classSize)*x;
-
- T maxInput = input[inputStartIndex];
- for (int i = 1; i < classSize; i++) {
- T value = input[inputStartIndex + i];
- maxInput = THCNumerics<T>::ge(maxInput, value) ? maxInput : value;
- }
+ const uint32_t outer_stride = inner_size * dim_size;
+ const uint32_t dim_stride = inner_size;
+
+ for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
+ const uint32_t outer_offset = outer_index * outer_stride;
+ for (uint32_t inner_index = blockIdx.y * blockDim.x + threadIdx.x; inner_index < inner_size; inner_index += blockDim.x * gridDim.y) {
+ const uint32_t data_offset = outer_offset + inner_index;
+
+ T max_input = input[data_offset];
+ for (uint32_t d = 1; d < dim_size; d++) {
+ const T value = input[data_offset + d * dim_stride];
+ max_input = THCNumerics<T>::ge(max_input, value) ? max_input : value;
+ }
- AccumT sum = 0;
- for (int i = 0; i < classSize; i++) {
- sum += THCNumerics<T>::exp(input[inputStartIndex + i] - maxInput);
- }
- T logsum = maxInput + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum));
-
- for (int i = 0; i < classSize; i++) {
- // calculate output index in torch layout (B x C x H x W)
- int outputIndex =
- (classSize*height*width)*batchIndex +
- (height*width)*i +
- (width)*y +
- x;
- output[outputIndex] = input[inputStartIndex + i] - logsum;
+ AccumT sum = 0;
+ for (uint32_t d = 0; d < dim_size; d++)
+ sum += THCNumerics<T>::exp(input[data_offset + d * dim_stride] - max_input);
+ const T logsum = max_input + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum));
+
+ for (uint32_t d = 0; d < dim_size; d++)
+ output[data_offset + d * dim_stride] = input[data_offset + d * dim_stride] - logsum;
}
- index += blockDim.x;
}
}
template <typename T, typename AccumT>
-__global__ void cunn_SpatialLogSoftMax_updateGradInput_kernel(T *gradInput, T *output, T *gradOutput, int classSize, int height, int width)
+__global__ void cunn_SpatialLogSoftMax_updateGradInput_kernel(T *gradInput, T *output, T *gradOutput, uint32_t outer_size, uint32_t dim_size, uint32_t inner_size)
{
- int batchIndex = blockIdx.x;
- int index = threadIdx.x;
-
- while (index < height*width) {
- int y = index / width;
- int x = index % width;
- if (y >= height)
- break;
-
- // calculate output starting index in cuda layout (B x H x W x C)
- int outputStartIndex =
- (height*width*classSize)*batchIndex +
- (width*classSize)*y +
- (classSize)*x;
-
- AccumT sum = 0;
- for (int i = 0; i < classSize; i++) {
- sum += gradOutput[outputStartIndex + i];
- }
+ const uint32_t outer_stride = inner_size * dim_size;
+ const uint32_t dim_stride = inner_size;
+
+ for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
+ const uint32_t outer_offset = outer_index * outer_stride;
+ for (uint32_t inner_index = blockIdx.y * blockDim.x + threadIdx.x; inner_index < inner_size; inner_index += blockDim.x * gridDim.y) {
+ const uint32_t data_offset = outer_offset + inner_index;
- for (int i = 0; i < classSize; i++) {
- // calculate input index in torch layout (B x C x H x W)
- int inputIndex =
- (classSize*height*width)*batchIndex +
- (height*width)*i +
- (width)*y +
- x;
- gradInput[inputIndex] = ScalarConvert<AccumT, T>::to(
- gradOutput[outputStartIndex + i] - THCNumerics<T>::exp(output[outputStartIndex + i]) * sum);
+ AccumT sum = 0;
+ for (uint32_t d = 0; d < dim_size; d++) {
+ sum += gradOutput[data_offset + d * dim_stride];
+ }
+ const T real_sum = ScalarConvert<AccumT, T>::to(sum);
+
+ for (uint32_t d = 0; d < dim_size; d++) {
+ gradInput[data_offset + d * dim_stride] = gradOutput[data_offset + d * dim_stride] -
+ THCNumerics<T>::exp(output[data_offset + d * dim_stride]) * real_sum;
+ }
}
- index += blockDim.x;
}
}
+static void LogSoftMax_getSpatialGridSize(
+ uint32_t block_size, uint32_t max_active_blocks,
+ uint64_t outer_size, uint64_t dim_size, uint64_t inner_size,
+ dim3& grid, dim3& block) {
+ // First, tile as many blocks as we can over the y axis
+ uint32_t y_size = (inner_size + block_size - 1) / block_size;
+ if (y_size > max_active_blocks)
+ y_size = max_active_blocks;
+ // Fill the x axis with as many blocks as we can fit
+ uint32_t x_size = (max_active_blocks + y_size - 1) / y_size;
+ if (x_size > outer_size)
+ x_size = outer_size;
+ grid = dim3(x_size, y_size);
+ block = dim3(block_size);
+}
+
template <typename T, typename AccumT>
struct MaxFloat
{
diff --git a/torch/lib/THCUNN/generic/LogSoftMax.cu b/torch/lib/THCUNN/generic/LogSoftMax.cu
index 2f24697f53..4e35335191 100644
--- a/torch/lib/THCUNN/generic/LogSoftMax.cu
+++ b/torch/lib/THCUNN/generic/LogSoftMax.cu
@@ -7,101 +7,60 @@
void THNN_(LogSoftMax_updateOutput)(
THCState *state,
THCTensor *input,
- THCTensor *output)
+ THCTensor *output,
+ int dim)
{
THCUNN_assertSameGPU(state, 2, input, output);
+ THArgCheck(dim >= 0 && dim < input->nDimension, 4,
+ "dim out of range (got %d, but input has %d dims)", dim, input->nDimension);
+ THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, input), 4,
+ "input tensor is too large (unsupported size. file a feature request)");
THCTensor_(resizeAs)(state, output, input);
- bool spatial = false;
- int batchSize = 1;
- int classSize = 0;
- int height = 0;
- int width = 0;
-
- int ndims = THCTensor_(nDimension)(state, input);
-
- if (ndims == 1)
- {
- classSize = THCTensor_(size)(state, input, 0);
- input = THCTensor_(newContiguous)(state, input);
- }
- else if (ndims == 2)
- {
- batchSize = THCTensor_(size)(state, input, 0);
- classSize = THCTensor_(size)(state, input, 1);
- input = THCTensor_(newContiguous)(state, input);
- }
- else if (ndims == 3)
- {
- spatial = true;
- classSize = THCTensor_(size)(state, input, 0);
- height = THCTensor_(size)(state, input, 1);
- width = THCTensor_(size)(state, input, 2);
-
- // create contiguous tensor with cuda layout from tensor with torch layout
- THCTensor *tinput = THCTensor_(new)(state);
- // C x H x W -> W x H x C
- THCTensor_(transpose)(state, tinput, input, 0, 2);
- // W x H x C -> H x W x C
- THCTensor_(transpose)(state, tinput, tinput, 0, 1);
- THCTensor *transposedInput = THCTensor_(newContiguous)(state, tinput);
- THCTensor_(free)(state, tinput);
- input = transposedInput;
- }
- else if (ndims == 4)
- {
- spatial = true;
- batchSize = THCTensor_(size)(state, input, 0);
- classSize = THCTensor_(size)(state, input, 1);
- height = THCTensor_(size)(state, input, 2);
- width = THCTensor_(size)(state, input, 3);
-
- // create contiguous tensor with cuda layout from tensor with torch layout
- // B x C x H x W -> B x W x H x C
- THCTensor *tinput = THCTensor_(new)(state);
- THCTensor_(transpose)(state, tinput, input, 1, 3);
- // B x W x H x C -> B x H x W x C
- THCTensor_(transpose)(state, tinput, tinput, 1, 2);
- THCTensor *transposedInput = THCTensor_(newContiguous)(state, tinput);
- THCTensor_(free)(state, tinput);
- input = transposedInput;
- }
- else
- {
- THError("1D, 2D, 3D or 4D Tensor expected");
- }
-
- if (!spatial)
- {
- dim3 grid(batchSize);
+ uint64_t outer_size = 1;
+ uint64_t dim_size = input->size[dim];
+ uint64_t inner_size = 1;
+ for (uint64_t i = 0; i < dim; ++i)
+ outer_size *= input->size[i];
+ for (uint64_t i = dim + 1; i < input->nDimension; ++i)
+ inner_size *= input->size[i];
+
+ // This kernel spawns a block of 1024 threads per each element in the batch.
+ // XXX: it assumes that inner_size == 1
+ input = THCTensor_(newContiguous)(state, input);
+ if (inner_size == 1 && dim_size >= 64) {
+ dim3 grid(outer_size);
dim3 block(1024);
cunn_LogSoftMax_updateOutput_kernel<2, real, accreal>
<<<grid, block, block.x * sizeof(accreal), THCState_getCurrentStream(state)>>>(
THCTensor_(data)(state, output),
THCTensor_(data)(state, input),
- classSize
+ dim_size
);
- }
- else
- {
- dim3 grid(batchSize);
- dim3 block(1024);
+ // This kernel runs in a 2D grid, where each application along y dimension has a fixed
+ // outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size.
+ // Reductions over dim are done in a single-threaded manner.
+ } else {
+ dim3 grid, block;
+ uint32_t block_size = 1024;
+ while (block_size > inner_size) block_size >>= 1; // block_size = floor(log2(inner_size))
+ int max_active_blocks;
+ cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks,
+ &cunn_SpatialLogSoftMax_updateOutput_kernel<real, accreal>,
+ block_size, 0);
+ max_active_blocks *= THCState_getCurrentDeviceProperties(state)->multiProcessorCount;
+ LogSoftMax_getSpatialGridSize(block_size, max_active_blocks, outer_size, dim_size, inner_size, grid, block);
cunn_SpatialLogSoftMax_updateOutput_kernel<real, accreal>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
THCTensor_(data)(state, output),
THCTensor_(data)(state, input),
- classSize, height, width
+ outer_size, dim_size, inner_size
);
}
-
- cudaError errcode = cudaGetLastError();
- if (errcode != cudaSuccess)
- {
- THError(cudaGetErrorString(errcode));
- }
+ THCudaCheck(cudaGetLastError());
THCTensor_(free)(state, input);
}
@@ -111,97 +70,32 @@ void THNN_(LogSoftMax_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- THCTensor *output)
+ THCTensor *output,
+ int dim)
{
- THCUNN_check_nElement(state, input, gradOutput);
+ THArgCheck(dim >= 0 && dim < output->nDimension, 6,
+ "dim out of range (got %d, but input has %d dims)", dim, output->nDimension);
+ THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, output), 6,
+ "input tensor is too large (unsupported size. file a feature request)");
+ THCUNN_check_nElement(state, output, gradOutput);
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
THCTensor_(resizeAs)(state, gradInput, output);
- bool spatial = false;
- int batchSize = 1;
- int classSize = 0;
- int height = 0;
- int width = 0;
+ uint64_t outer_size = 1;
+ uint64_t dim_size = output->size[dim];
+ uint64_t inner_size = 1;
+ for (uint64_t i = 0; i < dim; ++i)
+ outer_size *= output->size[i];
+ for (uint64_t i = dim + 1; i < output->nDimension; ++i)
+ inner_size *= output->size[i];
- int ndims = THCTensor_(nDimension)(state, input);
+ output = THCTensor_(newContiguous)(state, output);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
- if (ndims == 1)
- {
- classSize = THCTensor_(size)(state, gradInput, 0);
- output = THCTensor_(newContiguous)(state, output);
- gradOutput = THCTensor_(newContiguous)(state, gradOutput);
- }
- else if (ndims == 2)
- {
- batchSize = THCTensor_(size)(state, gradInput, 0);
- classSize = THCTensor_(size)(state, gradInput, 1);
- output = THCTensor_(newContiguous)(state, output);
- gradOutput = THCTensor_(newContiguous)(state, gradOutput);
- }
- else if (ndims == 3)
- {
- spatial = true;
- classSize = THCTensor_(size)(state, input, 0);
- height = THCTensor_(size)(state, input, 1);
- width = THCTensor_(size)(state, input, 2);
-
- // create contiguous tensor with cuda layout from tensor with torch layout
- // C x H x W -> W x H x C
- THCTensor_(transpose)(state, output, output, 0, 2);
- // W x H x C -> H x W x C
- THCTensor_(transpose)(state, output, output, 0, 1);
- THCTensor *transposedOutput = THCTensor_(newContiguous)(state, output);
- THCTensor_(transpose)(state, output, output, 0, 1);
- THCTensor_(transpose)(state, output, output, 0, 2);
- output = transposedOutput;
-
- // create contiguous tensor with cuda layout from tensor with torch layout
- // C x H x W -> W x H x C
- THCTensor_(transpose)(state, gradOutput, gradOutput, 0, 2);
- // W x H x C -> H x W x C
- THCTensor_(transpose)(state, gradOutput, gradOutput, 0, 1);
- THCTensor *transposedGradOutput = THCTensor_(newContiguous)(state, gradOutput);
- THCTensor_(transpose)(state, gradOutput, gradOutput, 0, 1);
- THCTensor_(transpose)(state, gradOutput, gradOutput, 0, 2);
- gradOutput = transposedGradOutput;
- }
- else if (ndims == 4)
- {
- spatial = true;
- batchSize = THCTensor_(size)(state, gradInput, 0);
- classSize = THCTensor_(size)(state, input, 1);
- height = THCTensor_(size)(state, input, 2);
- width = THCTensor_(size)(state, input, 3);
-
- // create contiguous tensor with cuda layout from tensor with torch layout
- // B x C x H x W -> B x W x H x C
- THCTensor_(transpose)(state, output, output, 1, 3);
- // B x W x H x C -> B x H x W x C
- THCTensor_(transpose)(state, output, output, 1, 2);
- THCTensor *transposedOutput = THCTensor_(newContiguous)(state, output);
- THCTensor_(transpose)(state, output, output, 1, 2);
- THCTensor_(transpose)(state, output, output, 1, 3);
- output = transposedOutput;
-
- // create contiguous tensor with cuda layout from tensor with torch layout
- // B x C x H x W -> B x W x H x C
- THCTensor_(transpose)(state, gradOutput, gradOutput, 1, 3);
- // B x W x H x C -> B x H x W x C
- THCTensor_(transpose)(state, gradOutput, gradOutput, 1, 2);
- THCTensor *transposedGradOutput = THCTensor_(newContiguous)(state, gradOutput);
- THCTensor_(transpose)(state, gradOutput, gradOutput, 1, 2);
- THCTensor_(transpose)(state, gradOutput, gradOutput, 1, 3);
- gradOutput = transposedGradOutput;
- }
- else
- {
- THError("1D, 2D, 3D or 4D Tensor expected");
- }
-
- if (!spatial)
- {
- dim3 grid(batchSize);
+ // See descriptions of kernels above.
+ if (inner_size == 1 && dim_size >= 64) {
+ dim3 grid(outer_size);
dim3 block(1024);
cunn_LogSoftMax_updateGradInput_kernel<2, real, accreal>
@@ -209,20 +103,25 @@ void THNN_(LogSoftMax_updateGradInput)(
THCTensor_(data)(state, gradInput),
THCTensor_(data)(state, output),
THCTensor_(data)(state, gradOutput),
- classSize
+ dim_size
);
- }
- else
- {
- dim3 grid(batchSize);
- dim3 block(1024);
+ } else {
+ dim3 grid, block;
+ uint32_t block_size = 1024;
+ while (block_size > inner_size) block_size >>= 1; // block_size = floor(log2(inner_size))
+ int max_active_blocks;
+ cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks,
+ &cunn_SpatialLogSoftMax_updateGradInput_kernel<real, accreal>,
+ block_size, 0);
+ max_active_blocks *= THCState_getCurrentDeviceProperties(state)->multiProcessorCount;
+ LogSoftMax_getSpatialGridSize(block_size, max_active_blocks, outer_size, dim_size, inner_size, grid, block);
cunn_SpatialLogSoftMax_updateGradInput_kernel<real, accreal>
<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
THCTensor_(data)(state, gradInput),
THCTensor_(data)(state, output),
THCTensor_(data)(state, gradOutput),
- classSize, height, width
+ outer_size, dim_size, inner_size
);
}
diff --git a/torch/lib/THCUNN/generic/SoftMax.cu b/torch/lib/THCUNN/generic/SoftMax.cu
index b52ca179ea..030f1ee94f 100644
--- a/torch/lib/THCUNN/generic/SoftMax.cu
+++ b/torch/lib/THCUNN/generic/SoftMax.cu
@@ -7,7 +7,8 @@
void THNN_(SoftMax_updateOutput)(
THCState *state,
THCTensor *input,
- THCTensor *output)
+ THCTensor *output,
+ int _dim)
{
THCUNN_assertSameGPU(state, 2, input, output);
@@ -21,12 +22,14 @@ void THNN_(SoftMax_updateOutput)(
batchSize = 1;
dim = input->size[0];
stride0 = 1;
+ THArgCheck(_dim == 0, 4, "dim has to be 0 for 1D input");
}
else if (input->nDimension == 2)
{
batchSize = input->size[0];
dim = input->size[1];
stride0 = 1;
+ THArgCheck(_dim == 1, 4, "dim has to be 1 for 2D input");
}
else if (input->nDimension == 3)
{
@@ -36,6 +39,7 @@ void THNN_(SoftMax_updateOutput)(
blocksZ = input->size[2];
stride0 = blocksY * blocksZ;
stride1 = blocksZ;
+ THArgCheck(_dim == 0, 4, "dim has to be 0 for 3D input");
}
else if (input->nDimension == 4)
{
@@ -45,6 +49,7 @@ void THNN_(SoftMax_updateOutput)(
blocksZ = input->size[3];
stride0 = blocksY * blocksZ;
stride1 = blocksZ;
+ THArgCheck(_dim == 1, 4, "dim has to be 1 for 4D input");
}
else
{
@@ -79,9 +84,10 @@ void THNN_(SoftMax_updateGradInput)(
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- THCTensor *output)
+ THCTensor *output,
+ int _dim)
{
- THCUNN_check_nElement(state, input, gradOutput);
+ THCUNN_check_nElement(state, output, gradOutput);
THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput);
output = THCTensor_(newContiguous)(state, output);
@@ -96,12 +102,14 @@ void THNN_(SoftMax_updateGradInput)(
batchSize = 1;
dim = gradInput->size[0];
stride0 = 1;
+ THArgCheck(_dim == 0, 6, "dim has to be 0 for 1D input");
}
else if (gradInput->nDimension == 2)
{
batchSize = gradInput->size[0];
dim = gradInput->size[1];
stride0 = 1;
+ THArgCheck(_dim == 1, 6, "dim has to be 0 for 2D input");
}
else if (gradInput->nDimension == 3)
{
@@ -111,6 +119,7 @@ void THNN_(SoftMax_updateGradInput)(
blocksZ = gradInput->size[2];
stride0 = blocksY * blocksZ;
stride1 = blocksZ;
+ THArgCheck(_dim == 0, 6, "dim has to be 0 for 3D input");
}
else if (gradInput->nDimension == 4)
{
@@ -120,6 +129,7 @@ void THNN_(SoftMax_updateGradInput)(
blocksZ = gradInput->size[3];
stride0 = blocksY * blocksZ;
stride1 = blocksZ;
+ THArgCheck(_dim == 1, 6, "dim has to be 0 for 4D input");
}
else
{
@@ -131,7 +141,7 @@ void THNN_(SoftMax_updateGradInput)(
{
blocksY *= blocksZ;
blocksZ = 1;
- if (input->nDimension == 3 || input->nDimension == 4) {
+ if (output->nDimension == 3 || output->nDimension == 4) {
stride0 = blocksY * blocksZ;
stride1 = blocksZ;
}
diff --git a/torch/lib/THCUNN/generic/THCUNN.h b/torch/lib/THCUNN/generic/THCUNN.h
index aa1842f50d..bc75e0484b 100644
--- a/torch/lib/THCUNN/generic/THCUNN.h
+++ b/torch/lib/THCUNN/generic/THCUNN.h
@@ -242,14 +242,16 @@ TH_API void THNN_(LogSigmoid_updateGradInput)(
TH_API void THNN_(LogSoftMax_updateOutput)(
THCState *state,
THCTensor *input,
- THCTensor *output);
+ THCTensor *output,
+ int dim);
TH_API void THNN_(LogSoftMax_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- THCTensor *output);
+ THCTensor *output,
+ int dim);
TH_API void THNN_(LookupTable_accGradParameters)(
THCState *state,
@@ -1066,14 +1068,16 @@ TH_API void THNN_(SoftMarginCriterion_updateGradInput)(
TH_API void THNN_(SoftMax_updateOutput)(
THCState *state,
THCTensor *input,
- THCTensor *output);
+ THCTensor *output,
+ int dim);
TH_API void THNN_(SoftMax_updateGradInput)(
THCState *state,
THCTensor *input,
THCTensor *gradOutput,
THCTensor *gradInput,
- THCTensor *output);
+ THCTensor *output,
+ int dim);
TH_API void THNN_(SoftPlus_updateOutput)(
THCState *state,
diff --git a/torch/lib/THNN/generic/LogSoftMax.c b/torch/lib/THNN/generic/LogSoftMax.c
index a7280422b1..4c4700eb1e 100644
--- a/torch/lib/THNN/generic/LogSoftMax.c
+++ b/torch/lib/THNN/generic/LogSoftMax.c
@@ -5,64 +5,48 @@
void THNN_(LogSoftMax_updateOutput)(
THNNState *state,
THTensor *input,
- THTensor *output)
+ THTensor *output,
+ int dim)
{
- real *input_data, *output_data;
- ptrdiff_t nframe = 0, dim = 0, stride = 0;
- ptrdiff_t t, d;
+ THArgCheck(dim >= 0 && dim < input->nDimension, 4,
+ "dim out of range (got %d, but input has %d dims)", dim, input->nDimension);
- if (input->nDimension == 1)
- {
- nframe = 1;
- dim = input->size[0];
- stride = 1;
- }
- else if (input->nDimension == 2)
- {
- nframe = input->size[0];
- dim = input->size[1];
- stride = 1;
- }
- else if (input->nDimension == 3)
- {
- nframe = 1;
- dim = input->size[0];
- stride = input->size[1]*input->size[2];
- }
- else if (input->nDimension == 4)
- {
- nframe = input->size[0];
- dim = input->size[1];
- stride = input->size[2]*input->size[3];
- }
- else
- THArgCheck(0, 2, "1D, 2D, 3D or 4D tensor expected");
+ uint64_t outer_size = 1;
+ uint64_t dim_size = input->size[dim];
+ uint64_t inner_size = 1;
+ for (uint64_t i = 0; i < dim; ++i)
+ outer_size *= input->size[i];
+ for (uint64_t i = dim + 1; i < input->nDimension; ++i)
+ inner_size *= input->size[i];
input = THTensor_(newContiguous)(input);
THTensor_(resizeAs)(output, input);
- real *input_data0 = THTensor_(data)(input);
- real *output_data0 = THTensor_(data)(output);
-
- accreal logsum;
- real maxInput;
- #pragma omp parallel for private(t, d, maxInput, logsum, input_data, output_data)
- for (t = 0; t < stride*nframe; t++)
- {
- logsum = 0;
- maxInput = -THInf;
- input_data = input_data0 + (t/stride)*dim*stride + t % stride;
- output_data = output_data0 + (t/stride)*dim*stride + t % stride;
-
- for (d = 0; d < dim; d++)
- maxInput = THMax(maxInput, input_data[d*stride]);
+ real *input_data_base = THTensor_(data)(input);
+ real *output_data_base = THTensor_(data)(output);
- for (d = 0; d < dim; d++)
- logsum += exp(input_data[d*stride] - maxInput);
- logsum = maxInput + log(logsum);
+ uint64_t dim_stride = inner_size;
+ uint64_t outer_stride = dim_size * dim_stride;
- for (d = 0; d < dim; d++)
- output_data[d*stride] = input_data[d*stride] - logsum;
+#pragma omp parallel for
+ for (uint64_t i = 0; i < outer_size * inner_size; i++)
+ {
+ uint64_t outer_idx = i / inner_size;
+ uint64_t inner_idx = i % inner_size;
+ real *input_data = input_data_base + outer_idx * outer_stride + inner_idx;
+ real *output_data = output_data_base + outer_idx * outer_stride + inner_idx;
+
+ real max_input = -THInf;
+ for (uint64_t d = 1; d < dim_size; d++)
+ max_input = THMax(max_input, input_data[d * dim_stride]);
+
+ accreal logsum = 0;
+ for (uint64_t d = 0; d < dim_size; d++)
+ logsum += exp(input_data[d * dim_stride] - max_input);
+ logsum = max_input + log(logsum);
+
+ for (uint64_t d = 0; d < dim_size; d++)
+ output_data[d * dim_stride] = input_data[d * dim_stride] - logsum;
}
THTensor_(free)(input);
@@ -73,61 +57,47 @@ void THNN_(LogSoftMax_updateGradInput)(
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *output)
+ THTensor *output,
+ int dim)
{
- THNN_CHECK_SHAPE(input, gradOutput);
- real *gradInput_data, *gradOutput_data, *output_data;
- ptrdiff_t nframe = 0, dim = 0, stride = 0;
- ptrdiff_t t, d;
+ THNN_CHECK_SHAPE(output, gradOutput);
+ THArgCheck(dim >= 0 && dim < output->nDimension, 6,
+ "dim out of range (got %d, but input has %d dims)", dim, output->nDimension);
+
+ uint64_t outer_size = 1;
+ uint64_t dim_size = output->size[dim];
+ uint64_t inner_size = 1;
+ for (uint64_t i = 0; i < dim; ++i)
+ outer_size *= output->size[i];
+ for (uint64_t i = dim + 1; i < output->nDimension; ++i)
+ inner_size *= output->size[i];
- if (output->nDimension == 1)
- {
- nframe = 1;
- dim = output->size[0];
- stride = 1;
- }
- else if (output->nDimension == 2)
- {
- nframe = output->size[0];
- dim = output->size[1];
- stride = 1;
- }
- else if (output->nDimension == 3)
- {
- nframe = 1;
- dim = output->size[0];
- stride = output->size[1]*output->size[2];
- }
- else if (output->nDimension == 4)
- {
- nframe = output->size[0];
- dim = output->size[1];
- stride = output->size[2]*output->size[3];
- }
- else
- THError("1D, 2D, 3D or 4D tensor expected");
-
- output = THTensor_(newContiguous)(output);
gradOutput = THTensor_(newContiguous)(gradOutput);
-
+ output = THTensor_(newContiguous)(output);
THTensor_(resizeAs)(gradInput, output);
- real *gradInput_data0 = THTensor_(data)(gradInput);
- real *output_data0 = THTensor_(data)(output);
- real *gradOutput_data0 = THTensor_(data)(gradOutput);
- accreal sum;
- #pragma omp parallel for private(t, sum, d, gradInput_data, output_data, gradOutput_data)
- for (t = 0; t < stride*nframe; t++)
- {
- sum = 0;
- gradInput_data = gradInput_data0 + (t/stride)*dim*stride + t % stride;
- output_data = output_data0 + (t/stride)*dim*stride + t % stride;
- gradOutput_data = gradOutput_data0 + (t/stride)*dim*stride + t % stride;
- for (d = 0; d < dim; d++)
- sum += gradOutput_data[d*stride];
+ real *gradInput_data_base = THTensor_(data)(gradInput);
+ real *output_data_base = THTensor_(data)(output);
+ real *gradOutput_data_base = THTensor_(data)(gradOutput);
+
+ uint64_t dim_stride = inner_size;
+ uint64_t outer_stride = dim_size * dim_stride;
- for (d = 0; d < dim; d++)
- gradInput_data[d*stride] = gradOutput_data[d*stride] - exp(output_data[d*stride])*sum;
+#pragma omp parallel for
+ for (uint64_t i = 0; i < outer_size * inner_size; i++)
+ {
+ uint64_t outer_idx = i / inner_size;
+ uint64_t inner_idx = i % inner_size;
+ real *gradInput_data = gradInput_data_base + outer_idx * outer_stride + inner_idx;
+ real *output_data = output_data_base + outer_idx * outer_stride + inner_idx;
+ real *gradOutput_data = gradOutput_data_base + outer_idx * outer_stride + inner_idx;
+
+ accreal sum = 0;
+ for (uint64_t d = 0; d < dim_size; d++)
+ sum += gradOutput_data[d * dim_stride];
+
+ for (uint64_t d = 0; d < dim_size; d++)
+ gradInput_data[d * dim_stride] = gradOutput_data[d * dim_stride] - exp(output_data[d * dim_stride]) * sum;
}
THTensor_(free)(gradOutput);
diff --git a/torch/lib/THNN/generic/SoftMax.c b/torch/lib/THNN/generic/SoftMax.c
index 303526a222..6a11834b95 100644
--- a/torch/lib/THNN/generic/SoftMax.c
+++ b/torch/lib/THNN/generic/SoftMax.c
@@ -5,73 +5,50 @@
void THNN_(SoftMax_updateOutput)(
THNNState *state,
THTensor *input,
- THTensor *output)
-{
- real *input_data, *output_data;
- ptrdiff_t nframe = 0, dim = 0, stride = 0;
- ptrdiff_t t;
-
- if (input->nDimension == 1)
- {
- nframe = 1;
- dim = input->size[0];
- stride = 1;
- }
- else if (input->nDimension == 2)
- {
- nframe = input->size[0];
- dim = input->size[1];
- stride = 1;
- }
- else if (input->nDimension == 3)
- {
- nframe = 1;
- dim = input->size[0];
- stride = input->size[1]*input->size[2];
- }
- else if (input->nDimension == 4)
- {
- nframe = input->size[0];
- dim = input->size[1];
- stride = input->size[2]*input->size[3];
- }
- else
- {
- THArgCheck(0, 2, "1D, 2D, 3D or 4D tensor expected");
- }
+ THTensor *output,
+ int dim) {
+ THArgCheck(dim >= 0 && dim < input->nDimension, 4,
+ "dim out of range (got %d, but input has %d dims)", dim, input->nDimension);
+
+ uint64_t outer_size = 1;
+ uint64_t dim_size = input->size[dim];
+ uint64_t inner_size = 1;
+ for (uint64_t i = 0; i < dim; ++i)
+ outer_size *= input->size[i];
+ for (uint64_t i = dim + 1; i < input->nDimension; ++i)
+ inner_size *= input->size[i];
input = THTensor_(newContiguous)(input);
THTensor_(resizeAs)(output, input);
- input_data = THTensor_(data)(input);
- output_data = THTensor_(data)(output);
+ real *input_data_base = THTensor_(data)(input);
+ real *output_data_base = THTensor_(data)(output);
-#pragma omp parallel for private(t)
- for (t = 0; t < stride*nframe; t++)
- {
- real *input_ptr = input_data + (t/stride)*dim*stride + t % stride;
- real *output_ptr = output_data + (t/stride)*dim*stride + t % stride;
+ uint64_t dim_stride = inner_size;
+ uint64_t outer_stride = dim_size * dim_stride;
- real inputMax = -THInf;
- accreal sum;
+#pragma omp parallel for
+ for (uint64_t i = 0; i < outer_size * inner_size; i++) {
+ uint64_t outer_idx = i / inner_size;
+ uint64_t inner_idx = i % inner_size;
+ real *input_data = input_data_base + outer_idx * outer_stride + inner_idx;
+ real *output_data = output_data_base + outer_idx * outer_stride + inner_idx;
- ptrdiff_t d;
- for (d = 0; d < dim; d++)
- {
- if (input_ptr[d*stride] >= inputMax) inputMax = input_ptr[d*stride];
+ real input_max = -THInf;
+ for (uint64_t d = 0; d < dim_size; d++) {
+ if (input_data[d * dim_stride] >= input_max) input_max = input_data[d * dim_stride];
}
- sum = 0;
- for (d = 0; d < dim; d++)
- {
- real z = exp(input_ptr[d*stride] - inputMax);
- output_ptr[d*stride] = z;
+ accreal sum = 0;
+ for (uint64_t d = 0; d < dim_size; d++) {
+ real z = exp(input_data[d * dim_stride] - input_max);
+ output_data[d * dim_stride] = z;
sum += z;
}
- for (d = 0; d < dim; d++)
- {
- output_ptr[d*stride] *= 1/sum;
+ real invsum = 1 / sum; // NOTE: truncate sum to real once
+ for (uint64_t d = 0; d < dim_size; d++) {
+ output_data[d * dim_stride] *= invsum;
}
}
@@ -83,64 +60,47 @@ void THNN_(SoftMax_updateGradInput)(
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *output)
+ THTensor *output,
+ int dim)
{
- THNN_CHECK_SHAPE(input, gradOutput);
- real *gradInput_data, *gradOutput_data, *output_data;
- ptrdiff_t nframe = 0, dim = 0, stride = 0;
- ptrdiff_t t;
-
- if (output->nDimension == 1)
- {
- nframe = 1;
- dim = output->size[0];
- stride = 1;
- }
- else if (output->nDimension == 2)
- {
- nframe = output->size[0];
- dim = output->size[1];
- stride = 1;
- }
- else if (output->nDimension == 3)
- {
- nframe = 1;
- dim = output->size[0];
- stride = output->size[1]*output->size[2];
- }
- else if (output->nDimension == 4)
- {
- nframe = output->size[0];
- dim = output->size[1];
- stride = output->size[2]*output->size[3];
- }
- else
- {
- THError("1D, 2D, 3D or 4D tensor expected");
- }
+ THNN_CHECK_SHAPE(output, gradOutput);
+ THArgCheck(dim >= 0 && dim < output->nDimension, 6,
+ "dim out of range (got %d, but input has %d dims)", dim, output->nDimension);
+
+ uint64_t outer_size = 1;
+ uint64_t dim_size = output->size[dim];
+ uint64_t inner_size = 1;
+ for (uint64_t i = 0; i < dim; ++i)
+ outer_size *= output->size[i];
+ for (uint64_t i = dim + 1; i < output->nDimension; ++i)
+ inner_size *= output->size[i];
gradOutput = THTensor_(newContiguous)(gradOutput);
output = THTensor_(newContiguous)(output);
-
THTensor_(resizeAs)(gradInput, output);
- gradInput_data = THTensor_(data)(gradInput);
- output_data = THTensor_(data)(output);
- gradOutput_data = THTensor_(data)(gradOutput);
-#pragma omp parallel for private(t)
- for (t = 0; t < stride*nframe; t++)
+ real *gradInput_data_base = THTensor_(data)(gradInput);
+ real *output_data_base = THTensor_(data)(output);
+ real *gradOutput_data_base = THTensor_(data)(gradOutput);
+
+ uint64_t dim_stride = inner_size;
+ uint64_t outer_stride = dim_size * dim_stride;
+
+#pragma omp parallel for
+ for (uint64_t i = 0; i < outer_size * inner_size; i++)
{
- real *gradInput_ptr = gradInput_data + (t/stride)*dim*stride + t % stride;
- real *output_ptr = output_data + (t/stride)*dim*stride + t % stride;
- real *gradOutput_ptr = gradOutput_data + (t/stride)*dim*stride + t % stride;
+ uint64_t outer_idx = i / inner_size;
+ uint64_t inner_idx = i % inner_size;
+ real *gradInput_data = gradInput_data_base + outer_idx * outer_stride + inner_idx;
+ real *output_data = output_data_base + outer_idx * outer_stride + inner_idx;
+ real *gradOutput_data = gradOutput_data_base + outer_idx * outer_stride + inner_idx;
- ptrdiff_t d;
accreal sum = 0;
- for (d = 0; d < dim; d++)
- sum += (accreal)gradOutput_ptr[d*stride] * output_ptr[d*stride];
+ for (uint64_t d = 0; d < dim_size; d++)
+ sum += ((accreal)gradOutput_data[d * dim_stride]) * ((accreal)output_data[d * dim_stride]);
- for (d = 0; d < dim; d++)
- gradInput_ptr[d*stride] = output_ptr[d*stride] * (gradOutput_ptr[d*stride] - sum);
+ for (uint64_t d = 0; d < dim_size; d++)
+ gradInput_data[d * dim_stride] = output_data[d * dim_stride] * (gradOutput_data[d * dim_stride] - sum);
}
THTensor_(free)(gradOutput);
diff --git a/torch/lib/THNN/generic/THNN.h b/torch/lib/THNN/generic/THNN.h
index dbbf5f1bca..9da18bc0c4 100644
--- a/torch/lib/THNN/generic/THNN.h
+++ b/torch/lib/THNN/generic/THNN.h
@@ -224,13 +224,15 @@ TH_API void THNN_(LogSigmoid_updateGradInput)(
TH_API void THNN_(LogSoftMax_updateOutput)(
THNNState *state, // library's state
THTensor *input, // input tensor
- THTensor *output); // [OUT] output tensor
+ THTensor *output, // [OUT] output tensor
+ int dim);
TH_API void THNN_(LogSoftMax_updateGradInput)(
THNNState *state, // library's state
THTensor *input, // input tensor
THTensor *gradOutput, // gradient w.r.t. module's output
THTensor *gradInput, // [OUT] gradient w.r.t. input
- THTensor *output); // module's output
+ THTensor *output, // module's output
+ int dim);
TH_API void THNN_(LookupTable_accGradParameters)(
THNNState *state,
@@ -423,13 +425,15 @@ TH_API void THNN_(SmoothL1Criterion_updateGradInput)(
TH_API void THNN_(SoftMax_updateOutput)(
THNNState *state,
THTensor *input,
- THTensor *output);
+ THTensor *output,
+ int dim);
TH_API void THNN_(SoftMax_updateGradInput)(
THNNState *state,
THTensor *input,
THTensor *gradOutput,
THTensor *gradInput,
- THTensor *output);
+ THTensor *output,
+ int dim);
TH_API void THNN_(SoftPlus_updateOutput)(
THNNState *state,
diff --git a/torch/nn/_functions/thnn/activation.py b/torch/nn/_functions/thnn/activation.py
index 4d2ac18656..25c2cac93f 100644
--- a/torch/nn/_functions/thnn/activation.py
+++ b/torch/nn/_functions/thnn/activation.py
@@ -4,7 +4,6 @@ from torch._thnn import type2backend
from torch.autograd.variable import Variable
from . import _all_functions
-from .auto_double_backwards import softmax_double_backwards
class PReLU(Function):
diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py
index 653eef0eb3..fbd83f40ff 100644
--- a/torch/nn/_functions/thnn/auto.py
+++ b/torch/nn/_functions/thnn/auto.py
@@ -298,6 +298,8 @@ def _generate_function_classes(scope_dict):
'LookupTableBag',
'PReLU',
'RReLU',
+ 'SoftMax',
+ 'LogSoftMax',
'GRUFused',
'LSTMFused',
'unfolded',
@@ -312,8 +314,6 @@ def _generate_function_classes(scope_dict):
'SpatialReplicationPadding': 'ReplicationPad2d',
'VolumetricReplicationPadding': 'ReplicationPad3d',
'VolumetricMaxUnpooling': 'MaxUnpool3d',
- 'SoftMax': 'Softmax',
- 'LogSoftMax': 'LogSoftmax',
'HardTanh': 'Hardtanh',
'HardShrink': 'Hardshrink',
'SoftPlus': 'Softplus',
diff --git a/torch/nn/_functions/thnn/auto_double_backwards.py b/torch/nn/_functions/thnn/auto_double_backwards.py
index ba900e9d99..e3a96d2e49 100644
--- a/torch/nn/_functions/thnn/auto_double_backwards.py
+++ b/torch/nn/_functions/thnn/auto_double_backwards.py
@@ -91,21 +91,6 @@ def logsigmoid_double_backwards(ctx, ggI):
return gI, ggO, None, None, None, None
-def logsoftmax_double_backwards(ctx, ggI):
- t = ctx.saved_variables
- gO, output = t[1], t[2]
-
- output_exp = output.exp()
- gO_sum = gO.sum(dim=1, keepdim=True)
- ggI_output_exp = ggI * output_exp
- ggI_output_exp_sum = ggI_output_exp.sum(dim=1, keepdim=True)
-
- gI = output_exp * gO_sum * ggI_output_exp_sum - ggI_output_exp * gO_sum
- ggO = ggI - ggI_output_exp_sum
-
- return gI, ggO, None, None, None, None
-
-
def reflectionpad1d_double_backwards(ctx, ggI):
gI = None
ggO = torch.nn._functions.thnn.auto.ReflectionPad1d.apply(ggI, *ctx.additional_args)
@@ -141,29 +126,6 @@ def replicationpad3d_double_backwards(ctx, ggI):
return gI, ggO, None, None, None, None
-def softmax_double_backwards(ctx, ggI):
- t = ctx.saved_variables
- gO, output = t[1], t[2]
-
- # terms for reuse
- ggI_output = ggI * output
- ggI_out_sum = ggI_output.sum(dim=1, keepdim=True)
- ggI_out_sum_output = ggI_out_sum * output
- gO_out_sum = (gO * output).sum(dim=1, keepdim=True)
-
- # gI calculation
- gI_t0 = ggI_output * (gO - gO_out_sum)
- gI_t1 = output * ((ggI_output * gO).sum(dim=1, keepdim=True).sub_(gO_out_sum * ggI_out_sum))
- gI_t2 = ggI_out_sum_output * gO
- gI_t3 = ggI_out_sum_output * gO_out_sum
- gI = gI_t0 - gI_t1 - gI_t2 + gI_t3
-
- # gO calculation
- ggO = output * (ggI - ggI_out_sum)
-
- return gI, ggO, None, None, None, None
-
-
def softplus_double_backwards(ctx, ggI):
t = ctx.saved_variables
input, gO, output = t[0], t[1], t[2]
@@ -313,13 +275,11 @@ double_backwards_fns = {
'Hardtanh': hardtanh_double_backwards,
'LeakyReLU': leakyrelu_double_backwards,
'LogSigmoid': logsigmoid_double_backwards,
- 'LogSoftmax': logsoftmax_double_backwards,
'ReflectionPad1d': reflectionpad1d_double_backwards,
'ReflectionPad2d': reflectionpad2d_double_backwards,
'ReplicationPad1d': replicationpad1d_double_backwards,
'ReplicationPad2d': replicationpad2d_double_backwards,
'ReplicationPad3d': replicationpad3d_double_backwards,
- 'Softmax': softmax_double_backwards,
'Softplus': softplus_double_backwards,
'Softshrink': softshrink_double_backwards,
'Threshold': threshold_double_backwards,
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 52ec672864..ca6a18a71a 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -17,7 +17,9 @@ from torch.autograd import Variable
from .modules.utils import _single, _pair, _triple
# Convolutions
-ConvNd = torch._C._functions.ConvNd
+_ConvNd = torch._C._functions.ConvNd
+_Softmax = torch._C._functions.Softmax
+_LogSoftmax = torch._C._functions.LogSoftmax
def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1,
@@ -49,9 +51,9 @@ def conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1,
if input is not None and input.dim() != 3:
raise ValueError("Expected 3D tensor as input, got {}D tensor instead.".format(input.dim()))
- f = ConvNd(_single(stride), _single(padding), _single(dilation), False,
- _single(0), groups, torch.backends.cudnn.benchmark,
- torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
+ f = _ConvNd(_single(stride), _single(padding), _single(dilation), False,
+ _single(0), groups, torch.backends.cudnn.benchmark,
+ torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
return f(input, weight, bias)
@@ -85,9 +87,9 @@ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
if input is not None and input.dim() != 4:
raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(input.dim()))
- f = ConvNd(_pair(stride), _pair(padding), _pair(dilation), False,
- _pair(0), groups, torch.backends.cudnn.benchmark,
- torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
+ f = _ConvNd(_pair(stride), _pair(padding), _pair(dilation), False,
+ _pair(0), groups, torch.backends.cudnn.benchmark,
+ torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
return f(input, weight, bias)
@@ -121,9 +123,9 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1,
if input is not None and input.dim() != 5:
raise ValueError("Expected 5D tensor as input, got {}D tensor instead.".format(input.dim()))
- f = ConvNd(_triple(stride), _triple(padding), _triple(dilation), False,
- _triple(0), groups, torch.backends.cudnn.benchmark,
- torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
+ f = _ConvNd(_triple(stride), _triple(padding), _triple(dilation), False,
+ _triple(0), groups, torch.backends.cudnn.benchmark,
+ torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
return f(input, weight, bias)
@@ -153,10 +155,10 @@ def conv_transpose1d(input, weight, bias=None, stride=1, padding=0,
if input is not None and input.dim() != 3:
raise ValueError("Expected 3D tensor as input, got {}D tensor instead.".format(input.dim()))
- f = ConvNd(_single(stride), _single(padding), _single(dilation), True,
- _single(output_padding),
- groups, torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
- torch.backends.cudnn.enabled)
+ f = _ConvNd(_single(stride), _single(padding), _single(dilation), True,
+ _single(output_padding),
+ groups, torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
+ torch.backends.cudnn.enabled)
return f(input, weight, bias)
@@ -187,9 +189,9 @@ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0,
if input is not None and input.dim() != 4:
raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(input.dim()))
- f = ConvNd(_pair(stride), _pair(padding), _pair(dilation), True,
- _pair(output_padding), groups, torch.backends.cudnn.benchmark,
- torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
+ f = _ConvNd(_pair(stride), _pair(padding), _pair(dilation), True,
+ _pair(output_padding), groups, torch.backends.cudnn.benchmark,
+ torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
return f(input, weight, bias)
@@ -219,9 +221,9 @@ def conv_transpose3d(input, weight, bias=None, stride=1, padding=0,
if input is not None and input.dim() != 5:
raise ValueError("Expected 5D tensor as input, got {}D tensor instead.".format(input.dim()))
- f = ConvNd(_triple(stride), _triple(padding), _triple(dilation), True,
- _triple(output_padding), groups, torch.backends.cudnn.benchmark,
- torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
+ f = _ConvNd(_triple(stride), _triple(padding), _triple(dilation), True,
+ _triple(output_padding), groups, torch.backends.cudnn.benchmark,
+ torch.backends.cudnn.deterministic, torch.backends.cudnn.enabled)
return f(input, weight, bias)
@@ -599,20 +601,73 @@ def softplus(input, beta=1, threshold=20):
return _functions.thnn.auto.Softplus.apply(input, beta, threshold)
-def softmin(input):
- return softmax(-input)
+def _get_softmax_dim(name, ndim, stacklevel):
+ warnings.warn("Implicit dimension choice for " + name + " has been deprecated. "
+ "Change the call to include dim=X as an argument.", stacklevel=stacklevel)
+ if ndim == 0 or ndim == 3:
+ return 0
+ else:
+ return 1
-def softmax(input):
- return _functions.thnn.auto.Softmax.apply(input)
+def softmin(input, dim=None, _stacklevel=3):
+ """Applies a softmin function.
+ Note that softmin(x) = softmax(-x). See softmax definition for mathematical formula.
-def softshrink(input, lambd=0.5):
- return _functions.thnn.auto.Softshrink.apply(input, lambd)
+ Arguments:
+ input (Variable): input
+ dim (int): A dimension along which softmin will be computed (so every slice
+ along dim will sum to 1).
+ """
+ if dim is None:
+ dim = _get_softmax_dim('softmin', input.dim(), _stacklevel)
+ return _Softmax(dim)(-input)
+
+
+def softmax(input, dim=None, _stacklevel=3):
+ """Applies a softmax function.
+
+ Softmax is defined as:
+
+ :math:`softmax(x) = \frac{exp(-x_i)}{\sum_j exp(-x_j)}`
+
+ It is applied to all slices along dim, and will rescale them so that the elements
+ lie in the range `(0, 1)` and sum to 1.
+
+ Arguments:
+ input (Variable): input
+ dim (int): A dimension along which softmax will be computed.
+ .. note::
+ This function doesn't work directly with NLLLoss,
+ which expects the Log to be computed between the Softmax and itself.
+ Use log_softmax instead (it's faster and has better numerical properties).
-def log_softmax(input):
- return _functions.thnn.LogSoftmax.apply(input)
+ """
+ if dim is None:
+ dim = _get_softmax_dim('softmax', input.dim(), _stacklevel)
+ return _Softmax(dim)(input)
+
+
+def log_softmax(input, dim=None, _stacklevel=3):
+ """Applies a softmax followed by a logarithm.
+
+ While mathematically equivalent to log(softmax(x)), doing these two
+ operations separately is slower, and numerically unstable. This function
+ uses an alternative formulation to compute the output and gradient correctly.
+
+ Arguments:
+ input (Variable): input
+ dim (int): A dimension along which log_softmax will be computed.
+ """
+ if dim is None:
+ dim = _get_softmax_dim('log_softmax', input.dim(), _stacklevel)
+ return _LogSoftmax(dim)(input)
+
+
+def softshrink(input, lambd=0.5):
+ return _functions.thnn.auto.Softshrink.apply(input, lambd)
def tanh(input):
@@ -931,7 +986,7 @@ def cross_entropy(input, target, weight=None, size_average=True, ignore_index=-1
>>> loss = F.cross_entropy(input, target)
>>> loss.backward()
"""
- return nll_loss(log_softmax(input), target, weight, size_average, ignore_index)
+ return nll_loss(log_softmax(input, 1), target, weight, size_average, ignore_index)
def binary_cross_entropy(input, target, weight=None, size_average=True):
diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py
index 399555da27..762e51fa6a 100644
--- a/torch/nn/modules/activation.py
+++ b/torch/nn/modules/activation.py
@@ -601,8 +601,12 @@ class Softmin(Module):
:math:`f(x) = exp(-x_i) / sum_j exp(-x_j)`
Shape:
- - Input: :math:`(N, L)`
- - Output: :math:`(N, L)`
+ - Input: any shape
+ - Output: same as input
+
+ Arguments:
+ dim (int): A dimension along which Softmax will be computed (so every slice
+ along dim will sum to 1).
Returns:
a Tensor of the same dimension and shape as the input, with
@@ -615,9 +619,12 @@ class Softmin(Module):
>>> print(input)
>>> print(m(input))
"""
+ def __init__(self, dim=None):
+ super(Softmin, self).__init__()
+ self.dim = dim
def forward(self, input):
- return F.softmin(input)
+ return F.softmin(input, self.dim, _stacklevel=5)
def __repr__(self):
return self.__class__.__name__ + ' ()'
@@ -632,17 +639,21 @@ class Softmax(Module):
:math:`f_i(x) = exp(x_i) / sum_j exp(x_j)`
Shape:
- - Input: :math:`(N, L)`
- - Output: :math:`(N, L)`
+ - Input: any shape
+ - Output: same as input
Returns:
a Tensor of the same dimension and shape as the input with
values in the range [0, 1]
+ Arguments:
+ dim (int): A dimension along which Softmax will be computed (so every slice
+ along dim will sum to 1).
+
.. note::
This module doesn't work directly with NLLLoss,
which expects the Log to be computed between the Softmax and itself.
- Use Logsoftmax instead (it's faster).
+ Use Logsoftmax instead (it's faster and has better numerical properties).
Examples::
@@ -652,9 +663,17 @@ class Softmax(Module):
>>> print(m(input))
"""
+ def __init__(self, dim=None):
+ super(Softmax, self).__init__()
+ self.dim = dim
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ if not hasattr(self, 'dim'):
+ self.dim = None
+
def forward(self, input):
- assert input.dim() == 2, 'Softmax requires a 2D tensor as input'
- return F.softmax(input)
+ return F.softmax(input, self.dim, _stacklevel=5)
def __repr__(self):
return self.__class__.__name__ + ' ()'
@@ -686,7 +705,7 @@ class Softmax2d(Module):
def forward(self, input):
assert input.dim() == 4, 'Softmax2d requires a 4D tensor as input'
- return F.softmax(input)
+ return F.softmax(input, 1, _stacklevel=5)
def __repr__(self):
return self.__class__.__name__ + ' ()'
@@ -699,8 +718,12 @@ class LogSoftmax(Module):
:math:`f_i(x) = log(exp(x_i) / sum_j exp(x_j) )`
Shape:
- - Input: :math:`(N, L)`
- - Output: :math:`(N, L)`
+ - Input: any shape
+ - Output: same as input
+
+ Arguments:
+ dim (int): A dimension along which Softmax will be computed (so every slice
+ along dim will sum to 1).
Returns:
a Tensor of the same dimension and shape as the input with
@@ -714,8 +737,17 @@ class LogSoftmax(Module):
>>> print(m(input))
"""
+ def __init__(self, dim=None):
+ super(LogSoftmax, self).__init__()
+ self.dim = dim
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ if not hasattr(self, 'dim'):
+ self.dim = None
+
def forward(self, input):
- return F.log_softmax(input)
+ return F.log_softmax(input, self.dim, _stacklevel=5)
def __repr__(self):
return self.__class__.__name__ + ' ()'