1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
import torch
from .Module import Module
class Replicate(Module):
def __init__(self, nf, dim=0):
super(Replicate, self).__init__()
self.nfeatures = nf
self.dim = dim
assert self.dim >= 0
def updateOutput(self, input):
assert self.dim < input.dim()
size = list(input.size())
size.insert(self.dim, self.nfeatures)
stride = list(input.stride())
stride.insert(self.dim, 0)
self.output.set_(input.storage(), input.storage_offset(),
torch.Size(size), tuple(stride))
return self.output
def updateGradInput(self, input, gradOutput):
self.gradInput.resize_as_(input).zero_()
size = list(input.size())
size.insert(self.dim, 1)
gradInput = self.gradInput.view(*size)
torch.sum(gradOutput, self.dim, out=gradInput)
return self.gradInput
|