summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/clDNN/src/include/deconvolution_inst.h
blob: a2e1516f2927e8f7eca75974bf76a5d89b048c4b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
/*
// Copyright (c) 2016 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "api/CPP/deconvolution.hpp"
#include "primitive_inst.h"

namespace cldnn
{

template <>
struct typed_program_node<deconvolution> : public typed_program_node_base<deconvolution>
{
    using parent = typed_program_node_base<deconvolution>;

public:
    typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog)
        : parent(prim, prog)
        , split(this->get_primitive()->split())
        , depthwise_sep_opt(false)
    {
    }

    
    void set_split(int32_t node_split) { split = node_split; }
    int32_t get_split() const { return split; }

    void set_depthwise_sep_opt(bool node_depthwise_sep_opt) { depthwise_sep_opt = node_depthwise_sep_opt; }
    bool get_depthwise_sep_opt() const { return depthwise_sep_opt; }

    program_node& input() const { return get_dependency(0); }

    program_node& weights(size_t idx = 0) const
    {
        if (static_cast<int32_t>(idx) >= get_split())
            throw std::range_error("weights offset too big");

        return get_dependency(1 + idx);
    }

    program_node& bias(size_t idx = 0) const
    { 
        if (static_cast<int32_t>(idx) >= get_split())
            throw std::range_error("bias offset too big");

        return get_dependency(1 + this->get_split() + idx);
    }

    bool bias_term() const
    {
        if (get_primitive()->bias.size() != 0)
            return true;
        else
            return false;
    }

    program_node& fused_sum(size_t idx = 0) const
    {
        if (static_cast<int32_t>(idx) > 0)
            throw std::range_error("Only one input for fused sum is supported");

        int d_idx = 1 + this->get_split() + idx;
        d_idx += bias_term() ? this->get_split() : 0;
        return get_dependency(d_idx);
    }

    bool has_fused_sum() const
    {
        int d_idx = 1 + this->get_split();
        d_idx += bias_term() ? this->get_split() : 0;
        return static_cast<int>(dependencies.size()) == (d_idx + 1);
    }

private:
    int32_t split;
    bool depthwise_sep_opt;
};

using deconvolution_node = typed_program_node<deconvolution>;

template <>
class typed_primitive_inst<deconvolution> : public typed_primitive_inst_base<deconvolution>
{
    using parent = typed_primitive_inst_base<deconvolution>;

public:
    static layout calc_output_layout(deconvolution_node const& node);
    static std::string to_string(deconvolution_node const& node);

public:
    typed_primitive_inst(network_impl& network, deconvolution_node const& node);

    memory_impl& weights_memory(size_t index) const
    {
        if (static_cast<int32_t>(index) >= node.get_split())
            throw std::range_error("weights offset too big");

        return dep_memory(1 + index);
    }

    memory_impl& bias_memory(size_t index) const
    {
        if (argument.bias.size() == 0 && static_cast<int32_t>(index) >= node.get_split())
            throw std::range_error("no bias data");

        if (static_cast<int32_t>(index) > node.get_split())
            throw std::range_error("bias offset too big");

        return dep_memory(1 + node.get_split() + index);
    }

    bool bias_term() const
    {
        if (argument.bias.size() != 0)
            return true;
        else
            return false;
    }
};

using deconvolution_inst = typed_primitive_inst<deconvolution>;

}