diff options
author | Michael Dagitses <mikeyd@fb.com> | 2021-09-02 06:49:09 -0700 |
---|---|---|
committer | Facebook GitHub Bot <facebook-github-bot@users.noreply.github.com> | 2021-09-02 07:32:11 -0700 |
commit | b737629ff0d4dd82f246b0efa6aef53f15971e78 (patch) | |
tree | 47fc536a30524d1d49ee3ac5560ba5bb4eeef07f /tools | |
parent | b2c7c1dfcf9c366ecef5db635b201954981c609f (diff) | |
download | pytorch-b737629ff0d4dd82f246b0efa6aef53f15971e78.tar.gz pytorch-b737629ff0d4dd82f246b0efa6aef53f15971e78.tar.bz2 pytorch-b737629ff0d4dd82f246b0efa6aef53f15971e78.zip |
simplify op name determination into a single forward pass (#64261)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64261
Note that this does not preserve byte-for-byte compatibility with
existing names.
Test Plan:
* Rely on CI to catch gross errors.
* Merge after release cut to catch subtle issues.
Reviewed By: albanD
Differential Revision: D30700647
Pulled By: dagitses
fbshipit-source-id: 7b02f34b8fae3041240cc78fbc6bcae498c3acd4
Diffstat (limited to 'tools')
-rw-r--r-- | tools/autograd/load_derivatives.py | 82 |
1 files changed, 30 insertions, 52 deletions
diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 3ff11f4d18..8a5904b732 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -2,9 +2,9 @@ # # Each autograd function is represented by `DifferentiabilityInfo` containing # a list of `Derivative`. See `tools.codegen.api.autograd` for the data models. -from collections import defaultdict, Counter +from collections import defaultdict import re -from typing import Sequence, Any, Tuple, List, Set, Dict, Match, Optional +from typing import Counter, Sequence, Any, Tuple, List, Set, Dict, Match, Optional import yaml from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo, @@ -43,32 +43,15 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque assert str(function.func) not in functions_by_schema functions_by_schema[str(function.func)] = function + # Keep track of how many of which ops we've seen so we can + # disambiguate them with a numeric suffix. + op_counter = Counter[str]() + infos = [ - create_differentiability_info(defn, functions_by_signature, functions_by_schema) + create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter) for defn in definitions] - # To keep it byte-for-byte compatible with the old codegen, we assign op names as a separate - # step. We only assign op names to those with differentiable args, and only append suffix to - # duplicated op names. This can be simplified if the first of the duplicates can be named - # 'XyzBackward' instead of 'XyzBackward0' or unconditionally append '0' to singletons. - op_names = create_op_names(infos) - res = [ - DifferentiabilityInfo( - name=info.name, - func=info.func, - op=op_name, - derivatives=info.derivatives, - forward_derivatives=info.forward_derivatives, - all_saved_inputs=info.all_saved_inputs, - all_saved_outputs=info.all_saved_outputs, - args_with_derivatives=info.args_with_derivatives, - non_differentiable_arg_names=info.non_differentiable_arg_names, - output_differentiability=info.output_differentiability, - output_differentiability_conditions=info.output_differentiability_conditions, - ) - for info, op_name in zip(infos, op_names)] - - _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = res + _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos return _GLOBAL_LOAD_DERIVATIVE_CACHE[key] @@ -279,6 +262,7 @@ def create_differentiability_info( defn: Dict[Any, Any], functions_by_signature: Dict[FunctionSchema, List[NativeFunction]], functions_by_schema: Dict[str, NativeFunction], + op_counter: Counter[str], ) -> DifferentiabilityInfo: """Processes a single entry `defn` in derivatives.yaml""" @@ -424,10 +408,17 @@ def create_differentiability_info( derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names = set_up_derivatives(canonical) + # only assign an op name if we are actually going to calculate a derivative + op = None + if args_with_derivatives: + op_prefix = _create_op_prefix(defn_name) + op = f'{op_prefix}{op_counter[op_prefix]}' + op_counter[op_prefix] += 1 + return DifferentiabilityInfo( name=defn_name, func=canonical, - op=None, + op=op, derivatives=derivatives, forward_derivatives=forward_derivatives, all_saved_inputs=dedup_vars([v for d in derivatives for v in d.saved_inputs]), @@ -566,35 +557,22 @@ def saved_variables( return formula, tuple(saved) -def create_op_name(info: DifferentiabilityInfo) -> Optional[str]: - # only assign an op name if we are actually going to calculate a derivative - if not info.args_with_derivatives: - return None - name = info.name +def _create_op_prefix(name: str) -> str: + """Takes a native function name converts to a op prefix name. + + Note that the "name" parameter must be the native function name + without the optional variant suffix, so "add" instead of + "add.out". + + OP names correspond to classes, hence the change to title case. + + Example:: + >>> _create_op_prefix('add') + 'AddBackward' + """ camel_case = ''.join([p.title() for p in name.split('_')]) return (camel_case + 'Backward').replace('ForwardBackward', 'Backward') -def create_op_names(infos: Sequence[DifferentiabilityInfo]) -> Sequence[Optional[str]]: - names = list(map(create_op_name, infos)) - dups = set(item for item, count in Counter(names).items() if count > 1) - - # de-duplicate operation names - # you end up with something like: - # AddBackward0 - # AddBackward1 - # one for each overload - counter: Dict[str, int] = Counter() - dedup: List[Optional[str]] = [] - for name in names: - if name is None: - # Keep a placeholder - dedup.append(None) - elif name in dups: - dedup.append(f'{name}{counter[name]}') - counter[name] += 1 - else: - dedup.append(name) - return dedup def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]: seen: Set[str] = set() |