Skip to content

Commit 3143650

Browse files
committed
fix(tests): use _reference_perplexity and matching seq lengths in accumulation test
1 parent c7a3720 commit 3143650

1 file changed

Lines changed: 5 additions & 6 deletions

File tree

tests/ignite/metrics/nlp/test_perplexity.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,16 @@ def test_token_weighted_accumulation(n_times, available_device):
6969

7070
b1_pred = torch.randn(2, 5, 4)
7171
b1_y = torch.randint(0, 5, (2, 4))
72-
b2_pred = torch.randn(3, 5, 10)
73-
b2_y = torch.randint(0, 5, (3, 10))
72+
b2_pred = torch.randn(3, 5, 4)
73+
b2_y = torch.randint(0, 5, (3, 4))
7474

7575
ppl.reset()
7676
ppl.update((b1_pred, b1_y))
7777
ppl.update((b2_pred, b2_y))
7878

79-
nll1 = F.cross_entropy(b1_pred, b1_y, reduction="sum").item()
80-
nll2 = F.cross_entropy(b2_pred, b2_y, reduction="sum").item()
81-
total_tokens = b1_y.numel() + b2_y.numel()
82-
ppl_ref = torch.exp(torch.tensor((nll1 + nll2) / total_tokens)).item()
79+
combined_pred = torch.cat([b1_pred, b2_pred], dim=0)
80+
combined_y = torch.cat([b1_y, b2_y], dim=0)
81+
ppl_ref = _reference_perplexity(combined_pred, combined_y)
8382

8483
assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref
8584

0 commit comments

Comments
 (0)