summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMingzhe Li <mingzhe0908@fb.com>2019-04-02 17:03:23 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-02 17:06:19 -0700
commit5f5a2aaab9d0b847df09e0a7f367e603ef6bcb2a (patch)
treee0c8f2dc6c4e829b241d47e3bf31064981ba4052
parentb832b99afb241a8b6ea9fc34698d1f3bfd451f00 (diff)
downloadpytorch-5f5a2aaab9d0b847df09e0a7f367e603ef6bcb2a.tar.gz
pytorch-5f5a2aaab9d0b847df09e0a7f367e603ef6bcb2a.tar.bz2
pytorch-5f5a2aaab9d0b847df09e0a7f367e603ef6bcb2a.zip
Operator-level performance microbenchmarks (#18740)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18740 Test utilities for writing Caffe2/PyTorch performance microbenchmarks. Brief description of the file structure * benchmark_core.py : core utiltiites for running microbenchmark tests * benchmark_caffe2.py : Caffe2 specific benchmark utilitites * benchmark_pytorch.py: PyTorch specific benchmark utilities * benchmark_runner.py : Main function. Currently it can run the microbenchmark tests in a stand-alone mode. The next step is to have this integrate with AI-PEP. The utilities are located at https://github.com/pytorch/pytorch/tree/master/test to have access to both Caffe2/PyTorch Python's frontend. Include two operator microbenchmarks; support both Caffe2/PyTorch: * MatMul * Add Reference: PyTorch benchmarks : https://github.com/pytorch/benchmark/tree/master/timing/python. In this work, we start with two example binary operators MatMul and Add, but eventually we should to cover unary operators like in the PyTorch benchmark repo. Reviewed By: zheng-xq Differential Revision: D13887111 fbshipit-source-id: b7a56b95448c9ec3e674b0de0ffb96af4439bfce
-rw-r--r--benchmarks/operator_benchmark/__init__.py0
-rw-r--r--benchmarks/operator_benchmark/benchmark_caffe2.py47
-rw-r--r--benchmarks/operator_benchmark/benchmark_core.py187
-rw-r--r--benchmarks/operator_benchmark/benchmark_pytorch.py29
-rw-r--r--benchmarks/operator_benchmark/benchmark_runner.py90
-rw-r--r--benchmarks/operator_benchmark/benchmark_utils.py35
-rw-r--r--benchmarks/operator_benchmark/ops/__init__.py0
-rw-r--r--benchmarks/operator_benchmark/ops/add.py68
-rw-r--r--benchmarks/operator_benchmark/ops/matmul.py63
-rw-r--r--caffe2/python/pybind_state.cc20
-rw-r--r--caffe2/python/workspace.py8
11 files changed, 547 insertions, 0 deletions
diff --git a/benchmarks/operator_benchmark/__init__.py b/benchmarks/operator_benchmark/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/benchmarks/operator_benchmark/__init__.py
diff --git a/benchmarks/operator_benchmark/benchmark_caffe2.py b/benchmarks/operator_benchmark/benchmark_caffe2.py
new file mode 100644
index 0000000000..cf341c4fe8
--- /dev/null
+++ b/benchmarks/operator_benchmark/benchmark_caffe2.py
@@ -0,0 +1,47 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.python import core, workspace
+from caffe2.benchmarks.operator_benchmark import benchmark_core, benchmark_utils
+
+"""Caffe2 performance microbenchmarks.
+
+This module contains Caffe2-specific functionalities for performance
+microbenchmarks.
+"""
+
+
+def Caffe2OperatorTestCase(test_name, op_type, input_shapes, op_args, run_mode):
+ """Benchmark Tester function for Caffe2 framework.
+ test_case is expected to be a Caffe2OperatorTestCase object. If not, the
+ function will return False.
+ It returns a function that contains the code to benchmarked
+ (operator execution).
+ """
+ idx = 0
+ input_blobs = []
+ for input in input_shapes:
+ blob_name = 'input_' + test_name + str(input_shapes) + str(op_args) + str(idx)
+ input_blobs.append(blob_name)
+ # TODO: figure out the data type from operator schema/
+ # or accept custom data type for more comprehensive coverage.
+ # Also, consider a more complex range/distribution of numerical inputs.
+ workspace.FeedBlob(blob_name, benchmark_utils.numpy_random_fp32(*input))
+ idx += 1
+
+ # TODO: consider reuse logic in Caffe2's Functional utility to get
+ # these benefits
+ # - Read operator schema to figure out if inplace enforcement is needed
+ # for the operator and name the output blob appropriately.
+ # - Also figure out the number of outputs from operator schema.
+ op = core.CreateOperator(
+ op_type, input_blobs, ['out'], **op_args
+ )
+
+ def benchmark_func(num_runs):
+ if not workspace.RunOperatorMultiple(op, num_runs):
+ raise RuntimeError('Unable to run operator test case ' % test_name)
+
+ benchmark_core.add_benchmark_tester("Caffe2", test_name, input_shapes, op_args, run_mode, benchmark_func)
diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py
new file mode 100644
index 0000000000..2693f8461b
--- /dev/null
+++ b/benchmarks/operator_benchmark/benchmark_core.py
@@ -0,0 +1,187 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import functools
+import numpy as np
+import timeit
+import json
+
+from caffe2.benchmarks.operator_benchmark import benchmark_utils
+
+"""Performance microbenchmarks.
+
+This module contains core functionalities for performance microbenchmark tests.
+"""
+
+
+# List of run modes we support.
+# Each benchmark test case is associated with a run mode.
+# If the value of the test case's run mode is less than the value of the
+# benchmark binary's run mode, the test case will be executed, e.g. a short-mode
+# test case will be executed when the binary is on either long and short
+# modes; while a long-mode test case will only be executed when the binary is
+# on long-mode.
+RUN_MODES = {'short': 0, 'long': 1}
+BENCHMARK_TESTER = [{} for _ in range(len(RUN_MODES))]
+BENCHMARK_TEST_GROUP = {}
+
+
+def add_benchmark_tester(framework, op_name, input_shapes, op_args, run_mode, func):
+ func_name = "__".join([framework, op_name, benchmark_utils.shape_to_string(input_shapes)
+ , str(op_args), run_mode])
+ run_mode = RUN_MODES[run_mode]
+ for mode in RUN_MODES.values():
+ # short mode runs with some of the input shapes for an op
+ # long mode runs with all the input shapes for an op
+ if (mode < run_mode):
+ continue
+ BENCHMARK_TESTER[mode][func_name] = func
+
+
+def benchmark_test_group(func):
+ """Decorator to register a benchmark test group.
+ A benchmark test group is a function that returns a list of benchmark test
+ case objects to be run.
+ """
+ BENCHMARK_TEST_GROUP[__name__ + "." + func.__name__] = func
+ return func
+
+
+HEADER_LINE = """
+# {}
+# PyTorch/Caffe2 Operator Micro-benchmarks
+# {}
+# Run_mode : {}
+"""
+
+
+class BenchmarkRunner(object):
+ """BenchmarkRunner is responsible for benchmarking all the registered
+ benchmark test groups.
+
+ Attributes:
+ run_mode (str): Must of one of 'short', 'long'. For long mode, the
+ benchmark runner takes a longer time to run since it repeats each benchmark
+ test case more times to reduce measured variance, and it also executes
+ longer running test cases that is marked as long mode.
+ operator (str): Only run benchmark test cases that contains
+ this filter string in the test case's id.
+ """
+ def __init__(self, args):
+ # Depend on the run mode, set the execution contrains based of number of
+ # runs per measure, and number of measures.
+ # TODO: consider time-bound constraints as well.
+ self.args = args
+ self.iters = 100
+ self.has_explicit_iteration_count = False
+ self.multiplier = 2
+ self.min_time = 0.8
+ self.max_iters = 1e6
+ for test_group in BENCHMARK_TEST_GROUP.items():
+ test_group_func = test_group[1]
+ test_group_func()
+ if self.args.iterations:
+ self.has_explicit_iteration_count = True
+ self.iters = self.args.iterations
+
+ def _print_header(self, run_mode):
+ DASH_LINE = '-' * 40
+ print(HEADER_LINE.format(DASH_LINE, DASH_LINE, self.args.run_mode, self.iters))
+ print("# List of Operators to run:")
+ if self.args.operator is None:
+ ops = set()
+ for tester in BENCHMARK_TESTER[run_mode].items():
+ full_test_id = tester[0]
+ framework, op_name, input_shapes, args, run_mode = full_test_id.split("__")
+ if op_name not in ops:
+ print("# {}".format(op_name))
+ ops.add(op_name)
+ else:
+ print("# {}".format(self.args.operator))
+ print("\n")
+
+ def _print_perf_result(self, full_test_id, input_shapes, args, reported_run_time):
+ if self.args.ai_pep_format:
+ # Output for AI-PEP
+ print("Caffe2Observer " + json.dumps(
+ {
+ "type": "NET",
+ "metric": full_test_id,
+ "unit": "ms",
+ "value": str(reported_run_time),
+ }
+ ))
+ else:
+ print("# Input Shape: {}\n"
+ "Execution Time (us) : {:.3f} \n"
+ .format(input_shapes, reported_run_time))
+
+ def _predict_num_iter_needed(self, i):
+ return (i * self.multiplier)
+
+ def _report_iteration_result(self, iters, run_time):
+ return (iters > self.max_iters or
+ run_time > 5 * self.min_time)
+
+ def run(self):
+ run_mode = RUN_MODES[self.args.run_mode]
+ self._print_header(run_mode)
+
+ if self.args.list_tests:
+ return
+
+ for tester in BENCHMARK_TESTER[run_mode].items():
+ full_test_id = tester[0]
+ benchmark_func = tester[1]
+ framework, op_name, input_shapes, args, run_mode = full_test_id.split("__")
+ # TODO: consider regex matching for test filtering.
+ # Currently, this is a sub-string matching.
+ if self.args.operator and (self.args.operator not in full_test_id):
+ continue
+ if self.args.framework and (self.args.framework not in full_test_id):
+ continue
+
+ # To reduce variance, fix a numpy randseed to the test case,
+ # so that the randomly generated input tensors remain the
+ # same for each test case.
+ # The random seed is limited to 32-bit because of numpy
+ # requirement.
+ np.random.seed(seed=hash(full_test_id) & ((1 << 32) - 1))
+
+ print("# Benchmarking {} {}".format(
+ framework,
+ op_name))
+ # Warmup
+ functools.partial(benchmark_func, self.args.warmup_iterations)
+
+ # Actual Execution
+ run_time = 0
+ iters = self.iters
+ while True:
+ # Use Python's timeit module to measure execution time.
+ # Each experiment consists of repeated execution of
+ # the benchmark_func a number of times (self.iters)
+ # because otherwise the duration is too short to get
+ # an accurate measure. The benchmark loop is pushed
+ # to C++ to minimize Python overhead.
+ # The experiment is also repeated a number of times
+ # (num_repeats) and we then take the minimum execution
+ # time as the final measurement result (this is also
+ # recommended by timeit's doc).
+ run_time = run_time + min(timeit.repeat(functools.partial(benchmark_func, iters),
+ repeat=1, number=1))
+ # Analyze time after each run to decide if the result is stable
+ results_are_significant = self.has_explicit_iteration_count or \
+ self._report_iteration_result(iters, run_time)
+
+ if results_are_significant:
+ break
+
+ # Re-estimate the hopefully-sufficient
+ # iteration count, and run the benchmark again...
+ iters = self._predict_num_iter_needed(iters)
+
+ reported_run_time = (1e6 * run_time / iters)
+ self._print_perf_result(full_test_id, input_shapes, args, reported_run_time)
diff --git a/benchmarks/operator_benchmark/benchmark_pytorch.py b/benchmarks/operator_benchmark/benchmark_pytorch.py
new file mode 100644
index 0000000000..5f30542231
--- /dev/null
+++ b/benchmarks/operator_benchmark/benchmark_pytorch.py
@@ -0,0 +1,29 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.benchmarks.operator_benchmark import benchmark_core, benchmark_utils
+
+import torch
+
+"""PyTorch performance microbenchmarks.
+
+This module contains PyTorch-specific functionalities for performance
+microbenchmarks.
+"""
+
+
+def PyTorchOperatorTestCase(test_name, op_type, input_shapes, op_args, run_mode):
+ """Benchmark Tester function for Pytorch framework.
+ test_case is expected to be a PyTorchOperatorTestCase object. If not, the
+ function will return False.
+ It returns a function that contains the code to benchmarked
+ (operator execution).
+ """
+ inputs = [torch.from_numpy(benchmark_utils.numpy_random_fp32(*input)) for input in input_shapes]
+
+ def benchmark_func(num_runs):
+ op_type(*(inputs + [num_runs]))
+
+ benchmark_core.add_benchmark_tester("PyTorch", test_name, input_shapes, op_args, run_mode, benchmark_func)
diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py
new file mode 100644
index 0000000000..5e06a5a9e8
--- /dev/null
+++ b/benchmarks/operator_benchmark/benchmark_runner.py
@@ -0,0 +1,90 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import sys
+import argparse
+
+from caffe2.python import workspace
+
+from caffe2.benchmarks.operator_benchmark import benchmark_core
+
+import caffe2.benchmarks.operator_benchmark.benchmark_caffe2
+import caffe2.benchmarks.operator_benchmark.benchmark_pytorch
+
+import caffe2.benchmarks.operator_benchmark.ops.add
+import caffe2.benchmarks.operator_benchmark.ops.matmul
+
+"""Performance microbenchmarks's main binary.
+
+This is the main function for running performance microbenchmark tests.
+It also registers existing benchmark tests via Python module imports.
+"""
+
+
+if __name__ == "__main__":
+ print("Python version " + str(sys.version_info[0]))
+
+ parser = argparse.ArgumentParser(
+ description="Run microbenchmarks.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ '--run_mode',
+ help='Run mode. '
+ 'short: run all operators with few shapes'
+ 'long: run all operators with all shapes',
+ choices=benchmark_core.RUN_MODES.keys(),
+ default='short')
+
+ # This option is used to filter test cases to run.
+ # Currently, the matching is sub-string but we can consider support regex.
+ # For example, if test_case_filter = 'matmul', in will match these test
+ # cases:
+ # matmul_benchmark.Caffe2OperatorTestCase.matmul_512_128_512_transa_transb
+ # matmul_benchmark.PyTorchOperatorTestCase.matmul_100_200_150
+ # ...
+ parser.add_argument(
+ '--operator',
+ help='Only run the test cases that contain the provided operator'
+ ' as a substring of their names',
+ default=None)
+
+ parser.add_argument(
+ '--list_tests',
+ help='List all test cases without running them',
+ action='store_true')
+
+ parser.add_argument(
+ "--iterations",
+ help="Repeat each operator for the number of iterations",
+ type=int
+ )
+
+ parser.add_argument(
+ "--warmup_iterations",
+ help="Number of iterations to ignore before measuring performance",
+ default=10,
+ type=int
+ )
+
+ parser.add_argument(
+ "--ai_pep_format",
+ help="Print result when running on AI-PEP",
+ default=False,
+ type=bool
+ )
+
+ parser.add_argument(
+ '--framework',
+ help='Run PyTorch or Caffe2 operators',
+ default=None)
+
+ args = parser.parse_args()
+
+ workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
+ workspace.ClearGlobalNetObserver()
+
+ benchmark_core.BenchmarkRunner(args).run()
diff --git a/benchmarks/operator_benchmark/benchmark_utils.py b/benchmarks/operator_benchmark/benchmark_utils.py
new file mode 100644
index 0000000000..e0d5231312
--- /dev/null
+++ b/benchmarks/operator_benchmark/benchmark_utils.py
@@ -0,0 +1,35 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import numpy as np
+import itertools
+import random
+
+
+"""Performance microbenchmarks's utils.
+
+This module contains utilities for writing microbenchmark tests.
+"""
+
+
+def shape_to_string(shape):
+ return ', '.join([str(x) for x in shape])
+
+
+def numpy_random_fp32(*shape):
+ """Return a random numpy tensor of float32 type.
+ """
+ # TODO: consider more complex/custom dynamic ranges for
+ # comprehensive test coverage.
+ return np.random.rand(*shape).astype(np.float32)
+
+
+def cross_product(*inputs):
+ return (list(itertools.product(*inputs)))
+
+
+def get_n_rand_nums(min_val, max_val, n):
+ random.seed((1 << 32) - 1)
+ return random.sample(range(min_val, max_val), n)
diff --git a/benchmarks/operator_benchmark/ops/__init__.py b/benchmarks/operator_benchmark/ops/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/benchmarks/operator_benchmark/ops/__init__.py
diff --git a/benchmarks/operator_benchmark/ops/add.py b/benchmarks/operator_benchmark/ops/add.py
new file mode 100644
index 0000000000..23d208cf47
--- /dev/null
+++ b/benchmarks/operator_benchmark/ops/add.py
@@ -0,0 +1,68 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.benchmarks.operator_benchmark import benchmark_core, benchmark_utils
+
+from caffe2.benchmarks.operator_benchmark.benchmark_caffe2 import Caffe2OperatorTestCase
+from caffe2.benchmarks.operator_benchmark.benchmark_pytorch import PyTorchOperatorTestCase
+
+import torch
+
+
+"""Microbenchmarks for element-wise Add operator. Supports both Caffe2/PyTorch."""
+
+# Input shapes that we test and the run mode for each shape.
+# Sum up two tensors with the same shape
+
+
+def generate_inputs():
+ ms = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=1)
+ ns = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
+ ks = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
+ mode = ['long']
+
+ test_cases = benchmark_utils.cross_product([ms], mode)
+
+ two_dims = benchmark_utils.cross_product(ms, ns)
+ two_dims = benchmark_utils.cross_product(two_dims, mode)
+ test_cases.extend(two_dims)
+
+ three_dims = benchmark_utils.cross_product(ms, ns, ks)
+ three_dims = benchmark_utils.cross_product(three_dims, mode)
+ test_cases.extend(three_dims)
+
+ # Representative inputs
+ test_cases.extend([([128], 'short'),
+ ([64, 128], 'short'),
+ ([32, 64, 128], 'short')])
+ return test_cases
+
+
+@torch.jit.script
+def torch_add(a, b, iterations):
+ # type: (Tensor, Tensor, int)
+ result = torch.jit.annotate(torch.Tensor, None)
+ for _ in range(iterations):
+ result = torch.add(a, b)
+ return result
+
+
+@benchmark_core.benchmark_test_group
+def add_test_cases():
+ test_cases = generate_inputs()
+ for test_case in test_cases:
+ X, run_mode = test_case
+ Caffe2OperatorTestCase(
+ test_name='add',
+ op_type='Add',
+ input_shapes=[X, X],
+ op_args={},
+ run_mode=run_mode)
+ PyTorchOperatorTestCase(
+ test_name='add',
+ op_type=torch_add,
+ input_shapes=[X, X],
+ op_args={},
+ run_mode=run_mode)
diff --git a/benchmarks/operator_benchmark/ops/matmul.py b/benchmarks/operator_benchmark/ops/matmul.py
new file mode 100644
index 0000000000..214e2a5eac
--- /dev/null
+++ b/benchmarks/operator_benchmark/ops/matmul.py
@@ -0,0 +1,63 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.benchmarks.operator_benchmark import benchmark_core, benchmark_utils
+
+from caffe2.benchmarks.operator_benchmark.benchmark_caffe2 import Caffe2OperatorTestCase
+from caffe2.benchmarks.operator_benchmark.benchmark_pytorch import PyTorchOperatorTestCase
+
+import torch
+
+
+"""Microbenchmarks for MatMul operator. Supports both Caffe2/PyTorch."""
+
+
+def generate_inputs():
+ # Random inputs
+ Ms = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
+ Ns = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
+ Ks = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
+ transpose_a = [False, True]
+ transpose_b = [True, False]
+ mode = ['long']
+ test_cases = benchmark_utils.cross_product(Ms, Ns, Ks, transpose_a, transpose_b, mode)
+
+ # Representative inputs
+ test_cases.extend([(8, 16, 64, False, False, 'short'),
+ (64, 64, 256, False, False, 'short'),
+ (256, 256, 256, False, False, 'short')])
+ return test_cases
+
+
+@torch.jit.script
+def torch_matmul(a, b, iterations):
+ # type: (Tensor, Tensor, int)
+ result = torch.jit.annotate(torch.Tensor, None)
+ for _ in range(iterations):
+ result = torch.matmul(a, b)
+ return result
+
+
+@benchmark_core.benchmark_test_group
+def matmul_test_cases():
+ test_cases = generate_inputs()
+ for test_case in test_cases:
+ M, N, K, trans_a, trans_b, run_mode = test_case
+ input_shapes = [(N, M) if trans_a else (M, N), (K, N) if trans_b else (N, K)]
+ Caffe2OperatorTestCase(
+ test_name='matmul',
+ op_type='MatMul',
+ input_shapes=input_shapes,
+ op_args={'trans_a': trans_a, 'trans_b': trans_b},
+ run_mode=run_mode)
+ if not trans_a and not trans_b:
+ # PyTorch's matmul does not take transpose flags, so we only
+ # have a test case when there are no transpose flags.
+ PyTorchOperatorTestCase(
+ test_name='matmul',
+ op_type=torch_matmul,
+ input_shapes=input_shapes,
+ op_args={},
+ run_mode=run_mode)
diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc
index e4f3e6fa58..dc2b339a84 100644
--- a/caffe2/python/pybind_state.cc
+++ b/caffe2/python/pybind_state.cc
@@ -1190,6 +1190,10 @@ void addGlobalMethods(py::module& m) {
NetBase* net = gWorkspace->GetNet(net_name);
net->DetachObserver(observer);
});
+ m.def("clear_global_net_observer", []() {
+ py::gil_scoped_release g;
+ caffe2::ClearGlobalNetObservers();
+ });
m.def("num_observers_on_net", [](const std::string& net_name) {
CAFFE_ENFORCE(gWorkspace);
CAFFE_ENFORCE(gWorkspace->GetNet(net_name), "Can't find net ", net_name);
@@ -1227,6 +1231,22 @@ void addGlobalMethods(py::module& m) {
CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def));
return true;
});
+ // Run an operator multiple times.
+ // This is needed for microbenchmarking as we want the benchmark loop to be in
+ // C++ to minimize overhead.
+ m.def("run_operator_multiple", [](const py::bytes& op_def, int num_runs) {
+ CAFFE_ENFORCE(gWorkspace);
+ OperatorDef def;
+ CAFFE_ENFORCE(ParseProtoFromLargeString(op_def.cast<std::string>(), &def));
+ py::gil_scoped_release g;
+ std::unique_ptr<OperatorBase> op(CreateOperator(def, gWorkspace));
+ for (int i = 0; i < num_runs; i++) {
+ if (!op->Run()) {
+ return false;
+ }
+ }
+ return true;
+ });
m.def(
"get_operator_cost",
[](const py::bytes& op_def, const std::vector<string>& input_blobs) {
diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py
index 342bdfcf8a..18fcd9bb42 100644
--- a/caffe2/python/workspace.py
+++ b/caffe2/python/workspace.py
@@ -185,6 +185,10 @@ def RunOperatorOnce(operator):
return C.run_operator_once(StringifyProto(operator))
+def RunOperatorMultiple(operator, num_runs):
+ return C.run_operator_multiple(StringifyProto(operator), num_runs)
+
+
def RunOperatorsOnce(operators):
for op in operators:
success = RunOperatorOnce(op)
@@ -193,6 +197,10 @@ def RunOperatorsOnce(operators):
return True
+def ClearGlobalNetObserver():
+ return C.clear_global_net_observer()
+
+
def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
try:
return func(*args, **kwargs)