summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/src/cpu/jit_uni_softmax_kernel_f32.hpp
blob: b34b9de35ddb44167e792434330c530a9df2c5c6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/*******************************************************************************
* Copyright 2017 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef CPU_JIT_UNI_SOFTMAX_KERNEL_F32_HPP
#define CPU_JIT_UNI_SOFTMAX_KERNEL_F32_HPP

#include <cfloat>

#include "c_types_map.hpp"
#include "jit_generator.hpp"
#include "type_helpers.hpp"

#include "jit_primitive_conf.hpp"

namespace mkldnn {
namespace impl {
namespace cpu {

using namespace Xbyak;

template <cpu_isa_t isa>
struct jit_uni_softmax_kernel_f32 : public jit_generator {
    DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_softmax_kernel_f32)
    using Vmm = typename utils::conditional3<isa == sse42, Xmm,
            isa == avx2, Ymm, Zmm>::type;

    jit_uni_softmax_kernel_f32(jit_softmax_conf_t ajpp) : jpp(ajpp) {
        this->generate();
        jit_ker = (decltype(jit_ker))this->getCode();
    }

    jit_softmax_conf_t jpp;

    static status_t init_conf(jit_softmax_conf_t &jpp,
                       const softmax_desc_t &pd,
                       const memory_desc_wrapper &src_d,
                       const memory_desc_wrapper &dst_d);

    void operator()(jit_softmax_call_s *arg) { jit_ker(arg); }

    void prepare_table();
    void simd_expf(const Vmm &vmm_src);
    void scalar_expf(const Xmm &xmm_src);

    void simd_loop_max(int ur_inner, char label_tag);
    void simd_loop_exp(int ur_inner, char label_tag);
    void simd_loop_div(int ur_inner, char label_tag);

    void scalar_loop_max();
    void scalar_loop_exp();
    void scalar_loop_div();
private:
    void (*jit_ker)(jit_softmax_call_s *);

    const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
    const int vlen = cpu_isa_traits<isa>::vlen;

    Reg64 reg_work_amount   = rax;
    Reg64 reg_src_base_ptr  = rbx;
    Reg64 reg_dst_base_ptr  = rsi;
    Reg64 reg_src_ptr       = r8;
    Reg64 reg_dst_ptr       = r9;
    Reg64 reg_channels      = r12;
    Reg64 reg_ch_work       = r13;
    Reg64 reg_min           = rdx;
    Reg64 imm_addr64        = r14;

    Vmm vmm_aux0            = Vmm(0);
    Vmm vmm_aux1            = Vmm(1);
    Vmm vmm_aux2            = Vmm(2);
    Xmm xmm_aux0            = Xmm(0);
    Xmm xmm_aux1            = Xmm(1);
    Xmm xmm_aux2            = Xmm(2);

    Xmm xmm_float_min       = Xmm(3);
    Xmm xmm_one             = Xmm(4);
    Vmm vmm_one             = Vmm(4);

    Xmm xmm_max             = Xmm(5);
    Xmm xmm_denom           = Xmm(6);
    Xmm xmm_src             = Xmm(7);

    Opmask k_mask_tmp       = Opmask(2);

    unsigned char _cmp_gt_os = isa == avx512_common ? 14 : 6;

    int id_vreg_max(int ur_inner);
    int id_vreg_denom(int ur_inner);
    int id_vreg_src(int ur_inner);

    auto vreg_max(int ur_inner) -> Vmm;
    auto vreg_denom(int ur_inner) -> Vmm;
    auto vreg_src(int ur_inner) -> Vmm;

    Label loop_simd_unroll;
    Label loop_simd;
    Label loop_scalar;
    Label loop_end;
    Label l_table;

    unsigned char _op_floor = 1;

    void generate();
};

}
}
}

#endif