summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/clDNN/src/include/convolution_inst.h
blob: b47e09fab9ad03254bf2610aab32802a17b31be7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
/*
// Copyright (c) 2016 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
*/

///////////////////////////////////////////////////////////////////////////////////////////////////
#pragma once
#include "api/CPP/convolution.hpp"
#include "primitive_inst.h"

#include <memory>

namespace cldnn
{

template <>
struct typed_program_node<convolution> : public typed_program_node_base<convolution>
{
    using parent = typed_program_node_base<convolution>;

public:
    typed_program_node(std::shared_ptr<primitive> prim, program_impl& prog)
        : parent(prim, prog)
        , split(this->get_primitive()->split())
        , depthwise_sep_opt(false)
        , transposed(false)
        , input_qf(this->get_primitive()->input_quantization_factor)
        , output_qf(this->get_primitive()->output_quantization_factor)
    {
    }

    void set_split(int32_t node_split) { split = node_split; }
    int32_t get_split() const { return split; }

    void set_depthwise_sep_opt(bool node_depthwise_sep_opt) { depthwise_sep_opt = node_depthwise_sep_opt; }
    bool get_depthwise_sep_opt() const { return depthwise_sep_opt; }

    void set_transposed(bool node_transposed) { transposed = node_transposed; }
    bool get_transposed() const { return transposed; }

    program_node& input() const { return get_dependency(0); }

    program_node& weights(size_t idx = 0) const
    {
        if (static_cast<int32_t>(idx) >= this->get_split())
            throw std::range_error("weights offset too big");

        return get_dependency(1 + idx);
    }

    program_node& bias(size_t idx = 0) const
    { 
        if (static_cast<int32_t>(idx) >= this->get_split())
            throw std::range_error("bias offset too big");

        return get_dependency(1 + this->get_split() + idx);
    }

    program_node& weights_quantization_factors(size_t idx = 0) const
    {
        if (static_cast<int32_t>(idx) >= this->get_split())
            throw std::range_error("quantization factor offset too big");

        return get_dependency(1 + 2*this->get_split() + idx);
    }

    program_node& output_calibration_factors(size_t idx = 0) const
    {
        if (static_cast<int32_t>(idx) >= this->get_split())
            throw std::range_error("calibration factor offset too big");

        return get_dependency(1 + 3 * this->get_split() + idx);
    }

    bool bias_term() const
    {
        return get_primitive()->bias.size() > 0;
    }

    bool weights_quantization_term() const
    {
        return get_primitive()->weights_quantization_factors.size() > 0;
    }

    bool output_calibration_term() const
    {
        return get_primitive()->output_calibration_factors.size() > 0;
    }
    
    float get_input_qf() const { return input_qf; }
    float get_output_qf() const { return output_qf; }

private:
    int32_t split;
    bool depthwise_sep_opt;
    bool transposed;
    float input_qf;
    float output_qf;
};

using convolution_node = typed_program_node<convolution>;

template <>
class typed_primitive_inst<convolution> : public typed_primitive_inst_base<convolution>
{
    using parent = typed_primitive_inst_base<convolution>;

public:
    static layout calc_output_layout(convolution_node const& node);
    static std::string to_string(convolution_node const& node);

public:
    typed_primitive_inst(network_impl& network, convolution_node const& node);

    memory_impl& weights_memory(size_t index) const
    {
        if (static_cast<int32_t>(index) >= node.get_split())
            throw std::range_error("weights offset too big");
        
        return dep_memory(1 + index);
    }

    memory_impl& bias_memory(size_t index) const
    { 
        if (static_cast<int32_t>(index) >= node.get_split())
            throw std::range_error("bias offset too big");

        return dep_memory(1 + node.get_split() + index);
    }

    memory_impl& weights_quantization_factors_memory(size_t index) const
    {
        if (static_cast<int32_t>(index) >= node.get_split())
            throw std::range_error("quantization factors offset too big");

        return dep_memory(1 + 2*node.get_split() + index);
    }

    memory_impl& output_calibration_factors_memory(size_t index) const
    {
        if (static_cast<int32_t>(index) >= node.get_split())
            throw std::range_error("quantization factors offset too big");

        return dep_memory(1 + 3 * node.get_split() + index);
    }

    bool bias_term() const
    {
        return node.bias_term();
    }

    bool weights_quantization_factors_term() const
    {
        return node.weights_quantization_term();
    }

    bool output_calibration_factors_term() const
    {
        return node.output_calibration_term();
    }
};

using convolution_inst = typed_primitive_inst<convolution>;

}