diff options
Diffstat (limited to 'torch/legacy/nn/SpatialCrossMapLRN.py')
-rw-r--r-- | torch/legacy/nn/SpatialCrossMapLRN.py | 30 |
1 files changed, 15 insertions, 15 deletions
diff --git a/torch/legacy/nn/SpatialCrossMapLRN.py b/torch/legacy/nn/SpatialCrossMapLRN.py index 7fa34c92d2..4b7402a46d 100644 --- a/torch/legacy/nn/SpatialCrossMapLRN.py +++ b/torch/legacy/nn/SpatialCrossMapLRN.py @@ -2,6 +2,7 @@ import torch from .Module import Module from .utils import clear + class SpatialCrossMapLRN(Module): def __init__(self, size, alpha=1e-4, beta=0.75, k=1): @@ -19,7 +20,7 @@ class SpatialCrossMapLRN(Module): assert input.dim() == 4 if self.scale is None: - self.scale = input.new() + self.scale = input.new() if input.type() == 'torch.cuda.FloatTensor': self._backend.SpatialCrossMapLRN_updateOutput( self._backend.library_state, @@ -32,10 +33,10 @@ class SpatialCrossMapLRN(Module): self.k ) else: - batchSize = input.size(0) - channels = input.size(1) + batchSize = input.size(0) + channels = input.size(1) inputHeight = input.size(2) - inputWidth = input.size(3) + inputWidth = input.size(3) self.output.resize_as_(input) self.scale.resize_as_(input) @@ -44,7 +45,7 @@ class SpatialCrossMapLRN(Module): inputSquare = self.output torch.pow(input, 2, out=inputSquare) - prePad = int((self.size - 1)/2 + 1) + prePad = int((self.size - 1) / 2 + 1) prePadCrop = channels if prePad > channels else prePad scaleFirst = self.scale.select(1, 0) @@ -57,10 +58,10 @@ class SpatialCrossMapLRN(Module): # by adding the next feature map and removing the previous for c in range(1, channels): scalePrevious = self.scale.select(1, c - 1) - scaleCurrent = self.scale.select(1, c) + scaleCurrent = self.scale.select(1, c) scaleCurrent.copy_(scalePrevious) if c < channels - prePad + 1: - squareNext = inputSquare.select(1, c + prePad - 1) + squareNext = inputSquare.select(1, c + prePad - 1) scaleCurrent.add_(1, squareNext) if c > prePad: @@ -91,15 +92,15 @@ class SpatialCrossMapLRN(Module): self.k ) else: - batchSize = input.size(0) - channels = input.size(1) + batchSize = input.size(0) + channels = input.size(1) inputHeight = input.size(2) - inputWidth = input.size(3) + inputWidth = input.size(3) if self.paddedRatio is None: - self.paddedRatio = input.new() + self.paddedRatio = input.new() if self.accumRatio is None: - self.accumRatio = input.new() + self.accumRatio = input.new() self.paddedRatio.resize_(channels + self.size - 1, inputHeight, inputWidth) self.accumRatio.resize_(inputHeight, inputWidth) @@ -114,9 +115,9 @@ class SpatialCrossMapLRN(Module): for n in range(batchSize): torch.mul(gradOutput[n], self.output[n], out=paddedRatioCenter) paddedRatioCenter.div_(self.scale[n]) - torch.sum(self.paddedRatio.narrow(0, 0,self.size-1), 0, out=self.accumRatio) + torch.sum(self.paddedRatio.narrow(0, 0, self.size - 1), 0, out=self.accumRatio) for c in range(channels): - self.accumRatio.add_(self.paddedRatio[c+self.size-1]) + self.accumRatio.add_(self.paddedRatio[c + self.size - 1]) self.gradInput[n][c].addcmul_(-cacheRatioValue, input[n][c], self.accumRatio) self.accumRatio.add_(-1, self.paddedRatio[c]) @@ -125,4 +126,3 @@ class SpatialCrossMapLRN(Module): def clearState(self): clear(self, 'scale', 'paddedRatio', 'accumRatio') return super(SpatialCrossMapLRN, self).clearState() - |