summaryrefslogtreecommitdiff
path: root/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/ip/ip.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'inference-engine/thirdparty/mkl-dnn/tests/benchdnn/ip/ip.hpp')
-rw-r--r--inference-engine/thirdparty/mkl-dnn/tests/benchdnn/ip/ip.hpp64
1 files changed, 52 insertions, 12 deletions
diff --git a/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/ip/ip.hpp b/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/ip/ip.hpp
index 65d546232..2cb5b55f4 100644
--- a/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/ip/ip.hpp
+++ b/inference-engine/thirdparty/mkl-dnn/tests/benchdnn/ip/ip.hpp
@@ -24,22 +24,63 @@
#include "mkldnn_memory.hpp"
namespace ip {
+const size_t max_prb_len = 392;
+const size_t max_desc_len = 196;
-struct prb_t {
+struct desc_t {
+ int mb, oc, ic, id, ih, iw;
+ const char *name;
+};
+
+typedef struct dt_conf_t {
+ mkldnn_data_type_t dt;
+ double min, max; /* representative */
+ int f_min, f_max; /* fill range */
+ int f_base; /* fill base, use 0 */
+ int f_step; /* fill step, use 1 */
+ double f_sparsity; /* amount of non-zeros, default 0.25 */
+ double eps; /* acceptable error */
+} _dt_conf_t[DAT_TOTAL];
+
+extern const _dt_conf_t conf_f32;
+extern const _dt_conf_t conf_s32s16s16s32;
+extern const _dt_conf_t conf_u8s8s32s32;
+extern const _dt_conf_t conf_u8s8s8s32;
+extern const _dt_conf_t conf_u8s8u8s32;
+
+struct prb_t : public desc_t {
+ prb_t(const desc_t &desc, int mb, dir_t dir, const dt_conf_t *cfg,
+ const attr_t &attr)
+ : desc_t(desc), dir(dir), cfg(cfg), attr(attr), scales(NULL) {
+ if (mb)
+ this->mb = mb;
+ generate_oscales();
+ }
+ ~prb_t() { if (scales) zfree(scales); }
dir_t dir;
- int mb;
- int ic, ih, iw;
- int oc;
+ const dt_conf_t *cfg;
+ attr_t attr;
+ float *scales;
- mkldnn_data_type_t src_dt, wei_dt, acc_dt, dst_dt;
+ void generate_oscales();
};
-inline size_t src_off_f(const prb_t *p, int mb, int ic, int ih, int iw) {
- return ((mb * p->ic + ic) * p->ih + ih) * p->iw + iw;
+int str2desc(desc_t *desc, const char *str);
+void prb2str(const prb_t *p, char *buffer, bool canonical = false);
+void perf_report(const prb_t *p, const res_t *r, const char *pstr);
+const dt_conf_t *str2cfg(const char *str);
+const char *cfg2str(const dt_conf_t *cfg);
+
+extern const char *perf_template; /* performance output template */
+
+inline size_t src_off_f(
+ const prb_t *p, int mb, int ic, int id, int ih, int iw) {
+ return ((((size_t)mb * p->ic + ic) * p->id + id) * p->ih + ih) * p->iw + iw;
}
-inline size_t wei_off_f(const prb_t *p, int oc, int ic, int ih, int iw) {
- return ((oc * p->ic + ic) * p->ih + ih) * p->iw + iw;
+inline size_t wei_off_f(
+ const prb_t *p, int oc, int ic, int id, int ih, int iw) {
+ return ((((size_t)oc * p->ic + ic) * p->id + id) * p->ih + ih) * p->iw + iw;
}
inline size_t bia_off_f(const prb_t *p, int oc) { return oc; }
@@ -55,10 +96,9 @@ void compute_ref_bwd_d(const prb_t *p, dnn_mem_t &diff_src_m, dnn_mem_t &wei_m,
void compute_ref_bwd_w(const prb_t *p, dnn_mem_t &src_m, dnn_mem_t &diff_wei_m,
dnn_mem_t &diff_bia_m, dnn_mem_t &diff_dst_m);
-int doit(prb_t *p, res_t *res);
-
-int bench(int argc, char **argv);
+int doit(const prb_t *p, res_t *res);
+int bench(int argc, char **argv, bool main_bench = true);
}
#endif