summaryrefslogtreecommitdiff
path: root/inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp')
-rw-r--r--inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp17
1 files changed, 11 insertions, 6 deletions
diff --git a/inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp b/inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp
index 00550a7d3..4f1850705 100644
--- a/inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp
+++ b/inference-engine/src/inference_engine/shape_infer/ie_reshaper.hpp
@@ -1,5 +1,4 @@
// Copyright (C) 2018 Intel Corporation
-//
// SPDX-License-Identifier: Apache-2.0
//
@@ -13,9 +12,11 @@
#include <memory>
#include <ie_layers.h>
+#include <ie_context.hpp>
+#include "../ie_network.hpp"
#include "details/caseless.hpp"
-#include "shape_infer/built-in/ie_built_in_holder.hpp"
#include "ie_reshape_launcher.hpp"
+#include "ie_icnn_network.hpp"
namespace InferenceEngine {
namespace ShapeInfer {
@@ -61,6 +62,8 @@ public:
explicit Reshaper(ICNNNetwork& network,
const LauncherCreator::Ptr& creator = std::make_shared<LauncherCreator>());
+ Reshaper(const Context& context, details::Network::Ptr& network);
+
virtual ~Reshaper() = default;
/**
@@ -74,20 +77,22 @@ public:
* Throws if shape infer failed without corruption of original shapes
* @param inputShapes - Map of input names (data) to their input shapes.
*/
- void run(const std::map<std::string, SizeVector>& inputShapes);
-
- using Ptr = std::shared_ptr<Reshaper>;
+ StatusCode run(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp = nullptr);
private:
ReshapeLauncher::Ptr getLauncherByLayerName(const std::string& layerName) const;
+ StatusCode networkShapeInfer(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp);
+
static InferenceEngine::details::caseless_set<std::string> getTypeNamesFromExtension(const IShapeInferExtensionPtr& extension);
-private:
std::vector<IShapeInferExtensionPtr> _extensions;
std::set<ReshapeLauncher::Ptr> _launchers;
std::vector<CNNLayerPtr> _allSortedLayers{};
std::set<CNNLayerPtr> _inputLayers{};
InferenceEngine::details::caseless_set<std::string> _allTypes;
+
+ Context ctx;
+ details::Network::Ptr network;
};
} // namespace ShapeInfer