summaryrefslogtreecommitdiff
path: root/tools/autograd/gen_python_functions.py
blob: d837dd79be1eae2c28481df24475d8373e22f975 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
# Generates Python bindings for ATen functions
#
# The bindings are generated as methods on python_variable or functions on the
# torch._C._nn object.
#
from collections import defaultdict
import re
from .nested_dict import nested_dict
from .gen_variable_type import should_trace
from .utils import write

try:
    from src.ATen.code_template import CodeTemplate
except ImportError:
    from tools.shared.module_loader import import_module
    CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate

# These functions require manual Python bindings or are not exposed to Python
SKIP_PYTHON_BINDINGS = [
    'alias', 'contiguous', 'is_cuda', 'is_sparse', 'size', 'stride',
    '.*_backward', '.*_backward_(out|input|weight|bias)', '.*_forward',
    '.*_forward_out', '_unsafe_view', 'tensor', '_?sparse_coo_tensor.*',
    '_arange.*', '_range.*', '_linspace.*', '_logspace.*',
    '_sparse_add_out', '_sparse_div.*', '_sparse_mul.*', '_sparse_sub.*',
    'index',
    '_indexCopy_', 'max_values', 'min_values', 'argmax', 'argmin',
    '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
    '_th_.*', '_thnn_.*',
    'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*',
    '_potrs.*', '_cholesky.*',
    'slice', 'randint(_out)?', 'unique_dim', '_unique', '_unique_dim',
    'item', '_local_scalar_dense',
    'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
    'copy_sparse_to_sparse_',
]

# These function signatures are not exposed to Python. Note that this signature
# list does not support regex.
SKIP_PYTHON_BINDINGS_SIGNATURES = [
    'add(Tensor, Scalar, Scalar)', 'add_(Tensor, Scalar, Scalar)',
    'sub(Tensor, Scalar, Scalar)', 'sub_(Tensor, Scalar, Scalar)',
    'mul(Tensor, Scalar)', 'mul_(Tensor, Scalar)',
    'div(Tensor, Scalar)', 'div_(Tensor, Scalar)',
]

PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
{
  HANDLE_TH_ERRORS
  static PythonArgParser parser({
    ${signatures}
  }, /*traceable=*/${traceable});
  ${unpack_self}
  ParsedArgs<${max_args}> parsed_args;
  auto r = parser.parse(args, kwargs, parsed_args);
  ${declare_namedtuple_return_types}
  ${dispatch}
  Py_RETURN_NONE;
  END_HANDLE_TH_ERRORS
}
""")

PY_VARIABLE_METHOD_NOARGS = CodeTemplate("""\
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
{
  HANDLE_TH_ERRORS
  ${declare_namedtuple_return_types}
  ${unpack_self}
  return wrap(${namedtuple_return_type}${dispatch_name}(${actuals}));
  END_HANDLE_TH_ERRORS
}
""")

PY_VARIABLE_CASE = CodeTemplate("""\
${cond} (r.idx == ${i}) {
  ${call_dispatch}
""")

PY_VARIABLE_OUT = CodeTemplate("""\
if (r.isNone(${out_idx})) {
  ${call_dispatch}
} else {
  ${call_dispatch_out}
}
""")

PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate("""\
if (r.isNone(${out_idx})) {
  ${call_dispatch}
} else {
  check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.isNone(${type_idx}),
                         r.layout(${layout_idx}), r.isNone(${layout_idx}),
                         r.device(${device_idx}), r.isNone(${device_idx}));
  ${call_dispatch_out}
}
""")

PY_VARIABLE_CALL_DISPATCH = CodeTemplate("""\
${dispatch_name}(${actuals})""")

PY_VARIABLE_SET_REQUIRES_GRAD = CodeTemplate("""\
${call_dispatch}.set_requires_grad(${requires_grad})""")

PY_VARIABLE_WRAP = CodeTemplate("""\
return wrap(${namedtuple_return_type}${call_dispatch});""")

PY_VARIABLE_DISPATCH = CodeTemplate("""\
inline ${simple_return_type} ${dispatch_name}(${formal_args}) {
  ${initialize_cuda}
  ${AutoNoGIL}
  return ${dispatch_call}(${dispatch_args});
}
""")

PY_VARIABLE_METHOD_DEF = CodeTemplate("""\
{"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""")

