diff options
author | Dmytro Dzhulgakov <dzhulgakov@fb.com> | 2019-04-05 11:14:11 -0700 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-05 11:20:13 -0700 |
commit | ef779b33979018f31485b935dfa130e8763040b2 (patch) | |
tree | cdeac0695bd5c761d3268a21c35ccfcc4e74fc31 /aten | |
parent | c3a559deb72626b664729192fd551f9bf269a43c (diff) | |
download | pytorch-ef779b33979018f31485b935dfa130e8763040b2.tar.gz pytorch-ef779b33979018f31485b935dfa130e8763040b2.tar.bz2 pytorch-ef779b33979018f31485b935dfa130e8763040b2.zip |
Wrap workaround for cpp custom types a bit prettier and add an example (#18791)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18791
As a temporary demonstration on how to extend this hack further until custom C types are ready.
Reviewed By: jamesr66a
Differential Revision: D14742020
fbshipit-source-id: 0f2fd83ae56ab2abe16977a1829ed421e6abe74b
Diffstat (limited to 'aten')
-rw-r--r-- | aten/src/ATen/cpp_custom_type_hack.h | 48 | ||||
-rw-r--r-- | aten/src/ATen/native/QuantizedLinear.cpp | 37 |
2 files changed, 61 insertions, 24 deletions
diff --git a/aten/src/ATen/cpp_custom_type_hack.h b/aten/src/ATen/cpp_custom_type_hack.h new file mode 100644 index 0000000000..211f1a0b53 --- /dev/null +++ b/aten/src/ATen/cpp_custom_type_hack.h @@ -0,0 +1,48 @@ +// WARNING! WARNING! WARNING! +// This file is a temporary hack to enable development of pytorch quantization +// +// It's a stub for wrapping arbitrary cpp types in TorchScript. Proper +// implementation (under development) is to use TorchScript custom types. +// In the meantime, we abuse ByteTensor with custom deleter for this purpose. +// +// Template argument <T> has to be registered with CAFFE_KNOWN_TYPE mechanism. + +#include "ATen/ATen.h" + +namespace at { +namespace cpp_custom_type_hack { + +template<typename T> +T& cast(const Tensor& packed) { + AT_CHECK( + packed.scalar_type() == kByte, "Expected temporary cpp type wrapper"); + AT_CHECK( + packed.storage().data_ptr().get_deleter() == + caffe2::TypeMeta::Make<T>().deleteFn(), + "Expected temporary cpp type wrapper of type ", + caffe2::TypeMeta::TypeName<T>()); + return *reinterpret_cast<T*>(packed.storage().data_ptr().get()); +} + +template<typename T> +Tensor create(std::unique_ptr<T> ptr) { + // We store this instance away in a Tensor and register a deleter function + // so that we do not leak memory. On the other side, we pull out the storage's + // data_ptr and get the right typed pointer. + void* raw_ptr = ptr.release(); + at::DataPtr at_ptr( + raw_ptr, + raw_ptr, + caffe2::TypeMeta::Make<T>().deleteFn(), + at::kCPU); + + // size doesn't really matter, but we can align it to the actual size + // returning variables because one likely want to use this hack from python + auto retval = at::empty( + {sizeof(T)}, + at::device(kCPU).dtype(at::kByte).is_variable(true).requires_grad(false)); + retval.storage().set_data_ptr(std::move(at_ptr)); + return retval; +} +} +} diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp index a9c1b3d56c..58f64dec52 100644 --- a/aten/src/ATen/native/QuantizedLinear.cpp +++ b/aten/src/ATen/native/QuantizedLinear.cpp @@ -1,6 +1,7 @@ #include "ATen/ATen.h" #include "ATen/NativeFunctions.h" #include "ATen/WrapDimUtilsMulti.h" +#include "ATen/cpp_custom_type_hack.h" #ifdef USE_FBGEMM #include "fbgemm/Fbgemm.h" @@ -16,6 +17,14 @@ #include <vector> #include <chrono> + +namespace caffe2 { +#ifdef USE_FBGEMM +// Required for cpp_custom_type_hack to work +CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix<int8_t>); +#endif // USE_FBGEMM +} + namespace at { namespace native { @@ -127,13 +136,12 @@ Tensor fbgemm_linear_int8_weight( auto buffer = at::zeros_like(output, output.options().dtype(at::kInt)); // Pull out the PackBMatrix instance from the owning tensor - auto* packB = reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>( - packed.storage().data_ptr().get()); + auto& packB = cpp_custom_type_hack::cast<fbgemm::PackBMatrix<int8_t>>(packed); // Do the GEMM fbgemm::fbgemmPacked( /*packA=*/packA, - /*packB=*/*packB, + /*packB=*/packB, /*C=*/output.data<float>(), /*C_buffer=*/buffer.data<int32_t>(), /*ldc=*/N, @@ -233,7 +241,7 @@ Tensor fbgemm_pack_quantized_matrix( AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); auto weight_contig = weight.contiguous(); auto contiguous_ptr = weight_contig.data<int8_t>(); - auto* ptr = new fbgemm::PackBMatrix<int8_t>( + auto ptr = std::make_unique<fbgemm::PackBMatrix<int8_t>>( /*trans=*/fbgemm::matrix_op_t::Transpose, /*nRow=*/K, /*nCol=*/N, @@ -241,26 +249,7 @@ Tensor fbgemm_pack_quantized_matrix( /*ld=*/K, /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat /*groups=*/1); - - // We store this instance away in a Tensor and register a deleter function - // so that we do not leak memory. On the other side, we pull out the storage's - // data_ptr and get the PackBMatrix's pointer. - at::DataPtr at_ptr( - ptr, - ptr, - [](void* ptr) { - fbgemm::PackBMatrix<int8_t>* typed_ptr = - reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(ptr); - delete typed_ptr; - }, - at::kCPU); - - auto retval = at::empty( - {sizeof(fbgemm::PackBMatrix<int8_t>)}, weight.options().dtype(at::kByte)); - - retval.storage().set_data_ptr(std::move(at_ptr)); - - return retval; + return cpp_custom_type_hack::create(std::move(ptr)); } #else // USE_FBGEMM |