test_top_p.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import random 6 7 import pytest 8 9 from haystack import Document 10 from haystack.components.samplers.top_p import TopPSampler 11 12 13 @pytest.fixture 14 def documents_with_score_field() -> list[Document]: 15 return [ 16 Document(content="Sarajevo", meta={"similarity_score": 0.7}), 17 Document(content="Belgrade", meta={"similarity_score": 0.01}), 18 Document(content="Berlin", meta={"similarity_score": 0.001}), 19 ] 20 21 22 @pytest.fixture 23 def documents_with_score() -> list[Document]: 24 return [ 25 Document(content="Sarajevo", score=0.7), 26 Document(content="Belgrade", score=0.01), 27 Document(content="Berlin", score=0.001), 28 ] 29 30 31 class TestTopPSampler: 32 def test_init_raises_value_error(self) -> None: 33 with pytest.raises(ValueError): 34 TopPSampler(top_p=2.0) 35 36 def test_run_raises_value_error(self, documents_with_score: list[Document]) -> None: 37 sampler = TopPSampler(top_p=0.95) 38 with pytest.raises(ValueError): 39 sampler.run(documents=documents_with_score, top_p=2.0) 40 41 def test_run_score_field(self, documents_with_score_field: list[Document]) -> None: 42 sampler = TopPSampler(top_p=0.95, score_field="similarity_score") 43 docs = documents_with_score_field 44 output = sampler.run(documents=docs) 45 docs = output["documents"] 46 assert len(docs) == 2 47 assert docs[0].content == "Sarajevo" 48 assert docs[1].content == "Belgrade" 49 50 def test_run_score_field_missing_scores(self, caplog: pytest.LogCaptureFixture) -> None: 51 sampler = TopPSampler(top_p=1.0, score_field="similarity_score") 52 docs = [ 53 Document(content="Sarajevo", meta={"similarity_score": 0.7}), 54 Document(content="Belgrade", meta={"similarity_score": 0.01}), 55 Document(content="Berlin", meta={"similarity_score": None}), 56 ] 57 output = sampler.run(documents=docs) 58 docs = output["documents"] 59 assert len(docs) == 2 60 assert docs[0].content == "Sarajevo" 61 assert docs[1].content == "Belgrade" 62 assert "Score field" in caplog.text 63 64 def test_run(self, documents_with_score: list[Document]) -> None: 65 sampler = TopPSampler(top_p=0.99) 66 docs = documents_with_score 67 random.shuffle(docs) 68 sorted_scores = sorted([doc.score for doc in docs], reverse=True) 69 70 # top_p = 0.99 will get the top 1 document 71 output = sampler.run(documents=docs) 72 docs_filtered = output["documents"] 73 assert len(docs_filtered) == 2 74 assert docs_filtered[0].content == "Sarajevo" 75 assert docs_filtered[1].content == "Belgrade" 76 77 assert [doc.score for doc in docs_filtered] == sorted_scores[:2] 78 79 def test_run_top_p_1(self, documents_with_score: list[Document]) -> None: 80 sampler = TopPSampler(top_p=1.0) 81 docs = documents_with_score 82 random.shuffle(docs) 83 output = sampler.run(documents=docs) 84 docs_filtered = output["documents"] 85 assert len(docs_filtered) == len(docs) 86 assert docs_filtered[0].content == "Sarajevo" 87 assert [doc.score for doc in docs_filtered] == sorted([doc.score for doc in docs], reverse=True) 88 89 def test_run_top_p_0(self, caplog: pytest.LogCaptureFixture, documents_with_score: list[Document]) -> None: 90 sampler = TopPSampler(top_p=0.0) 91 docs = documents_with_score 92 output = sampler.run(documents=docs) 93 docs = output["documents"] 94 assert len(docs) == 1 95 assert docs[0].content == "Sarajevo" 96 assert "Top-p sampling with p=" in caplog.text 97 98 def test_run_returns_empty_list_no_documents(self) -> None: 99 sampler = TopPSampler() 100 output = sampler.run(documents=[]) 101 assert output["documents"] == [] 102 103 def test_run_no_score_field(self, caplog: pytest.LogCaptureFixture, documents_with_score: list[Document]) -> None: 104 sampler = TopPSampler(top_p=0.95, score_field="similarity_score") 105 docs = documents_with_score 106 output = sampler.run(documents=docs) 107 docs = output["documents"] 108 assert len(docs) == 3 109 assert docs[0].content == "Sarajevo" 110 assert "Score field 'similarity_score' not found" in caplog.text 111 112 def test_run_missing_scores(self, caplog: pytest.LogCaptureFixture) -> None: 113 sampler = TopPSampler(top_p=0.95) 114 docs = [ 115 Document(content="Sarajevo", score=0.7), 116 Document(content="Belgrade", score=0.01), 117 Document(content="Berlin", score=None), 118 ] 119 output = sampler.run(documents=docs) 120 docs = output["documents"] 121 assert len(docs) == 1 122 assert docs[0].content == "Sarajevo" 123 assert "Ensure all documents have a valid score value" in caplog.text 124 125 def test_run_min_top_k(self, documents_with_score: list[Document]) -> None: 126 sampler = TopPSampler(min_top_k=2, top_p=0.2) 127 docs = documents_with_score 128 output = sampler.run(documents=docs) 129 docs = output["documents"] 130 assert len(docs) == 2 131 assert docs[0].content == "Sarajevo" 132 assert docs[1].content == "Belgrade"