Skip to content

Commit dc85d0f

Browse files
Fix/hitrate empty topk (#3719)
Fixes #3718 Description: - Added validation for empty `top_k` list raising `ValueError` instead of crashing with `IndexError` in `update()` - Added support for `top_k: list[int] | int` so users can pass a single int without wrapping in a list - Added type validation raising `ValueError` if `top_k` is neither int nor list - Fixed leading space in existing error message Check list: - [x] New tests are added (if a new feature is added) - [x] New doc strings: description and/or example code are in RST format - [x] Documentation is updated (if required) --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent fd1efbe commit dc85d0f

2 files changed

Lines changed: 53 additions & 3 deletions

File tree

ignite/metrics/rec_sys/hitrate.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ class HitRate(Metric):
2828
- returns a list of HitRate ordered by the sorted values of ``top_k``.
2929
3030
Args:
31-
top_k: a list of sorted positive integers that specifies `k` for calculating hitrate@top-k.
31+
top_k: a single positive integer or a list of positive integers that specifies `k` for
32+
calculating hitrate@top-k. If a single int is provided, it will be wrapped in a list.
33+
Default is 1.
3234
ignore_zero_hits: if True, users with no relevant items (ground truth tensor being all zeros)
3335
are ignored in computation of HitRate. if set False, such users are counted as a miss.
3436
By default, True.
@@ -97,22 +99,50 @@ class HitRate(Metric):
9799
98100
[0.0, 0.5, 0.5, 0.5]
99101
102+
int top_k case
103+
104+
.. testcode:: 3
105+
106+
metric = HitRate(top_k=2)
107+
metric.attach(default_evaluator, "hit_rate")
108+
y_pred = torch.Tensor([
109+
[4.0, 2.0, 3.0, 1.0],
110+
])
111+
y_true = torch.Tensor([
112+
[0.0, 0.0, 1.0, 0.0],
113+
])
114+
state = default_evaluator.run([(y_pred, y_true)])
115+
print(state.metrics["hit_rate"])
116+
117+
.. testoutput:: 3
118+
119+
[1.0]
120+
100121
.. versionadded:: 0.5.4
122+
.. versionchanged:: 0.5.4
123+
`top_k` now accepts a single positive integer in addition to a list of integers.
101124
"""
102125

103126
required_output_keys = ("y_pred", "y")
104127
_state_dict_all_req_keys = ("_hits_per_k", "_num_examples")
105128

106129
def __init__(
107130
self,
108-
top_k: list[int],
131+
top_k: list[int] | int = 1,
109132
ignore_zero_hits: bool = True,
110133
output_transform: Callable = lambda x: x,
111134
device: str | torch.device = torch.device("cpu"),
112135
skip_unrolling: bool = False,
113136
):
137+
if not isinstance(top_k, (int, list)):
138+
raise ValueError("top_k must be either int or a list[int]")
139+
140+
top_k = [top_k] if isinstance(top_k, int) else top_k
141+
142+
if len(top_k) == 0:
143+
raise ValueError("top_k must have at least one positive value")
114144
if any(k <= 0 for k in top_k):
115-
raise ValueError(" top_k must be list of positive integers only.")
145+
raise ValueError("top_k must be list of positive integers only.")
116146

117147
self.top_k = sorted(top_k)
118148
self.ignore_zero_hits = ignore_zero_hits

tests/ignite/metrics/rec_sys/test_hitrate_metric.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,26 @@ def test_shape_mismatch():
5454
metric.update((y_pred, y))
5555

5656

57+
def test_empty_top_k():
58+
with pytest.raises(ValueError, match="top_k must have at least one positive value"):
59+
HitRate(top_k=[])
60+
61+
62+
def test_invalid_top_k_type():
63+
with pytest.raises(ValueError, match="top_k must be either int or a list"):
64+
HitRate(top_k="invalid")
65+
66+
67+
def test_int_top_k(available_device):
68+
metric = HitRate(top_k=2, device=available_device)
69+
y_pred = torch.tensor([[4.0, 2.0, 3.0, 1.0]])
70+
y_true = torch.tensor([[0.0, 0.0, 1.0, 0.0]])
71+
metric.update((y_pred, y_true))
72+
res = metric.compute()
73+
expected = manual_hit_rate(y_pred.numpy(), y_true.numpy(), [2])
74+
np.testing.assert_allclose(res, expected)
75+
76+
5777
@pytest.mark.parametrize("top_k", [[1], [1, 2, 4]])
5878
@pytest.mark.parametrize("ignore_zero_hits", [True, False])
5979
def test_compute(top_k, ignore_zero_hits, available_device):

0 commit comments

Comments
 (0)