Skip to content

Commit 099d0de

Browse files
fix: Dynamic max_answers for SquadProcessor (fixes IndexError when max_answers is less than the number of answers in the dataset) (#4817)
* #4320 implemented dynamic max_answers for SquadProcessor, fixed IndexError when max_answers is less than the number of answers in the dataset * #4320 added two unit tests for dataset_from_dicts testing default and manual max_answers * apply suggestions from code review Co-authored-by: bogdankostic <bogdankostic@web.de> * simplify comment, fix mypy & pylint errors, fix old test * adjust max_answers to each dataset individually --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
1 parent 8fbfca9 commit 099d0de

File tree

2 files changed

+62
-11
lines changed

2 files changed

+62
-11
lines changed

haystack/modeling/data_handler/processor.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from haystack.modeling.data_handler.input_features import sample_to_features_text
3636
from haystack.utils.experiment_tracking import Tracker as tracker
3737

38-
3938
DOWNSTREAM_TASK_MAP = {
4039
"squad20": "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-downstream/squad20.tar.gz",
4140
"covidqa": "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-downstream/covidqa.tar.gz",
@@ -381,7 +380,7 @@ def __init__(
381380
doc_stride: int = 128,
382381
max_query_length: int = 64,
383382
proxies: Optional[dict] = None,
384-
max_answers: int = 6,
383+
max_answers: Optional[int] = None,
385384
**kwargs,
386385
):
387386
"""
@@ -403,7 +402,9 @@ def __init__(
403402
:param max_query_length: Maximum length of the question (in number of subword tokens)
404403
:param proxies: proxy configuration to allow downloads of remote datasets.
405404
Format as in "requests" library: https://2.python-requests.org//en/latest/user/advanced/#proxies
406-
:param max_answers: number of answers to be converted. QA dev or train sets can contain multi-way annotations, which are converted to arrays of max_answer length
405+
:param max_answers: Number of answers to be converted. QA sets can contain multi-way annotations, which are converted to arrays of max_answer length.
406+
Adjusts to maximum number of answers in the first processed datasets if not set.
407+
Truncates or pads to max_answer length if set.
407408
:param kwargs: placeholder for passing generic parameters
408409
"""
409410
self.ph_output_type = "per_token_squad"
@@ -469,12 +470,19 @@ def dataset_from_dicts(
469470
# Split documents into smaller passages to fit max_seq_len
470471
baskets = self._split_docs_into_passages(baskets)
471472

473+
# Determine max_answers if not set
474+
max_answers = (
475+
self.max_answers
476+
if self.max_answers is not None
477+
else max(max(len(basket.raw["answers"]) for basket in baskets), 1)
478+
)
479+
472480
# Convert answers from string to token space, skip this step for inference
473481
if not return_baskets:
474-
baskets = self._convert_answers(baskets)
482+
baskets = self._convert_answers(baskets, max_answers)
475483

476484
# Convert internal representation (nested baskets + samples with mixed types) to pytorch features (arrays of numbers)
477-
baskets = self._passages_to_pytorch_features(baskets, return_baskets)
485+
baskets = self._passages_to_pytorch_features(baskets, return_baskets, max_answers)
478486

479487
# Convert features into pytorch dataset, this step also removes potential errors during preprocessing
480488
dataset, tensor_names, baskets = self._create_dataset(baskets)
@@ -607,7 +615,7 @@ def _split_docs_into_passages(self, baskets: List[SampleBasket]):
607615

608616
return baskets
609617

610-
def _convert_answers(self, baskets: List[SampleBasket]):
618+
def _convert_answers(self, baskets: List[SampleBasket], max_answers: int):
611619
"""
612620
Converts answers that are pure strings into the token based representation with start and end token offset.
613621
Can handle multiple answers per question document pair as is common for development/text sets
@@ -617,14 +625,22 @@ def _convert_answers(self, baskets: List[SampleBasket]):
617625
for sample in basket.samples: # type: ignore
618626
# Dealing with potentially multiple answers (e.g. Squad dev set)
619627
# Initializing a numpy array of shape (max_answers, 2), filled with -1 for missing values
620-
label_idxs = np.full((self.max_answers, 2), fill_value=-1)
628+
label_idxs = np.full((max_answers, 2), fill_value=-1)
621629

622630
if error_in_answer or (len(basket.raw["answers"]) == 0):
623631
# If there are no answers we set
624632
label_idxs[0, :] = 0
625633
else:
626634
# For all other cases we use start and end token indices, that are relative to the passage
627635
for i, answer in enumerate(basket.raw["answers"]):
636+
if i >= max_answers:
637+
logger.warning(
638+
"Found a sample with more answers (%d) than "
639+
"max_answers (%d). These will be ignored.",
640+
len(basket.raw["answers"]),
641+
max_answers,
642+
)
643+
break
628644
# Calculate start and end relative to document
629645
answer_len_c = len(answer["text"])
630646
answer_start_c = answer["answer_start"]
@@ -691,7 +707,7 @@ def _convert_answers(self, baskets: List[SampleBasket]):
691707

692708
return baskets
693709

694-
def _passages_to_pytorch_features(self, baskets: List[SampleBasket], return_baskets: bool):
710+
def _passages_to_pytorch_features(self, baskets: List[SampleBasket], return_baskets: bool, max_answers: int):
695711
"""
696712
Convert internal representation (nested baskets + samples with mixed types) to python features (arrays of numbers).
697713
We first join question and passages into one large vector.
@@ -769,7 +785,7 @@ def _passages_to_pytorch_features(self, baskets: List[SampleBasket], return_bask
769785
len(input_ids) == len(padding_mask) == len(segment_ids) == len(start_of_word) == len(span_mask)
770786
)
771787
id_check = len(sample_id) == 3
772-
label_check = return_baskets or len(sample.tokenized.get("labels", [])) == self.max_answers # type: ignore
788+
label_check = return_baskets or len(sample.tokenized.get("labels", [])) == max_answers # type: ignore
773789
# labels are set to -100 when answer cannot be found
774790
label_check2 = return_baskets or np.all(sample.tokenized["labels"] > -99) # type: ignore
775791
if len_check and id_check and label_check and label_check2:

test/modeling/test_processor.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import copy
12
import logging
23

4+
import pytest
35
from transformers import AutoTokenizer
46

57
from haystack.modeling.data_handler.processor import SquadProcessor
@@ -233,7 +235,7 @@ def test_batch_encoding_flatten_rename():
233235
pass
234236

235237

236-
def test_dataset_from_dicts_qa_labelconversion(samples_path, caplog=None):
238+
def test_dataset_from_dicts_qa_label_conversion(samples_path, caplog=None):
237239
if caplog:
238240
caplog.set_level(logging.CRITICAL)
239241

@@ -248,7 +250,7 @@ def test_dataset_from_dicts_qa_labelconversion(samples_path, caplog=None):
248250

249251
for model in models:
250252
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model)
251-
processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None)
253+
processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None, max_answers=6)
252254

