diff options
Diffstat (limited to 'inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp')
-rw-r--r-- | inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp | 17 |
1 files changed, 11 insertions, 6 deletions
diff --git a/inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp b/inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp index 00550a7d3..4f1850705 100644 --- a/inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp +++ b/inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp @@ -1,5 +1,4 @@ // Copyright (C) 2018 Intel Corporation -// // SPDX-License-Identifier: Apache-2.0 // @@ -13,9 +12,11 @@ #include <memory> #include <ie_layers.h> +#include <ie_context.hpp> +#include "../ie_network.hpp" #include "details/caseless.hpp" -#include "shape_infer/built-in/ie_built_in_holder.hpp" #include "ie_reshape_launcher.hpp" +#include "ie_icnn_network.hpp" namespace InferenceEngine { namespace ShapeInfer { @@ -61,6 +62,8 @@ public: explicit Reshaper(ICNNNetwork& network, const LauncherCreator::Ptr& creator = std::make_shared<LauncherCreator>()); + Reshaper(const Context& context, details::Network::Ptr& network); + virtual ~Reshaper() = default; /** @@ -74,20 +77,22 @@ public: * Throws if shape infer failed without corruption of original shapes * @param inputShapes - Map of input names (data) to their input shapes. */ - void run(const std::map<std::string, SizeVector>& inputShapes); - - using Ptr = std::shared_ptr<Reshaper>; + StatusCode run(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp = nullptr); private: ReshapeLauncher::Ptr getLauncherByLayerName(const std::string& layerName) const; + StatusCode networkShapeInfer(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp); + static InferenceEngine::details::caseless_set<std::string> getTypeNamesFromExtension(const IShapeInferExtensionPtr& extension); -private: std::vector<IShapeInferExtensionPtr> _extensions; std::set<ReshapeLauncher::Ptr> _launchers; std::vector<CNNLayerPtr> _allSortedLayers{}; std::set<CNNLayerPtr> _inputLayers{}; InferenceEngine::details::caseless_set<std::string> _allTypes; + + Context ctx; + details::Network::Ptr network; }; } // namespace ShapeInfer |