summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRoy Li <royboy@fb.com>2019-02-01 10:55:00 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-01 11:00:16 -0800
commit7e642dfff3a4bbf80826f331312acff2672a3e02 (patch)
tree2b8499a1824bf467cfb152ec1089be517bcae520
parent64186e06eccad8721cfa741cae31a39842a3dfe3 (diff)
downloadpytorch-7e642dfff3a4bbf80826f331312acff2672a3e02.tar.gz
pytorch-7e642dfff3a4bbf80826f331312acff2672a3e02.tar.bz2
pytorch-7e642dfff3a4bbf80826f331312acff2672a3e02.zip
Introduce backend extensions (overriding operators on custom backends)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15153 Reviewed By: gchanan Differential Revision: D13445571 fbshipit-source-id: 62e2ebe0a6e81c4983b47cddb57ee5eb78e96708
-rw-r--r--aten/src/ATen/core/Type.h1
-rw-r--r--aten/src/ATen/function_wrapper.py35
-rw-r--r--aten/src/ATen/gen.py53
-rw-r--r--aten/src/ATen/templates/ExtensionBackendRegistration.h19
-rw-r--r--aten/src/ATen/templates/TypeExtension.cpp51
-rw-r--r--aten/src/ATen/templates/TypeExtension.h49
-rw-r--r--aten/src/ATen/test/CMakeLists.txt3
-rw-r--r--aten/src/ATen/test/extension_backend_test.cpp66
-rwxr-xr-xaten/tools/run_tests.sh1
-rw-r--r--c10/core/Backend.h19
-rw-r--r--c10/core/DeviceType.cpp3
-rw-r--r--c10/core/DeviceType.h3
-rw-r--r--c10/core/TensorImpl.h4
-rw-r--r--c10/core/TensorOptions.h18
-rw-r--r--c10/core/TensorTypeIdRegistration.cpp1
-rw-r--r--c10/core/TensorTypeIdRegistration.h1
-rw-r--r--caffe2/proto/caffe2.proto3
17 files changed, 317 insertions, 13 deletions
diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h
index 30dd0129c7..f7b20e7e53 100644
--- a/aten/src/ATen/core/Type.h
+++ b/aten/src/ATen/core/Type.h
@@ -74,6 +74,7 @@ enum class TypeID {
SparseCUDAInt,
SparseCUDALong,
SparseCUDAShort,
+ MSNPU,
CPUComplexFloat,
CPUComplexDouble,
CUDAComplexFloat,
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 4978f8e5f4..2ee5aeb76f 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -116,6 +116,13 @@ TYPE_DEFINITION_BODY_NATIVE = CodeTemplate("""\
${return_call} at::native::${native_type_method_dispatch}(/* native_actuals */ ${native_actuals});
""")
+# Overrideable stubs to be used in user-extendable backends
+TYPE_DEFINITION_EXTENSION_BACKEND = CodeTemplate("""\
+${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const {
+ return ${Type}Dispatch::get_function<${return_type} (*)(${formals_types})>("${schema}")(${native_actuals});
+}
+""")
+
# add non-virtual declaration to Tensor.h
TENSOR_METHOD_DECLARATION = CodeTemplate("""\
${return_type} ${api_name}(${method_formals_with_defaults})${const_mark};
@@ -489,6 +496,7 @@ FunctionOption = TypedDict('FunctionOption', {
'formals_list': List[AtFormal],
'formals_with_defaults': List[str],
'formals': List[str],
+ 'formals_types': List[str],
'inferred_type': str,
'inplace': bool,
'matches_jit_signature': bool,
@@ -513,6 +521,8 @@ FunctionOption = TypedDict('FunctionOption', {
'return': ReturnDecl,
'returns': List[ReturnType],
'scalar_check': str,
+ # schema used for extension backend operator registration
+ 'schema': str,
'sparse': bool,
'type_definition_body': List[str],
'type_method_actuals': List[str],
@@ -1595,3 +1605,28 @@ def create_derived(backend_type_env, declarations):
except NYIError:
pass
return type_object_declarations, type_object_definitions
+
+
+def create_extension_backend(backend_type_env, declarations):
+ # type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]]
+ type_object_declarations = []
+ type_object_definitions = []
+
+ for declaration in declarations:
+ for option in declaration['options']:
+ if not option.get('skip', False):
+ try:
+ option['formals_types'] = [f['type'] for f in option['formals_list']]
+ option['native_actuals'] = [f['name'] for f in option['formals_list']]
+ schema_args = ", ".join(
+ ["{} {}".format(f['dynamic_type'], f['name']) for f in option['formals_list']])
+ return_type = NATIVE_DYNAMIC_TYPE.get(option['return_type'], option['return_type'])
+ option['schema'] = "{}({}) -> {}".format(option['api_name'], schema_args, return_type)
+ env = nested_dict(option, backend_type_env)
+ type_object_declarations.append(
+ TYPE_DERIVED_DECLARATION.substitute(env))
+ type_object_definitions.append(
+ TYPE_DEFINITION_EXTENSION_BACKEND.substitute(env))
+ except NYIError:
+ pass
+ return type_object_declarations, type_object_definitions
diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py
index 32255e39c1..c8207710bb 100644
--- a/aten/src/ATen/gen.py
+++ b/aten/src/ATen/gen.py
@@ -121,6 +121,8 @@ TYPE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/Type.h")
TYPE_EXTENDED_INTERFACE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtendedInterface.h")
TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h")
TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp")
+TYPE_EXTENSION_BACKEND_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.h")
+TYPE_EXTENSION_BACKEND_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.cpp")
LEGACY_TH_DISPATCHER_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.h")
LEGACY_TH_DISPATCHER_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.cpp")
@@ -141,10 +143,18 @@ LEGACY_TH_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHFunctio
NATIVE_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/NativeFunctions.h")
+EXTENSION_BACKEND_REGISTRATION_H = CodeTemplate.from_file(TEMPLATE_PATH + "/ExtensionBackendRegistration.h")
+
TYPE_REGISTER = CodeTemplate("""\
context->registerType(Backend::${backend}, ScalarType::${scalar_type}, new ${type_name}());
""")
+EXTENSION_BACKEND_REGISTER_SWITCH = CodeTemplate("""\
+case Backend::${Backend}:
+ ${Type}Dispatch::register_function(schema, fn);
+ break;
+""")
+
core_file_manager = FileManager(core_install_dir)
file_manager = FileManager()
cuda_file_manager = FileManager()
@@ -164,6 +174,7 @@ generators = {
backends = ['CPU', 'CUDA']
densities = ['Dense', 'Sparse']
+extension_backends = ['MSNPU']
# scalar_name, c_type, accreal, th_scalar_type, is_floating_type
scalar_types = [
@@ -193,6 +204,8 @@ top_env = {
'function_definitions': [],
'type_ids': [],
'native_function_declarations': [],
+ 'extension_backend_headers': [],
+ 'extension_backend_register_switches': [],
}
@@ -347,6 +360,37 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
return env
+def generate_type_extension_backend(backend, declarations):
+ env = {}
+ env['Type'] = "{}Type".format(backend)
+ env['Backend'] = backend
+ env['DeviceType'] = backend
+ env['is_extension_backend'] = True
+ env['TypeID'] = 'TypeID::' + backend
+ top_env['type_ids'].append(backend + ',')
+
+ declarations, definitions = function_wrapper.create_extension_backend(
+ env, declarations)
+ env['type_method_declarations'] = declarations
+ env['type_method_definitions'] = definitions
+
+ fm = file_manager
+ fm.write(env['Type'] + ".cpp", TYPE_EXTENSION_BACKEND_CPP, env)
+ fm.write(env['Type'] + ".h", TYPE_EXTENSION_BACKEND_H, env)
+
+ for scalar_name, _, _, _, _ in scalar_types:
+ type_register = TYPE_REGISTER.substitute(backend=env['Backend'], scalar_type=scalar_name, type_name=env['Type'])
+ top_env['cpu_type_registrations'].append(type_register)
+ extension_backend_register_switch = EXTENSION_BACKEND_REGISTER_SWITCH.substitute(env)
+ top_env['extension_backend_register_switches'].append(extension_backend_register_switch)
+ top_env['extension_backend_headers'].append(
+ '#include <ATen/{}.h>'.format(env['Type']))
+ top_env['cpu_type_headers'].append(
+ '#include "ATen/{}.h"'.format(env['Type']))
+
+ return env
+
+
def generate_legacy_th_dispatcher(backend, density, scalar_type, declarations):
assert density != 'Sparse'
scalar_name, c_type, accreal, th_scalar_type, is_floating_type = scalar_type
@@ -384,7 +428,7 @@ def declare_outputs():
core_file_manager.will_write(f)
files = ['Declarations.yaml', 'TypeExtendedInterface.h', 'TypeDefault.cpp', 'TypeDefault.h',
'LegacyTHDispatcher.h', 'LegacyTHDispatcher.cpp', 'LegacyTHFunctions.h',
- 'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h']
+ 'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h', 'ExtensionBackendRegistration.h']
for f in files:
file_manager.will_write(f)
cuda_files = ['RegisterCUDA.cpp', 'RegisterCUDA.h']
@@ -411,6 +455,9 @@ def declare_outputs():
if density != 'Sparse':
fm.will_write("{}{}{}{}.h".format('LegacyTH', full_backend, scalar_name, 'Dispatcher'))
fm.will_write("{}{}{}{}.cpp".format('LegacyTH', full_backend, scalar_name, 'Dispatcher'))
+ for backend in extension_backends:
+ file_manager.will_write("{}Type.h".format(backend))
+ file_manager.will_write("{}Type.cpp".format(backend))
def filter_by_extension(files, *extensions):
@@ -472,6 +519,8 @@ def generate_outputs():
for backend, density, scalar_type in iterate_types():
all_types.append(generate_storage_type_and_tensor(
backend, density, scalar_type, declarations))
+ for backend in extension_backends:
+ all_types.append(generate_type_extension_backend(backend, declarations))
all_legacy_th_dispatchers = []
for backend, density, scalar_type in iterate_types():
@@ -506,6 +555,8 @@ def generate_outputs():
file_manager.write('NativeFunctions.h', NATIVE_FUNCTIONS_H, top_env)
+ file_manager.write('ExtensionBackendRegistration.h', EXTENSION_BACKEND_REGISTRATION_H, top_env)
+
file_manager.check_all_files_written()
cuda_file_manager.check_all_files_written()
diff --git a/aten/src/ATen/templates/ExtensionBackendRegistration.h b/aten/src/ATen/templates/ExtensionBackendRegistration.h
new file mode 100644
index 0000000000..2f9f731922
--- /dev/null
+++ b/aten/src/ATen/templates/ExtensionBackendRegistration.h
@@ -0,0 +1,19 @@
+#pragma once
+#include <ATen/Backend.h>
+${extension_backend_headers}
+
+namespace at {
+
+template <typename FnPtr>
+inline void register_extension_backend_op(
+ Backend backend,
+ const char * schema,
+ FnPtr fn) {
+ switch (backend) {
+ ${extension_backend_register_switches}
+ default:
+ AT_ERROR("Invalid extension backend: ", toString(backend));
+ }
+}
+
+} // namespace at
diff --git a/aten/src/ATen/templates/TypeExtension.cpp b/aten/src/ATen/templates/TypeExtension.cpp
new file mode 100644
index 0000000000..313b865fc0
--- /dev/null
+++ b/aten/src/ATen/templates/TypeExtension.cpp
@@ -0,0 +1,51 @@
+#include <ATen/${Type}.h>
+
+namespace at {
+
+std::unordered_map<std::string, void *>& ${Type}Dispatch::get_fn_table() {
+ static std::unordered_map<std::string, void *> fn_table;
+ return fn_table;
+}
+
+${Type}::${Type}()
+ : TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
+
+Allocator* ${Type}::allocator() const {
+ AT_ERROR("allocator is not implemented for ${Type}");
+}
+
+Device ${Type}::getDeviceFromPtr(void * data) const {
+ return DeviceType::${DeviceType};
+}
+
+std::unique_ptr<Generator> ${Type}::generator() const {
+ AT_ERROR("generator is not implemented for ${Type}");
+}
+
+ScalarType ${Type}::scalarType() const {
+ AT_ERROR("scalarType is not implemented for ${Type}");
+}
+
+caffe2::TypeMeta ${Type}::typeMeta() const {
+ AT_ERROR("typeMeta is not implemented for ${Type}");
+}
+
+Backend ${Type}::backend() const {
+ return Backend::${Backend};
+}
+
+const char * ${Type}::toString() const {
+ return "${Type}";
+}
+
+TypeID ${Type}::ID() const {
+ return ${TypeID};
+}
+
+size_t ${Type}::elementSizeInBytes() const {
+ AT_ERROR("elementSizeInBytes is not implemented for ${Type}");
+}
+
+${type_method_definitions}
+
+} // namespace at
diff --git a/aten/src/ATen/templates/TypeExtension.h b/aten/src/ATen/templates/TypeExtension.h
new file mode 100644
index 0000000000..fe45960225
--- /dev/null
+++ b/aten/src/ATen/templates/TypeExtension.h
@@ -0,0 +1,49 @@
+#pragma once
+#include <ATen/TypeDefault.h>
+
+namespace at {
+
+// This dispatch class holds static map in which function pointers are
+// registered by schema.
+// TODO: Check for invalid schemas prior to registration.
+struct CAFFE2_API ${Type}Dispatch {
+ template<typename FnPtr>
+ static FnPtr get_function(const std::string& schema) {
+ auto & fn_table = get_fn_table();
+ auto it = fn_table.find(schema);
+ if (it != fn_table.end()) {
+ return reinterpret_cast<FnPtr>(it->second);
+ }
+ AT_ERROR("No function registered for schema: ", schema);
+ }
+
+ template<typename FnPtr>
+ static void register_function(const std::string& schema, FnPtr fn) {
+ auto & fn_table = get_fn_table();
+ if (fn_table.find(schema) != fn_table.end()) {
+ AT_ERROR("Function already registered for schema: ", schema);
+ }
+ fn_table[schema] = reinterpret_cast<void *>(fn);
+ }
+
+ static std::unordered_map<std::string, void *>& get_fn_table();
+};
+
+struct CAFFE2_API ${Type} : public TypeDefault {
+ explicit ${Type}();
+
+ Allocator* allocator() const override;
+ Device getDeviceFromPtr(void * data) const override;
+ std::unique_ptr<Generator> generator() const override;
+
+ virtual ScalarType scalarType() const override;
+ virtual caffe2::TypeMeta typeMeta() const override;
+ virtual Backend backend() const override;
+ virtual const char * toString() const override;
+ virtual size_t elementSizeInBytes() const override;
+ virtual TypeID ID() const override;
+
+ ${type_method_declarations}
+};
+
+} // namespace at
diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt
index 1d2457154b..72a98f7ace 100644
--- a/aten/src/ATen/test/CMakeLists.txt
+++ b/aten/src/ATen/test/CMakeLists.txt
@@ -20,7 +20,8 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp)
+ ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp)
list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp
new file mode 100644
index 0000000000..be8d262f6a
--- /dev/null
+++ b/aten/src/ATen/test/extension_backend_test.cpp
@@ -0,0 +1,66 @@
+#include <gtest/gtest.h>
+
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/ExtensionBackendRegistration.h>
+
+using namespace at;
+
+static int test_int;
+
+Tensor empty_override(IntList size, const TensorOptions & options) {
+ test_int = 1;
+ auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
+ Storage(
+ caffe2::TypeMeta::Make<float>(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false),
+ MSNPUTensorId(),
+ false);
+ return Tensor(std::move(tensor_impl));
+}
+
+Tensor empty_like_override(const Tensor & self, const TensorOptions & options) {
+ test_int = 2;
+ return self;
+}
+
+Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
+ test_int = 3;
+ return a;
+}
+
+TEST(BackendExtensionTest, TestRegisterOp) {
+ EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU));
+ register_extension_backend_op(
+ Backend::MSNPU,
+ "empty(IntList size, TensorOptions options) -> Tensor", &empty_override);
+ Tensor a = empty({5, 5}, at::kMSNPU);
+ ASSERT_EQ(a.device().type(), at::kMSNPU);
+ ASSERT_EQ(a.device().index(), 1);
+ ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
+ ASSERT_EQ(test_int, 1);
+
+ EXPECT_ANY_THROW(empty_like(a, at::kMSNPU));
+ register_extension_backend_op(
+ Backend::MSNPU,
+ "empty_like(Tensor self, TensorOptions options) -> Tensor", &empty_like_override);
+ Tensor b = empty_like(a, at::kMSNPU);
+ ASSERT_EQ(test_int, 2);
+
+ EXPECT_ANY_THROW(add(a, b));
+ register_extension_backend_op(
+ Backend::MSNPU,
+ "add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override);
+ add(a, b);
+ ASSERT_EQ(test_int, 3);
+
+ // Ensure that non-MSNPU operator still works
+ Tensor d = empty({5, 5}, at::kCPU);
+ ASSERT_EQ(d.device().type(), at::kCPU);
+
+ // Attempt to register on a schema that has already has a function
+ EXPECT_ANY_THROW(
+ register_extension_backend_op(
+ Backend::MSNPU,
+ "empty(IntList size, TensorOptions options) -> Tensor", &empty_override)
+ );
+}
diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh
index 8ed5a21e2d..e2df276220 100755
--- a/aten/tools/run_tests.sh
+++ b/aten/tools/run_tests.sh
@@ -17,6 +17,7 @@ VALGRIND=${VALGRIND:=ON}
./scalar_tensor_test
./tensor_interop_test
./undefined_tensor_test
+./extension_backend_test
if [[ -x ./cudnn_test ]]; then
./cudnn_test
fi
diff --git a/c10/core/Backend.h b/c10/core/Backend.h
index 991c8f65aa..54a22cc499 100644
--- a/c10/core/Backend.h
+++ b/c10/core/Backend.h
@@ -20,7 +20,7 @@ namespace c10 {
* would make sense in your use case. If it doesn't make sense, maybe
* you want DeviceType.
*/
-enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, Undefined, NumOptions };
+enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, Undefined, NumOptions };
static inline Backend toSparse(Backend b) {
switch (b) {
@@ -49,6 +49,8 @@ static inline Backend toDense(Backend b) {
return Backend::CUDA;
case Backend::HIP:
return Backend::HIP;
+ case Backend::MSNPU:
+ return Backend::MSNPU;
case Backend::SparseCPU:
return Backend::CPU;
case Backend::SparseCUDA:
@@ -67,6 +69,8 @@ static inline Backend tensorTypeIdToBackend(TensorTypeId t) {
return Backend::CUDA;
} else if (t == HIPTensorId()) {
return Backend::HIP;
+ } else if (t == MSNPUTensorId()) {
+ return Backend::MSNPU;
} else if (t == SparseCPUTensorId()) {
return Backend::SparseCPU;
} else if (t == SparseCUDATensorId()) {
@@ -88,6 +92,8 @@ static inline TensorTypeId backendToTensorTypeId(Backend b) {
return CUDATensorId();
case Backend::HIP:
return HIPTensorId();
+ case Backend::MSNPU:
+ return MSNPUTensorId();
case Backend::SparseCPU:
return SparseCPUTensorId();
case Backend::SparseCUDA:
@@ -109,6 +115,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::CUDA;
case Backend::HIP:
return DeviceType::HIP;
+ case Backend::MSNPU:
+ return DeviceType::MSNPU;
case Backend::SparseCPU:
return DeviceType::CPU;
case Backend::SparseCUDA:
@@ -130,6 +138,8 @@ static inline Backend deviceTypeToBackend(DeviceType d) {
return Backend::CUDA;
case DeviceType::HIP:
return Backend::HIP;
+ case DeviceType::MSNPU:
+ return Backend::MSNPU;
default:
AT_ERROR("Unknown device type ", d);
}
@@ -149,6 +159,8 @@ static inline Backend backendToCPU(Backend b) {
return Backend::SparseCPU;
case Backend::SparseHIP:
return Backend::SparseCPU;
+ case Backend::MSNPU:
+ return Backend::CPU;
case Backend::Undefined:
return Backend::Undefined;
default:
@@ -161,6 +173,7 @@ static inline Backend backendToCUDA(Backend b) {
case Backend::CPU:
case Backend::CUDA:
case Backend::HIP:
+ case Backend::MSNPU:
return Backend::CUDA;
case Backend::SparseCPU:
case Backend::SparseCUDA:
@@ -178,6 +191,7 @@ static inline Backend backendToHIP(Backend b) {
case Backend::CPU:
case Backend::CUDA:
case Backend::HIP:
+ case Backend::MSNPU:
return Backend::HIP;
case Backend::SparseCPU:
case Backend::SparseCUDA:
@@ -193,6 +207,7 @@ static inline Backend backendToHIP(Backend b) {
constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kHIP = DeviceType::HIP;
+constexpr DeviceType kMSNPU = DeviceType::MSNPU;
static inline const char* toString(Backend b) {
switch (b) {
@@ -202,6 +217,8 @@ static inline const char* toString(Backend b) {
return "CUDA";
case Backend::HIP:
return "HIP";
+ case Backend::MSNPU:
+ return "MSNPU";
case Backend::SparseCPU:
return "SparseCPU";
case Backend::SparseCUDA:
diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp
index e4b6415124..fd24b70113 100644
--- a/c10/core/DeviceType.cpp
+++ b/c10/core/DeviceType.cpp
@@ -23,6 +23,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
return lower_case ? "hip" : "HIP";
case DeviceType::FPGA:
return lower_case ? "fpga" : "FPGA";
+ case DeviceType::MSNPU:
+ return lower_case ? "msnpu" : "MSNPU";
default:
AT_ERROR(
"Unknown device: ",
@@ -53,6 +55,7 @@ bool isValidDeviceType(DeviceType d) {
case DeviceType::IDEEP:
case DeviceType::HIP:
case DeviceType::FPGA:
+ case DeviceType::MSNPU:
return true;
default:
return false;
diff --git a/c10/core/DeviceType.h b/c10/core/DeviceType.h
index 916319ab19..de7f387f04 100644
--- a/c10/core/DeviceType.h
+++ b/c10/core/DeviceType.h
@@ -21,11 +21,12 @@ enum class DeviceType : int16_t {
IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
+ MSNPU = 8, // MSNPU
// NB: If you add more devices:
// - Change the implementations of DeviceTypeName and isValidDeviceType
// in DeviceType.cpp
// - Change the number below
- COMPILE_TIME_MAX_DEVICE_TYPES = 8,
+ COMPILE_TIME_MAX_DEVICE_TYPES = 9,
ONLY_FOR_TEST = 20901, // This device type is only for test.
};
diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h
index fd2e79c5f9..3e258519b8 100644
--- a/c10/core/TensorImpl.h
+++ b/c10/core/TensorImpl.h
@@ -357,7 +357,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
int64_t get_device() const {
// NB: This method is not virtual and tries to avoid dispatches in the common case for perf.
const auto tid = type_id();
- if (tid == CUDATensorId() || tid == HIPTensorId()) {
+ if (tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) {
// TODO: #12934 investigate caching device on TensorImpl to avoid this vdispatch.
return storage().device().index();
}
@@ -369,7 +369,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
// TODO: This is a little convoluted so it would be good to investigate
// caching device on TensorImpl (#12934) to speed up device() calls in all cases.
const auto tid = type_id();
- if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId()) {
+ if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) {
// NB: storage(), not storage_, b/c of Variable.
const auto& mystorage = storage();
if (mystorage) {
diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h
index 82701ef9eb..cd2b464c91 100644
--- a/c10/core/TensorOptions.h
+++ b/c10/core/TensorOptions.h
@@ -326,13 +326,15 @@ struct C10_API TensorOptions {
// Resolves the ATen backend specified by the current construction axes.
Backend backend() const noexcept {
- Backend backend;
- if (device().type() == Device::Type::CPU) {
- backend = (layout() == kStrided) ? Backend::CPU : Backend::SparseCPU;
- } else {
- backend = (layout() == kStrided) ? Backend::CUDA : Backend::SparseCUDA;
+ Backend backend = deviceTypeToBackend(device().type());
+ switch (layout()) {
+ case kStrided:
+ return backend;
+ case kSparse:
+ return toSparse(backend);
+ default:
+ return backend;
}
- return backend;
}
private:
@@ -507,6 +509,8 @@ inline TensorTypeId computeTensorTypeId(TensorOptions options) {
return IDEEPTensorId();
case DeviceType::HIP:
return HIPTensorId();
+ case DeviceType::MSNPU:
+ return MSNPUTensorId();
default:
AT_ERROR("Unsupported device type for dense layout: ", options.device().type());
}
@@ -543,6 +547,8 @@ inline DeviceType computeDeviceType(TensorTypeId tid) {
return DeviceType::IDEEP;
} else if (tid == HIPTensorId()) {
return DeviceType::HIP;
+ } else if (tid == MSNPUTensorId()) {
+ return DeviceType::MSNPU;
} else if (tid == SparseCPUTensorId()) {
return DeviceType::CPU;
} else if (tid == SparseCUDATensorId()) {
diff --git a/c10/core/TensorTypeIdRegistration.cpp b/c10/core/TensorTypeIdRegistration.cpp
index ac3f62346c..2333d04606 100644
--- a/c10/core/TensorTypeIdRegistration.cpp
+++ b/c10/core/TensorTypeIdRegistration.cpp
@@ -68,5 +68,6 @@ C10_DEFINE_TENSOR_TYPE(OpenCLTensorId);
C10_DEFINE_TENSOR_TYPE(IDEEPTensorId);
C10_DEFINE_TENSOR_TYPE(HIPTensorId);
C10_DEFINE_TENSOR_TYPE(SparseHIPTensorId);
+C10_DEFINE_TENSOR_TYPE(MSNPUTensorId);
} // namespace c10
diff --git a/c10/core/TensorTypeIdRegistration.h b/c10/core/TensorTypeIdRegistration.h
index 16a427e1db..9d7d36eed8 100644
--- a/c10/core/TensorTypeIdRegistration.h
+++ b/c10/core/TensorTypeIdRegistration.h
@@ -107,6 +107,7 @@ C10_DECLARE_TENSOR_TYPE(OpenCLTensorId); // Caffe2 only
C10_DECLARE_TENSOR_TYPE(IDEEPTensorId); // Caffe2 only
C10_DECLARE_TENSOR_TYPE(HIPTensorId); // PyTorch/Caffe2 supported
C10_DECLARE_TENSOR_TYPE(SparseHIPTensorId); // PyTorch only
+C10_DECLARE_TENSOR_TYPE(MSNPUTensorId); // PyTorch only
} // namespace c10
diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto
index 1c466200f7..210e55f37f 100644
--- a/caffe2/proto/caffe2.proto
+++ b/caffe2/proto/caffe2.proto
@@ -178,8 +178,9 @@ enum DeviceTypeProto {
PROTO_IDEEP = 5; // IDEEP.
PROTO_HIP = 6; // AMD HIP
PROTO_FPGA = 7; // FPGA
+ PROTO_MSNPU = 8; // MSNPU
// Change the following number if you add more devices in the code.
- PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 8;
+ PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 9;
PROTO_ONLY_FOR_TEST = 20901; // This device type is only for test.
}