253255
for sample_type in sample_types:
254256
dicts = processor.file_to_dicts(samples_path / "qa" / f"{sample_type}.json")
@@ -296,3 +298,36 @@ def test_dataset_from_dicts_qa_labelconversion(samples_path, caplog=None):
296298
12,
297299
12,
298300
], f"Processing labels for {model} has changed."
301+
302+
303+
@pytest.mark.integration
304+
def test_dataset_from_dicts_auto_determine_max_answers(samples_path, caplog=None):
305+
"""
306+
SquadProcessor should determine the number of answers for the pytorch dataset based on
307+
the maximum number of answers for each question. Vanilla.json has one question with two answers,
308+
so the number of answers should be two.
309+
"""
310+
model = "deepset/roberta-base-squad2"
311+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model)
312+
processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None)
313+
dicts = processor.file_to_dicts(samples_path / "qa" / "vanilla.json")
314+
dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts, indices=[1])
315+
assert len(dataset[0][tensor_names.index("labels")]) == 2
316+
# check that a max_answers will be adjusted when processing a different dataset with the same SquadProcessor
317+
dicts_more_answers = copy.deepcopy(dicts)
318+
dicts_more_answers[0]["qas"][0]["answers"] = dicts_more_answers[0]["qas"][0]["answers"] * 3
319+
dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts_more_answers, indices=[1])
320+
assert len(dataset[0][tensor_names.index("labels")]) == 6
321+
322+
323+
@pytest.mark.integration
324+
def test_dataset_from_dicts_truncate_max_answers(samples_path, caplog=None):
325+
"""
326+
Test that it is possible to manually set the number of answers, truncating the answers in the data.
327+
"""
328+
model = "deepset/roberta-base-squad2"
329+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model)
330+
processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None, max_answers=1)
331+
dicts = processor.file_to_dicts(samples_path / "qa" / "vanilla.json")
332+
dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts, indices=[1])
333+
assert len(dataset[0][tensor_names.index("labels")]) == 1

0 commit comments

Comments
 (0)