summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMingzhe Li <mingzhe0908@fb.com>2019-04-18 17:03:56 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-18 17:07:02 -0700
commit45d5b6be48caae761970d6d4c99e4ed8bc82263e (patch)
treebdeb64abe7cf308e0de6d33c153e0130a21fb209
parentedf77fe64ae2f9becfbc1848721b096a14bcd820 (diff)
downloadpytorch-45d5b6be48caae761970d6d4c99e4ed8bc82263e.tar.gz
pytorch-45d5b6be48caae761970d6d4c99e4ed8bc82263e.tar.bz2
pytorch-45d5b6be48caae761970d6d4c99e4ed8bc82263e.zip
Enhance front-end to add op (#19433)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19433 For operator benchmark project, we need to cover a lot of operators, so the interface for adding operators needs to be very clean and simple. This diff is implementing a new interface to add op. Here is the logic to add new operator to the benchmark: ``` long_config = {} short_config = {} map_func add_test( [long_config, short_config], map_func, [caffe2 op] [pt op] ) ``` Reviewed By: zheng-xq Differential Revision: D14791191 fbshipit-source-id: ac6738507cf1b9d6013dc8e546a2022a9b177f05
-rw-r--r--benchmarks/operator_benchmark/benchmark_core.py2
-rw-r--r--benchmarks/operator_benchmark/benchmark_runner.py1
-rw-r--r--benchmarks/operator_benchmark/benchmark_test_generator.py72
-rw-r--r--benchmarks/operator_benchmark/benchmark_utils.py17
-rw-r--r--benchmarks/operator_benchmark/ops/add.py74
-rw-r--r--benchmarks/operator_benchmark/ops/matmul.py78
6 files changed, 157 insertions, 87 deletions
diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py
index f6ff591df5..34df22cbf0 100644
--- a/benchmarks/operator_benchmark/benchmark_core.py
+++ b/benchmarks/operator_benchmark/benchmark_core.py
@@ -40,7 +40,7 @@ def add_benchmark_tester(framework, op_name, input_shapes, op_args, run_mode, fu
BENCHMARK_TESTER[mode][func_name] = func
-def benchmark_test_group(func):
+def register_test(func):
"""Decorator to register a benchmark test group.
A benchmark test group is a function that returns a list of benchmark test
case objects to be run.
diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py
index 9f86006a99..86ed2dbd6b 100644
--- a/benchmarks/operator_benchmark/benchmark_runner.py
+++ b/benchmarks/operator_benchmark/benchmark_runner.py
@@ -12,6 +12,7 @@ from caffe2.benchmarks.operator_benchmark import benchmark_core
import caffe2.benchmarks.operator_benchmark.benchmark_caffe2
import caffe2.benchmarks.operator_benchmark.benchmark_pytorch
+import caffe2.benchmarks.operator_benchmark.benchmark_test_generator
import caffe2.benchmarks.operator_benchmark.ops.add
import caffe2.benchmarks.operator_benchmark.ops.matmul # noqa
diff --git a/benchmarks/operator_benchmark/benchmark_test_generator.py b/benchmarks/operator_benchmark/benchmark_test_generator.py
new file mode 100644
index 0000000000..f03e175002
--- /dev/null
+++ b/benchmarks/operator_benchmark/benchmark_test_generator.py
@@ -0,0 +1,72 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.benchmarks.operator_benchmark.benchmark_caffe2 import Caffe2OperatorTestCase
+from caffe2.benchmarks.operator_benchmark.benchmark_pytorch import PyTorchOperatorTestCase
+from caffe2.benchmarks.operator_benchmark.benchmark_utils import * # noqa
+
+
+def generate_test(configs, map_config, ops, OperatorTestCase):
+ """
+ This function is used to create PyTorch/Caffe2 operators based on configs.
+ configs usually include both long_config and short_config and they will be
+ mapped to input_shapes and args which are ready to be digested by an operator.
+ OperatorTestCase is used to create an operator with inputs/outputs and args.
+ """
+ for config in configs:
+ for case in config:
+ shapes_args_config = case[:-1]
+ mode = case[-1]
+ shapes_args = map_config(*shapes_args_config)
+ if shapes_args is not None:
+ for op in ops:
+ OperatorTestCase(
+ test_name=op[0],
+ op_type=op[1],
+ input_shapes=shapes_args[0],
+ op_args=shapes_args[1],
+ run_mode=mode)
+
+
+def generate_pt_test(configs, pt_map_func, pt_ops):
+ """
+ This function creates PyTorch operators which will be benchmarked.
+ """
+ generate_test(configs, pt_map_func, pt_ops, PyTorchOperatorTestCase)
+
+
+def generate_c2_test(configs, c2_map_func, c2_ops):
+ """
+ This function creates Caffe2 operators which will be benchmarked.
+ """
+ generate_test(configs, c2_map_func, c2_ops, Caffe2OperatorTestCase)
+
+
+def map_c2_config_add(M, N, K):
+ input_one = (M, N, K)
+ input_two = (M, N, K)
+ input_shapes = [input_one, input_two]
+ args = {}
+ return (input_shapes, args)
+
+map_pt_config_add = map_c2_config_add
+
+
+def map_c2_config_matmul(M, N, K, trans_a, trans_b):
+ input_one = (N, M) if trans_a else (M, N)
+ input_two = (K, N) if trans_b else (N, K)
+ input_shapes = [input_one, input_two]
+ args = {'trans_a': trans_a, 'trans_b': trans_b}
+ return (input_shapes, args)
+
+
+def map_pt_config_matmul(M, N, K, trans_a, trans_b):
+ input_one = (N, M) if trans_a else (M, N)
+ input_two = (K, N) if trans_b else (N, K)
+ input_shapes = [input_one, input_two]
+ args = {}
+ if not trans_a and not trans_b:
+ return (input_shapes, args)
+ return None
diff --git a/benchmarks/operator_benchmark/benchmark_utils.py b/benchmarks/operator_benchmark/benchmark_utils.py
index e0d5231312..1dc6479351 100644
--- a/benchmarks/operator_benchmark/benchmark_utils.py
+++ b/benchmarks/operator_benchmark/benchmark_utils.py
@@ -27,9 +27,26 @@ def numpy_random_fp32(*shape):
def cross_product(*inputs):
+ """
+ Return a list of cartesian product of input iterables.
+ For example, cross_product(A, B) returns ((x,y) for x in A for y in B).
+ """
return (list(itertools.product(*inputs)))
def get_n_rand_nums(min_val, max_val, n):
random.seed((1 << 32) - 1)
return random.sample(range(min_val, max_val), n)
+
+
+def generate_configs(**configs):
+ """
+ Given configs from users, we want to generate different combinations of
+ those configs
+ For example, given M = ((1, 2), N = (4, 5)) and sample_func being cross_product,
+ we will generate ((1, 4), (1, 5), (2, 4), (2, 5))
+ """
+ assert 'sample_func' in configs, "Missing sample_func to generat configs"
+ results = configs['sample_func'](
+ *[value for key, value in configs.items() if key != 'sample_func'])
+ return results
diff --git a/benchmarks/operator_benchmark/ops/add.py b/benchmarks/operator_benchmark/ops/add.py
index 23d208cf47..4025c0c0e5 100644
--- a/benchmarks/operator_benchmark/ops/add.py
+++ b/benchmarks/operator_benchmark/ops/add.py
@@ -3,10 +3,8 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
-from caffe2.benchmarks.operator_benchmark import benchmark_core, benchmark_utils
-
-from caffe2.benchmarks.operator_benchmark.benchmark_caffe2 import Caffe2OperatorTestCase
-from caffe2.benchmarks.operator_benchmark.benchmark_pytorch import PyTorchOperatorTestCase
+from caffe2.benchmarks.operator_benchmark import benchmark_core
+from caffe2.benchmarks.operator_benchmark.benchmark_test_generator import *
import torch
@@ -16,28 +14,23 @@ import torch
# Input shapes that we test and the run mode for each shape.
# Sum up two tensors with the same shape
-
-def generate_inputs():
- ms = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=1)
- ns = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
- ks = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
- mode = ['long']
-
- test_cases = benchmark_utils.cross_product([ms], mode)
-
- two_dims = benchmark_utils.cross_product(ms, ns)
- two_dims = benchmark_utils.cross_product(two_dims, mode)
- test_cases.extend(two_dims)
-
- three_dims = benchmark_utils.cross_product(ms, ns, ks)
- three_dims = benchmark_utils.cross_product(three_dims, mode)
- test_cases.extend(three_dims)
-
- # Representative inputs
- test_cases.extend([([128], 'short'),
- ([64, 128], 'short'),
- ([32, 64, 128], 'short')])
- return test_cases
+# Long config
+long_config = generate_configs(
+ M=get_n_rand_nums(min_val=1, max_val=128, n=2),
+ N=get_n_rand_nums(min_val=1, max_val=128, n=2),
+ K=get_n_rand_nums(min_val=1, max_val=128, n=2),
+ mode=['long'],
+ sample_func=cross_product,
+)
+
+# Short config
+short_config = generate_configs(
+ M=[8, 16],
+ N=[32, 64],
+ K=[64, 128],
+ mode=['short'],
+ sample_func=cross_product
+)
@torch.jit.script
@@ -49,20 +42,15 @@ def torch_add(a, b, iterations):
return result
-@benchmark_core.benchmark_test_group
-def add_test_cases():
- test_cases = generate_inputs()
- for test_case in test_cases:
- X, run_mode = test_case
- Caffe2OperatorTestCase(
- test_name='add',
- op_type='Add',
- input_shapes=[X, X],
- op_args={},
- run_mode=run_mode)
- PyTorchOperatorTestCase(
- test_name='add',
- op_type=torch_add,
- input_shapes=[X, X],
- op_args={},
- run_mode=run_mode)
+@benchmark_core.register_test
+def test_add():
+ generate_pt_test(
+ [long_config, short_config],
+ map_pt_config_add,
+ [('add', torch_add)]
+ )
+ generate_c2_test(
+ [long_config, short_config],
+ map_c2_config_add,
+ [('add', 'Add')],
+ )
diff --git a/benchmarks/operator_benchmark/ops/matmul.py b/benchmarks/operator_benchmark/ops/matmul.py
index 214e2a5eac..62a9b43f7a 100644
--- a/benchmarks/operator_benchmark/ops/matmul.py
+++ b/benchmarks/operator_benchmark/ops/matmul.py
@@ -3,32 +3,33 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
-from caffe2.benchmarks.operator_benchmark import benchmark_core, benchmark_utils
-
-from caffe2.benchmarks.operator_benchmark.benchmark_caffe2 import Caffe2OperatorTestCase
-from caffe2.benchmarks.operator_benchmark.benchmark_pytorch import PyTorchOperatorTestCase
-
+from caffe2.benchmarks.operator_benchmark import benchmark_core
+from caffe2.benchmarks.operator_benchmark.benchmark_test_generator import *
import torch
"""Microbenchmarks for MatMul operator. Supports both Caffe2/PyTorch."""
-
-
-def generate_inputs():
- # Random inputs
- Ms = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
- Ns = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
- Ks = benchmark_utils.get_n_rand_nums(min_val=1, max_val=128, n=2)
- transpose_a = [False, True]
- transpose_b = [True, False]
- mode = ['long']
- test_cases = benchmark_utils.cross_product(Ms, Ns, Ks, transpose_a, transpose_b, mode)
-
- # Representative inputs
- test_cases.extend([(8, 16, 64, False, False, 'short'),
- (64, 64, 256, False, False, 'short'),
- (256, 256, 256, False, False, 'short')])
- return test_cases
+# Long config
+long_config = generate_configs(
+ M=get_n_rand_nums(min_val=1, max_val=128, n=2),
+ N=get_n_rand_nums(min_val=1, max_val=128, n=2),
+ K=get_n_rand_nums(min_val=1, max_val=128, n=2),
+ transpose_a=[False, True],
+ transpose_b=[True, False],
+ mode=['long'],
+ sample_func=cross_product
+)
+
+# Short config
+short_config = generate_configs(
+ M=[8, 16],
+ N=[32, 64],
+ K=[64, 128],
+ transpose_a=[False, True],
+ transpose_b=[True, False],
+ mode=['short'],
+ sample_func=cross_product
+)
@torch.jit.script
@@ -40,24 +41,15 @@ def torch_matmul(a, b, iterations):
return result
-@benchmark_core.benchmark_test_group
-def matmul_test_cases():
- test_cases = generate_inputs()
- for test_case in test_cases:
- M, N, K, trans_a, trans_b, run_mode = test_case
- input_shapes = [(N, M) if trans_a else (M, N), (K, N) if trans_b else (N, K)]
- Caffe2OperatorTestCase(
- test_name='matmul',
- op_type='MatMul',
- input_shapes=input_shapes,
- op_args={'trans_a': trans_a, 'trans_b': trans_b},
- run_mode=run_mode)
- if not trans_a and not trans_b:
- # PyTorch's matmul does not take transpose flags, so we only
- # have a test case when there are no transpose flags.
- PyTorchOperatorTestCase(
- test_name='matmul',
- op_type=torch_matmul,
- input_shapes=input_shapes,
- op_args={},
- run_mode=run_mode)
+@benchmark_core.register_test
+def test_matmul():
+ generate_pt_test(
+ [long_config, short_config],
+ map_pt_config_matmul,
+ [('matmul', torch_matmul)]
+ )
+ generate_c2_test(
+ [long_config, short_config],
+ map_c2_config_matmul,
+ [('matmul', 'MatMul')],
+ )