diff options
-rw-r--r-- | benchmarks/instruction_counts/README.md | 142 | ||||
-rw-r--r-- | benchmarks/instruction_counts/core/__init__.py | 0 | ||||
-rw-r--r-- | benchmarks/instruction_counts/core/api.py | 402 | ||||
-rw-r--r-- | benchmarks/instruction_counts/core/expand.py | 260 | ||||
-rw-r--r-- | benchmarks/instruction_counts/core/types.py | 94 | ||||
-rw-r--r-- | benchmarks/instruction_counts/core/utils.py | 99 | ||||
-rw-r--r-- | benchmarks/instruction_counts/definitions/__init__.py | 0 | ||||
-rw-r--r-- | benchmarks/instruction_counts/definitions/setup.py | 57 | ||||
-rw-r--r-- | benchmarks/instruction_counts/definitions/standard.py | 143 | ||||
-rw-r--r-- | benchmarks/instruction_counts/main.py | 57 | ||||
-rw-r--r-- | benchmarks/instruction_counts/worker/__init__.py | 0 | ||||
-rw-r--r-- | benchmarks/instruction_counts/worker/main.py | 188 | ||||
-rw-r--r-- | mypy-strict.ini | 2 |
13 files changed, 1444 insertions, 0 deletions
diff --git a/benchmarks/instruction_counts/README.md b/benchmarks/instruction_counts/README.md new file mode 100644 index 0000000000..ed2633caba --- /dev/null +++ b/benchmarks/instruction_counts/README.md @@ -0,0 +1,142 @@ +# Instruction count microbenchmarks +## Quick start + +### To run the benchmark: + +``` +# From pytorch root +cd benchmarks/instruction_counts +python main.py +``` + +Currently `main.py` contains a very simple threadpool (so that run time isn't +unbearably onerous) and simply prints the results. These components will be +upgraded in subsequent PRs. + +### To define a new benchmark: +* `TimerArgs`: Low level definition which maps directly to +`torch.utils.benchmark.Timer` +* `GroupedStmts`: Benchmark a snippet. (Python, C++, or both) Can automatically +generate TorchScript and autograd variants. +* `GroupedModules`: Like `GroupedStmts`, but takes `nn.Module`s +* `GroupedVariants`: Benchmark-per-line to define many related benchmarks in a +single code block. + +## Architecture +### Benchmark definition. + +One primary goal of this suite is to make it easy to define semantically +related clusters of benchmarks. The crux of this effort is the +`GroupedBenchmark` class, which is defined in `core/api.py`. It takes a +definition for a set of related benchmarks, and produces one or more concrete +cases. It's helpful to see an example to understand how the machinery works. +Consider the following benchmark: + +``` +# `GroupedStmts` is an alias of `GroupedBenchmark.init_from_stmts` +benchmark = GroupedStmts( + py_stmt=r"y = x * w", + cpp_stmt=r"auto y = x * w;", + + setup=GroupedSetup( + py_setup=""" + x = torch.ones((4, 4)) + w = torch.ones((4, 4), requires_grad=True) + """, + cpp_setup=""" + auto x = torch::ones((4, 4)); + auto w = torch::ones((4, 4)); + w.set_requires_grad(true); + """, + ), + + signature="f(x, w) -> y", + torchscript=True, + autograd=True, +), +``` + +It is trivial to generate Timers for the eager forward mode case (ignoring +`num_threads` for now): + +``` +Timer( + stmt=benchmark.py_fwd_stmt, + setup=benchmark.setup.py_setup, +) + +Timer( + stmt=benchmark.cpp_fwd_stmt, + setup=benchmark.setup.cpp_setup, + language="cpp", +) +``` + +Moreover, because `signature` is provided we know that creation of `x` and `w` +is part of setup, and the overall comptation uses `x` and `w` to produce `y`. +As a result, we can derive TorchScript'd and AutoGrad variants as well. We can +deduce that a TorchScript model will take the form: + +``` +@torch.jit.script +def f(x, w): + # Paste `benchmark.py_fwd_stmt` into the function body. + y = x * w + return y # Set by `-> y` in signature. +``` + +And because we will want to use this model in both Python and C++, we save it to +disk and load it as needed. At this point Timers for TorchScript become: + +``` +Timer( + stmt=""" + y = jit_model(x, w) + """, + setup=""", + # benchmark.setup.py_setup + # jit_model = torch.jit.load(...) + # Warm up jit_model + """, +) + +Timer( + stmt=""" + std::vector<torch::jit::IValue> ivalue_inputs( + torch::jit::IValue({x}), + torch::jit::IValue({w}) + ); + auto y = jit_model.forward(ivalue_inputs); + """, + setup=""" + # benchmark.setup.cpp_setup + # jit_model = torch::jit::load(...) + # Warm up jit_model + """, +) +``` + +While nothing above is particularly complex, there is non-trivial bookkeeping +(managing the model artifact, setting up IValues) which if done manually would +be rather bug-prone and hard to read. + +The story is similar for autograd: because we know the output variable (`y`) +and we make sure to assign it when calling TorchScript models, testing AutoGrad +is as simple as appending `y.backward()` (or `y.backward();` in C++) to the +stmt of the forward only variant. Of course this requires that `signature` be +provided, as there is nothing special about the name `y`. + +The logic for the manipulations above is split between `core/api.py` (for +generating `stmt` based on language, Eager/TorchScript, with or without AutoGrad) +and `core/expand.py` (for larger, more expansive generation). The benchmarks +themselves are defined in `definitions/standard.py`. The current set is chosen +to demonstrate the various model definition APIs, and will be expanded when the +benchmark runner infrastructure is better equipped to deal with a larger run. + +### Benchmark execution. + +Once `expand.materialize` has flattened the abstract benchmark definitions into +`TimerArgs`, they can be sent to a worker (`worker/main.py`) subprocess to +execution. This worker has no concept of the larger benchmark suite; `TimerArgs` +is a one-to-one and direct mapping to the `torch.utils.benchmark.Timer` instance +that the worker instantiates. diff --git a/benchmarks/instruction_counts/core/__init__.py b/benchmarks/instruction_counts/core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/benchmarks/instruction_counts/core/__init__.py diff --git a/benchmarks/instruction_counts/core/api.py b/benchmarks/instruction_counts/core/api.py new file mode 100644 index 0000000000..b29c468aba --- /dev/null +++ b/benchmarks/instruction_counts/core/api.py @@ -0,0 +1,402 @@ +"""Key enums and structs used to handle data flow within the benchmark.""" +import dataclasses +import enum +import re +import textwrap +from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING + +from worker.main import WorkerTimerArgs + +if TYPE_CHECKING: + # Benchmark utils are only partially strict compliant, so MyPy won't follow + # imports using the public namespace. (Due to an exclusion rule in + # mypy-strict.ini) + from torch.utils.benchmark.utils.timer import Language +else: + from torch.utils.benchmark import Language + + +# Note: +# WorkerTimerArgs is defined in worker.main so that the worker does not +# depend on any files, including core.api. We mirror it with a public symbol +# `TimerArgs` for API consistency. +TimerArgs = WorkerTimerArgs + + +class RuntimeMode(enum.Enum): + EAGER = "Eager" + JIT = "TorchScript" + EXPLICIT = "" + + +class AutogradMode(enum.Enum): + FORWARD = "Forward" + FORWARD_BACKWARD = "Forward + Backward" + EXPLICIT = "" + + +@dataclasses.dataclass(frozen=True) +class AutoLabels: + """Labels for a TimerArgs instance which are inferred during unpacking.""" + runtime: RuntimeMode + autograd: AutogradMode + language: Language + + +@dataclasses.dataclass(frozen=True) +class GroupedSetup: + py_setup: str = "" + cpp_setup: str = "" + global_setup: str = "" + + def __post_init__(self) -> None: + for field in dataclasses.fields(self): + assert field.type == str + value: str = getattr(self, field.name) + object.__setattr__(self, field.name, textwrap.dedent(value)) + + +@dataclasses.dataclass(frozen=True) +class GroupedBenchmark: + """Base class for defining groups of benchmarks. + + Concrete interfaces: + - `core.api.GroupedStmts` (init_from_stmts) + - `core.api.GroupedModules` (init_from_model) + - `core.api.GroupedVariants` (init_from_variants) + + There are a variety of dimensions along which one might wish to measure + PyTorch performance: + - Python, C++ + - Eager, TorchScript + - Single threaded, multi threaded + - Training, inference + + It is useful to define them together, both for clear, concise benchmark + definition and more intelligent post processing and analysis. + + There are also two programming idioms in PyTorch. One is to write free form + code (so-called "NumPy with gradients"), and the other is to organize code + using `torch.nn.Module`s. (This is how common neural network layers are + exposed through the PyTorch API.) To support easy definition two simple + initialization methods are provided: + - `init_from_stmts` + - `init_from_model` + + Those methods will document their unique constructor arguments, however + most are shared and are defined here: + setup: Defines how to initialize a benchmark in both Python and C++. + signature: + A string of the form: + ``` + f(a, b, ...) -> c + ``` + For instance, if Python setup is: + ``` + x = torch.ones((2,), requires_grad=True) + y = torch.ones((2,)) + ``` + and the corresponding stmt is: + ``` + z = torch.dot(x, y) + ``` + Then the signature is `f(x, y) -> z`. `signature` is required any + time we need to generate part of a snippet: + - When calling an opaque model provided by `init_from_models` + - When `torchscript=True` + - When `autograd=True` + + If a return value is not needed (e.g. because of in place mutation) + then `-> None` is valid, but a non-None return must be provided if + `autograd=True` + + torchscript: + If True, also JIT the stmt or model and generate benchmarks which + call the scripted version. Requires that `signature` is defined. + + autograd: + If True, generate both forward and forward + backward benchmarks. + Requires that `signature` is defined, and return value is not None. + + num_threads: + Maps to the Timer arg. If a tuple of ints is provided, benchmarks + will be generated for each value. + + A third method, `init_from_variants`, is provided to define several related + benchmarks at once. + """ + + # These are the stmts which are actually executed by Timer. In the case of + # `GroupedStmts` (init_from_stmts) they are passed through from user args. + # In the case of `GroupedModules` (init_from_model) they are generated + # using `signature`. (e.g. `f(x, y) -> z` generates `z = model(x, y)`) + py_fwd_stmt: Optional[str] + cpp_fwd_stmt: Optional[str] + + # Code block used to define a model. `init_from_stmts` will never populate + # `cpp_model_setup`, but if TorchScript is requested it will generate + # `py_model_setup` using `torch.jit.script`. + py_model_setup: Optional[str] + cpp_model_setup: Optional[str] + + # True if this benchmark used `init_from_stmts`, otherwise False. + inferred_model_setup: bool + + # Described above + setup: GroupedSetup + signature_args: Optional[Tuple[str, ...]] + signature_output: Optional[str] + torchscript: bool + autograd: bool + num_threads: Tuple[int, ...] + + @classmethod + def init_from_stmts( + cls, + py_stmt: Optional[str] = None, + cpp_stmt: Optional[str] = None, + + # Generic constructor arguments + setup: GroupedSetup = GroupedSetup(), + signature: Optional[str] = None, + torchscript: bool = False, + autograd: bool = False, + num_threads: Union[int, Tuple[int, ...]] = 1, + ) -> "GroupedBenchmark": + """Create a set of benchmarks from free-form statements. + + This method of benchmark definition is analogous to Timer use, where + we simply execute the provided stmts. + """ + if py_stmt is not None: + py_stmt = textwrap.dedent(py_stmt) + + if cpp_stmt is not None: + cpp_stmt = textwrap.dedent(cpp_stmt) + + signature_args, signature_output = cls._parse_signature(signature) + py_model_setup = ( + cls._model_from_py_stmt( + py_stmt=py_stmt, + signature_args=signature_args, + signature_output=signature_output + ) if torchscript else None + ) + + return cls( + py_fwd_stmt=py_stmt, + cpp_fwd_stmt=cpp_stmt, + py_model_setup=py_model_setup, + cpp_model_setup=None, + inferred_model_setup=True, + setup=setup, + signature_args=signature_args, + signature_output=signature_output, + torchscript=torchscript, + autograd=autograd, + num_threads=(num_threads,) if isinstance(num_threads, int) else num_threads, + ) + + @classmethod + def init_from_model( + cls, + py_model_setup: Optional[str] = None, + cpp_model_setup: Optional[str] = None, + + # Generic constructor arguments + setup: GroupedSetup = GroupedSetup(), + signature: Optional[str] = None, + torchscript: bool = False, + autograd: bool = False, + num_threads: Union[int, Tuple[int, ...]] = 1, + ) -> "GroupedBenchmark": + """Create a set of benchmarks using torch.nn Modules. + + This method of benchmark creation takes setup code, and then calls + a model rather than a free form block of code. As a result, there are + two additional requirements compared to `init_from_stmts`: + - `signature` must be provided. + - A model (named "model") must be defined, either with `model = ...` + or `def model(...): ...` in Python or `auto model = ...` in C++. + """ + signature_args, signature_output = cls._parse_signature(signature) + if signature_args is None: + raise ValueError("signature is needed when initializing from model definitions.") + + return cls( + *cls._make_model_invocation(signature_args, signature_output, RuntimeMode.EAGER), + py_model_setup=py_model_setup, + cpp_model_setup=cpp_model_setup, + inferred_model_setup=False, + setup=setup, + signature_args=signature_args, + signature_output=signature_output, + torchscript=torchscript, + autograd=autograd, + num_threads=(num_threads,) if isinstance(num_threads, int) else num_threads, + ) + + @classmethod + def init_from_variants( + cls, + py_block: str = "", + cpp_block: str = "", + num_threads: Union[int, Tuple[int, ...]] = 1, + ) -> Dict[Union[Tuple[str, ...], Optional[str]], "GroupedBenchmark"]: + + py_cases, py_setup, py_global_setup = cls._parse_variants(py_block, Language.PYTHON) + cpp_cases, cpp_setup, cpp_global_setup = cls._parse_variants(cpp_block, Language.CPP) + + assert not py_global_setup + setup = GroupedSetup( + py_setup=py_setup, + cpp_setup=cpp_setup, + global_setup=cpp_global_setup, + ) + + # NB: The key is actually `Tuple[str, ...]`, however MyPy gets confused + # and we use the superset `Union[Tuple[str, ...], Optional[str]` to + # match the expected signature. + variants: Dict[Union[Tuple[str, ...], Optional[str]], GroupedBenchmark] = {} + for label in set(list(py_cases.keys()) + list(cpp_cases.keys())): + py_lines = py_cases.get(label, []) + cpp_lines = cpp_cases.get(label, []) + + n_lines = max(len(py_lines), len(cpp_lines)) + py_lines += [""] * (n_lines - len(py_lines)) + cpp_lines += [""] * (n_lines - len(cpp_lines)) + lines = [ + (py_stmt, cpp_stmt) + for py_stmt, cpp_stmt in zip(py_lines, cpp_lines) + if py_stmt or cpp_stmt + ] + + for i, (py_stmt, cpp_stmt) in enumerate(lines): + variants[(label, f"Case: {i:>2}")] = GroupedBenchmark.init_from_stmts( + py_stmt=py_stmt or None, + cpp_stmt=cpp_stmt or None, + setup=setup, + num_threads=num_threads, + ) + + return variants + + def __post_init__(self) -> None: + if self.autograd and self.signature_output is None: + raise ValueError("An output variable must be specified when `autograd=True`.") + + if self.py_model_setup and "model" not in self.py_model_setup: + raise ValueError("`py_model_setup` appears to be missing `model` definition.") + + if self.cpp_model_setup and "model" not in self.cpp_model_setup: + raise ValueError("`cpp_model_setup` appears to be missing `model` definition.") + + # ========================================================================= + # == String manipulation methods ========================================== + # ========================================================================= + + @staticmethod + def _parse_signature( + signature: Optional[str] + ) -> Tuple[Optional[Tuple[str, ...]], Optional[str]]: + if signature is None: + return None, None + + match = re.search(r"^f\((.*)\) -> (.*)$", signature) + if match is None: + raise ValueError(f"Invalid signature: `{signature}`") + + args: Tuple[str, ...] = tuple(match.groups()[0].split(", ")) + output: str = match.groups()[1].strip() + + if "," in output: + raise ValueError(f"Multiple return values are not currently allowed: `{output}`") + + if output == "None": + return args, None + + return args, output + + @staticmethod + def _model_from_py_stmt( + py_stmt: Optional[str], + signature_args: Optional[Tuple[str, ...]], + signature_output: Optional[str], + ) -> str: + if py_stmt is None: + raise ValueError("`py_stmt` must be defined in order to derive a model.") + + if signature_args is None: + raise ValueError("signature is needed in order to derive a model.") + + return textwrap.dedent(f"""\ + def model({', '.join(signature_args)}): + {{stmt_str}} + return {signature_output} + """).format(stmt_str=textwrap.indent(py_stmt, ' ' * 4)) + + @staticmethod + def _make_model_invocation( + signature_args: Tuple[str, ...], + signature_output: Optional[str], + runtime: RuntimeMode, + ) -> Tuple[str, str]: + py_prefix, cpp_prefix = "", "" + if signature_output is not None: + py_prefix = f"{signature_output} = " + cpp_prefix = f"auto {signature_output} = " + + if runtime == RuntimeMode.EAGER: + model_name = "model" + cpp_invocation = f"{cpp_prefix}{model_name}->forward({', '.join(signature_args)});" + + else: + assert runtime == RuntimeMode.JIT + model_name = "jit_model" + cpp_invocation = textwrap.dedent(f"""\ + std::vector<torch::jit::IValue> ivalue_inputs({{ + {', '.join([f'torch::jit::IValue({a})' for a in signature_args])} + }}); + {cpp_prefix}{model_name}.forward(ivalue_inputs); + """) + + # NB: + # In python we invoke __call__, however C++ doesn't have an analogous + # method so we invoke `forward` instead. This means that that Python + # is doing extra work (e.g. checking hooks) compared to C++; however + # because this is the default user experience that's acceptable. + py_invocation = f"{py_prefix}{model_name}({', '.join(signature_args)})" + + return py_invocation, cpp_invocation + + @staticmethod + def _parse_variants(block: str, language: Language) -> Tuple[Dict[str, List[str]], str, str]: + block = textwrap.dedent(block).strip() + comment = "#" if language == Language.PYTHON else "//" + label_pattern = f"{comment} @(.+)$" + label = "" + + lines_by_label: Dict[str, List[str]] = {"SETUP": [], "GLOBAL_SETUP": []} + for line in block.splitlines(keepends=False): + match = re.search(label_pattern, line.strip()) + if match: + label = match.groups()[0] + if label.replace(" ", "_").upper() in ("SETUP", "GLOBAL_SETUP"): + label = label.replace(" ", "_").upper() + continue + + lines_by_label.setdefault(label, []) + if line.startswith(comment): + line = "" + lines_by_label[label].append(line) + + setup = "\n".join(lines_by_label.pop("SETUP")) + global_setup = "\n".join(lines_by_label.pop("GLOBAL_SETUP")) + + return lines_by_label, setup, global_setup + + +# These are the user facing APIs. +GroupedStmts = GroupedBenchmark.init_from_stmts +GroupedModules = GroupedBenchmark.init_from_model +GroupedVariants = GroupedBenchmark.init_from_variants diff --git a/benchmarks/instruction_counts/core/expand.py b/benchmarks/instruction_counts/core/expand.py new file mode 100644 index 0000000000..2736a4f805 --- /dev/null +++ b/benchmarks/instruction_counts/core/expand.py @@ -0,0 +1,260 @@ +"""Logic for converting human-readable benchmarks into executable form. + +This is mostly string manipulation, with just a bit of importlib magic. +""" +import importlib.abc +import importlib.util +import itertools as it +import os +import re +import textwrap +from typing import cast, List, Optional, Tuple, TYPE_CHECKING +import uuid + +import torch + +if TYPE_CHECKING: + # See the note in api.py for why this is necessary. + from torch.utils.benchmark.utils.timer import Language +else: + from torch.utils.benchmark import Language + +from core.api import AutogradMode, AutoLabels, GroupedBenchmark, RuntimeMode, TimerArgs +from core.types import FlatDefinition, FlatIntermediateDefinition, Label +from core.utils import get_temp_dir + + +_ALL_MODES = tuple(it.product( + RuntimeMode, + AutogradMode, + Language, +)) + + +def _generate_torchscript_file(model_src: str, name: str) -> Optional[str]: + """Returns the path a saved model if one can be constructed from `spec`. + + Because TorchScript requires actual source code in order to script a + model, we can't simply `eval` an appropriate model string. Instead, we + must write the correct source to a temporary Python file and then import + the TorchScript model from that temporary file. + + `model_src` must contain `jit_model = ...`, which `materialize` will supply. + """ + # Double check. + assert "jit_model = " in model_src, f"Missing jit_model definition:\n{model_src}" + + # `torch.utils.benchmark.Timer` will automatically import torch, so we + # need to match that convention. + model_src = f"import torch\n{model_src}" + + model_root = os.path.join(get_temp_dir(), "TorchScript_models") + os.makedirs(model_root, exist_ok=True) + module_path = os.path.join(model_root, f"torchscript_{name}.py") + artifact_path = os.path.join(model_root, f"torchscript_{name}.pt") + + if os.path.exists(module_path): + # The uuid in `name` should protect against this, but it doesn't hurt + # to confirm. + raise ValueError(f"File {module_path} already exists.") + + with open(module_path, "wt") as f: + f.write(model_src) + + # Import magic to actually load our function. + module_spec = importlib.util.spec_from_file_location(f"torchscript__{name}", module_path) + module = importlib.util.module_from_spec(module_spec) + loader = module_spec.loader + assert loader is not None + + # Module.loader has type Optional[_Loader]. Even when we assert loader is + # not None and MyPy narrows it to type _Loader, it will not pass type + # checks. So we have to use a cast to tell MyPy that _Loader implements + # importlib.abc.Loader. + cast(importlib.abc.Loader, loader).exec_module(module) + + # And again, the type checker has no way of knowing that this line is valid. + jit_model = module.jit_model # type: ignore + assert isinstance( + jit_model, + (torch.jit.ScriptFunction, torch.jit.ScriptModule) + ), f"Expected ScriptFunction or ScriptModule, got: {type(jit_model)}" + jit_model.save(artifact_path) + + # Cleanup now that we have the actual serialized model. + os.remove(module_path) + return artifact_path + + +def _get_stmt( + benchmark: GroupedBenchmark, + runtime: RuntimeMode, + autograd: AutogradMode, + language: Language, +) -> Optional[str]: + """Specialize a GroupedBenchmark for a particular configuration.""" + is_python = (language == Language.PYTHON) + + # During GroupedBenchmark construction, py_fwd_stmt and cpp_fwd_stmt are + # set to the eager invocation. So in the RuntimeMode.EAGER case we can + # simply reuse them. For the RuntimeMode.JIT case, we need to generate + # an appropriate `jit_model(...)` invocation. + if runtime == RuntimeMode.EAGER: + stmts = (benchmark.py_fwd_stmt, benchmark.cpp_fwd_stmt) + + else: + assert runtime == RuntimeMode.JIT + assert benchmark.signature_args is not None + stmts = GroupedBenchmark._make_model_invocation( + benchmark.signature_args, benchmark.signature_output, RuntimeMode.JIT) + + stmt = stmts[0 if is_python else 1] + + if autograd == AutogradMode.FORWARD_BACKWARD and stmt is not None: + assert benchmark.signature_output is not None + backward = ( + f"{benchmark.signature_output}" + + # In C++ we have to get the Tensor out of the IValue to call `.backward()` + f"{'.toTensor()' if runtime == RuntimeMode.JIT and language == Language.CPP else ''}" + f".backward(){';' if language == Language.CPP else ''}" + ) + stmt = f"{stmt}\n{backward}" + return stmt + + +def _get_setup( + benchmark: GroupedBenchmark, + runtime: RuntimeMode, + language: Language, + stmt: str, + model_path: Optional[str] +) -> str: + """Specialize a GroupedBenchmark for a particular configuration. + + Setup requires two extra pieces of information: + 1) The benchmark stmt. This is needed to warm up the model and avoid + measuring lazy initialization. + 2) The model path so we can load it during the benchmark. + + These are only used when `runtime == RuntimeMode.JIT`. + """ + + # By the time we get here, details about how to set up a model have already + # been determined by GroupedBenchmark. (Or set to None if appropriate.) We + # simply need to collect and package the code blocks. + if language == Language.PYTHON: + setup = benchmark.setup.py_setup + model_setup = benchmark.py_model_setup + else: + assert language == Language.CPP + setup = benchmark.setup.cpp_setup + model_setup = benchmark.cpp_model_setup + + if runtime == RuntimeMode.EAGER: + return "\n".join([setup, model_setup or ""]) + + assert runtime == RuntimeMode.JIT + assert model_path is not None + + # We template `"{model_path}"`, so quotes would break model loading. The + # model path is generated within the benchmark, so this is just an + # abundance of caution rather than something that is expected in practice. + assert '"' not in model_path + + # `stmt` may contain newlines, so we can't use f-strings. Instead we need + # to generate templates so that dedent works properly. + if language == Language.PYTHON: + setup_template: str = textwrap.dedent(f""" + jit_model = torch.jit.load("{model_path}") + + # Warmup `jit_model` + for _ in range(3): + {{stmt}} + """) + + else: + assert language == Language.CPP + setup_template = textwrap.dedent(f""" + const std::string fpath = "{model_path}"; + auto jit_model = torch::jit::load(fpath); + + // Warmup `jit_model` + for (int i = 0; i < 3; i++) {{{{ + {{stmt}} + }}}} + """) + + model_load = setup_template.format(stmt=textwrap.indent(stmt, ' ' * 4)) + return "\n".join([setup, model_load]) + + +def materialize(benchmarks: FlatIntermediateDefinition) -> FlatDefinition: + """Convert a heterogeneous benchmark into an executable state. + + This entails generation of TorchScript model artifacts, splitting + GroupedBenchmarks into multiple TimerArgs, and tagging the results with + AutoLabels. + """ + results: List[Tuple[Label, AutoLabels, TimerArgs]] = [] + + for label, args in benchmarks.items(): + if isinstance(args, TimerArgs): + # User provided an explicit TimerArgs, so no processing is necessary. + auto_labels = AutoLabels( + RuntimeMode.EXPLICIT, + AutogradMode.EXPLICIT, + args.language + ) + results.append((label, auto_labels, args)) + + else: + assert isinstance(args, GroupedBenchmark) + + model_path: Optional[str] = None + if args.py_model_setup and args.torchscript: + model_setup = f"{args.py_model_setup}\njit_model = torch.jit.script(model)" + + # This is just for debugging. We just need a unique name for the + # model, but embedding the label makes debugging easier. + name: str = re.sub(r'[^a-z0-9_]', '_', '_'.join(label).lower()) + name = f"{name}_{uuid.uuid4()}" + + model_path = _generate_torchscript_file(model_setup, name=name) + + for (runtime, autograd, language), num_threads in it.product(_ALL_MODES, args.num_threads): + if runtime == RuntimeMode.EXPLICIT or autograd == AutogradMode.EXPLICIT: + continue + + if runtime == RuntimeMode.JIT and not args.torchscript: + continue + + if autograd == AutogradMode.FORWARD_BACKWARD and not args.autograd: + continue + + stmt = _get_stmt(args, runtime, autograd, language) + if stmt is None: + continue + + setup = _get_setup(args, runtime, language, stmt, model_path) + + global_setup: str = "" + if language == Language.CPP and runtime == RuntimeMode.JIT: + global_setup = textwrap.dedent(""" + #include <string> + #include <vector> + #include <torch/script.h> + """) + + autolabels = AutoLabels(runtime, autograd, language) + timer_args = TimerArgs( + stmt=stmt, + setup=setup, + global_setup=global_setup, + num_threads=num_threads, + language=language, + ) + + results.append((label, autolabels, timer_args)) + + return tuple(results) diff --git a/benchmarks/instruction_counts/core/types.py b/benchmarks/instruction_counts/core/types.py new file mode 100644 index 0000000000..8f36bfbcec --- /dev/null +++ b/benchmarks/instruction_counts/core/types.py @@ -0,0 +1,94 @@ +"""Type annotations for various benchmark objects.""" +from typing import Any, Dict, Optional, Tuple, Union + +from core.api import AutoLabels, TimerArgs, GroupedBenchmark + + +# ============================================================================= +# == Benchmark schema ========================================================= +# ============================================================================= +""" (There is a TL;DR at the end for ad-hoc benchmarks.) +The end state for representing a benchmark is: + ``` + Tuple[ + Tuple[ + Tuple[str, ...], # Primary key + core.api.AutoLabels, # Secondary key + core.api.TimerArgs, # Value + ], + ... + ] + ``` + +For example: + ``` + [ + (("pointwise", "add"), AutoLabels(..., Language.PYTHON), TimerArgs(...)), + (("pointwise", "add"), AutoLabels(..., Language.CPP), TimerArgs(...)), + ... + ] + ``` + +However, such a flat list is somewhat tedious to maintain (and read), because +there is significant duplication in the key structure. So instead, we would +like to define something like: + ``` + { + "pointwise" : { + "add": { + None: GroupedStmts(...), + "with alpha": GroupedStmts(...), + }, + "mul": GroupedStmts(...), + }, + "matmul": GroupedStmts(...), + } + ``` +and then parse out a flat representation. The type declarations below are +simply formalizing the structure of nested dictionaries with string or tuple +of string keys. + +TL;DR + If you only care about writing an ad-hoc benchmark for a PR, just use a + flat dictionary and everything will work. For example: + ``` + { + "case 0": TimerArgs(...), + "case 1": TimerArgs(...), + "case 2": GroupedStmts(...), + ... + } + ``` +""" + +# Allow strings in definition for convenience, and None to signify a base +# case. (No subsequent entry needed. See the "add" example above.) +Label = Tuple[str, ...] +_Label = Union[Label, Optional[str]] + +# MyPy does not currently support recursive types: +# https://github.com/python/mypy/issues/731 +# +# So while the correct type definition would be: +# _Value = Union[ +# # Base case: +# Union[TimerArgs, GroupedBenchmark], +# +# # Recursive case: +# Dict[Label, "_Value"], +# ] +# we instead have to use Any and rely on runtime asserts when flattening. +_Value = Union[ + Union[TimerArgs, GroupedBenchmark], + Dict[_Label, Any], +] + +Definition = Dict[_Label, _Value] + +# We initially have to parse (flatten) to an intermediate state in order to +# build TorchScript models since multiple entries will share the same model +# artifact. +FlatIntermediateDefinition = Dict[Label, Union[TimerArgs, GroupedBenchmark]] + +# Final parsed schema. +FlatDefinition = Tuple[Tuple[Label, AutoLabels, TimerArgs], ...] diff --git a/benchmarks/instruction_counts/core/utils.py b/benchmarks/instruction_counts/core/utils.py new file mode 100644 index 0000000000..1da98c80df --- /dev/null +++ b/benchmarks/instruction_counts/core/utils.py @@ -0,0 +1,99 @@ +import atexit +import shutil +import re +import tempfile +import textwrap +from typing import List, Optional, Tuple + +from core.api import GroupedBenchmark, TimerArgs +from core.types import Definition, FlatIntermediateDefinition, Label + + +_TEMPDIR: Optional[str] = None +def get_temp_dir() -> str: + global _TEMPDIR + if _TEMPDIR is None: + temp_dir = tempfile.mkdtemp() + atexit.register(shutil.rmtree, path=temp_dir) + _TEMPDIR = temp_dir + return _TEMPDIR + + +def _flatten( + key_prefix: Label, + sub_schema: Definition, + result: FlatIntermediateDefinition +) -> None: + for k, value in sub_schema.items(): + if isinstance(k, tuple): + assert all(isinstance(ki, str) for ki in k) + key_suffix: Label = k + elif k is None: + key_suffix = () + else: + assert isinstance(k, str) + key_suffix = (k,) + + key: Label = key_prefix + key_suffix + if isinstance(value, (TimerArgs, GroupedBenchmark)): + assert key not in result, f"duplicate key: {key}" + result[key] = value + else: + assert isinstance(value, dict) + _flatten(key_prefix=key, sub_schema=value, result=result) + + +def flatten(schema: Definition) -> FlatIntermediateDefinition: + """See types.py for an explanation of nested vs. flat definitions.""" + result: FlatIntermediateDefinition = {} + _flatten(key_prefix=(), sub_schema=schema, result=result) + + # Ensure that we produced a valid flat definition. + for k, v in result.items(): + assert isinstance(k, tuple) + assert all(isinstance(ki, str) for ki in k) + assert isinstance(v, (TimerArgs, GroupedBenchmark)) + return result + + +def parse_stmts(stmts: str) -> Tuple[str, str]: + """Helper function for side-by-side Python and C++ stmts. + + For more complex statements, it can be useful to see Python and C++ code + side by side. To this end, we provide an **extremely restricted** way + to define Python and C++ code side-by-side. The schema should be mostly + self explanatory, with the following non-obvious caveats: + - Width for the left (Python) column MUST be 40 characters. + - The column separator is " | ", not "|". Whitespace matters. + """ + stmts = textwrap.dedent(stmts).strip() + lines: List[str] = stmts.splitlines(keepends=False) + assert len(lines) >= 3, f"Invalid string:\n{stmts}" + + column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$" + signature_pattern = r"^: f\((.*)\)( -> (.+))?\s*$" + separation_pattern = r"^[-]{40} | [-]{40}$" + code_pattern = r"^(.{40}) \|($| (.*)$)" + + column_match = re.search(column_header_pattern, lines[0]) + if column_match is None: + raise ValueError( + f"Column header `{lines[0]}` " + f"does not match pattern `{column_header_pattern}`") + + assert re.search(separation_pattern, lines[1]) + + py_lines: List[str] = [] + cpp_lines: List[str] = [] + for l in lines[2:]: + l_match = re.search(code_pattern, l) + if l_match is None: + raise ValueError(f"Invalid line `{l}`") + py_lines.append(l_match.groups()[0]) + cpp_lines.append(l_match.groups()[2] or "") + + # Make sure we can round trip for correctness. + l_from_stmts = f"{py_lines[-1]:<40} | {cpp_lines[-1]:<40}".rstrip() + assert l_from_stmts == l.rstrip(), f"Failed to round trip `{l}`" + + return "\n".join(py_lines), "\n".join(cpp_lines) diff --git a/benchmarks/instruction_counts/definitions/__init__.py b/benchmarks/instruction_counts/definitions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/benchmarks/instruction_counts/definitions/__init__.py diff --git a/benchmarks/instruction_counts/definitions/setup.py b/benchmarks/instruction_counts/definitions/setup.py new file mode 100644 index 0000000000..6bd044702b --- /dev/null +++ b/benchmarks/instruction_counts/definitions/setup.py @@ -0,0 +1,57 @@ +"""Define some common setup blocks which benchmarks can reuse.""" + +import enum + +from core.api import GroupedSetup +from core.utils import parse_stmts + + +_TRIVIAL_2D = GroupedSetup( + r"x = torch.ones((4, 4))", + r"auto x = torch::ones({4, 4});" +) + + +_TRIVIAL_4D = GroupedSetup( + r"x = torch.ones((4, 4, 4, 4))", + r"auto x = torch::ones({4, 4, 4, 4});" +) + + +_GENERIC = GroupedSetup(*parse_stmts( + r""" + Python | C++ + ---------------------------------------- | ---------------------------------------- + torch.manual_seed(138_10_23) | torch::manual_seed(1381023); + x = torch.rand((4, 4)) | auto x = torch::rand({4, 4}); + y_float = torch.ones((4, 4)) | auto y_float = torch::ones({4, 4}); + y_int = torch.ones( | auto y_int = torch::ones({4, 4}, at::kInt); + (4, 4), dtype=torch.int32) | + """ +)) + + +_TRAINING = GroupedSetup(*parse_stmts( + r""" + Python | C++ + ---------------------------------------- | ---------------------------------------- + # Inputs | // Inputs + x = torch.ones((1,)) | auto x = torch::ones({1}); + y = torch.ones((1,)) | auto y = torch::ones({1}); + | + # Weights | // Weights + w0 = torch.ones( | auto w0 = torch::ones({1}); + (1,), requires_grad=True) | w0.set_requires_grad(true); + w1 = torch.ones( | auto w1 = torch::ones({1}); + (1,), requires_grad=True) | w1.set_requires_grad(true); + w2 = torch.ones( | auto w2 = torch::ones({2}); + (2,), requires_grad=True) | w2.set_requires_grad(true); + """ +)) + + +class Setup(enum.Enum): + TRIVIAL_2D = _TRIVIAL_2D + TRIVIAL_4D = _TRIVIAL_4D + GENERIC = _GENERIC + TRAINING = _TRAINING diff --git a/benchmarks/instruction_counts/definitions/standard.py b/benchmarks/instruction_counts/definitions/standard.py new file mode 100644 index 0000000000..cf4583cd00 --- /dev/null +++ b/benchmarks/instruction_counts/definitions/standard.py @@ -0,0 +1,143 @@ +"""Default set of benchmarks. + +Parser notes: + `parse_stmts`: + - Width for the left (Python) column MUST be 40 characters. + - The column separator is " | ", not "|". Whitespace matters. + + `GroupedVariants`: + - `Setup` and `Global_Setup` (case insensitive) are reserved keywords + to populate `setup` and `global_setup` for every generated benchmark. + - To set a label for the succeeding block, add `# @YOUR_LABEL` (Python) + or `// @YOUR_LABEL` (C++). +""" + +from core.api import GroupedModules, GroupedStmts, GroupedVariants +from core.types import FlatIntermediateDefinition +from core.utils import flatten, parse_stmts +from definitions.setup import Setup + + +BENCHMARKS: FlatIntermediateDefinition = flatten({ + "Empty": { + "no allocation": GroupedStmts( + r"torch.empty(())", + r"torch::empty({0});", + ), + + "with allocation": GroupedStmts( + r"torch.empty((1,))", + r"torch::empty({1});", + ), + + "overloads": GroupedVariants( + cpp_block=r""" + // @Setup + auto options_empty = c10::TensorOptions(); + auto options_full = c10::TensorOptions().dtype(at::kFloat).device(at::kCPU); + auto optional_float = c10::make_optional(at::kFloat); + + // @TensorOptions overload + at::empty({0}, options_empty); + at::empty({0}, options_full); + at::empty({0}, at::kFloat); // implicit conversion + + // @Faithful overload + at::empty({0}, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + at::empty({0}, at::kFloat, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + at::empty({0}, optional_float, c10::nullopt, c10::nullopt, c10::nullopt, c10::nullopt); + """ + ), + }, + + "Pointwise": { + "Math": { + "add": { + "Tensor-Scalar": GroupedStmts( + r"x += 1.0", + r"x += 1.0;", + setup=Setup.GENERIC.value, + ), + }, + }, + }, + + "Indexing": GroupedVariants(*parse_stmts(r""" + Python | C++ + ---------------------------------------- | ---------------------------------------- + # @setup | // @setup + | using namespace torch::indexing; + torch.manual_seed(6626_10_34) | torch::manual_seed(66261034); + | + x = torch.randn(1, 1, 1) | auto x = torch::randn({1, 1, 1}); + y = torch.randn(1, 1, 1) | auto y = torch::randn({1, 1, 1}); + | + # @Tensor-Scalar | // @Tensor-Scalar + x[0] = 1 | x.index_put_({0}, 1); + x[0, 0] = 1 | x.index_put_({0, 0}, 1); + x[0, 0, 0] = 1 | x.index_put_({0, 0, 0}, 1); + | + # @Tensor-Scalar (Advanced) | // @Tensor-Scalar (Advanced) + x[...] = 1 | x.index_put_({"..."}, 1); + x[:] = 1 | x.index_put_({Slice(None, None, None)}, 1); + x[None] = 1 | x.index_put_({None}, 1); + x[False] = 1 | x.index_put_({false}, 1); + x[True] = 1 | x.index_put_({true}, 1); + | + # @Tensor-Tensor | // @Tensor-Tensor + x[0] = y[0] | x.index_put_({0}, y.index({0})); + x[0, 0] = y[0, 0] | x.index_put_({0, 0}, y.index({0, 0})); + x[0, 0, 0] = y[0, 0, 0] | x.index_put_({0, 0, 0}, y.index({0, 0, 0})); + | + # @Tensor-Tensor (Advanced) | // @Tensor-Tensor (Advanced) + x[...] = y[...] | x.index_put_({"..."}, y.index({"..."})); + x[:] = y[:] | x.index_put_({Slice(None, None, None)}, y.index({Slice(None, None, None)})); + x[None] = y[None] | x.index_put_({None}, y.index({None})); + x[False] = y[False] | x.index_put_({false}, y.index({false})); + x[True] = y[True] | x.index_put_({true}, y.index({true})); + """)), + + "nn Modules": { + "Linear": GroupedModules( + "model = torch.nn.Linear(4, 2)", + "auto model = torch::nn::Linear(4, 2);", + setup=Setup.TRIVIAL_4D.value, + signature="f(x) -> y", + torchscript=True, + ), + }, + + "training": { + "simple": GroupedStmts( + *parse_stmts(r""" + Python | C++ + ---------------------------------------- | ---------------------------------------- + a0 = torch.nn.functional.relu(x * w0) | auto a0 = torch::nn::functional::relu(x * w0); + y = a0 * w1 | auto y = a0 * w1; + """), + Setup.TRAINING.value, + num_threads=(1, 2), + signature=r"f(x, w0, w1) -> y", + torchscript=True, + autograd=True, + ), + + "ensemble": GroupedStmts( + *parse_stmts(r""" + Python | C++ + ---------------------------------------- | ---------------------------------------- + a0 = torch.nn.functional.gelu(x * w0) | auto a0 = torch::nn::functional::gelu(x * w0); + a1 = torch.nn.functional.prelu(y, w1) | auto a1 = torch::nn::functional::prelu(y, w1); + z = torch.nn.functional.normalize( | auto z = torch::nn::functional::normalize( + torch.cat([a0, a1]), | torch::cat({a0, a1}), + p=2.0, dim=0, | torch::nn::functional::NormalizeFuncOptions().p(2).dim(0) + ).dot(w2) | ).dot(w2); + """), + Setup.TRAINING.value, + num_threads=(1, 2), + signature=r"f(x, y, w0, w1, w2) -> z", + torchscript=True, + autograd=True, + ), + }, +}) diff --git a/benchmarks/instruction_counts/main.py b/benchmarks/instruction_counts/main.py new file mode 100644 index 0000000000..0f8773a9cf --- /dev/null +++ b/benchmarks/instruction_counts/main.py @@ -0,0 +1,57 @@ +"""Basic runner for the instruction count microbenchmarks. + +The contents of this file are placeholders, and will be replaced by more +expressive and robust components (e.g. better runner and result display +components) in future iterations. However this allows us to excercise the +underlying benchmark generation infrastructure in the mean time. +""" +import multiprocessing +import multiprocessing.dummy +import os +import pickle +import subprocess +from typing import Tuple + +from core.api import AutoLabels, TimerArgs +from core.expand import materialize +from core.types import Label +from core.utils import get_temp_dir +from definitions.standard import BENCHMARKS +from worker.main import WORKER_PATH, WorkerFailure, WorkerOutput, WorkerTimerArgs, WorkerUnpickler + + +def call_worker( + args: Tuple[int, Tuple[Label, AutoLabels, TimerArgs]] +) -> Tuple[Label, AutoLabels, int, WorkerOutput]: + worker_id, (label, autolabels, timer_args) = args + + communication_file = os.path.join(get_temp_dir(), f"communication_file_{worker_id}.pkl") + with open(communication_file, "wb") as f: + pickle.dump(timer_args, f) + + subprocess.call( + ["python", WORKER_PATH, "--communication_file", communication_file], + shell=False, + ) + + with open(communication_file, "rb") as f: + result = WorkerUnpickler(f).load_output() + + if isinstance(result, WorkerTimerArgs): + raise RuntimeError("Benchmark worker failed without starting.") + + elif isinstance(result, WorkerFailure): + raise RuntimeError(f"Worker failed: {label} {autolabels}\n{result.failure_trace}") + + assert isinstance(result, WorkerOutput) + return label, autolabels, timer_args.num_threads, result + + +def main() -> None: + with multiprocessing.dummy.Pool(multiprocessing.cpu_count() - 4) as pool: + for label, autolabels, num_threads, result in pool.imap(call_worker, enumerate(materialize(BENCHMARKS)), 1): + print(label, autolabels, num_threads, result.instructions) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/instruction_counts/worker/__init__.py b/benchmarks/instruction_counts/worker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/benchmarks/instruction_counts/worker/__init__.py diff --git a/benchmarks/instruction_counts/worker/main.py b/benchmarks/instruction_counts/worker/main.py new file mode 100644 index 0000000000..f59509de74 --- /dev/null +++ b/benchmarks/instruction_counts/worker/main.py @@ -0,0 +1,188 @@ +"""File invoked through subprocess to actually carry out measurements. + +`worker/main.py` is deliberately isolated from the rest of the benchmark +infrastructure. Other parts of the benchmark rely on this file, but +`worker/` has only one Python file and does not import ANYTHING from the rest +of the benchmark suite. The reason that this is important is that we can't +rely on paths to access the other files (namely `core.api`) since a source +command might change the CWD. It also helps keep startup time down by limiting +spurious definition work. + +The life of a worker is very simple: + It receives a file containing a `WorkerTimerArgs` telling it what to run, + and writes a `WorkerOutput` result back to the same file. + +Because this file only expects to run in a child context, error handling means +plumbing failures up to the caller, not raising in this process. +""" +import argparse +import dataclasses +import io +import os +import pickle +import timeit +import traceback +from typing import Any, Tuple, Union, TYPE_CHECKING +import sys + + +if TYPE_CHECKING: + # Benchmark utils are only partially strict compliant, so MyPy won't follow + # imports using the public namespace. (Due to an exclusion rule in + # mypy-strict.ini) + from torch.utils.benchmark.utils.timer import Language, Timer + from torch.utils.benchmark.utils.valgrind_wrapper.timer_interface import CallgrindStats + +else: + from torch.utils.benchmark import CallgrindStats, Language, Timer + + +WORKER_PATH = os.path.abspath(__file__) + + +# ============================================================================= +# == Interface ================================================================ +# ============================================================================= + +# While the point of this is mainly to collect instruction counts, we're going +# to have to compile C++ timers anyway (as they're used as a check before +# calling Valgrind), so we may as well grab wall times for reference. They +# are comparatively inexpensive. +MIN_RUN_TIME = 5 + +# Repeats are inexpensive as long as they are all run in the same process. This +# also lets us filter outliers (e.g. malloc arena reorganization), so we don't +# need a high CALLGRIND_NUMBER to get good data. +CALLGRIND_NUMBER = 100 +CALLGRIND_REPEATS = 5 + + +@dataclasses.dataclass(frozen=True) +class WorkerTimerArgs: + """Container for Timer constructor arguments. + + This dataclass serves two roles. First, it is a simple interface for + defining benchmarks. (See core.api.GroupedStmts and core.api.GroupedModules + for the advanced interfaces.) Second, it provides serialization for + controlling workers. `Timer` is not pickleable, so instead the main process + will pass `WorkerTimerArgs` instances to workers for processing. + """ + stmt: str + setup: str = "pass" + global_setup: str = "" + num_threads: int = 1 + language: Language = Language.PYTHON + + +@dataclasses.dataclass(frozen=True) +class WorkerOutput: + # Only return values to reduce communication between main process and workers. + wall_times: Tuple[float, ...] + instructions: Tuple[int, ...] + + +@dataclasses.dataclass(frozen=True) +class WorkerFailure: + # If a worker fails, we attach the string contents of the Exception + # rather than the Exception object itself. This is done for two reasons: + # 1) Depending on the type thrown, `e` may or may not be pickleable + # 2) If we re-throw in the main process, we lose the true stack trace. + failure_trace: str + + +class WorkerUnpickler(pickle.Unpickler): + def find_class(self, module: str, name: str) -> Any: + """Resolve import for pickle. + + When the main runner uses a symbol `foo` from this file, it sees it as + `worker.main.foo`. However the worker (called as a standalone file) + sees the same symbol as `__main__.foo`. We have to help pickle + understand that they refer to the same symbols. + """ + symbol_map = { + # Only blessed interface Enums and dataclasses need to be mapped. + "WorkerTimerArgs": WorkerTimerArgs, + "WorkerOutput": WorkerOutput, + "WorkerFailure": WorkerFailure, + } + + if name in symbol_map: + return symbol_map[name] + + return super().find_class(module, name) + + def load_input(self) -> WorkerTimerArgs: + result = self.load() + assert isinstance(result, WorkerTimerArgs) + return result + + def load_output(self) -> Union[WorkerTimerArgs, WorkerOutput, WorkerFailure]: + """Convenience method for type safe loading.""" + result = self.load() + assert isinstance(result, (WorkerTimerArgs, WorkerOutput, WorkerFailure)) + return result + + +# ============================================================================= +# == Execution ================================================================ +# ============================================================================= + +def _run(timer_args: WorkerTimerArgs) -> WorkerOutput: + timer = Timer( + stmt=timer_args.stmt, + setup=timer_args.setup or "pass", + global_setup=timer_args.global_setup, + + # Prevent NotImplementedError on GPU builds and C++ snippets. + timer=timeit.default_timer, + num_threads=timer_args.num_threads, + language=timer_args.language, + ) + + m = timer.blocked_autorange(min_run_time=MIN_RUN_TIME) + + stats: Tuple[CallgrindStats, ...] = timer.collect_callgrind( + number=CALLGRIND_NUMBER, + collect_baseline=False, + repeats=CALLGRIND_REPEATS, + retain_out_file=False, + ) + + return WorkerOutput( + wall_times=tuple(m.times), + instructions=tuple(s.counts(denoise=True) for s in stats) + ) + + +def main(communication_file: str) -> None: + result: Union[WorkerOutput, WorkerFailure] + try: + with open(communication_file, "rb") as f: + timer_args: WorkerTimerArgs = WorkerUnpickler(f).load_input() + assert isinstance(timer_args, WorkerTimerArgs) + result = _run(timer_args) + + except KeyboardInterrupt: + # Runner process sent SIGINT. + sys.exit() + + except BaseException: + trace_f = io.StringIO() + traceback.print_exc(file=trace_f) + result = WorkerFailure(failure_trace=trace_f.getvalue()) + + if not os.path.exists(os.path.split(communication_file)[0]): + # This worker is an orphan, and the parent has already cleaned up the + # working directory. In that case we can simply exit. + print(f"Orphaned worker {os.getpid()} exiting.") + return + + with open(communication_file, "wb") as f: + pickle.dump(result, f) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--communication_file', type=str) + communication_file = parser.parse_args().communication_file + main(communication_file) diff --git a/mypy-strict.ini b/mypy-strict.ini index e41ca8bf03..4608f7cf70 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -36,6 +36,8 @@ strict_equality = True files = .github/scripts/generate_binary_build_matrix.py, + benchmarks/instruction_counts/*.py, + benchmarks/instruction_counts/*/*.py, tools/autograd/*.py, tools/codegen/gen.py, tools/mypy_wrapper.py, |