summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-02-06 15:14:17 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-06 15:32:01 -0800
commit8105aaca8611acd9e33707d6b57c6b1f144e4ab4 (patch)
tree3a344ec0b3368f842655301b160a8a6d4973ac57
parent30ab1773f9b0e32e3423c2f23b6561dac752252c (diff)
downloadpytorch-8105aaca8611acd9e33707d6b57c6b1f144e4ab4.tar.gz
pytorch-8105aaca8611acd9e33707d6b57c6b1f144e4ab4.tar.bz2
pytorch-8105aaca8611acd9e33707d6b57c6b1f144e4ab4.zip
int8 SpatialBN (#16796)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16796 SpatialBN int8 version Reviewed By: dskhudia Differential Revision: D13971224 fbshipit-source-id: e55fd608c161069daaa4e62c618bc14b01f32cb7
-rw-r--r--caffe2/quantization/server/CMakeLists.txt1
-rw-r--r--caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.cc141
-rw-r--r--caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h43
-rw-r--r--caffe2/quantization/server/spatial_batch_norm_dnnlowp_op_test.py110
4 files changed, 295 insertions, 0 deletions
diff --git a/caffe2/quantization/server/CMakeLists.txt b/caffe2/quantization/server/CMakeLists.txt
index 8aedc5a8f2..b21eab5332 100644
--- a/caffe2/quantization/server/CMakeLists.txt
+++ b/caffe2/quantization/server/CMakeLists.txt
@@ -37,6 +37,7 @@ list(APPEND Caffe2_CPU_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/quantize_dnnlowp_op.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/relu_dnnlowp_op.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/sigmoid_dnnlowp_op.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/spatial_batch_norm_dnnlowp_op.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/tanh_dnnlowp_op.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/utility_dnnlowp_ops.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/dynamic_histogram.cc"
diff --git a/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.cc b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.cc
new file mode 100644
index 0000000000..414089f1ce
--- /dev/null
+++ b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.cc
@@ -0,0 +1,141 @@
+#include "caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h"
+
+#include "caffe2/quantization/server/caffe2_dnnlowp_utils.h"
+
+namespace caffe2 {
+
+template <typename T>
+SpatialBNDNNLowPOp<T>::SpatialBNDNNLowPOp(
+ const OperatorDef& operator_def,
+ Workspace* ws)
+ : DNNLowPOp<T, SpatialBNOp<CPUContext>>(operator_def, ws),
+ OP_SINGLE_ARG(double, "epsilon", epsilon_, 1e-5),
+ order_(StringToStorageOrder(
+ this->template GetSingleArgument<std::string>("order", "NCHW"))) {
+ bool is_test = this->template GetSingleArgument<bool>("is_test", false);
+ OPERATOR_NEEDS_FEATURE(
+ is_test, "SpatialBN DNNLOWP op only works for inference.");
+ CAFFE_ENFORCE_NE(
+ order_,
+ StorageOrder::UNKNOWN,
+ "order should be either \"NCHW\" or \"NHWC\".");
+ CAFFE_ENFORCE(OutputSize() == 1);
+ CAFFE_ENFORCE_GT(epsilon_, 0);
+}
+
+template <typename T>
+void SpatialBNDNNLowPOp<T>::ComputeFusedParam_(
+ const int C,
+ const float* scale,
+ const float* bias,
+ const float* mean,
+ const float* var,
+ float* alpha,
+ float* beta) {
+ EigenVectorArrayMap<float> alpha_arr(alpha, C);
+ EigenVectorArrayMap<float> beta_arr(beta, C);
+ alpha_arr = ConstEigenVectorArrayMap<float>(scale, C) *
+ (ConstEigenVectorArrayMap<float>(var, C) + epsilon_).rsqrt();
+ beta_arr = ConstEigenVectorArrayMap<float>(bias, C) -
+ alpha_arr * ConstEigenVectorArrayMap<float>(mean, C);
+
+ // Adjust alpha and beta considering quantization scales
+ alpha_arr = alpha_arr * (in_qparams_[0].scale / out_qparams_.scale);
+ beta_arr = beta_arr / out_qparams_.scale;
+}
+
+template <typename T>
+bool SpatialBNDNNLowPOp<T>::RunOnDevice() {
+ const auto& X = InputTensorCPU_(INPUT);
+ const auto& scale = Input(SCALE);
+ const auto& bias = Input(BIAS);
+
+ const int ndim = X.dim();
+ CAFFE_ENFORCE_GE(ndim, 3);
+ const int N = X.dim32(0);
+ const int C = (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1));
+ const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
+ const int HxW =
+ std::accumulate(
+ X_dims.cbegin() + 1, X_dims.cend(), 1, std::multiplies<int>()) /
+ C;
+ CAFFE_ENFORCE_EQ(scale.numel(), C);
+ CAFFE_ENFORCE_EQ(bias.numel(), C);
+
+ GetOutputQuantizationParams_();
+
+ in_qparams_[0] = GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get());
+
+ const float* scale_data = scale.template data<float>();
+ const float* bias_data = bias.template data<float>();
+ ReinitializeTensor(
+ &alpha_, {C}, at::dtype<float>().device(CPUContext::GetDeviceType()));
+ ReinitializeTensor(
+ &beta_, {C}, at::dtype<float>().device(CPUContext::GetDeviceType()));
+ float* alpha_data = alpha_.template mutable_data<float>();
+ float* beta_data = beta_.template mutable_data<float>();
+ if (N == 0) {
+ return true;
+ }
+ const auto& mean = Input(EST_MEAN);
+ const auto& var = Input(EST_VAR);
+ CAFFE_ENFORCE_EQ(mean.numel(), C);
+ CAFFE_ENFORCE_EQ(var.numel(), C);
+ ComputeFusedParam_(
+ C,
+ scale_data,
+ bias_data,
+ mean.template data<float>(),
+ var.template data<float>(),
+ alpha_data,
+ beta_data);
+
+ vector<T> X_temp;
+ const T* X_data =
+ dnnlowp::QuantizeInputIfNeeded(this, 0, in_qparams_[0], X_temp);
+ auto* Y = OutputTensorCPU_(OUTPUT);
+ Y->Resize(X.sizes());
+ T* Y_data = GetQuantizedOutputData_();
+
+ if (order_ == StorageOrder::NCHW) {
+ for (int c = 0; c < C; ++c) {
+ for (int i = 0; i < N; ++i) {
+ for (int j = 0; j < HxW; ++j) {
+ long quantized_down = out_qparams_.zero_point +
+ std::lrintf(alpha_data[c] *
+ (X_data[(i * C + c) * HxW + j] -
+ in_qparams_[0].zero_point) +
+ beta_data[c]);
+ Y_data[(i * C + c) * HxW + j] =
+ fbgemm::clamp<long, T>(quantized_down, 8);
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < N * HxW; ++i) {
+ for (int c = 0; c < C; ++c) {
+ long quantized_down = out_qparams_.zero_point +
+ std::lrintf(alpha_data[c] *
+ (X_data[i * C + c] - in_qparams_[0].zero_point) +
+ beta_data[c]);
+ Y_data[i * C + c] = fbgemm::clamp<long, T>(quantized_down, 8);
+ }
+ }
+ }
+
+ RunOnDeviceEpilogue_();
+
+ return true;
+}
+
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+ SpatialBN,
+ DNNLOWP,
+ SpatialBNDNNLowPOp<uint8_t>);
+
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+ Int8SpatialBN,
+ DNNLOWP,
+ SpatialBNDNNLowPOp<uint8_t>);
+
+} // namespace caffe2
diff --git a/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h
new file mode 100644
index 0000000000..076e91d41e
--- /dev/null
+++ b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op.h
@@ -0,0 +1,43 @@
+#pragma once
+
+#include "caffe2/operators/spatial_batch_norm_op.h"
+#include "caffe2/quantization/server/dnnlowp_op.h"
+
+namespace caffe2 {
+
+/**
+ * Note this implementation assumes SCALE, BIAS, EST_MEAN, and EST_VAR inputs
+ * are still in fp32, so is epsilon argument
+ */
+template <typename T>
+class SpatialBNDNNLowPOp final : public DNNLowPOp<T, SpatialBNOp<CPUContext>> {
+ public:
+ USE_OPERATOR_FUNCTIONS(CPUContext);
+ USE_DNNLOWP_OPERATOR_BASE_FUNCTIONS(T, SpatialBNOp<CPUContext>);
+ SpatialBNDNNLowPOp(const OperatorDef& operator_def, Workspace* ws);
+
+ virtual ~SpatialBNDNNLowPOp() override = default;
+
+ bool RunOnDevice() override;
+
+ private:
+ void ComputeFusedParam_(
+ const int C,
+ const float* scale,
+ const float* bias,
+ const float* mean,
+ const float* var,
+ float* alpha,
+ float* beta);
+
+ double epsilon_;
+ const StorageOrder order_;
+
+ Tensor alpha_;
+ Tensor beta_;
+
+ INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR);
+ OUTPUT_TAGS(OUTPUT);
+};
+
+} // namespace caffe2
diff --git a/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op_test.py b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op_test.py
new file mode 100644
index 0000000000..5dbcfd7b10
--- /dev/null
+++ b/caffe2/quantization/server/spatial_batch_norm_dnnlowp_op_test.py
@@ -0,0 +1,110 @@
+from __future__ import absolute_import, division, print_function, unicode_literals
+
+import collections
+
+import caffe2.python.hypothesis_test_util as hu
+import hypothesis.strategies as st
+import numpy as np
+from caffe2.python import core, dyndep, utils, workspace
+from caffe2.quantization.server import utils as dnnlowp_utils
+from dnnlowp_test_utils import check_quantized_results_close
+from hypothesis import given
+
+
+dyndep.InitOpsLibrary("//caffe2/caffe2/quantization/server:dnnlowp_ops")
+workspace.GlobalInit(["caffe2", "--caffe2_omp_num_threads=11"])
+
+
+class DNNLowPOpSpatialBNTest(hu.HypothesisTestCase):
+ # correctness test with no quantization error in inputs
+ @given(
+ size=st.integers(10, 16),
+ input_channels=st.integers(2, 16),
+ output_channels=st.integers(2, 16),
+ batch_size=st.integers(1, 3),
+ order=st.sampled_from(["NCHW", "NHWC"]),
+ in_quantized=st.booleans(),
+ out_quantized=st.booleans(),
+ **hu.gcs_cpu_only
+ )
+ def test_dnnlowp_spatial_bn_int(
+ self,
+ size,
+ input_channels,
+ output_channels,
+ batch_size,
+ order,
+ in_quantized,
+ out_quantized,
+ gc,
+ dc,
+ ):
+ X_min = -77
+ X_max = X_min + 255
+ X = np.round(np.random.rand(batch_size, size, size, input_channels)).astype(
+ np.float32
+ )
+ X[0, 0, 0, 0] = X_min
+ X[0, 0, 0, 1] = X_max
+
+ epsilon = np.abs(np.random.rand())
+ scale = np.random.rand(input_channels).astype(np.float32)
+ bias = np.random.rand(input_channels).astype(np.float32)
+ mean = np.random.rand(input_channels).astype(np.float32)
+ var = np.random.rand(input_channels).astype(np.float32)
+
+ if order == "NCHW":
+ X = utils.NHWC2NCHW(X)
+
+ Output = collections.namedtuple("Output", ["Y", "op_type", "engine"])
+ outputs = []
+
+ op_engine_list = [
+ ("SpatialBN", ""),
+ ("SpatialBN", "DNNLOWP"),
+ ("Int8SpatialBN", "DNNLOWP"),
+ ]
+
+ for op_type, engine in op_engine_list:
+ net = core.Net("test_net")
+
+ do_quantize = "DNNLOWP" in engine and in_quantized
+ do_dequantize = "DNNLOWP" in engine and out_quantized
+
+ if do_quantize:
+ quantize = core.CreateOperator(
+ "Quantize", ["X"], ["X_q"], engine=engine
+ )
+ net.Proto().op.extend([quantize])
+
+ bn = core.CreateOperator(
+ op_type,
+ ["X_q" if do_quantize else "X", "scale", "bias", "mean", "var"],
+ ["Y_q" if do_dequantize else "Y"],
+ is_test=True,
+ epsilon=epsilon,
+ order=order,
+ engine=engine,
+ dequantize_output=not do_dequantize,
+ )
+ net.Proto().op.extend([bn])
+ if "DNNLOWP" in engine:
+ dnnlowp_utils.add_quantization_param_args(bn, outputs[0][0])
+
+ if do_dequantize:
+ dequantize = core.CreateOperator(
+ "Dequantize", ["Y_q"], ["Y"], engine=engine
+ )
+ net.Proto().op.extend([dequantize])
+
+ self.ws.create_blob("X").feed(X, device_option=gc)
+ self.ws.create_blob("scale").feed(scale, device_option=gc)
+ self.ws.create_blob("bias").feed(bias, device_option=gc)
+ self.ws.create_blob("mean").feed(mean, device_option=gc)
+ self.ws.create_blob("var").feed(var, device_option=gc)
+ self.ws.run(net)
+ outputs.append(
+ Output(Y=self.ws.blobs["Y"].fetch(), op_type=op_type, engine=engine)
+ )
+
+ check_quantized_results_close(outputs)