summaryrefslogtreecommitdiff
path: root/caffe2/operators/quantized
diff options
context:
space:
mode:
authorMarat Dukhan <marat@fb.com>2018-10-25 12:38:35 -0700
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2018-10-25 12:43:00 -0700
commit9cb4bce847af688ca344caaa876d3c47a36f28bf (patch)
treebb08da3db0087d543417e81b9b2d42f87c325672 /caffe2/operators/quantized
parentfaa354e10243e9e6d6428e1bd0f317bb226c42bd (diff)
downloadpytorch-9cb4bce847af688ca344caaa876d3c47a36f28bf.tar.gz
pytorch-9cb4bce847af688ca344caaa876d3c47a36f28bf.tar.bz2
pytorch-9cb4bce847af688ca344caaa876d3c47a36f28bf.zip
Open-source Caffe2 Int8 ops (#13065)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13065 - Open-source Caffe2 Int8 (quantized) operators Reviewed By: Yangqing Differential Revision: D10524381 fbshipit-source-id: 6daa153dc247572900c91e37262d033c368b382d
Diffstat (limited to 'caffe2/operators/quantized')
-rw-r--r--caffe2/operators/quantized/init_qnnpack.cc17
-rw-r--r--caffe2/operators/quantized/int8_add_op.cc68
-rw-r--r--caffe2/operators/quantized/int8_add_op.h222
-rw-r--r--caffe2/operators/quantized/int8_average_pool_op.cc65
-rw-r--r--caffe2/operators/quantized/int8_average_pool_op.h197
-rw-r--r--caffe2/operators/quantized/int8_channel_shuffle_op.cc14
-rw-r--r--caffe2/operators/quantized/int8_channel_shuffle_op.h164
-rw-r--r--caffe2/operators/quantized/int8_concat_op.cc22
-rw-r--r--caffe2/operators/quantized/int8_concat_op.h90
-rw-r--r--caffe2/operators/quantized/int8_conv_op.cc81
-rw-r--r--caffe2/operators/quantized/int8_conv_op.h171
-rw-r--r--caffe2/operators/quantized/int8_conv_transpose_op.cc49
-rw-r--r--caffe2/operators/quantized/int8_conv_transpose_op.h169
-rw-r--r--caffe2/operators/quantized/int8_dequantize_op.cc14
-rw-r--r--caffe2/operators/quantized/int8_dequantize_op.h52
-rw-r--r--caffe2/operators/quantized/int8_fc_op.cc41
-rw-r--r--caffe2/operators/quantized/int8_fc_op.h133
-rw-r--r--caffe2/operators/quantized/int8_flatten_op.cc30
-rw-r--r--caffe2/operators/quantized/int8_flatten_op.h47
-rw-r--r--caffe2/operators/quantized/int8_given_tensor_fill_op.cc32
-rw-r--r--caffe2/operators/quantized/int8_given_tensor_fill_op.h114
-rw-r--r--caffe2/operators/quantized/int8_leaky_relu_op.cc24
-rw-r--r--caffe2/operators/quantized/int8_leaky_relu_op.h64
-rw-r--r--caffe2/operators/quantized/int8_max_pool_op.cc63
-rw-r--r--caffe2/operators/quantized/int8_max_pool_op.h183
-rw-r--r--caffe2/operators/quantized/int8_quantize_op.cc16
-rw-r--r--caffe2/operators/quantized/int8_quantize_op.h91
-rw-r--r--caffe2/operators/quantized/int8_relu_op.cc37
-rw-r--r--caffe2/operators/quantized/int8_relu_op.h43
-rw-r--r--caffe2/operators/quantized/int8_reshape_op.cc31
-rw-r--r--caffe2/operators/quantized/int8_reshape_op.h47
-rw-r--r--caffe2/operators/quantized/int8_resize_nearest_op.cc25
-rw-r--r--caffe2/operators/quantized/int8_resize_nearest_op.h72
-rw-r--r--caffe2/operators/quantized/int8_roi_align_op.cc45
-rw-r--r--caffe2/operators/quantized/int8_roi_align_op.h341
-rw-r--r--caffe2/operators/quantized/int8_roi_align_op_test.cc62
-rw-r--r--caffe2/operators/quantized/int8_simd.h63
-rw-r--r--caffe2/operators/quantized/int8_slice_op.cc44
-rw-r--r--caffe2/operators/quantized/int8_slice_op.h71
-rw-r--r--caffe2/operators/quantized/int8_softmax_op.cc46
-rw-r--r--caffe2/operators/quantized/int8_softmax_op.h227
-rw-r--r--caffe2/operators/quantized/int8_test.cc858
-rw-r--r--caffe2/operators/quantized/int8_test_utils.h118
-rw-r--r--caffe2/operators/quantized/int8_utils.h177
44 files changed, 4540 insertions, 0 deletions
diff --git a/caffe2/operators/quantized/init_qnnpack.cc b/caffe2/operators/quantized/init_qnnpack.cc
new file mode 100644
index 0000000000..8b1356b366
--- /dev/null
+++ b/caffe2/operators/quantized/init_qnnpack.cc
@@ -0,0 +1,17 @@
+#include <mutex>
+
+#include <qnnpack.h>
+
+#include "caffe2/core/logging.h"
+
+namespace caffe2 {
+
+void initQNNPACK() {
+ static std::once_flag once;
+ static enum qnnp_status qnnpackStatus = qnnp_status_uninitialized;
+ std::call_once(once, []() { qnnpackStatus = qnnp_initialize(); });
+ CAFFE_ENFORCE(
+ qnnpackStatus == qnnp_status_success, "failed to initialize QNNPACK");
+}
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_add_op.cc b/caffe2/operators/quantized/int8_add_op.cc
new file mode 100644
index 0000000000..225a7a1f6b
--- /dev/null
+++ b/caffe2/operators/quantized/int8_add_op.cc
@@ -0,0 +1,68 @@
+#include <climits>
+
+#include "caffe2/operators/quantized/int8_add_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8Add, int8::Int8AddOp<int8::Activation::NONE>);
+REGISTER_CPU_OPERATOR(Int8AddRelu, int8::Int8AddOp<int8::Activation::RELU>);
+
+REGISTER_CPU_OPERATOR(Int8Sum, int8::Int8AddOp<int8::Activation::NONE>);
+REGISTER_CPU_OPERATOR(Int8SumRelu, int8::Int8AddOp<int8::Activation::RELU>);
+
+OPERATOR_SCHEMA(Int8Add)
+ .NumInputs(2)
+ .NumOutputs(1)
+ .AllowInplace({{0, 0}, {1, 0}})
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .SetDoc(R"DOC(
+ Performs element-wise binary Add (with no broadcast support).
+)DOC")
+ .Input(
+ 0,
+ "A",
+ "First operand, should share the type with the second operand.")
+ .Input(1, "B", "Second operand. It should be of the same size as A.")
+ .Output(0, "C", "Result, has same dimensions and type as A");
+
+OPERATOR_SCHEMA(Int8AddRelu)
+ .NumInputs(2)
+ .NumOutputs(1)
+ .AllowInplace({{0, 0}, {1, 0}})
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .SetDoc(R"DOC(
+ Performs element-wise binary Add (with no broadcast support). "
+ "Output will go through rectified linear "
+ "function, where y = max(0, x).
+)DOC")
+ .Input(
+ 0,
+ "A",
+ "First operand, should share the type with the second operand.")
+ .Input(1, "B", "Second operand. It should be of the same size as A.")
+ .Output(0, "C", "Result, has same dimensions and type as A");
+
+/*
+ * These ops are defined as alias of Int8Add/Int8AddRelu for compatibility
+ * with current production models. In the future these ops will be changed
+ * to an equivalent of Sum op, which does reduction of a single argument.
+ * We deliberately omit schema for Int8Sum/Int8SumRelu so they can
+ * temporary use either legacy or the new semantics depending on the engine.
+ */
+OPERATOR_SCHEMA(Int8Sum)
+ .NumInputs(1, std::numeric_limits<int>::max())
+ .NumOutputs(1)
+ .AllowInplace({{0, 0}, {1, 0}})
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset");
+
+OPERATOR_SCHEMA(Int8SumRelu)
+ .NumInputs(1, std::numeric_limits<int>::max())
+ .NumOutputs(1)
+ .AllowInplace({{0, 0}, {1, 0}})
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_add_op.h b/caffe2/operators/quantized/int8_add_op.h
new file mode 100644
index 0000000000..b847b012c0
--- /dev/null
+++ b/caffe2/operators/quantized/int8_add_op.h
@@ -0,0 +1,222 @@
+#ifndef CAFFE2_OPERATORS_INT8_ADD_OP_H_
+#define CAFFE2_OPERATORS_INT8_ADD_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+namespace {
+
+/*
+ * Implementation based on TensorFlow Lite kernels:
+ * - Repo: https://github.com/tensorflow/tensorflow
+ * - Path: tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+ * - Hash: d4ad9c73969c45d1a224ebfc43eb645b9860216b
+ */
+
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+constexpr size_t kAddLeftShift = 20;
+
+void Int8Add(
+ const uint8_t* X0_data,
+ size_t N,
+ size_t D,
+ int32_t X0_offset,
+ int32_t X0_multiplier,
+ int X0_shift,
+ const uint8_t* X1_data,
+ int32_t X1_offset,
+ int32_t X1_multiplier,
+ int X1_shift,
+ int32_t Y_offset,
+ int32_t Y_multiplier,
+ int Y_shift,
+ uint8_t* Y_data,
+ uint8_t Y_activation_min,
+ uint8_t Y_activation_max,
+ C2GEMMContext* gemm_context) {
+ CHECK_GT(X0_offset, -256);
+ CHECK_GT(X1_offset, -256);
+ CHECK_LT(X0_offset, 256);
+ CHECK_LT(X1_offset, 256);
+ static_assert(kAddLeftShift > 5, "");
+ static_assert(kAddLeftShift <= 20, "");
+
+ auto f = [&](int, size_t n) {
+ size_t d = 0;
+
+#ifdef INT8_NEON_SIMD
+ constexpr size_t kIntermediateAddLeftShift = 4;
+ const auto X0_offset_val =
+ vshlq_n_s16(vdupq_n_s16(X0_offset), kIntermediateAddLeftShift);
+ const auto X1_offset_val =
+ vshlq_n_s16(vdupq_n_s16(X1_offset), kIntermediateAddLeftShift);
+ const auto X0_shift_dup = vdupq_n_s32(-X0_shift);
+ const auto X1_shift_dup = vdupq_n_s32(-X1_shift);
+ const auto DUnroll = (D / 8) * 8;
+
+ for (; d < DUnroll; d += 8) {
+ const auto X0_val_original = vld1_u8(X0_data + n * D + d);
+ const auto X1_val_original = vld1_u8(X1_data + n * D + d);
+
+ // Load input
+ // Widen to int16
+ // Add int16 offset.
+ // Widen to int32
+ // Shift right by 20.
+ // Alternatively, we can widening shift X by 4, shifty X0_offset by 4,
+ // add, then shift by 16. Safe as X << 5 + X_offset << 5 can't overflow
+ // uint16, as X ~ 8 bit, X_offset ~ 10 bit, so 15 bits total from X +
+ // X_offset
+ const auto X0_val_s16 = vreinterpretq_s16_u16(
+ vshll_n_u8(X0_val_original, kIntermediateAddLeftShift));
+ const auto X1_val_s16 = vreinterpretq_s16_u16(
+ vshll_n_u8(X1_val_original, kIntermediateAddLeftShift));
+ const auto X0_val = vaddq_s16(X0_val_s16, X0_offset_val);
+ const auto X1_val = vaddq_s16(X1_val_s16, X1_offset_val);
+ const auto X0_val_high = vget_high_s16(X0_val);
+ const auto X0_val_low = vget_low_s16(X0_val);
+ const auto X1_val_high = vget_high_s16(X1_val);
+ const auto X1_val_low = vget_low_s16(X1_val);
+ auto x11 =
+ vshll_n_s16(X0_val_low, kAddLeftShift - kIntermediateAddLeftShift);
+ auto x12 =
+ vshll_n_s16(X0_val_high, kAddLeftShift - kIntermediateAddLeftShift);
+ auto x21 =
+ vshll_n_s16(X1_val_low, kAddLeftShift - kIntermediateAddLeftShift);
+ auto x22 =
+ vshll_n_s16(X1_val_high, kAddLeftShift - kIntermediateAddLeftShift);
+ x11 = vqrdmulhq_n_s32(x11, X0_multiplier);
+ x12 = vqrdmulhq_n_s32(x12, X0_multiplier);
+ x21 = vqrdmulhq_n_s32(x21, X1_multiplier);
+ x22 = vqrdmulhq_n_s32(x22, X1_multiplier);
+ x11 = vshlq_s32(x11, X0_shift_dup);
+ x12 = vshlq_s32(x12, X0_shift_dup);
+ x21 = vshlq_s32(x21, X1_shift_dup);
+ x22 = vshlq_s32(x22, X1_shift_dup);
+ auto s1 = vaddq_s32(x11, x21);
+ auto s2 = vaddq_s32(x12, x22);
+ s1 = vqrdmulhq_n_s32(s1, Y_multiplier);
+ s2 = vqrdmulhq_n_s32(s2, Y_multiplier);
+ using gemmlowp::RoundingDivideByPOT;
+ s1 = RoundingDivideByPOT(s1, Y_shift);
+ s2 = RoundingDivideByPOT(s2, Y_shift);
+ const auto s1_narrowed = vmovn_s32(s1);
+ const auto s2_narrowed = vmovn_s32(s2);
+ const auto s = vaddq_s16(
+ vcombine_s16(s1_narrowed, s2_narrowed), vdupq_n_s16(Y_offset));
+ auto ss = vqmovun_s16(s);
+ ss = vmin_u8(ss, vdup_n_u8(Y_activation_max));
+ ss = vmax_u8(ss, vdup_n_u8(Y_activation_min));
+ vst1_u8(Y_data + n * D + d, ss);
+ }
+#endif // NEON
+
+ for (; d < D; d++) {
+ const int32_t X0_val = X0_offset + X0_data[n * D + d];
+ const int32_t X1_val = X1_offset + X1_data[n * D + d];
+ const int32_t shifted_X0_val = X0_val * (1 << kAddLeftShift);
+ const int32_t shifted_X1_val = X1_val * (1 << kAddLeftShift);
+ const int32_t scaled_X0_val = MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_X0_val, X0_multiplier, X0_shift);
+ const int32_t scaled_X1_val = MultiplyByQuantizedMultiplierSmallerThanOne(
+ shifted_X1_val, X1_multiplier, X1_shift);
+ const int32_t raw_sum = scaled_X0_val + scaled_X1_val;
+ const int32_t raw_Y = MultiplyByQuantizedMultiplierSmallerThanOne(
+ raw_sum, Y_multiplier, Y_shift) +
+ Y_offset;
+ const int32_t clamped_Y = std::min<int32_t>(
+ Y_activation_max, std::max<int32_t>(Y_activation_min, raw_Y));
+ Y_data[n * D + d] = static_cast<uint8_t>(clamped_Y);
+ }
+ };
+ gemm_context->threadPool()->run(f, N);
+}
+
+} // namespace
+
+template <Activation Ac>
+class Int8AddOp final : public Operator<CPUContext> {
+ public:
+ Int8AddOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CPUContext>(operator_def, ws),
+ gemm_context_(ws->GetThreadPool()) {}
+
+ bool RunOnDevice() override {
+ CAFFE_ENFORCE_EQ(Inputs().size(), 2);
+ const auto& X0 = Inputs()[0]->template Get<Int8TensorCPU>();
+ const auto& X1 = Inputs()[1]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ auto X0_offset = -X0.zero_point;
+ auto X1_offset = -X1.zero_point;
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ const double twice_max_input_scale = 2 * std::max(X0.scale, X1.scale);
+ const double real_X0_multiplier = X0.scale / twice_max_input_scale;
+ const double real_X1_multiplier = X1.scale / twice_max_input_scale;
+ const double real_Y_multiplier =
+ twice_max_input_scale / ((1 << kAddLeftShift) * Y_scale);
+
+ Y->t.ResizeLike(X0.t);
+ Y->zero_point = Y_offset;
+ Y->scale = Y_scale;
+
+ int32_t X0_multiplier;
+ int X0_shift;
+ QuantizeMultiplierSmallerThanOne(
+ real_X0_multiplier, &X0_multiplier, &X0_shift);
+ int32_t X1_multiplier;
+ int X1_shift;
+ QuantizeMultiplierSmallerThanOne(
+ real_X1_multiplier, &X1_multiplier, &X1_shift);
+ int32_t Y_multiplier;
+ int Y_shift;
+ QuantizeMultiplierSmallerThanOne(
+ real_Y_multiplier, &Y_multiplier, &Y_shift);
+
+ Int8Add(
+ X0.t.template data<uint8_t>(),
+ X0.t.size() / X0.t.dim(X0.t.ndim() - 1),
+ X0.t.dim(X0.t.ndim() - 1),
+ X0_offset,
+ X0_multiplier,
+ X0_shift,
+ X1.t.template data<uint8_t>(),
+ X1_offset,
+ X1_multiplier,
+ X1_shift,
+ Y_offset,
+ Y_multiplier,
+ Y_shift,
+ Y->t.template mutable_data<uint8_t>(),
+ activationLimits(Y->scale, Y->zero_point, Ac).first,
+ activationLimits(Y->scale, Y->zero_point, Ac).second,
+ &gemm_context_);
+ return true;
+ }
+
+ private:
+ C2GEMMContext gemm_context_;
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_ADD_OP_H_
diff --git a/caffe2/operators/quantized/int8_average_pool_op.cc b/caffe2/operators/quantized/int8_average_pool_op.cc
new file mode 100644
index 0000000000..70df83f6fc
--- /dev/null
+++ b/caffe2/operators/quantized/int8_average_pool_op.cc
@@ -0,0 +1,65 @@
+#include "caffe2/operators/quantized/int8_average_pool_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(
+ Int8AveragePool,
+ int8::Int8AveragePoolOp<int8::Activation::NONE>);
+REGISTER_CPU_OPERATOR(
+ Int8AveragePoolRelu,
+ int8::Int8AveragePoolOp<int8::Activation::RELU>);
+
+const char kAveragePoolDoc_int8[] = R"DOC(
+consumes an input blob X and applies average pooling across the
+the blob according to kernel sizes, stride sizes, and pad lengths defined by the
+ConvPoolOpBase operator. Average pooling consisting of averaging all values of a
+subset of the input tensor according to the kernel size and downsampling the
+data into the output blob Y for further processing.
+)DOC";
+
+std::function<void(OpSchema&)> AveragePoolDocGenerator(
+ const char* dim,
+ bool relu_fused = false) {
+ auto suffix = relu_fused ? " Output will go through rectified linear "
+ "function, where y = max(0, x)."
+ : "";
+ return [=](OpSchema& schema) {
+ string doc = "AveragePool{dim} {pool_doc}";
+ c10::ReplaceAll(doc, "{dim}", dim);
+ c10::ReplaceAll(doc, "{pool_doc}", kAveragePoolDoc_int8);
+ schema.SetDoc(doc);
+ string output_doc =
+ "Output data tensor from average pooling across the input "
+ "tensor. Dimensions will vary based on various kernel, stride, and pad "
+ "sizes.{suffix}";
+ c10::ReplaceAll(output_doc, "{suffix}", suffix);
+ schema.Input(
+ 0,
+ "X",
+ "Input data tensor from the previous operator; dimensions depend on "
+ "whether the NCHW or NHWC operators are being used. For example, in "
+ "the former, the input has size (N x C x H x W), where N is the batch "
+ "size, C is the number of channels, and H and W are the height and the "
+ "width of the data. The corresponding permutation of dimensions is "
+ "used in the latter case.");
+ schema.Output(0, "Y", output_doc.c_str());
+ };
+}
+
+OPERATOR_SCHEMA(Int8AveragePool)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
+ .FillUsing(AveragePoolDocGenerator(""));
+
+OPERATOR_SCHEMA(Int8AveragePoolRelu)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
+ .FillUsing(AveragePoolDocGenerator("", true));
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_average_pool_op.h b/caffe2/operators/quantized/int8_average_pool_op.h
new file mode 100644
index 0000000000..c8bcb66d93
--- /dev/null
+++ b/caffe2/operators/quantized/int8_average_pool_op.h
@@ -0,0 +1,197 @@
+#ifndef CAFFE2_OPERATORS_INT8_AVERAGE_POOL_OP_H_
+#define CAFFE2_OPERATORS_INT8_AVERAGE_POOL_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/conv_pool_op_base.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+namespace {
+
+/*
+ * Implementation based on TensorFlow Lite kernels:
+ * - Repo: https://github.com/tensorflow/tensorflow
+ * - Path: tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+ * - Hash: d4ad9c73969c45d1a224ebfc43eb645b9860216b
+ */
+
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+void Int8AveragePool(
+ const uint8_t* input_data,
+ at::IntList input_dims,
+ int stride_width,
+ int stride_height,
+ int pad_width,
+ int pad_height,
+ int filter_width,
+ int filter_height,
+ uint8_t* output_data,
+ at::IntList output_dims,
+ uint8_t output_activation_min,
+ uint8_t output_activation_max) {
+ DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = input_dims[0];
+ const int depth = input_dims[3];
+ const int input_height = input_dims[1];
+ const int input_width = input_dims[2];
+ const int output_height = output_dims[1];
+ const int output_width = output_dims[2];
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ const int filter_count =
+ (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start);
+ // 1280 required by Inception v3
+ static constexpr int kAccBufferMaxSize = 2048;
+ DCHECK_LE(depth, kAccBufferMaxSize);
+ uint16_t acc[kAccBufferMaxSize];
+ memset(acc, 0, depth * sizeof(acc[0]));
+ const uint8_t* input_ptr = input_data + depth * in_x_origin +
+ depth * input_width * in_y_origin +
+ depth * input_width * input_height * batch;
+ for (int fy = filter_y_start; fy < filter_y_end; fy++) {
+ const uint8_t* input_row_ptr =
+ input_ptr + fy * input_width * depth + filter_x_start * depth;
+ for (int fx = filter_x_start; fx < filter_x_end; fx++) {
+ int channel = 0;
+#ifdef INT8_NEON_SIMD
+ for (; channel <= depth - 16; channel += 16) {
+ uint16x8_t acc_reg[2];
+ for (int i = 0; i < 2; i++) {
+ acc_reg[i] = vld1q_u16(acc + channel + 8 * i);
+ }
+ uint8x16_t input_reg = vld1q_u8(input_row_ptr);
+ input_row_ptr += 16;
+ acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg));
+ acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg));
+ for (int i = 0; i < 2; i++) {
+ vst1q_u16(acc + channel + 8 * i, acc_reg[i]);
+ }
+ }
+ for (; channel <= depth - 8; channel += 8) {
+ uint16x8_t acc_reg = vld1q_u16(acc + channel);
+ uint8x8_t input_reg = vld1_u8(input_row_ptr);
+ input_row_ptr += 8;
+ acc_reg = vaddw_u8(acc_reg, input_reg);
+ vst1q_u16(acc + channel, acc_reg);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ acc[channel] += *input_row_ptr++;
+ }
+ }
+ }
+ uint8_t* output_ptr = output_data + out_x * depth +
+ out_y * depth * output_width +
+ batch * depth * output_width * output_height;
+ int channel = 0;
+#ifdef INT8_NEON_SIMD
+#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \
+ if (filter_count == FILTER_COUNT) { \
+ for (; channel <= depth - 8; channel += 8) { \
+ uint16_t buf[8]; \
+ for (int i = 0; i < 8; i++) { \
+ buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \
+ } \
+ uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \
+ buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \
+ buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \
+ vst1_u8(output_ptr + channel, buf8); \
+ } \
+ }
+ AVGPOOL_DIVIDING_BY(9)
+ AVGPOOL_DIVIDING_BY(15)
+#undef AVGPOOL_DIVIDING_BY
+ for (; channel <= depth - 8; channel += 8) {
+ uint16_t buf[8];
+ for (int i = 0; i < 8; i++) {
+ buf[i] = (acc[channel + i] + filter_count / 2) / filter_count;
+ }
+ uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf));
+ buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max));
+ buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min));
+ vst1_u8(output_ptr + channel, buf8);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ uint16_t a = (acc[channel] + filter_count / 2) / filter_count;
+ a = std::max<uint16_t>(a, output_activation_min);
+ a = std::min<uint16_t>(a, output_activation_max);
+ output_ptr[channel] = static_cast<uint8_t>(a);
+ }
+ }
+ }
+ }
+}
+
+} // namespace
+
+template <Activation Ac>
+class Int8AveragePoolOp final : public ConvPoolOpBase<CPUContext> {
+ public:
+ Int8AveragePoolOp(const OperatorDef& operator_def, Workspace* ws)
+ : ConvPoolOpBase<CPUContext>(operator_def, ws) {
+ OPERATOR_NEEDS_FEATURE(
+ this->order_ == StorageOrder::NHWC, "Int8 only supports NCHW order.");
+ }
+
+ bool RunOnDeviceWithOrderNHWC() override {
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ Y->scale = X.scale;
+ Y->zero_point = X.zero_point;
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ CHECK_EQ(Y_offset, X.zero_point);
+ CHECK_EQ(Y_scale, X.scale);
+
+ CHECK_EQ(X.t.ndim(), 4);
+ const int channels = X.t.dim32(3);
+ ConvPoolOpBase<CPUContext>::SetOutputSize(X.t, &(Y->t), channels);
+
+ Int8AveragePool(
+ X.t.template data<uint8_t>(),
+ X.t.sizes(),
+ stride_w(),
+ stride_h(),
+ pad_l(),
+ pad_t(),
+ kernel_w(),
+ kernel_h(),
+ Y->t.template mutable_data<uint8_t>(),
+ Y->t.sizes(),
+ activationLimits(Y->scale, Y->zero_point, Ac).first,
+ activationLimits(Y->scale, Y->zero_point, Ac).second);
+ return true;
+ }
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_AVERAGE_POOL_OP_H_
diff --git a/caffe2/operators/quantized/int8_channel_shuffle_op.cc b/caffe2/operators/quantized/int8_channel_shuffle_op.cc
new file mode 100644
index 0000000000..a133e2bf74
--- /dev/null
+++ b/caffe2/operators/quantized/int8_channel_shuffle_op.cc
@@ -0,0 +1,14 @@
+#include "caffe2/operators/quantized/int8_channel_shuffle_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8ChannelShuffle, int8::Int8ChannelShuffleOp);
+
+OPERATOR_SCHEMA(Int8ChannelShuffle)
+ .IdenticalTypeAndShape()
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .NumInputs(1)
+ .NumOutputs(1);
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_channel_shuffle_op.h b/caffe2/operators/quantized/int8_channel_shuffle_op.h
new file mode 100644
index 0000000000..4589c7d9e8
--- /dev/null
+++ b/caffe2/operators/quantized/int8_channel_shuffle_op.h
@@ -0,0 +1,164 @@
+#ifndef CAFFE2_OPERATORS_INT8_CHANNEL_SHUFFLE_OP_H_
+#define CAFFE2_OPERATORS_INT8_CHANNEL_SHUFFLE_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/conv_pool_op_base.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+namespace {
+
+template <size_t TileSizeK, size_t TileSizeG>
+inline void
+TransposeTile(const uint8_t* X_tile, size_t K, size_t G, uint8_t* Y_tile) {
+#ifdef INT8_NEON_SIMD
+ static_assert(TileSizeK == 8, "");
+ static_assert(TileSizeG == 4, "");
+ auto Transpose8x4_NEON =
+ [](uint8x8_t* a0, uint8x8_t* a1, uint8x8_t* a2, uint8x8_t* a3) {
+ const uint8x8x2_t b0 = vtrn_u8(*a0, *a1);
+ const uint8x8x2_t b1 = vtrn_u8(*a2, *a3);
+ const uint16x4x2_t c0 = vtrn_u16(
+ vreinterpret_u16_u8(b0.val[0]), vreinterpret_u16_u8(b1.val[0]));
+ const uint16x4x2_t c1 = vtrn_u16(
+ vreinterpret_u16_u8(b0.val[1]), vreinterpret_u16_u8(b1.val[1]));
+ *a0 = vreinterpret_u8_u16(c0.val[0]);
+ *a1 = vreinterpret_u8_u16(c1.val[0]);
+ *a2 = vreinterpret_u8_u16(c0.val[1]);
+ *a3 = vreinterpret_u8_u16(c1.val[1]);
+ };
+
+ uint8x8_t g0 = vld1_u8(X_tile + 0 * K);
+ uint8x8_t g1 = vld1_u8(X_tile + 1 * K);
+ uint8x8_t g2 = vld1_u8(X_tile + 2 * K);
+ uint8x8_t g3 = vld1_u8(X_tile + 3 * K);
+ Transpose8x4_NEON(&g0, &g1, &g2, &g3);
+ uint8_t tile[TileSizeK / 2][2][TileSizeG];
+ vst1_u8(&tile[0][0][0], g0);
+ vst1_u8(&tile[1][0][0], g1);
+ vst1_u8(&tile[2][0][0], g2);
+ vst1_u8(&tile[3][0][0], g3);
+ for (auto kkk = 0; kkk < 2; ++kkk) {
+ for (auto kk = 0; kk < TileSizeK / 2; ++kk) {
+ const auto k = TileSizeK / 2 * kkk + kk;
+ for (auto g = 0; g < TileSizeG; ++g) {
+ Y_tile[k * G + g] = tile[kk][kkk][g];
+ }
+ }
+ }
+#else
+ uint8_t tile[TileSizeG][TileSizeK];
+ for (auto g = 0; g < TileSizeG; ++g) {
+ for (auto k = 0; k < TileSizeK; ++k) {
+ tile[g][k] = X_tile[g * K + k];
+ }
+ }
+ for (auto k = 0; k < TileSizeK; ++k) {
+ for (auto g = 0; g < TileSizeG; ++g) {
+ Y_tile[k * G + g] = tile[g][k];
+ }
+ }
+#endif
+}
+
+void Int8ChannelShuffle(
+ const uint8_t* X_data,
+ size_t B,
+ size_t K,
+ size_t G,
+ uint8_t* Y_data,
+ C2GEMMContext* gemm_context) {
+ auto divRoundUp = [](size_t n, size_t d) { return (n + d - 1) / d; };
+ constexpr size_t kTileSizeG = 4;
+ constexpr size_t kTileSizeK = 8;
+ auto f = [&](int, size_t b) {
+ for (auto kk = 0; kk < divRoundUp(K, kTileSizeK); ++kk) {
+ for (auto gg = 0; gg < divRoundUp(G, kTileSizeG); ++gg) {
+ const auto g = gg * kTileSizeG;
+ const auto k = kk * kTileSizeK;
+ const auto X_tile = X_data + b * G * K + g * K + k;
+ auto* Y_tile = Y_data + b * G * K + k * G + g;
+ if (kk * kTileSizeK + kTileSizeK <= K &&
+ gg * kTileSizeG + kTileSizeG <= G) {
+ // Complete tile.
+ TransposeTile<kTileSizeK, kTileSizeG>(X_tile, K, G, Y_tile);
+ } else {
+ uint8_t Xp[kTileSizeG][kTileSizeK];
+ uint8_t Yp[kTileSizeK][kTileSizeG];
+ for (auto kt = 0; kt < kTileSizeK; ++kt) {
+ for (auto gt = 0; gt < kTileSizeG; ++gt) {
+ if (k + kt < K && g + gt < G) {
+ Xp[gt][kt] = X_tile[gt * K + kt];
+ }
+ }
+ }
+ TransposeTile<kTileSizeK, kTileSizeG>(
+ &Xp[0][0], kTileSizeK, kTileSizeG, &Yp[0][0]);
+ for (auto kt = 0; kt < kTileSizeK; ++kt) {
+ for (auto gt = 0; gt < kTileSizeG; ++gt) {
+ if (k + kt < K && g + gt < G) {
+ Y_tile[kt * G + gt] = Yp[kt][gt];
+ }
+ }
+ }
+ }
+ }
+ }
+ };
+ gemm_context->threadPool()->run(f, B);
+}
+
+} // namespace
+
+class Int8ChannelShuffleOp final : public ConvPoolOpBase<CPUContext> {
+ public:
+ Int8ChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws)
+ : ConvPoolOpBase<CPUContext>(operator_def, ws),
+ gemm_context_(ws->GetThreadPool()) {
+ OPERATOR_NEEDS_FEATURE(
+ this->order_ == StorageOrder::NHWC,
+ "Int8ChannelShuffleOp only supports NHWC order");
+ }
+
+ bool RunOnDeviceWithOrderNHWC() override {
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ Y->t.ResizeLike(X.t);
+ Y->scale = X.scale;
+ Y->zero_point = X.zero_point;
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ CHECK_EQ(Y_offset, X.zero_point);
+ CHECK_EQ(Y_scale, X.scale);
+ CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
+ CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
+
+ const auto C = X.t.dim32(3);
+ CAFFE_ENFORCE(C % this->group_ == 0, "");
+ const auto G = this->group_;
+ const auto K = C / G;
+ const auto B = X.t.dim32(0) * X.t.dim32(1) * X.t.dim32(2);
+ Int8ChannelShuffle(
+ X.t.data<uint8_t>(),
+ B,
+ K,
+ G,
+ Y->t.mutable_data<uint8_t>(),
+ &gemm_context_);
+ return true;
+ }
+
+ private:
+ C2GEMMContext gemm_context_;
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_CHANNEL_SHUFFLE_OP_H_
diff --git a/caffe2/operators/quantized/int8_concat_op.cc b/caffe2/operators/quantized/int8_concat_op.cc
new file mode 100644
index 0000000000..8950d41427
--- /dev/null
+++ b/caffe2/operators/quantized/int8_concat_op.cc
@@ -0,0 +1,22 @@
+#include "caffe2/operators/quantized/int8_concat_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8Concat, int8::Int8ConcatOp);
+
+OPERATOR_SCHEMA(Int8Concat)
+ .NumInputs(1, INT_MAX)
+ .NumOutputs(1, 2)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .Arg("axis", "Which axis to concat on")
+ .Arg(
+ "add_axis",
+ "Pass 1 to add the axis specified in arg 'axis' to all "
+ "input tensors")
+ .SetDoc("Concatenate a list of tensors into a single tensor")
+ .Output(0, "concat_result", "Concatenated tensor")
+ .Output(1, "split_info", "The dimensions of the inputs.")
+ .InheritOnnxSchema("Concat");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_concat_op.h b/caffe2/operators/quantized/int8_concat_op.h
new file mode 100644
index 0000000000..939376f795
--- /dev/null
+++ b/caffe2/operators/quantized/int8_concat_op.h
@@ -0,0 +1,90 @@
+#ifndef CAFFE2_OPERATORS_INT8_CONCAT_OP_H_
+#define CAFFE2_OPERATORS_INT8_CONCAT_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+class Int8ConcatOp final : public Operator<CPUContext> {
+ public:
+ Int8ConcatOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CPUContext>(operator_def, ws) {
+ // concat supports more than NHWC format
+ if (this->template GetSingleArgument<string>("order", "") == "NHWC") {
+ // Default to C axis
+ axis_ = this->template GetSingleArgument<int>("axis", 3);
+ CHECK_GE(axis_, 0);
+ CHECK_LT(axis_, 4);
+ } else if (
+ this->template GetSingleArgument<string>("order", "") == "NCHW") {
+ axis_ = this->template GetSingleArgument<int>("axis", 1);
+ CHECK_GE(axis_, 0);
+ CHECK_LT(axis_, 4);
+ } else {
+ axis_ = this->template GetSingleArgument<int>("axis", 0);
+ }
+ }
+
+ bool RunOnDevice() override {
+ const auto& X0 = Inputs()[0]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ Y->scale = X0.scale;
+ Y->zero_point = X0.zero_point;
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ CHECK_EQ(Y_offset, X0.zero_point);
+ CHECK_EQ(Y_scale, X0.scale);
+ CHECK_GE(X0.zero_point, std::numeric_limits<uint8_t>::min());
+ CHECK_LE(X0.zero_point, std::numeric_limits<uint8_t>::max());
+ auto Y_dims = X0.t.sizes().vec();
+ if (this->template GetSingleArgument<string>("order", "") == "NHWC") {
+ CHECK_EQ(Y_dims.size(), 4);
+ }
+ for (auto i = 1; i < InputSize(); ++i) {
+ const auto& Xi = Inputs()[i]->template Get<Int8TensorCPU>();
+ CHECK_EQ(Xi.t.ndim(), Y_dims.size());
+ for (auto j = 0; j < Y_dims.size(); ++j) {
+ if (j != axis_) {
+ CHECK_EQ(Xi.t.dim(j), Y_dims[j]);
+ }
+ }
+ Y_dims[axis_] += Xi.t.dim(axis_);
+ }
+ Y->t.Resize(Y_dims);
+ int before = X0.t.size_to_dim(axis_);
+ int after = X0.t.size_from_dim(axis_ + 1);
+ const auto C_total = Y_dims[axis_];
+ size_t C_offset = 0;
+ for (auto i = 0; i < InputSize(); ++i) {
+ const auto& Xi = Inputs()[i]->template Get<Int8TensorCPU>();
+ // Copy the NxHxWxC input slice to NxHxWx[C_offset:C_offset + C].
+ const auto Ci = Xi.t.dim(axis_);
+ math::CopyMatrix<CPUContext>(
+ Xi.t.itemsize(),
+ before,
+ Ci * after,
+ Xi.t.template data<uint8_t>(),
+ Ci * after,
+ Y->t.template mutable_data<uint8_t>() + C_offset,
+ C_total * after,
+ &context_,
+ Xi.t.meta().copy());
+ C_offset += Ci * after * Xi.t.itemsize();
+ }
+ return true;
+ }
+
+ private:
+ int axis_;
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_CONCAT_OP_H_
diff --git a/caffe2/operators/quantized/int8_conv_op.cc b/caffe2/operators/quantized/int8_conv_op.cc
new file mode 100644
index 0000000000..4bbd83f5ab
--- /dev/null
+++ b/caffe2/operators/quantized/int8_conv_op.cc
@@ -0,0 +1,81 @@
+#include "caffe2/operators/quantized/int8_conv_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8Conv, int8::Int8ConvOp<int8::Activation::NONE>);
+REGISTER_CPU_OPERATOR(Int8ConvRelu, int8::Int8ConvOp<int8::Activation::RELU>);
+
+const char kConvDoc_int8[] = R"DOC(
+[Only NHWC order is supported now]Note that other parameters, such as the stride and
+kernel size, or the pads' sizes in each direction are not necessary for input
+because they are provided by the ConvPoolOpBase operator. Various dimension
+checks are done implicitly, and the sizes are specified in the Input docs for
+this operator. As is expected, the filter is convolved with a subset of the
+image and the bias is added; this is done throughout the image data and the
+output is computed. As a side note on the implementation layout:
+conv_op_impl.h is the templated implementation of the conv_op.h file, which is
+why they are separate files.
+)DOC";
+
+std::function<void(OpSchema&)> ConvDocGenerator(
+ const char* dim,
+ bool relu_fused = false) {
+ auto suffix = relu_fused ? " Output will go through rectified linear "
+ "function, where y = max(0, x)."
+ : "";
+ return [=](OpSchema& schema) {
+ string doc = R"DOC(
+The convolution operator consumes an input vector, a {dim}filter blob
+and a bias blob and computes the output. {conv_doc})DOC";
+ c10::ReplaceAll(doc, "{dim}", dim);
+ c10::ReplaceAll(doc, "{conv_doc}", kConvDoc_int8);
+ schema.SetDoc(doc);
+ string output_doc =
+ "Output data blob that contains the result of the "
+ "convolution. The output dimensions are functions of the kernel size, "
+ "stride size, and pad lengths.{suffix}";
+ c10::ReplaceAll(output_doc, "{suffix}", suffix);
+ schema.Input(
+ 0,
+ "X",
+ "Input data blob from previous layer; has size (N x C x H x W), "
+ "where N is the batch size, C is the number of channels, "
+ "and H and W are the height and width. Note that this is for the NCHW "
+ "usage. On the other hand, the NHWC Op has a different set of "
+ "dimension constraints. ");
+ schema.Input(
+ 1,
+ "filter",
+ "The filter blob that will be used in the "
+ "convolutions; has size (M x C x kH x kW), where C is the number of "
+ "channels, and kH and kW are the height and width of the kernel.");
+ schema.Input(
+ 2,
+ "bias",
+ "The 1D bias blob that is added through the "
+ "convolution; has size (M).");
+ schema.Output(0, "Y", output_doc.c_str());
+ };
+}
+
+OPERATOR_SCHEMA(Int8Conv)
+ .NumInputs(2, 3)
+ .NumOutputs(1)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv)
+ .CostInferenceFunction(OpSchema::CostInferenceFunctionType(
+ ConvPoolOpBase<CPUContext>::CostInferenceForConv))
+ .FillUsing(ConvDocGenerator(""));
+
+OPERATOR_SCHEMA(Int8ConvRelu)
+ .NumInputs(2, 3)
+ .NumOutputs(1)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv)
+ .CostInferenceFunction(OpSchema::CostInferenceFunctionType(
+ ConvPoolOpBase<CPUContext>::CostInferenceForConv))
+ .FillUsing(ConvDocGenerator("", true));
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_conv_op.h b/caffe2/operators/quantized/int8_conv_op.h
new file mode 100644
index 0000000000..befa85bc25
--- /dev/null
+++ b/caffe2/operators/quantized/int8_conv_op.h
@@ -0,0 +1,171 @@
+#ifndef CAFFE2_OPERATORS_INT8_CONV_OP_H_
+#define CAFFE2_OPERATORS_INT8_CONV_OP_H_
+
+#include <qnnpack.h>
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/conv_op_shared.h"
+#include "caffe2/operators/conv_pool_op_base.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+template <Activation Ac>
+class Int8ConvOp final : public ConvPoolOpBase<CPUContext> {
+ public:
+ USE_CONV_POOL_BASE_FUNCTIONS(CPUContext);
+ Int8ConvOp(const OperatorDef& def, Workspace* ws)
+ : ConvPoolOpBase(def, ws), gemm_context_(ws->GetThreadPool()) {
+ OPERATOR_NEEDS_FEATURE(
+ this->order_ == StorageOrder::NHWC,
+ "Int8Conv only supports NHWC order");
+ createSharedBuffer<CPUContext>(ws_);
+ }
+
+ ~Int8ConvOp() {
+ if (this->qnnpackObject_ != nullptr) {
+ qnnp_delete_operator(this->qnnpackObject_);
+ this->qnnpackObject_ = nullptr;
+ }
+ }
+
+ bool RunOnDeviceWithOrderNHWC() override {
+ CAFFE_ENFORCE_EQ(Inputs().size(), 3);
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
+ const auto& W = Inputs()[1]->template Get<Int8TensorCPU>();
+ const auto& B = Inputs()[2]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ const int32_t Y_offset =
+ this->template GetSingleArgument<int>("Y_zero_point", 0);
+ double Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+
+ ConvPoolOpBase<CPUContext>::SetOutputSize(X.t, &(Y->t), W.t.dim32(0));
+ Y->scale = Y_scale;
+ Y->zero_point = Y_offset;
+
+ const auto M = W.t.dim(0);
+ const auto KH = W.t.dim(1);
+ const auto KW = W.t.dim(2);
+ const auto KC = W.t.dim(3);
+ const auto C = X.t.dim32(3);
+ const bool isDepthwise = this->group_ > 1 && this->group_ == M &&
+ this->group_ == C && KC == 1 && KH * KW == 9 && dilation_w() == 1;
+
+ CHECK_EQ(Y->t.dim32(3), M);
+ runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
+ initQNNPACK();
+
+ pthreadpool_t threadpool =
+ reinterpret_cast<pthreadpool_t>(gemm_context_.threadPool());
+
+ if (this->qnnpackObject_ == nullptr) {
+ CAFFE_ENFORCE(
+ C % this->group_ == 0,
+ "number of input channels must be divisible by groups count");
+ CAFFE_ENFORCE(
+ M % this->group_ == 0,
+ "number of output channels must be divisible by groups count");
+ const qnnp_status createStatus = qnnp_create_convolution2d_nhwc_q8(
+ pad_t(),
+ pad_r(),
+ pad_b(),
+ pad_l(),
+ KH,
+ KW,
+ stride_h(),
+ stride_w(),
+ dilation_h(),
+ dilation_w(),
+ this->group_,
+ C / this->group_,
+ M / this->group_,
+ X.zero_point,
+ X.scale,
+ W.zero_point,
+ W.scale,
+ W.t.template data<uint8_t>(),
+ B.t.template data<int32_t>(),
+ Y->zero_point,
+ Y->scale,
+ activationLimits(Y->scale, Y->zero_point, Ac).first,
+ activationLimits(Y->scale, Y->zero_point, Ac).second,
+ &this->qnnpackObject_);
+ CAFFE_ENFORCE(
+ createStatus == qnnp_status_success,
+ "failed to create QNNPACK convolution object");
+ CAFFE_ENFORCE(this->qnnpackObject_ != nullptr);
+ }
+
+ uint8_t* inputPtr = X.t.template mutable_data<uint8_t>();
+ if ((isDepthwise && this->group_ < 8) ||
+ (!isDepthwise && C / this->group_ < 8)) {
+ buffer->Resize(std::vector<int64_t>{X.t.size() + 8});
+ inputPtr = buffer->template mutable_data<uint8_t>() + 8;
+ memcpy(inputPtr, X.t.template data<uint8_t>(), X.t.size());
+ }
+
+ if (lastBatchSize_ != static_cast<size_t>(X.t.dim(0)) ||
+ lastInputHeight_ != static_cast<size_t>(X.t.dim(1)) ||
+ lastInputWidth_ != static_cast<size_t>(X.t.dim(2)) ||
+ lastInputPointer_ != inputPtr ||
+ lastOutputPointer_ != Y->t.template mutable_data<uint8_t>()) {
+ const qnnp_status setupStatus = qnnp_setup_convolution2d_nhwc_q8(
+ this->qnnpackObject_,
+ X.t.dim(0),
+ X.t.dim(1),
+ X.t.dim(2),
+ inputPtr,
+ X.t.dim(3) /* input pixel stride */,
+ Y->t.template mutable_data<uint8_t>(),
+ Y->t.dim(3) /* output pixel stride */,
+ nullptr /* threadpool */);
+ CAFFE_ENFORCE(
+ setupStatus == qnnp_status_success,
+ "failed to setup QNNPACK convolution object");
+
+ lastBatchSize_ = static_cast<size_t>(X.t.dim(0));
+ lastInputHeight_ = static_cast<size_t>(X.t.dim(1));
+ lastInputWidth_ = static_cast<size_t>(X.t.dim(2));
+ lastInputPointer_ = inputPtr;
+ lastOutputPointer_ = Y->t.template mutable_data<uint8_t>();
+ }
+
+#ifdef FBCODE_CAFFE2
+ const qnnp_status runStatus =
+ qnnp_run_operator(this->qnnpackObject_, nullptr /* thread pool */);
+#else
+ const qnnp_status runStatus =
+ qnnp_run_operator(this->qnnpackObject_, threadpool);
+#endif
+ CAFFE_ENFORCE(
+ runStatus == qnnp_status_success,
+ "failed to run QNNPACK convolution");
+ });
+ return true;
+ }
+
+ private:
+ C2GEMMContext gemm_context_;
+ // QNNPACK convolution object
+ qnnp_operator_t qnnpackObject_{nullptr};
+ // batch size in the previous call to RunOnDeviceWithOrderNHWC
+ size_t lastBatchSize_{0};
+ // input height in the previous call to RunOnDeviceWithOrderNHWC
+ size_t lastInputHeight_{0};
+ // input width in the previous call to RunOnDeviceWithOrderNHWC
+ size_t lastInputWidth_{0};
+ // input pointer in the previous call to RunOnDeviceWithOrderNHWC
+ const void* lastInputPointer_{nullptr};
+ // output pointer in the previous call to RunOnDeviceWithOrderNHWC
+ void* lastOutputPointer_{nullptr};
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_CONV_OP_H_
diff --git a/caffe2/operators/quantized/int8_conv_transpose_op.cc b/caffe2/operators/quantized/int8_conv_transpose_op.cc
new file mode 100644
index 0000000000..6431852567
--- /dev/null
+++ b/caffe2/operators/quantized/int8_conv_transpose_op.cc
@@ -0,0 +1,49 @@
+#include "caffe2/operators/quantized/int8_conv_transpose_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8ConvTranspose, int8::Int8ConvTransposeOp);
+
+OPERATOR_SCHEMA(Int8ConvTranspose)
+ .NumInputs(2, 3)
+ .NumOutputs(1)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .SetDoc(R"DOC(
+The transposed convolution consumes an input vector, the filter blob, and
+the bias blob, and computes the output. Note that other parameters, such as
+the stride and kernel size, or the pads' sizes in each direction are not
+necessary for input because they are provided by the
+ConvTransposeUnpoolOpBase operator. Various dimension checks are done
+implicitly, and the sizes are specified in the Input docs for this operator.
+As is expected, the filter is deconvolved with a subset of the
+image and the bias is added; this is done throughout the image data and the
+output is computed. As a side note on the implementation layout:
+conv_transpose_op_impl.h is the templated implementation of the
+conv_transpose_op.h file, which is why they are separate files.
+ )DOC")
+ .Input(
+ 0,
+ "X",
+ "Input data blob from previous layer; has size "
+ "(N x H x W x C), where N is the batch size, C is the number of channels, and"
+ " H and W are the height and width. Note that NHWC is supported now")
+ .Input(
+ 1,
+ "filter",
+ "The filter blob that will be used in the transposed "
+ "convolution; has size (M x kH x kW x C), where C is the number of channels,"
+ " and kH and kW are the height and width of the kernel.")
+ .Input(
+ 2,
+ "bias",
+ "The 1D bias blob that is added through the convolution;"
+ "has size (C). Optional, if not passed, will treat it as all 0.")
+ .Output(
+ 0,
+ "Y",
+ "Output data blob that contains the result of the "
+ "transposed convolution. The output dimensions are functions of the kernel"
+ " size, stride size, and pad lengths.");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_conv_transpose_op.h b/caffe2/operators/quantized/int8_conv_transpose_op.h
new file mode 100644
index 0000000000..e3c78df9a8
--- /dev/null
+++ b/caffe2/operators/quantized/int8_conv_transpose_op.h
@@ -0,0 +1,169 @@
+#ifndef CAFFE2_OPERATORS_INT8_CONV_TRANSPOSE_OP_H_
+#define CAFFE2_OPERATORS_INT8_CONV_TRANSPOSE_OP_H_
+
+#include <qnnpack.h>
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/conv_op_shared.h"
+#include "caffe2/operators/conv_transpose_unpool_op_base.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+class Int8ConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
+ public:
+ USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(CPUContext);
+ Int8ConvTransposeOp(const OperatorDef& def, Workspace* ws)
+ : ConvTransposeUnpoolBase(def, ws), gemm_context_(ws->GetThreadPool()) {
+ OPERATOR_NEEDS_FEATURE(
+ this->order_ == StorageOrder::NHWC,
+ "Int8ConvTransposeOp only supports NHWC order");
+ createSharedBuffer<CPUContext>(ws_);
+ }
+
+ ~Int8ConvTransposeOp() {
+ if (this->qnnpackObject_ != nullptr) {
+ qnnp_delete_operator(this->qnnpackObject_);
+ this->qnnpackObject_ = nullptr;
+ }
+ }
+
+ bool RunOnDeviceWithOrderNHWC() override {
+ CAFFE_ENFORCE_EQ(Inputs().size(), 3);
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
+ const auto& W = Inputs()[1]->template Get<Int8TensorCPU>();
+ const auto& B = Inputs()[2]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ const auto X_offset = -X.zero_point;
+ const auto W_offset = -W.zero_point;
+ const int32_t Y_offset =
+ this->template GetSingleArgument<int>("Y_zero_point", 0);
+ double Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ Y->scale = Y_scale;
+ Y->zero_point = Y_offset;
+
+ const auto N = X.t.dim(0);
+ const auto IH = X.t.dim(1);
+ const auto IW = X.t.dim(2);
+ const auto IC = X.t.dim(3);
+
+ CHECK_EQ(IC, W.t.dim(0));
+ const auto KH = W.t.dim(1);
+ const auto KW = W.t.dim(2);
+ const auto OC = W.t.dim(3);
+
+ ConvTransposeUnpoolBase<CPUContext>::SetOutputSize(X.t, &(Y->t), OC);
+ CHECK_EQ(OC, Y->t.dim(3));
+
+ runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
+ initQNNPACK();
+
+ pthreadpool_t threadpool =
+ reinterpret_cast<pthreadpool_t>(gemm_context_.threadPool());
+
+ if (this->qnnpackObject_ == nullptr) {
+ const qnnp_status createStatus = qnnp_create_deconvolution2d_nhwc_q8(
+ pad_t(),
+ pad_r(),
+ pad_b(),
+ pad_l(),
+ adj_h(),
+ adj_w(),
+ KH,
+ KW,
+ stride_h(),
+ stride_w(),
+ 1 /* dilation height */,
+ 1 /* dilation width */,
+ 1 /* groups */,
+ IC,
+ OC,
+ X.zero_point,
+ X.scale,
+ W.zero_point,
+ W.scale,
+ W.t.template data<uint8_t>(),
+ B.t.template data<int32_t>(),
+ Y->zero_point,
+ Y->scale,
+ std::numeric_limits<uint8_t>::min(),
+ std::numeric_limits<uint8_t>::max(),
+ &this->qnnpackObject_);
+ CAFFE_ENFORCE(
+ createStatus == qnnp_status_success,
+ "failed to create QNNPACK convolution object");
+ CAFFE_ENFORCE(this->qnnpackObject_ != nullptr);
+ }
+
+ uint8_t* inputPtr = X.t.template mutable_data<uint8_t>();
+ if (IC < 8) {
+ buffer->Resize(std::vector<int64_t>{X.t.size() + 8});
+ inputPtr = buffer->template mutable_data<uint8_t>() + 8;
+ memcpy(inputPtr, X.t.template data<uint8_t>(), X.t.size());
+ }
+
+ if (lastBatchSize_ != static_cast<size_t>(X.t.dim(0)) ||
+ lastInputHeight_ != static_cast<size_t>(X.t.dim(1)) ||
+ lastInputWidth_ != static_cast<size_t>(X.t.dim(2)) ||
+ lastInputPointer_ != inputPtr ||
+ lastOutputPointer_ != Y->t.template mutable_data<uint8_t>()) {
+ const qnnp_status setupStatus = qnnp_setup_deconvolution2d_nhwc_q8(
+ this->qnnpackObject_,
+ X.t.dim(0),
+ X.t.dim(1),
+ X.t.dim(2),
+ inputPtr,
+ X.t.dim(3) /* input pixel stride */,
+ Y->t.template mutable_data<uint8_t>(),
+ Y->t.dim(3) /* output pixel stride */,
+ nullptr /* threadpool */);
+ CAFFE_ENFORCE(
+ setupStatus == qnnp_status_success,
+ "failed to setup QNNPACK convolution object");
+
+ lastBatchSize_ = static_cast<size_t>(X.t.dim(0));
+ lastInputHeight_ = static_cast<size_t>(X.t.dim(1));
+ lastInputWidth_ = static_cast<size_t>(X.t.dim(2));
+ lastInputPointer_ = inputPtr;
+ lastOutputPointer_ = Y->t.template mutable_data<uint8_t>();
+ }
+
+#ifdef FBCODE_CAFFE2
+ const qnnp_status runStatus =
+ qnnp_run_operator(this->qnnpackObject_, nullptr /* thread pool */);
+#else
+ const qnnp_status runStatus =
+ qnnp_run_operator(this->qnnpackObject_, threadpool);
+#endif
+ CAFFE_ENFORCE(
+ runStatus == qnnp_status_success,
+ "failed to run QNNPACK convolution");
+ });
+ return true;
+ }
+
+ private:
+ C2GEMMContext gemm_context_;
+ // QNNPACK convolution object
+ qnnp_operator_t qnnpackObject_{nullptr};
+ // batch size in the previous call to RunOnDeviceWithOrderNHWC
+ size_t lastBatchSize_{0};
+ // input height in the previous call to RunOnDeviceWithOrderNHWC
+ size_t lastInputHeight_{0};
+ // input width in the previous call to RunOnDeviceWithOrderNHWC
+ size_t lastInputWidth_{0};
+ // input pointer in the previous call to RunOnDeviceWithOrderNHWC
+ const void* lastInputPointer_{nullptr};
+ // output pointer in the previous call to RunOnDeviceWithOrderNHWC
+ void* lastOutputPointer_{nullptr};
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_CONV_TRANSPOSE_OP_H_
diff --git a/caffe2/operators/quantized/int8_dequantize_op.cc b/caffe2/operators/quantized/int8_dequantize_op.cc
new file mode 100644
index 0000000000..f000b8f06c
--- /dev/null
+++ b/caffe2/operators/quantized/int8_dequantize_op.cc
@@ -0,0 +1,14 @@
+#include "caffe2/operators/quantized/int8_dequantize_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8Dequantize, int8::Int8DequantizeOp);
+
+OPERATOR_SCHEMA(Int8Dequantize)
+ .IdenticalTypeAndShape()
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Input(0, "qX", "Int8 Tensor qX.")
+ .Output(0, "Y", "FP32 Tensor that represents mapped real value of qX.");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_dequantize_op.h b/caffe2/operators/quantized/int8_dequantize_op.h
new file mode 100644
index 0000000000..bb6c3421e5
--- /dev/null
+++ b/caffe2/operators/quantized/int8_dequantize_op.h
@@ -0,0 +1,52 @@
+#ifndef CAFFE2_OPERATORS_INT8_DEQUANTIZE_OP_H_
+#define CAFFE2_OPERATORS_INT8_DEQUANTIZE_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+namespace {
+
+void Int8Dequantize(
+ const uint8_t* in,
+ float* out,
+ const int64_t N,
+ const float X_scale,
+ const int32_t X_offset) {
+ for (auto i = 0; i < N; ++i) {
+ out[i] = (static_cast<int32_t>(in[i]) - X_offset) * X_scale;
+ }
+}
+
+} // namespace
+
+class Int8DequantizeOp final : public Operator<CPUContext> {
+ public:
+ using Operator<CPUContext>::Operator;
+
+ bool RunOnDevice() override {
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
+ auto* Y = Output(0);
+ Y->ResizeLike(X.t);
+ int32_t X_offset = X.zero_point;
+ auto X_scale = X.scale;
+ Int8Dequantize(
+ X.t.data<uint8_t>(),
+ Y->mutable_data<float>(),
+ X.t.size(),
+ X_scale,
+ X_offset);
+ return true;
+ }
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_DEQUANTIZE_OP_H_
diff --git a/caffe2/operators/quantized/int8_fc_op.cc b/caffe2/operators/quantized/int8_fc_op.cc
new file mode 100644
index 0000000000..ee7d6059ee
--- /dev/null
+++ b/caffe2/operators/quantized/int8_fc_op.cc
@@ -0,0 +1,41 @@
+#include "caffe2/operators/quantized/int8_fc_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8FC, int8::Int8FCOp);
+
+OPERATOR_SCHEMA(Int8FC)
+ .NumInputs(3)
+ .NumOutputs(1)
+ .SetDoc(R"DOC(
+Computes the result of passing an input vector X into a fully
+connected layer with 2D weight matrix W and 1D bias vector b. That is,
+the layer computes Y = X * W^T + b, where X has size (M x K),
+W has size (N x K), b has size (N), and Y has size (M x N),
+where M is often the batch size.
+
+
+NOTE: X does not need to explicitly be a 2D vector; rather, it will be
+coerced into one. For an arbitrary n-dimensional tensor
+X \in [a_0, a_1 * ... * a_{n-1}]. Only this case is supported!
+Lastly, even though b is a 1D vector of size N, it is copied/resized to
+be size (M x N) implicitly and added to each vector in the batch.
+Each of these dimensions must be matched correctly, or else the operator
+will throw errors.
+)DOC")
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .Input(
+ 0,
+ "X",
+ "input tensor that's coerced into a 2D matrix of size (MxK) "
+ "as described above")
+ .Input(
+ 1,
+ "W",
+ "A tensor that is coerced into a 2D blob of size (KxN) "
+ "containing fully connected weight matrix")
+ .Input(2, "b", "1D blob containing bias vector")
+ .Output(0, "Y", "2D output tensor");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_fc_op.h b/caffe2/operators/quantized/int8_fc_op.h
new file mode 100644
index 0000000000..963dd9bd56
--- /dev/null
+++ b/caffe2/operators/quantized/int8_fc_op.h
@@ -0,0 +1,133 @@
+#ifndef CAFFE2_OPERATORS_INT8_FC_OP_H_
+#define CAFFE2_OPERATORS_INT8_FC_OP_H_
+
+#include <qnnpack.h>
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/conv_op_shared.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+class Int8FCOp final : public Operator<CPUContext> {
+ public:
+ Int8FCOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CPUContext>(operator_def, ws),
+ ws_(ws),
+ gemm_context_(ws->GetThreadPool()) {
+ createSharedBuffer<CPUContext>(ws_);
+ }
+
+ ~Int8FCOp() {
+ if (this->qnnpackObject_ != nullptr) {
+ qnnp_delete_operator(this->qnnpackObject_);
+ this->qnnpackObject_ = nullptr;
+ }
+ }
+
+ bool RunOnDevice() override {
+ const auto& X = Inputs()[0]->Get<Int8TensorCPU>();
+ const auto& W = Inputs()[1]->Get<Int8TensorCPU>();
+ const auto& B = Inputs()[2]->Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ Y->scale = Y_scale;
+ Y->zero_point = Y_offset;
+ // (NxHxW)xC == MxK x (NxK) -> MxN
+ const auto K = X.t.size_from_dim(1);
+ const auto N = W.t.dim(0);
+ CHECK_EQ(K, W.t.dim(1));
+ CHECK_EQ(N, B.t.size());
+ const auto M = X.t.size() / K;
+ Y->t.Resize(M, N);
+
+ runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
+ initQNNPACK();
+
+ pthreadpool_t threadpool =
+ reinterpret_cast<pthreadpool_t>(gemm_context_.threadPool());
+
+ if (this->qnnpackObject_ == nullptr) {
+ const qnnp_status createStatus = qnnp_create_fully_connected_nc_q8(
+ K,
+ N,
+ X.zero_point,
+ X.scale,
+ W.zero_point,
+ W.scale,
+ W.t.template data<uint8_t>(),
+ B.t.template data<int32_t>(),
+ Y->zero_point,
+ Y->scale,
+ std::numeric_limits<uint8_t>::min(),
+ std::numeric_limits<uint8_t>::max(),
+ &this->qnnpackObject_);
+ CAFFE_ENFORCE(
+ createStatus == qnnp_status_success,
+ "failed to create QNNPACK fully connected operator");
+ CAFFE_ENFORCE(this->qnnpackObject_ != nullptr);
+ }
+
+ uint8_t* inputPtr = X.t.template mutable_data<uint8_t>();
+ if (K < 8) {
+ buffer->Resize(std::vector<int64_t>{X.t.size() + 8});
+ inputPtr = buffer->template mutable_data<uint8_t>() + 8;
+ memcpy(inputPtr, X.t.template data<uint8_t>(), X.t.size());
+ }
+
+ if (lastBatchSize_ != static_cast<size_t>(M) ||
+ lastInputPointer_ != inputPtr ||
+ lastOutputPointer_ != Y->t.template mutable_data<uint8_t>()) {
+ const qnnp_status setupStatus = qnnp_setup_fully_connected_nc_q8(
+ this->qnnpackObject_,
+ M,
+ inputPtr,
+ K /* input stride */,
+ Y->t.template mutable_data<uint8_t>(),
+ N /* output stride */,
+ nullptr /* threadpool */);
+ CAFFE_ENFORCE(
+ setupStatus == qnnp_status_success,
+ "failed to setup QNNPACK fully connected operator");
+
+ lastBatchSize_ = static_cast<size_t>(M);
+ lastInputPointer_ = inputPtr;
+ lastOutputPointer_ = Y->t.template mutable_data<uint8_t>();
+ }
+
+#ifdef FBCODE_CAFFE2
+ const qnnp_status runStatus =
+ qnnp_run_operator(this->qnnpackObject_, nullptr /* thread pool */);
+#else
+ const qnnp_status runStatus =
+ qnnp_run_operator(this->qnnpackObject_, threadpool);
+#endif
+ CAFFE_ENFORCE(
+ runStatus == qnnp_status_success, "failed to run QNNPACK operator");
+ });
+ return true;
+ }
+
+ private:
+ Workspace* ws_;
+ C2GEMMContext gemm_context_;
+ // QNNPACK convolution object
+ qnnp_operator_t qnnpackObject_{nullptr};
+ // batch size in the previous call to RunOnDeviceWithOrderNHWC
+ size_t lastBatchSize_{0};
+ // input pointer in the previous call to RunOnDeviceWithOrderNHWC
+ const void* lastInputPointer_{nullptr};
+ // output pointer in the previous call to RunOnDeviceWithOrderNHWC
+ void* lastOutputPointer_{nullptr};
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_FC_OP_H_
diff --git a/caffe2/operators/quantized/int8_flatten_op.cc b/caffe2/operators/quantized/int8_flatten_op.cc
new file mode 100644
index 0000000000..a9d5a5ede2
--- /dev/null
+++ b/caffe2/operators/quantized/int8_flatten_op.cc
@@ -0,0 +1,30 @@
+#include "caffe2/operators/quantized/int8_flatten_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8Flatten, int8::Int8FlattenOp);
+
+OPERATOR_SCHEMA(Int8Flatten)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .SetDoc(R"DOC(
+Flattens the input tensor into a 2D matrix. If input tensor has shape
+(d_0, d_1, ... d_n) then the output will have shape
+(d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn)
+)DOC")
+ .Input(0, "input", "A Int8 tensor of rank >= axis.")
+ .Output(
+ 0,
+ "output",
+ "A 2D Int8 tensor with the contents of the input tensor, "
+ "with input dimensions up to axis flattened to the outer dimension "
+ "of the output and remaining input dimensions flattened into the inner "
+ "dimension of the output.")
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .Arg(
+ "axis",
+ "(Default to 1) Indicate up to which input dimensions "
+ "(exclusive) should be flattened to the outer dimension of the output");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_flatten_op.h b/caffe2/operators/quantized/int8_flatten_op.h
new file mode 100644
index 0000000000..6aacb2ac7a
--- /dev/null
+++ b/caffe2/operators/quantized/int8_flatten_op.h
@@ -0,0 +1,47 @@
+#ifndef CAFFE2_OPERATORS_INT8_FLATTEN_OP_H_
+#define CAFFE2_OPERATORS_INT8_FLATTEN_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+class Int8FlattenOp : public Operator<CPUContext> {
+ public:
+ Int8FlattenOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CPUContext>(operator_def, ws),
+ axis_(this->template GetSingleArgument<int>("axis", 1)) {}
+
+ bool RunOnDevice() override {
+ const auto& X = Inputs()[0]->Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ CHECK_EQ(Y_offset, X.zero_point);
+ CHECK_EQ(Y_scale, X.scale);
+ Y->scale = Y_scale;
+ Y->zero_point = Y_offset;
+ CAFFE_ENFORCE_GE(
+ X.t.sizes().size(), axis_, "The rank of the tensor must be >= axis.");
+ Y->t.Resize(X.t.size_to_dim(axis_), X.t.size_from_dim(axis_));
+ context_.CopyItemsToCPU(
+ X.t.meta(),
+ X.t.size(),
+ X.t.raw_data(),
+ Y->t.raw_mutable_data(X.t.meta()));
+ return true;
+ }
+
+ private:
+ int axis_;
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_FLATTEN_OP_H_
diff --git a/caffe2/operators/quantized/int8_given_tensor_fill_op.cc b/caffe2/operators/quantized/int8_given_tensor_fill_op.cc
new file mode 100644
index 0000000000..4840b4880b
--- /dev/null
+++ b/caffe2/operators/quantized/int8_given_tensor_fill_op.cc
@@ -0,0 +1,32 @@
+#include "int8_given_tensor_fill_op.h"
+
+namespace caffe2 {
+
+OPERATOR_SCHEMA(Int8GivenTensorFill)
+ .NumInputs(0)
+ .NumOutputs(1)
+ .Arg("value", "Input array of type char(byte)")
+ .Arg("shape", "Input tensor shape")
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .SetDoc(R"DOC(
+ Creates quantized tensor of type char(byte) with scale and zero point info.
+)DOC")
+ .Output(0, "Tensor", "An Int8TensorCPU with scale and zero point info");
+
+OPERATOR_SCHEMA(Int8GivenIntTensorFill)
+ .NumInputs(0)
+ .NumOutputs(1)
+ .Arg("value", "Input array of type int32")
+ .Arg("shape", "Input tensor shape")
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .SetDoc(R"DOC(
+ Creates quantized tensor of type int32 with scale and zero point info.
+)DOC")
+ .Output(0, "Tensor", "An Int8TensorCPU with scale and zero point info");
+
+REGISTER_CPU_OPERATOR(Int8GivenTensorFill, int8::Int8GivenTensorFillOp);
+REGISTER_CPU_OPERATOR(Int8GivenIntTensorFill, int8::Int8GivenIntTensorFillOp);
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_given_tensor_fill_op.h b/caffe2/operators/quantized/int8_given_tensor_fill_op.h
new file mode 100644
index 0000000000..5352844db6
--- /dev/null
+++ b/caffe2/operators/quantized/int8_given_tensor_fill_op.h
@@ -0,0 +1,114 @@
+#ifndef CAFFE2_OPERATORS_INT8_GIVEN_TENSOR_FILL_OP_H_
+#define CAFFE2_OPERATORS_INT8_GIVEN_TENSOR_FILL_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/filler_op.h"
+#include "caffe2/utils/cast.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+namespace int8 {
+
+class Int8GivenTensorFillOp final : public Operator<CPUContext> {
+ public:
+ Int8GivenTensorFillOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CPUContext>(operator_def, ws),
+ scale_(this->template GetSingleArgument<float>("Y_scale", 1.0)),
+ zero_point_(
+ this->template GetSingleArgument<int32_t>("Y_zero_point", 0)),
+ shape_(this->template GetRepeatedArgument<int64_t>("shape")) {
+ ExtractValues();
+ }
+
+ bool RunOnDevice() override {
+ auto* output = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ output->t.Resize(shape_);
+ output->scale = scale_;
+ output->zero_point = zero_point_;
+ return Fill(output);
+ }
+
+ private:
+ void ExtractValues() {
+ auto source_values = this->template GetSingleArgument<string>("values", "");
+ values_.Resize(source_values.size());
+ uint8_t* values_data = values_.template mutable_data<uint8_t>();
+ for (int i = 0; i < source_values.size(); i++) {
+ values_data[i] = static_cast<uint8_t>(source_values[i]);
+ }
+ }
+
+ bool Fill(Int8TensorCPU* output) {
+ DCHECK_EQ(output->t.size(), values_.size())
+ << "output size: " << output->t.size()
+ << " given size: " << values_.size();
+ auto* data = output->t.template mutable_data<uint8_t>();
+ const uint8_t* values_data = values_.template data<uint8_t>();
+ if (output->t.size()) {
+ context_.template CopySameDevice<uint8_t>(
+ output->t.size(), values_data, data);
+ }
+ return true;
+ }
+
+ float scale_;
+ int32_t zero_point_;
+ vector<int64_t> shape_;
+ Tensor values_{CPU};
+};
+
+class Int8GivenIntTensorFillOp final : public Operator<CPUContext> {
+ public:
+ Int8GivenIntTensorFillOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CPUContext>(operator_def, ws),
+ scale_(this->template GetSingleArgument<float>("Y_scale", 1.0)),
+ zero_point_(
+ this->template GetSingleArgument<int32_t>("Y_zero_point", 0)),
+ shape_(this->template GetRepeatedArgument<int64_t>("shape")) {
+ ExtractValues();
+ }
+
+ bool RunOnDevice() override {
+ auto* output = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ output->t.Resize(shape_);
+ output->scale = scale_;
+ output->zero_point = zero_point_;
+ return Fill(output);
+ }
+
+ private:
+ void ExtractValues() {
+ auto source_values = this->template GetRepeatedArgument<int32_t>("values");
+ values_.Resize(source_values.size());
+ auto* values_data = values_.template mutable_data<int32_t>();
+ for (int i = 0; i < source_values.size(); i++) {
+ values_data[i] = static_cast<int32_t>(source_values[i]);
+ }
+ }
+
+ bool Fill(Int8TensorCPU* output) {
+ DCHECK_EQ(output->t.size(), values_.size())
+ << "output size: " << output->t.size()
+ << " given size: " << values_.size();
+ auto* data = output->t.template mutable_data<int32_t>();
+ const auto* values_data = values_.template data<int32_t>();
+ if (output->t.size()) {
+ context_.template CopySameDevice<int32_t>(
+ output->t.size(), values_data, data);
+ }
+ return true;
+ }
+
+ float scale_;
+ int32_t zero_point_;
+ vector<int64_t> shape_;
+ Tensor values_{CPU};
+};
+
+} // namespace int8
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_GIVEN_TENSOR_FILL_OP_H_
diff --git a/caffe2/operators/quantized/int8_leaky_relu_op.cc b/caffe2/operators/quantized/int8_leaky_relu_op.cc
new file mode 100644
index 0000000000..b6bde8b78e
--- /dev/null
+++ b/caffe2/operators/quantized/int8_leaky_relu_op.cc
@@ -0,0 +1,24 @@
+#include "caffe2/operators/quantized/int8_leaky_relu_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8LeakyRelu, int8::Int8LeakyReluOp);
+
+OPERATOR_SCHEMA(Int8LeakyRelu)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Arg("alpha", "Coefficient of leakage, default value is 0.01")
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .AllowInplace({{0, 0}})
+ .CostInferenceFunction(PointwiseCostInference<2>)
+ .IdenticalTypeAndShape()
+ .SetDoc(R"DOC(
+LeakyRelu takes input data (Tensor<T>) and an argument alpha, and produces one
+output data (Tensor<T>) where the function `f(x) = alpha * x for x < 0`,
+`f(x) = x for x >= 0`, is applied to the data tensor elementwise.
+)DOC")
+ .Input(0, "X", "1D input tensor")
+ .Output(0, "Y", "1D input tensor");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_leaky_relu_op.h b/caffe2/operators/quantized/int8_leaky_relu_op.h
new file mode 100644
index 0000000000..68e6a46a12
--- /dev/null
+++ b/caffe2/operators/quantized/int8_leaky_relu_op.h
@@ -0,0 +1,64 @@
+#ifndef CAFFE2_OPERATORS_INT8_LEAKY_RELU_OP_H_
+#define CAFFE2_OPERATORS_INT8_LEAKY_RELU_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+class Int8LeakyReluOp final : public Operator<CPUContext> {
+ public:
+ Int8LeakyReluOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CPUContext>(operator_def, ws) {
+ double alpha = this->template GetSingleArgument<float>("alpha", 0.01);
+ CAFFE_ENFORCE_GT(alpha, 0.0);
+ CAFFE_ENFORCE_LT(alpha, 1.0);
+ QuantizeMultiplierSmallerThanOne(alpha, &multiplier_, &shift_);
+ }
+
+ bool RunOnDevice() override {
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ Y->t.ResizeLike(X.t);
+ Y->scale = X.scale;
+ Y->zero_point = X.zero_point;
+ CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
+ CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ CHECK_EQ(Y_offset, X.zero_point);
+ CHECK_EQ(Y_scale, X.scale);
+
+ const uint8_t* Xdata = X.t.data<uint8_t>();
+ uint8_t* Ydata = Y->t.mutable_data<uint8_t>();
+
+ // For x < zero_point:
+ // (y - zero_point) * scale = alpha * (x - zero_point) * scale
+ // y = alpha * (x - zeropoint) + zero_point
+ for (int i = 0; i < X.t.size(); i++) {
+ if (Xdata[i] < X.zero_point) {
+ int32_t out = MultiplyByQuantizedMultiplierSmallerThanOne(
+ Xdata[i] - X.zero_point, multiplier_, shift_) +
+ X.zero_point;
+ Ydata[i] = static_cast<uint8_t>(out);
+ } else {
+ Ydata[i] = Xdata[i];
+ }
+ }
+ return true;
+ }
+
+ private:
+ int32_t multiplier_;
+ int shift_;
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_LEAKY_RELU_OP_H_
diff --git a/caffe2/operators/quantized/int8_max_pool_op.cc b/caffe2/operators/quantized/int8_max_pool_op.cc
new file mode 100644
index 0000000000..64a1507b84
--- /dev/null
+++ b/caffe2/operators/quantized/int8_max_pool_op.cc
@@ -0,0 +1,63 @@
+#include "caffe2/operators/quantized/int8_max_pool_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8MaxPool, int8::Int8MaxPoolOp<int8::Activation::NONE>);
+REGISTER_CPU_OPERATOR(
+ Int8MaxPoolRelu,
+ int8::Int8MaxPoolOp<int8::Activation::RELU>);
+
+const char kMaxPoolDoc_int8[] = R"DOC(
+consumes an input blob X and applies max pooling across the
+the blob according to kernel sizes, stride sizes, and pad lengths defined by the
+ConvPoolOpBase operator. Max pooling consisting of taking the maximum value of a
+subset of the input tensor according to the kernel size and downsampling the
+data into the output blob Y for further processing.
+)DOC";
+
+std::function<void(OpSchema&)> MaxPoolDocGenerator(
+ const char* dim,
+ bool relu_fused = false) {
+ auto suffix = relu_fused ? " Output will go through rectified linear "
+ "function, where y = max(0, x)."
+ : "";
+ return [=](OpSchema& schema) {
+ string doc = "MaxPool{dim} {pool_doc}";
+ c10::ReplaceAll(doc, "{dim}", dim);
+ c10::ReplaceAll(doc, "{pool_doc}", kMaxPoolDoc_int8);
+ string output_doc =
+ "Output data tensor from max pooling across the input "
+ "tensor. Dimensions will vary based on various kernel, stride, and pad "
+ "sizes.{suffix}";
+ c10::ReplaceAll(output_doc, "{suffix}", suffix);
+ schema.SetDoc(doc);
+ schema.Input(
+ 0,
+ "X",
+ "Input data tensor from the previous operator; dimensions depend on "
+ "whether the NCHW or NHWC operators are being used. For example, in "
+ "the former, the input has size (N x C x H x W), where N is the batch "
+ "size, C is the number of channels, and H and W are the height and the "
+ "width of the data. The corresponding permutation of dimensions is "
+ "used in the latter case.");
+ schema.Output(0, "Y", output_doc.c_str());
+ };
+}
+
+OPERATOR_SCHEMA(Int8MaxPool)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
+ .FillUsing(MaxPoolDocGenerator(""));
+
+OPERATOR_SCHEMA(Int8MaxPoolRelu)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForPool)
+ .FillUsing(MaxPoolDocGenerator("", true));
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_max_pool_op.h b/caffe2/operators/quantized/int8_max_pool_op.h
new file mode 100644
index 0000000000..26f2fa3d45
--- /dev/null
+++ b/caffe2/operators/quantized/int8_max_pool_op.h
@@ -0,0 +1,183 @@
+#ifndef CAFFE2_OPERATORS_INT8_MAX_POOL_OP_H_
+#define CAFFE2_OPERATORS_INT8_MAX_POOL_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/conv_pool_op_base.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+namespace {
+
+/*
+ * Implementation based on TensorFlow Lite kernels:
+ * - Repo: https://github.com/tensorflow/tensorflow
+ * - Path: tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+ * - Hash: d4ad9c73969c45d1a224ebfc43eb645b9860216b
+ */
+
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+void Int8MaxPool(
+ const uint8_t* input_data,
+ at::IntList input_dims,
+ int stride_width,
+ int stride_height,
+ int pad_width,
+ int pad_height,
+ int filter_width,
+ int filter_height,
+ uint8_t* output_data,
+ at::IntList output_dims,
+ uint8_t output_activation_min,
+ uint8_t output_activation_max) {
+ const int batches = input_dims[0];
+ const int depth = input_dims[3];
+ const int input_height = input_dims[1];
+ const int input_width = input_dims[2];
+ const int output_height = output_dims[1];
+ const int output_width = output_dims[2];
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ const int filter_x_start = std::max(0, -in_x_origin);
+ const int filter_x_end =
+ std::min(filter_width, input_width - in_x_origin);
+ const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_end =
+ std::min(filter_height, input_height - in_y_origin);
+ // 2048 required by Inception v3
+ static constexpr int kAccBufferMaxSize = 2048;
+ CHECK_LE(depth, kAccBufferMaxSize);
+ uint8_t acc[kAccBufferMaxSize];
+ memset(acc, 0, depth * sizeof(acc[0]));
+
+ const uint8_t* input_ptr =
+ &input_data
+ [in_x_origin * depth + in_y_origin * input_width * depth +
+ batch * input_height * input_width * depth];
+
+ for (int fy = filter_y_start; fy < filter_y_end; fy++) {
+ const uint8_t* input_row_ptr =
+ &input_ptr[fy * input_width * depth + filter_x_start * depth];
+
+ for (int fx = filter_x_start; fx < filter_x_end; fx++) {
+ int channel = 0;
+#ifdef INT8_NEON_SIMD
+ for (; channel <= depth - 16; channel += 16) {
+ uint8x16_t acc_reg = vld1q_u8(acc + channel);
+ uint8x16_t input_reg = vld1q_u8(input_row_ptr);
+ input_row_ptr += 16;
+ acc_reg = vmaxq_u8(acc_reg, input_reg);
+ vst1q_u8(acc + channel, acc_reg);
+ }
+
+ for (; channel <= depth - 8; channel += 8) {
+ uint8x8_t acc_reg = vld1_u8(acc + channel);
+ uint8x8_t input_reg = vld1_u8(input_row_ptr);
+ input_row_ptr += 8;
+ acc_reg = vmax_u8(acc_reg, input_reg);
+ vst1_u8(acc + channel, acc_reg);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ acc[channel] = std::max(acc[channel], *input_row_ptr++);
+ }
+ }
+ }
+ uint8_t* output_ptr =
+ &output_data
+ [out_x * depth + out_y * output_width * depth +
+ batch * output_height * output_width * depth];
+ int channel = 0;
+#ifdef INT8_NEON_SIMD
+ for (; channel <= depth - 16; channel += 16) {
+ uint8x16_t a = vld1q_u8(acc + channel);
+ a = vminq_u8(a, vdupq_n_u8(output_activation_max));
+ a = vmaxq_u8(a, vdupq_n_u8(output_activation_min));
+ vst1q_u8(output_ptr + channel, a);
+ }
+ for (; channel <= depth - 8; channel += 8) {
+ uint8x8_t a = vld1_u8(acc + channel);
+ a = vmin_u8(a, vdup_n_u8(output_activation_max));
+ a = vmax_u8(a, vdup_n_u8(output_activation_min));
+ vst1_u8(output_ptr + channel, a);
+ }
+#endif
+ for (; channel < depth; ++channel) {
+ uint8_t a = acc[channel];
+ a = std::max<uint8_t>(a, output_activation_min);
+ a = std::min<uint8_t>(a, output_activation_max);
+ output_ptr[channel] = static_cast<uint8_t>(a);
+ }
+ }
+ }
+ }
+}
+
+} // namespace
+
+template <Activation Ac>
+class Int8MaxPoolOp final : public ConvPoolOpBase<CPUContext> {
+ public:
+ Int8MaxPoolOp(const OperatorDef& operator_def, Workspace* ws)
+ : ConvPoolOpBase<CPUContext>(operator_def, ws) {
+ OPERATOR_NEEDS_FEATURE(
+ this->order_ == StorageOrder::NHWC, "Int8 only supports NCHW order.");
+ }
+
+ bool RunOnDeviceWithOrderNHWC() override {
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ Y->scale = X.scale;
+ Y->zero_point = X.zero_point;
+ const int32_t Y_offset =
+ this->template GetSingleArgument<int>("Y_zero_point", 0);
+ const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ CHECK_EQ(Y_offset, X.zero_point);
+ CHECK_EQ(Y_scale, X.scale);
+
+ CHECK_EQ(X.t.ndim(), 4);
+ const int height = X.t.dim32(1);
+ const int width = X.t.dim32(2);
+ const int channels = X.t.dim32(3);
+ ConvPoolOpBase<CPUContext>::SetOutputSize(X.t, &(Y->t), channels);
+
+ Int8MaxPool(
+ X.t.template data<uint8_t>(),
+ X.t.sizes(),
+ stride_w(),
+ stride_h(),
+ pad_l(),
+ pad_t(),
+ kernel_w(),
+ kernel_h(),
+ Y->t.template mutable_data<uint8_t>(),
+ Y->t.sizes(),
+ activationLimits(Y->scale, Y->zero_point, Ac).first,
+ activationLimits(Y->scale, Y->zero_point, Ac).second);
+ return true;
+ }
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_MAX_POOL_OP_H_
diff --git a/caffe2/operators/quantized/int8_quantize_op.cc b/caffe2/operators/quantized/int8_quantize_op.cc
new file mode 100644
index 0000000000..cbabe4613c
--- /dev/null
+++ b/caffe2/operators/quantized/int8_quantize_op.cc
@@ -0,0 +1,16 @@
+#include "caffe2/operators/quantized/int8_quantize_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8Quantize, int8::Int8QuantizeOp);
+
+OPERATOR_SCHEMA(Int8Quantize)
+ .IdenticalTypeAndShape()
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Input(0, "X", "FP32 Tensor X.")
+ .Output(0, "Y", "Int8 Tensor qX representing X with linear quantization.");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_quantize_op.h b/caffe2/operators/quantized/int8_quantize_op.h
new file mode 100644
index 0000000000..53a5ce2dfc
--- /dev/null
+++ b/caffe2/operators/quantized/int8_quantize_op.h
@@ -0,0 +1,91 @@
+#ifndef CAFFE2_OPERATORS_INT8_QUANTIZE_OP_H_
+#define CAFFE2_OPERATORS_INT8_QUANTIZE_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+namespace {
+
+void Int8Quantize(
+ const float* in,
+ uint8_t* out,
+ const int64_t N,
+ const float Y_scale,
+ const int32_t Y_offset) {
+ const float inv_scale = 1.0f / Y_scale;
+ uint32_t i = 0;
+#ifdef INT8_NEON_SIMD
+ const float32x4_t vinv_scale = vdupq_n_f32(inv_scale);
+ // magic float and magic int to take care of rounding
+ // int magic_round(float f): interpret_int32(f + 12582912.0f) - 0x4B400000
+ // Some detail:
+ // 12582912.0f is 2**23 + 2**22. The trick is based on the fact that when you
+ // add a small number to a large number, the result rounds to the precision of
+ // the least significant bit of the large number. For IEEE-754
+ // single-precision number mantissa has 23 bits, and adding 2**23 would cause
+ // rounding to the nearest even integer. The we cast to int and subtract the
+ // same number (0x4B400000 is the integer representation of 12582912.0f) to
+ // get only the mantissa. This works if -2**22 < x < 2**22, but preserves the
+ // sign for negative numbers.
+ const int32x4_t voffset = vdupq_n_s32(Y_offset - 0x4B400000);
+ const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f);
+ for (i = 0; i + 8 < N; i += 8) {
+ const float32x4_t vin0123 = vld1q_f32(in);
+ in += 4;
+ const float32x4_t vin4567 = vld1q_f32(in);
+ in += 4;
+ const int32x4_t vraw0123 = vaddq_s32(
+ voffset,
+ vreinterpretq_s32_f32(
+ vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale))));
+ const int32x4_t vraw4567 = vaddq_s32(
+ voffset,
+ vreinterpretq_s32_f32(
+ vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale))));
+ const int16x8_t vraw01234567 =
+ vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567));
+ const uint8x8_t vout01234567 = vqmovun_s16(vraw01234567);
+ vst1_u8(out, vout01234567);
+ out += 8;
+ }
+#endif
+ for (; i < N; ++i) {
+ (*out++) = QuantizeUint8(Y_scale, Y_offset, (*in++));
+ }
+}
+
+} // namespace
+
+class Int8QuantizeOp final : public Operator<CPUContext> {
+ public:
+ using Operator<CPUContext>::Operator;
+
+ bool RunOnDevice() override {
+ const auto& X = Input(0);
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ Y->t.ResizeLike(X);
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ Y->scale = Y_scale;
+ Y->zero_point = Y_offset;
+ Int8Quantize(
+ X.data<float>(),
+ Y->t.mutable_data<uint8_t>(),
+ X.size(),
+ Y_scale,
+ Y_offset);
+ return true;
+ }
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_QUANTIZE_OP_H_
diff --git a/caffe2/operators/quantized/int8_relu_op.cc b/caffe2/operators/quantized/int8_relu_op.cc
new file mode 100644
index 0000000000..f069d45100
--- /dev/null
+++ b/caffe2/operators/quantized/int8_relu_op.cc
@@ -0,0 +1,37 @@
+#include "caffe2/operators/quantized/int8_relu_op.h"
+
+namespace caffe2 {
+
+namespace {
+
+OpSchema::Cost CostInferenceForRelu(
+ const OperatorDef& def,
+ const vector<TensorShape>& in) {
+ struct OpSchema::Cost cost = PointwiseCostInference<0>(def, in);
+ cost.params_bytes = 0;
+ return cost;
+}
+
+} // namespace
+
+REGISTER_CPU_OPERATOR(Int8Relu, int8::Int8ReluOp);
+
+// Input: X, output: Y
+OPERATOR_SCHEMA(Int8Relu)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .AllowInplace({{0, 0}})
+ .CostInferenceFunction(CostInferenceForRelu)
+ .IdenticalTypeAndShape()
+ .SetDoc(R"DOC(
+Relu takes one input data (Tensor<T>) and produces one output data
+(Tensor<T>) where the rectified linear function, y = max(0, x), is applied to
+the tensor elementwise.
+)DOC")
+ .Input(0, "X", "1D input tensor")
+ .Output(0, "Y", "1D input tensor")
+ .InheritOnnxSchema("Relu");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_relu_op.h b/caffe2/operators/quantized/int8_relu_op.h
new file mode 100644
index 0000000000..9e88bcc75e
--- /dev/null
+++ b/caffe2/operators/quantized/int8_relu_op.h
@@ -0,0 +1,43 @@
+#ifndef CAFFE2_OPERATORS_INT8_RELU_OP_H_
+#define CAFFE2_OPERATORS_INT8_RELU_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+#include "caffe2/utils/eigen_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+class Int8ReluOp final : public Operator<CPUContext> {
+ public:
+ using Operator<CPUContext>::Operator;
+
+ bool RunOnDevice() override {
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ Y->t.ResizeLike(X.t);
+ Y->scale = X.scale;
+ Y->zero_point = X.zero_point;
+ CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
+ CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
+ const int32_t Y_offset =
+ this->template GetSingleArgument<int>("Y_zero_point", 0);
+ const float Y_scale =
+ this->template GetSingleArgument<float>("Y_scale", 1.0f);
+ CHECK_EQ(Y_offset, X.zero_point);
+ CHECK_EQ(Y_scale, X.scale);
+ EigenVectorMap<uint8_t>(Y->t.mutable_data<uint8_t>(), X.t.size()) =
+ ConstEigenVectorMap<uint8_t>(X.t.data<uint8_t>(), X.t.size())
+ .cwiseMax(QuantizeUint8(X.scale, X.zero_point, 0));
+ return true;
+ }
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_RELU_OP_H_
diff --git a/caffe2/operators/quantized/int8_reshape_op.cc b/caffe2/operators/quantized/int8_reshape_op.cc
new file mode 100644
index 0000000000..a385313666
--- /dev/null
+++ b/caffe2/operators/quantized/int8_reshape_op.cc
@@ -0,0 +1,31 @@
+#include "caffe2/operators/quantized/int8_reshape_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8Reshape, int8::Int8ReshapeOp);
+
+OPERATOR_SCHEMA(Int8Reshape)
+ .NumInputs(1, 2)
+ .NumOutputs(2)
+ .AllowInplace({{0, 0}})
+ .SetDoc(R"DOC(
+Reshape the input tensor similar to numpy.reshape.
+
+It takes a tensor as input and an optional tensor specifying the new shape.
+When the second input is absent, an extra argument `shape` must be specified.
+It outputs the reshaped tensor as well as the original shape.
+
+At most one dimension of the new shape can be -1. In this case, the value is
+inferred from the size of the tensor and the remaining dimensions. A dimension
+could also be 0, in which case the actual dimension value is going to be copied
+from the input tensor.
+)DOC")
+ .Arg("shape", "New shape")
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .Input(0, "data", "An input tensor.")
+ .Input(1, "new_shape", "New shape.")
+ .Output(0, "reshaped", "Reshaped data.")
+ .Output(1, "old_shape", "Original shape.");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_reshape_op.h b/caffe2/operators/quantized/int8_reshape_op.h
new file mode 100644
index 0000000000..0aa2000ea7
--- /dev/null
+++ b/caffe2/operators/quantized/int8_reshape_op.h
@@ -0,0 +1,47 @@
+#ifndef CAFFE2_OPERATORS_INT8_RESHAPE_OP_H_
+#define CAFFE2_OPERATORS_INT8_RESHAPE_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+#include "caffe2/operators/reshape_op.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+class Int8ReshapeOp final : public ReshapeOp<uint8_t, CPUContext> {
+ public:
+ Int8ReshapeOp(const OperatorDef& operator_def, Workspace* ws)
+ : ReshapeOp(operator_def, ws) {}
+
+ bool RunOnDevice() override {
+ if (InputSize() == 2) {
+ return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
+ }
+ CAFFE_ENFORCE(
+ OperatorBase::HasArgument("shape"), "Argument `shape` is missing.");
+ return this->template DoRunWithType<int64_t>();
+ }
+
+ template <typename T>
+ bool DoRunWithType() {
+ auto& X = Inputs()[0]->Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ CHECK_EQ(Y_offset, X.zero_point);
+ CHECK_EQ(Y_scale, X.scale);
+ Y->scale = Y_scale;
+ Y->zero_point = Y_offset;
+ DoRunWithTypeImpl<T>(X.t, &Y->t);
+ return true;
+ }
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_RESHAPE_OP_H_
diff --git a/caffe2/operators/quantized/int8_resize_nearest_op.cc b/caffe2/operators/quantized/int8_resize_nearest_op.cc
new file mode 100644
index 0000000000..fd5e2fd89b
--- /dev/null
+++ b/caffe2/operators/quantized/int8_resize_nearest_op.cc
@@ -0,0 +1,25 @@
+#include "caffe2/operators/quantized/int8_resize_nearest_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8ResizeNearest, int8::Int8ResizeNearestOp);
+
+// Input: X, output: Y
+OPERATOR_SCHEMA(Int8ResizeNearest)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .Arg("width_scale", "Scale along width dimension")
+ .Arg("height_scale", "Scale along height dimension")
+ .SetDoc(R"DOC(
+Resizes the spatial dimensions of the input using nearest neighbor
+interpolation. The `width_scale` and `height_scale` arguments
+control the size of the output, which is given by:
+output_width = floor(input_width * width_scale)
+output_height = floor(output_height * height_scale)
+)DOC")
+ .Input(0, "X", "Input Int8 tensor")
+ .Output(0, "Y", "Output Int8 tensor");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_resize_nearest_op.h b/caffe2/operators/quantized/int8_resize_nearest_op.h
new file mode 100644
index 0000000000..eebcc33064
--- /dev/null
+++ b/caffe2/operators/quantized/int8_resize_nearest_op.h
@@ -0,0 +1,72 @@
+#ifndef CAFFE2_OPERATORS_INT8_RESIZE_NEAREST_OP_H_
+#define CAFFE2_OPERATORS_INT8_RESIZE_NEAREST_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+class Int8ResizeNearestOp final : public Operator<CPUContext> {
+ public:
+ Int8ResizeNearestOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CPUContext>(operator_def, ws) {
+ width_scale_ = this->template GetSingleArgument<float>("width_scale", 1);
+ height_scale_ = this->template GetSingleArgument<float>("height_scale", 1);
+ CAFFE_ENFORCE_GT(width_scale_, 0);
+ CAFFE_ENFORCE_GT(height_scale_, 0);
+ }
+
+ bool RunOnDevice() override {
+ // Assume NHWC layout.
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+
+ CAFFE_ENFORCE_EQ(4, X.t.ndim());
+ const int N = X.t.dim32(0);
+ const int IH = X.t.dim32(1);
+ const int IW = X.t.dim32(2);
+ const int C = X.t.dim32(3);
+ const int OW = IW * width_scale_;
+ const int OH = IH * height_scale_;
+
+ Y->t.Resize(N, OH, OW, C);
+ Y->scale = X.scale;
+ Y->zero_point = X.zero_point;
+
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ CHECK_EQ(Y_offset, X.zero_point);
+ CHECK_EQ(Y_scale, X.scale);
+
+ const uint8_t* Xdata = X.t.data<uint8_t>();
+ uint8_t* Ydata = Y->t.mutable_data<uint8_t>();
+
+ for (int n = 0; n < N; ++n) {
+ for (int y = 0; y < OH; ++y) {
+ const int in_y = std::min((int)(y / height_scale_), (IH - 1));
+ for (int x = 0; x < OW; ++x) {
+ const int in_x = std::min((int)(x / width_scale_), (IW - 1));
+ std::memcpy(
+ &Ydata[C * x + C * OW * y + C * OW * OH * n],
+ &Xdata[C * in_x + C * IW * in_y + C * IW * IH * n],
+ C);
+ }
+ }
+ }
+ return true;
+ }
+
+ private:
+ float width_scale_;
+ float height_scale_;
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_RESIZE_NEAREST_OP_H_
diff --git a/caffe2/operators/quantized/int8_roi_align_op.cc b/caffe2/operators/quantized/int8_roi_align_op.cc
new file mode 100644
index 0000000000..2caf91706f
--- /dev/null
+++ b/caffe2/operators/quantized/int8_roi_align_op.cc
@@ -0,0 +1,45 @@
+#include "caffe2/operators/quantized/int8_roi_align_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8RoIAlign, int8::Int8RoIAlignOp);
+
+OPERATOR_SCHEMA(Int8RoIAlign)
+ .NumInputs(2)
+ .NumOutputs(1)
+ .SetDoc(R"DOC(
+Region of Interest (RoI) align operation as used in Mask R-CNN.
+)DOC")
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .Arg(
+ "spatial_scale",
+ "(float) default 1.0; Spatial scale of the input feature map X "
+ "relative to the input image. E.g., 0.0625 if X has a stride of 16 "
+ "w.r.t. the input image.")
+ .Arg("pooled_h", "(int) default 1; Pooled output Y's height.")
+ .Arg("pooled_w", "(int) default 1; Pooled output Y's width.")
+ .Arg(
+ "sampling_ratio",
+ "(int) default -1; number of sampling points in the interpolation grid "
+ "used to compute the output value of each pooled output bin. If > 0, "
+ "then exactly sampling_ratio x sampling_ratio grid points are used. If "
+ "<= 0, then an adaptive number of grid points are used (computed as "
+ "ceil(roi_width / pooled_w), and likewise for height).")
+ .Input(0, "X", "4D Int8 Tensor feature map input of shape (N, C, H, W).")
+ .Input(
+ 1,
+ "RoIs",
+ "2D input of shape (R, 4 or 5) specifying R RoIs "
+ "representing: batch index in [0, N - 1], x1, y1, x2, y2. The RoI "
+ "coordinates are in the coordinate system of the input image. For "
+ "inputs corresponding to a single image, batch index can be excluded "
+ "to have just 4 columns.")
+ .Output(
+ 0,
+ "Y",
+ "4D Int8 Tensor output of shape (R, C, pooled_h, pooled_w). "
+ "The r-th batch element "
+ "is a pooled feature map cooresponding to the r-th RoI.");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_roi_align_op.h b/caffe2/operators/quantized/int8_roi_align_op.h
new file mode 100644
index 0000000000..93d36166d9
--- /dev/null
+++ b/caffe2/operators/quantized/int8_roi_align_op.h
@@ -0,0 +1,341 @@
+#ifndef CAFFE2_OPERATORS_INT8_ROI_ALIGN_OP_H_
+#define CAFFE2_OPERATORS_INT8_ROI_ALIGN_OP_H_
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/context.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/operator_schema.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+namespace {
+
+struct PreCalc {
+ int pos1;
+ int pos2;
+ int pos3;
+ int pos4;
+ uint8_t w1;
+ uint8_t w2;
+ uint8_t w3;
+ uint8_t w4;
+};
+
+void pre_calc_for_bilinear_interpolate(
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int iy_upper,
+ const int ix_upper,
+ float roi_start_h,
+ float roi_start_w,
+ float bin_size_h,
+ float bin_size_w,
+ int roi_bin_grid_h,
+ int roi_bin_grid_w,
+ std::vector<PreCalc>& pre_calc) {
+ int pre_calc_index = 0;
+ // boltnn use a smaller multiplier here. Sometimes w will shrink to 0.
+ const float w_multiplier = 255.0;
+ for (int ph = 0; ph < pooled_height; ph++) {
+ for (int pw = 0; pw < pooled_width; pw++) {
+ for (int iy = 0; iy < iy_upper; iy++) {
+ const float yy = roi_start_h + ph * bin_size_h +
+ static_cast<float>(iy + .5f) * bin_size_h /
+ static_cast<float>(roi_bin_grid_h); // e.g., 0.5, 1.5
+ for (int ix = 0; ix < ix_upper; ix++) {
+ const float xx = roi_start_w + pw * bin_size_w +
+ static_cast<float>(ix + .5f) * bin_size_w /
+ static_cast<float>(roi_bin_grid_w);
+
+ float x = xx;
+ float y = yy;
+ // deal with: inverse elements are out of feature map boundary
+ if (y < -1.0 || y > height || x < -1.0 || x > width) {
+ // empty
+ PreCalc pc;
+ pc.pos1 = 0;
+ pc.pos2 = 0;
+ pc.pos3 = 0;
+ pc.pos4 = 0;
+ pc.w1 = 0;
+ pc.w2 = 0;
+ pc.w3 = 0;
+ pc.w4 = 0;
+ pre_calc[pre_calc_index] = pc;
+ pre_calc_index += 1;
+ continue;
+ }
+
+ if (y <= 0) {
+ y = 0;
+ }
+ if (x <= 0) {
+ x = 0;
+ }
+
+ int y_low = (int)y;
+ int x_low = (int)x;
+ int y_high;
+ int x_high;
+
+ if (y_low >= height - 1) {
+ y_high = y_low = height - 1;
+ y = (float)y_low;
+ } else {
+ y_high = y_low + 1;
+ }
+
+ if (x_low >= width - 1) {
+ x_high = x_low = width - 1;
+ x = (float)x_low;
+ } else {
+ x_high = x_low + 1;
+ }
+
+ float ly = y - y_low;
+ float lx = x - x_low;
+ float hy = 1. - ly, hx = 1. - lx;
+ // w are not necessary 1
+ uint8_t w1 = static_cast<uint8_t>(Round(hy * hx * w_multiplier));
+ uint8_t w2 = static_cast<uint8_t>(Round(hy * lx * w_multiplier));
+ uint8_t w3 = static_cast<uint8_t>(Round(ly * hx * w_multiplier));
+ uint8_t w4 = static_cast<uint8_t>(Round(ly * lx * w_multiplier));
+
+ // save weights and indeces
+ PreCalc pc;
+ pc.pos1 = y_low * width + x_low;
+ pc.pos2 = y_low * width + x_high;
+ pc.pos3 = y_high * width + x_low;
+ pc.pos4 = y_high * width + x_high;
+
+ pc.w1 = w1;
+ pc.w2 = w2;
+ pc.w3 = w3;
+ pc.w4 = w4;
+ pre_calc[pre_calc_index] = pc;
+
+ pre_calc_index += 1;
+ }
+ }
+ }
+ }
+}
+
+void ROIAlignForward(
+ const int nthreads,
+ const uint8_t* bottom_data,
+ const float& spatial_scale,
+ const int channels,
+ const int height,
+ const int width,
+ const int pooled_height,
+ const int pooled_width,
+ const int sampling_ratio,
+ const float* bottom_rois,
+ int roi_cols,
+ uint8_t* top_data,
+ const float x_scale,
+ const float y_scale,
+ const int32_t x_offset,
+ const int32_t y_offset,
+ StorageOrder order) {
+ DCHECK(roi_cols == 4 || roi_cols == 5);
+
+ int n_rois = nthreads / channels / pooled_width / pooled_height;
+
+ for (int n = 0; n < n_rois; n++) {
+ int index_n = n * channels * pooled_width * pooled_height;
+
+ // roi could have 4 or 5 columns
+ const float* offset_bottom_rois = bottom_rois + n * roi_cols;
+ int roi_batch_ind = 0;
+ if (roi_cols == 5) {
+ roi_batch_ind = offset_bottom_rois[0];
+ offset_bottom_rois++;
+ }
+
+ // Do not using rounding; this implementation detail is critical
+ float roi_start_w = offset_bottom_rois[0] * spatial_scale;
+ float roi_start_h = offset_bottom_rois[1] * spatial_scale;
+ float roi_end_w = offset_bottom_rois[2] * spatial_scale;
+ float roi_end_h = offset_bottom_rois[3] * spatial_scale;
+
+ // Force malformed ROIs to be 1x1
+ float roi_width = std::max(roi_end_w - roi_start_w, (float)1.);
+ float roi_height = std::max(roi_end_h - roi_start_h, (float)1.);
+ float bin_size_h =
+ static_cast<float>(roi_height) / static_cast<float>(pooled_height);
+ float bin_size_w =
+ static_cast<float>(roi_width) / static_cast<float>(pooled_width);
+
+ // We use roi_bin_grid to sample the grid and mimic integral
+ int roi_bin_grid_h = (sampling_ratio > 0)
+ ? sampling_ratio
+ : ceil(roi_height / pooled_height); // e.g., = 2
+ int roi_bin_grid_w =
+ (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
+
+ // We do average (integral) pooling inside a bin
+ const float count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
+
+ // calculate multiplier
+ double real_multiplier = x_scale / (y_scale * 255.0 * count);
+ int32_t Y_multiplier;
+ int Y_shift;
+ QuantizeMultiplierSmallerThanOne(real_multiplier, &Y_multiplier, &Y_shift);
+
+ // we want to precalculate indeces and weights shared by all chanels,
+ // this is the key point of optimiation
+ std::vector<PreCalc> pre_calc(
+ roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
+ pre_calc_for_bilinear_interpolate(
+ height,
+ width,
+ pooled_height,
+ pooled_width,
+ roi_bin_grid_h,
+ roi_bin_grid_w,
+ roi_start_h,
+ roi_start_w,
+ bin_size_h,
+ bin_size_w,
+ roi_bin_grid_h,
+ roi_bin_grid_w,
+ pre_calc);
+
+ const uint8_t* offset_bottom_data =
+ bottom_data + roi_batch_ind * channels * height * width;
+ int pre_calc_index = 0;
+ for (int ph = 0; ph < pooled_height; ph++) {
+ for (int pw = 0; pw < pooled_width; pw++) {
+ vector<int32_t> acc_buffer(channels, 0);
+
+ for (int iy = 0; iy < roi_bin_grid_h; iy++) {
+ for (int ix = 0; ix < roi_bin_grid_w; ix++) {
+ PreCalc pc = pre_calc[pre_calc_index];
+
+ const uint8_t* data_1 = offset_bottom_data + channels * pc.pos1;
+ const uint8_t* data_2 = offset_bottom_data + channels * pc.pos2;
+ const uint8_t* data_3 = offset_bottom_data + channels * pc.pos3;
+ const uint8_t* data_4 = offset_bottom_data + channels * pc.pos4;
+ for (int c = 0; c < channels; ++c) {
+ acc_buffer[c] += (uint32_t)(pc.w1) * (uint32_t)(data_1[c]);
+ acc_buffer[c] += (uint32_t)(pc.w2) * (uint32_t)(data_2[c]);
+ acc_buffer[c] += (uint32_t)(pc.w3) * (uint32_t)(data_3[c]);
+ acc_buffer[c] += (uint32_t)(pc.w4) * (uint32_t)(data_4[c]);
+
+ // w_1..4 are all multiplied by 255.0
+ acc_buffer[c] -= x_offset * 255.0;
+ }
+
+ pre_calc_index += 1;
+ }
+ }
+ int index_nhw = index_n + (ph * pooled_width + pw) * channels;
+ uint8_t* out_ptr = top_data + index_nhw;
+ for (int c = 0; c < channels; ++c) {
+ int32_t a_mul = MultiplyByQuantizedMultiplierSmallerThanOne(
+ acc_buffer[c], Y_multiplier, Y_shift) +
+ y_offset;
+ int32_t clamped_a =
+ std::min<int32_t>(255, std::max<int32_t>(0, a_mul));
+ out_ptr[c] = static_cast<uint8_t>(clamped_a);
+ }
+ } // for pw
+ } // for ph
+ } // for n
+}
+
+} // namespace
+
+class Int8RoIAlignOp final : public Operator<CPUContext> {
+ public:
+ Int8RoIAlignOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CPUContext>(operator_def, ws),
+ order_(StringToStorageOrder(
+ this->template GetSingleArgument<string>("order", "NHWC"))),
+ spatial_scale_(
+ this->template GetSingleArgument<float>("spatial_scale", 1.)),
+ pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),
+ pooled_width_(this->template GetSingleArgument<int>("pooled_w", 1)),
+ sampling_ratio_(
+ this->template GetSingleArgument<int>("sampling_ratio", -1)) {
+ DCHECK_GT(spatial_scale_, 0);
+ DCHECK_GT(pooled_height_, 0);
+ DCHECK_GT(pooled_width_, 0);
+ DCHECK_GE(sampling_ratio_, 0);
+ // only supports NHWC
+ CAFFE_ENFORCE(order_ == StorageOrder::NHWC);
+ }
+
+ bool RunOnDevice() override {
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>(); // Input, NHWC
+ auto& R = Input(1); // RoIs
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>(); // RoI pooled
+ // calculate multiplier
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ Y->scale = Y_scale;
+ Y->zero_point = Y_offset;
+
+ if (R.size() == 0) {
+ // Handle empty rois
+ Y->t.Resize(0, pooled_height_, pooled_width_, X.t.dim32(3));
+ // The following mutable_data calls are needed to allocate the tensors
+ Y->t.mutable_data<uint8_t>();
+ return true;
+ }
+
+ CAFFE_ENFORCE_EQ(R.ndim(), 2);
+ // if R has 5 columns, the first column is the index, otherwise 0
+ CAFFE_ENFORCE(R.dim32(1) == 4 || R.dim32(1) == 5);
+
+ assert(sampling_ratio_ >= 0);
+
+ // only supports NHWC now
+ Y->t.Resize(R.dim32(0), pooled_height_, pooled_width_, X.t.dim32(3));
+ int output_size = Y->t.size();
+
+ ROIAlignForward(
+ output_size,
+ X.t.data<uint8_t>(),
+ spatial_scale_,
+ X.t.dim32(3),
+ X.t.dim32(1),
+ X.t.dim32(2),
+ pooled_height_,
+ pooled_width_,
+ sampling_ratio_,
+ R.data<float>(),
+ R.dim32(1),
+ Y->t.mutable_data<uint8_t>(),
+ X.scale,
+ Y_scale,
+ X.zero_point,
+ Y_offset,
+ order_);
+
+ return true;
+ }
+
+ protected:
+ StorageOrder order_;
+ float spatial_scale_;
+ int pooled_height_;
+ int pooled_width_;
+ int sampling_ratio_;
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_ROI_ALIGN_OP_H_
diff --git a/caffe2/operators/quantized/int8_roi_align_op_test.cc b/caffe2/operators/quantized/int8_roi_align_op_test.cc
new file mode 100644
index 0000000000..e00c4aeebe
--- /dev/null
+++ b/caffe2/operators/quantized/int8_roi_align_op_test.cc
@@ -0,0 +1,62 @@
+#include "caffe2/operators/quantized/int8_test_utils.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+TEST(Int8RoIAlign, RoIAlign) {
+ const int N = 2;
+ const int C = 3;
+ const int H = 100;
+ const int W = 110;
+ auto XQ = q({N, H, W, C});
+ XQ->scale = 0.01f;
+ XQ->zero_point = 127;
+ auto X = dq(*XQ);
+ const int n_rois = 10;
+ Workspace ws;
+ vector<float> rois_array;
+ for (int n = 0; n < n_rois; n++) {
+ rois_array.push_back(randomInt(0, N - 1));
+ int w1 = randomInt(0, W);
+ int w2 = randomInt(0, W);
+ int h1 = randomInt(0, H);
+ int h2 = randomInt(0, H);
+ rois_array.push_back(std::min(w1, w2));
+ rois_array.push_back(std::max(h1, h2));
+ rois_array.push_back(std::min(w1, w2));
+ rois_array.push_back(std::max(h1, h2));
+ }
+ add_input({n_rois, 5}, rois_array, "RoIs", &ws);
+ auto xop = CreateOperatorDef(
+ "RoIAlign",
+ "",
+ {"X", "RoIs"},
+ {"Y"},
+ {MakeArgument<float>("spatial_scale", 0.25f),
+ MakeArgument<int>("pooled_h", 2),
+ MakeArgument<int>("pooled_w", 2),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("sampling_ratio", 2)});
+ auto op = CreateOperatorDef(
+ "Int8RoIAlign",
+ "",
+ {"XQ", "RoIs"},
+ {"YQ"},
+ {MakeArgument<float>("spatial_scale", 0.25f),
+ MakeArgument<int>("pooled_h", 2),
+ MakeArgument<int>("pooled_w", 2),
+ MakeArgument<int>("sampling_ratio", 2),
+ MakeArgument<int>("Y_zero_point", 127),
+ MakeArgument<float>("Y_scale", 0.01f)});
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ // we cant make sure delta is within XQ->scale since there is interpolation
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, 4 * XQ->scale);
+}
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_simd.h b/caffe2/operators/quantized/int8_simd.h
new file mode 100644
index 0000000000..63d6fd0cab
--- /dev/null
+++ b/caffe2/operators/quantized/int8_simd.h
@@ -0,0 +1,63 @@
+#pragma once
+
+// We want to allow 128-bit wide SIMD if either NEON is available (as
+// detected by GEMMLOWP_NEON), or whether SSE4.2 and Clang is
+// available (in which case we will use the neon_sse.h library to
+// share source between the two implementations). We use SSE4.2 to
+// ensure we can use the full neon2sse library, and we use Clang as
+// GCC has issues correctly compiling some parts of the neon2sse
+// library.
+
+// Otherwise, the INT8_NEON_SIMD variable will be undefined.
+
+#include "gemmlowp/fixedpoint/fixedpoint.h"
+#include "gemmlowp/public/gemmlowp.h"
+
+#ifdef GEMMLOWP_NEON
+#define INT8_NEON_SIMD
+#endif
+
+#if defined(__SSE4_2__) && defined(__clang__)
+#define INT8_NEON_SIMD
+
+#include "neon2sse.h"
+// Add GEMMLOWP SIMD type wrappers for the NEON2SSE SIMD types.
+
+namespace gemmlowp {
+template <>
+struct FixedPointRawTypeTraits<int32x4_t> {
+ typedef std::int32_t ScalarRawType;
+ static const int kLanes = 4;
+};
+
+template <>
+inline int32x4_t Dup<int32x4_t>(std::int32_t x) {
+ return vdupq_n_s32(x);
+}
+
+template <>
+inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) {
+ return vandq_s32(a, b);
+}
+
+template <>
+inline int32x4_t Add(int32x4_t a, int32x4_t b) {
+ return vaddq_s32(a, b);
+}
+
+template <>
+inline int32x4_t ShiftRight(int32x4_t a, int offset) {
+ return vshlq_s32(a, vdupq_n_s32(-offset));
+}
+
+template <>
+inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) {
+ return vreinterpretq_s32_u32(vcltq_s32(a, b));
+}
+
+template <>
+inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) {
+ return vreinterpretq_s32_u32(vcgtq_s32(a, b));
+}
+} // namespace gemmlowp
+#endif
diff --git a/caffe2/operators/quantized/int8_slice_op.cc b/caffe2/operators/quantized/int8_slice_op.cc
new file mode 100644
index 0000000000..38e9714b63
--- /dev/null
+++ b/caffe2/operators/quantized/int8_slice_op.cc
@@ -0,0 +1,44 @@
+#include "caffe2/operators/quantized/int8_slice_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8Slice, int8::Int8SliceOp);
+
+OPERATOR_SCHEMA(Int8Slice)
+ .NumInputs(1, 3)
+ .NumOutputs(1)
+ .SetDoc(R"DOC(
+Produces a slice of the input Int8 tensor. Currently, only slicing in a single
+dimension is supported.
+Slices are passed as 2 1D vectors or as two keyword argument lists with starting
+and end indices for each dimension of the input `data` tensor. If a negative
+value is passed for any of the start or end indices, it represents the number of
+elements before the end of that dimension. End indices are non-inclusive unless
+negative (end index -1 means up to and including the last element).
+
+
+Example:
+
+ data = [
+ [1, 2, 3, 4],
+ [5, 6, 7, 8],
+ ]
+ starts = [0, 1]
+ ends = [-1, 3]
+
+ result = [
+ [2, 3],
+ [6, 7],
+ ]
+)DOC")
+ .Input(0, "data", "Int8 Tensor of data to extract slices from.")
+ .Input(1, "starts", "1D tensor: start-indices for each dimension of data.")
+ .Input(2, "ends", "1D tensor: end-indices for each dimension of data.")
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .Arg("starts", "List of starting indices")
+ .Arg("ends", "List of ending indices")
+ .Output(0, "output", "Sliced Int8 data tensor.")
+ .InheritOnnxSchema("Slice");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_slice_op.h b/caffe2/operators/quantized/int8_slice_op.h
new file mode 100644
index 0000000000..10b5b05141
--- /dev/null
+++ b/caffe2/operators/quantized/int8_slice_op.h
@@ -0,0 +1,71 @@
+#ifndef CAFFE2_OPERATORS_INT8_SLICE_OP_H_
+#define CAFFE2_OPERATORS_INT8_SLICE_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+#include "caffe2/operators/slice_op.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+class Int8SliceOp final : public SliceOp<CPUContext> {
+ public:
+ Int8SliceOp(const OperatorDef& operator_def, Workspace* ws)
+ : SliceOp(operator_def, ws) {}
+
+ bool RunOnDevice() override {
+ if (InputSize() > 1) {
+ return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(1));
+ } else {
+ return DoRunWithType<int64_t>();
+ }
+ }
+
+ template <typename SIndex>
+ bool DoRunWithType() {
+ if (InputSize() > 1) {
+ starts_host_.CopyFrom(Input(1));
+ ends_host_.CopyFrom(Input(2));
+ } else {
+ if (!statically_inited_) {
+ CAFFE_ENFORCE(HasArgument("starts"));
+ CAFFE_ENFORCE(HasArgument("ends"));
+ CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
+
+ starts_host_.Resize(starts_.size());
+ ends_host_.Resize(ends_.size());
+
+ memcpy(
+ starts_host_.template mutable_data<SIndex>(),
+ starts_.data(),
+ sizeof(SIndex) * starts_.size());
+ memcpy(
+ ends_host_.template mutable_data<SIndex>(),
+ ends_.data(),
+ sizeof(SIndex) * ends_.size());
+ statically_inited_ = true;
+ }
+ }
+
+ auto& X = Inputs()[0]->Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->GetMutable<Int8TensorCPU>();
+ int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
+ auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ CHECK_EQ(Y_offset, X.zero_point);
+ CHECK_EQ(Y_scale, X.scale);
+ Y->scale = Y_scale;
+ Y->zero_point = Y_offset;
+
+ return SliceImpl<SIndex, CPUContext>(
+ &Y->t, X.t, starts_host_, ends_host_, &context_);
+ }
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_SLICE_OP_H_
diff --git a/caffe2/operators/quantized/int8_softmax_op.cc b/caffe2/operators/quantized/int8_softmax_op.cc
new file mode 100644
index 0000000000..42285a3b63
--- /dev/null
+++ b/caffe2/operators/quantized/int8_softmax_op.cc
@@ -0,0 +1,46 @@
+#include "caffe2/operators/quantized/int8_softmax_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Int8Softmax, int8::Int8SoftmaxOp);
+
+OPERATOR_SCHEMA(Int8Softmax)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Arg("Y_scale", "Output tensor quantization scale")
+ .Arg("Y_zero_point", "Output tensor quantization offset")
+ .IdenticalTypeAndShape()
+ .SetDoc(R"DOC(
+The operator computes the softmax normalized values for each layer in the batch
+ of the given input. The input is a 2-D tensor (Tensor<float>) of size
+(batch_size x input_feature_dimensions). The output tensor has the same shape
+and contains the softmax normalized values of the corresponding input.
+
+X does not need to explicitly be a 2D vector; rather, it will be
+coerced into one. For an arbitrary n-dimensional tensor
+X \in [a_0, a_1, ..., a_{k-1}, a_k, ..., a_{n-1}] and k is
+the axis provided, then X will be coerced into a 2-dimensional tensor with
+dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. For the default
+case where axis=1, this means the X tensor will be coerced into a 2D tensor
+of dimensions [a_0, a_1 * ... * a_{n-1}], where a_0 is often the batch size.
+In this situation, we must have a_0 = N and a_1 * ... * a_{n-1} = D.
+Each of these dimensions must be matched correctly, or else the operator
+will throw errors.
+)DOC")
+ .Arg(
+ "axis",
+ "(int) default to 1; describes the axis of the inputs when coerced "
+ "to 2D; defaults to one because the 0th axis most likely describes "
+ "the batch_size")
+ .Input(
+ 0,
+ "input",
+ "The input tensor that's coerced into a 2D matrix of size (NxD) "
+ "as described above.")
+ .Output(
+ 0,
+ "output",
+ "The softmax normalized output values with the same "
+ "shape as input tensor.");
+
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_softmax_op.h b/caffe2/operators/quantized/int8_softmax_op.h
new file mode 100644
index 0000000000..0b6b9d68df
--- /dev/null
+++ b/caffe2/operators/quantized/int8_softmax_op.h
@@ -0,0 +1,227 @@
+#ifndef CAFFE2_OPERATORS_INT8_SOFTMAX_OP_H_
+#define CAFFE2_OPERATORS_INT8_SOFTMAX_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+#include "caffe2/operators/reshape_op.h"
+
+namespace caffe2 {
+
+namespace int8 {
+
+namespace {
+
+/*
+ * Implementation based on TensorFlow Lite kernels:
+ * - Repo: https://github.com/tensorflow/tensorflow
+ * - Path: tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+ * - Hash: d4ad9c73969c45d1a224ebfc43eb645b9860216b
+ */
+
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+void QuantizeMultiplierGreaterThanOne(
+ double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* left_shift) {
+ CHECK(double_multiplier > 1.);
+ const double q = std::frexp(double_multiplier, left_shift);
+ auto q_fixed = static_cast<int64_t>(Round(q * (1ll << 31)));
+ CHECK(q_fixed <= (1ll << 31));
+ if (q_fixed == (1ll << 31)) {
+ q_fixed /= 2;
+ ++*left_shift;
+ }
+ CHECK_GE(*left_shift, 0);
+ CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
+ *quantized_multiplier = static_cast<int32_t>(q_fixed);
+}
+
+inline int32_t MultiplyByQuantizedMultiplierGreaterThanOne(
+ int32_t x,
+ int32_t quantized_multiplier,
+ int left_shift) {
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return SaturatingRoundingDoublingHighMul(
+ x * (1 << left_shift), quantized_multiplier);
+}
+
+void PreprocessSoftmaxScaling(
+ double beta,
+ double input_scale,
+ int input_integer_bits,
+ int32_t* quantized_multiplier,
+ int* left_shift) {
+ // If the overall multiplier (input and beta) is large, then exp() of an
+ // input difference of 1 scaled by this will be large. In other words, we
+ // can cap the multiplier and know that, when it is used, the output will be
+ // (round to) zero wherever the input is not at the maximum value.
+
+ // If the overall scale is less than one, and input_integer_bits=0, then the
+ // result is double equivalent of Q0.31 (actually with more precision). Thus
+ // this generates a Q(input_integer_bits).(31-input_integer_bits)
+ // representation.
+ const double input_beta_real_multiplier = std::min(
+ beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
+
+ QuantizeMultiplierGreaterThanOne(
+ input_beta_real_multiplier, quantized_multiplier, left_shift);
+}
+
+int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
+ const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
+ (1ll << (31 - input_integer_bits)) / (1ll << input_left_shift);
+ // Tighten bound using floor. Suppose that we could use the exact value.
+ // After scaling the difference, the result would be at the maximum. Thus we
+ // must ensure that our value has lower magnitude.
+ return static_cast<int>(std::floor(max_input_rescaled));
+}
+
+void Int8Softmax(
+ const uint8_t* input_data,
+ const size_t N,
+ const size_t D,
+ int32_t input_beta_multiplier,
+ int32_t input_beta_left_shift,
+ int diff_min,
+ uint8_t* output_data) {
+ // The representation chosen for the input to the exp() function is Q5.26.
+ // We need to leave extra space since values that we skip might be as large as
+ // -32 before multiplying by input_beta_multiplier, and therefore as large as
+ // -16 afterwards. Note that exp(-8) is definitely not insignificant to
+ // accumulation, but exp(-16) definitely is.
+ static const int kScaledDiffIntegerBits = 5;
+ static const int kAccumulationIntegerBits = 12;
+ using FixedPointScaledDiff =
+ gemmlowp::FixedPoint<int32_t, kScaledDiffIntegerBits>;
+ using FixedPointAccum =
+ gemmlowp::FixedPoint<int32_t, kAccumulationIntegerBits>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32_t, 0>;
+
+ for (int n = 0; n < N; ++n) {
+ uint8_t max_in_row = 0;
+ for (int c = 0; c < D; ++c) {
+ max_in_row = std::max(max_in_row, input_data[n * D + c]);
+ }
+
+ FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
+ for (int c = 0; c < D; ++c) {
+ int32_t input_diff =
+ static_cast<int32_t>(input_data[n * D + c]) - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32_t input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+ sum_of_exps = sum_of_exps +
+ gemmlowp::Rescale<kAccumulationIntegerBits>(
+ exp_on_negative_values(scaled_diff_f8));
+ }
+ }
+
+ int32_t fixed_sum_of_exps = sum_of_exps.raw();
+ // TODO(starka): Use a NEON intrinsic like vclzq_u32 instead.
+ int headroom_plus_one =
+ __builtin_clz(static_cast<uint32_t>(fixed_sum_of_exps));
+ // This is the number of bits to the left of the binary point above 1.0.
+ // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
+ // no later adjustment will be needed.
+ int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
+ int32_t shifted_sum_minus_one = static_cast<int32_t>(
+ (static_cast<uint32_t>(fixed_sum_of_exps) << headroom_plus_one) -
+ (static_cast<uint32_t>(1) << 31));
+
+ FixedPoint0 shifted_scale;
+ // gemmlowp::one_over_one_plus_x_for_x_in_0_1 is defined on (0,
+ // 1), not [0, 1), so need to handle the case where
+ // shifted_sum_minus_one is exactly 0.
+ if (shifted_sum_minus_one == 0) {
+ shifted_scale = FixedPoint0::One();
+ } else {
+ shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
+ FixedPoint0::FromRaw(shifted_sum_minus_one));
+ }
+
+ for (int c = 0; c < D; ++c) {
+ int32_t input_diff =
+ static_cast<int32_t>(input_data[n * D + c]) - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32_t input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+
+ FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
+ int32_t unsat_output = gemmlowp::RoundingDivideByPOT(
+ (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
+
+ output_data[n * D + c] = std::max(std::min(unsat_output, 255), 0);
+
+ } else {
+ output_data[n * D + c] = 0;
+ }
+ }
+ }
+}
+
+} // namespace
+
+class Int8SoftmaxOp final : public Operator<CPUContext> {
+ public:
+ Int8SoftmaxOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CPUContext>(operator_def, ws) {}
+
+ bool RunOnDevice() override {
+ const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
+ auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
+ const int32_t Y_offset =
+ this->template GetSingleArgument<int>("Y_zero_point", 0);
+ const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
+ CHECK_EQ(Y_offset, 0);
+ CHECK_EQ(Y_scale, 1. / 256);
+
+ static const int kScaledDiffIntegerBits = 5;
+ Y->scale = Y_scale;
+ Y->zero_point = Y_offset;
+ Y->t.ResizeLike(X.t);
+ int32_t input_multiplier;
+ int input_left_shift;
+ PreprocessSoftmaxScaling(
+ 1.0 /*params->beta*/,
+ X.scale,
+ kScaledDiffIntegerBits,
+ &input_multiplier,
+ &input_left_shift);
+ const int diff_min =
+ -1.0 * CalculateInputRadius(kScaledDiffIntegerBits, input_left_shift);
+ Int8Softmax(
+ X.t.data<uint8_t>(),
+ X.t.dim(0),
+ X.t.size() / X.t.dim(0),
+ input_multiplier,
+ input_left_shift,
+ diff_min,
+ Y->t.mutable_data<uint8_t>());
+ return true;
+ }
+};
+
+} // namespace int8
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_INT8_SOFTMAX_OP_H_
diff --git a/caffe2/operators/quantized/int8_test.cc b/caffe2/operators/quantized/int8_test.cc
new file mode 100644
index 0000000000..85933956e3
--- /dev/null
+++ b/caffe2/operators/quantized/int8_test.cc
@@ -0,0 +1,858 @@
+#include "caffe2/operators/conv_pool_op_base.h"
+#include "caffe2/operators/conv_transpose_unpool_op_base.h"
+#include "caffe2/operators/spatial_batch_norm_op.h"
+
+#include "caffe2/operators/quantized/int8_test_utils.h"
+#include "caffe2/operators/quantized/int8_utils.h"
+
+namespace caffe2 {
+
+// How to test
+
+// Generate int8 tensor
+// Convert to fp32
+// Run with int8 backend
+// Dequantize result to fp32
+// Run with fp32 backend
+// Compare results.
+
+// for quantized Add, the error shouldn't exceed 2 * scale
+
+TEST(Int8, ReLU) {
+ auto XQ = q({1, 224, 224, 3});
+ auto X = dq(*XQ);
+ auto xop = CreateOperatorDef("Relu", "", {"X"}, {"Y"});
+ auto op = CreateOperatorDef(
+ "Int8Relu",
+ "",
+ {"XQ"},
+ {"YQ"},
+ {MakeArgument<int>("Y_zero_point", XQ->zero_point),
+ MakeArgument<float>("Y_scale", XQ->scale)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_EQ(*YA, YE);
+}
+
+// LeakyReLU isn't build in xplat, so this fails buck test
+// xplat/caffe2:caffe2_testAndroid
+TEST(Int8, DISABLED_LeakyReLU) {
+ auto XQ = q({1, 224, 224, 3});
+ auto X = dq(*XQ);
+ const float alpha = 0.1;
+ auto xop = CreateOperatorDef(
+ "LeakyRelu", "", {"X"}, {"Y"}, {MakeArgument<float>("alpha", alpha)});
+ auto op = CreateOperatorDef(
+ "Int8LeakyRelu",
+ "",
+ {"XQ"},
+ {"YQ"},
+ {MakeArgument<float>("alpha", alpha),
+ MakeArgument<int>("Y_zero_point", XQ->zero_point),
+ MakeArgument<float>("Y_scale", XQ->scale)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, addErrorTolerance(YQ.scale));
+}
+
+TEST(Int8, Softmax) {
+ auto XQ = q({1, 2, 1, 3});
+ auto X = dq(*XQ);
+ auto xop = CreateOperatorDef("Softmax", "", {"X"}, {"Y"});
+ auto op = CreateOperatorDef(
+ "Int8Softmax",
+ "",
+ {"XQ"},
+ {"YQ"},
+ {MakeArgument<int>("Y_zero_point", 0),
+ MakeArgument<float>("Y_scale", 1.0 / 256)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ EXPECT_EQ(YQ.scale, 1.0 / 256);
+ EXPECT_EQ(YQ.zero_point, 0);
+
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, addErrorTolerance(YQ.scale));
+}
+
+TEST(Int8, MaxPool) {
+ auto XQ = q({1, 25, 25, 16});
+ auto X = dq(*XQ);
+ auto xop = CreateOperatorDef(
+ "MaxPool",
+ "",
+ {"X"},
+ {"Y"},
+ {MakeArgument<int>("kernel", 2), MakeArgument<string>("order", "NHWC")});
+ auto op = CreateOperatorDef(
+ "Int8MaxPool",
+ "",
+ {"XQ"},
+ {"YQ"},
+ {MakeArgument<int>("kernel", 2),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("Y_zero_point", XQ->zero_point),
+ MakeArgument<float>("Y_scale", XQ->scale)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_EQ(*YA, YE);
+}
+
+TEST(Int8, AveragePool) {
+ auto XQ = q({1, 25, 25, 16});
+ auto X = dq(*XQ);
+ auto xop = CreateOperatorDef(
+ "AveragePool",
+ "",
+ {"X"},
+ {"Y"},
+ {MakeArgument<int>("kernel", 2), MakeArgument<string>("order", "NHWC")});
+ auto op = CreateOperatorDef(
+ "Int8AveragePool",
+ "",
+ {"XQ"},
+ {"YQ"},
+ {MakeArgument<int>("kernel", 2),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("Y_zero_point", XQ->zero_point),
+ MakeArgument<float>("Y_scale", XQ->scale)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, addErrorTolerance(XQ->scale));
+}
+
+TEST(Int8, ResizeNearest) {
+ auto XQ = q({1, 25, 25, 16});
+ auto X = dq(*XQ);
+ auto xop = CreateOperatorDef(
+ "ResizeNearest",
+ "",
+ {"XT"},
+ {"YT"},
+ {MakeArgument<float>("width_scale", 2),
+ MakeArgument<float>("height_scale", 2)});
+ auto op = CreateOperatorDef(
+ "Int8ResizeNearest",
+ "",
+ {"XQ"},
+ {"YQ"},
+ {MakeArgument<float>("width_scale", 2),
+ MakeArgument<float>("height_scale", 2),
+ MakeArgument<int>("Y_zero_point", XQ->zero_point),
+ MakeArgument<float>("Y_scale", XQ->scale)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ ws.RunOperatorOnce(CreateOperatorDef("NHWC2NCHW", "", {"X"}, {"XT"}));
+ ws.RunOperatorOnce(xop);
+ ws.RunOperatorOnce(CreateOperatorDef("NCHW2NHWC", "", {"YT"}, {"Y"}));
+ ws.RunOperatorOnce(op);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_EQ(*YA, YE);
+}
+
+TEST(Int8, ChannelShuffle) {
+ auto XQ = q({2, 25, 25, 32});
+ auto X = dq(*XQ);
+ auto xop = CreateOperatorDef(
+ "ChannelShuffle",
+ "",
+ {"XT"},
+ {"YT"},
+ {
+ MakeArgument<int>("kernel", 1),
+ MakeArgument<int>("group", 4),
+ MakeArgument<std::string>("order", "NCHW"),
+ });
+ auto op = CreateOperatorDef(
+ "Int8ChannelShuffle",
+ "",
+ {"XQ"},
+ {"YQ"},
+ {
+ MakeArgument<int>("kernel", 1),
+ MakeArgument<int>("group", 4),
+ MakeArgument<std::string>("order", "NHWC"),
+ MakeArgument<int>("Y_zero_point", XQ->zero_point),
+ MakeArgument<float>("Y_scale", XQ->scale),
+ });
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ ws.RunOperatorOnce(CreateOperatorDef("NHWC2NCHW", "", {"X"}, {"XT"}));
+ ws.RunOperatorOnce(xop);
+ ws.RunOperatorOnce(CreateOperatorDef("NCHW2NHWC", "", {"YT"}, {"Y"}));
+ ws.RunOperatorOnce(op);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_EQ(*YA, YE);
+}
+
+TEST(Int8, Concat) {
+ auto XQ0 = q({2, 25, 25, 16});
+ auto X0 = dq(*XQ0);
+ auto XQ1 = q({2, 25, 25, 24});
+ auto X1 = dq(*XQ1);
+ auto xop = CreateOperatorDef(
+ "Concat",
+ "",
+ {"XT0", "XT1"},
+ {"YT", "_"},
+ {
+ MakeArgument<std::string>("order", "NCHW"),
+ });
+ auto op = CreateOperatorDef(
+ "Int8Concat",
+ "",
+ {"XQ0", "XQ1"},
+ {"YQ", "_"},
+ {
+ MakeArgument<std::string>("order", "NHWC"),
+ MakeArgument<int>("Y_zero_point", XQ0->zero_point),
+ MakeArgument<float>("Y_scale", XQ0->scale),
+ });
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ0")->GetMutable<int8::Int8TensorCPU>(), *XQ0);
+ int8Copy(ws.CreateBlob("XQ1")->GetMutable<int8::Int8TensorCPU>(), *XQ1);
+ BlobGetMutableTensor(ws.CreateBlob("X0"), CPU)->CopyFrom(*X0);
+ BlobGetMutableTensor(ws.CreateBlob("X1"), CPU)->CopyFrom(*X1);
+ ws.RunOperatorOnce(CreateOperatorDef("NHWC2NCHW", "", {"X0"}, {"XT0"}));
+ ws.RunOperatorOnce(CreateOperatorDef("NHWC2NCHW", "", {"X1"}, {"XT1"}));
+ ws.RunOperatorOnce(xop);
+ ws.RunOperatorOnce(CreateOperatorDef("NCHW2NHWC", "", {"YT"}, {"Y"}));
+ ws.RunOperatorOnce(op);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_EQ(*YA, YE);
+}
+
+TEST(Int8, Add) {
+ auto XQ0 = q({1, 10, 10, 20});
+ auto XQ1 = q({1, 10, 10, 20});
+ auto X0 = dq(*XQ0);
+ auto X1 = dq(*XQ1);
+ auto xop = CreateOperatorDef("Add", "", {"X0", "X1"}, {"Y"});
+ const auto Y_scale = 2 * std::max(XQ0->scale, XQ1->scale);
+ auto op = CreateOperatorDef(
+ "Int8Add",
+ "",
+ {"XQ0", "XQ1"},
+ {"YQ"},
+ {MakeArgument<int>("Y_zero_point", XQ0->zero_point),
+ MakeArgument<float>("Y_scale", Y_scale)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ0")->GetMutable<int8::Int8TensorCPU>(), *XQ0);
+ int8Copy(ws.CreateBlob("XQ1")->GetMutable<int8::Int8TensorCPU>(), *XQ1);
+ BlobGetMutableTensor(ws.CreateBlob("X0"), CPU)->CopyFrom(*X0);
+ BlobGetMutableTensor(ws.CreateBlob("X1"), CPU)->CopyFrom(*X1);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, addErrorTolerance(Y_scale));
+}
+
+TEST(Int8, SumRelu) {
+ auto XQ0 = q({1, 10, 10, 20});
+ auto XQ1 = q({1, 10, 10, 20});
+ auto X0 = dq(*XQ0);
+ auto X1 = dq(*XQ1);
+ auto xop = CreateOperatorDef("Sum", "", {"X0", "X1"}, {"Y"});
+ auto rlxop = CreateOperatorDef("Relu", "", {"Y"}, {"Y"});
+ const auto Y_scale = 2 * std::max(XQ0->scale, XQ1->scale);
+ auto op = CreateOperatorDef(
+ "Int8SumRelu",
+ "",
+ {"XQ0", "XQ1"},
+ {"YQ"},
+ {MakeArgument<int>("Y_zero_point", XQ0->zero_point),
+ MakeArgument<float>("Y_scale", Y_scale)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ0")->GetMutable<int8::Int8TensorCPU>(), *XQ0);
+ int8Copy(ws.CreateBlob("XQ1")->GetMutable<int8::Int8TensorCPU>(), *XQ1);
+ BlobGetMutableTensor(ws.CreateBlob("X0"), CPU)->CopyFrom(*X0);
+ BlobGetMutableTensor(ws.CreateBlob("X1"), CPU)->CopyFrom(*X1);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ ws.RunOperatorOnce(rlxop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, addErrorTolerance(Y_scale));
+}
+
+void setq(int8::Int8TensorCPU* dst, const std::vector<float>& vs) {
+ CHECK_EQ(vs.size(), dst->t.size());
+ for (auto i = 0; i < vs.size(); ++i) {
+ uint8_t vq = std::max(
+ std::numeric_limits<uint8_t>::min(),
+ std::min(
+ std::numeric_limits<uint8_t>::max(),
+ static_cast<uint8_t>(int8::Round(
+ static_cast<float>(dst->zero_point + (vs[i] / dst->scale))))));
+ dst->t.mutable_data<uint8_t>()[i] = vq;
+ }
+}
+
+void biassetq(int8::Int8TensorCPU* dst, const std::vector<float>& vs) {
+ CHECK_EQ(vs.size(), dst->t.size());
+ for (auto i = 0; i < vs.size(); ++i) {
+ int32_t vq = std::max(
+ std::numeric_limits<int32_t>::min(),
+ std::min(
+ std::numeric_limits<int32_t>::max(),
+ static_cast<int32_t>(int8::Round(
+ static_cast<float>(dst->zero_point + (vs[i] / dst->scale))))));
+ dst->t.mutable_data<int32_t>()[i] = vq;
+ }
+}
+
+// Use TFLite test vectors to ensure compatibility.
+TEST(Int8, Conv) {
+ auto XQ = q({2, 2, 4, 1});
+ XQ->scale = 0.5;
+ XQ->zero_point = 127;
+ setq(
+ XQ.get(),
+ std::vector<float>{1, 1, 1, 1, 2, 2, 2, 2, 1, 2, 3, 4, 1, 2, 3, 4});
+ auto WQ = q({3, 2, 2, 1});
+ WQ->scale = 0.5;
+ WQ->zero_point = 127;
+ setq(
+ WQ.get(),
+ {
+ 1,
+ 2,
+ 3,
+ 4,
+ -1,
+ 1,
+ -1,
+ 1,
+ -1,
+ -1,
+ 1,
+ 1,
+ });
+ auto BQ = biasq({3}, XQ->scale * WQ->scale);
+ biassetq(BQ.get(), {1, 2, 3});
+ auto X = dq(*XQ);
+ auto W = dq(*WQ);
+ auto B = biasdq(*BQ);
+ auto xop = CreateOperatorDef(
+ "Conv",
+ "",
+ {"X", "W", "B"},
+ {"Y"},
+ {MakeArgument<int>("kernel", 2),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("stride", 2)});
+ auto op = CreateOperatorDef(
+ "Int8Conv",
+ "",
+ {"XQ", "WQ", "BQ"},
+ {"YQ"},
+ {MakeArgument<int>("kernel", 2),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("stride", 2),
+ MakeArgument<int>("Y_zero_point", 127),
+ MakeArgument<float>("Y_scale", 1.0)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ int8Copy(ws.CreateBlob("WQ")->GetMutable<int8::Int8TensorCPU>(), *WQ);
+ int8Copy(ws.CreateBlob("BQ")->GetMutable<int8::Int8TensorCPU>(), *BQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ BlobGetMutableTensor(ws.CreateBlob("W"), CPU)->CopyFrom(*W);
+ BlobGetMutableTensor(ws.CreateBlob("B"), CPU)->CopyFrom(*B);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TRUE(
+ (std::vector<uint8_t>(
+ YQ.t.data<uint8_t>(), YQ.t.data<uint8_t>() + YQ.t.size()) ==
+ std::vector<uint8_t>{
+ 145, 129, 132, 145, 129, 132, 144, 131, 130, 164, 131, 130}));
+
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, 1.0e-5);
+}
+
+TEST(Int8, Grouped1x1Conv) {
+ auto XQ = q({1, 3, 2, 4});
+ XQ->scale = 0.5;
+ XQ->zero_point = 127;
+ setq(XQ.get(), std::vector<float>{1, 4, 3, 2, 9, 3, 8, 2, 6, 7, 8, 2,
+ 3, 8, 1, 7, 4, 2, 1, 3, 8, 5, 3, 1});
+
+ // G = 2
+ auto WQ = q({4, 1, 1, 2});
+ WQ->scale = 0.5;
+ WQ->zero_point = 127;
+ setq(WQ.get(), {1, 2, 3, 4, -1, -2, -3, -4});
+ auto BQ = biasq({4}, XQ->scale * WQ->scale);
+ biassetq(BQ.get(), {1, 2, 3, 4});
+ auto X = dq(*XQ);
+ auto W = dq(*WQ);
+ auto B = biasdq(*BQ);
+ auto xop = CreateOperatorDef(
+ "Conv",
+ "",
+ {"XT", "WT", "B"},
+ {"YT"},
+ {MakeArgument<int>("kernel", 1),
+ MakeArgument<string>("order", "NCHW"),
+ MakeArgument<int>("group", 2)});
+ auto op = CreateOperatorDef(
+ "Int8Conv",
+ "",
+ {"XQ", "WQ", "BQ"},
+ {"YQ"},
+ {MakeArgument<int>("kernel", 1),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("group", 2),
+ MakeArgument<int>("Y_zero_point", 127),
+ MakeArgument<float>("Y_scale", 1.0)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ int8Copy(ws.CreateBlob("WQ")->GetMutable<int8::Int8TensorCPU>(), *WQ);
+ int8Copy(ws.CreateBlob("BQ")->GetMutable<int8::Int8TensorCPU>(), *BQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ BlobGetMutableTensor(ws.CreateBlob("W"), CPU)->CopyFrom(*W);
+ BlobGetMutableTensor(ws.CreateBlob("B"), CPU)->CopyFrom(*B);
+ ws.RunOperatorOnce(op);
+
+ ws.RunOperatorOnce(CreateOperatorDef("NHWC2NCHW", "", {"X"}, {"XT"}));
+ // Need to transpose MxKHxKWx1 to Mx1xKHxKW
+ ws.RunOperatorOnce(CreateOperatorDef("NHWC2NCHW", "", {"W"}, {"WT"}));
+ ws.RunOperatorOnce(xop);
+ ws.RunOperatorOnce(CreateOperatorDef("NCHW2NHWC", "", {"YT"}, {"Y"}));
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, 1.0e-5);
+
+ // test repacking between runs
+ std::unique_ptr<OperatorBase> op_ptr(CreateOperator(op, &ws));
+ EXPECT_TRUE(op_ptr != nullptr);
+ for (auto it = 0; it < 3; ++it) {
+ EXPECT_TRUE(op_ptr->Run());
+ const auto& temp_YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto temp_YA = dq(temp_YQ);
+ const auto& temp_YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*temp_YA, temp_YE, 1.0e-5);
+ }
+}
+
+TEST(Int8, Conv2) {
+ auto XQ = q({1, 3, 6, 1});
+ XQ->scale = 0.5;
+ XQ->zero_point = 127;
+ setq(
+ XQ.get(),
+ std::vector<float>{
+ 3, 2, 1, -1, -2, -3, 4, 3, 2, -2, -3, -4, 5, 4, 3, -3, -4, -5});
+ auto WQ = q({1, 2, 2, 1});
+ WQ->scale = 0.5;
+ WQ->zero_point = 127;
+ setq(WQ.get(), {1, 2, 3, 4});
+ auto BQ = biasq({1}, XQ->scale * WQ->scale);
+ biassetq(BQ.get(), {-1});
+ auto X = dq(*XQ);
+ auto W = dq(*WQ);
+ auto B = biasdq(*BQ);
+ auto xop = CreateOperatorDef(
+ "Conv",
+ "",
+ {"X", "W", "B"},
+ {"Y"},
+ {MakeArgument<int>("kernel", 2),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("stride_w", 3),
+ MakeArgument<int>("stride_h", 1)});
+ auto op = CreateOperatorDef(
+ "Int8Conv",
+ "",
+ {"XQ", "WQ", "BQ"},
+ {"YQ"},
+ {MakeArgument<int>("kernel", 2),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("stride_w", 3),
+ MakeArgument<int>("stride_h", 1),
+ MakeArgument<int>("Y_zero_point", 127),
+ MakeArgument<float>("Y_scale", 1.0)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ int8Copy(ws.CreateBlob("WQ")->GetMutable<int8::Int8TensorCPU>(), *WQ);
+ int8Copy(ws.CreateBlob("BQ")->GetMutable<int8::Int8TensorCPU>(), *BQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ BlobGetMutableTensor(ws.CreateBlob("W"), CPU)->CopyFrom(*W);
+ BlobGetMutableTensor(ws.CreateBlob("B"), CPU)->CopyFrom(*B);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TRUE(
+ (std::vector<uint8_t>(
+ YQ.t.data<uint8_t>(), YQ.t.data<uint8_t>() + YQ.t.size()) ==
+ std::vector<uint8_t>{157, 103, 167, 93}));
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, 1.0e-5);
+}
+
+TEST(Int8, DepthwiseConv) {
+ auto XQ = q({1, 3, 2, 2});
+ XQ->scale = 0.5;
+ XQ->zero_point = 127;
+ // setq(XQ.get(), std::vector<float>{1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12});
+ setq(XQ.get(), std::vector<float>{1, 4, 3, 2, 9, 3, 8, 2, 6, 7, 8, 2});
+
+ auto WQ = q({2, 2, 2, 1});
+ WQ->scale = 0.5;
+ WQ->zero_point = 127;
+ setq(WQ.get(), {1, 2, 3, 4, -9, 10, -11, 12});
+ auto BQ = biasq({2}, XQ->scale * WQ->scale);
+ biassetq(BQ.get(), {1, 2});
+ auto X = dq(*XQ);
+ auto W = dq(*WQ);
+ auto B = biasdq(*BQ);
+ auto xop = CreateOperatorDef(
+ "Conv",
+ "",
+ {"XT", "WT", "B"},
+ {"YT"},
+ {MakeArgument<int>("kernel", 2),
+ MakeArgument<string>("order", "NCHW"),
+ MakeArgument<int>("group", 2)});
+ auto op = CreateOperatorDef(
+ "Int8Conv",
+ "",
+ {"XQ", "WQ", "BQ"},
+ {"YQ"},
+ {MakeArgument<int>("kernel", 2),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("group", 2),
+ MakeArgument<int>("Y_zero_point", 127),
+ MakeArgument<float>("Y_scale", 1.0)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ int8Copy(ws.CreateBlob("WQ")->GetMutable<int8::Int8TensorCPU>(), *WQ);
+ int8Copy(ws.CreateBlob("BQ")->GetMutable<int8::Int8TensorCPU>(), *BQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ BlobGetMutableTensor(ws.CreateBlob("W"), CPU)->CopyFrom(*W);
+ BlobGetMutableTensor(ws.CreateBlob("B"), CPU)->CopyFrom(*B);
+ ws.RunOperatorOnce(op);
+
+ ws.RunOperatorOnce(CreateOperatorDef("NHWC2NCHW", "", {"X"}, {"XT"}));
+ // Need to transpose MxKHxKWx1 to Mx1xKHxKW
+ ws.RunOperatorOnce(CreateOperatorDef("NHWC2NCHW", "", {"W"}, {"WT"}));
+ ws.RunOperatorOnce(xop);
+ ws.RunOperatorOnce(CreateOperatorDef("NCHW2NHWC", "", {"YT"}, {"Y"}));
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ for (auto i = 0; i < YA->size(); ++i) {
+ LOG(INFO) << YA->data<float>()[i];
+ LOG(INFO) << YE.data<float>()[i];
+ }
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, 1.0e-5);
+}
+
+TEST(Int8, ConvTranspose) {
+ auto XQ = q({1, 3, 6, 1});
+ XQ->scale = 0.5;
+ XQ->zero_point = 127;
+ setq(
+ XQ.get(),
+ std::vector<float>{
+ 3, 2, 1, -1, -2, -3, 4, 3, 2, -2, -3, -4, 5, 4, 3, -3, -4, -5});
+ auto WQ = q({1, 2, 2, 1});
+ WQ->scale = 0.5;
+ WQ->zero_point = 127;
+ setq(WQ.get(), {1, 2, 3, 4});
+ auto BQ = biasq({1}, XQ->scale * WQ->scale);
+ biassetq(BQ.get(), {-1});
+ auto X = dq(*XQ);
+ auto W = dq(*WQ);
+ auto B = biasdq(*BQ);
+ auto xop = CreateOperatorDef(
+ "ConvTranspose",
+ "",
+ {"X", "W", "B"},
+ {"Y"},
+ {MakeArgument<int>("kernel", 2),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("stride_w", 1),
+ MakeArgument<int>("stride_h", 2)});
+ auto op = CreateOperatorDef(
+ "Int8ConvTranspose",
+ "",
+ {"XQ", "WQ", "BQ"},
+ {"YQ"},
+ {MakeArgument<int>("kernel", 2),
+ MakeArgument<string>("order", "NHWC"),
+ MakeArgument<int>("stride_w", 1),
+ MakeArgument<int>("stride_h", 2),
+ MakeArgument<int>("Y_zero_point", 127),
+ MakeArgument<float>("Y_scale", 1.0)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ int8Copy(ws.CreateBlob("WQ")->GetMutable<int8::Int8TensorCPU>(), *WQ);
+ int8Copy(ws.CreateBlob("BQ")->GetMutable<int8::Int8TensorCPU>(), *BQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ BlobGetMutableTensor(ws.CreateBlob("W"), CPU)->CopyFrom(*W);
+ BlobGetMutableTensor(ws.CreateBlob("B"), CPU)->CopyFrom(*B);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, 1.0e-5);
+}
+
+TEST(Int8, FC) {
+ auto XQ = q({2, 10});
+ XQ->scale = 0.5;
+ XQ->zero_point = 127;
+ setq(XQ.get(), {1, 2, 3, 4, 5, 6, 7, 8, -9, -10,
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10});
+ auto WQ = q({3, 10});
+ WQ->scale = 0.5;
+ WQ->zero_point = 127;
+ setq(
+ WQ.get(),
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ auto BQ = biasq({3}, XQ->scale * WQ->scale);
+ biassetq(BQ.get(), {1, 2, 3});
+ auto X = dq(*XQ);
+ auto W = dq(*WQ);
+ auto B = biasdq(*BQ);
+ auto xop = CreateOperatorDef("FC", "", {"X", "W", "B"}, {"Y"}, {});
+ auto op = CreateOperatorDef(
+ "Int8FC",
+ "",
+ {"XQ", "WQ", "BQ"},
+ {"YQ"},
+ {MakeArgument<int>("Y_zero_point", 127),
+ MakeArgument<float>("Y_scale", 1.0)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ int8Copy(ws.CreateBlob("WQ")->GetMutable<int8::Int8TensorCPU>(), *WQ);
+ int8Copy(ws.CreateBlob("BQ")->GetMutable<int8::Int8TensorCPU>(), *BQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ BlobGetMutableTensor(ws.CreateBlob("W"), CPU)->CopyFrom(*W);
+ BlobGetMutableTensor(ws.CreateBlob("B"), CPU)->CopyFrom(*B);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ for (auto i = 0; i < YA->size(); ++i) {
+ LOG(INFO) << YA->data<float>()[i];
+ LOG(INFO) << YE.data<float>()[i];
+ }
+ EXPECT_TRUE(
+ (std::vector<uint8_t>(
+ YQ.t.data<uint8_t>(), YQ.t.data<uint8_t>() + YQ.t.size()) ==
+ std::vector<uint8_t>{151, 152, 153, 185, 186, 187}));
+}
+
+TEST(Int8, GivenTensorFill) {
+ vector<int64_t> shape = {1, 25, 25, 16};
+ auto XQ = q(shape);
+ auto X = dq(*XQ);
+ vector<float> v(
+ X->template data<float>(), X->template data<float>() + X->size());
+ std::string vq(
+ XQ->t.template data<uint8_t>(),
+ XQ->t.template data<uint8_t>() + XQ->t.size());
+ auto op = CreateOperatorDef(
+ "GivenTensorFill",
+ "",
+ {},
+ {"Y"},
+ {MakeArgument<vector<int64_t>>("shape", shape),
+ MakeArgument<vector<float>>("values", v)});
+ auto xop = CreateOperatorDef(
+ "Int8GivenTensorFill",
+ "",
+ {},
+ {"YQ"},
+ {MakeArgument<vector<int64_t>>("shape", shape),
+ MakeArgument<string>("values", vq),
+ MakeArgument<float>("Y_scale", XQ->scale),
+ MakeArgument<int32_t>("Y_zero_point", XQ->zero_point)});
+ Workspace ws;
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, 1.0e-5);
+}
+
+TEST(Int8, GivenIntTensorFill) {
+ vector<int64_t> shape = {32};
+ auto XQ = biasq(shape, 1. / 255 * 1. / 255);
+ auto X = biasdq(*XQ);
+ vector<float> v(
+ X->template data<float>(), X->template data<float>() + X->size());
+ vector<int32_t> vq(
+ XQ->t.template data<int32_t>(),
+ XQ->t.template data<int32_t>() + XQ->t.size());
+ auto op = CreateOperatorDef(
+ "GivenTensorFill",
+ "",
+ {},
+ {"Y"},
+ {MakeArgument<vector<int64_t>>("shape", shape),
+ MakeArgument<vector<float>>("values", v)});
+ auto xop = CreateOperatorDef(
+ "Int8GivenIntTensorFill",
+ "",
+ {},
+ {"YQ"},
+ {MakeArgument<vector<int64_t>>("shape", shape),
+ MakeArgument<vector<int32_t>>("values", vq),
+ MakeArgument<float>("Y_scale", XQ->scale),
+ MakeArgument<int32_t>("Y_zero_point", XQ->zero_point)});
+ Workspace ws;
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = biasdq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*YA, YE, 1.0e-5);
+}
+
+TEST(Int8, QuantDeQuant) {
+ vector<int64_t> shape = {1, 25, 25, 16};
+ auto XQ = q(shape);
+ auto X = dq(*XQ);
+ auto xop = CreateOperatorDef(
+ "Int8Quantize",
+ "",
+ {"X"},
+ {"XQ"},
+ {MakeArgument<float>("Y_scale", XQ->scale),
+ MakeArgument<int32_t>("Y_zero_point", XQ->zero_point)});
+ auto op = CreateOperatorDef("Int8Dequantize", "", {"XQ"}, {"X_x"});
+ Workspace ws;
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ ws.RunOperatorOnce(xop);
+ ws.RunOperatorOnce(op);
+ const auto& X_x = ws.GetBlob("X_x")->Get<TensorCPU>();
+ EXPECT_TENSOR_APPROX_EQ(*X, X_x, XQ->scale);
+}
+
+TEST(Int8, Reshape) {
+ auto XQ = q({1, 25, 25, 16});
+ auto xop = CreateOperatorDef(
+ "Int8Reshape",
+ "",
+ {"XQ"},
+ {"YQ", "old_shape"},
+ {MakeArgument("shape", vector<int64_t>{0, -1, 2000}),
+ MakeArgument<float>("Y_scale", XQ->scale),
+ MakeArgument<int32_t>("Y_zero_point", XQ->zero_point)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ EXPECT_EQ(YQ.t.sizes(), (vector<int64_t>{1, 5, 2000}));
+ EXPECT_EQ(YQ.scale, XQ->scale);
+ EXPECT_EQ(YQ.zero_point, XQ->zero_point);
+}
+
+TEST(Int8, Flatten) {
+ auto XQ = q({1, 25, 25, 16});
+ auto xop = CreateOperatorDef(
+ "Int8Flatten",
+ "",
+ {"XQ"},
+ {"YQ"},
+ {MakeArgument<int>("axis", 2),
+ MakeArgument<float>("Y_scale", XQ->scale),
+ MakeArgument<int32_t>("Y_zero_point", XQ->zero_point)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ EXPECT_EQ(YQ.t.sizes(), (vector<int64_t>{25, 400}));
+ EXPECT_EQ(YQ.scale, XQ->scale);
+ EXPECT_EQ(YQ.zero_point, XQ->zero_point);
+}
+
+TEST(Int8, Slice) {
+ auto XQ = q({1, 25, 25, 16});
+ auto X = dq(*XQ);
+ vector<int> starts = {0, 3, 0, 0};
+ vector<int> ends = {-1, 5, -1, -1};
+ auto xop = CreateOperatorDef(
+ "Slice",
+ "",
+ {"X"},
+ {"Y"},
+ {MakeArgument<vector<int>>("starts", starts),
+ MakeArgument<vector<int>>("ends", ends)});
+ auto op = CreateOperatorDef(
+ "Int8Slice",
+ "",
+ {"XQ"},
+ {"YQ"},
+ {MakeArgument<vector<int>>("starts", starts),
+ MakeArgument<vector<int>>("ends", ends),
+ MakeArgument<int>("Y_zero_point", XQ->zero_point),
+ MakeArgument<float>("Y_scale", XQ->scale)});
+ Workspace ws;
+ int8Copy(ws.CreateBlob("XQ")->GetMutable<int8::Int8TensorCPU>(), *XQ);
+ BlobGetMutableTensor(ws.CreateBlob("X"), CPU)->CopyFrom(*X);
+ ws.RunOperatorOnce(op);
+ ws.RunOperatorOnce(xop);
+ const auto& YQ = ws.GetBlob("YQ")->Get<int8::Int8TensorCPU>();
+ auto YA = dq(YQ);
+ const auto& YE = ws.GetBlob("Y")->Get<TensorCPU>();
+ EXPECT_TENSOR_EQ(*YA, YE);
+ EXPECT_EQ(YQ.t.sizes(), (vector<int64_t>{1, 2, 25, 16}));
+ EXPECT_EQ(YQ.scale, XQ->scale);
+ EXPECT_EQ(YQ.zero_point, XQ->zero_point);
+}
+} // namespace caffe2
diff --git a/caffe2/operators/quantized/int8_test_utils.h b/caffe2/operators/quantized/int8_test_utils.h
new file mode 100644
index 0000000000..7cd5bf2d58
--- /dev/null
+++ b/caffe2/operators/quantized/int8_test_utils.h
@@ -0,0 +1,118 @@
+#ifndef CAFFE2_INT8_TEST_UTILS_H_
+#define CAFFE2_INT8_TEST_UTILS_H_
+
+#include "caffe2/core/common.h"
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor.h"
+#include "caffe2/core/tensor_int8.h"
+
+#include <array>
+#include <cmath>
+#include <random>
+
+#include "gtest/gtest.h"
+
+namespace caffe2 {
+
+// for quantized Add, the error shouldn't exceed 2 * scale
+inline float addErrorTolerance(float scale) {
+ return 2 * scale;
+}
+
+inline std::unique_ptr<int8::Int8TensorCPU> q(
+ const std::vector<int64_t>& dims) {
+ auto r = caffe2::make_unique<int8::Int8TensorCPU>();
+ r->scale = 0.01;
+ r->zero_point = static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) / 2;
+ r->t.Resize(dims);
+ std::random_device rd;
+ std::mt19937 gen(rd());
+ std::uniform_int_distribution<uint8_t> dis;
+ for (auto i = 0; i < r->t.size(); ++i) {
+ r->t.mutable_data<uint8_t>()[i] = dis(gen);
+ }
+ return r;
+}
+
+inline std::unique_ptr<int8::Int8TensorCPU> biasq(
+ const std::vector<int64_t>& dims,
+ double scale) {
+ auto r = caffe2::make_unique<int8::Int8TensorCPU>();
+ r->scale = scale;
+ r->zero_point = 0;
+ r->t.Resize(dims);
+ std::random_device rd;
+ std::mt19937 gen(rd());
+ std::uniform_real_distribution<float> dis(-1, 1);
+ for (auto i = 0; i < r->t.size(); ++i) {
+ r->t.mutable_data<int32_t>()[i] =
+ static_cast<int32_t>(dis(gen) / scale + r->zero_point);
+ }
+ return r;
+}
+
+inline std::unique_ptr<TensorCPU> dq(const int8::Int8TensorCPU& XQ) {
+ auto r = caffe2::make_unique<Tensor>(CPU);
+ r->Resize(XQ.t.sizes());
+ for (auto i = 0; i < r->size(); ++i) {
+ r->mutable_data<float>()[i] =
+ (static_cast<int32_t>(XQ.t.data<uint8_t>()[i]) - XQ.zero_point) *
+ XQ.scale;
+ }
+ return r;
+}
+
+inline std::unique_ptr<TensorCPU> biasdq(const int8::Int8TensorCPU& XQ) {
+ auto r = caffe2::make_unique<Tensor>(CPU);
+ r->Resize(XQ.t.sizes());
+ for (auto i = 0; i < r->size(); ++i) {
+ r->mutable_data<float>()[i] =
+ (XQ.t.data<int32_t>()[i] - XQ.zero_point) * XQ.scale;
+ }
+ return r;
+}
+
+#define EXPECT_TENSOR_EQ(_YA, _YE) \
+ do { \
+ EXPECT_TRUE((_YA).sizes() == (_YE).sizes()); \
+ for (auto i = 0; i < (_YA).size(); ++i) { \
+ EXPECT_FLOAT_EQ((_YA).data<float>()[i], (_YE).data<float>()[i]); \
+ } \
+ } while (0);
+
+#define EXPECT_TENSOR_APPROX_EQ(_YA, _YE, _tol) \
+ do { \
+ EXPECT_TRUE((_YA).sizes() == (_YE).sizes()); \
+ for (auto i = 0; i < (_YA).size(); ++i) { \
+ EXPECT_NEAR((_YA).data<float>()[i], (_YE).data<float>()[i], (_tol)); \
+ } \
+ } while (0);
+
+inline void int8Copy(int8::Int8TensorCPU* dst, const int8::Int8TensorCPU& src) {
+ dst->zero_point = src.zero_point;
+ dst->scale = src.scale;
+ dst->t.CopyFrom(src.t);
+}
+
+inline void add_input(
+ const vector<int64_t>& shape,
+ const vector<float>& values,
+ const string& name,
+ Workspace* ws) {
+ // auto* t = ws->CreateBlob(name)->GetMutable<TensorCPU>();
+ auto t = caffe2::make_unique<Tensor>(CPU);
+ t->Resize(shape);
+ std::copy(values.begin(), values.end(), t->mutable_data<float>());
+ BlobGetMutableTensor(ws->CreateBlob(name), CPU)->CopyFrom(*t);
+}
+
+inline int randomInt(int a, int b) {
+ static std::random_device rd;
+ static std::mt19937 gen(rd());
+ return std::uniform_int_distribution<int>(a, b)(gen);
+}
+
+} // namespace caffe2
+
+#endif // CAFFE2_INT8_TEST_UTILS_H_
diff --git a/caffe2/operators/quantized/int8_utils.h b/caffe2/operators/quantized/int8_utils.h
new file mode 100644
index 0000000000..2cacb079d9
--- /dev/null
+++ b/caffe2/operators/quantized/int8_utils.h
@@ -0,0 +1,177 @@
+#ifndef CAFFE2_INT8_UTILS_H_
+#define CAFFE2_INT8_UTILS_H_
+
+#include <gemmlowp/public/gemmlowp.h>
+
+#include "caffe2/utils/threadpool/ThreadPool.h"
+#include "caffe2/utils/threadpool/WorkersPool.h"
+
+namespace caffe2 {
+
+/*
+ * Initialized QNNPACK (only once).
+ * Throws if initialization failed.
+ */
+void initQNNPACK();
+
+namespace int8 {
+
+/*
+ * Code here is partially derived from gemmlowp library
+ * (https://github.com/google/gemmlowp)
+ */
+
+// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+inline int32_t MultiplyByQuantizedMultiplierSmallerThanOne(
+ int32_t x,
+ int32_t quantized_multiplier,
+ int right_shift) {
+ using gemmlowp::RoundingDivideByPOT;
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ return RoundingDivideByPOT(
+ SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift);
+}
+
+#if defined(__ANDROID__) && !defined(__NDK_MAJOR__)
+template <class T>
+inline float Round(const float x) {
+ return ::nearbyintf(x);
+}
+inline double Round(const double x) {
+ return ::nearbyint(x);
+}
+#else
+template <class T>
+inline T Round(const T x) {
+ return std::nearbyint(x);
+}
+#endif
+
+inline uint8_t QuantizeUint8(float scale, int32_t zero_point, float value) {
+ const int32_t qmin = std::numeric_limits<uint8_t>::min();
+ const int32_t qmax = std::numeric_limits<uint8_t>::max();
+
+ auto r = zero_point + static_cast<int32_t>(Round(value / scale));
+ r = std::max(r, qmin);
+ r = std::min(r, qmax);
+ return static_cast<uint8_t>(r);
+}
+
+inline void QuantizeMultiplierSmallerThanOne(
+ double double_multiplier,
+ int32_t* quantized_multiplier,
+ int* right_shift) {
+ CHECK(double_multiplier >= 0.);
+ CHECK(double_multiplier < 1.);
+ if (double_multiplier == 0.) {
+ *quantized_multiplier = 0;
+ *right_shift = 0;
+ return;
+ }
+ CHECK(double_multiplier > 0.);
+ const double q = std::frexp(double_multiplier, right_shift);
+ *right_shift *= -1;
+
+ auto q_fixed = static_cast<int64_t>(Round(q * (1ll << 31)));
+ CHECK(q_fixed <= (1ll << 31));
+ if (q_fixed == (1ll << 31)) {
+ q_fixed /= 2;
+ --*right_shift;
+ }
+ CHECK_GE(*right_shift, 0);
+ CHECK_LE(q_fixed, std::numeric_limits<int32_t>::max());
+ *quantized_multiplier = static_cast<int32_t>(q_fixed);
+}
+
+// An adaptor to use the Caffe2 WorkersPool implementation for gemmlowp
+// multithreading functions.
+class C2GEMMContext : public gemmlowp::SingleThreadGemmContext {
+ class C2WorkersPool;
+
+ public:
+ C2GEMMContext(ThreadPool* pool) : threadPool_(pool), workersPool_(pool) {}
+ int max_num_threads() const {
+ CHECK(threadPool_);
+ return threadPool_->getNumThreads();
+ }
+ C2WorkersPool* workers_pool() {
+ return &workersPool_;
+ }
+
+ ThreadPool* threadPool() {
+ return threadPool_;
+ }
+
+ private:
+ class C2WorkersPool {
+ public:
+ C2WorkersPool(ThreadPool* pool) : pool_(pool) {}
+ void Execute(const std::vector<gemmlowp::Task*>& tasks) {
+ class C2Task : public Task {
+ public:
+ C2Task(gemmlowp::Task* task) : task_(task){};
+ virtual void Run() override {
+ CHECK(task_);
+ task_->Run();
+ }
+
+ private:
+ gemmlowp::Task* task_;
+ };
+ std::vector<std::shared_ptr<Task>> c2tasks;
+ c2tasks.reserve(tasks.size());
+ std::vector<gemmlowp::Allocator> allocators(tasks.size());
+
+ for (size_t i = 0; i < tasks.size(); ++i) {
+ auto* task = tasks[i];
+ CHECK(task);
+ task->local_allocator = &allocators[i];
+ c2tasks.push_back(std::shared_ptr<Task>(new C2Task(task)));
+ }
+ CHECK(pool_);
+ pool_->withPool([&](WorkersPool* pool) { pool->Execute(c2tasks); });
+ for (auto* t : tasks) {
+ delete t;
+ }
+ }
+
+ private:
+ ThreadPool* pool_;
+ };
+ ThreadPool* threadPool_;
+ C2WorkersPool workersPool_;
+};
+
+enum class Activation : uint8_t { NONE = 0, RELU = 1 };
+
+inline std::pair<uint8_t, uint8_t>
+activationLimits(float scale, int32_t zero_point, Activation Ac) {
+ switch (Ac) {
+ case Activation::NONE:
+ return {std::numeric_limits<uint8_t>::min(),
+ std::numeric_limits<uint8_t>::max()};
+ case Activation::RELU:
+ return {QuantizeUint8(scale, zero_point, 0.0),
+ std::numeric_limits<uint8_t>::max()};
+ default:
+ __builtin_unreachable();
+ }
+}
+
+} // namespace int8
+} // namespace caffe2
+
+#endif // CAFFE2_INT8_UTILS_H_