diff options
author | Wei Yang <38509346+weiyangfb@users.noreply.github.com> | 2018-06-07 17:17:18 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-07 17:17:18 -0700 |
commit | 4c2a1a1a64db08f24219ce5507ae93c0de926173 (patch) | |
tree | f4abf17d2ce5cb7983e25d301b5dcf5374fd52cf /test/test_nn.py | |
parent | ce122cc2d34135f6cf7fa2fc02335346722118c1 (diff) | |
download | pytorch-4c2a1a1a64db08f24219ce5507ae93c0de926173.tar.gz pytorch-4c2a1a1a64db08f24219ce5507ae93c0de926173.tar.bz2 pytorch-4c2a1a1a64db08f24219ce5507ae93c0de926173.zip |
Added backward function for kl_div target (#7839)
* added backward fn for target
* added module test for kl_div target, and assuming targets are probabilities
Diffstat (limited to 'test/test_nn.py')
-rw-r--r-- | test/test_nn.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/test/test_nn.py b/test/test_nn.py index 7cbecbcc27..bea6820aba 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5806,6 +5806,18 @@ def bce_with_logistic_no_reduce_scalar_test(): pickle=False) +def kldivloss_with_target_no_reduce_test(): + i = torch.rand(10, 10).log() + return dict( + fullname='KLDivLoss_with_target_no_reduce', + constructor=wrap_functional( + lambda t: F.kl_div(i.type_as(t), t, reduce=False)), + input_fn=lambda: torch.rand(10, 10), + reference_fn=lambda t, _: + loss_reference_fns['KLDivLoss'](i.type_as(t), t, reduce=False), + pickle=False) + + def kldivloss_no_reduce_test(): t = torch.randn(10, 10) return dict( @@ -6258,6 +6270,7 @@ new_module_tests = [ bceloss_no_reduce_scalar_test(), bceloss_weights_no_reduce_scalar_test(), bce_with_logistic_no_reduce_scalar_test(), + kldivloss_with_target_no_reduce_test(), kldivloss_no_reduce_test(), kldivloss_no_reduce_scalar_test(), l1loss_no_reduce_test(), |