summaryrefslogtreecommitdiff
path: root/tools/calibration/aggregated_statistics.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/calibration/aggregated_statistics.py')
-rw-r--r--tools/calibration/aggregated_statistics.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/tools/calibration/aggregated_statistics.py b/tools/calibration/aggregated_statistics.py
index 52072c3b9..cc381a917 100644
--- a/tools/calibration/aggregated_statistics.py
+++ b/tools/calibration/aggregated_statistics.py
@@ -61,7 +61,8 @@ class AggregatedStatistics:
# TODO: can be refactored: we are itterating by all layers (to cover input layers output) to collect statistics
# for inference_result in inference_results:
for out_layer_name in layer_names:
- if self._ignore_layer_names and out_layer_name in self._ignore_layer_names:
+ if self._ignore_layer_names and out_layer_name in self._ignore_layer_names or \
+ out_layer_name in network.outputs and network.outputs[out_layer_name].layout.lower() == 'blocked':
continue
if out_layer_name in network.inputs:
@@ -133,6 +134,9 @@ class AggregatedStatistics:
element_to_take = int(len(max_values) * threshold / 100) if threshold else len(max_values)
elements_to_throw = len(max_values) - element_to_take if threshold else 0
+ element_to_take = len(max_values) if element_to_take > len(max_values) else element_to_take
+ elements_to_throw = (len(min_values) - 1) if elements_to_throw >= len(min_values) else elements_to_throw
+
max_values.sort()
min_values.sort()