summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--benchmarks/instruction_counts/README.md142
-rw-r--r--benchmarks/instruction_counts/core/__init__.py0
-rw-r--r--benchmarks/instruction_counts/core/api.py402
-rw-r--r--benchmarks/instruction_counts/core/expand.py260
-rw-r--r--benchmarks/instruction_counts/core/types.py94
-rw-r--r--benchmarks/instruction_counts/core/utils.py99
-rw-r--r--benchmarks/instruction_counts/definitions/__init__.py0
-rw-r--r--benchmarks/instruction_counts/definitions/setup.py57
-rw-r--r--benchmarks/instruction_counts/definitions/standard.py143
-rw-r--r--benchmarks/instruction_counts/main.py57
-rw-r--r--benchmarks/instruction_counts/worker/__init__.py0
-rw-r--r--benchmarks/instruction_counts/worker/main.py188
-rw-r--r--mypy-strict.ini2
13 files changed, 1444 insertions, 0 deletions
diff --git a/benchmarks/instruction_counts/README.md b/benchmarks/instruction_counts/README.md
new file mode 100644
index 0000000000..ed2633caba
--- /dev/null
+++ b/benchmarks/instruction_counts/README.md
@@ -0,0 +1,142 @@
+# Instruction count microbenchmarks
+## Quick start
+
+### To run the benchmark:
+
+```
+# From pytorch root
+cd benchmarks/instruction_counts
+python main.py
+```
+
+Currently `main.py` contains a very simple threadpool (so that run time isn't
+unbearably onerous) and simply prints the results. These components will be
+upgraded in subsequent PRs.
+
+### To define a new benchmark:
+* `TimerArgs`: Low level definition which maps directly to
+`torch.utils.benchmark.Timer`
+* `GroupedStmts`: Benchmark a snippet. (Python, C++, or both) Can automatically
+generate TorchScript and autograd variants.
+* `GroupedModules`: Like `GroupedStmts`, but takes `nn.Module`s
+* `GroupedVariants`: Benchmark-per-line to define many related benchmarks in a
+single code block.
+
+## Architecture
+### Benchmark definition.
+
+One primary goal of this suite is to make it easy to define semantically
+related clusters of benchmarks. The crux of this effort is the
+`GroupedBenchmark` class, which is defined in `core/api.py`. It takes a
+definition for a set of related benchmarks, and produces one or more concrete
+cases. It's helpful to see an example to understand how the machinery works.
+Consider the following benchmark:
+
+```
+# `GroupedStmts` is an alias of `GroupedBenchmark.init_from_stmts`
+benchmark = GroupedStmts(
+ py_stmt=r"y = x * w",
+ cpp_stmt=r"auto y = x * w;",
+
+ setup=GroupedSetup(
+ py_setup="""
+ x = torch.ones((4, 4))
+ w = torch.ones((4, 4), requires_grad=True)
+ """,
+ cpp_setup="""
+ auto x = torch::ones((4, 4));
+ auto w = torch::ones((4, 4));
+ w.set_requires_grad(true);
+ """,
+ ),
+
+ signature="f(x, w) -> y",
+ torchscript=True,
+ autograd=True,
+),
+```
+
+It is trivial to generate Timers for the eager forward mode case (ignoring
+`num_threads` for now):
+
+```
+Timer(
+ stmt=benchmark.py_fwd_stmt,
+ setup=benchmark.setup.py_setup,
+)
+
+Timer(
+ stmt=benchmark.cpp_fwd_stmt,
+ setup=benchmark.setup.cpp_setup,
+ language="cpp",
+)
+```
+
+Moreover, because `signature` is provided we know that creation of `x` and `w`
+is part of setup, and the overall comptation uses `x` and `w` to produce `y`.
+As a result, we can derive TorchScript'd and AutoGrad variants as well. We can
+deduce that a TorchScript model will take the form:
+
+```
+@torch.jit.script
+def f(x, w):
+ # Paste `benchmark.py_fwd_stmt` into the function body.
+ y = x * w
+ return y # Set by `-> y` in signature.
+```
+
+And because we will want to use this model in both Python and C++, we save it to
+disk and load it as needed. At this point Timers for TorchScript become:
+
+```
+Timer(
+ stmt="""
+ y = jit_model(x, w)
+ """,
+ setup=""",
+ # benchmark.setup.py_setup
+ # jit_model = torch.jit.load(...)
+ # Warm up jit_model
+ """,
+)
+
+Timer(
+ stmt="""
+ std::vector<torch::jit::IValue> ivalue_inputs(
+ torch::jit::IValue({x}),
+ torch::jit::IValue({w})
+ );
+ auto y = jit_model.forward(ivalue_inputs);
+ """,
+ setup="""
+ # benchmark.setup.cpp_setup
+ # jit_model = torch::jit::load(...)
+ # Warm up jit_model
+ """,
+)
+```
+
+While nothing above is particularly complex, there is non-trivial bookkeeping
+(managing the model artifact, setting up IValues) which if done manually would
+be rather bug-prone and hard to read.
+
+The story is similar for autograd: because we know the output variable (`y`)
+and we make sure to assign it when calling TorchScript models, testing AutoGrad
+is as simple as appending `y.backward()` (or `y.backward();` in C++) to the
+stmt of the forward only variant. Of course this requires that `signature` be
+provided, as there is nothing special about the name `y`.
+
+The logic for the manipulations above is split between `core/api.py` (for
+generating `stmt` based on language, Eager/TorchScript, with or without AutoGrad)
+and `core/expand.py` (for larger, more expansive generation). The benchmarks
+themselves are defined in `definitions/standard.py`. The current set is chosen
+to demonstrate the various model definition APIs, and will be expanded when the
+benchmark runner infrastructure is better equipped to deal with a larger run.
+
+### Benchmark execution.
+
+Once `expand.materialize` has flattened the abstract benchmark definitions into
+`TimerArgs`, they can be sent to a worker (`worker/main.py`) subprocess to
+execution. This worker has no concept of the larger benchmark suite; `TimerArgs`
+is a one-to-one and direct mapping to the `torch.utils.benchmark.Timer` instance
+that the worker instantiates.
diff --git a/benchmarks/instruction_counts/core/__init__.py b/benchmarks/instruction_counts/core/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/benchmarks/instruction_counts/core/__init__.py
diff --git a/benchmarks/instruction_counts/core/api.py b/benchmarks/instruction_counts/core/api.py
new file mode 100644
index 0000000000..b29c468aba
--- /dev/null
+++ b/benchmarks/instruction_counts/core/api.py
@@ -0,0 +1,402 @@
+"""Key enums and structs used to handle data flow within the benchmark."""
+import dataclasses
+import enum
+import re
+import textwrap
+from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING
+
+from worker.main import WorkerTimerArgs
+
+if TYPE_CHECKING:
+ # Benchmark utils are only partially strict compliant, so MyPy won't follow
+ # imports using the public namespace. (Due to an exclusion rule in
+ # mypy-strict.ini)
+ from torch.utils.benchmark.utils.timer import Language
+else:
+ from torch.utils.benchmark import Language
+
+
+# Note:
+# WorkerTimerArgs is defined in worker.main so that the worker does not
+# depend on any files, including core.api. We mirror it with a public symbol
+# `TimerArgs` for API consistency.
+TimerArgs = WorkerTimerArgs
+
+
+class RuntimeMode(enum.Enum):
+ EAGER = "Eager"
+ JIT = "TorchScript"
+ EXPLICIT = ""
+
+
+class AutogradMode(enum.Enum):
+ FORWARD = "Forward"
+ FORWARD_BACKWARD = "Forward + Backward"
+ EXPLICIT = ""
+
+
+@dataclasses.dataclass(frozen=True)
+class AutoLabels:
+ """Labels for a TimerArgs instance which are inferred during unpacking."""
+ runtime: RuntimeMode
+ autograd: AutogradMode
+ language: Language
+
+
+@dataclasses.dataclass(frozen=True)
+class GroupedSetup:
+ py_setup: str = ""
+ cpp_setup: str = ""
+ global_setup: str = ""
+
+ def __post_init__(self) -> None:
+ for field in dataclasses.fields(self):
+ assert field.type == str
+ value: str = getattr(self, field.name)
+ object.__setattr__(self, field.name, textwrap.dedent(value))
+
+
+@dataclasses.dataclass(frozen=True)
+class GroupedBenchmark:
+ """Base class for defining groups of benchmarks.
+
+ Concrete interfaces:
+ - `core.api.GroupedStmts` (init_from_stmts)
+ - `core.api.GroupedModules` (init_from_model)
+ - `core.api.GroupedVariants` (init_from_variants)
+
+ There are a variety of dimensions along which one might wish to measure
+ PyTorch performance:
+ - Python, C++
+ - Eager, TorchScript
+ - Single threaded, multi threaded
+ - Training, inference
+
+ It is useful to define them together, both for clear, concise benchmark
+ definition and more intelligent post processing and analysis.
+
+ There are also two programming idioms in PyTorch. One is to write free form
+ code (so-called "NumPy with gradients"), and the other is to organize code
+ using `torch.nn.Module`s. (This is how common neural network layers are
+ exposed through the PyTorch API.) To support easy definition two simple
+ initialization methods are provided:
+ - `init_from_stmts`
+ - `init_from_model`
+
+ Those methods will document their unique constructor arguments, however
+ most are shared and are defined here:
+ setup: Defines how to initialize a benchmark in both Python and C++.
+ signature:
+ A string of the form:
+ ```
+ f(a, b, ...) -> c
+ ```
+ For instance, if Python setup is:
+ ```
+ x = torch.ones((2,), requires_grad=True)
+ y = torch.ones((2,))
+ ```
+ and the corresponding stmt is:
+ ```
+ z = torch.dot(x, y)
+ ```
+ Then the signature is `f(x, y) -> z`. `signature` is required any
+ time we need to generate part of a snippet:
+ - When calling an opaque model provided by `init_from_models`
+ - When `torchscript=True`
+ - When `autograd=True`
+
+ If a return value is not needed (e.g. because of in place mutation)
+ then `-> None` is valid, but a non-None return must be provided if
+ `autograd=True`
+
+ torchscript:
+ If True, also JIT the stmt or model and generate benchmarks which
+ call the scripted version. Requires that `signature` is defined.
+
+ autograd:
+ If True, generate both forward and forward + backward benchmarks.
+ Requires that `signature` is defined, and return value is not None.
+
+ num_threads:
+ Maps to the Timer arg. If a tuple of ints is provided, benchmarks
+ will be generated for each value.
+
+ A third method, `init_from_variants`, is provided to define several related
+ benchmarks at once.
+ """
+
+ # These are the stmts which are actually executed by Timer. In the case of
+ # `GroupedStmts` (init_from_stmts) they are passed through from user args.
+ # In the case of `GroupedModules` (init_from_model) they are generated
+ # using `signature`. (e.g. `f(x, y) -> z` generates `z = model(x, y)`)
+ py_fwd_stmt: Optional[str]
+ cpp_fwd_stmt: Optional[str]
+
+ # Code block used to define a model. `init_from_stmts` will never populate
+ # `cpp_model_setup`, but if TorchScript is requested it will generate
+ # `py_model_setup` using `torch.jit.script`.
+ py_model_setup: Optional[str]
+ cpp_model_setup: Optional[str]
+
+ # True if this benchmark used `init_from_stmts`, otherwise False.
+ inferred_model_setup: bool
+
+ # Described above
+ setup: GroupedSetup
+ signature_args: Optional[Tuple[str, ...]]
+ signature_output: Optional[str]
+ torchscript: bool
+ autograd: bool
+ num_threads: Tuple[int, ...]
+
+ @classmethod
+ def init_from_stmts(
+ cls,
+ py_stmt: Optional[str] = None,
+ cpp_stmt: Optional[str] = None,
+
+ # Generic constructor arguments
+ setup: GroupedSetup = GroupedSetup(),
+ signature: Optional[str] = None,
+ torchscript: bool = False,
+ autograd: bool = False,
+ num_threads: Union[int, Tuple[int, ...]] = 1,
+ ) -> "GroupedBenchmark":
+ """Create a set of benchmarks from free-form statements.
+
+ This method of benchmark definition is analogous to Timer use, where
+ we simply execute the provided stmts.
+ """
+ if py_stmt is not None:
+ py_stmt = textwrap.dedent(py_stmt)
+
+ if cpp_stmt is not None:
+ cpp_stmt = textwrap.dedent(cpp_stmt)
+
+ signature_args, signature_output = cls._parse_signature(signature)
+ py_model_setup = (
+ cls._model_from_py_stmt(
+ py_stmt=py_stmt,
+ signature_args=signature_args,
+ signature_output=signature_output
+ ) if torchscript else None
+ )
+
+ return cls(
+ py_fwd_stmt=py_stmt,
+ cpp_fwd_stmt=cpp_stmt,
+ py_model_setup=py_model_setup,
+ cpp_model_setup=None,
+ inferred_model_setup=True,
+ setup=setup,
+ signature_args=signature_args,
+ signature_output=signature_output,
+ torchscript=torchscript,
+ autograd=autograd,
+ num_threads=(num_threads,) if isinstance(num_threads, int) else num_threads,
+ )
+
+ @classmethod
+ def init_from_model(
+ cls,
+ py_model_setup: Optional[str] = None,
+ cpp_model_setup: Optional[str] = None,
+
+ # Generic constructor arguments
+ setup: GroupedSetup = GroupedSetup(),
+ signature: Optional[str] = None,
+ torchscript: bool = False,
+ autograd: bool = False,
+ num_threads: Union[int, Tuple[int, ...]] = 1,
+ ) -> "GroupedBenchmark":
+ """Create a set of benchmarks using torch.nn Modules.
+
+ This method of benchmark creation takes setup code, and then calls
+ a model rather than a free form block of code. As a result, there are
+ two additional requirements compared to `init_from_stmts`:
+ - `signature` must be provided.
+ - A model (named "model") must be defined, either with `model = ...`
+ or `def model(...): ...` in Python or `auto model = ...` in C++.
+ """
+ signature_args, signature_output = cls._parse_signature(signature)
+ if signature_args is None:
+ raise ValueError("signature is needed when initializing from model definitions.")
+
+ return cls(
+ *cls._make_model_invocation(signature_args, signature_output, RuntimeMode.EAGER),
+ py_model_setup=py_model_setup,
+ cpp_model_setup=cpp_model_setup,
+ inferred_model_setup=False,
+ setup=setup,
+ signature_args=signature_args,
+ signature_output=signature_output,
+ torchscript=torchscript,
+ autograd=autograd,
+ num_threads=(num_threads,) if isinstance(num_threads, int) else num_threads,
+ )
+
+ @classmethod
+ def init_from_variants(
+ cls,
+ py_block: str = "",
+ cpp_block: str = "",
+ num_threads: Union[int, Tuple[int, ...]] = 1,
+ ) -> Dict[Union[Tuple[str, ...], Optional[str]], "GroupedBenchmark"]:
+
+ py_cases, py_setup, py_global_setup = cls._parse_variants(py_block, Language.PYTHON)
+ cpp_cases, cpp_setup, cpp_global_setup = cls._parse_variants(cpp_block, Language.CPP)
+
+ assert not py_global_setup
+ setup = GroupedSetup(
+ py_setup=py_setup,
+ cpp_setup=cpp_setup,
+ global_setup=cpp_global_setup,
+ )
+
+ # NB: The key is actually `Tuple[str, ...]`, however MyPy gets confused
+ # and we use the superset `Union[Tuple[str, ...], Optional[str]` to
+ # match the expected signature.
+ variants: Dict[Union[Tuple[str, ...], Optional[str]], GroupedBenchmark] = {}
+ for label in set(list(py_cases.keys()) + list(cpp_cases.keys())):
+ py_lines = py_cases.get(label, [])
+ cpp_lines = cpp_cases.get(label, [])
+
+ n_lines = max(len(py_lines), len(cpp_lines))
+ py_lines += [""] * (n_lines - len(py_lines))
+ cpp_lines += [""] * (n_lines - len(cpp_lines))
+ lines = [
+ (py_stmt, cpp_stmt)
+ for py_stmt, cpp_stmt in zip(py_lines, cpp_lines)
+ if py_stmt or cpp_stmt
+ ]
+
+ for i, (py_stmt, cpp_stmt) in enumerate(lines):
+ variants[(label, f"Case: {i:>2}")] = GroupedBenchmark.init_from_stmts(
+ py_stmt=py_stmt or None,
+ cpp_stmt=cpp_stmt or None,
+ setup=setup,
+ num_threads=num_threads,
+ )
+
+ return variants
+
+ def __post_init__(self) -> None:
+ if self.autograd and self.signature_output is None:
+ raise ValueError("An output variable must be specified when `autograd=True`.")
+
+ if self.py_model_setup and "model" not in self.py_model_setup:
+ raise ValueError("`py_model_setup` appears to be missing `model` definition.")
+
+ if self.cpp_model_setup and "model" not in self.cpp_model_setup:
+ raise ValueError("`cpp_model_setup` appears to be missing `model` definition.")
+
+ # =========================================================================
+ # == String manipulation methods ==========================================
+ # =========================================================================
+
+ @staticmethod
+ def _parse_signature(
+ signature: Optional[str]
+ ) -> Tuple[Optional[Tuple[str, ...]], Optional[str]]:
+ if signature is None:
+ return None, None
+
+ match = re.search(r"^f\((.*)\) -> (.*)$", signature)
+ if match is None:
+ raise ValueError(f"Invalid signature: `{signature}`")
+
+ args: Tuple[str, ...] = tuple(match.groups()[0].split(", "))
+ output: str = match.groups()[1].strip()
+
+ if "," in output:
+ raise ValueError(f"Multiple return values are not currently allowed: `{output}`")
+
+ if output == "None":
+ return args, None
+
+ return args, output
+
+ @staticmethod
+ def _model_from_py_stmt(
+ py_stmt: Optional[str],
+ signature_args: Optional[Tuple[str, ...]],
+ signature_output: Optional[str],
+ ) -> str:
+ if py_stmt is None:
+ raise ValueError("`py_stmt` must be defined in order to derive a model.")
+
+ if signature_args is None:
+ raise ValueError("signature is needed in order to derive a model.")
+
+ return textwrap.dedent(f"""\
+ def model({', '.join(signature_args)}):
+ {{stmt_str}}
+ return {signature_output}
+ """).format(stmt_str=textwrap.indent(py_stmt, ' ' * 4))
+
+ @staticmethod
+ def _make_model_invocation(
+ signature_args: Tuple[str, ...],
+ signature_output: Optional[str],
+ runtime: RuntimeMode,
+ ) -> Tuple[str, str]:
+ py_prefix, cpp_prefix = "", ""
+ if signature_output is not None:
+ py_prefix = f"{signature_output} = "
+ cpp_prefix = f"auto {signature_output} = "
+
+ if runtime == RuntimeMode.EAGER:
+ model_name = "model"
+ cpp_invocation = f"{cpp_prefix}{model_name}->forward({', '.join(signature_args)});"
+
+ else:
+ assert runtime == RuntimeMode.JIT
+ model_name = "jit_model"
+ cpp_invocation = textwrap.dedent(f"""\
+ std::vector<torch::jit::IValue> ivalue_inputs({{
+ {', '.join([f'torch::jit::IValue({a})' for a in signature_args])}
+ }});
+ {cpp_prefix}{model_name}.forward(ivalue_inputs);
+ """)
+
+ # NB:
+ # In python we invoke __call__, however C++ doesn't have an analogous
+ # method so we invoke `forward` instead. This means that that Python
+ # is doing extra work (e.g. checking hooks) compared to C++; however
+ # because this is the default user experience that's acceptable.
+ py_invocation = f"{py_prefix}{model_name}({', '.join(signature_args)})"
+
+ return py_invocation, cpp_invocation
+
+ @staticmethod
+ def _parse_variants(block: str, language: Language) -> Tuple[Dict[str, List[str]], str, str]:
+ block = textwrap.dedent(block).strip()
+ comment = "#" if language == Language.PYTHON else "//"
+ label_pattern = f"{comment} @(.+)$"
+ label = ""
+
+ lines_by_label: Dict[str, List[str]] = {"SETUP": [], "GLOBAL_SETUP": []}
+ for line in block.splitlines(keepends=False):
+ match = re.search(label_pattern, line.strip())
+ if match:
+ label = match.groups()[0]
+ if label.replace(" ", "_").upper() in ("SETUP", "GLOBAL_SETUP"):
+ label = label.replace(" ", "_").upper()
+ continue
+
+ lines_by_label.setdefault(label, [])
+ if line.startswith(comment):
+ line = ""
+ lines_by_label[label].append(line)
+
+ setup = "\n".join(lines_by_label.pop("SETUP"))
+ global_setup = "\n".join(lines_by_label.pop("GLOBAL_SETUP"))
+
+ return lines_by_label, setup, global_setup
+
+
+# These are the user facing APIs.
+GroupedStmts = GroupedBenchmark.init_from_stmts
+GroupedModules = GroupedBenchmark.init_from_model
+GroupedVariants = GroupedBenchmark.init_from_variants
diff --git a/benchmarks/instruction_counts/core/expand.py b/benchmarks/instruction_counts/core/expand.py
new file mode 100644
index 0000000000..2736a4f805
--- /dev/null
+++ b/benchmarks/instruction_counts/core/expand.py
@@ -0,0 +1,260 @@
+"""Logic for converting human-readable benchmarks into executable form.
+
+This is mostly string manipulation, with just a bit of importlib magic.
+"""
+import importlib.abc
+import importlib.util
+import itertools as it
+import os
+import re
+import textwrap
+from typing import cast, List, Optional, Tuple, TYPE_CHECKING
+import uuid
+
+import torch
+
+if TYPE_CHECKING:
+ # See the note in api.py for why this is necessary.
+ from torch.utils.benchmark.utils.timer import Language
+else:
+ from torch.utils.benchmark import Language
+
+from core.api import AutogradMode, AutoLabels, GroupedBenchmark, RuntimeMode, TimerArgs
+from core.types import FlatDefinition, FlatIntermediateDefinition, Label
+from core.utils import get_temp_dir
+
+
+_ALL_MODES = tuple(it.product(
+ RuntimeMode,
+ AutogradMode,
+ Language,
+))
+
+
+def _generate_torchscript_file(model_src: str, name: str) -> Optional[str]:
+ """Returns the path a saved model if one can be constructed from `spec`.
+
+ Because TorchScript requires actual source code in order to script a
+ model, we can't simply `eval` an appropriate model string. Instead, we
+ must write the correct source to a temporary Python file and then import
+ the TorchScript model from that temporary file.
+
+ `model_src` must contain `jit_model = ...`, which `materialize` will supply.
+ """
+ # Double check.
+ assert "jit_model = " in model_src, f"Missing jit_model definition:\n{model_src}"
+
+ # `torch.utils.benchmark.Timer` will automatically import torch, so we
+ # need to match that convention.
+ model_src = f"import torch\n{model_src}"
+
+ model_root = os.path.join(get_temp_dir(), "TorchScript_models")
+ os.makedirs(model_root, exist_ok=True)
+ module_path = os.path.join(model_root, f"torchscript_{name}.py")
+ artifact_path = os.path.join(model_root, f"torchscript_{name}.pt")
+
+ if os.path.exists(module_path):
+ # The uuid in `name` should protect against this, but it doesn't hurt
+ # to confirm.
+ raise ValueError(f"File {module_path} already exists.")
+
+ with open(module_path, "wt") as f:
+ f.write(model_src)
+
+ # Import magic to actually load our function.
+ module_spec = importlib.util.spec_from_file_location(f"torchscript__{name}", module_path)
+ module = importlib.util.module_from_spec(module_spec)
+ loader = module_spec.loader
+ assert loader is not None
+
+ # Module.loader has type Optional[_Loader]. Even when we assert loader is
+ # not None and MyPy narrows it to type _Loader, it will not pass type
+ # checks. So we have to use a cast to tell MyPy that _Loader implements
+ # importlib.abc.Loader.
+ cast(importlib.abc.Loader, loader).exec_module(module)
+
+ # And again, the type checker has no way of knowing that this line is valid.
+ jit_model = module.jit_model # type: ignore
+ assert isinstance(
+ jit_model,
+ (torch.jit.ScriptFunction, torch.jit.ScriptModule)
+ ), f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}"
+ jit_model.save(artifact_path)
+
+ # Cleanup now that we have the actual serialized model.
+ os.remove(module_path)
+ return artifact_path
+
+
+def _get_stmt(
+ benchmark: GroupedBenchmark,
+ runtime: RuntimeMode,
+ autograd: AutogradMode,
+ language: Language,
+) -> Optional[str]:
+ """Specialize a GroupedBenchmark for a particular configuration."""
+ is_python = (language == Language.PYTHON)
+
+ # During GroupedBenchmark construction, py_fwd_stmt and cpp_fwd_stmt are
+ # set to the eager invocation. So in the RuntimeMode.EAGER case we can
+ # simply reuse them. For the RuntimeMode.JIT case, we need to generate
+ # an appropriate `jit_model(...)` invocation.
+ if runtime == RuntimeMode.EAGER:
+ stmts = (benchmark.py_fwd_stmt, benchmark.cpp_fwd_stmt)
+
+ else:
+ assert runtime == RuntimeMode.JIT
+ assert benchmark.signature_args is not None
+ stmts = GroupedBenchmark._make_model_invocation(
+ benchmark.signature_args, benchmark.signature_output, RuntimeMode.JIT)
+
+ stmt = stmts[0 if is_python else 1]
+
+ if autograd == AutogradMode.FORWARD_BACKWARD and stmt is not None:
+ assert benchmark.signature_output is not None
+ backward = (
+ f"{benchmark.signature_output}"
+
+ # In C++ we have to get the Tensor out of the IValue to call `.backward()`
+ f"{'.toTensor()' if runtime == RuntimeMode.JIT and language == Language.CPP else ''}"
+ f".backward(){';' if language == Language.CPP else ''}"
+ )
+ stmt = f"{stmt}\n{backward}"
+ return stmt
+
+
+def _get_setup(
+ benchmark: GroupedBenchmark,
+ runtime: RuntimeMode,
+ language: Language,
+ stmt: str,
+ model_path: Optional[str]
+) -> str:
+ """Specialize a GroupedBenchmark for a particular configuration.
+
+ Setup requires two extra pieces of information:
+ 1) The benchmark stmt. This is needed to warm up the model and avoid
+ measuring lazy initialization.
+ 2) The model path so we can load it during the benchmark.
+
+ These are only used when `runtime == RuntimeMode.JIT`.
+ """
+
+ # By the time we get here, details about how to set up a model have already
+ # been determined by GroupedBenchmark. (Or set to None if appropriate.) We
+ # simply need to collect and package the code blocks.
+ if language == Language.PYTHON:
+ setup = benchmark.setup.py_setup
+ model_setup = benchmark.py_model_setup
+ else:
+ assert language == Language.CPP
+ setup = benchmark.setup.cpp_setup
+ model_setup = benchmark.cpp_model_setup
+
+ if runtime == RuntimeMode.EAGER:
+ return "\n".join([setup, model_setup or ""])
+
+ assert runtime == RuntimeMode.JIT
+ assert model_path is not None
+
+ # We template `"{model_path}"`, so quotes would break model loading. The
+ # model path is generated within the benchmark, so this is just an
+ # abundance of caution rather than something that is expected in practice.
+ assert '"' not in model_path
+
+ # `stmt` may contain newlines, so we can't use f-strings. Instead we need
+ # to generate templates so that dedent works properly.
+ if language == Language.PYTHON:
+ setup_template: str = textwrap.dedent(f"""
+ jit_model = torch.jit.load("{model_path}")
+
+ # Warmup `jit_model`
+ for _ in range(3):
+ {{stmt}}
+ """)
+
+ else:
+ assert language == Language.CPP
+ setup_template = textwrap.dedent(f"""
+ const std::string fpath = "{model_path}";
+ auto jit_model = torch::jit::load(fpath);
+
+ // Warmup `jit_model`
+ for (int i = 0; i < 3; i++) {{{{
+ {{stmt}}
+ }}}}
+ """)
+
+ model_load = setup_template.format(stmt=textwrap.indent(stmt, ' ' * 4))
+ return "\n".join([setup, model_load])
+
+
+def materialize(benchmarks: FlatIntermediateDefinition) -> FlatDefinition:
+ """Convert a heterogeneous benchmark into an executable state.
+
+ This entails generation of TorchScript model artifacts, splitting
+ GroupedBenchmarks into multiple TimerArgs, and tagging the results with
+ AutoLabels.
+ """
+ results: List[Tuple[Label, AutoLabels, TimerArgs]] = []
+
+ for label, args in benchmarks.items():
+ if isinstance(args, TimerArgs):
+ # User provided an explicit TimerArgs, so no processing is necessary.
+ auto_labels = AutoLabels(
+ RuntimeMode.EXPLICIT,
+ AutogradMode.EXPLICIT,
+ args.language
+ )
+ results.append((label, auto_labels, args))
+
+ else:
+ assert isinstance(args, GroupedBenchmark)
+
+ model_path: Optional[str] = None
+ if args.py_model_setup and args.torchscript:
+ model_setup = f"{args.py_model_setup}\njit_model = torch.jit.script(model)"
+
+ # This is just for debugging. We just need a unique name for the
+ # model, but embedding the label makes debugging easier.
+ name: str = re.sub(r'[^a-z0-9_]', '_', '_'.join(label).lower())
+ name = f"{name}_{uuid.uuid4()}"
+
+ model_path = _generate_torchscript_file(model_setup, name=name)
+
+ for (runtime, autograd, language), num_threads in it.product(_ALL_MODES, args.num_threads):
+ if runtime == RuntimeMode.EXPLICIT or autograd == AutogradMode.EXPLICIT:
+ continue
+
+ if runtime == RuntimeMode.JIT and not args.torchscript:
+ continue
+
+ if autograd == AutogradMode.FORWARD_BACKWARD and not args.autograd:
+ continue
+
+ stmt = _get_stmt(args, runtime, autograd, language)
+ if stmt is None:
+ continue
+
+ setup = _get_setup(args, runtime, language, stmt, model_path)
+
+ global_setup: str = ""
+ if language == Language.CPP and runtime == RuntimeMode.JIT:
+ global_setup = textwrap.dedent("""
+ #include <string>
+ #include <vector>
+ #include <torch/script.h>
+ """)
+
+ autolabels = AutoLabels(runtime, autograd, language)
+ timer_args = TimerArgs(
+ stmt=stmt,
+ setup=setup,
+ global_setup=global_setup,
+ num_threads=num_threads,
+ language=language,
+ )
+
+ results.append((label, autolabels, timer_args))
+
+ return tuple(results)
diff --git a/benchmarks/instruction_counts/core/types.py b/benchmarks/instruction_counts/core/types.py
new file mode 100644
index 0000000000..8f36bfbcec
--- /dev/null
+++ b/benchmarks/instruction_counts/core/types.py
@@ -0,0 +1,94 @@
+"""Type annotations for various benchmark objects."""
+from typing import Any, Dict, Optional, Tuple, Union
+
+from core.api import AutoLabels, TimerArgs, GroupedBenchmark
+
+
+# =============================================================================
+# == Benchmark schema =========================================================
+# =============================================================================
+""" (There is a TL;DR at the end for ad-hoc benchmarks.)
+The end state for representing a benchmark is:
+ ```
+ Tuple[
+ Tuple[
+ Tuple[str, ...], # Primary key
+ core.api.AutoLabels, # Secondary key
+ core.api.TimerArgs, # Value
+ ],
+ ...
+ ]
+ ```
+
+For example:
+ ```
+ [
+ (("pointwise", "add"), AutoLabels(..., Language.PYTHON), TimerArgs(...)),
+ (("pointwise", "add"), AutoLabels(..., Language.CPP), TimerArgs(...)),
+ ...
+ ]
+ ```
+
+However, such a flat list is somewhat tedious to maintain (and read), because
+there is significant duplication in the key structure. So instead, we would
+like to define something like:
+ ```
+ {
+ "pointwise" : {
+ "add": {
+ None: GroupedStmts(...),
+ "with alpha": GroupedStmts(...),
+ },
+ "mul": GroupedStmts(...),
+ },
+ "matmul": GroupedStmts(...),
+ }
+ ```
+and then parse out a flat representation. The type declarations below are
+simply formalizing the structure of nested dictionaries with string or tuple
+of string keys.
+
+TL;DR
+ If you only care about writing an ad-hoc benchmark for a PR, just use a
+ flat dictionary and everything will work. For example:
+ ```
+ {
+ "case 0": TimerArgs(...),
+ "case 1": TimerArgs(...),
+ "case 2": GroupedStmts(...),
+ ...
+ }
+ ```
+"""
+
+# Allow strings in definition for convenience, and None to signify a base
+# case. (No subsequent entry needed. See the "add" example above.)
+Label = Tuple[str, ...]
+_Label = Union[Label, Optional[str]]
+
+# MyPy does not currently support recursive types:
+# https://github.com/python/mypy/issues/731
+#
+# So while the correct type definition would be:
+# _Value = Union[
+# # Base case:
+# Union[TimerArgs, GroupedBenchmark],
+#
+# # Recursive case:
+# Dict[Label, "_Value"],
+# ]
+# we instead have to use Any and rely on runtime asserts when flattening.
+_Value = Union[
+ Union[TimerArgs, GroupedBenchmark],
+ Dict[_Label, Any],
+]
+
+Definition = Dict[_Label, _Value]
+
+# We initially have to parse (flatten) to an intermediate state in order to
+# build TorchScript models since multiple entries will share the same model
+# artifact.
+FlatIntermediateDefinition = Dict[Label, Union[TimerArgs, GroupedBenchmark]]
+
+# Final parsed schema.
+FlatDefinition = Tuple[Tuple[Label, AutoLabels, TimerArgs], ...]
diff --git a/benchmarks/instruction_counts/core/utils.py b/benchmarks/instruction_counts/core/utils.py
new file mode 100644
index 0000000000..1da98c80df
--- /dev/null
+++ b/benchmarks/instruction_counts/core/utils.py
@@ -0,0 +1,99 @@
+import atexit
+import shutil
+import re
+import tempfile
+import textwrap
+from typing import List, Optional, Tuple
+
+from core.api import GroupedBenchmark, TimerArgs
+from core.types import Definition, FlatIntermediateDefinition, Label
+
+
+_TEMPDIR: Optional[str] = None
+def get_temp_dir() -> str:
+ global _TEMPDIR
+ if _TEMPDIR is None:
+ temp_dir = tempfile.mkdtemp()
+ atexit.register(shutil.rmtree, path=temp_dir)
+ _TEMPDIR = temp_dir
+ return _TEMPDIR
+
+
+def _flatten(
+ key_prefix: Label,
+ sub_schema: Definition,
+ result: FlatIntermediateDefinition
+) -> None:
+ for k, value in sub_schema.items():
+ if isinstance(k, tuple):
+ assert all(isinstance(ki, str) for ki in k)
+ key_suffix: Label = k
+ elif k is None:
+ key_suffix = ()
+ else:
+ assert isinstance(k, str)
+ key_suffix = (k,)
+
+ key: Label = key_prefix + key_suffix
+ if isinstance(value, (TimerArgs, GroupedBenchmark)):
+ assert key not in result, f"duplicate key: {key}"
+ result[key] = value
+ else:
+ assert isinstance(value, dict)
+ _flatten(key_prefix=key, sub_schema=value, result=result)
+
+
+def flatten(schema: Definition) -> FlatIntermediateDefinition:
+ """See types.py for an explanation of nested vs. flat definitions."""
+ result: FlatIntermediateDefinition = {}
+ _flatten(key_prefix=(), sub_schema=schema, result=result)
+
+ # Ensure that we produced a valid flat definition.
+ for k, v in result.items():
+ assert isinstance(k, tuple)
+ assert all(isinstance(ki, str) for ki in k)
+ assert isinstance(v, (TimerArgs, GroupedBenchmark))
+ return result
+
+
+def parse_stmts(stmts: str) -> Tuple[str, str]:
+ """Helper function for side-by-side Python and C++ stmts.
+
+ For more complex statements, it can be useful to see Python and C++ code
+ side by side. To this end, we provide an **extremely restricted** way
+ to define Python and C++ code side-by-side. The schema should be mostly
+ self explanatory, with the following non-obvious caveats:
+ - Width for the left (Python) column MUST be 40 characters.
+ - The column separator is " | ", not "|". Whitespace matters.
+ """
+ stmts = textwrap.dedent(stmts).strip()
+ lines: List[str] = stmts.splitlines(keepends=False)
+ assert len(lines) >= 3, f"Invalid string:\n{stmts}"
+
+ column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$"
+ signature_pattern = r"^: f\((.*)\)( -> (.+))?\s*$"
+ separation_pattern = r"^[-]{40} | [-]{40}$"
+ code_pattern = r"^(.{40}) \|($| (.*)$)"
+
+ column_match = re.search(column_header_pattern, lines[0])
+ if column_match is None:
+ raise ValueError(
+ f"Column header `{lines[0]}` "
+ f"does not match pattern `{column_header_pattern}`")
+
+ assert re.search(separation_pattern, lines[1])
+
+ py_lines: List[str] = []
+ cpp_lines: List[str] = []
+ for l in lines[2:]:
+ l_match = re.search(code_pattern, l)
+ if l_match is None:
+ raise ValueError(f"Invalid line `{l}`")
+ py_lines.append(l_match.groups()[0])
+ cpp_lines.append(l_match.groups()[2] or "")
+
+ # Make sure we can round trip for correctness.
+ l_from_stmts = f"{py_lines[-1]:<40} | {cpp_lines[-1]:<40}".rstrip()
+ assert l_from_stmts == l.rstrip(), f"Failed to round trip `{l}`"
+
+ return "\n".join(py_lines), "\n".join(cpp_lines)
diff --git a/benchmarks/instruction_counts/definitions/__init__.py b/benchmarks/instruction_counts/definitions/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/benchmarks/instruction_counts/definitions/__init__.py
diff --git a/benchmarks/instruction_counts/definitions/setup.py b/benchmarks/instruction_counts/definitions/setup.py
new file mode 100644
index 0000000000..6bd044702b
--- /dev/null
+++ b/benchmarks/instruction_counts/definitions/setup.py
@@ -0,0 +1,57 @@
+"""Define some common setup blocks which benchmarks can reuse."""
+
+import enum
+
+from core.api import GroupedSetup
+from core.utils import parse_stmts
+
+
+_TRIVIAL_2D = GroupedSetup(
+ r"x = torch.ones((4, 4))",
+ r"auto x = torch::ones({4, 4});"
+)
+
+
+_TRIVIAL_4D = GroupedSetup(
+ r"x = torch.ones((4, 4, 4, 4))",
+ r"auto x = torch::ones({4, 4, 4, 4});"
+)
+
+
+_GENERIC = GroupedSetup(*parse_stmts(
+ r"""
+ Python | C++
+ ---------------------------------------- | ----------------------------------------
+ torch.manual_seed(138_10_23) | torch::manual_seed(1381023);
+ x = torch.rand((4, 4)) | auto x = torch::rand({4, 4});
+ y_float = torch.ones((4, 4)) | auto y_float = torch::ones({4, 4});
+ y_int = torch.ones( | auto y_int = torch::ones({4, 4}, at::kInt);
+ (4, 4), dtype=torch.int32) |
+ """
+))
+
+
+_TRAINING = GroupedSetup(*parse_stmts(
+ r"""
+ Python | C++
+ ---------------------------------------- | ----------------------------------------
+ # Inputs | // Inputs
+ x = torch.ones((1,)) | auto x = torch::ones({1});
+ y = torch.ones((1,)) | auto y = torch::ones({1});
+ |
+ # Weights | // Weights
+ w0 = torch.ones( | auto w0 = torch::ones({1});
+ (1,), requires_grad=True) | w0.set_requires_grad(true);
+ w1 = torch.ones( | auto w1 = torch::ones({1});
+ (1,), requires_grad=True) | w1.set_requires_grad(true);
+ w2 = torch.ones( | auto w2 = torch::ones({2});
+ (2,), requires_grad=True) | w2.set_requires_grad(true);
+ """
+))
+
+
+class Setup(enum.Enum):
+ TRIVIAL_2D = _TRIVIAL_2D
+ TRIVIAL_4D = _TRIVIAL_4D
+ GENERIC = _GENERIC
+ TRAINING = _TRAINING
diff --git a/benchmarks/instruction_counts/definitions/standard.py b/benchmarks/instruction_counts/definitions/standard.py
new file mode 100644
index 0000000000..cf4583cd00
--- /dev/null
+++ b/benchmarks/instruction_counts/definitions/standard.py
@@ -0,0 +1,143 @@
+"""Default set of benchmarks.
+
+Parser notes:
+ `parse_stmts`:
+ - Width for the left (Python) column MUST be 40 characters.
+ - The column separator is " | ", not "|". Whitespace matters.
+
+ `GroupedVariants`:
+ - `Setup` and `Global_Setup` (case insensitive) are reserved keywords
+ to populate `setup` and `global_setup` for every generated benchmark.
+ - To set a label for the succeeding block, add `# @YOUR_LABEL` (Python)
+ or `// @YOUR_LABEL` (C++).
+"""
+
+from core.api import GroupedModules, GroupedStmts, GroupedVariants
+from core.types import FlatIntermediateDefinition
+from core.utils import flatten, parse_stmts
+from definitions.setup import Setup
+
+
+BENCHMARKS: FlatIntermediateDefinition = flatten({
+ "Empty": {
+ "no allocation": GroupedStmts(
+ r"torch.empty(())",
+ r"torch::empty({0});",
+ ),
+
+ "with allocation": GroupedStmts(
+ r"torch.empty((1,))",
+ r"torch::empty({1});",
+ ),
+
+ "overloads": GroupedVariants(
+ cpp_block=r"""
+ // @Setup
+ auto options_empty = c10::TensorOptions();
+ auto options_full = c10::TensorOptions().dtype(at::kFloat).device(at::kCPU);
+ auto optional_float = c10::make_optional(at::kFloat);
+
+ // @TensorOptions overload
+ at::empty({0}, options_empty);
+ at::empty({0}, options_full);
+ at::empty({0}, at::kFloat); // implicit conversion
+
+ // @Faithful overload
+ at::empty({0}, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
+ at::empty({0}, at::kFloat, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
+ at::empty({0}, optional_float, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt);
+ """
+ ),
+ },
+
+ "Pointwise": {
+ "Math": {
+ "add": {
+ "Tensor-Scalar": GroupedStmts(
+ r"x += 1.0",
+ r"x += 1.0;",
+ setup=Setup.GENERIC.value,
+ ),
+ },
+ },
+ },
+
+ "Indexing": GroupedVariants(*parse_stmts(r"""
+ Python | C++
+ ---------------------------------------- | ----------------------------------------
+ # @setup | // @setup
+ | using namespace torch::indexing;
+ torch.manual_seed(6626_10_34) | torch::manual_seed(66261034);
+ |
+ x = torch.randn(1, 1, 1) | auto x = torch::randn({1, 1, 1});
+ y = torch.randn(1, 1, 1) | auto y = torch::randn({1, 1, 1});
+ |
+ # @Tensor-Scalar | // @Tensor-Scalar
+ x[0] = 1 | x.index_put_({0}, 1);
+ x[0, 0] = 1 | x.index_put_({0, 0}, 1);
+ x[0, 0, 0] = 1 | x.index_put_({0, 0, 0}, 1);
+ |
+ # @Tensor-Scalar (Advanced) | // @Tensor-Scalar (Advanced)
+ x[...] = 1 | x.index_put_({"..."}, 1);
+ x[:] = 1 | x.index_put_({Slice(None, None, None)}, 1);
+ x[None] = 1 | x.index_put_({None}, 1);
+ x[False] = 1 | x.index_put_({false}, 1);
+ x[True] = 1 | x.index_put_({true}, 1);
+ |
+ # @Tensor-Tensor | // @Tensor-Tensor
+ x[0] = y[0] | x.index_put_({0}, y.index({0}));
+ x[0, 0] = y[0, 0] | x.index_put_({0, 0}, y.index({0, 0}));
+ x[0, 0, 0] = y[0, 0, 0] | x.index_put_({0, 0, 0}, y.index({0, 0, 0}));
+ |
+ # @Tensor-Tensor (Advanced) | // @Tensor-Tensor (Advanced)
+ x[...] = y[...] | x.index_put_({"..."}, y.index({"..."}));
+ x[:] = y[:] | x.index_put_({Slice(None, None, None)}, y.index({Slice(None, None, None)}));
+ x[None] = y[None] | x.index_put_({None}, y.index({None}));
+ x[False] = y[False] | x.index_put_({false}, y.index({false}));
+ x[True] = y[True] | x.index_put_({true}, y.index({true}));
+ """)),
+
+ "nn Modules": {
+ "Linear": GroupedModules(
+ "model = torch.nn.Linear(4, 2)",
+ "auto model = torch::nn::Linear(4, 2);",
+ setup=Setup.TRIVIAL_4D.value,
+ signature="f(x) -> y",
+ torchscript=True,
+ ),
+ },
+
+ "training": {
+ "simple": GroupedStmts(
+ *parse_stmts(r"""
+ Python | C++
+ ---------------------------------------- | ----------------------------------------
+ a0 = torch.nn.functional.relu(x * w0) | auto a0 = torch::nn::functional::relu(x * w0);
+ y = a0 * w1 | auto y = a0 * w1;
+ """),
+ Setup.TRAINING.value,
+ num_threads=(1, 2),
+ signature=r"f(x, w0, w1) -> y",
+ torchscript=True,
+ autograd=True,
+ ),
+
+ "ensemble": GroupedStmts(
+ *parse_stmts(r"""
+ Python | C++
+ ---------------------------------------- | ----------------------------------------
+ a0 = torch.nn.functional.gelu(x * w0) | auto a0 = torch::nn::functional::gelu(x * w0);
+ a1 = torch.nn.functional.prelu(y, w1) | auto a1 = torch::nn::functional::prelu(y, w1);
+ z = torch.nn.functional.normalize( | auto z = torch::nn::functional::normalize(
+ torch.cat([a0, a1]), | torch::cat({a0, a1}),
+ p=2.0, dim=0, | torch::nn::functional::NormalizeFuncOptions().p(2).dim(0)
+ ).dot(w2) | ).dot(w2);
+ """),
+ Setup.TRAINING.value,
+ num_threads=(1, 2),
+ signature=r"f(x, y, w0, w1, w2) -> z",
+ torchscript=True,
+ autograd=True,
+ ),
+ },
+})
diff --git a/benchmarks/instruction_counts/main.py b/benchmarks/instruction_counts/main.py
new file mode 100644
index 0000000000..0f8773a9cf
--- /dev/null
+++ b/benchmarks/instruction_counts/main.py
@@ -0,0 +1,57 @@
+"""Basic runner for the instruction count microbenchmarks.
+
+The contents of this file are placeholders, and will be replaced by more
+expressive and robust components (e.g. better runner and result display
+components) in future iterations. However this allows us to excercise the
+underlying benchmark generation infrastructure in the mean time.
+"""
+import multiprocessing
+import multiprocessing.dummy
+import os
+import pickle
+import subprocess
+from typing import Tuple
+
+from core.api import AutoLabels, TimerArgs
+from core.expand import materialize
+from core.types import Label
+from core.utils import get_temp_dir
+from definitions.standard import BENCHMARKS
+from worker.main import WORKER_PATH, WorkerFailure, WorkerOutput, WorkerTimerArgs, WorkerUnpickler
+
+
+def call_worker(
+ args: Tuple[int, Tuple[Label, AutoLabels, TimerArgs]]
+) -> Tuple[Label, AutoLabels, int, WorkerOutput]:
+ worker_id, (label, autolabels, timer_args) = args
+
+ communication_file = os.path.join(get_temp_dir(), f"communication_file_{worker_id}.pkl")
+ with open(communication_file, "wb") as f:
+ pickle.dump(timer_args, f)
+
+ subprocess.call(
+ ["python", WORKER_PATH, "--communication_file", communication_file],
+ shell=False,
+ )
+
+ with open(communication_file, "rb") as f:
+ result = WorkerUnpickler(f).load_output()
+
+ if isinstance(result, WorkerTimerArgs):
+ raise RuntimeError("Benchmark worker failed without starting.")
+
+ elif isinstance(result, WorkerFailure):
+ raise RuntimeError(f"Worker failed: {label} {autolabels}\n{result.failure_trace}")
+
+ assert isinstance(result, WorkerOutput)
+ return label, autolabels, timer_args.num_threads, result
+
+
+def main() -> None:
+ with multiprocessing.dummy.Pool(multiprocessing.cpu_count() - 4) as pool:
+ for label, autolabels, num_threads, result in pool.imap(call_worker, enumerate(materialize(BENCHMARKS)), 1):
+ print(label, autolabels, num_threads, result.instructions)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmarks/instruction_counts/worker/__init__.py b/benchmarks/instruction_counts/worker/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/benchmarks/instruction_counts/worker/__init__.py
diff --git a/benchmarks/instruction_counts/worker/main.py b/benchmarks/instruction_counts/worker/main.py
new file mode 100644
index 0000000000..f59509de74
--- /dev/null
+++ b/benchmarks/instruction_counts/worker/main.py
@@ -0,0 +1,188 @@
+"""File invoked through subprocess to actually carry out measurements.
+
+`worker/main.py` is deliberately isolated from the rest of the benchmark
+infrastructure. Other parts of the benchmark rely on this file, but
+`worker/` has only one Python file and does not import ANYTHING from the rest
+of the benchmark suite. The reason that this is important is that we can't
+rely on paths to access the other files (namely `core.api`) since a source
+command might change the CWD. It also helps keep startup time down by limiting
+spurious definition work.
+
+The life of a worker is very simple:
+ It receives a file containing a `WorkerTimerArgs` telling it what to run,
+ and writes a `WorkerOutput` result back to the same file.
+
+Because this file only expects to run in a child context, error handling means
+plumbing failures up to the caller, not raising in this process.
+"""
+import argparse
+import dataclasses
+import io
+import os
+import pickle
+import timeit
+import traceback
+from typing import Any, Tuple, Union, TYPE_CHECKING
+import sys
+
+
+if TYPE_CHECKING:
+ # Benchmark utils are only partially strict compliant, so MyPy won't follow
+ # imports using the public namespace. (Due to an exclusion rule in
+ # mypy-strict.ini)
+ from torch.utils.benchmark.utils.timer import Language, Timer
+ from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import CallgrindStats
+
+else:
+ from torch.utils.benchmark import CallgrindStats, Language, Timer
+
+
+WORKER_PATH = os.path.abspath(__file__)
+
+
+# =============================================================================
+# == Interface ================================================================
+# =============================================================================
+
+# While the point of this is mainly to collect instruction counts, we're going
+# to have to compile C++ timers anyway (as they're used as a check before
+# calling Valgrind), so we may as well grab wall times for reference. They
+# are comparatively inexpensive.
+MIN_RUN_TIME = 5
+
+# Repeats are inexpensive as long as they are all run in the same process. This
+# also lets us filter outliers (e.g. malloc arena reorganization), so we don't
+# need a high CALLGRIND_NUMBER to get good data.
+CALLGRIND_NUMBER = 100
+CALLGRIND_REPEATS = 5
+
+
+@dataclasses.dataclass(frozen=True)
+class WorkerTimerArgs:
+ """Container for Timer constructor arguments.
+
+ This dataclass serves two roles. First, it is a simple interface for
+ defining benchmarks. (See core.api.GroupedStmts and core.api.GroupedModules
+ for the advanced interfaces.) Second, it provides serialization for
+ controlling workers. `Timer` is not pickleable, so instead the main process
+ will pass `WorkerTimerArgs` instances to workers for processing.
+ """
+ stmt: str
+ setup: str = "pass"
+ global_setup: str = ""
+ num_threads: int = 1
+ language: Language = Language.PYTHON
+
+
+@dataclasses.dataclass(frozen=True)
+class WorkerOutput:
+ # Only return values to reduce communication between main process and workers.
+ wall_times: Tuple[float, ...]
+ instructions: Tuple[int, ...]
+
+
+@dataclasses.dataclass(frozen=True)
+class WorkerFailure:
+ # If a worker fails, we attach the string contents of the Exception
+ # rather than the Exception object itself. This is done for two reasons:
+ # 1) Depending on the type thrown, `e` may or may not be pickleable
+ # 2) If we re-throw in the main process, we lose the true stack trace.
+ failure_trace: str
+
+
+class WorkerUnpickler(pickle.Unpickler):
+ def find_class(self, module: str, name: str) -> Any:
+ """Resolve import for pickle.
+
+ When the main runner uses a symbol `foo` from this file, it sees it as
+ `worker.main.foo`. However the worker (called as a standalone file)
+ sees the same symbol as `__main__.foo`. We have to help pickle
+ understand that they refer to the same symbols.
+ """
+ symbol_map = {
+ # Only blessed interface Enums and dataclasses need to be mapped.
+ "WorkerTimerArgs": WorkerTimerArgs,
+ "WorkerOutput": WorkerOutput,
+ "WorkerFailure": WorkerFailure,
+ }
+
+ if name in symbol_map:
+ return symbol_map[name]
+
+ return super().find_class(module, name)
+
+ def load_input(self) -> WorkerTimerArgs:
+ result = self.load()
+ assert isinstance(result, WorkerTimerArgs)
+ return result
+
+ def load_output(self) -> Union[WorkerTimerArgs, WorkerOutput, WorkerFailure]:
+ """Convenience method for type safe loading."""
+ result = self.load()
+ assert isinstance(result, (WorkerTimerArgs, WorkerOutput, WorkerFailure))
+ return result
+
+
+# =============================================================================
+# == Execution ================================================================
+# =============================================================================
+
+def _run(timer_args: WorkerTimerArgs) -> WorkerOutput:
+ timer = Timer(
+ stmt=timer_args.stmt,
+ setup=timer_args.setup or "pass",
+ global_setup=timer_args.global_setup,
+
+ # Prevent NotImplementedError on GPU builds and C++ snippets.
+ timer=timeit.default_timer,
+ num_threads=timer_args.num_threads,
+ language=timer_args.language,
+ )
+
+ m = timer.blocked_autorange(min_run_time=MIN_RUN_TIME)
+
+ stats: Tuple[CallgrindStats, ...] = timer.collect_callgrind(
+ number=CALLGRIND_NUMBER,
+ collect_baseline=False,
+ repeats=CALLGRIND_REPEATS,
+ retain_out_file=False,
+ )
+
+ return WorkerOutput(
+ wall_times=tuple(m.times),
+ instructions=tuple(s.counts(denoise=True) for s in stats)
+ )
+
+
+def main(communication_file: str) -> None:
+ result: Union[WorkerOutput, WorkerFailure]
+ try:
+ with open(communication_file, "rb") as f:
+ timer_args: WorkerTimerArgs = WorkerUnpickler(f).load_input()
+ assert isinstance(timer_args, WorkerTimerArgs)
+ result = _run(timer_args)
+
+ except KeyboardInterrupt:
+ # Runner process sent SIGINT.
+ sys.exit()
+
+ except BaseException:
+ trace_f = io.StringIO()
+ traceback.print_exc(file=trace_f)
+ result = WorkerFailure(failure_trace=trace_f.getvalue())
+
+ if not os.path.exists(os.path.split(communication_file)[0]):
+ # This worker is an orphan, and the parent has already cleaned up the
+ # working directory. In that case we can simply exit.
+ print(f"Orphaned worker {os.getpid()} exiting.")
+ return
+
+ with open(communication_file, "wb") as f:
+ pickle.dump(result, f)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--communication_file', type=str)
+ communication_file = parser.parse_args().communication_file
+ main(communication_file)
diff --git a/mypy-strict.ini b/mypy-strict.ini
index e41ca8bf03..4608f7cf70 100644
--- a/mypy-strict.ini
+++ b/mypy-strict.ini
@@ -36,6 +36,8 @@ strict_equality = True
files =
.github/scripts/generate_binary_build_matrix.py,
+ benchmarks/instruction_counts/*.py,
+ benchmarks/instruction_counts/*/*.py,
tools/autograd/*.py,
tools/codegen/gen.py,
tools/mypy_wrapper.py,