PY_RETURN_NAMEDTUPLE_DEF = CodeTemplate("""\
static PyStructSequence_Field fields${namedtuple_type_index}[] = {
  ${namedtuple_fields} {nullptr}
};
static PyStructSequence_Desc desc${namedtuple_type_index} = {
  "torch.return_types.${name}", nullptr,
  fields${namedtuple_type_index}, ${namedtuple_size}
};
static PyTypeObject type${namedtuple_type_index};
static bool namedtuple_type_initialized${namedtuple_type_index} = false;
if (!namedtuple_type_initialized${namedtuple_type_index}) {
  PyStructSequence_InitType(&type${namedtuple_type_index}, &desc${namedtuple_type_index});
  namedtuple_type_initialized${namedtuple_type_index} = true;
}
""")

UNPACK_SELF = "auto& self = reinterpret_cast<THPVariable*>(self_)->cdata;"

PYTHON_FUNCTION_SIGNATURE = CodeTemplate("""\
${name}(${py_formal_args})""")

# XXX: if you got here because of an assertion failure, it doesn't mean
# it's enough to just extend the list here. Before you do this, make sure
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
SUPPORTED_RETURN_TYPES = {
    'Tensor', 'std::tuple<Tensor,Tensor>',
    'std::tuple<Tensor,Tensor,double,int64_t>',
    'std::tuple<Tensor,Tensor,Tensor>',
    'std::tuple<Tensor,Tensor,Tensor,Tensor>',
    'std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>',
    'std::vector<Tensor>',
    'Scalar', 'bool', 'int64_t', 'void*', 'void'
}

TENSOR_OPTIONS = CodeTemplate("""\
const auto options = TensorOptions()
    .dtype(${dtype})
    .device(${device})
    .layout(${layout}.layout)
    .requires_grad(${requires_grad});
""")


def should_generate_python_binding(declaration):
    name = declaration['name']
    for pattern in SKIP_PYTHON_BINDINGS:
        if re.match('^' + pattern + '$', name):
            return False

    simple_types = [arg['simple_type'] for arg in declaration['arguments']]
    signature = '{}({})'.format(name, ', '.join(simple_types))
    for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
        if pattern == signature:
            return False

    # TODO: fix handling of SparseTensor. We don't want to generate Python
    # bindings to SparseTensor overloads, such as add(Tensor, SparseTensorRef),
    # since the Tensor-based signature already dynamically dispatches correctly.
    # However, sparse_mask only has a SparseTensor signature so we need to bind
    # that function.
    for arg in declaration['arguments']:
        if arg['type'] == 'SparseTensorRef' and declaration['name'] != 'sparse_mask':
            return False

    return True


def gen_py_variable_methods(out, declarations, template_path):
    PY_VARIABLE_METHODS_CPP = CodeTemplate.from_file(template_path + '/python_variable_methods.cpp')
    PY_VARIABLE_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_variable_methods_dispatch.h')

    def should_bind(declaration):
        return (should_generate_python_binding(declaration) and
                declaration['mode'] != 'NN' and
                declaration.get('python_module') != 'nn' and
                'Tensor' in declaration['method_of'])

    py_variable_methods = group_declarations_by_name(declarations, should_bind)

    env = create_python_bindings(py_variable_methods, True)
    write(out, 'python_variable_methods.cpp', PY_VARIABLE_METHODS_CPP, env)
    write(out, 'python_variable_methods_dispatch.h', PY_VARIABLE_DISPATCH_H, env)


def gen_py_nn_functions(out, declarations, template_path):
    PY_NN_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_nn_functions.cpp')
    PY_NN_FUNCTIONS_H = CodeTemplate.from_file(template_path + '/python_nn_functions.h')
    PY_NN_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_nn_functions_dispatch.h')

    def should_bind(declaration):
        return (should_generate_python_binding(declaration) and
                (declaration['mode'] == 'NN' or declaration.get('python_module') == 'nn'))

    py_nn_functions = group_declarations_by_name(declarations, should_bind)

    env = create_python_bindings(py_nn_functions, has_self=False, is_module=True)
    write(out, 'python_nn_functions.cpp', PY_NN_FUNCTIONS_CPP, env)
    write(out, 'python_nn_functions.h', PY_NN_FUNCTIONS_H, env)
    write(out, 'python_nn_functions_dispatch.h', PY_NN_DISPATCH_H, env)


