summaryrefslogtreecommitdiff
path: root/tools/autograd/gen_python_functions.py
blob: 46821d615227651264e4480d72849a5cdc6b145a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from .nested_dict import nested_dict
from tools.shared.module_loader import import_module

CodeTemplate = import_module('code_template', 'torch/lib/ATen/code_template.py').CodeTemplate


PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    ${prototypes}
  });
  ${unpack_self}
  PyObject* parsed_args[${max_args}];
  auto r = parser.parse(args, kwargs, parsed_args);
  ${dispatch}
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}
""")

PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self, PyObject* args)
{
  HANDLE_TH_ERRORS
  ${unpack_self}
  return wrap(${dispatch_name}(${actuals}));
  END_HANDLE_TH_ERRORS
}
""")

PY_VARIABLE_CASE = CodeTemplate("""\
${cond} (r.idx == ${i}) {
  return wrap(${dispatch_name}(${actuals}));
""")

PY_VARIABLE_DISPATCH = CodeTemplate("""\
inline ${return_type} ${dispatch_name}(${formal_args}) {
  ${AutoNoGIL}
  ${AutoGPU}
  return ${dispatch_call}(${dispatch_args});
}
""")

PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
{"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""")

UNPACK_SELF = "auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;"


def create_python_bindings(
        python_functions, py_methods, py_method_defs, py_method_dispatch,
        is_class):
    """python_variable_methods.cpp

    Generates Python bindings to Variable methods
    """

    unpack_methods = {
        'const Tensor &': 'tensor',
        'Generator *': 'generator',
        'Storage &': 'storage',
        'int64_t': 'toInt64',
        'int': 'toInt64',
        'bool': 'toBool',
        'double': 'toDouble',
    }

    def first_tensor_arg(arguments):
        for arg in arguments:
            if arg['simple_type'] in {'Tensor', 'TensorList'}:
                return arg['name']
        return None

    def auto_gpu(option):
        tensor_arg = first_tensor_arg(option['arguments'])
        if tensor_arg is None:
            return ''
        return 'AutoGPU auto_gpu({});'.format(tensor_arg)

    def emit_dispatch(i, function):
        env = {}

        actuals = []
        formal_args = []
        arg_idx = 0
        for arg in function['arguments']:
            if 'Tensor' in function['method_of'] and arg['name'] == 'self':
                formal_args.append('Tensor & {}'.format(arg['name']))
                actuals.append('self_')
                continue

            typename = arg['type']
            if typename.startswith('IntList['):
                typename = 'IntList'
            if typename.startswith('LongTensor'):
                typename = 'Tensor'

            unpack = unpack_methods.get(typename, typename.lower())
            actuals.append('r.{}({})'.format(unpack, arg_idx))
            dispatch_type = typename
            dispatch_type = 'const Tensor &' if dispatch_type == 'Tensor' else dispatch_type
            formal_args.append('{} {}'.format(dispatch_type, arg['name']))
            arg_idx += 1

        env['i'] = i
        env['actuals'] = actuals
        env['formal_args'] = formal_args
        if 'call_args' in function:
            env['dispatch_args'] = function['call_args']
        else:
            env['dispatch_args'] = [arg['name'] for arg in function['arguments']]
        if 'Tensor' in function['method_of']:
            env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
            env['dispatch_call'] = 'self.{}'.format(function['name'])
        else:
            env['dispatch_call'] = 'at::{}'.format(function['name'])
        env['AutoNoGIL'] = 'AutoNoGIL no_gil;'
        env['AutoGPU'] = auto_gpu(function)
        env['cond'] = 'if' if i == 0 else '} else if'
        env = nested_dict(env, function)
        py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
        return PY_VARIABLE_CASE.substitute(env)

    def process_function(name, functions):
        env = {
            'name': name,
            'dispatch_name': 'dispatch_{}'.format(name),
            'pycname': 'THPVariable_{}'.format(name),
            'prototypes': [],
            'max_args': max(len(o['arguments']) for o in functions),
            'unpack_self': [],
            'dispatch': [],
        }

        is_method = 'Tensor' in functions[0]['method_of']
        if is_method:
            env['unpack_self'] = [UNPACK_SELF]

        for o in functions:
            prototype = o['prototype']
            if is_method:
                prototype = prototype.replace('Tensor self, ', '')
                prototype = prototype.replace('Tensor self', '')
            if 'deprecated' in o:
                prototype += '|deprecated'
            env['prototypes'].append('"{}",'.format(prototype))

        for i, option in enumerate(functions):
            env['dispatch'].append(emit_dispatch(i, nested_dict(env, option)))
        env['dispatch'].append('}')

        if len(functions) == 1 and len(functions[0]['args']) == 1 and is_method:
            tmpl = PY_VARIABLE_METHOD_NOARGS
            env['actuals'] = ['self_']
            env['flags'] = 'METH_NOARGS'
        else:
            tmpl = PY_VARIABLE_METHOD_VARARGS
            env['flags'] = 'METH_VARARGS | METH_KEYWORDS'

        if is_class and not is_method:
            env['flags'] += ' | METH_STATIC'

        py_methods.append(tmpl.substitute(env))
        py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))

    for name in sorted(python_functions.keys()):
        process_function(name, python_functions[name])