1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
/*******************************************************************************
* Copyright 2017 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef CPU_JIT_UNI_SOFTMAX_KERNEL_F32_HPP
#define CPU_JIT_UNI_SOFTMAX_KERNEL_F32_HPP
#include <cfloat>
#include "c_types_map.hpp"
#include "jit_generator.hpp"
#include "type_helpers.hpp"
#include "jit_primitive_conf.hpp"
namespace mkldnn {
namespace impl {
namespace cpu {
using namespace Xbyak;
template <cpu_isa_t isa>
struct jit_uni_softmax_kernel_f32 : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_softmax_kernel_f32)
using Vmm = typename utils::conditional3<isa == sse42, Xmm,
isa == avx2, Ymm, Zmm>::type;
jit_uni_softmax_kernel_f32(jit_softmax_conf_t ajpp) : jpp(ajpp) {
this->generate();
jit_ker = (decltype(jit_ker))this->getCode();
}
jit_softmax_conf_t jpp;
static status_t init_conf(jit_softmax_conf_t &jpp,
const softmax_desc_t &pd,
const memory_desc_wrapper &src_d,
const memory_desc_wrapper &dst_d);
void operator()(jit_softmax_call_s *arg) { jit_ker(arg); }
void prepare_table();
void simd_expf(const Vmm &vmm_src);
void scalar_expf(const Xmm &xmm_src);
void simd_loop_max(int ur_inner, char label_tag);
void simd_loop_exp(int ur_inner, char label_tag);
void simd_loop_div(int ur_inner, char label_tag);
void scalar_loop_max();
void scalar_loop_exp();
void scalar_loop_div();
private:
void (*jit_ker)(jit_softmax_call_s *);
const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
const int vlen = cpu_isa_traits<isa>::vlen;
Reg64 reg_work_amount = rax;
Reg64 reg_src_base_ptr = rbx;
Reg64 reg_dst_base_ptr = rsi;
Reg64 reg_src_ptr = r8;
Reg64 reg_dst_ptr = r9;
Reg64 reg_channels = r12;
Reg64 reg_ch_work = r13;
Reg64 reg_min = rdx;
Reg64 imm_addr64 = r14;
Vmm vmm_aux0 = Vmm(0);
Vmm vmm_aux1 = Vmm(1);
Vmm vmm_aux2 = Vmm(2);
Xmm xmm_aux0 = Xmm(0);
Xmm xmm_aux1 = Xmm(1);
Xmm xmm_aux2 = Xmm(2);
Xmm xmm_float_min = Xmm(3);
Xmm xmm_one = Xmm(4);
Vmm vmm_one = Vmm(4);
Xmm xmm_max = Xmm(5);
Xmm xmm_denom = Xmm(6);
Xmm xmm_src = Xmm(7);
Opmask k_mask_tmp = Opmask(2);
unsigned char _cmp_gt_os = isa == avx512_common ? 14 : 6;
int id_vreg_max(int ur_inner);
int id_vreg_denom(int ur_inner);
int id_vreg_src(int ur_inner);
auto vreg_max(int ur_inner) -> Vmm;
auto vreg_denom(int ur_inner) -> Vmm;
auto vreg_src(int ur_inner) -> Vmm;
Label loop_simd_unroll;
Label loop_simd;
Label loop_scalar;
Label loop_end;
Label l_table;
unsigned char _op_floor = 1;
void generate();
};
}
}
}
#endif
|