def gen_py_torch_functions(out, declarations, template_path):
    PY_TORCH_FUNCTIONS_CPP = CodeTemplate.from_file(template_path + '/python_torch_functions.cpp')
    PY_TORCH_DISPATCH_H = CodeTemplate.from_file(template_path + '/python_torch_functions_dispatch.h')

    def should_bind(declaration):
        return (should_generate_python_binding(declaration) and
                declaration['mode'] != 'NN' and
                declaration.get('python_module') != 'nn' and
                'namespace' in declaration['method_of'])

    py_torch_functions = group_declarations_by_name(declarations, should_bind)

    env = create_python_bindings(py_torch_functions, has_self=False)
    write(out, 'python_torch_functions.cpp', PY_TORCH_FUNCTIONS_CPP, env)
    write(out, 'python_torch_functions_dispatch.h', PY_TORCH_DISPATCH_H, env)


def group_declarations_by_name(declarations, should_bind_fn):
    """Group declarations by name ignoring _out suffix"""
    groups = defaultdict(list)
    for declaration in declarations:
        name = declaration['name']
        if should_bind_fn(declaration):
            if name.endswith('_out'):
                groups[name[:-4]].append(declaration)
            else:
                groups[name].append(declaration)
    return groups


def get_type_default(declaration):
    if declaration['name'].startswith('randperm') or \
            declaration['name'] == 'tril_indices' or \
            declaration['name'] == 'triu_indices':
        return 'torch.int64'
    else:
        return 'None'


