summaryrefslogtreecommitdiff
path: root/aten
diff options
context:
space:
mode:
authorDmytro Dzhulgakov <dzhulgakov@fb.com>2019-04-05 11:14:11 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-04-05 11:20:13 -0700
commitef779b33979018f31485b935dfa130e8763040b2 (patch)
treecdeac0695bd5c761d3268a21c35ccfcc4e74fc31 /aten
parentc3a559deb72626b664729192fd551f9bf269a43c (diff)
downloadpytorch-ef779b33979018f31485b935dfa130e8763040b2.tar.gz
pytorch-ef779b33979018f31485b935dfa130e8763040b2.tar.bz2
pytorch-ef779b33979018f31485b935dfa130e8763040b2.zip
Wrap workaround for cpp custom types a bit prettier and add an example (#18791)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18791 As a temporary demonstration on how to extend this hack further until custom C types are ready. Reviewed By: jamesr66a Differential Revision: D14742020 fbshipit-source-id: 0f2fd83ae56ab2abe16977a1829ed421e6abe74b
Diffstat (limited to 'aten')
-rw-r--r--aten/src/ATen/cpp_custom_type_hack.h48
-rw-r--r--aten/src/ATen/native/QuantizedLinear.cpp37
2 files changed, 61 insertions, 24 deletions
diff --git a/aten/src/ATen/cpp_custom_type_hack.h b/aten/src/ATen/cpp_custom_type_hack.h
new file mode 100644
index 0000000000..211f1a0b53
--- /dev/null
+++ b/aten/src/ATen/cpp_custom_type_hack.h
@@ -0,0 +1,48 @@
+// WARNING! WARNING! WARNING!
+// This file is a temporary hack to enable development of pytorch quantization
+//
+// It's a stub for wrapping arbitrary cpp types in TorchScript. Proper
+// implementation (under development) is to use TorchScript custom types.
+// In the meantime, we abuse ByteTensor with custom deleter for this purpose.
+//
+// Template argument <T> has to be registered with CAFFE_KNOWN_TYPE mechanism.
+
+#include "ATen/ATen.h"
+
+namespace at {
+namespace cpp_custom_type_hack {
+
+template<typename T>
+T& cast(const Tensor& packed) {
+ AT_CHECK(
+ packed.scalar_type() == kByte, "Expected temporary cpp type wrapper");
+ AT_CHECK(
+ packed.storage().data_ptr().get_deleter() ==
+ caffe2::TypeMeta::Make<T>().deleteFn(),
+ "Expected temporary cpp type wrapper of type ",
+ caffe2::TypeMeta::TypeName<T>());
+ return *reinterpret_cast<T*>(packed.storage().data_ptr().get());
+}
+
+template<typename T>
+Tensor create(std::unique_ptr<T> ptr) {
+ // We store this instance away in a Tensor and register a deleter function
+ // so that we do not leak memory. On the other side, we pull out the storage's
+ // data_ptr and get the right typed pointer.
+ void* raw_ptr = ptr.release();
+ at::DataPtr at_ptr(
+ raw_ptr,
+ raw_ptr,
+ caffe2::TypeMeta::Make<T>().deleteFn(),
+ at::kCPU);
+
+ // size doesn't really matter, but we can align it to the actual size
+ // returning variables because one likely want to use this hack from python
+ auto retval = at::empty(
+ {sizeof(T)},
+ at::device(kCPU).dtype(at::kByte).is_variable(true).requires_grad(false));
+ retval.storage().set_data_ptr(std::move(at_ptr));
+ return retval;
+}
+}
+}
diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp
index a9c1b3d56c..58f64dec52 100644
--- a/aten/src/ATen/native/QuantizedLinear.cpp
+++ b/aten/src/ATen/native/QuantizedLinear.cpp
@@ -1,6 +1,7 @@
#include "ATen/ATen.h"
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtilsMulti.h"
+#include "ATen/cpp_custom_type_hack.h"
#ifdef USE_FBGEMM
#include "fbgemm/Fbgemm.h"
@@ -16,6 +17,14 @@
#include <vector>
#include <chrono>
+
+namespace caffe2 {
+#ifdef USE_FBGEMM
+// Required for cpp_custom_type_hack to work
+CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix<int8_t>);
+#endif // USE_FBGEMM
+}
+
namespace at {
namespace native {
@@ -127,13 +136,12 @@ Tensor fbgemm_linear_int8_weight(
auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));
// Pull out the PackBMatrix instance from the owning tensor
- auto* packB = reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(
- packed.storage().data_ptr().get());
+ auto& packB = cpp_custom_type_hack::cast<fbgemm::PackBMatrix<int8_t>>(packed);
// Do the GEMM
fbgemm::fbgemmPacked(
/*packA=*/packA,
- /*packB=*/*packB,
+ /*packB=*/packB,
/*C=*/output.data<float>(),
/*C_buffer=*/buffer.data<int32_t>(),
/*ldc=*/N,
@@ -233,7 +241,7 @@ Tensor fbgemm_pack_quantized_matrix(
AT_ASSERTM(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
auto weight_contig = weight.contiguous();
auto contiguous_ptr = weight_contig.data<int8_t>();
- auto* ptr = new fbgemm::PackBMatrix<int8_t>(
+ auto ptr = std::make_unique<fbgemm::PackBMatrix<int8_t>>(
/*trans=*/fbgemm::matrix_op_t::Transpose,
/*nRow=*/K,
/*nCol=*/N,
@@ -241,26 +249,7 @@ Tensor fbgemm_pack_quantized_matrix(
/*ld=*/K,
/*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
/*groups=*/1);
-
- // We store this instance away in a Tensor and register a deleter function
- // so that we do not leak memory. On the other side, we pull out the storage's
- // data_ptr and get the PackBMatrix's pointer.
- at::DataPtr at_ptr(
- ptr,
- ptr,
- [](void* ptr) {
- fbgemm::PackBMatrix<int8_t>* typed_ptr =
- reinterpret_cast<fbgemm::PackBMatrix<int8_t>*>(ptr);
- delete typed_ptr;
- },
- at::kCPU);
-
- auto retval = at::empty(
- {sizeof(fbgemm::PackBMatrix<int8_t>)}, weight.options().dtype(at::kByte));
-
- retval.storage().set_data_ptr(std::move(at_ptr));
-
- return retval;
+ return cpp_custom_type_hack::create(std::move(ptr));
}
#else // USE_FBGEMM