summaryrefslogtreecommitdiff
path: root/tools/autograd
diff options
context:
space:
mode:
authorMichael Dagitses <mikeyd@fb.com>2021-09-02 06:49:09 -0700
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>2021-09-02 07:32:11 -0700
commitb737629ff0d4dd82f246b0efa6aef53f15971e78 (patch)
tree47fc536a30524d1d49ee3ac5560ba5bb4eeef07f /tools/autograd
parentb2c7c1dfcf9c366ecef5db635b201954981c609f (diff)
downloadpytorch-b737629ff0d4dd82f246b0efa6aef53f15971e78.tar.gz
pytorch-b737629ff0d4dd82f246b0efa6aef53f15971e78.tar.bz2
pytorch-b737629ff0d4dd82f246b0efa6aef53f15971e78.zip
simplify op name determination into a single forward pass (#64261)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64261 Note that this does not preserve byte-for-byte compatibility with existing names. Test Plan: * Rely on CI to catch gross errors. * Merge after release cut to catch subtle issues. Reviewed By: albanD Differential Revision: D30700647 Pulled By: dagitses fbshipit-source-id: 7b02f34b8fae3041240cc78fbc6bcae498c3acd4
Diffstat (limited to 'tools/autograd')
-rw-r--r--tools/autograd/load_derivatives.py82
1 files changed, 30 insertions, 52 deletions
diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py
index 3ff11f4d18..8a5904b732 100644
--- a/tools/autograd/load_derivatives.py
+++ b/tools/autograd/load_derivatives.py
@@ -2,9 +2,9 @@
#
# Each autograd function is represented by `DifferentiabilityInfo` containing
# a list of `Derivative`. See `tools.codegen.api.autograd` for the data models.
-from collections import defaultdict, Counter
+from collections import defaultdict
import re
-from typing import Sequence, Any, Tuple, List, Set, Dict, Match, Optional
+from typing import Counter, Sequence, Any, Tuple, List, Set, Dict, Match, Optional
import yaml
from tools.codegen.api.autograd import (Derivative, DifferentiabilityInfo,
@@ -43,32 +43,15 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque
assert str(function.func) not in functions_by_schema
functions_by_schema[str(function.func)] = function
+ # Keep track of how many of which ops we've seen so we can
+ # disambiguate them with a numeric suffix.
+ op_counter = Counter[str]()
+
infos = [
- create_differentiability_info(defn, functions_by_signature, functions_by_schema)
+ create_differentiability_info(defn, functions_by_signature, functions_by_schema, op_counter)
for defn in definitions]
- # To keep it byte-for-byte compatible with the old codegen, we assign op names as a separate
- # step. We only assign op names to those with differentiable args, and only append suffix to
- # duplicated op names. This can be simplified if the first of the duplicates can be named
- # 'XyzBackward' instead of 'XyzBackward0' or unconditionally append '0' to singletons.
- op_names = create_op_names(infos)
- res = [
- DifferentiabilityInfo(
- name=info.name,
- func=info.func,
- op=op_name,
- derivatives=info.derivatives,
- forward_derivatives=info.forward_derivatives,
- all_saved_inputs=info.all_saved_inputs,
- all_saved_outputs=info.all_saved_outputs,
- args_with_derivatives=info.args_with_derivatives,
- non_differentiable_arg_names=info.non_differentiable_arg_names,
- output_differentiability=info.output_differentiability,
- output_differentiability_conditions=info.output_differentiability_conditions,
- )
- for info, op_name in zip(infos, op_names)]
-
- _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = res
+ _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos
return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
@@ -279,6 +262,7 @@ def create_differentiability_info(
defn: Dict[Any, Any],
functions_by_signature: Dict[FunctionSchema, List[NativeFunction]],
functions_by_schema: Dict[str, NativeFunction],
+ op_counter: Counter[str],
) -> DifferentiabilityInfo:
"""Processes a single entry `defn` in derivatives.yaml"""
@@ -424,10 +408,17 @@ def create_differentiability_info(
derivatives, forward_derivatives, args_with_derivatives, non_differentiable_arg_names = set_up_derivatives(canonical)
+ # only assign an op name if we are actually going to calculate a derivative
+ op = None
+ if args_with_derivatives:
+ op_prefix = _create_op_prefix(defn_name)
+ op = f'{op_prefix}{op_counter[op_prefix]}'
+ op_counter[op_prefix] += 1
+
return DifferentiabilityInfo(
name=defn_name,
func=canonical,
- op=None,
+ op=op,
derivatives=derivatives,
forward_derivatives=forward_derivatives,
all_saved_inputs=dedup_vars([v for d in derivatives for v in d.saved_inputs]),
@@ -566,35 +557,22 @@ def saved_variables(
return formula, tuple(saved)
-def create_op_name(info: DifferentiabilityInfo) -> Optional[str]:
- # only assign an op name if we are actually going to calculate a derivative
- if not info.args_with_derivatives:
- return None
- name = info.name
+def _create_op_prefix(name: str) -> str:
+ """Takes a native function name converts to a op prefix name.
+
+ Note that the "name" parameter must be the native function name
+ without the optional variant suffix, so "add" instead of
+ "add.out".
+
+ OP names correspond to classes, hence the change to title case.
+
+ Example::
+ >>> _create_op_prefix('add')
+ 'AddBackward'
+ """
camel_case = ''.join([p.title() for p in name.split('_')])
return (camel_case + 'Backward').replace('ForwardBackward', 'Backward')
-def create_op_names(infos: Sequence[DifferentiabilityInfo]) -> Sequence[Optional[str]]:
- names = list(map(create_op_name, infos))
- dups = set(item for item, count in Counter(names).items() if count > 1)
-
- # de-duplicate operation names
- # you end up with something like:
- # AddBackward0
- # AddBackward1
- # one for each overload
- counter: Dict[str, int] = Counter()
- dedup: List[Optional[str]] = []
- for name in names:
- if name is None:
- # Keep a placeholder
- dedup.append(None)
- elif name in dups:
- dedup.append(f'{name}{counter[name]}')
- counter[name] += 1
- else:
- dedup.append(name)
- return dedup
def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
seen: Set[str] = set()