summaryrefslogtreecommitdiff
path: root/torch/legacy/nn/VolumetricDropout.py
blob: dda60c82f78b6a6e283fc2cdb40050b835c518b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from .Module import Module
from .utils import clear

class VolumetricDropout(Module):

    def __init__(self, p=0.5):
        super(VolumetricDropout, self).__init__()
        self.p = p
        self.train = True
        self.noise = torch.Tensor()

    def updateOutput(self, input):
        self.output.resize_as_(input).copy_(input)
        if self.train:
            assert input.dim() == 5
            self.noise.resize_(input.size(0), input.size(1), 1, 1, 1)

            self.noise.bernoulli_(1-self.p)
            # We expand the random dropouts to the entire feature map because the
            # features are likely correlated accross the map and so the dropout
            # should also be correlated.
            self.output.mul_(self.noise.expand_as(input))
        else:
            self.output.mul_(1-self.p)

        return self.output

    def updateGradInput(self, input, gradOutput):
        if self.train:
            self.gradInput.resize_as_(gradOutput).copy_(gradOutput)
            self.gradInput.mul_(self.noise.expand_as(input)) # simply mask the gradients with the noise vector
        else:
            raise RuntimeError('backprop only defined while training')

        return self.gradInput

    def setp(self, p):
        self.p = p

    def __repr__(self):
        return super(VolumetricDropout, self).__repr__() + '({:.4f})'.format(self.p)

    def clearState(self):
        clear(self, 'noise')
        return super(VolumetricDropout, self).clearState()