def create_python_bindings(python_functions, has_self, is_module=False):
    """Generates Python bindings to ATen functions"""
    py_methods = []
    py_method_defs = []
    py_method_dispatch = []

    unpack_methods = {
        'const Tensor &': 'tensor',
        'SparseTensorRef': 'tensor',
        'Tensor &': 'tensor',
        'Generator *': 'generator',
        'Storage &': 'storage',
        'const Type &': 'scalartype',
        'const THPLayout &': 'layout',
        'const Device &': 'device',
        'c10::optional<ScalarType>': 'scalartypeOptional',
        'c10::optional<Scalar>': 'scalarOptional',
        'c10::optional<int64_t>': 'toInt64Optional',
        'int64_t': 'toInt64',
        'bool': 'toBool',
        'double': 'toDouble',
        'std::string': 'string',
    }

    unpack_with_default_methods = {
        'IntList': 'setDefaultIntlist',
        'Scalar': 'scalarWithDefault',
        'int64_t': 'toInt64WithDefault',
        'bool': 'setDefaultBool',
        'double': 'setDefaultDouble',
        'const Type &': 'scalartypeWithDefault',
        'const THPLayout &': 'layoutWithDefault',
        'const Device &': 'deviceWithDefault',
        'ScalarType': 'scalartypeWithDefault',
    }

    def emit_single_dispatch(declaration, out_idx, base_env):
        env = {}
        simple_return_type = declaration['return_type'].replace(' &', '')
        assert simple_return_type in SUPPORTED_RETURN_TYPES, \
            declaration['name'] + ' returns unsupported type: ' + simple_return_type

        body = []
        actuals = []
        formal_args = []
        arg_idx = 0

        def is_output(arg):
            return arg.get('output', False)

        inputs = [arg for arg in declaration['arguments'] if not is_output(arg)]
        outputs = [arg for arg in declaration['arguments'] if is_output(arg)]

        has_tensor_options = any(arg['simple_type'] == 'TensorOptions' for arg in declaration['arguments'])

        def get_type_args(args):
            return [arg for arg in args if arg['simple_type'] == 'Type']
        type_actual_args = get_type_args(declaration['arguments'])
        type_binding_args = get_type_args(declaration['python_binding_arguments'])
        assert len(type_actual_args + type_binding_args) <= 1
        if type_binding_args and len(outputs) == 0:
            # out(s) determines the dtype if it is present, so only use this if there are no outputs.
            type_args = type_binding_args
        else:
            type_args = type_actual_args

        if type_args and len(outputs) > 1:
                raise RuntimeError("Not supported: type dispatched parameter with multiple outputs")

        def parse_arg(arg, arg_index, unpack_args=False):
            name = arg['name']
            typename = arg['type']
            if typename.startswith('IntList['):
                typename = 'IntList'
            if typename.startswith('LongTensor'):
                typename = 'Tensor'

            if arg.get('python_default_init'):
                assert typename in unpack_with_default_methods, \
                    '`{}` type is not supported in python_default_init'.format(typename)
                unpack_with_default = unpack_with_default_methods.get(typename)
                default_expr = arg.get('python_default_init')
                # TODO: Type currently maps to ScalarType, figure out a cleaner solution
                if typename == 'const Type &':
                    default_expr += '.scalarType()'
                expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
            else:
                unpack = unpack_methods.get(typename, typename.lower())
                expr = 'r.{}({})'.format(unpack, arg_index)

            if unpack_args:
                body.append('auto {} = {};'.format(name, expr))
                expr = name

            if typename == 'SparseTensorRef':
                expr = 'SparseTensorRef({})'.format(expr)

            dispatch_type = typename
            if dispatch_type == 'Tensor':
                dispatch_type = 'const Tensor &'
            elif dispatch_type == 'Tensor &':
                dispatch_type = 'Tensor'
            elif dispatch_type == 'const Device &':
                dispatch_type = 'c10::optional<int32_t>'
            formal = '{} {}'.format(dispatch_type, name)
            return expr, formal

        def append_actuals_formals(actual, formal):
            actuals.append(actual)
            formal_args.append(formal)

        # We always want to unpack when we have TensorOptions.
        unpack = has_tensor_options
        for arg in inputs:
            if arg['simple_type'] in ['Type', 'TensorOptions']:
                continue
            if has_self and arg['name'] == 'self':
                formal_args.append('Tensor & self')
                actuals.append('self')
                continue
            append_actuals_formals(*parse_arg(arg, arg_idx, unpack))
            arg_idx += 1

        if len(outputs) == 1:
            append_actuals_formals(*parse_arg(outputs[0], arg_idx))
        elif len(outputs) > 1:
            N = len(outputs)
            body.append('auto results = r.tensorlist_n<{}>({});'.format(N, arg_idx))
            for i, arg in enumerate(outputs):
                formal_args.append('Tensor & {}'.format(arg['name']))
                actuals.append('results[{}]'.format(i))

        layout = None
        parsed_type_args = None
        # type args go after the outputs to match the signature generation.
        arg_idx = arg_idx if out_idx is None else out_idx + 1
        for arg in type_args:
            parsed_type_args = parse_arg(arg, arg_idx, unpack)
            arg_idx += 1

        # check python_binding_arguments
        has_device_bind = False
        requires_grad = None
        python_binding_arguments = declaration.get('python_binding_arguments', [])
        if 'dtype' in (a['name'] for a in python_binding_arguments):
            if not has_tensor_options:
                arg_idx += 1

        if 'layout' in (a['name'] for a in python_binding_arguments):
            layout_idx, device_idx, requires_grad_idx = (arg_idx, arg_idx + 1, arg_idx + 2)
        else:
            device_idx, requires_grad_idx = (arg_idx, arg_idx + 1)

        device = None
        for arg in python_binding_arguments:
            if arg['name'] == 'dtype' and arg['simple_type'] == 'Type':
                pass  # already handled by type_dispatched_args
            elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
                # out(s) determines the type and layout if it is present, so only use this if there are no outputs.
                if len(outputs) == 0:
                    layout = parse_arg(arg, layout_idx)[0]
            elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
                if len(outputs) == 0:
                    assert parsed_type_args
                    assert layout
                    device, device_type = parse_arg(arg, device_idx, True)

                    if not has_tensor_options:
                        # add type, device formals and corresponding actuals.
                        # The type actual is the ATen type mapped from (ScalarType, Layout, Device)
                        # The device actual is the corresponding AutoGPU index for the Device.
                        formal_args.append(parsed_type_args[1])
                        formal_args.append(device_type)
                        actuals.append("torch::getVariableType({}, {}, {})".format(parsed_type_args[0], layout, device))
                        actuals.append('{}.index()'.format(device))

                    has_device_bind = True
            elif arg['name'] == 'requires_grad' and arg['simple_type'] == 'bool':
                requires_grad = parse_arg(arg, requires_grad_idx)[0]
            else:
                raise RuntimeError(("found {} in python_binding_arguments but only "
                                    "\"bool requires_grad\", \"ScalarType dtype\", \"Layout layout\", "
                                    "\"Device device\" are supported".format(arg)))

        dtype = parsed_type_args[0] if parsed_type_args else None
        if has_tensor_options and all([dtype, device, layout, requires_grad]):
            body.append(TENSOR_OPTIONS.substitute({
                'dtype': dtype,
                'layout': layout,
                'device': device,
                'requires_grad': requires_grad
            }))
            formal_args.append('const TensorOptions & options')
            actuals.append('options')

        env['unpack_args'] = []
        env['formal_args'] = formal_args
        env['actuals'] = actuals

        if has_tensor_options:
            env['initialize_cuda'] = 'maybe_initialize_cuda(options);'
        else:
            env['initialize_cuda'] = ''

        if 'call_args' in declaration:
            env['dispatch_args'] = declaration['call_args']
        else:
            env['dispatch_args'] = [arg['name'] for arg in declaration['arguments']]

        if 'Tensor' in declaration['method_of']:
            env['dispatch_args'] = [arg for arg in env['dispatch_args'] if arg != 'self']
            env['dispatch_call'] = 'self.{}'.format(declaration['name'])
        elif 'namespace' in declaration['method_of']:
            namespace = 'torch' if (has_tensor_options or declaration['name'].endswith('_like')) else 'at'
            env['dispatch_call'] = '{}::{}'.format(namespace, declaration['name'])
        else:
            raise RuntimeError('could not dispatch, neither namespace function nor Tensor method')

        env['AutoNoGIL'] = 'AutoNoGIL no_gil;' if not declaration['with_gil'] else ''

        # Use the simple_return_type (Tensor) rather than the fancy return type
        # (Tensor &).  This is important because the dispatch functions take
        # mutable arguments *by value*, not by reference.  If you then return
        # a a reference to such an argument, you will now have a pointer to a
        # dangling stack entry.  Not good.
        #
        # You want:
        #
        #   Tensor dispatch_selu_(Tensor self) { return at::selu_(self); }
        #
        # *not*
        #
        #   Tensor& dispatch_selu_(Tensor self) { return at::selu_(self); }
        #
        # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
        # codegen looks like dispatch_selu_(wrap(tensor)), and you can't take a
        # mutable reference to temporary.  Maybe we could assign it to a
        # variable itself.)
        env['simple_return_type'] = simple_return_type

        env = nested_dict(env, nested_dict(base_env, declaration))
        call_dispatch = PY_VARIABLE_CALL_DISPATCH.substitute(env)
        if requires_grad and not has_tensor_options:
            call_dispatch = PY_VARIABLE_SET_REQUIRES_GRAD.substitute(env, call_dispatch=call_dispatch,
                                                                     requires_grad=requires_grad)
        if simple_return_type == 'void':
            body.append('{call_dispatch};'.format(call_dispatch=call_dispatch))
            body.append('Py_RETURN_NONE;')
        else:
            body.append(PY_VARIABLE_WRAP.substitute(env, call_dispatch=call_dispatch))
        py_method_dispatch.append(PY_VARIABLE_DISPATCH.substitute(env))
        return body

    def emit_dispatch(i, dictionary, base_env):
        if 'out' in dictionary:
            out_idx = len([arg for arg in dictionary['out']['arguments']
                           if not arg.get('output', False)])
            env = {}
            env['call_dispatch_out'] = emit_single_dispatch(dictionary['out'], out_idx, base_env)
            env['call_dispatch'] = emit_single_dispatch(dictionary['base'], out_idx, base_env)

            has_dtype_bind = 'dtype' in [d['name'] for d in dictionary['out'].get('python_binding_arguments', [])]
            if has_dtype_bind:
                body = PY_VARIABLE_OUT_CHECK_TYPE.substitute(env, out_idx=out_idx, type_idx=out_idx + 1,
                                                             layout_idx=out_idx + 2, device_idx=out_idx + 3).split('\n')
            else:
                body = PY_VARIABLE_OUT.substitute(env, out_idx=out_idx).split('\n')
        else:
            body = emit_single_dispatch(dictionary['base'], None, base_env)

        cond = 'if' if i == 0 else '} else if'
        return PY_VARIABLE_CASE.substitute(i=i, cond=cond, call_dispatch=body)

    def get_python_binding_arguments(declaration):
        python_binding_arguments = []
        has_tensor_input_arg = False
        has_type_input_arg = False
        has_options_arg = False
        for arg in declaration['arguments']:
            if arg.get('output', False):
                continue
            typename = arg['simple_type']
            if typename in ['Tensor', 'TensorList']:
                has_tensor_input_arg = True
            if arg['simple_type'] == 'Type':
                has_type_input_arg = True
            elif arg['simple_type'] == 'TensorOptions':
                has_options_arg = True
            if arg['name'] == 'requires_grad':
                raise ValueError("argument named requires_grad not supported")

        has_tensor_return = False
        for ret in declaration['returns']:
            if ret['dynamic_type'] in ['Tensor', 'TensorList']:
                # this probably won't work if one of the returns is not a tensor, but it will
                # produce a compile-time error that is obvious
                has_tensor_return = True

        is_like_function = name.endswith('_like')
        is_like_function_with_options = is_like_function and has_options_arg
        is_factory_function = has_tensor_return and not has_tensor_input_arg
        is_factory_or_like_function = has_tensor_return and (not has_tensor_input_arg or is_like_function)

        if (is_factory_function and not has_type_input_arg) or has_options_arg:
            default_type = get_type_default(declaration)
            py_default_dtype = 'self.type()' if is_like_function_with_options else None
            dtype_arg = {
                'default': default_type,
                'dynamic_type': 'Type',
                'kwarg_only': True,
                'name': 'dtype',
                'type': 'const Type &',
                'simple_type': 'Type',
                'python_default_init': py_default_dtype,
            }
            python_binding_arguments.append(dtype_arg)
        if is_factory_function or is_like_function_with_options:
            py_default_layout = '*torch::getLayout(self.type().backend())' if is_like_function_with_options else None
            layout_arg = {
                'default': 'torch.strided',
                'dynamic_type': 'Layout',
                'kwarg_only': True,
                'name': 'layout',
                'type': 'const THPLayout &',
                'simple_type': 'Layout',
                'python_default_init': py_default_layout,
            }
            python_binding_arguments.append(layout_arg)
            py_default_device = 'self.device()' if is_like_function_with_options else None
            device_arg = {
                'default': 'None',
                'default_init': 'None',
                'dynamic_type': 'Device',
                'kwarg_only': True,
                'name': 'device',
                'type': 'const Device &',
                'simple_type': 'Device',
                'python_default_init': py_default_device
            }
            python_binding_arguments.append(device_arg)
        if is_factory_or_like_function:
            requires_grad_arg = {
                'default': False,
                'dynamic_type': 'bool',
                'kwarg_only': True,
                'name': 'requires_grad',
                'type': 'bool',
                'simple_type': 'bool',
            }
            python_binding_arguments.append(requires_grad_arg)
        return python_binding_arguments

    def emit_namedtuple_return_type_def(declaration, next_index):
        returns = declaration['returns']
        if len(returns) <= 1 or all(['field_name' not in x for x in returns]):
            declaration['namedtuple_return_type'] = ''
            return '', next_index
        declaration['namedtuple_type_index'] = next_index
        declaration['namedtuple_fields'] = ''
        for x in returns:
            # See Note [field_name versus name]
            if 'field_name' not in x:
                # When building on Windows, `PyStructSequence_UnnamedField` could not be
                # resolved by the linker for some reason, which cause error in building:
                #
                # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
                # PyStructSequence_UnnamedField
                #
                # Thus, at this point in time, we do not support unnamed
                # fields in namedtuple; you must either name all fields,
                # or none of them.
                raise ValueError("Unnamed field is not supported by codegen")
            else:
                declaration['namedtuple_fields'] += '{"' + x['field_name'] + '", ""}, '
        declaration['namedtuple_size'] = len(returns)
        declaration['namedtuple_return_type'] = '&type{}, '.format(next_index)
        return PY_RETURN_NAMEDTUPLE_DEF.substitute(declaration), next_index + 1

    def process_function(name, declarations):
        for declaration in declarations:
            declaration['python_binding_arguments'] = get_python_binding_arguments(declaration)

        env = {
            'name': name,
            'dispatch_name': 'dispatch_{}'.format(name),
            'pycname': 'THPVariable_{}'.format(name),
            'signatures': [],
            'max_args': max(len(o['arguments']) + len(o['python_binding_arguments']) for o in declarations),
            'unpack_self': [],
            'dispatch': [],
            'declare_namedtuple_return_types': '',
        }

        if has_self:
            env['unpack_self'] = [UNPACK_SELF]

        # generate namedtuple type declare
        next_index = 0
        for declaration in declarations:
            typedef, next_index = emit_namedtuple_return_type_def(declaration, next_index)
            env['declare_namedtuple_return_types'] += typedef

        # emit dispatch
        grouped = group_declarations(declarations)
        for i, dictionary in enumerate(grouped):
            signature = dictionary['signature']
            if has_self:
                signature = signature.replace('Tensor self, ', '')
                signature = signature.replace('Tensor self', '')
            if not has_self:
                # Use 'input' instead of 'self' for NN functions
                signature = signature.replace('Tensor self', 'Tensor input')
            signature = signature.replace('SparseTensorRef', 'Tensor')
            if dictionary['base'].get('deprecated', False):
                signature += '|deprecated'
            env['signatures'].append('"{}",'.format(signature))
            env['dispatch'].append(emit_dispatch(i, dictionary, env))

        env['dispatch'].append('}')

        env['traceable'] = 'true' if all(should_trace(d) for d in declarations) else 'false'

        if len(declarations) == 1 and len(declarations[0]['args']) == 1 and has_self:
            tmpl = PY_VARIABLE_METHOD_NOARGS
            env['actuals'] = ['self']
            env['flags'] = 'METH_NOARGS'
            env['namedtuple_return_type'] = declarations[0]['namedtuple_return_type']
        else:
            tmpl = PY_VARIABLE_METHOD_VARARGS
            env['flags'] = 'METH_VARARGS | METH_KEYWORDS'

        if not is_module and not has_self:
            env['flags'] += ' | METH_STATIC'

        py_methods.append(tmpl.substitute(env))
        py_method_defs.append(PY_VARIABLE_METHOD_DEF.substitute(env))

    for name in sorted(python_functions.keys()):
        process_function(name, python_functions[name])

    return {
        'py_methods': py_methods,
        'py_method_defs': py_method_defs,
        'py_method_dispatch': py_method_dispatch,
    }


