summaryrefslogtreecommitdiff
path: root/torch/legacy/nn/SpatialSoftMax.py
diff options
context:
space:
mode:
Diffstat (limited to 'torch/legacy/nn/SpatialSoftMax.py')
-rw-r--r--torch/legacy/nn/SpatialSoftMax.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/torch/legacy/nn/SpatialSoftMax.py b/torch/legacy/nn/SpatialSoftMax.py
index 526e6d47dc..5c9c0a45d1 100644
--- a/torch/legacy/nn/SpatialSoftMax.py
+++ b/torch/legacy/nn/SpatialSoftMax.py
@@ -8,7 +8,8 @@ class SpatialSoftMax(Module):
self._backend.SoftMax_updateOutput(
self._backend.library_state,
input,
- self.output
+ self.output,
+ 0 if input.dim() == 1 or input.dim() == 3 else 1
)
return self.output
@@ -18,6 +19,7 @@ class SpatialSoftMax(Module):
input,
gradOutput,
self.gradInput,
- self.output
+ self.output,
+ 0 if input.dim() == 1 or input.dim() == 3 else 1
)
return self.gradInput