summaryrefslogtreecommitdiff
path: root/torch/legacy/nn/SplitTable.py
blob: 6f2f12ed46789410467310095e4a137450bed891 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from .Module import Module

class SplitTable(Module):

    def __init__(self, dimension):
        super(SplitTable, self).__init__()
        self.dimension = dimension

    def _getPositiveDimension(self, input):
        dimension = self.dimension
        if dimension < 0:
           dimension = input.dim() + dimension

        return dimension

    def updateOutput(self, input):
        dimension = self._getPositiveDimension(input)
        slices = input.size(dimension)

        currentOutput = []
        for i in range(slices):
            currentOutput.append(input.select(dimension, i))

        self.output = currentOutput
        return self.output

    def updateGradInput(self, input, gradOutput):
        if self.gradInput is None:
            return
        dimension = self._getPositiveDimension(input)
        slices = input.size(dimension)
        self.gradInput.resize_as_(input)

        for i in range(slices):
            self.gradInput.select(dimension, i).copy_(gradOutput[i])

        return self.gradInput