diff options
author | Roy Li <royboy@fb.com> | 2019-02-01 10:55:00 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-01 11:00:16 -0800 |
commit | 7e642dfff3a4bbf80826f331312acff2672a3e02 (patch) | |
tree | 2b8499a1824bf467cfb152ec1089be517bcae520 /aten | |
parent | 64186e06eccad8721cfa741cae31a39842a3dfe3 (diff) | |
download | pytorch-7e642dfff3a4bbf80826f331312acff2672a3e02.tar.gz pytorch-7e642dfff3a4bbf80826f331312acff2672a3e02.tar.bz2 pytorch-7e642dfff3a4bbf80826f331312acff2672a3e02.zip |
Introduce backend extensions (overriding operators on custom backends)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15153
Reviewed By: gchanan
Differential Revision: D13445571
fbshipit-source-id: 62e2ebe0a6e81c4983b47cddb57ee5eb78e96708
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/core/Type.h | 1 | ||||
-rw-r--r-- | aten/src/ATen/function_wrapper.py | 35 | ||||
-rw-r--r-- | aten/src/ATen/gen.py | 53 | ||||
-rw-r--r-- | aten/src/ATen/templates/ExtensionBackendRegistration.h | 19 | ||||
-rw-r--r-- | aten/src/ATen/templates/TypeExtension.cpp | 51 | ||||
-rw-r--r-- | aten/src/ATen/templates/TypeExtension.h | 49 | ||||
-rw-r--r-- | aten/src/ATen/test/CMakeLists.txt | 3 | ||||
-rw-r--r-- | aten/src/ATen/test/extension_backend_test.cpp | 66 | ||||
-rwxr-xr-x | aten/tools/run_tests.sh | 1 |
9 files changed, 276 insertions, 2 deletions
diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 30dd0129c7..f7b20e7e53 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -74,6 +74,7 @@ enum class TypeID { SparseCUDAInt, SparseCUDALong, SparseCUDAShort, + MSNPU, CPUComplexFloat, CPUComplexDouble, CUDAComplexFloat, diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 4978f8e5f4..2ee5aeb76f 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -116,6 +116,13 @@ TYPE_DEFINITION_BODY_NATIVE = CodeTemplate("""\ ${return_call} at::native::${native_type_method_dispatch}(/* native_actuals */ ${native_actuals}); """) +# Overrideable stubs to be used in user-extendable backends +TYPE_DEFINITION_EXTENSION_BACKEND = CodeTemplate("""\ +${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const { + return ${Type}Dispatch::get_function<${return_type} (*)(${formals_types})>("${schema}")(${native_actuals}); +} +""") + # add non-virtual declaration to Tensor.h TENSOR_METHOD_DECLARATION = CodeTemplate("""\ ${return_type} ${api_name}(${method_formals_with_defaults})${const_mark}; @@ -489,6 +496,7 @@ FunctionOption = TypedDict('FunctionOption', { 'formals_list': List[AtFormal], 'formals_with_defaults': List[str], 'formals': List[str], + 'formals_types': List[str], 'inferred_type': str, 'inplace': bool, 'matches_jit_signature': bool, @@ -513,6 +521,8 @@ FunctionOption = TypedDict('FunctionOption', { 'return': ReturnDecl, 'returns': List[ReturnType], 'scalar_check': str, + # schema used for extension backend operator registration + 'schema': str, 'sparse': bool, 'type_definition_body': List[str], 'type_method_actuals': List[str], @@ -1595,3 +1605,28 @@ def create_derived(backend_type_env, declarations): except NYIError: pass return type_object_declarations, type_object_definitions + + +def create_extension_backend(backend_type_env, declarations): + # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]] + type_object_declarations = [] + type_object_definitions = [] + + for declaration in declarations: + for option in declaration['options']: + if not option.get('skip', False): + try: + option['formals_types'] = [f['type'] for f in option['formals_list']] + option['native_actuals'] = [f['name'] for f in option['formals_list']] + schema_args = ", ".join( + ["{} {}".format(f['dynamic_type'], f['name']) for f in option['formals_list']]) + return_type = NATIVE_DYNAMIC_TYPE.get(option['return_type'], option['return_type']) + option['schema'] = "{}({}) -> {}".format(option['api_name'], schema_args, return_type) + env = nested_dict(option, backend_type_env) + type_object_declarations.append( + TYPE_DERIVED_DECLARATION.substitute(env)) + type_object_definitions.append( + TYPE_DEFINITION_EXTENSION_BACKEND.substitute(env)) + except NYIError: + pass + return type_object_declarations, type_object_definitions diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index 32255e39c1..c8207710bb 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -121,6 +121,8 @@ TYPE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/Type.h") TYPE_EXTENDED_INTERFACE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtendedInterface.h") TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h") TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp") +TYPE_EXTENSION_BACKEND_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.h") +TYPE_EXTENSION_BACKEND_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.cpp") LEGACY_TH_DISPATCHER_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.h") LEGACY_TH_DISPATCHER_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.cpp") @@ -141,10 +143,18 @@ LEGACY_TH_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHFunctio NATIVE_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/NativeFunctions.h") +EXTENSION_BACKEND_REGISTRATION_H = CodeTemplate.from_file(TEMPLATE_PATH + "/ExtensionBackendRegistration.h") + TYPE_REGISTER = CodeTemplate("""\ context->registerType(Backend::${backend}, ScalarType::${scalar_type}, new ${type_name}()); """) +EXTENSION_BACKEND_REGISTER_SWITCH = CodeTemplate("""\ +case Backend::${Backend}: + ${Type}Dispatch::register_function(schema, fn); + break; +""") + core_file_manager = FileManager(core_install_dir) file_manager = FileManager() cuda_file_manager = FileManager() @@ -164,6 +174,7 @@ generators = { backends = ['CPU', 'CUDA'] densities = ['Dense', 'Sparse'] +extension_backends = ['MSNPU'] # scalar_name, c_type, accreal, th_scalar_type, is_floating_type scalar_types = [ @@ -193,6 +204,8 @@ top_env = { 'function_definitions': [], 'type_ids': [], 'native_function_declarations': [], + 'extension_backend_headers': [], + 'extension_backend_register_switches': [], } @@ -347,6 +360,37 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations return env +def generate_type_extension_backend(backend, declarations): + env = {} + env['Type'] = "{}Type".format(backend) + env['Backend'] = backend + env['DeviceType'] = backend + env['is_extension_backend'] = True + env['TypeID'] = 'TypeID::' + backend + top_env['type_ids'].append(backend + ',') + + declarations, definitions = function_wrapper.create_extension_backend( + env, declarations) + env['type_method_declarations'] = declarations + env['type_method_definitions'] = definitions + + fm = file_manager + fm.write(env['Type'] + ".cpp", TYPE_EXTENSION_BACKEND_CPP, env) + fm.write(env['Type'] + ".h", TYPE_EXTENSION_BACKEND_H, env) + + for scalar_name, _, _, _, _ in scalar_types: + type_register = TYPE_REGISTER.substitute(backend=env['Backend'], scalar_type=scalar_name, type_name=env['Type']) + top_env['cpu_type_registrations'].append(type_register) + extension_backend_register_switch = EXTENSION_BACKEND_REGISTER_SWITCH.substitute(env) + top_env['extension_backend_register_switches'].append(extension_backend_register_switch) + top_env['extension_backend_headers'].append( + '#include <ATen/{}.h>'.format(env['Type'])) + top_env['cpu_type_headers'].append( + '#include "ATen/{}.h"'.format(env['Type'])) + + return env + + def generate_legacy_th_dispatcher(backend, density, scalar_type, declarations): assert density != 'Sparse' scalar_name, c_type, accreal, th_scalar_type, is_floating_type = scalar_type @@ -384,7 +428,7 @@ def declare_outputs(): core_file_manager.will_write(f) files = ['Declarations.yaml', 'TypeExtendedInterface.h', 'TypeDefault.cpp', 'TypeDefault.h', 'LegacyTHDispatcher.h', 'LegacyTHDispatcher.cpp', 'LegacyTHFunctions.h', - 'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h'] + 'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h', 'ExtensionBackendRegistration.h'] for f in files: file_manager.will_write(f) cuda_files = ['RegisterCUDA.cpp', 'RegisterCUDA.h'] @@ -411,6 +455,9 @@ def declare_outputs(): if density != 'Sparse': fm.will_write("{}{}{}{}.h".format('LegacyTH', full_backend, scalar_name, 'Dispatcher')) fm.will_write("{}{}{}{}.cpp".format('LegacyTH', full_backend, scalar_name, 'Dispatcher')) + for backend in extension_backends: + file_manager.will_write("{}Type.h".format(backend)) + file_manager.will_write("{}Type.cpp".format(backend)) def filter_by_extension(files, *extensions): @@ -472,6 +519,8 @@ def generate_outputs(): for backend, density, scalar_type in iterate_types(): all_types.append(generate_storage_type_and_tensor( backend, density, scalar_type, declarations)) + for backend in extension_backends: + all_types.append(generate_type_extension_backend(backend, declarations)) all_legacy_th_dispatchers = [] for backend, density, scalar_type in iterate_types(): @@ -506,6 +555,8 @@ def generate_outputs(): file_manager.write('NativeFunctions.h', NATIVE_FUNCTIONS_H, top_env) + file_manager.write('ExtensionBackendRegistration.h', EXTENSION_BACKEND_REGISTRATION_H, top_env) + file_manager.check_all_files_written() cuda_file_manager.check_all_files_written() diff --git a/aten/src/ATen/templates/ExtensionBackendRegistration.h b/aten/src/ATen/templates/ExtensionBackendRegistration.h new file mode 100644 index 0000000000..2f9f731922 --- /dev/null +++ b/aten/src/ATen/templates/ExtensionBackendRegistration.h @@ -0,0 +1,19 @@ +#pragma once +#include <ATen/Backend.h> +${extension_backend_headers} + +namespace at { + +template <typename FnPtr> +inline void register_extension_backend_op( + Backend backend, + const char * schema, + FnPtr fn) { + switch (backend) { + ${extension_backend_register_switches} + default: + AT_ERROR("Invalid extension backend: ", toString(backend)); + } +} + +} // namespace at diff --git a/aten/src/ATen/templates/TypeExtension.cpp b/aten/src/ATen/templates/TypeExtension.cpp new file mode 100644 index 0000000000..313b865fc0 --- /dev/null +++ b/aten/src/ATen/templates/TypeExtension.cpp @@ -0,0 +1,51 @@ +#include <ATen/${Type}.h> + +namespace at { + +std::unordered_map<std::string, void *>& ${Type}Dispatch::get_fn_table() { + static std::unordered_map<std::string, void *> fn_table; + return fn_table; +} + +${Type}::${Type}() + : TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {} + +Allocator* ${Type}::allocator() const { + AT_ERROR("allocator is not implemented for ${Type}"); +} + +Device ${Type}::getDeviceFromPtr(void * data) const { + return DeviceType::${DeviceType}; +} + +std::unique_ptr<Generator> ${Type}::generator() const { + AT_ERROR("generator is not implemented for ${Type}"); +} + +ScalarType ${Type}::scalarType() const { + AT_ERROR("scalarType is not implemented for ${Type}"); +} + +caffe2::TypeMeta ${Type}::typeMeta() const { + AT_ERROR("typeMeta is not implemented for ${Type}"); +} + +Backend ${Type}::backend() const { + return Backend::${Backend}; +} + +const char * ${Type}::toString() const { + return "${Type}"; +} + +TypeID ${Type}::ID() const { + return ${TypeID}; +} + +size_t ${Type}::elementSizeInBytes() const { + AT_ERROR("elementSizeInBytes is not implemented for ${Type}"); +} + +${type_method_definitions} + +} // namespace at diff --git a/aten/src/ATen/templates/TypeExtension.h b/aten/src/ATen/templates/TypeExtension.h new file mode 100644 index 0000000000..fe45960225 --- /dev/null +++ b/aten/src/ATen/templates/TypeExtension.h @@ -0,0 +1,49 @@ +#pragma once +#include <ATen/TypeDefault.h> + +namespace at { + +// This dispatch class holds static map in which function pointers are +// registered by schema. +// TODO: Check for invalid schemas prior to registration. +struct CAFFE2_API ${Type}Dispatch { + template<typename FnPtr> + static FnPtr get_function(const std::string& schema) { + auto & fn_table = get_fn_table(); + auto it = fn_table.find(schema); + if (it != fn_table.end()) { + return reinterpret_cast<FnPtr>(it->second); + } + AT_ERROR("No function registered for schema: ", schema); + } + + template<typename FnPtr> + static void register_function(const std::string& schema, FnPtr fn) { + auto & fn_table = get_fn_table(); + if (fn_table.find(schema) != fn_table.end()) { + AT_ERROR("Function already registered for schema: ", schema); + } + fn_table[schema] = reinterpret_cast<void *>(fn); + } + + static std::unordered_map<std::string, void *>& get_fn_table(); +}; + +struct CAFFE2_API ${Type} : public TypeDefault { + explicit ${Type}(); + + Allocator* allocator() const override; + Device getDeviceFromPtr(void * data) const override; + std::unique_ptr<Generator> generator() const override; + + virtual ScalarType scalarType() const override; + virtual caffe2::TypeMeta typeMeta() const override; + virtual Backend backend() const override; + virtual const char * toString() const override; + virtual size_t elementSizeInBytes() const override; + virtual TypeID ID() const override; + + ${type_method_declarations} +}; + +} // namespace at diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 1d2457154b..72a98f7ace 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -20,7 +20,8 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp) list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp new file mode 100644 index 0000000000..be8d262f6a --- /dev/null +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -0,0 +1,66 @@ +#include <gtest/gtest.h> + +#include <ATen/ATen.h> +#include <ATen/NativeFunctions.h> +#include <ATen/ExtensionBackendRegistration.h> + +using namespace at; + +static int test_int; + +Tensor empty_override(IntList size, const TensorOptions & options) { + test_int = 1; + auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>( + Storage( + caffe2::TypeMeta::Make<float>(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false), + MSNPUTensorId(), + false); + return Tensor(std::move(tensor_impl)); +} + +Tensor empty_like_override(const Tensor & self, const TensorOptions & options) { + test_int = 2; + return self; +} + +Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) { + test_int = 3; + return a; +} + +TEST(BackendExtensionTest, TestRegisterOp) { + EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU)); + register_extension_backend_op( + Backend::MSNPU, + "empty(IntList size, TensorOptions options) -> Tensor", &empty_override); + Tensor a = empty({5, 5}, at::kMSNPU); + ASSERT_EQ(a.device().type(), at::kMSNPU); + ASSERT_EQ(a.device().index(), 1); + ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>()); + ASSERT_EQ(test_int, 1); + + EXPECT_ANY_THROW(empty_like(a, at::kMSNPU)); + register_extension_backend_op( + Backend::MSNPU, + "empty_like(Tensor self, TensorOptions options) -> Tensor", &empty_like_override); + Tensor b = empty_like(a, at::kMSNPU); + ASSERT_EQ(test_int, 2); + + EXPECT_ANY_THROW(add(a, b)); + register_extension_backend_op( + Backend::MSNPU, + "add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override); + add(a, b); + ASSERT_EQ(test_int, 3); + + // Ensure that non-MSNPU operator still works + Tensor d = empty({5, 5}, at::kCPU); + ASSERT_EQ(d.device().type(), at::kCPU); + + // Attempt to register on a schema that has already has a function + EXPECT_ANY_THROW( + register_extension_backend_op( + Backend::MSNPU, + "empty(IntList size, TensorOptions options) -> Tensor", &empty_override) + ); +} diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh index 8ed5a21e2d..e2df276220 100755 --- a/aten/tools/run_tests.sh +++ b/aten/tools/run_tests.sh @@ -17,6 +17,7 @@ VALGRIND=${VALGRIND:=ON} ./scalar_tensor_test ./tensor_interop_test ./undefined_tensor_test +./extension_backend_test if [[ -x ./cudnn_test ]]; then ./cudnn_test fi |