def group_declarations(declarations):
    """Returns a list of dictionaries containing the optional keys:

       "base": the regular ATen declaration (e.g. conv2d)
       "out": the out variant (e.g. conv2d_out)
       "signature": the signature used for Python argument parsing
    """
    grouped = defaultdict(dict)

    # first group by signature ignoring out arguments
    for declaration in declarations:
        signature = get_python_signature(declaration, False)
        v = grouped[signature]
        if declaration['name'].endswith('_out'):
            v['out'] = declaration
            # prefer the signature with optional out=... arguments
            v['signature'] = get_python_signature(declaration, True)
        else:
            v['base'] = declaration
            if 'signature' not in v:
                v['signature'] = signature

    result = []
    for _, dictionary in sorted(grouped.items()):
        if 'base' not in dictionary:
            raise RuntimeError("'base' not in dictionary", dictionary)
        result.append(dictionary)
    return sort_declarations(result)


# This function declares a partial order on declarations, and sorts them according
# to its linear extension. This is necessary, because there's some ambiguity in the
# choice of overload, and we want a different order.
#
# See Note[Order of overloads matters]
def sort_declarations(grouped_decls):

    # TODO: This is a hack!
    #
    # For some reason, when you specify a Scalar argument in a native
    # function, you get a Declarations.yaml entry that looks like this:
    #
    #   - default: 1
    #     dynamic_type: Scalar
    #     is_nullable: false
    #     kwarg_only: true
    #     name: alpha
    #     type: Scalar
    #
    # This is contrast to when there is a 'real' argument in TH
    # Declarations.cwrap; this gets (correctly?) translated into
    # dynamic_type: real, and type: Scalar.  I would like to fix this
    # at the source but I have never understood what dynamic_type is
    # supposed to be.
    def normalized_dynamic_type(arg):
        if arg['dynamic_type'] == 'real':
            return 'Scalar'
        return arg['dynamic_type']

    def is_coord_smaller(arg1, arg2):
        return normalized_dynamic_type(arg1) == 'Scalar' and arg2['dynamic_type'] == 'Tensor'

    def is_smaller(d1, d2):
        """Returns True if d1 < d2 in the partial order."""
        args1, args2 = d1['base']['arguments'], d2['base']['arguments']
        if len(args1) != len(args2):
            return False
        any_smaller = any(is_coord_smaller(arg1, arg2) for arg1, arg2 in zip(args1, args2))
        all_smaller_or_equal = all(normalized_dynamic_type(arg1) == normalized_dynamic_type(arg2) or
                                   is_coord_smaller(arg1, arg2)
                                   for arg1, arg2 in zip(args1, args2))
        return any_smaller and all_smaller_or_equal

    # Construct the relation graph
    larger_than = defaultdict(set)
    for i1, decl1 in enumerate(grouped_decls):
        for i2, decl2 in enumerate(grouped_decls):
            if is_smaller(decl1, decl2):
                larger_than[i1].add(i2)

    if not larger_than:
        return grouped_decls

    # Use a topological sort to sort decls according to the partial order.
    sorted_deps = [(i, decl) for i, decl in enumerate(grouped_decls)
                   if i not in larger_than]
    for i, decl in sorted_deps:
        for i2 in sorted(larger_than.keys()):
            larger = larger_than[i2]
            larger.discard(i)
            if not larger:
                del larger_than[i2]
                sorted_deps.append((i2, grouped_decls[i2]))

    return [decl for i, decl in sorted_deps]


