summaryrefslogtreecommitdiff
path: root/torch/legacy/nn/SoftShrink.py
blob: b663c54f47ff6542d1715157637581e219ed9d80 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from .Module import Module

class SoftShrink(Module):

    def __init__(self, lambd=0.5):
        super(SoftShrink, self).__init__()
        self.lambd = lambd

    def updateOutput(self, input):
        self._backend.SoftShrink_updateOutput(
            self._backend.library_state,
            input,
            self.output,
            self.lambd
        )
        return self.output

    def updateGradInput(self, input, gradOutput):
        self._backend.SoftShrink_updateGradInput(
            self._backend.library_state,
            input,
            gradOutput,
            self.gradInput,
            self.lambd
        )
        return self.gradInput