diff options
author | Jongsoo Park <jongsoo@fb.com> | 2018-12-03 12:14:47 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2018-12-03 12:18:19 -0800 |
commit | b5181ba1df13762b47ae6a83d38e3497ef9d04aa (patch) | |
tree | f68322f9ee7608dbf98e7f409ab745a159584259 /caffe2/perfkernels | |
parent | 4b90702037cc07b756c273adfa3bbec550cd3cc1 (diff) | |
download | pytorch-b5181ba1df13762b47ae6a83d38e3497ef9d04aa.tar.gz pytorch-b5181ba1df13762b47ae6a83d38e3497ef9d04aa.tar.bz2 pytorch-b5181ba1df13762b47ae6a83d38e3497ef9d04aa.zip |
add avx512 option (but no avx512 kernel yet) (#14664)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14664
This diff just adds a framework to add avx512 kernels.
Please be really really careful about using avx512 kernels unless you're convinced using avx512 will bring good enough *overall* speedups because it can backfire because of cpu frequency going down.
Reviewed By: duc0
Differential Revision: D13281944
fbshipit-source-id: 04fce8619c63f814944b727a99fbd7d35538eac6
Diffstat (limited to 'caffe2/perfkernels')
-rw-r--r-- | caffe2/perfkernels/CMakeLists.txt | 14 | ||||
-rw-r--r-- | caffe2/perfkernels/common.h | 27 | ||||
-rw-r--r-- | caffe2/perfkernels/common_avx512.cc | 23 | ||||
-rw-r--r-- | caffe2/perfkernels/embedding_lookup.cc | 60 |
4 files changed, 105 insertions, 19 deletions
diff --git a/caffe2/perfkernels/CMakeLists.txt b/caffe2/perfkernels/CMakeLists.txt index a5701da807..18ae10dad2 100644 --- a/caffe2/perfkernels/CMakeLists.txt +++ b/caffe2/perfkernels/CMakeLists.txt @@ -2,9 +2,11 @@ file(GLOB common_srcs *.cc) file(GLOB avx_srcs *_avx.cc) file(GLOB avx2_srcs *_avx2.cc) -# exclude avx and avx2 srcs from common_srcs +file(GLOB avx512_srcs *_avx512.cc) +# exclude avx, avx2, and avx512 srcs from common_srcs exclude(common_srcs "${common_srcs}" ${avx_srcs}) exclude(common_srcs "${common_srcs}" ${avx2_srcs}) +exclude(common_srcs "${common_srcs}" ${avx512_srcs}) # We will always build common srcs. set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs}) @@ -24,6 +26,7 @@ if (NOT MSVC AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS) Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "/arch:AVX") set_target_properties( Caffe2_perfkernels_avx2 PROPERTIES COMPILE_FLAGS "/arch:AVX2") + # Currently MSVC doesn't support AVX512 else() set_target_properties( Caffe2_perfkernels_avx PROPERTIES COMPILE_FLAGS "-mavx -mf16c") @@ -33,6 +36,15 @@ if (NOT MSVC AND CAFFE2_COMPILER_SUPPORTS_AVX2_EXTENSIONS) set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} $<TARGET_OBJECTS:Caffe2_perfkernels_avx> $<TARGET_OBJECTS:Caffe2_perfkernels_avx2>) + + if (CAFFE2_COMPILER_SUPPORTS_AVX512_EXTENSIONS) + add_library(Caffe2_perfkernels_avx512 OBJECT ${avx512_srcs}) + add_dependencies(Caffe2_perfkernels_avx512 Caffe2_PROTO c10) + set_target_properties( + Caffe2_perfkernels_avx512 PROPERTIES COMPILE_FLAGS "-mavx512f -mavx512dq -mavx2 -mfma -mavx -mf16c") + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} + $<TARGET_OBJECTS:Caffe2_perfkernels_avx512>) + endif() endif() # TODO(jiayq): currently, we only implement the very base files for the diff --git a/caffe2/perfkernels/common.h b/caffe2/perfkernels/common.h index 70d5945495..7fdf64fda6 100644 --- a/caffe2/perfkernels/common.h +++ b/caffe2/perfkernels/common.h @@ -7,6 +7,11 @@ implement a functionality called void foo(int a, float b). In foo.h, do: void foo(int a, float b); +In foo_avx512.cc, do: + void foo__avx512(int a, float b) { + [actual avx512 implementation] + } + In foo_avx2.cc, do: void foo__avx2(int a, float b) { [actual avx2 implementation] @@ -25,6 +30,7 @@ In foo.cc, do: void foo(int a, float b) { // You should always order things by their preference, faster // implementations earlier in the function. + AVX512_DO(foo, a, b); AVX2_DO(foo, a, b); AVX_DO(foo, a, b); BASE_DO(foo, a, b); @@ -35,11 +41,12 @@ In foo.cc, do: // and run time architecture support. // // During build time: -// The build system should provide flags CAFFE2_PERF_WITH_AVX2 and -// CAFFE2_PERF_WITH_AVX that corresponds to the __AVX__ and __AVX2__ flags -// the compiler provides. Note that we do not use the compiler flags but -// rely on the build system flags, because the common files (like foo.cc -// above) will always be built without __AVX__ and __AVX2__. +// The build system should provide flags CAFFE2_PERF_WITH_AVX512, +// CAFFE2_PERF_WITH_AVX2, and CAFFE2_PERF_WITH_AVX that corresponds to the +// __AVX512F__, __AVX512DQ__, __AVX__, and __AVX2__ flags the compiler +// provides. Note that we do not use the compiler flags but rely on the build +// system flags, because the common files (like foo.cc above) will always be +// built without __AVX512F__, __AVX512DQ__, __AVX__ and __AVX2__. // During run time: // we use cpuid to identify cpu support and run the proper functions. @@ -52,6 +59,16 @@ In foo.cc, do: #define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__); +#ifdef CAFFE2_PERF_WITH_AVX512 +#define AVX512_DO(funcname, ...) \ + decltype(funcname##__base) funcname##__avx512; \ + if (GetCpuId().avx512f() && GetCpuId().avx512dq()) { \ + return funcname##__avx512(__VA_ARGS__); \ + } +#else // CAFFE2_PERF_WITH_AVX512 +#define AVX512_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_AVX512 + #ifdef CAFFE2_PERF_WITH_AVX2 #define AVX2_DO(funcname, ...) \ decltype(funcname##__base) funcname##__avx2; \ diff --git a/caffe2/perfkernels/common_avx512.cc b/caffe2/perfkernels/common_avx512.cc new file mode 100644 index 0000000000..055f95d775 --- /dev/null +++ b/caffe2/perfkernels/common_avx512.cc @@ -0,0 +1,23 @@ +// This file is here merely to check that the flags are not mixed up: for +// example, if your compiler did not specify -mavx512f and -mavx512dq, +// you should not provide the CAFFE2_PERF_WITH_AVX512 macro. + +#include "caffe2/core/common.h" + +#ifdef CAFFE2_PERF_WITH_AVX512 +#if !defined(__AVX512F__) || !defined(__AVX512DQ__) +#error( \ + "You found a build system error: CAFFE2_PERF_WITH_AVX512 is defined" \ + "but __AVX512F__ or __AVX512DQ__ is not defined" \ + "(via e.g. -mavx512f and -mavx512dq)."); +#endif +#endif // CAFFE2_PERF_WITH_AVX512 + +#if defined(__AVX512F__) && defined(__AVX512DQ__) +#ifndef CAFFE2_PERF_WITH_AVX512 +#error( \ + "You found a build system error: __AVX512F__ and __AVX512DQ__ is defined" \ + "(via e.g. -mavx512f and -mavx512dq) " \ + "but CAFFE2_PERF_WITH_AVX512 is not defined."); +#endif // CAFFE2_PERF_WITH_AVX512 +#endif diff --git a/caffe2/perfkernels/embedding_lookup.cc b/caffe2/perfkernels/embedding_lookup.cc index e98bc51c25..fa93ae7496 100644 --- a/caffe2/perfkernels/embedding_lookup.cc +++ b/caffe2/perfkernels/embedding_lookup.cc @@ -82,13 +82,19 @@ static void EmbeddingLookupGenericSlow( // Proxy back to generic implementation #define EMBEDDING_SPECIALIZATION( \ - IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType, IS_WEIGHT_POSITIONAL) \ + IndexTypeName, \ + IndexType, \ + InTypeName, \ + InType, \ + OutTypeName, \ + OutType, \ + IS_WEIGHT_POSITIONAL) \ void \ EmbeddingLookup_##IndexTypeName##_##InTypeName##_##OutTypeName##_##IS_WEIGHT_POSITIONAL##__base( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ const int* lengths, \ @@ -115,10 +121,10 @@ static void EmbeddingLookupGenericSlow( } \ template <> \ void EmbeddingLookup<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>( \ - const int64_t block_size, \ - const int64_t output_size, \ - const int64_t index_size, \ - const int64_t data_size, \ + const int64_t block_size, \ + const int64_t output_size, \ + const int64_t index_size, \ + const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ const int* lengths, \ @@ -158,15 +164,43 @@ EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, false); EMBEDDING_SPECIALIZATION(int64_t, int64_t, float, float, float, float, false); EMBEDDING_SPECIALIZATION(int32_t, int32_t, half, at::Half, float, float, false); EMBEDDING_SPECIALIZATION(int64_t, int64_t, half, at::Half, float, float, false); -EMBEDDING_SPECIALIZATION(int32_t, int32_t, uint8_t, uint8_t, float, float, false); -EMBEDDING_SPECIALIZATION(int64_t, int64_t, uint8_t, uint8_t, float, float, false); +EMBEDDING_SPECIALIZATION( + int32_t, + int32_t, + uint8_t, + uint8_t, + float, + float, + false); +EMBEDDING_SPECIALIZATION( + int64_t, + int64_t, + uint8_t, + uint8_t, + float, + float, + false); EMBEDDING_SPECIALIZATION(int32_t, int32_t, float, float, float, float, true); EMBEDDING_SPECIALIZATION(int64_t, int64_t, float, float, float, float, true); EMBEDDING_SPECIALIZATION(int32_t, int32_t, half, at::Half, float, float, true); EMBEDDING_SPECIALIZATION(int64_t, int64_t, half, at::Half, float, float, true); -EMBEDDING_SPECIALIZATION(int32_t, int32_t, uint8_t, uint8_t, float, float, true); -EMBEDDING_SPECIALIZATION(int64_t, int64_t, uint8_t, uint8_t, float, float, true); +EMBEDDING_SPECIALIZATION( + int32_t, + int32_t, + uint8_t, + uint8_t, + float, + float, + true); +EMBEDDING_SPECIALIZATION( + int64_t, + int64_t, + uint8_t, + uint8_t, + float, + float, + true); #undef EMBEDDING_SPECIALIZATION |