summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgchanan <gregchanan@gmail.com>2018-02-09 15:45:41 -0500
committerGitHub <noreply@github.com>2018-02-09 15:45:41 -0500
commit4b8bf737297ae86049112c76914fc6730fa689d4 (patch)
tree396860377b5339de5c05311317e35046649e1d5c
parent8f1f84a6f2fa6c73a4e8d4751d5d231dd80105b2 (diff)
downloadpytorch-4b8bf737297ae86049112c76914fc6730fa689d4.tar.gz
pytorch-4b8bf737297ae86049112c76914fc6730fa689d4.tar.bz2
pytorch-4b8bf737297ae86049112c76914fc6730fa689d4.zip
Enable scalars. (#5158)
* Enable scalars. * Avoid variable name shadowing in list comprehension, because it rebinds in python2, but not python3.
-rw-r--r--aten/src/ATen/nn_parse.py2
-rw-r--r--setup.py3
2 files changed, 3 insertions, 2 deletions
diff --git a/aten/src/ATen/nn_parse.py b/aten/src/ATen/nn_parse.py
index 776e3dc988..7810620d8f 100644
--- a/aten/src/ATen/nn_parse.py
+++ b/aten/src/ATen/nn_parse.py
@@ -350,7 +350,7 @@ def backward_declaration(base, thnn_functions):
pass
else:
base_name = arg['name'][len('grad_'):] if arg['name'] != 'grad_input' else 'self'
- if base_name in [arg['name'] for arg in arguments]:
+ if base_name in [a['name'] for a in arguments]:
scalar_check[arg['name']] = base_name + '_->isScalar()'
else:
raise ValueError(("Could not infer scalar_check for {} argument of func {} because {} "
diff --git a/setup.py b/setup.py
index 1364725444..e87075fa38 100644
--- a/setup.py
+++ b/setup.py
@@ -34,7 +34,8 @@ IS_WINDOWS = (platform.system() == 'Windows')
IS_DARWIN = (platform.system() == 'Darwin')
IS_LINUX = (platform.system() == 'Linux')
-
+if 'WITH_SCALARS' not in os.environ:
+ os.environ['WITH_SCALARS'] = '1'
WITH_SCALARS = check_env_flag('WITH_SCALARS')
try: