summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/gemm_u8s8s32x_inner_product.cpp
blob: eb902a13843b0f62a848bacda9863003d10cd32f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "mkldnn_types.h"
#include "mkldnn_thread.hpp"
#include "simple_q10n.hpp"
#include "gemm_u8s8s32x_inner_product.hpp"

namespace mkldnn {
namespace impl {
namespace cpu {

using namespace math;
using namespace memory_format;

template <data_type_t dst_type>
void gemm_u8s8s32x_inner_product_fwd_t<dst_type>::execute_forward() {
#if USE_MKL_IGEMM
    auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
    auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
    auto bias = reinterpret_cast<const char *>(this->input_memory(2));
    auto dst = reinterpret_cast<dst_data_t *>(this->memory());

    const int MB = conf_.MB();
    const int OC = conf_.OC();

    bool wei_tr = utils::one_of(conf_.weights_pd()->desc()->format,
             oihw, oidhw, oi);

    const int M = OC;
    const int N = MB;
    const int K = conf_.IC_total_padded();
    const int8_t off_a = 0, off_b = 0;
    const int32_t off_c = 0;

    const int scale_idx_mult = conf_.attr()->output_scales_.mask_ == (1 << 1);
    const float *scales = conf_.attr()->output_scales_.scales_;
    const auto rmode = conf_.attr()->round_mode_;

    const auto &post_ops = conf_.attr()->post_ops_;
    const bool do_relu = post_ops.len_ == 1;
    const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;

    acc_data_t *acc = this->dst_is_acc_
        ? (acc_data_t *)dst
        : (acc_data_t *)this->scratchpad_->get();

    auto get_bias = [=, &bias](size_t off) -> acc_data_t {
#       define CASE(dt) case dt: return (acc_data_t)\
        (*((const prec_traits<dt>::type *)bias + off))
        switch (conf_.desc()->bias_desc.data_type) {
        CASE(data_type::s8);
        CASE(data_type::u8);
        CASE(data_type::s32);
        CASE(data_type::f32);
        default: assert(!"unimplemented");
        }
#       undef CASE
        return 0;
    };

    cblas_gemm_s8u8s32(CblasColMajor, wei_tr ? CblasTrans : CblasNoTrans,
            CblasNoTrans, CblasFixOffset, M, N, K, 1., weights,
            wei_tr ? K : M, off_a, src, K, off_b, 0., acc, M, &off_c);

    parallel_nd(MB, OC, [&](int mb, int oc) {
        size_t dst_off = mb * OC + oc;
        float d = (float)acc[dst_off];
        if (bias)
            d += get_bias(oc);
        d *= scales[oc * scale_idx_mult];
        if (do_relu && d < 0)
            d *= nslope;
        dst[dst_off] = qz_a1b0<float, dst_data_t>()(d, rmode);
    });
#endif
}

using namespace data_type;

template struct gemm_u8s8s32x_inner_product_fwd_t<f32>;
template struct gemm_u8s8s32x_inner_product_fwd_t<s32>;
template struct gemm_u8s8s32x_inner_product_fwd_t<s8>;
template struct gemm_u8s8s32x_inner_product_fwd_t<u8>;
}
}
}