blob: b663c54f47ff6542d1715157637581e219ed9d80 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
import torch
from .Module import Module
class SoftShrink(Module):
def __init__(self, lambd=0.5):
super(SoftShrink, self).__init__()
self.lambd = lambd
def updateOutput(self, input):
self._backend.SoftShrink_updateOutput(
self._backend.library_state,
input,
self.output,
self.lambd
)
return self.output
def updateGradInput(self, input, gradOutput):
self._backend.SoftShrink_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.lambd
)
return self.gradInput
|