Skip to content

Commit 401ed85

Browse files
committed
Refactor type hints from List and Tuple to list and tuple for consistency
1 parent ea7325a commit 401ed85

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

mmlearn/modules/metrics/retrieval_recall.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import concurrent.futures
44
import os
55
from functools import partial
6-
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, Literal, Optional, Union
77

88
import torch
99
import torch.distributed
@@ -56,9 +56,9 @@ class RetrievalRecallAtK(Metric):
5656
higher_is_better: bool = True
5757
full_state_update: bool = False
5858

59-
indexes: List[torch.Tensor]
60-
x: List[torch.Tensor]
61-
y: List[torch.Tensor]
59+
indexes: list[torch.Tensor]
60+
x: list[torch.Tensor]
61+
y: list[torch.Tensor]
6262
num_samples: torch.Tensor
6363

6464
def __init__(
@@ -272,9 +272,9 @@ def _recall_at_k(
272272
Parameters
273273
----------
274274
scores : torch.Tensor
275-
Compatibility score between embeddings (num_x, num_y).
275+
Compatibility score between embeddings ``(num_x, num_y)``.
276276
positive_pairs : torch.Tensor
277-
Boolean matrix of positive pairs (num_x, num_y).
277+
Boolean matrix of positive pairs ``(num_x, num_y)``.
278278
k : int
279279
Consider only the top k elements for each query.
280280
@@ -293,7 +293,7 @@ def _update_batch_inputs(
293293
x: torch.Tensor,
294294
y: torch.Tensor,
295295
indexes: torch.Tensor,
296-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
296+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
297297
"""Update and returns variables required to compute Retrieval Recall.
298298
299299
Checks for same shape of input tensors.
@@ -309,7 +309,7 @@ def _update_batch_inputs(
309309
310310
Returns
311311
-------
312-
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
312+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
313313
Returns updated tensors required to compute Retrieval Recall.
314314
315315
"""

0 commit comments

Comments
 (0)