Skip to content

Commit a84a35b

Browse files
committed
check if k >= num_classes
1 parent ab5f009 commit a84a35b

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

libmultilabel/linear/metrics.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,15 @@ def get_metrics(monitor_metrics: list[str], num_classes: int, multiclass: bool =
262262
monitor_metrics = []
263263
metrics = {}
264264
for metric in monitor_metrics:
265+
metric_at_k = re.match(r"(?:P|R|RP|NDCG)@(\d+)$", metric.upper())
266+
if metric_at_k:
267+
top_k = int(metric_at_k.groups()[0])
268+
print(f"top_k: {top_k}")
269+
if top_k >= num_classes:
270+
raise ValueError(
271+
f"Invalid metric: {metric}. top_k ({top_k}) is greater than num_classes({num_classes})."
272+
)
273+
265274
if re.match("P@\d+", metric):
266275
metrics[metric] = Precision(num_classes, average="samples", top_k=int(metric[2:]))
267276
elif re.match("R@\d+", metric):

0 commit comments

Comments
 (0)