33import concurrent .futures
44import os
55from functools import partial
6- from typing import Any , Callable , Dict , List , Literal , Optional , Tuple , Union
6+ from typing import Any , Callable , Dict , Literal , Optional , Union
77
88import torch
99import torch .distributed
@@ -56,9 +56,9 @@ class RetrievalRecallAtK(Metric):
5656 higher_is_better : bool = True
5757 full_state_update : bool = False
5858
59- indexes : List [torch .Tensor ]
60- x : List [torch .Tensor ]
61- y : List [torch .Tensor ]
59+ indexes : list [torch .Tensor ]
60+ x : list [torch .Tensor ]
61+ y : list [torch .Tensor ]
6262 num_samples : torch .Tensor
6363
6464 def __init__ (
@@ -272,9 +272,9 @@ def _recall_at_k(
272272 Parameters
273273 ----------
274274 scores : torch.Tensor
275- Compatibility score between embeddings (num_x, num_y).
275+ Compatibility score between embeddings `` (num_x, num_y)`` .
276276 positive_pairs : torch.Tensor
277- Boolean matrix of positive pairs (num_x, num_y).
277+ Boolean matrix of positive pairs `` (num_x, num_y)`` .
278278 k : int
279279 Consider only the top k elements for each query.
280280
@@ -293,7 +293,7 @@ def _update_batch_inputs(
293293 x : torch .Tensor ,
294294 y : torch .Tensor ,
295295 indexes : torch .Tensor ,
296- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
296+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
297297 """Update and returns variables required to compute Retrieval Recall.
298298
299299 Checks for same shape of input tensors.
@@ -309,7 +309,7 @@ def _update_batch_inputs(
309309
310310 Returns
311311 -------
312- Tuple [torch.Tensor, torch.Tensor, torch.Tensor]
312+ tuple [torch.Tensor, torch.Tensor, torch.Tensor]
313313 Returns updated tensors required to compute Retrieval Recall.
314314
315315 """
0 commit comments