extractive.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import math 6 from dataclasses import replace 7 from pathlib import Path 8 from typing import Any 9 10 from haystack import Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging 11 from haystack.lazy_imports import LazyImport 12 from haystack.utils import ComponentDevice, Device, DeviceMap, Secret 13 from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_device_map, serialize_hf_model_kwargs 14 15 with LazyImport("Run 'pip install transformers[torch,sentencepiece]'") as torch_and_transformers_import: 16 import accelerate # noqa: F401 # the library is used but not directly referenced 17 import torch 18 from tokenizers import Encoding 19 from transformers import AutoModelForQuestionAnswering, AutoTokenizer 20 21 22 logger = logging.getLogger(__name__) 23 24 25 @component 26 class ExtractiveReader: 27 """ 28 Locates and extracts answers to a given query from Documents. 29 30 The ExtractiveReader component performs extractive question answering. 31 It assigns a score to every possible answer span independently of other answer spans. 32 This fixes a common issue of other implementations which make comparisons across documents harder by normalizing 33 each document's answers independently. 34 35 Example usage: 36 ```python 37 from haystack import Document 38 from haystack.components.readers import ExtractiveReader 39 40 docs = [ 41 Document(content="Python is a popular programming language"), 42 Document(content="python ist eine beliebte Programmiersprache"), 43 ] 44 45 reader = ExtractiveReader() 46 47 question = "What is a popular programming language?" 48 result = reader.run(query=question, documents=docs) 49 assert "Python" in result["answers"][0].data 50 ``` 51 """ 52 53 def __init__( 54 self, 55 model: Path | str = "deepset/roberta-base-squad2-distilled", 56 device: ComponentDevice | None = None, 57 token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), 58 top_k: int = 20, 59 score_threshold: float | None = None, 60 max_seq_length: int = 384, 61 stride: int = 128, 62 max_batch_size: int | None = None, 63 answers_per_seq: int | None = None, 64 no_answer: bool = True, 65 calibration_factor: float = 0.1, 66 overlap_threshold: float | None = 0.01, 67 model_kwargs: dict[str, Any] | None = None, 68 ) -> None: 69 """ 70 Creates an instance of ExtractiveReader. 71 72 :param model: 73 A Hugging Face transformers question answering model. 74 Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub. 75 :param device: 76 The device on which the model is loaded. If `None`, the default device is automatically selected. 77 :param token: 78 The API token used to download private models from Hugging Face. 79 :param top_k: 80 Number of answers to return per query. It is required even if score_threshold is set. 81 An additional answer with no text is returned if no_answer is set to True (default). 82 :param score_threshold: 83 Returns only answers with the probability score above this threshold. 84 :param max_seq_length: 85 Maximum number of tokens. If a sequence exceeds it, the sequence is split. 86 :param stride: 87 Number of tokens that overlap when sequence is split because it exceeds max_seq_length. 88 :param max_batch_size: 89 Maximum number of samples that are fed through the model at the same time. 90 :param answers_per_seq: 91 Number of answer candidates to consider per sequence. 92 This is relevant when a Document was split into multiple sequences because of max_seq_length. 93 :param no_answer: 94 Whether to return an additional `no answer` with an empty text and a score representing the 95 probability that the other top_k answers are incorrect. 96 :param calibration_factor: 97 Factor used for calibrating probabilities. 98 :param overlap_threshold: 99 If set this will remove duplicate answers if they have an overlap larger than the 100 supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove 101 one of these answers since the second answer has a 100% (1.0) overlap with the first answer. 102 However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so 103 both of these answers could be kept if this variable is set to 0.24 or lower. 104 If None is provided then all answers are kept. 105 :param model_kwargs: 106 Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained` 107 when loading the model specified in `model`. For details on what kwargs you can pass, 108 see the model's documentation. 109 """ 110 torch_and_transformers_import.check() 111 self.model_name_or_path = str(model) 112 self.model = None 113 self.tokenizer: Any = None 114 self.device: ComponentDevice | None = None 115 self.token = token 116 self.max_seq_length = max_seq_length 117 self.top_k = top_k 118 self.score_threshold = score_threshold 119 self.stride = stride 120 self.max_batch_size = max_batch_size 121 self.answers_per_seq = answers_per_seq 122 self.no_answer = no_answer 123 self.calibration_factor = calibration_factor 124 self.overlap_threshold = overlap_threshold 125 126 model_kwargs = resolve_hf_device_map(device=device, model_kwargs=model_kwargs) 127 self.model_kwargs = model_kwargs 128 129 def _get_telemetry_data(self) -> dict[str, Any]: 130 """ 131 Data that is sent to Posthog for usage analytics. 132 """ 133 return {"model": self.model_name_or_path} 134 135 def to_dict(self) -> dict[str, Any]: 136 """ 137 Serializes the component to a dictionary. 138 139 :returns: 140 Dictionary with serialized data. 141 """ 142 serialization_dict = default_to_dict( 143 self, 144 model=self.model_name_or_path, 145 device=None, 146 token=self.token, 147 max_seq_length=self.max_seq_length, 148 top_k=self.top_k, 149 score_threshold=self.score_threshold, 150 stride=self.stride, 151 max_batch_size=self.max_batch_size, 152 answers_per_seq=self.answers_per_seq, 153 no_answer=self.no_answer, 154 calibration_factor=self.calibration_factor, 155 model_kwargs=self.model_kwargs, 156 ) 157 158 serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"]) 159 return serialization_dict 160 161 @classmethod 162 def from_dict(cls, data: dict[str, Any]) -> "ExtractiveReader": 163 """ 164 Deserializes the component from a dictionary. 165 166 :param data: 167 Dictionary to deserialize from. 168 :returns: 169 Deserialized component. 170 """ 171 init_params = data["init_parameters"] 172 if init_params.get("model_kwargs") is not None: 173 deserialize_hf_model_kwargs(init_params["model_kwargs"]) 174 175 return default_from_dict(cls, data) 176 177 def warm_up(self) -> None: 178 """ 179 Initializes the component. 180 """ 181 # Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device. 182 if self.model is None: 183 self.model = AutoModelForQuestionAnswering.from_pretrained( 184 self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs 185 ) 186 self.tokenizer = AutoTokenizer.from_pretrained( 187 self.model_name_or_path, token=self.token.resolve_value() if self.token else None 188 ) 189 assert self.model is not None # mypy doesn't know this is set in the line above 190 # hf_device_map appears to only be set now when mixed devices are actually used. 191 # So if it's missing then we can use the device attribute which is set even for single-device models. 192 if hf_device_map := getattr(self.model, "hf_device_map", None): 193 self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(hf_device_map)) 194 else: 195 self.device = ComponentDevice.from_single(Device.from_str(str(self.model.device))) 196 197 @staticmethod 198 def _flatten_documents( 199 queries: list[str], documents: list[list[Document]] 200 ) -> tuple[list[str], list[Document], list[int]]: 201 """ 202 Flattens queries and Documents so all query-document pairs are arranged along one batch axis. 203 """ 204 flattened_queries = [query for documents_, query in zip(documents, queries, strict=True) for _ in documents_] 205 flattened_documents = [document for documents_ in documents for document in documents_] 206 query_ids = [i for i, documents_ in enumerate(documents) for _ in documents_] 207 return flattened_queries, flattened_documents, query_ids 208 209 def _preprocess( 210 self, *, queries: list[str], documents: list[Document], max_seq_length: int, query_ids: list[int], stride: int 211 ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", list["Encoding"], list[int], list[int]]: 212 """ 213 Splits and tokenizes Documents and preserves structures by returning mappings to query and Document IDs. 214 """ 215 texts = [] 216 document_ids = [] 217 document_contents = [] 218 for i, doc in enumerate(documents): 219 if doc.content is None: 220 logger.warning( 221 "Document with id {doc_id} was passed to ExtractiveReader. The Document doesn't " 222 "contain any text and it will be ignored.", 223 doc_id=doc.id, 224 ) 225 continue 226 texts.append(doc.content) 227 document_ids.append(i) 228 document_contents.append(doc.content) 229 230 # mypy doesn't know this is set in warm_up 231 encodings_pt = self.tokenizer( 232 queries, 233 document_contents, 234 padding=True, 235 truncation=True, 236 max_length=max_seq_length, 237 return_tensors="pt", 238 return_overflowing_tokens=True, 239 stride=stride, 240 ) 241 242 # Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device. 243 # mypy doesn't know this is set in warm_up 244 first_device = self.device.first_device.to_torch() # type: ignore[union-attr] 245 246 input_ids = encodings_pt.input_ids.to(first_device) 247 attention_mask = encodings_pt.attention_mask.to(first_device) 248 249 query_ids = [query_ids[index] for index in encodings_pt.overflow_to_sample_mapping] 250 document_ids = [document_ids[sample_id] for sample_id in encodings_pt.overflow_to_sample_mapping] 251 252 encodings = encodings_pt.encodings 253 sequence_ids = torch.tensor( 254 [[id_ if id_ is not None else -1 for id_ in encoding.sequence_ids] for encoding in encodings] 255 ).to(first_device) 256 257 return input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids 258 259 def _postprocess( 260 self, 261 *, 262 start: "torch.Tensor", 263 end: "torch.Tensor", 264 sequence_ids: "torch.Tensor", 265 attention_mask: "torch.Tensor", 266 answers_per_seq: int, 267 encodings: list["Encoding"], 268 ) -> tuple[list[list[int]], list[list[int]], "torch.Tensor"]: 269 """ 270 Turns start and end logits into probabilities for each answer span. 271 272 Unlike most other implementations, it doesn't normalize the scores in each split to make them easier to 273 compare across different splits. Returns the top k answer spans. 274 """ 275 mask = sequence_ids == 1 # Only keep tokens from the context (should ignore special tokens) 276 mask = torch.logical_and(mask, attention_mask == 1) # Definitely remove special tokens 277 start = torch.where(mask, start, -torch.inf) # Apply the mask on the start logits 278 end = torch.where(mask, end, -torch.inf) # Apply the mask on the end logits 279 start = start.unsqueeze(-1) 280 end = end.unsqueeze(-2) 281 282 logits = start + end # shape: (batch_size, seq_length (start), seq_length (end)) 283 284 # The mask here onwards is the same for all instances in the batch 285 # As such we do away with the batch dimension 286 mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=logits.device) 287 mask = torch.triu(mask) # End shouldn't be before start 288 masked_logits = torch.where(mask, logits, -torch.inf) 289 probabilities = torch.sigmoid(masked_logits * self.calibration_factor) 290 291 flat_probabilities = probabilities.flatten(-2, -1) # necessary for top-k 292 293 # top-k can return invalid candidates as well if answers_per_seq > num_valid_candidates 294 # We only keep probability > 0 candidates later on 295 candidates = torch.topk(flat_probabilities, answers_per_seq) 296 seq_length = logits.shape[-1] 297 start_candidates = candidates.indices // seq_length # Recover indices from flattening 298 end_candidates = candidates.indices % seq_length 299 candidates_values = candidates.values.cpu() 300 start_candidates = start_candidates.cpu() 301 end_candidates = end_candidates.cpu() 302 303 start_candidates_tokens_to_chars = [] 304 end_candidates_tokens_to_chars = [] 305 for i, (s_candidates, e_candidates, encoding) in enumerate( 306 zip(start_candidates, end_candidates, encodings, strict=True) 307 ): 308 # Those with probabilities > 0 are valid 309 valid = candidates_values[i] > 0 310 s_char_spans = [] 311 e_char_spans = [] 312 for start_token, end_token in zip(s_candidates[valid], e_candidates[valid], strict=True): 313 # token_to_chars returns `None` for special tokens 314 # But we shouldn't have special tokens in the answers at this point 315 # The whole span is given by the start of the start_token (index 0) 316 # and the end of the end token (index 1) 317 s_char_spans.append(encoding.token_to_chars(start_token)[0]) 318 e_char_spans.append(encoding.token_to_chars(end_token)[1]) 319 start_candidates_tokens_to_chars.append(s_char_spans) 320 end_candidates_tokens_to_chars.append(e_char_spans) 321 322 return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, candidates_values 323 324 def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer: 325 if answer.meta is None: 326 answer.meta = {} 327 328 if answer.document_offset is None: 329 return answer 330 331 if not answer.document or "page_number" not in answer.document.meta: 332 return answer 333 334 if not isinstance(answer.document.meta["page_number"], int): 335 logger.warning( 336 "Document's page_number must be int but is {type}. No page number will be added to the answer.", 337 type=type(answer.document.meta["page_number"]), 338 ) 339 return answer 340 341 # Calculate the answer page number 342 if answer.document.content: 343 ans_start = answer.document_offset.start 344 answer_page_number = answer.document.meta["page_number"] + answer.document.content[:ans_start].count("\f") 345 answer.meta.update({"answer_page_number": answer_page_number}) 346 347 return answer 348 349 def _nest_answers( 350 self, 351 *, 352 start: list[list[int]], 353 end: list[list[int]], 354 probabilities: "torch.Tensor", 355 flattened_documents: list[Document], 356 queries: list[str], 357 answers_per_seq: int, 358 top_k: int | None, 359 score_threshold: float | None, 360 query_ids: list[int], 361 document_ids: list[int], 362 no_answer: bool, 363 overlap_threshold: float | None, 364 ) -> list[list[ExtractedAnswer]]: 365 """ 366 Reconstructs the nested structure that existed before flattening. 367 368 Also computes a no answer score. This score is different from most other implementations because it does not 369 consider the no answer logit introduced with SQuAD 2. Instead, it just computes the probability that the 370 answer does not exist in the top k or top p. 371 """ 372 answers_without_query = [] 373 for document_id, start_candidates_, end_candidates_, probabilities_ in zip( 374 document_ids, start, end, probabilities, strict=True 375 ): 376 for start_, end_, probability in zip(start_candidates_, end_candidates_, probabilities_, strict=True): 377 doc = flattened_documents[document_id] 378 answers_without_query.append( 379 ExtractedAnswer( 380 query="", # Can't be None but we'll add it later 381 data=doc.content[start_:end_], # type: ignore 382 document=doc, 383 score=probability.item(), 384 document_offset=ExtractedAnswer.Span(start_, end_), 385 meta={}, 386 ) 387 ) 388 i = 0 389 nested_answers = [] 390 for query_id in range(query_ids[-1] + 1): 391 current_answers = [] 392 while i < len(answers_without_query) and query_ids[i // answers_per_seq] == query_id: 393 current_answers.append(replace(answers_without_query[i], query=queries[query_id])) 394 i += 1 395 current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True) 396 current_answers = self.deduplicate_by_overlap(current_answers, overlap_threshold=overlap_threshold) 397 current_answers = current_answers[:top_k] 398 399 # Calculate the answer page number and add it to meta 400 current_answers = [self._add_answer_page_number(answer=answer) for answer in current_answers] 401 402 if no_answer: 403 no_answer_score = math.prod(1 - answer.score for answer in current_answers) 404 answer_ = ExtractedAnswer( 405 data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score 406 ) 407 current_answers.append(answer_) 408 current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True) 409 if score_threshold is not None: 410 current_answers = [answer for answer in current_answers if answer.score >= score_threshold] 411 nested_answers.append(current_answers) 412 413 return nested_answers 414 415 def _calculate_overlap(self, answer1_start: int, answer1_end: int, answer2_start: int, answer2_end: int) -> int: 416 """ 417 Calculates the amount of overlap (in number of characters) between two answer offsets. 418 419 This Stack overflow 420 [post](https://stackoverflow.com/questions/325933/determine-whether-two-date-ranges-overlap/325964#325964) 421 explains how to calculate the overlap between two ranges. 422 """ 423 # Check for overlap: (StartA <= EndB) and (StartB <= EndA) 424 if answer1_start <= answer2_end and answer2_start <= answer1_end: 425 return min( 426 answer1_end - answer1_start, 427 answer1_end - answer2_start, 428 answer2_end - answer1_start, 429 answer2_end - answer2_start, 430 ) 431 return 0 432 433 def _should_keep( 434 self, candidate_answer: ExtractedAnswer, current_answers: list[ExtractedAnswer], overlap_threshold: float 435 ) -> bool: 436 """ 437 Determines if the answer should be kept based on how much it overlaps with previous answers. 438 439 NOTE: We might want to avoid throwing away answers that only have a few character (or word) overlap: 440 - E.g. The answers "the river in" and "in Maine" from the context "I want to go to the river in Maine." 441 might both want to be kept. 442 443 :param candidate_answer: 444 Candidate answer that will be checked if it should be kept. 445 :param current_answers: 446 Current list of answers that will be kept. 447 :param overlap_threshold: 448 If the overlap between two answers is greater than this threshold then return False. 449 """ 450 keep = True 451 452 # If the candidate answer doesn't have a document keep it 453 if not candidate_answer.document: 454 return keep 455 456 for ans in current_answers: 457 # If an answer in current_answers doesn't have a document skip the comparison 458 if not ans.document: 459 continue 460 461 # If offset is missing then keep both 462 if ans.document_offset is None: 463 continue 464 465 # If offset is missing then keep both 466 if candidate_answer.document_offset is None: 467 continue 468 469 # If the answers come from different documents then keep both 470 if candidate_answer.document.id != ans.document.id: 471 continue 472 473 overlap_len = self._calculate_overlap( 474 answer1_start=ans.document_offset.start, 475 answer1_end=ans.document_offset.end, 476 answer2_start=candidate_answer.document_offset.start, 477 answer2_end=candidate_answer.document_offset.end, 478 ) 479 480 # If overlap is 0 then keep 481 if overlap_len == 0: 482 continue 483 484 overlap_frac_answer1 = overlap_len / (ans.document_offset.end - ans.document_offset.start) 485 overlap_frac_answer2 = overlap_len / ( 486 candidate_answer.document_offset.end - candidate_answer.document_offset.start 487 ) 488 489 if overlap_frac_answer1 > overlap_threshold or overlap_frac_answer2 > overlap_threshold: 490 keep = False 491 break 492 493 return keep 494 495 def deduplicate_by_overlap( 496 self, answers: list[ExtractedAnswer], overlap_threshold: float | None 497 ) -> list[ExtractedAnswer]: 498 """ 499 De-duplicates overlapping Extractive Answers. 500 501 De-duplicates overlapping Extractive Answers from the same document based on how much the spans of the 502 answers overlap. 503 504 :param answers: 505 List of answers to be deduplicated. 506 :param overlap_threshold: 507 If set this will remove duplicate answers if they have an overlap larger than the 508 supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove 509 one of these answers since the second answer has a 100% (1.0) overlap with the first answer. 510 However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so 511 both of these answers could be kept if this variable is set to 0.24 or lower. 512 If None is provided then all answers are kept. 513 :returns: 514 List of deduplicated answers. 515 """ 516 if overlap_threshold is None: 517 return answers 518 519 # Initialize with the first answer and its offsets_in_document 520 deduplicated_answers = [answers[0]] 521 522 # Loop over remaining answers to check for overlaps 523 for ans in answers[1:]: 524 keep = self._should_keep( 525 candidate_answer=ans, current_answers=deduplicated_answers, overlap_threshold=overlap_threshold 526 ) 527 if keep: 528 deduplicated_answers.append(ans) 529 530 return deduplicated_answers 531 532 @component.output_types(answers=list[ExtractedAnswer]) 533 def run( 534 self, 535 query: str, 536 documents: list[Document], 537 top_k: int | None = None, 538 score_threshold: float | None = None, 539 max_seq_length: int | None = None, 540 stride: int | None = None, 541 max_batch_size: int | None = None, 542 answers_per_seq: int | None = None, 543 no_answer: bool | None = None, 544 overlap_threshold: float | None = None, 545 ) -> dict[str, Any]: 546 """ 547 Locates and extracts answers from the given Documents using the given query. 548 549 :param query: 550 Query string. 551 :param documents: 552 List of Documents in which you want to search for an answer to the query. 553 :param top_k: 554 The maximum number of answers to return. 555 An additional answer is returned if no_answer is set to True (default). 556 :param score_threshold: 557 Returns only answers with the score above this threshold. 558 :param max_seq_length: 559 Maximum number of tokens. If a sequence exceeds it, the sequence is split. 560 :param stride: 561 Number of tokens that overlap when sequence is split because it exceeds max_seq_length. 562 :param max_batch_size: 563 Maximum number of samples that are fed through the model at the same time. 564 :param answers_per_seq: 565 Number of answer candidates to consider per sequence. 566 This is relevant when a Document was split into multiple sequences because of max_seq_length. 567 :param no_answer: 568 Whether to return no answer scores. 569 :param overlap_threshold: 570 If set this will remove duplicate answers if they have an overlap larger than the 571 supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove 572 one of these answers since the second answer has a 100% (1.0) overlap with the first answer. 573 However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so 574 both of these answers could be kept if this variable is set to 0.24 or lower. 575 If None is provided then all answers are kept. 576 :returns: 577 List of answers sorted by (desc.) answer score. 578 """ 579 if self.model is None: 580 self.warm_up() 581 582 if not documents: 583 return {"answers": []} 584 585 queries = [query] # Temporary solution until we have decided what batching should look like in v2 586 nested_documents = [documents] 587 top_k = top_k or self.top_k 588 score_threshold = score_threshold or self.score_threshold 589 max_seq_length = max_seq_length or self.max_seq_length 590 stride = stride or self.stride 591 max_batch_size = max_batch_size or self.max_batch_size 592 answers_per_seq = answers_per_seq or self.answers_per_seq or 20 593 no_answer = no_answer if no_answer is not None else self.no_answer 594 overlap_threshold = overlap_threshold or self.overlap_threshold 595 596 flattened_queries, flattened_documents, query_ids = ExtractiveReader._flatten_documents( 597 queries, nested_documents 598 ) 599 input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess( 600 queries=flattened_queries, 601 documents=flattened_documents, 602 max_seq_length=max_seq_length, 603 query_ids=query_ids, 604 stride=stride, 605 ) 606 607 num_batches = math.ceil(input_ids.shape[0] / max_batch_size) if max_batch_size else 1 608 batch_size = max_batch_size or input_ids.shape[0] 609 610 start_logits_list = [] 611 end_logits_list = [] 612 613 for i in range(num_batches): 614 start_index = i * batch_size 615 end_index = start_index + batch_size 616 cur_input_ids = input_ids[start_index:end_index] 617 cur_attention_mask = attention_mask[start_index:end_index] 618 619 with torch.inference_mode(): 620 # mypy doesn't know this is set in warm_up 621 output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask) # type: ignore[misc] 622 cur_start_logits = output.start_logits 623 cur_end_logits = output.end_logits 624 if num_batches != 1: 625 cur_start_logits = cur_start_logits.cpu() 626 cur_end_logits = cur_end_logits.cpu() 627 start_logits_list.append(cur_start_logits) 628 end_logits_list.append(cur_end_logits) 629 630 start_logits = torch.cat(start_logits_list) 631 end_logits = torch.cat(end_logits_list) 632 633 start, end, probabilities = self._postprocess( 634 start=start_logits, 635 end=end_logits, 636 sequence_ids=sequence_ids, 637 attention_mask=attention_mask, 638 answers_per_seq=answers_per_seq, 639 encodings=encodings, 640 ) 641 642 answers = self._nest_answers( 643 start=start, 644 end=end, 645 probabilities=probabilities, 646 flattened_documents=flattened_documents, 647 queries=queries, 648 answers_per_seq=answers_per_seq, 649 top_k=top_k, 650 score_threshold=score_threshold, 651 query_ids=query_ids, 652 document_ids=document_ids, 653 no_answer=no_answer, 654 overlap_threshold=overlap_threshold, 655 ) 656 657 return {"answers": answers[0]} # same temporary batching fix as above