summaryrefslogtreecommitdiff
path: root/caffe2/perfkernels
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2018-12-03 12:14:47 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-12-03 12:18:19 -0800
commitb5181ba1df13762b47ae6a83d38e3497ef9d04aa (patch)
treef68322f9ee7608dbf98e7f409ab745a159584259 /caffe2/perfkernels
parent4b90702037cc07b756c273adfa3bbec550cd3cc1 (diff)
downloadpytorch-b5181ba1df13762b47ae6a83d38e3497ef9d04aa.tar.gz
pytorch-b5181ba1df13762b47ae6a83d38e3497ef9d04aa.tar.bz2
pytorch-b5181ba1df13762b47ae6a83d38e3497ef9d04aa.zip
add avx512 option (but no avx512 kernel yet) (#14664)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14664 This diff just adds a framework to add avx512 kernels. Please be really really careful about using avx512 kernels unless you're convinced using avx512 will bring good enough *overall* speedups because it can backfire because of cpu frequency going down. Reviewed By: duc0 Differential Revision: D13281944 fbshipit-source-id: 04fce8619c63f814944b727a99fbd7d35538eac6
Diffstat (limited to 'caffe2/perfkernels')
-rw-r--r--caffe2/perfkernels/CMakeLists.txt14
-rw-r--r--caffe2/perfkernels/common.h27
-rw-r--r--caffe2/perfkernels/common_avx512.cc23
-rw-r--r--caffe2/perfkernels/embedding_lookup.cc60
4 files changed, 105 insertions, 19 deletions
diff --git a/caffe2/perfkernels/CMakeLists.txt b/caffe2/perfkernels/CMakeLists.txt
index a5701da807..18ae10dad2 100644
--- a/caffe2/perfkernels/CMakeLists.txt
+++ b/caffe2/perfkernels/CMakeLists.txt
@@ -2,9 +2,11 @@
file(GLOB common_srcs *.cc)
file(GLOB avx_srcs *_avx.cc)
file(GLOB avx2_srcs *_avx2.cc)
-# exclude avx and avx2 srcs from common_srcs
+file(GLOB avx512_srcs *_avx512.cc)
+# exclude avx, avx2, and avx512 srcs from common_srcs
exclude(common_srcs "${common_srcs}" ${avx_srcs})
exclude(common_srcs "${common_srcs}" ${avx2_srcs})
+exclude(common_srcs "${common_srcs}" ${avx512_srcs})
# We will always build common srcs.
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs})
@@ -24,6 +26,7 @@ if (NOT MSVC AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "/arch:AVX")
set_target_properties(
Caffe2_perfkernels_avx2 PROPERTIES COMPILE_FLAGS "/arch:AVX2")
+ # Currently MSVC doesn't support AVX512
else()
set_target_properties(
Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "-mavx -mf16c")
@@ -33,6 +36,15 @@ if (NOT MSVC AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS)
set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
$<TARGET_OBJECTS:Caffe2_perfkernels_avx>
$<TARGET_OBJECTS:Caffe2_perfkernels_avx2>)
+
+ if (CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS)
+ add_library(Caffe2_perfkernels_avx512 OBJECT ${avx512_srcs})
+ add_dependencies(Caffe2_perfkernels_avx512 Caffe2_PROTO c10)
+ set_target_properties(
+ Caffe2_perfkernels_avx512 PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx2 -mfma -mavx -mf16c")
+ set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS}
+ $<TARGET_OBJECTS:Caffe2_perfkernels_avx512>)
+ endif()
endif()
# TODO(jiayq): currently, we only implement the very base files for the
diff --git a/caffe2/perfkernels/common.h b/caffe2/perfkernels/common.h
index 70d5945495..7fdf64fda6 100644
--- a/caffe2/perfkernels/common.h
+++ b/caffe2/perfkernels/common.h
@@ -7,6 +7,11 @@ implement a functionality called void foo(int a, float b).
In foo.h, do:
void foo(int a, float b);
+In foo_avx512.cc, do:
+ void foo__avx512(int a, float b) {
+ [actual avx512 implementation]
+ }
+
In foo_avx2.cc, do:
void foo__avx2(int a, float b) {
[actual avx2 implementation]
@@ -25,6 +30,7 @@ In foo.cc, do:
void foo(int a, float b) {
// You should always order things by their preference, faster
// implementations earlier in the function.
+ AVX512_DO(foo, a, b);
AVX2_DO(foo, a, b);
AVX_DO(foo, a, b);
BASE_DO(foo, a, b);
@@ -35,11 +41,12 @@ In foo.cc, do:
// and run time architecture support.
//
// During build time:
-// The build system should provide flags CAFFE2_PERF_WITH_AVX2 and
-// CAFFE2_PERF_WITH_AVX that corresponds to the __AVX__ and __AVX2__ flags
-// the compiler provides. Note that we do not use the compiler flags but
-// rely on the build system flags, because the common files (like foo.cc
-// above) will always be built without __AVX__ and __AVX2__.
+// The build system should provide flags CAFFE2_PERF_WITH_AVX512,
+// CAFFE2_PERF_WITH_AVX2, and CAFFE2_PERF_WITH_AVX that corresponds to the
+// __AVX512F__, __AVX512DQ__, __AVX__, and __AVX2__ flags the compiler
+// provides. Note that we do not use the compiler flags but rely on the build
+// system flags, because the common files (like foo.cc above) will always be
+// built without __AVX512F__, __AVX512DQ__, __AVX__ and __AVX2__.
// During run time:
// we use cpuid to identify cpu support and run the proper functions.
@@ -52,6 +59,16 @@ In foo.cc, do:
#define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__);
+#ifdef CAFFE2_PERF_WITH_AVX512
+#define AVX512_DO(funcname, ...) \
+ decltype(funcname##__base) funcname##__avx512; \
+ if (GetCpuId().avx512f() && GetCpuId().avx512dq()) { \
+ return funcname##__avx512(__VA_ARGS__); \
+ }
+#else // CAFFE2_PERF_WITH_AVX512
+#define AVX512_DO(funcname, ...)
+#endif // CAFFE2_PERF_WITH_AVX512
+
#ifdef CAFFE2_PERF_WITH_AVX2
#define AVX2_DO(funcname, ...) \
decltype(funcname##__base) funcname##__avx2; \
diff --git a/caffe2/perfkernels/common_avx512.cc b/caffe2/perfkernels/common_avx512.cc
new file mode 100644
index 0000000000..055f95d775
--- /dev/null
+++ b/caffe2/perfkernels/common_avx512.cc
@@ -0,0 +1,23 @@
+// This file is here merely to check that the flags are not mixed up: for
+// example, if your compiler did not specify -mavx512f and -mavx512dq,
+// you should not provide the CAFFE2_PERF_WITH_AVX512 macro.
+
+#include "caffe2/core/common.h"
+
+#ifdef CAFFE2_PERF_WITH_AVX512
+#if !defined(__AVX512F__) || !defined(__AVX512DQ__)
+#error( \
+ "You found a build system error: CAFFE2_PERF_WITH_AVX512 is defined" \
+ "but __AVX512F__ or __AVX512DQ__ is not defined" \
+ "(via e.g. -mavx512f and -mavx512dq).");
+#endif
+#endif // CAFFE2_PERF_WITH_AVX512
+
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+#ifndef CAFFE2_PERF_WITH_AVX512
+#error( \
+ "You found a build system error: __AVX512F__ and __AVX512DQ__ is defined" \
+ "(via e.g. -mavx512f and -mavx512dq) " \
+ "but CAFFE2_PERF_WITH_AVX512 is not defined.");
+#endif // CAFFE2_PERF_WITH_AVX512
+#endif
diff --git a/caffe2/perfkernels/embedding_lookup.cc b/caffe2/perfkernels/embedding_lookup.cc
index e98bc51c25..fa93ae7496 100644
--- a/caffe2/perfkernels/embedding_lookup.cc
+++ b/caffe2/perfkernels/embedding_lookup.cc
@@ -82,13 +82,19 @@ static void EmbeddingLookupGenericSlow(
// Proxy back to generic implementation
#define EMBEDDING_SPECIALIZATION( \
- IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType, IS_WEIGHT_POSITIONAL) \
+ IndexTypeName, \
+ IndexType, \
+ InTypeName, \
+ InType, \
+ OutTypeName, \
+ OutType, \
+ IS_WEIGHT_POSITIONAL) \
void \
EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base( \
- const int64_t block_size, \
- const int64_t output_size, \
- const int64_t index_size, \
- const int64_t data_size, \
+ const int64_t block_size, \
+ const int64_t output_size, \
+ const int64_t index_size, \
+ const int64_t data_size, \
const InType* input, \
const IndexType* indices, \
const int* lengths, \
@@ -115,10 +121,10 @@ static void EmbeddingLookupGenericSlow(
} \
template <> \
void EmbeddingLookup<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>( \
- const int64_t block_size, \
- const int64_t output_size, \
- const int64_t index_size, \
- const int64_t data_size, \
+ const int64_t block_size, \
+ const int64_t output_size, \
+ const int64_t index_size, \
+ const int64_t data_size, \
const InType* input, \
const IndexType* indices, \
const int* lengths, \
@@ -158,15 +164,43 @@ EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, false);
EMBEDDING_SPECIALIZATION(int64_t, int64_t, float, float, float, float, false);
EMBEDDING_SPECIALIZATION(int32_t, int32_t, half, at::Half, float, float, false);
EMBEDDING_SPECIALIZATION(int64_t, int64_t, half, at::Half, float, float, false);
-EMBEDDING_SPECIALIZATION(int32_t, int32_t, uint8_t, uint8_t, float, float, false);
-EMBEDDING_SPECIALIZATION(int64_t, int64_t, uint8_t, uint8_t, float, float, false);
+EMBEDDING_SPECIALIZATION(
+ int32_t,
+ int32_t,
+ uint8_t,
+ uint8_t,
+ float,
+ float,
+ false);
+EMBEDDING_SPECIALIZATION(
+ int64_t,
+ int64_t,
+ uint8_t,
+ uint8_t,
+ float,
+ float,
+ false);
EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, true);
EMBEDDING_SPECIALIZATION(int64_t, int64_t, float, float, float, float, true);
EMBEDDING_SPECIALIZATION(int32_t, int32_t, half, at::Half, float, float, true);
EMBEDDING_SPECIALIZATION(int64_t, int64_t, half, at::Half, float, float, true);
-EMBEDDING_SPECIALIZATION(int32_t, int32_t, uint8_t, uint8_t, float, float, true);
-EMBEDDING_SPECIALIZATION(int64_t, int64_t, uint8_t, uint8_t, float, float, true);
+EMBEDDING_SPECIALIZATION(
+ int32_t,
+ int32_t,
+ uint8_t,
+ uint8_t,
+ float,
+ float,
+ true);
+EMBEDDING_SPECIALIZATION(
+ int64_t,
+ int64_t,
+ uint8_t,
+ uint8_t,
+ float,
+ float,
+ true);
#undef EMBEDDING_SPECIALIZATION