diff options
Diffstat (limited to 'torch/legacy/nn/SpatialSoftMax.py')
-rw-r--r-- | torch/legacy/nn/SpatialSoftMax.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/torch/legacy/nn/SpatialSoftMax.py b/torch/legacy/nn/SpatialSoftMax.py index 526e6d47dc..5c9c0a45d1 100644 --- a/torch/legacy/nn/SpatialSoftMax.py +++ b/torch/legacy/nn/SpatialSoftMax.py @@ -8,7 +8,8 @@ class SpatialSoftMax(Module): self._backend.SoftMax_updateOutput( self._backend.library_state, input, - self.output + self.output, + 0 if input.dim() == 1 or input.dim() == 3 else 1 ) return self.output @@ -18,6 +19,7 @@ class SpatialSoftMax(Module): input, gradOutput, self.gradInput, - self.output + self.output, + 0 if input.dim() == 1 or input.dim() == 3 else 1 ) return self.gradInput |