diff options
author | Hector Yuen <hyz@fb.com> | 2019-02-20 13:07:08 -0800 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-20 13:14:11 -0800 |
commit | 075c7b1fef537205a0bdb8e45d2f800e6c024603 (patch) | |
tree | 4db84a082769257dad146d2b6174581b7c23e62e /caffe2 | |
parent | db1d61a5c3f80a5b05ec861ea59b638130355a48 (diff) | |
download | pytorch-075c7b1fef537205a0bdb8e45d2f800e6c024603.tar.gz pytorch-075c7b1fef537205a0bdb8e45d2f800e6c024603.tar.bz2 pytorch-075c7b1fef537205a0bdb8e45d2f800e6c024603.zip |
make the threshold for acurracy more precise (#17194)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17194
we found that there is a per row absolute error due to int8 quant
and a relative error table-wide in case fp16 is used
Reviewed By: csummersea
Differential Revision: D14113353
fbshipit-source-id: c7065aa9d15c453c2e5609f421ad0155145af889
Diffstat (limited to 'caffe2')
-rw-r--r-- | caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py | 36 |
1 files changed, 23 insertions, 13 deletions
diff --git a/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py b/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py index 8f36dc02b7..8835092366 100644 --- a/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py +++ b/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py @@ -7,18 +7,28 @@ from caffe2.python import core, workspace from hypothesis import given -def compare_rowwise(emb_orig, emb_reconstructed): +def compare_rowwise(emb_orig, emb_reconstructed, fp16): + # there is an absolute error introduced per row through int8 quantization + # and a relative error introduced when quantizing back from fp32 to fp16 assert(emb_orig.shape == emb_reconstructed.shape) - range = np.amax(emb_orig, axis=1) - np.amin(emb_orig, axis=1) - # TOOO: figure out the right threshold, this has to do with the - # fact that the data types are float16, in float32, it should be /1.9 - threshold = range / 255.0 / 1.5 - diff = np.amax(np.abs(emb_orig - emb_reconstructed), axis=1) - n_violated = ((threshold - diff) < 0).sum() - if n_violated > 0: - print(n_violated, threshold, diff, threshold < diff, emb_orig, - emb_reconstructed, emb_orig - emb_reconstructed) - assert(n_violated == 0) + rtol = 1e-8 + if fp16: + rtol = 1e-3 + erange = np.amax(emb_orig, axis=1) - np.amin(emb_orig, axis=1) + + threshold = erange / 255.0 / 1.9 + + for i in range(emb_orig.shape[0]): + r_orig = emb_orig[i, :] + r_reconstructed = emb_reconstructed[i, :] + + isclose = np.isclose(r_orig, r_reconstructed, atol=threshold[i], rtol=rtol) + n_violated = isclose.size - isclose.sum() + + if n_violated > 0: + print(isclose, threshold[i]) + print(i, r_orig, r_reconstructed, threshold[i], r_orig - r_reconstructed) + assert(n_violated == 0) class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase): @@ -102,7 +112,7 @@ class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase): dequantized_data = workspace.FetchBlob("dequantized_data") np.testing.assert_array_almost_equal(input_data, workspace.FetchBlob("input_data")) - compare_rowwise(input_data, dequantized_data) + compare_rowwise(input_data, dequantized_data, fp16) sum_reference = workspace.FetchBlob("sum_reference") sum_quantized = workspace.FetchBlob("sum_quantized") @@ -179,7 +189,7 @@ class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase): dequantized_data = workspace.FetchBlob("dequantized_data") np.testing.assert_array_almost_equal(input_data, workspace.FetchBlob("input_data")) - compare_rowwise(input_data, dequantized_data) + compare_rowwise(input_data, dequantized_data, fp16) mean_reference = workspace.FetchBlob("mean_reference") mean_quantized = workspace.FetchBlob("mean_quantized") |