def get_python_signature(declaration, include_out):
    # Compute the Python function signature for argument parsing
    py_formal_args = []
    output_args = []
    type_args = []
    positional = True

    def get_py_formal_arg(arg):
        typename = arg['simple_type']
        typename = typename if typename != 'Type' else 'ScalarType'

        # TODO: remove this and make optional types in simple_type to be consistent across
        # tensor and other types after make Tensor? be optional instead of undefined
        if arg.get('is_nullable') and '?' not in typename:
            typename = '{}?'.format(typename)

        if arg.get('size') is not None:
            typename = '{}[{}]'.format(typename, arg['size'])
        param = typename + ' ' + arg['name']
        default = None
        if arg.get('default') is not None:
            default = arg['default']
            if default == 'nullptr' or default == 'nullopt' or default == '{}':
                default = 'None'
        if default is not None:
            param += '=' + str(default)
        return param

    for arg in declaration['arguments']:
        if arg.get('output', False):
            output_args.append(arg)
            continue
        if arg['simple_type'] == 'Type':
            type_args.append(arg)
            continue
        # Skip `TensorOptions` in Python, as it is only used on the C++ side.
        if arg['simple_type'] == 'TensorOptions':
            continue
        if arg.get('kwarg_only', False) and positional:
            py_formal_args.append('*')
            positional = False
        param = get_py_formal_arg(arg)
        py_formal_args.append(param)

    # add output arguments
    name = declaration['name']
    if name.endswith('_out'):
        name = name[:-4]

    if len(output_args) > 0 and include_out:
        assert declaration['name'].endswith('_out')
        if positional:
            py_formal_args.append('*')
            positional = False
        typenames = [arg['simple_type'] for arg in output_args]
        if len(typenames) > 1:
            typename = 'TensorList[{}]'.format(len(typenames))
        else:
            typename = typenames[0]
        py_formal_args.append(typename + ' out=None')

    # we could put this in the loop above but we want to ensure both type dispatched args
    # and python binding arguments are after the out argument; this matches the case
    # where there is a python binding argument dtype, which is necessary to match
    # the function signatures between the out and non-out variant.
    assert len(type_args) <= 1
    for arg in type_args:
        if positional:  # assume type_args should be kwarg_only.
            py_formal_args.append('*')
            positional = False
        py_formal_args.append(get_py_formal_arg(arg))

    if len(declaration['python_binding_arguments']) > 0:
        for arg in declaration['python_binding_arguments']:
            if arg.get('kwarg_only', False) and positional:
                py_formal_args.append('*')
                positional = False
            py_formal_args.append(get_py_formal_arg(arg))

    # Python function signature.
    # This is the string that we give to FunctionParameter, which is
    # then parsed into the actual structure which we do parsing
    # with.
    return PYTHON_FUNCTION_SIGNATURE.substitute(name=name, py_formal_args=py_formal_args)