summaryrefslogtreecommitdiff
path: root/test/test_nn.py
diff options
context:
space:
mode:
authorWei Yang <38509346+weiyangfb@users.noreply.github.com>2018-06-07 17:17:18 -0700
committerGitHub <noreply@github.com>2018-06-07 17:17:18 -0700
commit4c2a1a1a64db08f24219ce5507ae93c0de926173 (patch)
treef4abf17d2ce5cb7983e25d301b5dcf5374fd52cf /test/test_nn.py
parentce122cc2d34135f6cf7fa2fc02335346722118c1 (diff)
downloadpytorch-4c2a1a1a64db08f24219ce5507ae93c0de926173.tar.gz
pytorch-4c2a1a1a64db08f24219ce5507ae93c0de926173.tar.bz2
pytorch-4c2a1a1a64db08f24219ce5507ae93c0de926173.zip
Added backward function for kl_div target (#7839)
* added backward fn for target * added module test for kl_div target, and assuming targets are probabilities
Diffstat (limited to 'test/test_nn.py')
-rw-r--r--test/test_nn.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/test/test_nn.py b/test/test_nn.py
index 7cbecbcc27..bea6820aba 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -5806,6 +5806,18 @@ def bce_with_logistic_no_reduce_scalar_test():
pickle=False)
+def kldivloss_with_target_no_reduce_test():
+ i = torch.rand(10, 10).log()
+ return dict(
+ fullname='KLDivLoss_with_target_no_reduce',
+ constructor=wrap_functional(
+ lambda t: F.kl_div(i.type_as(t), t, reduce=False)),
+ input_fn=lambda: torch.rand(10, 10),
+ reference_fn=lambda t, _:
+ loss_reference_fns['KLDivLoss'](i.type_as(t), t, reduce=False),
+ pickle=False)
+
+
def kldivloss_no_reduce_test():
t = torch.randn(10, 10)
return dict(
@@ -6258,6 +6270,7 @@ new_module_tests = [
bceloss_no_reduce_scalar_test(),
bceloss_weights_no_reduce_scalar_test(),
bce_with_logistic_no_reduce_scalar_test(),
+ kldivloss_with_target_no_reduce_test(),
kldivloss_no_reduce_test(),
kldivloss_no_reduce_scalar_test(),
l1loss_no_reduce_test(),