diff options
Diffstat (limited to 'runtimes/neurun/core/src/compiler/BackendResolver.h')
-rw-r--r-- | runtimes/neurun/core/src/compiler/BackendResolver.h | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/runtimes/neurun/core/src/compiler/BackendResolver.h b/runtimes/neurun/core/src/compiler/BackendResolver.h new file mode 100644 index 000000000..248ef2f2e --- /dev/null +++ b/runtimes/neurun/core/src/compiler/BackendResolver.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __NEURUN_COMPILER_BACKEND_RESOLVER_H__ +#define __NEURUN_COMPILER_BACKEND_RESOLVER_H__ + +#include <unordered_map> +#include <typeindex> + +#include "util/logging.h" +#include "backend/Backend.h" +#include "backend/BackendManager.h" +#include "backend/ITensorBuilder.h" +#include "model/OperationIndexMap.h" + +namespace neurun +{ +namespace compiler +{ + +class BackendResolver +{ +public: + BackendResolver(const model::Operands &operands, + const std::vector<const backend::Backend *> &backends, + const std::shared_ptr<backend::custom::KernelRegistry> ®istry) + { + for (const auto backend : backends) + { + _context_manager.emplace(backend, backend->newContext(operands, registry)); + } + } + + ~BackendResolver() = default; + BackendResolver(const BackendResolver &obj); + BackendResolver(BackendResolver &&obj) = default; + BackendResolver &operator=(const BackendResolver &obj); + BackendResolver &operator=(BackendResolver &&obj) = default; + +public: + const backend::BackendContext *getBackendContext(const model::OperationIndex &index) const + { + return _context_manager.at(_gen_map.at(index)).get(); + } + + const backend::BackendContext *getBackendContext(const backend::Backend *backend) const + { + return _context_manager.at(backend).get(); + } + + backend::TensorBuilderSet tensor_builders() const + { + backend::TensorBuilderSet ret; + for (const auto &e : _context_manager) + { + ret.insert(e.second->tensor_builder); + } + return ret; + } + + const backend::Backend *getBackend(const model::OperationIndex &index) const + { + return getBackendContext(index)->backend; + } + + void setBackend(const model::OperationIndex &index, const backend::Backend *backend) + { + _gen_map[index] = backend; + } + + void iterate(const std::function<void(const model::OperationIndex &, + const backend::BackendContext &)> &fn) const + { + for (const auto &e : _gen_map) + { + fn(e.first, *_context_manager.at(e.second)); + } + } + +private: + std::unordered_map<const backend::Backend *, std::unique_ptr<backend::BackendContext>> + _context_manager; + model::OperationIndexMap<const backend::Backend *> _gen_map; +}; + +} // namespace compiler +} // namespace neurun + +#endif // __NEURUN_COMPILER_BACKEND_RESOLVER_H__ |