Skip to content

Commit 98175d0

Browse files
committed
check if k >= num_classes
1 parent ab5f009 commit 98175d0

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

libmultilabel/linear/metrics.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,17 @@ def get_metrics(monitor_metrics: list[str], num_classes: int, multiclass: bool =
261261
if monitor_metrics is None:
262262
monitor_metrics = []
263263
metrics = {}
264+
264265
for metric in monitor_metrics:
266+
metric_at_k = re.match(r"(?:P|R|RP|NDCG)@(\d+)$", metric.upper())
267+
if metric_at_k:
268+
top_k = int(metric_at_k.groups()[0])
269+
print(f"top_k: {top_k}")
270+
if top_k >= num_classes:
271+
raise ValueError(
272+
f"Invalid metric: {metric}. top_k ({top_k}) is greater than num_classes({num_classes})."
273+
)
274+
265275
if re.match("P@\d+", metric):
266276
metrics[metric] = Precision(num_classes, average="samples", top_k=int(metric[2:]))
267277
elif re.match("R@\d+", metric):

0 commit comments

Comments
 (0)