diff --git a/libmultilabel/linear/metrics.py b/libmultilabel/linear/metrics.py index 8fbe6c61..8c1d98b9 100644 --- a/libmultilabel/linear/metrics.py +++ b/libmultilabel/linear/metrics.py @@ -262,6 +262,15 @@ def get_metrics(monitor_metrics: list[str], num_classes: int, multiclass: bool = monitor_metrics = [] metrics = {} for metric in monitor_metrics: + metric_at_k = re.match(r"(?:P|R|RP|NDCG)@(\d+)$", metric.upper()) + if metric_at_k: + top_k = int(metric_at_k.groups()[0]) + + if top_k >= num_classes: + raise ValueError( + f"Invalid metric: {metric}. top_k ({top_k}) is greater than num_classes({num_classes})." + ) + if re.match("P@\d+", metric): metrics[metric] = Precision(num_classes, average="samples", top_k=int(metric[2:])) elif re.match("R@\d+", metric):