summaryrefslogtreecommitdiff
path: root/torch/legacy/nn/SpatialCrossMapLRN.py
diff options
context:
space:
mode:
Diffstat (limited to 'torch/legacy/nn/SpatialCrossMapLRN.py')
-rw-r--r--torch/legacy/nn/SpatialCrossMapLRN.py30
1 files changed, 15 insertions, 15 deletions
diff --git a/torch/legacy/nn/SpatialCrossMapLRN.py b/torch/legacy/nn/SpatialCrossMapLRN.py
index 7fa34c92d2..4b7402a46d 100644
--- a/torch/legacy/nn/SpatialCrossMapLRN.py
+++ b/torch/legacy/nn/SpatialCrossMapLRN.py
@@ -2,6 +2,7 @@ import torch
from .Module import Module
from .utils import clear
+
class SpatialCrossMapLRN(Module):
def __init__(self, size, alpha=1e-4, beta=0.75, k=1):
@@ -19,7 +20,7 @@ class SpatialCrossMapLRN(Module):
assert input.dim() == 4
if self.scale is None:
- self.scale = input.new()
+ self.scale = input.new()
if input.type() == 'torch.cuda.FloatTensor':
self._backend.SpatialCrossMapLRN_updateOutput(
self._backend.library_state,
@@ -32,10 +33,10 @@ class SpatialCrossMapLRN(Module):
self.k
)
else:
- batchSize = input.size(0)
- channels = input.size(1)
+ batchSize = input.size(0)
+ channels = input.size(1)
inputHeight = input.size(2)
- inputWidth = input.size(3)
+ inputWidth = input.size(3)
self.output.resize_as_(input)
self.scale.resize_as_(input)
@@ -44,7 +45,7 @@ class SpatialCrossMapLRN(Module):
inputSquare = self.output
torch.pow(input, 2, out=inputSquare)
- prePad = int((self.size - 1)/2 + 1)
+ prePad = int((self.size - 1) / 2 + 1)
prePadCrop = channels if prePad > channels else prePad
scaleFirst = self.scale.select(1, 0)
@@ -57,10 +58,10 @@ class SpatialCrossMapLRN(Module):
# by adding the next feature map and removing the previous
for c in range(1, channels):
scalePrevious = self.scale.select(1, c - 1)
- scaleCurrent = self.scale.select(1, c)
+ scaleCurrent = self.scale.select(1, c)
scaleCurrent.copy_(scalePrevious)
if c < channels - prePad + 1:
- squareNext = inputSquare.select(1, c + prePad - 1)
+ squareNext = inputSquare.select(1, c + prePad - 1)
scaleCurrent.add_(1, squareNext)
if c > prePad:
@@ -91,15 +92,15 @@ class SpatialCrossMapLRN(Module):
self.k
)
else:
- batchSize = input.size(0)
- channels = input.size(1)
+ batchSize = input.size(0)
+ channels = input.size(1)
inputHeight = input.size(2)
- inputWidth = input.size(3)
+ inputWidth = input.size(3)
if self.paddedRatio is None:
- self.paddedRatio = input.new()
+ self.paddedRatio = input.new()
if self.accumRatio is None:
- self.accumRatio = input.new()
+ self.accumRatio = input.new()
self.paddedRatio.resize_(channels + self.size - 1, inputHeight, inputWidth)
self.accumRatio.resize_(inputHeight, inputWidth)
@@ -114,9 +115,9 @@ class SpatialCrossMapLRN(Module):
for n in range(batchSize):
torch.mul(gradOutput[n], self.output[n], out=paddedRatioCenter)
paddedRatioCenter.div_(self.scale[n])
- torch.sum(self.paddedRatio.narrow(0, 0,self.size-1), 0, out=self.accumRatio)
+ torch.sum(self.paddedRatio.narrow(0, 0, self.size - 1), 0, out=self.accumRatio)
for c in range(channels):
- self.accumRatio.add_(self.paddedRatio[c+self.size-1])
+ self.accumRatio.add_(self.paddedRatio[c + self.size - 1])
self.gradInput[n][c].addcmul_(-cacheRatioValue, input[n][c], self.accumRatio)
self.accumRatio.add_(-1, self.paddedRatio[c])
@@ -125,4 +126,3 @@ class SpatialCrossMapLRN(Module):
def clearState(self):
clear(self, 'scale', 'paddedRatio', 'accumRatio')
return super(SpatialCrossMapLRN, self).clearState()
-