summaryrefslogtreecommitdiff
path: root/caffe2
diff options
context:
space:
mode:
authorHector Yuen <hyz@fb.com>2019-02-20 13:07:08 -0800
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-20 13:14:11 -0800
commit075c7b1fef537205a0bdb8e45d2f800e6c024603 (patch)
tree4db84a082769257dad146d2b6174581b7c23e62e /caffe2
parentdb1d61a5c3f80a5b05ec861ea59b638130355a48 (diff)
downloadpytorch-075c7b1fef537205a0bdb8e45d2f800e6c024603.tar.gz
pytorch-075c7b1fef537205a0bdb8e45d2f800e6c024603.tar.bz2
pytorch-075c7b1fef537205a0bdb8e45d2f800e6c024603.zip
make the threshold for acurracy more precise (#17194)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17194 we found that there is a per row absolute error due to int8 quant and a relative error table-wide in case fp16 is used Reviewed By: csummersea Differential Revision: D14113353 fbshipit-source-id: c7065aa9d15c453c2e5609f421ad0155145af889
Diffstat (limited to 'caffe2')
-rw-r--r--caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py36
1 files changed, 23 insertions, 13 deletions
diff --git a/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py b/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py
index 8f36dc02b7..8835092366 100644
--- a/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py
+++ b/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py
@@ -7,18 +7,28 @@ from caffe2.python import core, workspace
from hypothesis import given
-def compare_rowwise(emb_orig, emb_reconstructed):
+def compare_rowwise(emb_orig, emb_reconstructed, fp16):
+ # there is an absolute error introduced per row through int8 quantization
+ # and a relative error introduced when quantizing back from fp32 to fp16
assert(emb_orig.shape == emb_reconstructed.shape)
- range = np.amax(emb_orig, axis=1) - np.amin(emb_orig, axis=1)
- # TOOO: figure out the right threshold, this has to do with the
- # fact that the data types are float16, in float32, it should be /1.9
- threshold = range / 255.0 / 1.5
- diff = np.amax(np.abs(emb_orig - emb_reconstructed), axis=1)
- n_violated = ((threshold - diff) < 0).sum()
- if n_violated > 0:
- print(n_violated, threshold, diff, threshold < diff, emb_orig,
- emb_reconstructed, emb_orig - emb_reconstructed)
- assert(n_violated == 0)
+ rtol = 1e-8
+ if fp16:
+ rtol = 1e-3
+ erange = np.amax(emb_orig, axis=1) - np.amin(emb_orig, axis=1)
+
+ threshold = erange / 255.0 / 1.9
+
+ for i in range(emb_orig.shape[0]):
+ r_orig = emb_orig[i, :]
+ r_reconstructed = emb_reconstructed[i, :]
+
+ isclose = np.isclose(r_orig, r_reconstructed, atol=threshold[i], rtol=rtol)
+ n_violated = isclose.size - isclose.sum()
+
+ if n_violated > 0:
+ print(isclose, threshold[i])
+ print(i, r_orig, r_reconstructed, threshold[i], r_orig - r_reconstructed)
+ assert(n_violated == 0)
class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase):
@@ -102,7 +112,7 @@ class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase):
dequantized_data = workspace.FetchBlob("dequantized_data")
np.testing.assert_array_almost_equal(input_data, workspace.FetchBlob("input_data"))
- compare_rowwise(input_data, dequantized_data)
+ compare_rowwise(input_data, dequantized_data, fp16)
sum_reference = workspace.FetchBlob("sum_reference")
sum_quantized = workspace.FetchBlob("sum_quantized")
@@ -179,7 +189,7 @@ class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase):
dequantized_data = workspace.FetchBlob("dequantized_data")
np.testing.assert_array_almost_equal(input_data, workspace.FetchBlob("input_data"))
- compare_rowwise(input_data, dequantized_data)
+ compare_rowwise(input_data, dequantized_data, fp16)
mean_reference = workspace.FetchBlob("mean_reference")
mean_quantized = workspace.FetchBlob("mean_quantized")