summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
authorRoy Li <royboy@fb.com>2019-02-01 10:55:00 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-01 11:00:16 -0800
commit7e642dfff3a4bbf80826f331312acff2672a3e02 (patch)
tree2b8499a1824bf467cfb152ec1089be517bcae520 /aten
parent64186e06eccad8721cfa741cae31a39842a3dfe3 (diff)
downloadpytorch-7e642dfff3a4bbf80826f331312acff2672a3e02.tar.gz
pytorch-7e642dfff3a4bbf80826f331312acff2672a3e02.tar.bz2
pytorch-7e642dfff3a4bbf80826f331312acff2672a3e02.zip
Introduce backend extensions (overriding operators on custom backends)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15153 Reviewed By: gchanan Differential Revision: D13445571 fbshipit-source-id: 62e2ebe0a6e81c4983b47cddb57ee5eb78e96708
Diffstat (limited to 'aten')
-rw-r--r--aten/src/ATen/core/Type.h1
-rw-r--r--aten/src/ATen/function_wrapper.py35
-rw-r--r--aten/src/ATen/gen.py53
-rw-r--r--aten/src/ATen/templates/ExtensionBackendRegistration.h19
-rw-r--r--aten/src/ATen/templates/TypeExtension.cpp51
-rw-r--r--aten/src/ATen/templates/TypeExtension.h49
-rw-r--r--aten/src/ATen/test/CMakeLists.txt3
-rw-r--r--aten/src/ATen/test/extension_backend_test.cpp66
-rwxr-xr-xaten/tools/run_tests.sh1
9 files changed, 276 insertions, 2 deletions
diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h
index 30dd0129c7..f7b20e7e53 100644
--- a/aten/src/ATen/core/Type.h
+++ b/aten/src/ATen/core/Type.h
@@ -74,6 +74,7 @@ enum class TypeID {
SparseCUDAInt,
SparseCUDALong,
SparseCUDAShort,
+ MSNPU,
CPUComplexFloat,
CPUComplexDouble,
CUDAComplexFloat,
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 4978f8e5f4..2ee5aeb76f 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -116,6 +116,13 @@ TYPE_DEFINITION_BODY_NATIVE = CodeTemplate("""\
${return_call} at::native::${native_type_method_dispatch}(/* native_actuals */ ${native_actuals});
""")
+# Overrideable stubs to be used in user-extendable backends
+TYPE_DEFINITION_EXTENSION_BACKEND = CodeTemplate("""\
+${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const {
+ return ${Type}Dispatch::get_function<${return_type} (*)(${formals_types})>("${schema}")(${native_actuals});
+}
+""")
+
# add non-virtual declaration to Tensor.h
TENSOR_METHOD_DECLARATION = CodeTemplate("""\
${return_type} ${api_name}(${method_formals_with_defaults})${const_mark};
@@ -489,6 +496,7 @@ FunctionOption = TypedDict('FunctionOption', {
'formals_list': List[AtFormal],
'formals_with_defaults': List[str],
'formals': List[str],
+ 'formals_types': List[str],
'inferred_type': str,
'inplace': bool,
'matches_jit_signature': bool,
@@ -513,6 +521,8 @@ FunctionOption = TypedDict('FunctionOption', {
'return': ReturnDecl,
'returns': List[ReturnType],
'scalar_check': str,
+ # schema used for extension backend operator registration
+ 'schema': str,
'sparse': bool,
'type_definition_body': List[str],
'type_method_actuals': List[str],
@@ -1595,3 +1605,28 @@ def create_derived(backend_type_env, declarations):
except NYIError:
pass
return type_object_declarations, type_object_definitions
+
+
+def create_extension_backend(backend_type_env, declarations):
+ # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]]
+ type_object_declarations = []
+ type_object_definitions = []
+
+ for declaration in declarations:
+ for option in declaration['options']:
+ if not option.get('skip', False):
+ try:
+ option['formals_types'] = [f['type'] for f in option['formals_list']]
+ option['native_actuals'] = [f['name'] for f in option['formals_list']]
+ schema_args = ", ".join(
+ ["{} {}".format(f['dynamic_type'], f['name']) for f in option['formals_list']])
+ return_type = NATIVE_DYNAMIC_TYPE.get(option['return_type'], option['return_type'])
+ option['schema'] = "{}({}) -> {}".format(option['api_name'], schema_args, return_type)
+ env = nested_dict(option, backend_type_env)
+ type_object_declarations.append(
+ TYPE_DERIVED_DECLARATION.substitute(env))
+ type_object_definitions.append(
+ TYPE_DEFINITION_EXTENSION_BACKEND.substitute(env))
+ except NYIError:
+ pass
+ return type_object_declarations, type_object_definitions
diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py
index 32255e39c1..c8207710bb 100644
--- a/aten/src/ATen/gen.py
+++ b/aten/src/ATen/gen.py
@@ -121,6 +121,8 @@ TYPE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/Type.h")
TYPE_EXTENDED_INTERFACE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtendedInterface.h")
TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h")
TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp")
+TYPE_EXTENSION_BACKEND_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.h")
+TYPE_EXTENSION_BACKEND_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.cpp")
LEGACY_TH_DISPATCHER_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.h")
LEGACY_TH_DISPATCHER_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.cpp")
@@ -141,10 +143,18 @@ LEGACY_TH_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHFunctio
NATIVE_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/NativeFunctions.h")
+EXTENSION_BACKEND_REGISTRATION_H = CodeTemplate.from_file(TEMPLATE_PATH + "/ExtensionBackendRegistration.h")
+
TYPE_REGISTER = CodeTemplate("""\
context->registerType(Backend::${backend}, ScalarType::${scalar_type}, new ${type_name}());
""")
+EXTENSION_BACKEND_REGISTER_SWITCH = CodeTemplate("""\
+case Backend::${Backend}:
+ ${Type}Dispatch::register_function(schema, fn);
+ break;
+""")
+
core_file_manager = FileManager(core_install_dir)
file_manager = FileManager()
cuda_file_manager = FileManager()
@@ -164,6 +174,7 @@ generators = {
backends = ['CPU', 'CUDA']
densities = ['Dense', 'Sparse']
+extension_backends = ['MSNPU']
# scalar_name, c_type, accreal, th_scalar_type, is_floating_type
scalar_types = [
@@ -193,6 +204,8 @@ top_env = {
'function_definitions': [],
'type_ids': [],
'native_function_declarations': [],
+ 'extension_backend_headers': [],
+ 'extension_backend_register_switches': [],
}
@@ -347,6 +360,37 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
return env
+def generate_type_extension_backend(backend, declarations):
+ env = {}
+ env['Type'] = "{}Type".format(backend)
+ env['Backend'] = backend
+ env['DeviceType'] = backend
+ env['is_extension_backend'] = True
+ env['TypeID'] = 'TypeID::' + backend
+ top_env['type_ids'].append(backend + ',')
+
+ declarations, definitions = function_wrapper.create_extension_backend(
+ env, declarations)
+ env['type_method_declarations'] = declarations
+ env['type_method_definitions'] = definitions
+
+ fm = file_manager
+ fm.write(env['Type'] + ".cpp", TYPE_EXTENSION_BACKEND_CPP, env)
+ fm.write(env['Type'] + ".h", TYPE_EXTENSION_BACKEND_H, env)
+
+ for scalar_name, _, _, _, _ in scalar_types:
+ type_register = TYPE_REGISTER.substitute(backend=env['Backend'], scalar_type=scalar_name, type_name=env['Type'])
+ top_env['cpu_type_registrations'].append(type_register)
+ extension_backend_register_switch = EXTENSION_BACKEND_REGISTER_SWITCH.substitute(env)
+ top_env['extension_backend_register_switches'].append(extension_backend_register_switch)
+ top_env['extension_backend_headers'].append(
+ '#include <ATen/{}.h>'.format(env['Type']))
+ top_env['cpu_type_headers'].append(
+ '#include "ATen/{}.h"'.format(env['Type']))
+
+ return env
+
+
def generate_legacy_th_dispatcher(backend, density, scalar_type, declarations):
assert density != 'Sparse'
scalar_name, c_type, accreal, th_scalar_type, is_floating_type = scalar_type
@@ -384,7 +428,7 @@ def declare_outputs():
core_file_manager.will_write(f)
files = ['Declarations.yaml', 'TypeExtendedInterface.h', 'TypeDefault.cpp', 'TypeDefault.h',
'LegacyTHDispatcher.h', 'LegacyTHDispatcher.cpp', 'LegacyTHFunctions.h',
- 'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h']
+ 'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h', 'ExtensionBackendRegistration.h']
for f in files:
file_manager.will_write(f)
cuda_files = ['RegisterCUDA.cpp', 'RegisterCUDA.h']
@@ -411,6 +455,9 @@ def declare_outputs():
if density != 'Sparse':
fm.will_write("{}{}{}{}.h".format('LegacyTH', full_backend, scalar_name, 'Dispatcher'))
fm.will_write("{}{}{}{}.cpp".format('LegacyTH', full_backend, scalar_name, 'Dispatcher'))
+ for backend in extension_backends:
+ file_manager.will_write("{}Type.h".format(backend))
+ file_manager.will_write("{}Type.cpp".format(backend))
def filter_by_extension(files, *extensions):
@@ -472,6 +519,8 @@ def generate_outputs():
for backend, density, scalar_type in iterate_types():
all_types.append(generate_storage_type_and_tensor(
backend, density, scalar_type, declarations))
+ for backend in extension_backends:
+ all_types.append(generate_type_extension_backend(backend, declarations))
all_legacy_th_dispatchers = []
for backend, density, scalar_type in iterate_types():
@@ -506,6 +555,8 @@ def generate_outputs():
file_manager.write('NativeFunctions.h', NATIVE_FUNCTIONS_H, top_env)
+ file_manager.write('ExtensionBackendRegistration.h', EXTENSION_BACKEND_REGISTRATION_H, top_env)
+
file_manager.check_all_files_written()
cuda_file_manager.check_all_files_written()
diff --git a/aten/src/ATen/templates/ExtensionBackendRegistration.h b/aten/src/ATen/templates/ExtensionBackendRegistration.h
new file mode 100644
index 0000000000..2f9f731922
--- /dev/null
+++ b/aten/src/ATen/templates/ExtensionBackendRegistration.h
@@ -0,0 +1,19 @@
+#pragma once
+#include <ATen/Backend.h>
+${extension_backend_headers}
+
+namespace at {
+
+template <typename FnPtr>
+inline void register_extension_backend_op(
+ Backend backend,
+ const char * schema,
+ FnPtr fn) {
+ switch (backend) {
+ ${extension_backend_register_switches}
+ default:
+ AT_ERROR("Invalid extension backend: ", toString(backend));
+ }
+}
+
+} // namespace at
diff --git a/aten/src/ATen/templates/TypeExtension.cpp b/aten/src/ATen/templates/TypeExtension.cpp
new file mode 100644
index 0000000000..313b865fc0
--- /dev/null
+++ b/aten/src/ATen/templates/TypeExtension.cpp
@@ -0,0 +1,51 @@
+#include <ATen/${Type}.h>
+
+namespace at {
+
+std::unordered_map<std::string, void *>& ${Type}Dispatch::get_fn_table() {
+ static std::unordered_map<std::string, void *> fn_table;
+ return fn_table;
+}
+
+${Type}::${Type}()
+ : TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
+
+Allocator* ${Type}::allocator() const {
+ AT_ERROR("allocator is not implemented for ${Type}");
+}
+
+Device ${Type}::getDeviceFromPtr(void * data) const {
+ return DeviceType::${DeviceType};
+}
+
+std::unique_ptr<Generator> ${Type}::generator() const {
+ AT_ERROR("generator is not implemented for ${Type}");
+}
+
+ScalarType ${Type}::scalarType() const {
+ AT_ERROR("scalarType is not implemented for ${Type}");
+}
+
+caffe2::TypeMeta ${Type}::typeMeta() const {
+ AT_ERROR("typeMeta is not implemented for ${Type}");
+}
+
+Backend ${Type}::backend() const {
+ return Backend::${Backend};
+}
+
+const char * ${Type}::toString() const {
+ return "${Type}";
+}
+
+TypeID ${Type}::ID() const {
+ return ${TypeID};
+}
+
+size_t ${Type}::elementSizeInBytes() const {
+ AT_ERROR("elementSizeInBytes is not implemented for ${Type}");
+}
+
+${type_method_definitions}
+
+} // namespace at
diff --git a/aten/src/ATen/templates/TypeExtension.h b/aten/src/ATen/templates/TypeExtension.h
new file mode 100644
index 0000000000..fe45960225
--- /dev/null
+++ b/aten/src/ATen/templates/TypeExtension.h
@@ -0,0 +1,49 @@
+#pragma once
+#include <ATen/TypeDefault.h>
+
+namespace at {
+
+// This dispatch class holds static map in which function pointers are
+// registered by schema.
+// TODO: Check for invalid schemas prior to registration.
+struct CAFFE2_API ${Type}Dispatch {
+ template<typename FnPtr>
+ static FnPtr get_function(const std::string& schema) {
+ auto & fn_table = get_fn_table();
+ auto it = fn_table.find(schema);
+ if (it != fn_table.end()) {
+ return reinterpret_cast<FnPtr>(it->second);
+ }
+ AT_ERROR("No function registered for schema: ", schema);
+ }
+
+ template<typename FnPtr>
+ static void register_function(const std::string& schema, FnPtr fn) {
+ auto & fn_table = get_fn_table();
+ if (fn_table.find(schema) != fn_table.end()) {
+ AT_ERROR("Function already registered for schema: ", schema);
+ }
+ fn_table[schema] = reinterpret_cast<void *>(fn);
+ }
+
+ static std::unordered_map<std::string, void *>& get_fn_table();
+};
+
+struct CAFFE2_API ${Type} : public TypeDefault {
+ explicit ${Type}();
+
+ Allocator* allocator() const override;
+ Device getDeviceFromPtr(void * data) const override;
+ std::unique_ptr<Generator> generator() const override;
+
+ virtual ScalarType scalarType() const override;
+ virtual caffe2::TypeMeta typeMeta() const override;
+ virtual Backend backend() const override;
+ virtual const char * toString() const override;
+ virtual size_t elementSizeInBytes() const override;
+ virtual TypeID ID() const override;
+
+ ${type_method_declarations}
+};
+
+} // namespace at
diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt
index 1d2457154b..72a98f7ace 100644
--- a/aten/src/ATen/test/CMakeLists.txt
+++ b/aten/src/ATen/test/CMakeLists.txt
@@ -20,7 +20,8 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp)
+ ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp)
list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp
new file mode 100644
index 0000000000..be8d262f6a
--- /dev/null
+++ b/aten/src/ATen/test/extension_backend_test.cpp
@@ -0,0 +1,66 @@
+#include <gtest/gtest.h>
+
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/ExtensionBackendRegistration.h>
+
+using namespace at;
+
+static int test_int;
+
+Tensor empty_override(IntList size, const TensorOptions & options) {
+ test_int = 1;
+ auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
+ Storage(
+ caffe2::TypeMeta::Make<float>(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false),
+ MSNPUTensorId(),
+ false);
+ return Tensor(std::move(tensor_impl));
+}
+
+Tensor empty_like_override(const Tensor & self, const TensorOptions & options) {
+ test_int = 2;
+ return self;
+}
+
+Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
+ test_int = 3;
+ return a;
+}
+
+TEST(BackendExtensionTest, TestRegisterOp) {
+ EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU));
+ register_extension_backend_op(
+ Backend::MSNPU,
+ "empty(IntList size, TensorOptions options) -> Tensor", &empty_override);
+ Tensor a = empty({5, 5}, at::kMSNPU);
+ ASSERT_EQ(a.device().type(), at::kMSNPU);
+ ASSERT_EQ(a.device().index(), 1);
+ ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
+ ASSERT_EQ(test_int, 1);
+
+ EXPECT_ANY_THROW(empty_like(a, at::kMSNPU));
+ register_extension_backend_op(
+ Backend::MSNPU,
+ "empty_like(Tensor self, TensorOptions options) -> Tensor", &empty_like_override);
+ Tensor b = empty_like(a, at::kMSNPU);
+ ASSERT_EQ(test_int, 2);
+
+ EXPECT_ANY_THROW(add(a, b));
+ register_extension_backend_op(
+ Backend::MSNPU,
+ "add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override);
+ add(a, b);
+ ASSERT_EQ(test_int, 3);
+
+ // Ensure that non-MSNPU operator still works
+ Tensor d = empty({5, 5}, at::kCPU);
+ ASSERT_EQ(d.device().type(), at::kCPU);
+
+ // Attempt to register on a schema that has already has a function
+ EXPECT_ANY_THROW(
+ register_extension_backend_op(
+ Backend::MSNPU,
+ "empty(IntList size, TensorOptions options) -> Tensor", &empty_override)
+ );
+}
diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh
index 8ed5a21e2d..e2df276220 100755
--- a/aten/tools/run_tests.sh
+++ b/aten/tools/run_tests.sh
@@ -17,6 +17,7 @@ VALGRIND=${VALGRIND:=ON}
./scalar_tensor_test
./tensor_interop_test
./undefined_tensor_test
+./extension_backend_test
if [[ -x ./cudnn_test ]]; then
./cudnn_test
fi