diff options
author | gchanan <gregchanan@gmail.com> | 2018-02-09 15:45:41 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-09 15:45:41 -0500 |
commit | 4b8bf737297ae86049112c76914fc6730fa689d4 (patch) | |
tree | 396860377b5339de5c05311317e35046649e1d5c | |
parent | 8f1f84a6f2fa6c73a4e8d4751d5d231dd80105b2 (diff) | |
download | pytorch-4b8bf737297ae86049112c76914fc6730fa689d4.tar.gz pytorch-4b8bf737297ae86049112c76914fc6730fa689d4.tar.bz2 pytorch-4b8bf737297ae86049112c76914fc6730fa689d4.zip |
Enable scalars. (#5158)
* Enable scalars.
* Avoid variable name shadowing in list comprehension, because it rebinds in python2, but not python3.
-rw-r--r-- | aten/src/ATen/nn_parse.py | 2 | ||||
-rw-r--r-- | setup.py | 3 |
2 files changed, 3 insertions, 2 deletions
diff --git a/aten/src/ATen/nn_parse.py b/aten/src/ATen/nn_parse.py index 776e3dc988..7810620d8f 100644 --- a/aten/src/ATen/nn_parse.py +++ b/aten/src/ATen/nn_parse.py @@ -350,7 +350,7 @@ def backward_declaration(base, thnn_functions): pass else: base_name = arg['name'][len('grad_'):] if arg['name'] != 'grad_input' else 'self' - if base_name in [arg['name'] for arg in arguments]: + if base_name in [a['name'] for a in arguments]: scalar_check[arg['name']] = base_name + '_->isScalar()' else: raise ValueError(("Could not infer scalar_check for {} argument of func {} because {} " @@ -34,7 +34,8 @@ IS_WINDOWS = (platform.system() == 'Windows') IS_DARWIN = (platform.system() == 'Darwin') IS_LINUX = (platform.system() == 'Linux') - +if 'WITH_SCALARS' not in os.environ: + os.environ['WITH_SCALARS'] = '1' WITH_SCALARS = check_env_flag('WITH_SCALARS') try: |