1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
|
from .nested_dict import nested_dict
from tools.shared.module_loader import import_module
CodeTemplate = import_module('code_template', 'torch/lib/ATen/code_template.py').CodeTemplate
PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
${prototypes}
});
${unpack_self}
PyObject* parsed_args[${max_args}];
auto r = parser.parse(args, kwargs, parsed_args);
${dispatch}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
""")
PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self, PyObject* args)
{
HANDLE_TH_ERRORS
${unpack_self}
return wrap(${dispatch_name}(${actuals}));
END_HANDLE_TH_ERRORS
}
""")
PY_VARIABLE_CASE = CodeTemplate("""\
${cond} (r.idx == ${i}) {
return wrap(${dispatch_name}(${actuals}));
""")
PY_VARIABLE_DISPATCH = CodeTemplate("""\
inline ${return_type} ${dispatch_name}(${formal_args}) {
${AutoNoGIL}
${AutoGPU}
return ${dispatch_call}(${dispatch_args});
}
""")
PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
{"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""")
UNPACK_SELF = "auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;"
def create_python_bindings(
python_functions, py_methods, py_method_defs, py_method_dispatch,
is_class):
"""python_variable_methods.cpp
Generates Python bindings to Variable methods
"""
unpack_methods = {
'const Tensor &': 'tensor',
'Generator *': 'generator',
'Storage &': 'storage',
'int64_t': 'toInt64',
'int': 'toInt64',
'bool': 'toBool',
'double': 'toDouble',
}
def first_tensor_arg(arguments):
for arg in arguments:
if arg['simple_type'] in {'Tensor', 'TensorList'}:
return arg['name']
return None
def auto_gpu(option):
tensor_arg = first_tensor_arg(option['arguments'])
if tensor_arg is None:
return ''
return 'AutoGPU auto_gpu({});'.format(tensor_arg)
def emit_dispatch(i, function):
env = {}
actuals = []
formal_args = []
arg_idx = 0
for arg in function['arguments']:
if 'Tensor' in function['method_of'] and arg['name'] == 'self':
formal_args.append('Tensor & {}'.format(arg['name']))
actuals.append('self_')
continue
typename = arg['type']
if typename.startswith('IntList['):
typename = 'IntList'
if typename.startswith('LongTensor'):
typename = 'Tensor'
unpack = unpack_methods.get(typename, typename.lower())
actuals.append('r.{}({})'.format(unpack, arg_idx))
dispatch_type = typename
dispatch_type = 'const Tensor &' if dispatch_type == 'Tensor' else dispatch_type
formal_args.append('{} {}'.format(dispatch_type, arg['name']))
arg_idx += 1
env['i'] = i
env['actuals'] = actuals
env['formal_args'] = formal_args
if 'call_args' in function:
env['dispatch_args'] = function['call_args']
else:
env['dispatch_args'] = [arg['name'] for arg in function['arguments']]
if 'Tensor' in function['method_of']:
env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
env['dispatch_call'] = 'self.{}'.format(function['name'])
else:
env['dispatch_call'] = 'at::{}'.format(function['name'])
env['AutoNoGIL'] = 'AutoNoGIL no_gil;'
env['AutoGPU'] = auto_gpu(function)
env['cond'] = 'if' if i == 0 else '} else if'
env = nested_dict(env, function)
py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
return PY_VARIABLE_CASE.substitute(env)
def process_function(name, functions):
env = {
'name': name,
'dispatch_name': 'dispatch_{}'.format(name),
'pycname': 'THPVariable_{}'.format(name),
'prototypes': [],
'max_args': max(len(o['arguments']) for o in functions),
'unpack_self': [],
'dispatch': [],
}
is_method = 'Tensor' in functions[0]['method_of']
if is_method:
env['unpack_self'] = [UNPACK_SELF]
for o in functions:
prototype = o['prototype']
if is_method:
prototype = prototype.replace('Tensor self, ', '')
prototype = prototype.replace('Tensor self', '')
if 'deprecated' in o:
prototype += '|deprecated'
env['prototypes'].append('"{}",'.format(prototype))
for i, option in enumerate(functions):
env['dispatch'].append(emit_dispatch(i, nested_dict(env, option)))
env['dispatch'].append('}')
if len(functions) == 1 and len(functions[0]['args']) == 1 and is_method:
tmpl = PY_VARIABLE_METHOD_NOARGS
env['actuals'] = ['self_']
env['flags'] = 'METH_NOARGS'
else:
tmpl = PY_VARIABLE_METHOD_VARARGS
env['flags'] = 'METH_VARARGS | METH_KEYWORDS'
if is_class and not is_method:
env['flags'] += ' | METH_STATIC'
py_methods.append(tmpl.substitute(env))
py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))
for name in sorted(python_functions.keys()):
process_function(name, python_functions[name])
|