summaryrefslogtreecommitdiff
path: root/torch/legacy/nn/Replicate.py
blob: 10f4d80884b976410627409a3841499ace5911b3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from .Module import Module


class Replicate(Module):

    def __init__(self, nf, dim=0):
        super(Replicate, self).__init__()
        self.nfeatures = nf
        self.dim = dim
        assert self.dim >= 0

    def updateOutput(self, input):
        assert self.dim < input.dim()

        size = list(input.size())
        size.insert(self.dim, self.nfeatures)

        stride = list(input.stride())
        stride.insert(self.dim, 0)

        self.output.set_(input.storage(), input.storage_offset(),
                         torch.Size(size), tuple(stride))
        return self.output

    def updateGradInput(self, input, gradOutput):
        self.gradInput.resize_as_(input).zero_()
        size = list(input.size())
        size.insert(self.dim, 1)

        gradInput = self.gradInput.view(*size)
        torch.sum(gradOutput, self.dim, out=gradInput)
        return self.gradInput