We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c7a3720 commit 3143650Copy full SHA for 3143650
1 file changed
tests/ignite/metrics/nlp/test_perplexity.py
@@ -69,17 +69,16 @@ def test_token_weighted_accumulation(n_times, available_device):
69
70
b1_pred = torch.randn(2, 5, 4)
71
b1_y = torch.randint(0, 5, (2, 4))
72
- b2_pred = torch.randn(3, 5, 10)
73
- b2_y = torch.randint(0, 5, (3, 10))
+ b2_pred = torch.randn(3, 5, 4)
+ b2_y = torch.randint(0, 5, (3, 4))
74
75
ppl.reset()
76
ppl.update((b1_pred, b1_y))
77
ppl.update((b2_pred, b2_y))
78
79
- nll1 = F.cross_entropy(b1_pred, b1_y, reduction="sum").item()
80
- nll2 = F.cross_entropy(b2_pred, b2_y, reduction="sum").item()
81
- total_tokens = b1_y.numel() + b2_y.numel()
82
- ppl_ref = torch.exp(torch.tensor((nll1 + nll2) / total_tokens)).item()
+ combined_pred = torch.cat([b1_pred, b2_pred], dim=0)
+ combined_y = torch.cat([b1_y, b2_y], dim=0)
+ ppl_ref = _reference_perplexity(combined_pred, combined_y)
83
84
assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref
85
0 commit comments