diff options
author | gchanan <gregchanan@gmail.com> | 2018-04-18 23:37:54 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-04-18 23:37:54 -0400 |
commit | e1f5d80d5c080ffc68145425a38c31d5cf6437d3 (patch) | |
tree | ab30173829c72ce3a9d1875a81f33138fa3732d9 /aten | |
parent | 9c47eb554858edad4534479878ec1f82cef4c95c (diff) | |
download | pytorch-e1f5d80d5c080ffc68145425a38c31d5cf6437d3.tar.gz pytorch-e1f5d80d5c080ffc68145425a38c31d5cf6437d3.tar.bz2 pytorch-e1f5d80d5c080ffc68145425a38c31d5cf6437d3.zip |
Eliminate handle_zero_dim when broadcasting is applied earlier. (#6683)
* Eliminate handle_zero_dim when broadcasting is applied earlier.
This ends up not actually doing anything unless all the broadcasted tensors are scalars,
which ends up with inconsistent behavior in that case only, because the type promotion rules are different.
This is better solved with real type promotion logic.
* Change type of script comparison to long.
* Fix jit tests.
* Fix cpp jit test by being consistent about long-vs-float.
* Consistent float and long.
* Use int64_t rather than long.
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/function_wrapper.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 76aab3d7ba..266b1505e1 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -1079,15 +1079,19 @@ def create_derived(backend_type_env, declarations): return backend_type_env['AccScalarName'] == 'Long' return False - def get_zero_dim_dispatch_when_scalar(option): - # type: (FunctionOption) -> str - return option.get('zero_dim_dispatch_when_scalar', False) # type: ignore - def handle_zero_dim(env, option): # type: (Environment, FunctionOption) -> List[str] - zero_dim_dispatch = get_zero_dim_dispatch_when_scalar(option) + zero_dim_dispatch = option.get('zero_dim_dispatch_when_scalar', '') if not zero_dim_dispatch: return [] + broadcasts_arg = zero_dim_dispatch in option.get('broadcast_actuals', '') + zero_dim_only = option.get('zero_dim_tensor_only', False) + # this combination doesn't seem to make sense + assert not (broadcasts_arg and zero_dim_only) + # if the argument broadcasts, then this would only affect cases where all broadcasted + # tensors were zero-dim, which is inconsistent with the scalar handling. + if broadcasts_arg: + return [] zero_dim_actuals = [arg['name'] if arg['name'] != zero_dim_dispatch else "Scalar({})".format(arg['name']) for arg in option['formals_list']] @@ -1096,7 +1100,7 @@ def create_derived(backend_type_env, declarations): def handle_only_zero_dim(env, option): # type: (Environment, FunctionOption) -> List[str] if option.get('zero_dim_tensor_only', False): - check_name = get_zero_dim_dispatch_when_scalar(option) + check_name = option['zero_dim_dispatch_when_scalar'] return [ZERO_DIM_ONLY.substitute(env, check_name=check_name)] else: return None |