/ haystack / components / readers / extractive.py
extractive.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import math
  6  from dataclasses import replace
  7  from pathlib import Path
  8  from typing import Any
  9  
 10  from haystack import Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
 11  from haystack.lazy_imports import LazyImport
 12  from haystack.utils import ComponentDevice, Device, DeviceMap, Secret
 13  from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs
 14  
 15  with LazyImport("Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import:
 16      import accelerate  # noqa: F401 # the library is used but not directly referenced
 17      import torch
 18      from tokenizers import Encoding
 19      from transformers import AutoModelForQuestionAnswering, AutoTokenizer
 20  
 21  
 22  logger = logging.getLogger(__name__)
 23  
 24  
 25  @component
 26  class ExtractiveReader:
 27      """
 28      Locates and extracts answers to a given query from Documents.
 29  
 30      The ExtractiveReader component performs extractive question answering.
 31      It assigns a score to every possible answer span independently of other answer spans.
 32      This fixes a common issue of other implementations which make comparisons across documents harder by normalizing
 33      each document's answers independently.
 34  
 35      Example usage:
 36      ```python
 37      from haystack import Document
 38      from haystack.components.readers import ExtractiveReader
 39  
 40      docs = [
 41          Document(content="Python is a popular programming language"),
 42          Document(content="python ist eine beliebte Programmiersprache"),
 43      ]
 44  
 45      reader = ExtractiveReader()
 46  
 47      question = "What is a popular programming language?"
 48      result = reader.run(query=question, documents=docs)
 49      assert "Python" in result["answers"][0].data
 50      ```
 51      """
 52  
 53      def __init__(
 54          self,
 55          model: Path | str = "deepset/roberta-base-squad2-distilled",
 56          device: ComponentDevice | None = None,
 57          token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
 58          top_k: int = 20,
 59          score_threshold: float | None = None,
 60          max_seq_length: int = 384,
 61          stride: int = 128,
 62          max_batch_size: int | None = None,
 63          answers_per_seq: int | None = None,
 64          no_answer: bool = True,
 65          calibration_factor: float = 0.1,
 66          overlap_threshold: float | None = 0.01,
 67          model_kwargs: dict[str, Any] | None = None,
 68      ) -> None:
 69          """
 70          Creates an instance of ExtractiveReader.
 71  
 72          :param model:
 73              A Hugging Face transformers question answering model.
 74              Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub.
 75          :param device:
 76              The device on which the model is loaded. If `None`, the default device is automatically selected.
 77          :param token:
 78              The API token used to download private models from Hugging Face.
 79          :param top_k:
 80              Number of answers to return per query. It is required even if score_threshold is set.
 81              An additional answer with no text is returned if no_answer is set to True (default).
 82          :param score_threshold:
 83              Returns only answers with the probability score above this threshold.
 84          :param max_seq_length:
 85              Maximum number of tokens. If a sequence exceeds it, the sequence is split.
 86          :param stride:
 87              Number of tokens that overlap when sequence is split because it exceeds max_seq_length.
 88          :param max_batch_size:
 89              Maximum number of samples that are fed through the model at the same time.
 90          :param answers_per_seq:
 91              Number of answer candidates to consider per sequence.
 92              This is relevant when a Document was split into multiple sequences because of max_seq_length.
 93          :param no_answer:
 94              Whether to return an additional `no answer` with an empty text and a score representing the
 95              probability that the other top_k answers are incorrect.
 96          :param calibration_factor:
 97              Factor used for calibrating probabilities.
 98          :param overlap_threshold:
 99              If set this will remove duplicate answers if they have an overlap larger than the
100              supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
101              one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
102              However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
103              both of these answers could be kept if this variable is set to 0.24 or lower.
104              If None is provided then all answers are kept.
105          :param model_kwargs:
106              Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
107              when loading the model specified in `model`. For details on what kwargs you can pass,
108              see the model's documentation.
109          """
110          torch_and_transformers_import.check()
111          self.model_name_or_path = str(model)
112          self.model = None
113          self.tokenizer: Any = None
114          self.device: ComponentDevice | None = None
115          self.token = token
116          self.max_seq_length = max_seq_length
117          self.top_k = top_k
118          self.score_threshold = score_threshold
119          self.stride = stride
120          self.max_batch_size = max_batch_size
121          self.answers_per_seq = answers_per_seq
122          self.no_answer = no_answer
123          self.calibration_factor = calibration_factor
124          self.overlap_threshold = overlap_threshold
125  
126          model_kwargs = resolve_hf_device_map(device=device, model_kwargs=model_kwargs)
127          self.model_kwargs = model_kwargs
128  
129      def _get_telemetry_data(self) -> dict[str, Any]:
130          """
131          Data that is sent to Posthog for usage analytics.
132          """
133          return {"model": self.model_name_or_path}
134  
135      def to_dict(self) -> dict[str, Any]:
136          """
137          Serializes the component to a dictionary.
138  
139          :returns:
140              Dictionary with serialized data.
141          """
142          serialization_dict = default_to_dict(
143              self,
144              model=self.model_name_or_path,
145              device=None,
146              token=self.token,
147              max_seq_length=self.max_seq_length,
148              top_k=self.top_k,
149              score_threshold=self.score_threshold,
150              stride=self.stride,
151              max_batch_size=self.max_batch_size,
152              answers_per_seq=self.answers_per_seq,
153              no_answer=self.no_answer,
154              calibration_factor=self.calibration_factor,
155              model_kwargs=self.model_kwargs,
156          )
157  
158          serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
159          return serialization_dict
160  
161      @classmethod
162      def from_dict(cls, data: dict[str, Any]) -> "ExtractiveReader":
163          """
164          Deserializes the component from a dictionary.
165  
166          :param data:
167              Dictionary to deserialize from.
168          :returns:
169              Deserialized component.
170          """
171          init_params = data["init_parameters"]
172          if init_params.get("model_kwargs") is not None:
173              deserialize_hf_model_kwargs(init_params["model_kwargs"])
174  
175          return default_from_dict(cls, data)
176  
177      def warm_up(self) -> None:
178          """
179          Initializes the component.
180          """
181          # Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device.
182          if self.model is None:
183              self.model = AutoModelForQuestionAnswering.from_pretrained(
184                  self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs
185              )
186              self.tokenizer = AutoTokenizer.from_pretrained(
187                  self.model_name_or_path, token=self.token.resolve_value() if self.token else None
188              )
189              assert self.model is not None  # mypy doesn't know this is set in the line above
190              # hf_device_map appears to only be set now when mixed devices are actually used.
191              # So if it's missing then we can use the device attribute which is set even for single-device models.
192              if hf_device_map := getattr(self.model, "hf_device_map", None):
193                  self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(hf_device_map))
194              else:
195                  self.device = ComponentDevice.from_single(Device.from_str(str(self.model.device)))
196  
197      @staticmethod
198      def _flatten_documents(
199          queries: list[str], documents: list[list[Document]]
200      ) -> tuple[list[str], list[Document], list[int]]:
201          """
202          Flattens queries and Documents so all query-document pairs are arranged along one batch axis.
203          """
204          flattened_queries = [query for documents_, query in zip(documents, queries, strict=True) for _ in documents_]
205          flattened_documents = [document for documents_ in documents for document in documents_]
206          query_ids = [i for i, documents_ in enumerate(documents) for _ in documents_]
207          return flattened_queries, flattened_documents, query_ids
208  
209      def _preprocess(
210          self, *, queries: list[str], documents: list[Document], max_seq_length: int, query_ids: list[int], stride: int
211      ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", list["Encoding"], list[int], list[int]]:
212          """
213          Splits and tokenizes Documents and preserves structures by returning mappings to query and Document IDs.
214          """
215          texts = []
216          document_ids = []
217          document_contents = []
218          for i, doc in enumerate(documents):
219              if doc.content is None:
220                  logger.warning(
221                      "Document with id {doc_id} was passed to ExtractiveReader. The Document doesn't "
222                      "contain any text and it will be ignored.",
223                      doc_id=doc.id,
224                  )
225                  continue
226              texts.append(doc.content)
227              document_ids.append(i)
228              document_contents.append(doc.content)
229  
230          # mypy doesn't know this is set in warm_up
231          encodings_pt = self.tokenizer(
232              queries,
233              document_contents,
234              padding=True,
235              truncation=True,
236              max_length=max_seq_length,
237              return_tensors="pt",
238              return_overflowing_tokens=True,
239              stride=stride,
240          )
241  
242          # Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device.
243          # mypy doesn't know this is set in warm_up
244          first_device = self.device.first_device.to_torch()  # type: ignore[union-attr]
245  
246          input_ids = encodings_pt.input_ids.to(first_device)
247          attention_mask = encodings_pt.attention_mask.to(first_device)
248  
249          query_ids = [query_ids[index] for index in encodings_pt.overflow_to_sample_mapping]
250          document_ids = [document_ids[sample_id] for sample_id in encodings_pt.overflow_to_sample_mapping]
251  
252          encodings = encodings_pt.encodings
253          sequence_ids = torch.tensor(
254              [[id_ if id_ is not None else -1 for id_ in encoding.sequence_ids] for encoding in encodings]
255          ).to(first_device)
256  
257          return input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids
258  
259      def _postprocess(
260          self,
261          *,
262          start: "torch.Tensor",
263          end: "torch.Tensor",
264          sequence_ids: "torch.Tensor",
265          attention_mask: "torch.Tensor",
266          answers_per_seq: int,
267          encodings: list["Encoding"],
268      ) -> tuple[list[list[int]], list[list[int]], "torch.Tensor"]:
269          """
270          Turns start and end logits into probabilities for each answer span.
271  
272          Unlike most other implementations, it doesn't normalize the scores in each split to make them easier to
273          compare across different splits. Returns the top k answer spans.
274          """
275          mask = sequence_ids == 1  # Only keep tokens from the context (should ignore special tokens)
276          mask = torch.logical_and(mask, attention_mask == 1)  # Definitely remove special tokens
277          start = torch.where(mask, start, -torch.inf)  # Apply the mask on the start logits
278          end = torch.where(mask, end, -torch.inf)  # Apply the mask on the end logits
279          start = start.unsqueeze(-1)
280          end = end.unsqueeze(-2)
281  
282          logits = start + end  # shape: (batch_size, seq_length (start), seq_length (end))
283  
284          # The mask here onwards is the same for all instances in the batch
285          # As such we do away with the batch dimension
286          mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=logits.device)
287          mask = torch.triu(mask)  # End shouldn't be before start
288          masked_logits = torch.where(mask, logits, -torch.inf)
289          probabilities = torch.sigmoid(masked_logits * self.calibration_factor)
290  
291          flat_probabilities = probabilities.flatten(-2, -1)  # necessary for top-k
292  
293          # top-k can return invalid candidates as well if answers_per_seq > num_valid_candidates
294          # We only keep probability > 0 candidates later on
295          candidates = torch.topk(flat_probabilities, answers_per_seq)
296          seq_length = logits.shape[-1]
297          start_candidates = candidates.indices // seq_length  # Recover indices from flattening
298          end_candidates = candidates.indices % seq_length
299          candidates_values = candidates.values.cpu()
300          start_candidates = start_candidates.cpu()
301          end_candidates = end_candidates.cpu()
302  
303          start_candidates_tokens_to_chars = []
304          end_candidates_tokens_to_chars = []
305          for i, (s_candidates, e_candidates, encoding) in enumerate(
306              zip(start_candidates, end_candidates, encodings, strict=True)
307          ):
308              # Those with probabilities > 0 are valid
309              valid = candidates_values[i] > 0
310              s_char_spans = []
311              e_char_spans = []
312              for start_token, end_token in zip(s_candidates[valid], e_candidates[valid], strict=True):
313                  # token_to_chars returns `None` for special tokens
314                  # But we shouldn't have special tokens in the answers at this point
315                  # The whole span is given by the start of the start_token (index 0)
316                  # and the end of the end token (index 1)
317                  s_char_spans.append(encoding.token_to_chars(start_token)[0])
318                  e_char_spans.append(encoding.token_to_chars(end_token)[1])
319              start_candidates_tokens_to_chars.append(s_char_spans)
320              end_candidates_tokens_to_chars.append(e_char_spans)
321  
322          return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, candidates_values
323  
324      def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer:
325          if answer.meta is None:
326              answer.meta = {}
327  
328          if answer.document_offset is None:
329              return answer
330  
331          if not answer.document or "page_number" not in answer.document.meta:
332              return answer
333  
334          if not isinstance(answer.document.meta["page_number"], int):
335              logger.warning(
336                  "Document's page_number must be int but is {type}. No page number will be added to the answer.",
337                  type=type(answer.document.meta["page_number"]),
338              )
339              return answer
340  
341          # Calculate the answer page number
342          if answer.document.content:
343              ans_start = answer.document_offset.start
344              answer_page_number = answer.document.meta["page_number"] + answer.document.content[:ans_start].count("\f")
345              answer.meta.update({"answer_page_number": answer_page_number})
346  
347          return answer
348  
349      def _nest_answers(
350          self,
351          *,
352          start: list[list[int]],
353          end: list[list[int]],
354          probabilities: "torch.Tensor",
355          flattened_documents: list[Document],
356          queries: list[str],
357          answers_per_seq: int,
358          top_k: int | None,
359          score_threshold: float | None,
360          query_ids: list[int],
361          document_ids: list[int],
362          no_answer: bool,
363          overlap_threshold: float | None,
364      ) -> list[list[ExtractedAnswer]]:
365          """
366          Reconstructs the nested structure that existed before flattening.
367  
368          Also computes a no answer score. This score is different from most other implementations because it does not
369          consider the no answer logit introduced with SQuAD 2. Instead, it just computes the probability that the
370          answer does not exist in the top k or top p.
371          """
372          answers_without_query = []
373          for document_id, start_candidates_, end_candidates_, probabilities_ in zip(
374              document_ids, start, end, probabilities, strict=True
375          ):
376              for start_, end_, probability in zip(start_candidates_, end_candidates_, probabilities_, strict=True):
377                  doc = flattened_documents[document_id]
378                  answers_without_query.append(
379                      ExtractedAnswer(
380                          query="",  # Can't be None but we'll add it later
381                          data=doc.content[start_:end_],  # type: ignore
382                          document=doc,
383                          score=probability.item(),
384                          document_offset=ExtractedAnswer.Span(start_, end_),
385                          meta={},
386                      )
387                  )
388          i = 0
389          nested_answers = []
390          for query_id in range(query_ids[-1] + 1):
391              current_answers = []
392              while i < len(answers_without_query) and query_ids[i // answers_per_seq] == query_id:
393                  current_answers.append(replace(answers_without_query[i], query=queries[query_id]))
394                  i += 1
395              current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True)
396              current_answers = self.deduplicate_by_overlap(current_answers, overlap_threshold=overlap_threshold)
397              current_answers = current_answers[:top_k]
398  
399              # Calculate the answer page number and add it to meta
400              current_answers = [self._add_answer_page_number(answer=answer) for answer in current_answers]
401  
402              if no_answer:
403                  no_answer_score = math.prod(1 - answer.score for answer in current_answers)
404                  answer_ = ExtractedAnswer(
405                      data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score
406                  )
407                  current_answers.append(answer_)
408              current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True)
409              if score_threshold is not None:
410                  current_answers = [answer for answer in current_answers if answer.score >= score_threshold]
411              nested_answers.append(current_answers)
412  
413          return nested_answers
414  
415      def _calculate_overlap(self, answer1_start: int, answer1_end: int, answer2_start: int, answer2_end: int) -> int:
416          """
417          Calculates the amount of overlap (in number of characters) between two answer offsets.
418  
419          This Stack overflow
420          [post](https://stackoverflow.com/questions/325933/determine-whether-two-date-ranges-overlap/325964#325964)
421          explains how to calculate the overlap between two ranges.
422          """
423          # Check for overlap: (StartA <= EndB) and (StartB <= EndA)
424          if answer1_start <= answer2_end and answer2_start <= answer1_end:
425              return min(
426                  answer1_end - answer1_start,
427                  answer1_end - answer2_start,
428                  answer2_end - answer1_start,
429                  answer2_end - answer2_start,
430              )
431          return 0
432  
433      def _should_keep(
434          self, candidate_answer: ExtractedAnswer, current_answers: list[ExtractedAnswer], overlap_threshold: float
435      ) -> bool:
436          """
437          Determines if the answer should be kept based on how much it overlaps with previous answers.
438  
439          NOTE: We might want to avoid throwing away answers that only have a few character (or word) overlap:
440              - E.g. The answers "the river in" and "in Maine" from the context "I want to go to the river in Maine."
441              might both want to be kept.
442  
443          :param candidate_answer:
444              Candidate answer that will be checked if it should be kept.
445          :param current_answers:
446              Current list of answers that will be kept.
447          :param overlap_threshold:
448              If the overlap between two answers is greater than this threshold then return False.
449          """
450          keep = True
451  
452          # If the candidate answer doesn't have a document keep it
453          if not candidate_answer.document:
454              return keep
455  
456          for ans in current_answers:
457              # If an answer in current_answers doesn't have a document skip the comparison
458              if not ans.document:
459                  continue
460  
461              # If offset is missing then keep both
462              if ans.document_offset is None:
463                  continue
464  
465              # If offset is missing then keep both
466              if candidate_answer.document_offset is None:
467                  continue
468  
469              # If the answers come from different documents then keep both
470              if candidate_answer.document.id != ans.document.id:
471                  continue
472  
473              overlap_len = self._calculate_overlap(
474                  answer1_start=ans.document_offset.start,
475                  answer1_end=ans.document_offset.end,
476                  answer2_start=candidate_answer.document_offset.start,
477                  answer2_end=candidate_answer.document_offset.end,
478              )
479  
480              # If overlap is 0 then keep
481              if overlap_len == 0:
482                  continue
483  
484              overlap_frac_answer1 = overlap_len / (ans.document_offset.end - ans.document_offset.start)
485              overlap_frac_answer2 = overlap_len / (
486                  candidate_answer.document_offset.end - candidate_answer.document_offset.start
487              )
488  
489              if overlap_frac_answer1 > overlap_threshold or overlap_frac_answer2 > overlap_threshold:
490                  keep = False
491                  break
492  
493          return keep
494  
495      def deduplicate_by_overlap(
496          self, answers: list[ExtractedAnswer], overlap_threshold: float | None
497      ) -> list[ExtractedAnswer]:
498          """
499          De-duplicates overlapping Extractive Answers.
500  
501          De-duplicates overlapping Extractive Answers from the same document based on how much the spans of the
502          answers overlap.
503  
504          :param answers:
505              List of answers to be deduplicated.
506          :param overlap_threshold:
507              If set this will remove duplicate answers if they have an overlap larger than the
508              supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
509              one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
510              However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
511              both of these answers could be kept if this variable is set to 0.24 or lower.
512              If None is provided then all answers are kept.
513          :returns:
514              List of deduplicated answers.
515          """
516          if overlap_threshold is None:
517              return answers
518  
519          # Initialize with the first answer and its offsets_in_document
520          deduplicated_answers = [answers[0]]
521  
522          # Loop over remaining answers to check for overlaps
523          for ans in answers[1:]:
524              keep = self._should_keep(
525                  candidate_answer=ans, current_answers=deduplicated_answers, overlap_threshold=overlap_threshold
526              )
527              if keep:
528                  deduplicated_answers.append(ans)
529  
530          return deduplicated_answers
531  
532      @component.output_types(answers=list[ExtractedAnswer])
533      def run(
534          self,
535          query: str,
536          documents: list[Document],
537          top_k: int | None = None,
538          score_threshold: float | None = None,
539          max_seq_length: int | None = None,
540          stride: int | None = None,
541          max_batch_size: int | None = None,
542          answers_per_seq: int | None = None,
543          no_answer: bool | None = None,
544          overlap_threshold: float | None = None,
545      ) -> dict[str, Any]:
546          """
547          Locates and extracts answers from the given Documents using the given query.
548  
549          :param query:
550              Query string.
551          :param documents:
552              List of Documents in which you want to search for an answer to the query.
553          :param top_k:
554              The maximum number of answers to return.
555              An additional answer is returned if no_answer is set to True (default).
556          :param score_threshold:
557              Returns only answers with the score above this threshold.
558          :param max_seq_length:
559              Maximum number of tokens. If a sequence exceeds it, the sequence is split.
560          :param stride:
561              Number of tokens that overlap when sequence is split because it exceeds max_seq_length.
562          :param max_batch_size:
563              Maximum number of samples that are fed through the model at the same time.
564          :param answers_per_seq:
565              Number of answer candidates to consider per sequence.
566              This is relevant when a Document was split into multiple sequences because of max_seq_length.
567          :param no_answer:
568              Whether to return no answer scores.
569          :param overlap_threshold:
570              If set this will remove duplicate answers if they have an overlap larger than the
571              supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
572              one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
573              However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
574              both of these answers could be kept if this variable is set to 0.24 or lower.
575              If None is provided then all answers are kept.
576          :returns:
577              List of answers sorted by (desc.) answer score.
578          """
579          if self.model is None:
580              self.warm_up()
581  
582          if not documents:
583              return {"answers": []}
584  
585          queries = [query]  # Temporary solution until we have decided what batching should look like in v2
586          nested_documents = [documents]
587          top_k = top_k or self.top_k
588          score_threshold = score_threshold or self.score_threshold
589          max_seq_length = max_seq_length or self.max_seq_length
590          stride = stride or self.stride
591          max_batch_size = max_batch_size or self.max_batch_size
592          answers_per_seq = answers_per_seq or self.answers_per_seq or 20
593          no_answer = no_answer if no_answer is not None else self.no_answer
594          overlap_threshold = overlap_threshold or self.overlap_threshold
595  
596          flattened_queries, flattened_documents, query_ids = ExtractiveReader._flatten_documents(
597              queries, nested_documents
598          )
599          input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess(
600              queries=flattened_queries,
601              documents=flattened_documents,
602              max_seq_length=max_seq_length,
603              query_ids=query_ids,
604              stride=stride,
605          )
606  
607          num_batches = math.ceil(input_ids.shape[0] / max_batch_size) if max_batch_size else 1
608          batch_size = max_batch_size or input_ids.shape[0]
609  
610          start_logits_list = []
611          end_logits_list = []
612  
613          for i in range(num_batches):
614              start_index = i * batch_size
615              end_index = start_index + batch_size
616              cur_input_ids = input_ids[start_index:end_index]
617              cur_attention_mask = attention_mask[start_index:end_index]
618  
619              with torch.inference_mode():
620                  # mypy doesn't know this is set in warm_up
621                  output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask)  # type: ignore[misc]
622              cur_start_logits = output.start_logits
623              cur_end_logits = output.end_logits
624              if num_batches != 1:
625                  cur_start_logits = cur_start_logits.cpu()
626                  cur_end_logits = cur_end_logits.cpu()
627              start_logits_list.append(cur_start_logits)
628              end_logits_list.append(cur_end_logits)
629  
630          start_logits = torch.cat(start_logits_list)
631          end_logits = torch.cat(end_logits_list)
632  
633          start, end, probabilities = self._postprocess(
634              start=start_logits,
635              end=end_logits,
636              sequence_ids=sequence_ids,
637              attention_mask=attention_mask,
638              answers_per_seq=answers_per_seq,
639              encodings=encodings,
640          )
641  
642          answers = self._nest_answers(
643              start=start,
644              end=end,
645              probabilities=probabilities,
646              flattened_documents=flattened_documents,
647              queries=queries,
648              answers_per_seq=answers_per_seq,
649              top_k=top_k,
650              score_threshold=score_threshold,
651              query_ids=query_ids,
652              document_ids=document_ids,
653              no_answer=no_answer,
654              overlap_threshold=overlap_threshold,
655          )
656  
657          return {"answers": answers[0]}  # same temporary batching fix as above