summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/clDNN/src/include/batch_norm_inst.h
diff options
context:
space:
mode:
authoropenvino-pushbot <openvino_pushbot@intel.com>2018-10-16 13:45:03 +0300
committeropenvino-pushbot <openvino_pushbot@intel.com>2018-10-16 13:45:03 +0300
commit866530fb047cd17af6bd2dbbde5f5cb35f876840 (patch)
tree91451785d290a2481d82ed8dfe175aade3a0f727 /inference-engine/thirdparty/clDNN/src/include/batch_norm_inst.h
parentc37d4661a27afb408a45f7752acea968032afcc0 (diff)
downloaddldt-866530fb047cd17af6bd2dbbde5f5cb35f876840.tar.gz
dldt-866530fb047cd17af6bd2dbbde5f5cb35f876840.tar.bz2
dldt-866530fb047cd17af6bd2dbbde5f5cb35f876840.zip
Publishing R3
Diffstat (limited to 'inference-engine/thirdparty/clDNN/src/include/batch_norm_inst.h')
-rw-r--r--inference-engine/thirdparty/clDNN/src/include/batch_norm_inst.h66
1 files changed, 66 insertions, 0 deletions
diff --git a/inference-engine/thirdparty/clDNN/src/include/batch_norm_inst.h b/inference-engine/thirdparty/clDNN/src/include/batch_norm_inst.h
new file mode 100644
index 000000000..1973b739d
--- /dev/null
+++ b/inference-engine/thirdparty/clDNN/src/include/batch_norm_inst.h
@@ -0,0 +1,66 @@
+/*
+// Copyright (c) 2016 Intel Corporation
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+*/
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+#pragma once
+#include "api/CPP/batch_norm.hpp"
+#include "primitive_inst.h"
+
+namespace cldnn
+{
+
+template <>
+struct typed_program_node<batch_norm> : public typed_program_node_base<batch_norm>
+{
+ using parent = typed_program_node_base<batch_norm>;
+
+public:
+ using parent::parent;
+
+ decltype(auto) input() const { return get_dependency(0); }
+ decltype(auto) mean() const { return get_dependency(1); }
+ decltype(auto) variance() const { return get_dependency(2); }
+ decltype(auto) inv_variance() const { return get_dependency(1); };
+ bool variance_term() const { return !get_primitive()->variance.empty(); }
+ bool use_global_stats() const { return !get_primitive()->mean.empty() && !get_primitive()->variance.empty(); };
+ bool forwad_pass() const { return !get_primitive()->inv_variance.empty(); };
+
+};
+
+using batch_norm_node = typed_program_node<batch_norm>;
+
+template <>
+class typed_primitive_inst<batch_norm> : public typed_primitive_inst_base<batch_norm>
+{
+ using parent = typed_primitive_inst_base<batch_norm>;
+
+public:
+ static layout calc_output_layout(batch_norm_node const& node);
+ static std::string to_string(batch_norm_node const& node);
+
+public:
+ typed_primitive_inst(network_impl& network, batch_norm_node const& node);
+
+ decltype(auto) mean_memory() const { return dep_memory(1); }
+ decltype(auto) variance_memory() const { return dep_memory(2); }
+ decltype(auto) inv_variance_memory() const { return dep_memory(1); };
+ bool use_global_stats() const { return !argument.mean.empty() && !argument.variance.empty(); };
+ bool forwad_pass() const { return !argument.inv_variance.empty(); };
+};
+
+using batch_norm_inst = typed_primitive_inst<batch_norm>;
+
+}