/ test / components / samplers / test_top_p.py
test_top_p.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import random
  6  
  7  import pytest
  8  
  9  from haystack import Document
 10  from haystack.components.samplers.top_p import TopPSampler
 11  
 12  
 13  @pytest.fixture
 14  def documents_with_score_field() -> list[Document]:
 15      return [
 16          Document(content="Sarajevo", meta={"similarity_score": 0.7}),
 17          Document(content="Belgrade", meta={"similarity_score": 0.01}),
 18          Document(content="Berlin", meta={"similarity_score": 0.001}),
 19      ]
 20  
 21  
 22  @pytest.fixture
 23  def documents_with_score() -> list[Document]:
 24      return [
 25          Document(content="Sarajevo", score=0.7),
 26          Document(content="Belgrade", score=0.01),
 27          Document(content="Berlin", score=0.001),
 28      ]
 29  
 30  
 31  class TestTopPSampler:
 32      def test_init_raises_value_error(self) -> None:
 33          with pytest.raises(ValueError):
 34              TopPSampler(top_p=2.0)
 35  
 36      def test_run_raises_value_error(self, documents_with_score: list[Document]) -> None:
 37          sampler = TopPSampler(top_p=0.95)
 38          with pytest.raises(ValueError):
 39              sampler.run(documents=documents_with_score, top_p=2.0)
 40  
 41      def test_run_score_field(self, documents_with_score_field: list[Document]) -> None:
 42          sampler = TopPSampler(top_p=0.95, score_field="similarity_score")
 43          docs = documents_with_score_field
 44          output = sampler.run(documents=docs)
 45          docs = output["documents"]
 46          assert len(docs) == 2
 47          assert docs[0].content == "Sarajevo"
 48          assert docs[1].content == "Belgrade"
 49  
 50      def test_run_score_field_missing_scores(self, caplog: pytest.LogCaptureFixture) -> None:
 51          sampler = TopPSampler(top_p=1.0, score_field="similarity_score")
 52          docs = [
 53              Document(content="Sarajevo", meta={"similarity_score": 0.7}),
 54              Document(content="Belgrade", meta={"similarity_score": 0.01}),
 55              Document(content="Berlin", meta={"similarity_score": None}),
 56          ]
 57          output = sampler.run(documents=docs)
 58          docs = output["documents"]
 59          assert len(docs) == 2
 60          assert docs[0].content == "Sarajevo"
 61          assert docs[1].content == "Belgrade"
 62          assert "Score field" in caplog.text
 63  
 64      def test_run(self, documents_with_score: list[Document]) -> None:
 65          sampler = TopPSampler(top_p=0.99)
 66          docs = documents_with_score
 67          random.shuffle(docs)
 68          sorted_scores = sorted([doc.score for doc in docs], reverse=True)
 69  
 70          # top_p = 0.99 will get the top 1 document
 71          output = sampler.run(documents=docs)
 72          docs_filtered = output["documents"]
 73          assert len(docs_filtered) == 2
 74          assert docs_filtered[0].content == "Sarajevo"
 75          assert docs_filtered[1].content == "Belgrade"
 76  
 77          assert [doc.score for doc in docs_filtered] == sorted_scores[:2]
 78  
 79      def test_run_top_p_1(self, documents_with_score: list[Document]) -> None:
 80          sampler = TopPSampler(top_p=1.0)
 81          docs = documents_with_score
 82          random.shuffle(docs)
 83          output = sampler.run(documents=docs)
 84          docs_filtered = output["documents"]
 85          assert len(docs_filtered) == len(docs)
 86          assert docs_filtered[0].content == "Sarajevo"
 87          assert [doc.score for doc in docs_filtered] == sorted([doc.score for doc in docs], reverse=True)
 88  
 89      def test_run_top_p_0(self, caplog: pytest.LogCaptureFixture, documents_with_score: list[Document]) -> None:
 90          sampler = TopPSampler(top_p=0.0)
 91          docs = documents_with_score
 92          output = sampler.run(documents=docs)
 93          docs = output["documents"]
 94          assert len(docs) == 1
 95          assert docs[0].content == "Sarajevo"
 96          assert "Top-p sampling with p=" in caplog.text
 97  
 98      def test_run_returns_empty_list_no_documents(self) -> None:
 99          sampler = TopPSampler()
100          output = sampler.run(documents=[])
101          assert output["documents"] == []
102  
103      def test_run_no_score_field(self, caplog: pytest.LogCaptureFixture, documents_with_score: list[Document]) -> None:
104          sampler = TopPSampler(top_p=0.95, score_field="similarity_score")
105          docs = documents_with_score
106          output = sampler.run(documents=docs)
107          docs = output["documents"]
108          assert len(docs) == 3
109          assert docs[0].content == "Sarajevo"
110          assert "Score field 'similarity_score' not found" in caplog.text
111  
112      def test_run_missing_scores(self, caplog: pytest.LogCaptureFixture) -> None:
113          sampler = TopPSampler(top_p=0.95)
114          docs = [
115              Document(content="Sarajevo", score=0.7),
116              Document(content="Belgrade", score=0.01),
117              Document(content="Berlin", score=None),
118          ]
119          output = sampler.run(documents=docs)
120          docs = output["documents"]
121          assert len(docs) == 1
122          assert docs[0].content == "Sarajevo"
123          assert "Ensure all documents have a valid score value" in caplog.text
124  
125      def test_run_min_top_k(self, documents_with_score: list[Document]) -> None:
126          sampler = TopPSampler(min_top_k=2, top_p=0.2)
127          docs = documents_with_score
128          output = sampler.run(documents=docs)
129          docs = output["documents"]
130          assert len(docs) == 2
131          assert docs[0].content == "Sarajevo"
132          assert docs[1].content == "Belgrade"