summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexey Suhov <alexey.suhov@intel.com>2019-10-28 21:25:18 +0300
committerAlexey Suhov <alexey.suhov@intel.com>2019-10-28 21:25:18 +0300
commit6dfc778940ec1e52737404ddc5c9634a40064b4d (patch)
tree1c2d2e45d9e7cef570768714d170d81005db5c90
parent1798ac0d26d60c000c45c7d3614160d8f7f40925 (diff)
downloaddldt-6dfc778940ec1e52737404ddc5c9634a40064b4d.tar.gz
dldt-6dfc778940ec1e52737404ddc5c9634a40064b4d.tar.bz2
dldt-6dfc778940ec1e52737404ddc5c9634a40064b4d.zip
Publishing 2019 R3.1 content
-rw-r--r--inference-engine/README.md70
-rw-r--r--inference-engine/include/builders/ie_layer_decorator.hpp4
-rw-r--r--inference-engine/include/cldnn/cldnn_config.hpp3
-rw-r--r--inference-engine/include/dlia/dlia_config.hpp6
-rw-r--r--inference-engine/include/gna/gna_config.hpp15
-rw-r--r--inference-engine/include/hetero/hetero_plugin_config.hpp3
-rw-r--r--inference-engine/include/ie_plugin_config.hpp6
-rw-r--r--inference-engine/include/inference_engine.hpp3
-rw-r--r--inference-engine/include/multi-device/multi_device_config.hpp3
-rw-r--r--inference-engine/include/vpu/vpu_plugin_config.hpp3
-rw-r--r--inference-engine/tools/calibration_tool/README.md6
-rw-r--r--model-optimizer/extensions/analysis/__init__.py0
-rw-r--r--model-optimizer/extensions/analysis/inputs.py98
-rw-r--r--model-optimizer/extensions/analysis/json_print.py56
-rw-r--r--model-optimizer/extensions/analysis/nodes.py32
-rw-r--r--model-optimizer/extensions/analysis/tf_od_api.py81
-rw-r--r--model-optimizer/extensions/analysis/tf_yolo.py107
-rw-r--r--model-optimizer/extensions/front/tf/fifo_replacer.py4
-rw-r--r--model-optimizer/extensions/front/tf/fifo_replacer_test.py4
-rw-r--r--model-optimizer/extensions/front/tf/placeholder_with_default_ext.py33
-rw-r--r--model-optimizer/mo/front/tf/extractor.py1
-rw-r--r--model-optimizer/mo/main.py6
-rw-r--r--model-optimizer/mo/middle/passes/convert_data_type.py1
-rw-r--r--model-optimizer/mo/utils/import_extensions.py2
-rw-r--r--model-optimizer/mo/utils/model_analysis.py91
-rw-r--r--model-optimizer/mo/utils/utils.py31
-rw-r--r--model-optimizer/requirements.txt2
-rw-r--r--model-optimizer/requirements_caffe.txt2
-rw-r--r--model-optimizer/requirements_kaldi.txt2
-rw-r--r--model-optimizer/requirements_mxnet.txt2
-rw-r--r--model-optimizer/requirements_onnx.txt2
-rw-r--r--model-optimizer/requirements_tf.txt2
-rw-r--r--tools/calibration/logging.py2
33 files changed, 635 insertions, 48 deletions
diff --git a/inference-engine/README.md b/inference-engine/README.md
index cc2738330..9de589c03 100644
--- a/inference-engine/README.md
+++ b/inference-engine/README.md
@@ -22,6 +22,7 @@
- [Build Steps](#build-steps-2)
- [Additional Build Options](#additional-build-options-3)
- [Use Custom OpenCV Builds for Inference Engine](#use-custom-opencv-builds-for-inference-engine)
+- [Adding Inference Engine to your project](#adding-inference-engine-to-your-project)
- [(Optional) Additional Installation Steps for the Intel® Movidius™ Neural Compute Stick and Neural Compute Stick 2](#optional-additional-installation-steps-for-the-intel-movidius-neural-compute-stick-and-neural-compute-stick-2)
- [For Linux, Raspbian Stretch* OS](#for-linux-raspbian-stretch-os)
- [For Windows](#for-windows-1)
@@ -62,7 +63,13 @@ The software was validated on:
git submodule init
git submodule update --recursive
```
-2. Install build dependencies using the `install_dependencies.sh` script in the project root folder.
+2. Install build dependencies using the `install_dependencies.sh` script in the project root folder:
+ ```sh
+ chmod +x install_dependencies.sh
+ ```
+ ```sh
+ ./install_dependencies.sh
+ ```
3. By default, the build enables the Inference Engine GPU plugin to infer models on your Intel® Processor Graphics. This requires you to [Install Intel® Graphics Compute Runtime for OpenCL™ Driver package 19.04.12237](https://github.com/intel/compute-runtime/releases/tag/19.04.12237) before running the build. If you don't want to use the GPU plugin, use the `-DENABLE_CLDNN=OFF` CMake build option and skip the installation of the Intel® Graphics Compute Runtime for OpenCL™ Driver.
4. Create a build folder:
```sh
@@ -90,33 +97,20 @@ You can use the following additional build options:
- If the CMake-based build script can not find and download the OpenCV package that is supported on your platform, or if you want to use a custom build of the OpenCV library, refer to the [Use Custom OpenCV Builds](#use-custom-opencv-builds-for-inference-engine) section for details.
-- To build the Python API wrapper, use the `-DENABLE_PYTHON=ON` option. To specify an exact Python version, use the following options:
- ```sh
- -DPYTHON_EXECUTABLE=`which python3.7` \
- -DPYTHON_LIBRARY=/usr/lib/x86_64-linux-gnu/libpython3.7m.so \
- -DPYTHON_INCLUDE_DIR=/usr/include/python3.7
- ```
+- To build the Python API wrapper:
+ 1. Install all additional packages listed in the `/inference-engine/ie_bridges/python/requirements.txt` file:
+ ```sh
+ pip install -r requirements.txt
+ ```
+ 2. use the `-DENABLE_PYTHON=ON` option. To specify an exact Python version, use the following options:
+ ```sh
+ -DPYTHON_EXECUTABLE=`which python3.7` \
+ -DPYTHON_LIBRARY=/usr/lib/x86_64-linux-gnu/libpython3.7m.so \
+ -DPYTHON_INCLUDE_DIR=/usr/include/python3.7
+ ```
- To switch off/on the CPU and GPU plugins, use the `cmake` options `-DENABLE_MKL_DNN=ON/OFF` and `-DENABLE_CLDNN=ON/OFF` respectively.
-5. Adding to your project
-
- For CMake projects, set an environment variable `InferenceEngine_DIR`:
-
- ```sh
- export InferenceEngine_DIR=/path/to/dldt/inference-engine/build/
- ```
-
- Then you can find Inference Engine by `find_package`:
-
- ```cmake
- find_package(InferenceEngine)
-
- include_directories(${InferenceEngine_INCLUDE_DIRS})
-
- target_link_libraries(${PROJECT_NAME} ${InferenceEngine_LIBRARIES} dl)
- ```
-
## Build for Raspbian Stretch* OS
> **NOTE**: Only the MYRIAD plugin is supported.
@@ -371,7 +365,13 @@ The software was validated on:
git submodule init
git submodule update --recursive
```
-2. Install build dependencies using the `install_dependencies.sh` script in the project root folder.
+2. Install build dependencies using the `install_dependencies.sh` script in the project root folder:
+ ```sh
+ chmod +x install_dependencies.sh
+ ```
+ ```sh
+ ./install_dependencies.sh
+ ```
3. Create a build folder:
```sh
mkdir build
@@ -419,6 +419,22 @@ After you got the built OpenCV library, perform the following preparation steps
1. Set the `OpenCV_DIR` environment variable to the directory where the `OpenCVConfig.cmake` file of you custom OpenCV build is located.
2. Disable the package automatic downloading with using the `-DENABLE_OPENCV=OFF` option for CMake-based build script for Inference Engine.
+## Adding Inference Engine to your project
+
+For CMake projects, set the `InferenceEngine_DIR` environment variable:
+
+```sh
+export InferenceEngine_DIR=/path/to/dldt/inference-engine/build/
+```
+
+Then you can find Inference Engine by `find_package`:
+
+```cmake
+find_package(InferenceEngine)
+include_directories(${InferenceEngine_INCLUDE_DIRS})
+target_link_libraries(${PROJECT_NAME} ${InferenceEngine_LIBRARIES} dl)
+```
+
## (Optional) Additional Installation Steps for the Intel® Movidius™ Neural Compute Stick and Neural Compute Stick 2
> **NOTE**: These steps are only required if you want to perform inference on Intel® Movidius™ Neural Compute Stick or the Intel® Neural Compute Stick 2 using the Inference Engine MYRIAD Plugin. See also [Intel® Neural Compute Stick 2 Get Started](https://software.intel.com/en-us/neural-compute-stick/get-started)
@@ -461,7 +477,7 @@ For Intel® Movidius™ Neural Compute Stick and Intel® Neural Compute Stick 2,
1. Go to the `<DLDT_ROOT_DIR>/inference-engine/thirdparty/movidius/MovidiusDriver` directory, where the `DLDT_ROOT_DIR` is the directory to which the DLDT repository was cloned.
2. Right click on the `Movidius_VSC_Device.inf` file and choose **Install** from the pop up menu.
-You have installed the driver for your Intel® Movidius™ Neural Compute Stick or Intel® Neural Compute Stick 2.
+You have installed the driver for your Intel® Movidius™ Neural Compute Stick or Intel® Neural Compute Stick 2.
## Next Steps
diff --git a/inference-engine/include/builders/ie_layer_decorator.hpp b/inference-engine/include/builders/ie_layer_decorator.hpp
index c3b9c3488..3396a0fca 100644
--- a/inference-engine/include/builders/ie_layer_decorator.hpp
+++ b/inference-engine/include/builders/ie_layer_decorator.hpp
@@ -9,6 +9,10 @@
#include <vector>
namespace InferenceEngine {
+
+/**
+ * @brief Neural network builder API
+ */
namespace Builder {
/**
diff --git a/inference-engine/include/cldnn/cldnn_config.hpp b/inference-engine/include/cldnn/cldnn_config.hpp
index 6153c0aae..64ded2d2d 100644
--- a/inference-engine/include/cldnn/cldnn_config.hpp
+++ b/inference-engine/include/cldnn/cldnn_config.hpp
@@ -15,6 +15,9 @@
namespace InferenceEngine {
+/**
+ * @brief GPU plugin configuration
+ */
namespace CLDNNConfigParams {
/**
diff --git a/inference-engine/include/dlia/dlia_config.hpp b/inference-engine/include/dlia/dlia_config.hpp
index 1adca7ed3..a097205dd 100644
--- a/inference-engine/include/dlia/dlia_config.hpp
+++ b/inference-engine/include/dlia/dlia_config.hpp
@@ -16,6 +16,9 @@
namespace InferenceEngine {
+/**
+ * @brief DLIA plugin metrics
+ */
namespace DliaMetrics {
/**
@@ -37,6 +40,9 @@ DECLARE_DLIA_METRIC_VALUE(INPUT_STREAMING);
} // namespace DliaMetrics
+/**
+ * @brief DLIA plugin configuration
+ */
namespace DLIAConfigParams {
/**
diff --git a/inference-engine/include/gna/gna_config.hpp b/inference-engine/include/gna/gna_config.hpp
index ad8acf338..8f00d77b7 100644
--- a/inference-engine/include/gna/gna_config.hpp
+++ b/inference-engine/include/gna/gna_config.hpp
@@ -3,10 +3,10 @@
//
/**
- * @brief A header that defines advanced related properties for VPU plugins.
+ * @brief A header that defines advanced related properties for GNA plugin.
* These properties should be used in SetConfig() and LoadNetwork() methods of plugins
*
- * @file vpu_plugin_config.hpp
+ * @file gna_config.hpp
*/
#pragma once
@@ -16,9 +16,20 @@
namespace InferenceEngine {
+/**
+ * @brief GNA plugin configuration
+ */
namespace GNAConfigParams {
+/**
+ * @def GNA_CONFIG_KEY(name)
+ * @brief Shortcut for defining configuration keys
+ */
#define GNA_CONFIG_KEY(name) InferenceEngine::GNAConfigParams::_CONFIG_KEY(GNA_##name)
+/**
+ * @def GNA_CONFIG_VALUE(name)
+ * @brief Shortcut for defining configuration values
+ */
#define GNA_CONFIG_VALUE(name) InferenceEngine::GNAConfigParams::GNA_##name
#define DECLARE_GNA_CONFIG_KEY(name) DECLARE_CONFIG_KEY(GNA_##name)
diff --git a/inference-engine/include/hetero/hetero_plugin_config.hpp b/inference-engine/include/hetero/hetero_plugin_config.hpp
index 2eb362137..db6c0b40d 100644
--- a/inference-engine/include/hetero/hetero_plugin_config.hpp
+++ b/inference-engine/include/hetero/hetero_plugin_config.hpp
@@ -18,6 +18,9 @@
namespace InferenceEngine {
+/**
+ * @brief Heterogeneous plugin configuration
+ */
namespace HeteroConfigParams {
/**
diff --git a/inference-engine/include/ie_plugin_config.hpp b/inference-engine/include/ie_plugin_config.hpp
index 2d14316e0..7b9490e82 100644
--- a/inference-engine/include/ie_plugin_config.hpp
+++ b/inference-engine/include/ie_plugin_config.hpp
@@ -17,6 +17,9 @@
namespace InferenceEngine {
+/**
+ * @brief %Metrics
+ */
namespace Metrics {
#ifndef DECLARE_METRIC_KEY_IMPL
@@ -144,6 +147,9 @@ DECLARE_EXEC_NETWORK_METRIC_KEY(OPTIMAL_NUMBER_OF_INFER_REQUESTS, unsigned int);
} // namespace Metrics
+/**
+ * @brief Generic plugin configuration
+ */
namespace PluginConfigParams {
/**
diff --git a/inference-engine/include/inference_engine.hpp b/inference-engine/include/inference_engine.hpp
index 48ced7de9..dbe614e6b 100644
--- a/inference-engine/include/inference_engine.hpp
+++ b/inference-engine/include/inference_engine.hpp
@@ -28,6 +28,9 @@
#include <cpp/ie_executable_network.hpp>
#include <ie_version.hpp>
+/**
+ * @brief Inference Engine API
+ */
namespace InferenceEngine {
/**
* @brief Gets the top n results from a tblob
diff --git a/inference-engine/include/multi-device/multi_device_config.hpp b/inference-engine/include/multi-device/multi_device_config.hpp
index a5f037a53..41ee2c470 100644
--- a/inference-engine/include/multi-device/multi_device_config.hpp
+++ b/inference-engine/include/multi-device/multi_device_config.hpp
@@ -16,6 +16,9 @@
namespace InferenceEngine {
+/**
+ * @brief Multi Device plugin configuration
+ */
namespace MultiDeviceConfigParams {
/**
diff --git a/inference-engine/include/vpu/vpu_plugin_config.hpp b/inference-engine/include/vpu/vpu_plugin_config.hpp
index 5462acf19..62e84d0f3 100644
--- a/inference-engine/include/vpu/vpu_plugin_config.hpp
+++ b/inference-engine/include/vpu/vpu_plugin_config.hpp
@@ -37,6 +37,9 @@
namespace InferenceEngine {
+/**
+ * @brief VPU plugin configuration
+ */
namespace VPUConfigParams {
//
diff --git a/inference-engine/tools/calibration_tool/README.md b/inference-engine/tools/calibration_tool/README.md
index 06193dbb6..be870ede5 100644
--- a/inference-engine/tools/calibration_tool/README.md
+++ b/inference-engine/tools/calibration_tool/README.md
@@ -136,7 +136,7 @@ Command line:
python collect_statistics.py --config ~/inception_v1.yml -d ~/defenitions.yml -M /home/user/intel/openvino/deployment_tools/model_optimizer --models ~/models --source /media/user/calibration/datasets --annotations ~/annotations --converted_models ~/models
```
-Result model has statistics which allow you to infer this model in INT8 precision. To measure performance, you can use the [Benchmark App](./inference-engine/ie_bridges/python/sample/benchmark_app/README.md).
+Result model has statistics which allow you to infer this model in INT8 precision. To measure performance, you can use the [Benchmark App](./inference-engine/tools/benchmark_tool/README.md).
### Calibrate the Model
During calibration process, the model is adjusted for efficient quantization and minimization of accuracy drop on calibration dataset. Calibration tool produces calibrated model which will be executed in low precision 8-bit quantized mode after loading into CPU plugin.
@@ -180,4 +180,6 @@ To run the Calibration Tool in the simplified mode, use the following command:
```sh
python3 calibrate.py -sm -m <path-to-ir.xml> -s <path-to-dataset> -ss <images-number> -e <path-to-extensions-folder> -td <target-device> -precision <output-ir-precision> --output-dir <output-directory-path>
```
-It accepts models with FP32, FP16 precisions and image files as the dataset. \ No newline at end of file
+Input:
+- FP32 and FP16 models
+- image files as a dataset
diff --git a/model-optimizer/extensions/analysis/__init__.py b/model-optimizer/extensions/analysis/__init__.py
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/model-optimizer/extensions/analysis/__init__.py
diff --git a/model-optimizer/extensions/analysis/inputs.py b/model-optimizer/extensions/analysis/inputs.py
new file mode 100644
index 000000000..4cc678280
--- /dev/null
+++ b/model-optimizer/extensions/analysis/inputs.py
@@ -0,0 +1,98 @@
+"""
+ Copyright (c) 2019 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import logging as log
+
+import numpy as np
+
+from mo.graph.graph import Graph
+from mo.utils.model_analysis import AnalyzeAction
+
+
+class InputsAnalysis(AnalyzeAction):
+ """
+ The analyser gets information about model inputs and their default values if any.
+ """
+
+ @classmethod
+ def fifo_queue_analysis(cls, graph: Graph, inputs_desc: dict):
+ """
+ The FIFOQueue with QueueDeque has a separate input that specifies the size of batch to extract from queue. This
+ input is redundant and should be remove from the model analysis output.
+ """
+ inputs_to_ignore = set()
+ for fifo_queue in graph.get_op_nodes(op='FIFOQueueV2'):
+ if len(fifo_queue.get_outputs({'out': 0})) != 1:
+ log.debug('The FIFOQueue operation "{}" has more than 1 consumers'.format(fifo_queue.id))
+ continue
+ queue_deque = fifo_queue.out_node(0)
+ if queue_deque.op in ['QueueDequeueMany', 'QueueDequeueManyV2', 'QueueDequeueUpTo', 'QueueDequeueUpToV2']:
+ queue_deque_input_1 = queue_deque.in_node(1)
+ if queue_deque_input_1.op in ['Parameter', 'PlaceholderWithDefault']:
+ log.debug('Adding node "{}" to placeholder ignore list'.format(queue_deque_input_1.id))
+ inputs_to_ignore.add(queue_deque_input_1.id)
+
+ # create input per each QueueDeque output port
+ for port_ind in range(len(queue_deque.out_nodes())):
+ inputs_desc["{}:{}".format(queue_deque.id, port_ind)] = {'shape': fifo_queue.shapes[port_ind].tolist(),
+ 'value': None,
+ 'data_type': fifo_queue.types[port_ind]}
+ return inputs_to_ignore
+
+ @classmethod
+ def ignore_mxnet_softmax_inputs(cls, graph: Graph):
+ """
+ MxNet Softmax layers may have additional inputs which should be ignored. Refer to the
+ extensions/front/mxnet/check_softmax_node_inputs.py.
+ """
+ inputs_to_ignore = set()
+ softmax_nodes = []
+ [softmax_nodes.extend(graph.get_op_nodes(op=op)) for op in ('SoftMax', 'SoftmaxActivation', 'SoftmaxOutput')]
+ for softmax_node in softmax_nodes:
+ for i in range(1, len(softmax_node.in_nodes())):
+ if softmax_node.in_node(i).has_valid('op') and softmax_node.in_node(i).op == 'Parameter':
+ inputs_to_ignore.add(softmax_node.in_node(i).id)
+ return inputs_to_ignore
+
+ def analyze(self, graph: Graph):
+ inputs_desc = dict()
+
+ inputs_to_ignore = InputsAnalysis.fifo_queue_analysis(graph, inputs_desc)
+ if graph.graph['fw'] == 'mxnet':
+ inputs_to_ignore.update(InputsAnalysis.ignore_mxnet_softmax_inputs(graph))
+
+ inputs = graph.get_op_nodes(op='Parameter')
+ for input in inputs:
+ inputs_desc[input.name] = {'shape': input.soft_get('shape', None),
+ 'data_type': input.soft_get('data_type', None),
+ 'value': None,
+ }
+
+ placeholders_with_default = graph.get_op_nodes(op='PlaceholderWithDefault')
+ for input in placeholders_with_default:
+ inputs_desc[input.name] = {'shape': input.soft_get('shape', None),
+ 'data_type': input.soft_get('data_type', None),
+ 'value': input.in_node(0).value if 0 in input.in_nodes() and
+ input.in_node(0).has_valid('value') else None}
+
+ for input_to_ignore in inputs_to_ignore:
+ del inputs_desc[input_to_ignore]
+
+ # workaround for the ONNX models case where input shape is specified as string value like: "width", "height".
+ # In this case the string value is converted to 0, but in fact it is an arbitrary value so should be -1
+ if graph.graph['fw'] == 'onnx':
+ for inp in inputs_desc.values():
+ inp['shape'] = [-1 if item == 0 else item for item in inp['shape']]
+ return {'inputs': inputs_desc}
diff --git a/model-optimizer/extensions/analysis/json_print.py b/model-optimizer/extensions/analysis/json_print.py
new file mode 100644
index 000000000..327785f3b
--- /dev/null
+++ b/model-optimizer/extensions/analysis/json_print.py
@@ -0,0 +1,56 @@
+"""
+ Copyright (c) 2019 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import json
+import sys
+
+import numpy as np
+
+from extensions.front.user_data_repack import UserDataRepack
+from mo.graph.graph import Graph
+from mo.middle.passes.convert_data_type import np_data_type_to_precision
+from mo.utils.model_analysis import AnalyzeAction, AnalysisCollectorAnchor
+
+
+def prepare_obj_for_dump(obj: object):
+ if isinstance(obj, dict):
+ return {k: prepare_obj_for_dump(v) for k, v in obj.items()}
+ elif isinstance(obj, np.ndarray) or isinstance(obj, list):
+ return [prepare_obj_for_dump(elem) for elem in obj]
+ elif isinstance(obj, type):
+ return np_data_type_to_precision(obj)
+ elif isinstance(obj, np.generic):
+ return obj.item()
+ else:
+ return obj
+
+
+class AnalysisJSONPrint(AnalyzeAction):
+ """
+ The action prints the analysis results in JSON format.
+ """
+ enabled = False
+ id = 'ANALYSIS_JSON_PRINT'
+
+ def run_before(self):
+ return [UserDataRepack]
+
+ def run_after(self):
+ return [AnalysisCollectorAnchor]
+
+ def analyze(self, graph: Graph):
+ if 'analysis_results' in graph.graph and graph.graph['analysis_results'] is not None:
+ print(json.dumps(prepare_obj_for_dump(graph.graph['analysis_results'])))
+ sys.exit(0)
diff --git a/model-optimizer/extensions/analysis/nodes.py b/model-optimizer/extensions/analysis/nodes.py
new file mode 100644
index 000000000..34c0242f0
--- /dev/null
+++ b/model-optimizer/extensions/analysis/nodes.py
@@ -0,0 +1,32 @@
+"""
+ Copyright (c) 2019 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+from mo.graph.graph import Graph
+from mo.utils.model_analysis import AnalyzeAction
+
+
+class IntermediatesNodesAnalysis(AnalyzeAction):
+ """
+ The analyser gets node names, their shapes and values (if possible) of all nodes in the model.
+ """
+ def analyze(self, graph: Graph):
+ outputs_desc = dict()
+
+ for node in graph.get_op_nodes():
+ outputs_desc[node.id] = {'shape': node.soft_get('shape', None),
+ 'data_type': None,
+ 'value': None,
+ }
+ return {'intermediate': outputs_desc}
diff --git a/model-optimizer/extensions/analysis/tf_od_api.py b/model-optimizer/extensions/analysis/tf_od_api.py
new file mode 100644
index 000000000..459a5f6ab
--- /dev/null
+++ b/model-optimizer/extensions/analysis/tf_od_api.py
@@ -0,0 +1,81 @@
+"""
+ Copyright (c) 2019 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import logging as log
+
+from mo.graph.graph import Graph
+from mo.utils.model_analysis import AnalyzeAction, graph_contains_scope
+from mo.utils.utils import files_by_pattern, get_mo_root_dir
+
+
+class TensorFlowObjectDetectionAPIAnalysis(AnalyzeAction):
+ """
+ The analyser checks if the provided model is TF OD API model from
+ https://github.com/tensorflow/models/tree/master/research/object_detection/g3doc/detection_model_zoo.md of one of 4
+ supported flavors: SSD, RFCN, Faster RCNN, Mask RCNN.
+ """
+ graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
+
+ model_scopes = [('MaskRCNN', ['Preprocessor',
+ 'FirstStageFeatureExtractor',
+ 'SecondStageFeatureExtractor',
+ 'SecondStageBoxPredictor',
+ 'SecondStageBoxPredictor_1',
+ 'SecondStageFeatureExtractor_1',
+ ]),
+ ('RFCN', ['Preprocessor',
+ 'FirstStageFeatureExtractor',
+ 'SecondStageFeatureExtractor',
+ 'SecondStageBoxPredictor',
+ 'SecondStageBoxPredictor/map',
+ 'SecondStageBoxPredictor/map_1',
+ 'SecondStagePostprocessor',
+ ]),
+ ('FasterRCNN', ['Preprocessor',
+ 'FirstStageFeatureExtractor',
+ 'SecondStageFeatureExtractor',
+ 'SecondStageBoxPredictor',
+ 'SecondStagePostprocessor',
+ ]),
+ ('SSD', ['Preprocessor',
+ 'FeatureExtractor',
+ 'Postprocessor',
+ ]),
+ ]
+
+ file_patterns = {'MaskRCNN': 'mask_rcnn_support.*\\.json',
+ 'RFCN': 'rfcn_support.*\\.json',
+ 'FasterRCNN': 'faster_rcnn_support.*\\.json',
+ 'SSD': 'ssd.*_support.*\\.json',
+ }
+
+ def analyze(self, graph: Graph):
+ if any([name not in graph.nodes() for name in ['image_tensor', 'detection_classes', 'detection_boxes',
+ 'detection_scores']]):
+ log.debug('The model does not contain nodes that must exist in the TF OD API models')
+ return None
+
+ for flavor, scopes in __class__.model_scopes:
+ if all([graph_contains_scope(graph, scope) for scope in scopes]):
+ result = dict()
+ result['flavor'] = flavor
+ result['mandatory_parameters'] = {'tensorflow_use_custom_operations_config':
+ files_by_pattern(get_mo_root_dir() + '/extensions/front/tf',
+ __class__.file_patterns[flavor],
+ add_prefix=True),
+ 'tensorflow_object_detection_api_pipeline_config': None,
+ }
+ return {'model_type': {'TF_OD_API': result}}
+ return None
diff --git a/model-optimizer/extensions/analysis/tf_yolo.py b/model-optimizer/extensions/analysis/tf_yolo.py
new file mode 100644
index 000000000..3e9752fe0
--- /dev/null
+++ b/model-optimizer/extensions/analysis/tf_yolo.py
@@ -0,0 +1,107 @@
+"""
+ Copyright (c) 2019 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+
+from mo.graph.graph import Graph
+from mo.middle.pattern_match import apply_pattern
+from mo.utils.model_analysis import AnalyzeAction, graph_contains_scope
+
+
+YOLO_PATTERN = {
+ 'nodes': [
+ ('pad', dict(op='Pad')),
+ ('conv', dict(op='Conv2D')),
+ ('sub', dict(op='Sub')),
+ ('div', dict(op='Div')),
+ ('mul', dict(op='Mul')),
+ ('bias_add', dict(op='Add')),
+ ('mul_2', dict(op='Mul')),
+ ('max', dict(op='Maximum')),
+ ],
+ 'edges': [
+ ('pad', 'conv', {'out': 0}),
+ ('conv', 'sub', {'out': 0}),
+ ('sub', 'div', {'out': 0}),
+ ('div', 'mul', {'out': 0}),
+ ('mul', 'bias_add', {'out': 0}),
+ ('bias_add', 'mul_2', {'out': 0}),
+ ('bias_add', 'max', {'out': 0}),
+ ('mul_2', 'max', {'out': 0}),
+ ]
+}
+
+
+def pattern_instance_counter(graph: Graph, match: dict):
+ pattern_instance_counter.counter += 1
+
+
+pattern_instance_counter.counter = 0
+
+
+YOLO_CONFIGS = {'YOLOV2Full': ['extensions/front/tf/yolo_v2.json', 'extensions/front/tf/yolo_v2_voc.json'],
+ 'YOLOV3Full': ['extensions/front/tf/yolo_v3.json', 'extensions/front/tf/yolo_v3_voc.json'],
+ 'YOLOV2Tiny': ['extensions/front/tf/yolo_v2_tiny.json', 'extensions/front/tf/yolo_v2_tiny_voc.json'],
+ 'YOLOV3Tiny': ['extensions/front/tf/yolo_v3_tiny.json', 'extensions/front/tf/yolo_v3_tiny_voc.json'],
+ }
+
+
+def get_YOLO_params_by_flavor(flavor: str):
+ result = dict()
+ result['flavor'] = flavor
+ result['mandatory_parameters'] = {'tensorflow_use_custom_operations_config': YOLO_CONFIGS[flavor]}
+ return result
+
+
+class TensorFlowYOLOV1V2Analysis(AnalyzeAction):
+ """
+ The analyser checks if the provided model is TensorFlow YOLO models from https://github.com/thtrieu/darkflow .
+ """
+ graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
+
+ def analyze(self, graph: Graph):
+ pattern_instance_counter.counter = 0
+ apply_pattern(graph, **YOLO_PATTERN, action=pattern_instance_counter)
+
+ flavor = None
+ if pattern_instance_counter.counter > 0:
+ if pattern_instance_counter.counter == 22:
+ flavor = 'YOLOV2Full'
+ elif pattern_instance_counter.counter == 8:
+ flavor = 'YOLOV2Tiny'
+
+ if flavor is not None:
+ return {'model_type': {'YOLO': get_YOLO_params_by_flavor(flavor)}}
+ else:
+ return None
+
+
+class TensorFlowYOLOV3Analysis(AnalyzeAction):
+ """
+ The analyser checks if the provided model is TensorFlow YOLO models from
+ https://github.com/mystic123/tensorflow-yolo-v3.
+ """
+ graph_condition = [lambda graph: graph.graph['fw'] == 'tf']
+
+ def analyze(self, graph: Graph):
+ flavor = None
+ if graph_contains_scope(graph, 'detector/yolo-v3') and graph_contains_scope(graph, 'detector/darknet-53'):
+ flavor = 'YOLOV3Full'
+ elif graph_contains_scope(graph, 'detector/yolo-v3-tiny'):
+ flavor = 'YOLOV3Tiny'
+
+ if flavor is not None:
+ return {'model_type': {'YOLO': get_YOLO_params_by_flavor(flavor)}}
+ else:
+ return None
diff --git a/model-optimizer/extensions/front/tf/fifo_replacer.py b/model-optimizer/extensions/front/tf/fifo_replacer.py
index a8b421dde..8374a58a5 100644
--- a/model-optimizer/extensions/front/tf/fifo_replacer.py
+++ b/model-optimizer/extensions/front/tf/fifo_replacer.py
@@ -65,6 +65,7 @@ class FIFOQueue(FrontReplacementSubgraph):
"""
true_placeholder_shape = match['placeholder'].shape
placeholder_shape = match['fifo_queue'].shapes[0]
+ placeholder_data_type = match['fifo_queue'].types[0]
assert true_placeholder_shape.ndim <= 1
if true_placeholder_shape.ndim == 1 and len(true_placeholder_shape) > 1:
log.warning(
@@ -81,7 +82,8 @@ class FIFOQueue(FrontReplacementSubgraph):
graph.remove_node(out.out_node().id)
graph.remove_node(out.id)
graph.remove_node(match['batch_join'].id)
- placeholder = Parameter(graph, {'name': placeholder_name, 'shape': placeholder_shape}).create_node()
+ placeholder = Parameter(graph, {'name': placeholder_name, 'shape': placeholder_shape,
+ 'data_type': placeholder_data_type}).create_node()
graph.create_edge(placeholder, match['image_batch'])
log.info("FIFOQueueV2 pattern was detected. New shape of placeholder {} is {}. Use -b to set batch size if "
"needed".format(placeholder.id, placeholder['shape']))
diff --git a/model-optimizer/extensions/front/tf/fifo_replacer_test.py b/model-optimizer/extensions/front/tf/fifo_replacer_test.py
index b2a439996..13f9dd167 100644
--- a/model-optimizer/extensions/front/tf/fifo_replacer_test.py
+++ b/model-optimizer/extensions/front/tf/fifo_replacer_test.py
@@ -27,7 +27,7 @@ class TestFIFOQueueReplacement(unittest.TestCase):
nodes = {
'placeholder': {'op': 'Parameter', 'data_type': np.int32, 'kind': 'op', 'shape': np.array(1)},
'batch_join/fifo_queue': {'op': 'FIFOQueueV2', 'name': 'batch_join/fifo_queue',
- 'shapes': np.array([[1, 2, 3]]), 'kind': 'op'},
+ 'shapes': np.array([[1, 2, 3]]), 'types': np.array([np.float32]), 'kind': 'op'},
'batch_join': {'op': 'QueueDequeueUpToV2', 'kind': 'op'},
'image_batch': {'op': 'Identity', 'data_type': np.float32, 'kind': 'op'},
'label_batch': {'op': 'Identity', 'kind': 'op'},
@@ -56,7 +56,7 @@ class TestFIFOQueueReplacement(unittest.TestCase):
nodes_no_label = {
'placeholder': {'op': 'Parameter', 'data_type': np.int32, 'kind': 'op', 'shape': np.array(0)},
'batch_join/fifo_queue': {'op': 'FIFOQueueV2', 'name': 'batch_join/fifo_queue',
- 'shapes': np.array([[1, 2, 3]]), 'kind': 'op'},
+ 'shapes': np.array([[1, 2, 3]]), 'types': np.array([np.float32]), 'kind': 'op'},
'batch_join': {'op': 'QueueDequeueUpToV2', 'kind': 'op'},
'image_batch': {'op': 'Identity', 'data_type': np.float32, 'kind': 'op'},
}
diff --git a/model-optimizer/extensions/front/tf/placeholder_with_default_ext.py b/model-optimizer/extensions/front/tf/placeholder_with_default_ext.py
new file mode 100644
index 000000000..c23d9f097
--- /dev/null
+++ b/model-optimizer/extensions/front/tf/placeholder_with_default_ext.py
@@ -0,0 +1,33 @@
+"""
+ Copyright (c) 2019 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+from mo.front.extractor import FrontExtractorOp
+from mo.front.tf.extractors.utils import tf_dtype_extractor, tf_tensor_shape
+from mo.ops.op import Op
+
+
+class PlaceholderWithDefaultExtractor(FrontExtractorOp):
+ op = 'PlaceholderWithDefault'
+ enabled = True
+
+ @staticmethod
+ def extract(node):
+ attrs = {
+ 'data_type': tf_dtype_extractor(node.pb.attr["dtype"].type),
+ 'shape': tf_tensor_shape(node.pb.attr["shape"].shape),
+ 'identity': True,
+ }
+ Op.update_node_stat(node, attrs)
+ return __class__.enabled
diff --git a/model-optimizer/mo/front/tf/extractor.py b/model-optimizer/mo/front/tf/extractor.py
index ab17ee131..3097aaf15 100644
--- a/model-optimizer/mo/front/tf/extractor.py
+++ b/model-optimizer/mo/front/tf/extractor.py
@@ -82,7 +82,6 @@ tf_op_extractors = {
'SpaceToBatchND': node_pb_arg(tf_space_to_batch_ext),
'BatchToSpaceND': node_pb_arg(tf_batch_to_space_ext),
'ReadVariableOp': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True})),
- 'PlaceholderWithDefault': node_pb_arg(make_tf_eltwise(lambda v: v, attrs={'identity': True}))
}
diff --git a/model-optimizer/mo/main.py b/model-optimizer/mo/main.py
index 6dad42c97..24cd1754a 100644
--- a/model-optimizer/mo/main.py
+++ b/model-optimizer/mo/main.py
@@ -175,10 +175,6 @@ def driver(argv: argparse.Namespace):
if ret_code:
return ret_code
- if is_mxnet and not argv.input_shape:
- raise Error('Input shape is required to convert MXNet model. Please provide it with --input_shape. ' +
- refer_to_faq_msg(16))
-
mean_file_offsets = None
if is_caffe and argv.mean_file and argv.mean_values:
raise Error('Both --mean_file and mean_values are specified. Specify either mean file or mean values. ' +
@@ -279,7 +275,7 @@ def driver(argv: argparse.Namespace):
if ret_res != 0:
return ret_res
- if not (is_tf and argv.tensorflow_custom_operations_config_update):
+ if not (is_tf and argv.tensorflow_custom_operations_config_update) and not argv.silent:
output_dir = argv.output_dir if argv.output_dir != '.' else os.getcwd()
print('\n[ SUCCESS ] Generated IR model.')
print('[ SUCCESS ] XML file: {}.xml'.format(os.path.join(output_dir, model_name)))
diff --git a/model-optimizer/mo/middle/passes/convert_data_type.py b/model-optimizer/mo/middle/passes/convert_data_type.py
index ce785dbf4..d9859c567 100644
--- a/model-optimizer/mo/middle/passes/convert_data_type.py
+++ b/model-optimizer/mo/middle/passes/convert_data_type.py
@@ -30,6 +30,7 @@ SUPPORTED_DATA_TYPES = {
'uint8': (np.uint8, 'UI8'),
'int32': (np.int32, 'I32'),
'int64': (np.int64, 'I64'),
+ 'bool': (np.bool, 'BOOL'),
}
diff --git a/model-optimizer/mo/utils/import_extensions.py b/model-optimizer/mo/utils/import_extensions.py
index 0ed0ce653..1ac1385a6 100644
--- a/model-optimizer/mo/utils/import_extensions.py
+++ b/model-optimizer/mo/utils/import_extensions.py
@@ -24,6 +24,7 @@ from mo.back.replacement import BackReplacementPattern
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.op import Op
from mo.utils.class_registration import _check_unique_ids, update_registration, get_enabled_and_disabled_transforms
+from mo.utils.model_analysis import AnalyzeAction
def import_by_path(path: str, middle_names: list = ()):
@@ -73,6 +74,7 @@ def load_dir(framework: str, path: str, get_front_classes: callable):
front_classes = get_front_classes()
internal_dirs = {
('ops', ): [Op],
+ ('analysis',): [AnalyzeAction],
('front', ): front_classes,
('front', framework): front_classes,
('middle', ): [MiddleReplacementPattern],
diff --git a/model-optimizer/mo/utils/model_analysis.py b/model-optimizer/mo/utils/model_analysis.py
new file mode 100644
index 000000000..433f5160b
--- /dev/null
+++ b/model-optimizer/mo/utils/model_analysis.py
@@ -0,0 +1,91 @@
+"""
+ Copyright (c) 2019 Intel Corporation
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+"""
+import sys
+
+from extensions.front.user_data_repack import UserDataRepack
+from mo.graph.graph import Graph
+from mo.utils import class_registration
+from mo.utils.error import Error
+
+
+class AnalyzeAction(object):
+ registered_cls = []
+ registered_ops = {}
+ excluded_replacers = []
+ run_not_recursively = True
+
+ def find_and_replace_pattern(self, graph: Graph):
+ if 'analysis_results' not in graph.graph:
+ graph.graph['analysis_results'] = {'failed_analysers': []}
+
+ try:
+ result = self.analyze(graph) # pylint: disable=assignment-from-no-return
+ except SystemExit:
+ # the analysis transformation printing analysis results to the screen calls sys.exit(0) which in fact raises
+ # SystemExit exception, so we handle it here
+ sys.exit(0)
+ except:
+ graph.graph['analysis_results']['failed_analysers'].append(str(self.__class__))
+ result = None
+
+ if result is not None:
+ graph.graph['analysis_results'].update(result)
+
+ def analyze(self, graph: Graph):
+ raise Error('The method must be implemented in the sub-class')
+
+ def run_before(self):
+ """
+ Returns list of replacer classes which this replacer must be run before.
+ :return: list of classes
+ """
+ return [AnalysisCollectorAnchor, UserDataRepack]
+
+ def run_after(self):
+ """
+ Returns list of replacer classes which this replacer must be run after.
+ :return: list of classes
+ """
+ return []
+
+ @classmethod
+ def class_type(cls):
+ return class_registration.ClassType.FRONT_REPLACER
+
+
+class AnalysisCollectorAnchor(AnalyzeAction):
+ """
+ All analyzers should depend on this one which is an anchor analyzer to develop custom post-processor of all
+ analyzers results.
+ """
+
+ def run_before(self):
+ return []
+
+ def analyze(self, graph: Graph):
+ pass
+
+
+def graph_contains_scope(graph: Graph, scope: str):
+ """
+ Checks whether the graph contains node(s) which name starts with "scope" string.
+ :param graph: graph to check
+ :param scope: string defining the scope
+ :return: the result of the check (True/False)
+ """
+ if scope[-1] != '/':
+ scope += '/'
+ return any([node.soft_get('name').startswith(scope) for node in graph.get_op_nodes()])
diff --git a/model-optimizer/mo/utils/utils.py b/model-optimizer/mo/utils/utils.py
index 1001b42fa..4a919cb1e 100644
--- a/model-optimizer/mo/utils/utils.py
+++ b/model-optimizer/mo/utils/utils.py
@@ -14,8 +14,10 @@
limitations under the License.
"""
import functools
+import os
+import re
import warnings
-import logging as log
+
import numpy as np
@@ -77,3 +79,30 @@ def shrink_str_value(value: np.array, max_symbols=100):
if len(value) > max_symbols:
value = value.strip('\n')[:max_symbols - 3] + '...'
return value
+
+
+def files_by_pattern(dir: str, pattern: str, files_only=True, add_prefix=False):
+ """
+ Return a list of files and directories (or only files if the files_only is set to True) in the directory dir that
+ match pattern string pattern.
+ :param dir: Directory to search for files
+ :param pattern: string defining pattern name
+ :param files_only: flag to include only files (not directories) to the result
+ :param add_prefix: flag to include the prefix string to the file names
+ :return: list of file and directory names
+ """
+ pattern_compiled = re.compile(pattern)
+ matched_file_names = []
+ for file_name in os.listdir(dir):
+ if re.match(pattern_compiled, file_name) and (not files_only or os.path.isfile(os.path.join(dir, file_name))):
+ matched_file_names.append(os.path.join(dir, file_name) if add_prefix else file_name)
+ return matched_file_names
+
+
+def get_mo_root_dir():
+ """
+ Return the absolute path to the Model Optimizer root directory (where mo.py file is located)
+ :return: path to the MO root directory
+ """
+ return os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), os.pardir,
+ os.pardir))
diff --git a/model-optimizer/requirements.txt b/model-optimizer/requirements.txt
index 2dba5b90b..0849dc1b5 100644
--- a/model-optimizer/requirements.txt
+++ b/model-optimizer/requirements.txt
@@ -1,6 +1,6 @@
tensorflow>=1.2.0,<2.0.0
mxnet>=1.0.0,<=1.3.1
-networkx>=1.11
+networkx>=1.11,<2.4
numpy>=1.12.0
protobuf==3.6.1
onnx>=1.1.2
diff --git a/model-optimizer/requirements_caffe.txt b/model-optimizer/requirements_caffe.txt
index a032f83c8..12e20e2a5 100644
--- a/model-optimizer/requirements_caffe.txt
+++ b/model-optimizer/requirements_caffe.txt
@@ -1,4 +1,4 @@
-networkx>=1.11
+networkx>=1.11,<2.4
numpy>=1.12.0
protobuf==3.6.1
defusedxml>=0.5.0
diff --git a/model-optimizer/requirements_kaldi.txt b/model-optimizer/requirements_kaldi.txt
index acd2c87d4..b0c27bd49 100644
--- a/model-optimizer/requirements_kaldi.txt
+++ b/model-optimizer/requirements_kaldi.txt
@@ -1,3 +1,3 @@
-networkx>=1.11
+networkx>=1.11,<2.4
numpy==1.13.0
defusedxml>=0.5.0
diff --git a/model-optimizer/requirements_mxnet.txt b/model-optimizer/requirements_mxnet.txt
index 883ec69b3..c59ae05ad 100644
--- a/model-optimizer/requirements_mxnet.txt
+++ b/model-optimizer/requirements_mxnet.txt
@@ -1,4 +1,4 @@
mxnet>=1.0.0,<=1.3.1
-networkx>=1.11
+networkx>=1.11,<2.4
numpy>=1.12.0
defusedxml>=0.5.0
diff --git a/model-optimizer/requirements_onnx.txt b/model-optimizer/requirements_onnx.txt
index e0ed76ec1..c5974d483 100644
--- a/model-optimizer/requirements_onnx.txt
+++ b/model-optimizer/requirements_onnx.txt
@@ -1,4 +1,4 @@
onnx>=1.1.2
-networkx>=1.11
+networkx>=1.11,<2.4
numpy>=1.12.0
defusedxml>=0.5.0
diff --git a/model-optimizer/requirements_tf.txt b/model-optimizer/requirements_tf.txt
index 2accfb75a..97697bc8a 100644
--- a/model-optimizer/requirements_tf.txt
+++ b/model-optimizer/requirements_tf.txt
@@ -1,4 +1,4 @@
tensorflow>=1.2.0,<2.0.0
-networkx>=1.11
+networkx>=1.11,<2.4
numpy>=1.12.0
defusedxml>=0.5.0
diff --git a/tools/calibration/logging.py b/tools/calibration/logging.py
index f05f83f4e..0d99436ad 100644
--- a/tools/calibration/logging.py
+++ b/tools/calibration/logging.py
@@ -91,7 +91,7 @@ logging.config.dictConfig(_LOGGING_CONFIGURATION)
default_logger = logging.getLogger(_DEFAULT_LOGGER_NAME)
-def _warning_handler(message, category, filename, lineno):
+def _warning_handler(message, category, filename, lineno, *args, **kwargs):
s = warnings.formatwarning(message, category, filename, lineno)
default_logger.warning(s)