From 7e642dfff3a4bbf80826f331312acff2672a3e02 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Fri, 1 Feb 2019 10:55:00 -0800 Subject: Introduce backend extensions (overriding operators on custom backends) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15153 Reviewed By: gchanan Differential Revision: D13445571 fbshipit-source-id: 62e2ebe0a6e81c4983b47cddb57ee5eb78e96708 --- aten/src/ATen/core/Type.h | 1 + aten/src/ATen/function_wrapper.py | 35 ++++++++++++ aten/src/ATen/gen.py | 53 ++++++++++++++++- .../ATen/templates/ExtensionBackendRegistration.h | 19 +++++++ aten/src/ATen/templates/TypeExtension.cpp | 51 +++++++++++++++++ aten/src/ATen/templates/TypeExtension.h | 49 ++++++++++++++++ aten/src/ATen/test/CMakeLists.txt | 3 +- aten/src/ATen/test/extension_backend_test.cpp | 66 ++++++++++++++++++++++ aten/tools/run_tests.sh | 1 + c10/core/Backend.h | 19 ++++++- c10/core/DeviceType.cpp | 3 + c10/core/DeviceType.h | 3 +- c10/core/TensorImpl.h | 4 +- c10/core/TensorOptions.h | 18 ++++-- c10/core/TensorTypeIdRegistration.cpp | 1 + c10/core/TensorTypeIdRegistration.h | 1 + caffe2/proto/caffe2.proto | 3 +- 17 files changed, 317 insertions(+), 13 deletions(-) create mode 100644 aten/src/ATen/templates/ExtensionBackendRegistration.h create mode 100644 aten/src/ATen/templates/TypeExtension.cpp create mode 100644 aten/src/ATen/templates/TypeExtension.h create mode 100644 aten/src/ATen/test/extension_backend_test.cpp diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h index 30dd0129c7..f7b20e7e53 100644 --- a/aten/src/ATen/core/Type.h +++ b/aten/src/ATen/core/Type.h @@ -74,6 +74,7 @@ enum class TypeID { SparseCUDAInt, SparseCUDALong, SparseCUDAShort, + MSNPU, CPUComplexFloat, CPUComplexDouble, CUDAComplexFloat, diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 4978f8e5f4..2ee5aeb76f 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -116,6 +116,13 @@ TYPE_DEFINITION_BODY_NATIVE = CodeTemplate("""\ ${return_call} at::native::${native_type_method_dispatch}(/* native_actuals */ ${native_actuals}); """) +# Overrideable stubs to be used in user-extendable backends +TYPE_DEFINITION_EXTENSION_BACKEND = CodeTemplate("""\ +${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const { + return ${Type}Dispatch::get_function<${return_type} (*)(${formals_types})>("${schema}")(${native_actuals}); +} +""") + # add non-virtual declaration to Tensor.h TENSOR_METHOD_DECLARATION = CodeTemplate("""\ ${return_type} ${api_name}(${method_formals_with_defaults})${const_mark}; @@ -489,6 +496,7 @@ FunctionOption = TypedDict('FunctionOption', { 'formals_list': List[AtFormal], 'formals_with_defaults': List[str], 'formals': List[str], + 'formals_types': List[str], 'inferred_type': str, 'inplace': bool, 'matches_jit_signature': bool, @@ -513,6 +521,8 @@ FunctionOption = TypedDict('FunctionOption', { 'return': ReturnDecl, 'returns': List[ReturnType], 'scalar_check': str, + # schema used for extension backend operator registration + 'schema': str, 'sparse': bool, 'type_definition_body': List[str], 'type_method_actuals': List[str], @@ -1595,3 +1605,28 @@ def create_derived(backend_type_env, declarations): except NYIError: pass return type_object_declarations, type_object_definitions + + +def create_extension_backend(backend_type_env, declarations): + # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]] + type_object_declarations = [] + type_object_definitions = [] + + for declaration in declarations: + for option in declaration['options']: + if not option.get('skip', False): + try: + option['formals_types'] = [f['type'] for f in option['formals_list']] + option['native_actuals'] = [f['name'] for f in option['formals_list']] + schema_args = ", ".join( + ["{} {}".format(f['dynamic_type'], f['name']) for f in option['formals_list']]) + return_type = NATIVE_DYNAMIC_TYPE.get(option['return_type'], option['return_type']) + option['schema'] = "{}({}) -> {}".format(option['api_name'], schema_args, return_type) + env = nested_dict(option, backend_type_env) + type_object_declarations.append( + TYPE_DERIVED_DECLARATION.substitute(env)) + type_object_definitions.append( + TYPE_DEFINITION_EXTENSION_BACKEND.substitute(env)) + except NYIError: + pass + return type_object_declarations, type_object_definitions diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index 32255e39c1..c8207710bb 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -121,6 +121,8 @@ TYPE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/Type.h") TYPE_EXTENDED_INTERFACE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtendedInterface.h") TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h") TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp") +TYPE_EXTENSION_BACKEND_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.h") +TYPE_EXTENSION_BACKEND_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.cpp") LEGACY_TH_DISPATCHER_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.h") LEGACY_TH_DISPATCHER_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.cpp") @@ -141,10 +143,18 @@ LEGACY_TH_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHFunctio NATIVE_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/NativeFunctions.h") +EXTENSION_BACKEND_REGISTRATION_H = CodeTemplate.from_file(TEMPLATE_PATH + "/ExtensionBackendRegistration.h") + TYPE_REGISTER = CodeTemplate("""\ context->registerType(Backend::${backend}, ScalarType::${scalar_type}, new ${type_name}()); """) +EXTENSION_BACKEND_REGISTER_SWITCH = CodeTemplate("""\ +case Backend::${Backend}: + ${Type}Dispatch::register_function(schema, fn); + break; +""") + core_file_manager = FileManager(core_install_dir) file_manager = FileManager() cuda_file_manager = FileManager() @@ -164,6 +174,7 @@ generators = { backends = ['CPU', 'CUDA'] densities = ['Dense', 'Sparse'] +extension_backends = ['MSNPU'] # scalar_name, c_type, accreal, th_scalar_type, is_floating_type scalar_types = [ @@ -193,6 +204,8 @@ top_env = { 'function_definitions': [], 'type_ids': [], 'native_function_declarations': [], + 'extension_backend_headers': [], + 'extension_backend_register_switches': [], } @@ -347,6 +360,37 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations return env +def generate_type_extension_backend(backend, declarations): + env = {} + env['Type'] = "{}Type".format(backend) + env['Backend'] = backend + env['DeviceType'] = backend + env['is_extension_backend'] = True + env['TypeID'] = 'TypeID::' + backend + top_env['type_ids'].append(backend + ',') + + declarations, definitions = function_wrapper.create_extension_backend( + env, declarations) + env['type_method_declarations'] = declarations + env['type_method_definitions'] = definitions + + fm = file_manager + fm.write(env['Type'] + ".cpp", TYPE_EXTENSION_BACKEND_CPP, env) + fm.write(env['Type'] + ".h", TYPE_EXTENSION_BACKEND_H, env) + + for scalar_name, _, _, _, _ in scalar_types: + type_register = TYPE_REGISTER.substitute(backend=env['Backend'], scalar_type=scalar_name, type_name=env['Type']) + top_env['cpu_type_registrations'].append(type_register) + extension_backend_register_switch = EXTENSION_BACKEND_REGISTER_SWITCH.substitute(env) + top_env['extension_backend_register_switches'].append(extension_backend_register_switch) + top_env['extension_backend_headers'].append( + '#include '.format(env['Type'])) + top_env['cpu_type_headers'].append( + '#include "ATen/{}.h"'.format(env['Type'])) + + return env + + def generate_legacy_th_dispatcher(backend, density, scalar_type, declarations): assert density != 'Sparse' scalar_name, c_type, accreal, th_scalar_type, is_floating_type = scalar_type @@ -384,7 +428,7 @@ def declare_outputs(): core_file_manager.will_write(f) files = ['Declarations.yaml', 'TypeExtendedInterface.h', 'TypeDefault.cpp', 'TypeDefault.h', 'LegacyTHDispatcher.h', 'LegacyTHDispatcher.cpp', 'LegacyTHFunctions.h', - 'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h'] + 'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h', 'ExtensionBackendRegistration.h'] for f in files: file_manager.will_write(f) cuda_files = ['RegisterCUDA.cpp', 'RegisterCUDA.h'] @@ -411,6 +455,9 @@ def declare_outputs(): if density != 'Sparse': fm.will_write("{}{}{}{}.h".format('LegacyTH', full_backend, scalar_name, 'Dispatcher')) fm.will_write("{}{}{}{}.cpp".format('LegacyTH', full_backend, scalar_name, 'Dispatcher')) + for backend in extension_backends: + file_manager.will_write("{}Type.h".format(backend)) + file_manager.will_write("{}Type.cpp".format(backend)) def filter_by_extension(files, *extensions): @@ -472,6 +519,8 @@ def generate_outputs(): for backend, density, scalar_type in iterate_types(): all_types.append(generate_storage_type_and_tensor( backend, density, scalar_type, declarations)) + for backend in extension_backends: + all_types.append(generate_type_extension_backend(backend, declarations)) all_legacy_th_dispatchers = [] for backend, density, scalar_type in iterate_types(): @@ -506,6 +555,8 @@ def generate_outputs(): file_manager.write('NativeFunctions.h', NATIVE_FUNCTIONS_H, top_env) + file_manager.write('ExtensionBackendRegistration.h', EXTENSION_BACKEND_REGISTRATION_H, top_env) + file_manager.check_all_files_written() cuda_file_manager.check_all_files_written() diff --git a/aten/src/ATen/templates/ExtensionBackendRegistration.h b/aten/src/ATen/templates/ExtensionBackendRegistration.h new file mode 100644 index 0000000000..2f9f731922 --- /dev/null +++ b/aten/src/ATen/templates/ExtensionBackendRegistration.h @@ -0,0 +1,19 @@ +#pragma once +#include +${extension_backend_headers} + +namespace at { + +template +inline void register_extension_backend_op( + Backend backend, + const char * schema, + FnPtr fn) { + switch (backend) { + ${extension_backend_register_switches} + default: + AT_ERROR("Invalid extension backend: ", toString(backend)); + } +} + +} // namespace at diff --git a/aten/src/ATen/templates/TypeExtension.cpp b/aten/src/ATen/templates/TypeExtension.cpp new file mode 100644 index 0000000000..313b865fc0 --- /dev/null +++ b/aten/src/ATen/templates/TypeExtension.cpp @@ -0,0 +1,51 @@ +#include + +namespace at { + +std::unordered_map& ${Type}Dispatch::get_fn_table() { + static std::unordered_map fn_table; + return fn_table; +} + +${Type}::${Type}() + : TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {} + +Allocator* ${Type}::allocator() const { + AT_ERROR("allocator is not implemented for ${Type}"); +} + +Device ${Type}::getDeviceFromPtr(void * data) const { + return DeviceType::${DeviceType}; +} + +std::unique_ptr ${Type}::generator() const { + AT_ERROR("generator is not implemented for ${Type}"); +} + +ScalarType ${Type}::scalarType() const { + AT_ERROR("scalarType is not implemented for ${Type}"); +} + +caffe2::TypeMeta ${Type}::typeMeta() const { + AT_ERROR("typeMeta is not implemented for ${Type}"); +} + +Backend ${Type}::backend() const { + return Backend::${Backend}; +} + +const char * ${Type}::toString() const { + return "${Type}"; +} + +TypeID ${Type}::ID() const { + return ${TypeID}; +} + +size_t ${Type}::elementSizeInBytes() const { + AT_ERROR("elementSizeInBytes is not implemented for ${Type}"); +} + +${type_method_definitions} + +} // namespace at diff --git a/aten/src/ATen/templates/TypeExtension.h b/aten/src/ATen/templates/TypeExtension.h new file mode 100644 index 0000000000..fe45960225 --- /dev/null +++ b/aten/src/ATen/templates/TypeExtension.h @@ -0,0 +1,49 @@ +#pragma once +#include + +namespace at { + +// This dispatch class holds static map in which function pointers are +// registered by schema. +// TODO: Check for invalid schemas prior to registration. +struct CAFFE2_API ${Type}Dispatch { + template + static FnPtr get_function(const std::string& schema) { + auto & fn_table = get_fn_table(); + auto it = fn_table.find(schema); + if (it != fn_table.end()) { + return reinterpret_cast(it->second); + } + AT_ERROR("No function registered for schema: ", schema); + } + + template + static void register_function(const std::string& schema, FnPtr fn) { + auto & fn_table = get_fn_table(); + if (fn_table.find(schema) != fn_table.end()) { + AT_ERROR("Function already registered for schema: ", schema); + } + fn_table[schema] = reinterpret_cast(fn); + } + + static std::unordered_map& get_fn_table(); +}; + +struct CAFFE2_API ${Type} : public TypeDefault { + explicit ${Type}(); + + Allocator* allocator() const override; + Device getDeviceFromPtr(void * data) const override; + std::unique_ptr generator() const override; + + virtual ScalarType scalarType() const override; + virtual caffe2::TypeMeta typeMeta() const override; + virtual Backend backend() const override; + virtual const char * toString() const override; + virtual size_t elementSizeInBytes() const override; + virtual TypeID ID() const override; + + ${type_method_declarations} +}; + +} // namespace at diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 1d2457154b..72a98f7ace 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -20,7 +20,8 @@ list(APPEND ATen_CPU_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp) list(APPEND ATen_CUDA_TEST_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp new file mode 100644 index 0000000000..be8d262f6a --- /dev/null +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -0,0 +1,66 @@ +#include + +#include +#include +#include + +using namespace at; + +static int test_int; + +Tensor empty_override(IntList size, const TensorOptions & options) { + test_int = 1; + auto tensor_impl = c10::make_intrusive( + Storage( + caffe2::TypeMeta::Make(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false), + MSNPUTensorId(), + false); + return Tensor(std::move(tensor_impl)); +} + +Tensor empty_like_override(const Tensor & self, const TensorOptions & options) { + test_int = 2; + return self; +} + +Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) { + test_int = 3; + return a; +} + +TEST(BackendExtensionTest, TestRegisterOp) { + EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU)); + register_extension_backend_op( + Backend::MSNPU, + "empty(IntList size, TensorOptions options) -> Tensor", &empty_override); + Tensor a = empty({5, 5}, at::kMSNPU); + ASSERT_EQ(a.device().type(), at::kMSNPU); + ASSERT_EQ(a.device().index(), 1); + ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make()); + ASSERT_EQ(test_int, 1); + + EXPECT_ANY_THROW(empty_like(a, at::kMSNPU)); + register_extension_backend_op( + Backend::MSNPU, + "empty_like(Tensor self, TensorOptions options) -> Tensor", &empty_like_override); + Tensor b = empty_like(a, at::kMSNPU); + ASSERT_EQ(test_int, 2); + + EXPECT_ANY_THROW(add(a, b)); + register_extension_backend_op( + Backend::MSNPU, + "add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override); + add(a, b); + ASSERT_EQ(test_int, 3); + + // Ensure that non-MSNPU operator still works + Tensor d = empty({5, 5}, at::kCPU); + ASSERT_EQ(d.device().type(), at::kCPU); + + // Attempt to register on a schema that has already has a function + EXPECT_ANY_THROW( + register_extension_backend_op( + Backend::MSNPU, + "empty(IntList size, TensorOptions options) -> Tensor", &empty_override) + ); +} diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh index 8ed5a21e2d..e2df276220 100755 --- a/aten/tools/run_tests.sh +++ b/aten/tools/run_tests.sh @@ -17,6 +17,7 @@ VALGRIND=${VALGRIND:=ON} ./scalar_tensor_test ./tensor_interop_test ./undefined_tensor_test +./extension_backend_test if [[ -x ./cudnn_test ]]; then ./cudnn_test fi diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 991c8f65aa..54a22cc499 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -20,7 +20,7 @@ namespace c10 { * would make sense in your use case. If it doesn't make sense, maybe * you want DeviceType. */ -enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, Undefined, NumOptions }; +enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, Undefined, NumOptions }; static inline Backend toSparse(Backend b) { switch (b) { @@ -49,6 +49,8 @@ static inline Backend toDense(Backend b) { return Backend::CUDA; case Backend::HIP: return Backend::HIP; + case Backend::MSNPU: + return Backend::MSNPU; case Backend::SparseCPU: return Backend::CPU; case Backend::SparseCUDA: @@ -67,6 +69,8 @@ static inline Backend tensorTypeIdToBackend(TensorTypeId t) { return Backend::CUDA; } else if (t == HIPTensorId()) { return Backend::HIP; + } else if (t == MSNPUTensorId()) { + return Backend::MSNPU; } else if (t == SparseCPUTensorId()) { return Backend::SparseCPU; } else if (t == SparseCUDATensorId()) { @@ -88,6 +92,8 @@ static inline TensorTypeId backendToTensorTypeId(Backend b) { return CUDATensorId(); case Backend::HIP: return HIPTensorId(); + case Backend::MSNPU: + return MSNPUTensorId(); case Backend::SparseCPU: return SparseCPUTensorId(); case Backend::SparseCUDA: @@ -109,6 +115,8 @@ static inline DeviceType backendToDeviceType(Backend b) { return DeviceType::CUDA; case Backend::HIP: return DeviceType::HIP; + case Backend::MSNPU: + return DeviceType::MSNPU; case Backend::SparseCPU: return DeviceType::CPU; case Backend::SparseCUDA: @@ -130,6 +138,8 @@ static inline Backend deviceTypeToBackend(DeviceType d) { return Backend::CUDA; case DeviceType::HIP: return Backend::HIP; + case DeviceType::MSNPU: + return Backend::MSNPU; default: AT_ERROR("Unknown device type ", d); } @@ -149,6 +159,8 @@ static inline Backend backendToCPU(Backend b) { return Backend::SparseCPU; case Backend::SparseHIP: return Backend::SparseCPU; + case Backend::MSNPU: + return Backend::CPU; case Backend::Undefined: return Backend::Undefined; default: @@ -161,6 +173,7 @@ static inline Backend backendToCUDA(Backend b) { case Backend::CPU: case Backend::CUDA: case Backend::HIP: + case Backend::MSNPU: return Backend::CUDA; case Backend::SparseCPU: case Backend::SparseCUDA: @@ -178,6 +191,7 @@ static inline Backend backendToHIP(Backend b) { case Backend::CPU: case Backend::CUDA: case Backend::HIP: + case Backend::MSNPU: return Backend::HIP; case Backend::SparseCPU: case Backend::SparseCUDA: @@ -193,6 +207,7 @@ static inline Backend backendToHIP(Backend b) { constexpr DeviceType kCPU = DeviceType::CPU; constexpr DeviceType kCUDA = DeviceType::CUDA; constexpr DeviceType kHIP = DeviceType::HIP; +constexpr DeviceType kMSNPU = DeviceType::MSNPU; static inline const char* toString(Backend b) { switch (b) { @@ -202,6 +217,8 @@ static inline const char* toString(Backend b) { return "CUDA"; case Backend::HIP: return "HIP"; + case Backend::MSNPU: + return "MSNPU"; case Backend::SparseCPU: return "SparseCPU"; case Backend::SparseCUDA: diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index e4b6415124..fd24b70113 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -23,6 +23,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) { return lower_case ? "hip" : "HIP"; case DeviceType::FPGA: return lower_case ? "fpga" : "FPGA"; + case DeviceType::MSNPU: + return lower_case ? "msnpu" : "MSNPU"; default: AT_ERROR( "Unknown device: ", @@ -53,6 +55,7 @@ bool isValidDeviceType(DeviceType d) { case DeviceType::IDEEP: case DeviceType::HIP: case DeviceType::FPGA: + case DeviceType::MSNPU: return true; default: return false; diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h index 916319ab19..de7f387f04 100644 --- a/c10/core/DeviceType.h +++ b/c10/core/DeviceType.h @@ -21,11 +21,12 @@ enum class DeviceType : int16_t { IDEEP = 5, // IDEEP. HIP = 6, // AMD HIP FPGA = 7, // FPGA + MSNPU = 8, // MSNPU // NB: If you add more devices: // - Change the implementations of DeviceTypeName and isValidDeviceType // in DeviceType.cpp // - Change the number below - COMPILE_TIME_MAX_DEVICE_TYPES = 8, + COMPILE_TIME_MAX_DEVICE_TYPES = 9, ONLY_FOR_TEST = 20901, // This device type is only for test. }; diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index fd2e79c5f9..3e258519b8 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -357,7 +357,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { int64_t get_device() const { // NB: This method is not virtual and tries to avoid dispatches in the common case for perf. const auto tid = type_id(); - if (tid == CUDATensorId() || tid == HIPTensorId()) { + if (tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) { // TODO: #12934 investigate caching device on TensorImpl to avoid this vdispatch. return storage().device().index(); } @@ -369,7 +369,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // TODO: This is a little convoluted so it would be good to investigate // caching device on TensorImpl (#12934) to speed up device() calls in all cases. const auto tid = type_id(); - if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId()) { + if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) { // NB: storage(), not storage_, b/c of Variable. const auto& mystorage = storage(); if (mystorage) { diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 82701ef9eb..cd2b464c91 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -326,13 +326,15 @@ struct C10_API TensorOptions { // Resolves the ATen backend specified by the current construction axes. Backend backend() const noexcept { - Backend backend; - if (device().type() == Device::Type::CPU) { - backend = (layout() == kStrided) ? Backend::CPU : Backend::SparseCPU; - } else { - backend = (layout() == kStrided) ? Backend::CUDA : Backend::SparseCUDA; + Backend backend = deviceTypeToBackend(device().type()); + switch (layout()) { + case kStrided: + return backend; + case kSparse: + return toSparse(backend); + default: + return backend; } - return backend; } private: @@ -507,6 +509,8 @@ inline TensorTypeId computeTensorTypeId(TensorOptions options) { return IDEEPTensorId(); case DeviceType::HIP: return HIPTensorId(); + case DeviceType::MSNPU: + return MSNPUTensorId(); default: AT_ERROR("Unsupported device type for dense layout: ", options.device().type()); } @@ -543,6 +547,8 @@ inline DeviceType computeDeviceType(TensorTypeId tid) { return DeviceType::IDEEP; } else if (tid == HIPTensorId()) { return DeviceType::HIP; + } else if (tid == MSNPUTensorId()) { + return DeviceType::MSNPU; } else if (tid == SparseCPUTensorId()) { return DeviceType::CPU; } else if (tid == SparseCUDATensorId()) { diff --git a/c10/core/TensorTypeIdRegistration.cpp b/c10/core/TensorTypeIdRegistration.cpp index ac3f62346c..2333d04606 100644 --- a/c10/core/TensorTypeIdRegistration.cpp +++ b/c10/core/TensorTypeIdRegistration.cpp @@ -68,5 +68,6 @@ C10_DEFINE_TENSOR_TYPE(OpenCLTensorId); C10_DEFINE_TENSOR_TYPE(IDEEPTensorId); C10_DEFINE_TENSOR_TYPE(HIPTensorId); C10_DEFINE_TENSOR_TYPE(SparseHIPTensorId); +C10_DEFINE_TENSOR_TYPE(MSNPUTensorId); } // namespace c10 diff --git a/c10/core/TensorTypeIdRegistration.h b/c10/core/TensorTypeIdRegistration.h index 16a427e1db..9d7d36eed8 100644 --- a/c10/core/TensorTypeIdRegistration.h +++ b/c10/core/TensorTypeIdRegistration.h @@ -107,6 +107,7 @@ C10_DECLARE_TENSOR_TYPE(OpenCLTensorId); // Caffe2 only C10_DECLARE_TENSOR_TYPE(IDEEPTensorId); // Caffe2 only C10_DECLARE_TENSOR_TYPE(HIPTensorId); // PyTorch/Caffe2 supported C10_DECLARE_TENSOR_TYPE(SparseHIPTensorId); // PyTorch only +C10_DECLARE_TENSOR_TYPE(MSNPUTensorId); // PyTorch only } // namespace c10 diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto index 1c466200f7..210e55f37f 100644 --- a/caffe2/proto/caffe2.proto +++ b/caffe2/proto/caffe2.proto @@ -178,8 +178,9 @@ enum DeviceTypeProto { PROTO_IDEEP = 5; // IDEEP. PROTO_HIP = 6; // AMD HIP PROTO_FPGA = 7; // FPGA + PROTO_MSNPU = 8; // MSNPU // Change the following number if you add more devices in the code. - PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 8; + PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 9; PROTO_ONLY_FOR_TEST = 20901; // This device type is only for test. } -- cgit v1.2.3