summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/graph/Types.h3
-rw-r--r--arm_compute/graph/backends/CL/CLDeviceBackend.h12
-rw-r--r--examples/graph_alexnet.cpp3
-rw-r--r--examples/graph_deepspeech_v0_4_1.cpp3
-rw-r--r--examples/graph_edsr.cpp3
-rw-r--r--examples/graph_googlenet.cpp3
-rw-r--r--examples/graph_inception_resnet_v1.cpp3
-rw-r--r--examples/graph_inception_resnet_v2.cpp3
-rw-r--r--examples/graph_inception_v3.cpp3
-rw-r--r--examples/graph_inception_v4.cpp3
-rw-r--r--examples/graph_lenet.cpp3
-rw-r--r--examples/graph_mnist.cpp3
-rw-r--r--examples/graph_mobilenet.cpp4
-rw-r--r--examples/graph_mobilenet_v2.cpp3
-rw-r--r--examples/graph_resnet12.cpp3
-rw-r--r--examples/graph_resnet50.cpp3
-rw-r--r--examples/graph_resnet_v2_50.cpp3
-rw-r--r--examples/graph_resnext50.cpp3
-rw-r--r--examples/graph_shufflenet.cpp3
-rw-r--r--examples/graph_squeezenet.cpp3
-rw-r--r--examples/graph_squeezenet_v1_1.cpp3
-rw-r--r--examples/graph_srcnn955.cpp3
-rw-r--r--examples/graph_ssd_mobilenet.cpp3
-rw-r--r--examples/graph_vgg16.cpp3
-rw-r--r--examples/graph_vgg19.cpp3
-rw-r--r--examples/graph_vgg_vdsr.cpp3
-rw-r--r--examples/graph_yolov3.cpp3
-rw-r--r--src/graph/backends/CL/CLDeviceBackend.cpp8
-rw-r--r--tests/benchmark_examples/RunExample.cpp6
-rw-r--r--utils/CommonGraphOptions.cpp8
-rw-r--r--utils/CommonGraphOptions.h4
31 files changed, 79 insertions, 38 deletions
diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h
index c5d3d17a9..b891c1772 100644
--- a/arm_compute/graph/Types.h
+++ b/arm_compute/graph/Types.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -87,6 +87,7 @@ struct GraphConfig
CLTunerMode tuner_mode{ CLTunerMode::EXHAUSTIVE }; /**< Tuner mode to be used by the CL tuner */
int num_threads{ -1 }; /**< Number of threads to use (thread capable backends), if 0 the backend will auto-initialize, if -1 the backend will stay as it is. */
std::string tuner_file{ "acl_tuner.csv" }; /**< File to load/store tuning values from */
+ std::string mlgo_file{ "heuristics.mlgo" }; /**< Filename to load MLGO heuristics from */
};
/**< Device target types */
diff --git a/arm_compute/graph/backends/CL/CLDeviceBackend.h b/arm_compute/graph/backends/CL/CLDeviceBackend.h
index a8ee25d7e..82c0eacd1 100644
--- a/arm_compute/graph/backends/CL/CLDeviceBackend.h
+++ b/arm_compute/graph/backends/CL/CLDeviceBackend.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/graph/IDeviceBackend.h"
#include "arm_compute/runtime/CL/CLBufferAllocator.h"
+#include "arm_compute/runtime/CL/CLGEMMHeuristicsHandle.h"
#include "arm_compute/runtime/CL/CLTuner.h"
namespace arm_compute
@@ -70,10 +71,11 @@ public:
std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
private:
- int _context_count; /**< Counts how many contexts are currently using the backend */
- CLTuner _tuner; /**< CL kernel tuner */
- std::unique_ptr<CLBufferAllocator> _allocator; /**< CL buffer affinity allocator */
- std::string _tuner_file; /**< Filename to load/store the tuner's values from */
+ int _context_count; /**< Counts how many contexts are currently using the backend */
+ CLTuner _tuner; /**< CL kernel tuner */
+ CLGEMMHeuristicsHandle _gemm_heuristics; /**< GEMM heuristics */
+ std::unique_ptr<CLBufferAllocator> _allocator; /**< CL buffer affinity allocator */
+ std::string _tuner_file; /**< Filename to load/store the tuner's values from */
};
} // namespace backends
} // namespace graph
diff --git a/examples/graph_alexnet.cpp b/examples/graph_alexnet.cpp
index ce398be6c..7f4e75aaf 100644
--- a/examples/graph_alexnet.cpp
+++ b/examples/graph_alexnet.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -159,6 +159,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
// Load the precompiled kernels from a file into the kernel library, in this way the next time they are needed
// compilation won't be required.
diff --git a/examples/graph_deepspeech_v0_4_1.cpp b/examples/graph_deepspeech_v0_4_1.cpp
index 4a8a8b15a..a5658625c 100644
--- a/examples/graph_deepspeech_v0_4_1.cpp
+++ b/examples/graph_deepspeech_v0_4_1.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -211,6 +211,7 @@ public:
config.num_threads = common_params.threads;
config.use_tuner = common_params.enable_tuner;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
graph.finalize(common_params.target, config);
diff --git a/examples/graph_edsr.cpp b/examples/graph_edsr.cpp
index 77783d97e..0e41f1215 100644
--- a/examples/graph_edsr.cpp
+++ b/examples/graph_edsr.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2020 Arm Limited.
+ * Copyright (c) 2020-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -75,6 +75,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
context.set_config(config);
diff --git a/examples/graph_googlenet.cpp b/examples/graph_googlenet.cpp
index 0a5335561..7555d805c 100644
--- a/examples/graph_googlenet.cpp
+++ b/examples/graph_googlenet.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -130,6 +130,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_inception_resnet_v1.cpp b/examples/graph_inception_resnet_v1.cpp
index 7a55733a2..6ae5b5dc7 100644
--- a/examples/graph_inception_resnet_v1.cpp
+++ b/examples/graph_inception_resnet_v1.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -215,6 +215,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_inception_resnet_v2.cpp b/examples/graph_inception_resnet_v2.cpp
index 60236d078..ae37ee507 100644
--- a/examples/graph_inception_resnet_v2.cpp
+++ b/examples/graph_inception_resnet_v2.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -196,6 +196,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_inception_v3.cpp b/examples/graph_inception_v3.cpp
index 5cacbcb6e..8ceeb5c68 100644
--- a/examples/graph_inception_v3.cpp
+++ b/examples/graph_inception_v3.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -201,6 +201,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
graph.finalize(common_params.target, config);
diff --git a/examples/graph_inception_v4.cpp b/examples/graph_inception_v4.cpp
index db2a31047..cafa5c9f1 100644
--- a/examples/graph_inception_v4.cpp
+++ b/examples/graph_inception_v4.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -156,6 +156,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
// Load the precompiled kernels from a file into the kernel library, in this way the next time they are needed
diff --git a/examples/graph_lenet.cpp b/examples/graph_lenet.cpp
index e5783078f..6560a980c 100644
--- a/examples/graph_lenet.cpp
+++ b/examples/graph_lenet.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -111,6 +111,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_mnist.cpp b/examples/graph_mnist.cpp
index 85ab0ab97..4ef96cc59 100644
--- a/examples/graph_mnist.cpp
+++ b/examples/graph_mnist.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 Arm Limited.
+ * Copyright (c) 2019-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -140,6 +140,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_mobilenet.cpp b/examples/graph_mobilenet.cpp
index b73f7a2ab..09b6e6e09 100644
--- a/examples/graph_mobilenet.cpp
+++ b/examples/graph_mobilenet.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -100,6 +100,8 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_mobilenet_v2.cpp b/examples/graph_mobilenet_v2.cpp
index fa16c9464..b1b33be2f 100644
--- a/examples/graph_mobilenet_v2.cpp
+++ b/examples/graph_mobilenet_v2.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -91,6 +91,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_resnet12.cpp b/examples/graph_resnet12.cpp
index ebd2e5dd1..8818cf742 100644
--- a/examples/graph_resnet12.cpp
+++ b/examples/graph_resnet12.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -136,6 +136,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_resnet50.cpp b/examples/graph_resnet50.cpp
index 47d258ede..b585284c6 100644
--- a/examples/graph_resnet50.cpp
+++ b/examples/graph_resnet50.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -115,6 +115,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
graph.finalize(common_params.target, config);
diff --git a/examples/graph_resnet_v2_50.cpp b/examples/graph_resnet_v2_50.cpp
index 921fb145d..472bf02b4 100644
--- a/examples/graph_resnet_v2_50.cpp
+++ b/examples/graph_resnet_v2_50.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -118,6 +118,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
graph.finalize(common_params.target, config);
diff --git a/examples/graph_resnext50.cpp b/examples/graph_resnext50.cpp
index 1d9ed8dc8..ec87e0b88 100644
--- a/examples/graph_resnext50.cpp
+++ b/examples/graph_resnext50.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -102,6 +102,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_shufflenet.cpp b/examples/graph_shufflenet.cpp
index 300d0f15a..f90f36149 100644
--- a/examples/graph_shufflenet.cpp
+++ b/examples/graph_shufflenet.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -148,6 +148,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_squeezenet.cpp b/examples/graph_squeezenet.cpp
index 2e72c1476..3d32794e8 100644
--- a/examples/graph_squeezenet.cpp
+++ b/examples/graph_squeezenet.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -168,6 +168,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
graph.finalize(common_params.target, config);
diff --git a/examples/graph_squeezenet_v1_1.cpp b/examples/graph_squeezenet_v1_1.cpp
index 1708ac2f5..6d4ffee99 100644
--- a/examples/graph_squeezenet_v1_1.cpp
+++ b/examples/graph_squeezenet_v1_1.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -168,6 +168,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
graph.finalize(common_params.target, config);
diff --git a/examples/graph_srcnn955.cpp b/examples/graph_srcnn955.cpp
index bcc3824c6..f4ffc0213 100644
--- a/examples/graph_srcnn955.cpp
+++ b/examples/graph_srcnn955.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -119,6 +119,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
graph.finalize(common_params.target, config);
diff --git a/examples/graph_ssd_mobilenet.cpp b/examples/graph_ssd_mobilenet.cpp
index f5af84f4d..c0859227a 100644
--- a/examples/graph_ssd_mobilenet.cpp
+++ b/examples/graph_ssd_mobilenet.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -97,6 +97,7 @@ public:
config.num_threads = common_params.threads;
config.use_tuner = common_params.enable_tuner;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/examples/graph_vgg16.cpp b/examples/graph_vgg16.cpp
index a4c5e6bbd..83e663798 100644
--- a/examples/graph_vgg16.cpp
+++ b/examples/graph_vgg16.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -216,6 +216,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
graph.finalize(common_params.target, config);
diff --git a/examples/graph_vgg19.cpp b/examples/graph_vgg19.cpp
index c95fb0336..03f7e1606 100644
--- a/examples/graph_vgg19.cpp
+++ b/examples/graph_vgg19.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -227,6 +227,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
graph.finalize(common_params.target, config);
diff --git a/examples/graph_vgg_vdsr.cpp b/examples/graph_vgg_vdsr.cpp
index 3fa7dd133..bdb898081 100644
--- a/examples/graph_vgg_vdsr.cpp
+++ b/examples/graph_vgg_vdsr.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -140,6 +140,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
config.convert_to_uint8 = (common_params.data_type == DataType::QASYMM8);
graph.finalize(common_params.target, config);
diff --git a/examples/graph_yolov3.cpp b/examples/graph_yolov3.cpp
index 54aaf201c..3c8ddbffd 100644
--- a/examples/graph_yolov3.cpp
+++ b/examples/graph_yolov3.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -401,6 +401,7 @@ public:
config.use_tuner = common_params.enable_tuner;
config.tuner_mode = common_params.tuner_mode;
config.tuner_file = common_params.tuner_file;
+ config.mlgo_file = common_params.mlgo_file;
graph.finalize(common_params.target, config);
diff --git a/src/graph/backends/CL/CLDeviceBackend.cpp b/src/graph/backends/CL/CLDeviceBackend.cpp
index 50dd799ee..f8e22ca7a 100644
--- a/src/graph/backends/CL/CLDeviceBackend.cpp
+++ b/src/graph/backends/CL/CLDeviceBackend.cpp
@@ -65,7 +65,7 @@ bool file_exists(const std::string &filename)
static detail::BackendRegistrar<CLDeviceBackend> CLDeviceBackend_registrar(Target::CL);
CLDeviceBackend::CLDeviceBackend()
- : _context_count(0), _tuner(), _allocator(nullptr), _tuner_file()
+ : _context_count(0), _tuner(), _gemm_heuristics(), _allocator(nullptr), _tuner_file()
{
}
@@ -87,7 +87,7 @@ void CLDeviceBackend::set_kernel_tuning_mode(CLTunerMode tuning_mode)
void CLDeviceBackend::initialize_backend()
{
// Setup Scheduler
- CLScheduler::get().default_init(&_tuner);
+ CLScheduler::get().default_init(&_tuner, &_gemm_heuristics);
// Create allocator with new context
_allocator = std::make_unique<CLBufferAllocator>(nullptr /* legacy path for CLCoreRuntimeContext */);
}
@@ -123,6 +123,10 @@ void CLDeviceBackend::setup_backend_context(GraphContext &ctx)
set_kernel_tuning(ctx.config().use_tuner);
set_kernel_tuning_mode(ctx.config().tuner_mode);
+ // Attempt to load mlgo heuristics
+ ARM_COMPUTE_ERROR_ON(CLScheduler::get().gemm_heuristics() == nullptr);
+ CLScheduler::get().gemm_heuristics()->reload_from_file(ctx.config().mlgo_file);
+
// Setup a management backend
if(ctx.memory_management_ctx(Target::CL) == nullptr)
{
diff --git a/tests/benchmark_examples/RunExample.cpp b/tests/benchmark_examples/RunExample.cpp
index 925daaf15..8adcd95ff 100644
--- a/tests/benchmark_examples/RunExample.cpp
+++ b/tests/benchmark_examples/RunExample.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,6 +34,7 @@
#include "utils/command_line/CommandLineParser.h"
#ifdef ARM_COMPUTE_CL
+#include "arm_compute/runtime/CL/CLGEMMHeuristicsHandle.h"
#include "arm_compute/runtime/CL/CLHelpers.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
#endif /* ARM_COMPUTE_CL */
@@ -127,12 +128,13 @@ int run_example(int argc, char **argv, std::unique_ptr<Example> example)
}
#ifdef ARM_COMPUTE_CL
+ CLGEMMHeuristicsHandle gemm_h;
if(opencl_is_available())
{
auto ctx_dev_err = create_opencl_context_and_device();
ARM_COMPUTE_ERROR_ON_MSG(std::get<2>(ctx_dev_err) != CL_SUCCESS, "Failed to create OpenCL context");
CLScheduler::get()
- .default_init_with_context(std::get<1>(ctx_dev_err), std::get<0>(ctx_dev_err));
+ .default_init_with_context(std::get<1>(ctx_dev_err), std::get<0>(ctx_dev_err), nullptr, &gemm_h);
}
#endif /* ARM_COMPUTE_CL */
diff --git a/utils/CommonGraphOptions.cpp b/utils/CommonGraphOptions.cpp
index d262ea86e..44d66fa91 100644
--- a/utils/CommonGraphOptions.cpp
+++ b/utils/CommonGraphOptions.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -87,6 +87,7 @@ namespace utils
os << "Cache enabled? : " << (common_params.enable_cl_cache ? true_str : false_str) << std::endl;
os << "Tuner mode : " << common_params.tuner_mode << std::endl;
os << "Tuner file : " << common_params.tuner_file << std::endl;
+ os << "MLGO file : " << common_params.mlgo_file << std::endl;
os << "Fast math enabled? : " << (common_params.fast_math_hint == FastMathHint::Enabled ? true_str : false_str) << std::endl;
if(!common_params.data_path.empty())
{
@@ -129,7 +130,8 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser)
validation_file(parser.add_option<SimpleOption<std::string>>("validation-file")),
validation_path(parser.add_option<SimpleOption<std::string>>("validation-path")),
validation_range(parser.add_option<SimpleOption<std::string>>("validation-range")),
- tuner_file(parser.add_option<SimpleOption<std::string>>("tuner-file"))
+ tuner_file(parser.add_option<SimpleOption<std::string>>("tuner-file")),
+ mlgo_file(parser.add_option<SimpleOption<std::string>>("mlgo-file"))
{
std::set<arm_compute::graph::Target> supported_targets
{
@@ -183,6 +185,7 @@ CommonGraphOptions::CommonGraphOptions(CommandLineParser &parser)
validation_path->set_help("Path to the validation data");
validation_range->set_help("Range of the images to validate for (Format : start,end)");
tuner_file->set_help("File to load/save CLTuner values");
+ mlgo_file->set_help("File to load MLGO heuristics");
}
CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options)
@@ -211,6 +214,7 @@ CommonGraphParams consume_common_graph_parameters(CommonGraphOptions &options)
common_params.validation_range_start = validation_range.first;
common_params.validation_range_end = validation_range.second;
common_params.tuner_file = options.tuner_file->value();
+ common_params.mlgo_file = options.mlgo_file->value();
return common_params;
}
diff --git a/utils/CommonGraphOptions.h b/utils/CommonGraphOptions.h
index dac2e10b1..13cd653e4 100644
--- a/utils/CommonGraphOptions.h
+++ b/utils/CommonGraphOptions.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -108,6 +108,7 @@ struct CommonGraphParams
std::string validation_file{};
std::string validation_path{};
std::string tuner_file{};
+ std::string mlgo_file{};
unsigned int validation_range_start{ 0 };
unsigned int validation_range_end{ std::numeric_limits<unsigned int>::max() };
};
@@ -165,6 +166,7 @@ public:
SimpleOption<std::string> *validation_path; /**< Validation data path */
SimpleOption<std::string> *validation_range; /**< Validation range */
SimpleOption<std::string> *tuner_file; /**< File to load/store the tuner's values from */
+ SimpleOption<std::string> *mlgo_file; /**< File to load the MLGO heuristics from */
};
/** Consumes the common graph options and creates a structure containing any information