summaryrefslogtreecommitdiff
path: root/res/PyTorchExamples/examples/sub/__init__.py
blob: 2dc4a5ee0d8aebe030409e9c4dbbbeb2e1a51578 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn


# model
class net_sub(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs):
        return torch.sub(inputs[0], inputs[1])


_model_ = net_sub()

# dummy input for onnx generation
_dummy_ = [torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3)]