/ src / evidently / llm / datagen / rag.py
rag.py
  1  import random
  2  from abc import ABC
  3  from math import ceil
  4  from typing import Any
  5  from typing import ClassVar
  6  from typing import List
  7  from typing import Optional
  8  from typing import Sequence
  9  from typing import Tuple
 10  from typing import Union
 11  
 12  import pandas as pd
 13  
 14  from evidently.legacy.options.base import AnyOptions
 15  from evidently.legacy.options.base import Options
 16  from evidently.llm.datagen.base import BaseLLMDatasetGenerator
 17  from evidently.llm.datagen.base import DatasetGeneratorResult
 18  from evidently.llm.datagen.config import GenerationSpec
 19  from evidently.llm.datagen.config import ServiceSpec
 20  from evidently.llm.datagen.config import UserProfile
 21  from evidently.llm.rag.index import DataCollection
 22  from evidently.llm.rag.index import DataCollectionProvider
 23  from evidently.llm.rag.splitter import Chunk
 24  from evidently.llm.rag.splitter import ChunkSet
 25  from evidently.llm.utils.blocks import PromptBlock
 26  from evidently.llm.utils.prompt_render import PreparedTemplate
 27  from evidently.llm.utils.templates import StrPromptTemplate
 28  from evidently.llm.utils.templates import WithSystemPrompt
 29  from evidently.llm.utils.templates import prompt_contract
 30  
 31  RAGQuery = str
 32  RAGResponse = str
 33  RAGGeneration = Tuple[RAGQuery, RAGResponse, Chunk]
 34  
 35  
 36  class RagQueryPromptTemplate(WithSystemPrompt, StrPromptTemplate):
 37      system_prompt: str = "You are an assistant who generates questions based on provided context"
 38  
 39      query_spec: ClassVar[GenerationSpec]
 40      additional_prompt_blocks: ClassVar[List[PromptBlock]]
 41  
 42      @prompt_contract
 43      def generate(self, context: str, number: int) -> List[str]:
 44          """
 45          {query_spec}
 46  
 47          {additional_prompt_blocks}
 48  
 49          Here is a context
 50          {% input(context,tag=True) %}
 51  
 52          {% datagen_instruction('{number}') %}
 53  
 54          {% output_string_list(query_spec.kind, tagged=True) %}
 55          """
 56          return []
 57  
 58  
 59  class RagResponsePromptTemplate(WithSystemPrompt, StrPromptTemplate):
 60      system_prompt: str = "You are a helpful assistant that answer a given question directly without any preamble"
 61  
 62      query_spec: ClassVar[GenerationSpec]
 63      response_spec: ClassVar[GenerationSpec]
 64      additional_prompt_blocks: ClassVar[List[PromptBlock]]
 65  
 66      @prompt_contract
 67      def generate(self, input_value: str, context: str):
 68          """
 69          {response_spec}
 70          {additional_prompt_blocks}
 71  
 72          Your task is to generate {response_spec.kind} to the following {query_spec.kind}:
 73  
 74          {% input(input_value, tag=True) %}
 75  
 76          You have access to the following documents which are meant to provide context as you answer the query:
 77  
 78          {% input(context,tag=True) %}
 79  
 80          Please remain faithful to the underlying context,
 81          and deviate from it only if you haven't found the answer in the provided context.
 82          Avoid providing any preamble!
 83          Avoid providing any closing statement!,
 84  
 85          {% output_string(response_spec.kind) %}
 86          """
 87  
 88  
 89  def generate_chunksets(documents: DataCollection, count: int, chunks_per_set: int) -> List[ChunkSet]:
 90      return [[random.choice(documents.chunks) for _ in range(chunks_per_set)] for _ in range(count)]
 91  
 92  
 93  class BaseRagDatasetGenerator(BaseLLMDatasetGenerator, ABC):
 94      data_collection: DataCollectionProvider
 95  
 96  
 97  class RagQueryDatasetGenerator(BaseRagDatasetGenerator):
 98      """Dataset generator for RAG queries based on document chunks.
 99  
100      Generates questions/queries that are relevant to provided document chunks,
101      useful for creating evaluation datasets for RAG systems.
102      """
103  
104      query_template: RagQueryPromptTemplate = RagQueryPromptTemplate()
105      """Prompt template for generating queries."""
106      query_spec: GenerationSpec = GenerationSpec(kind="question")
107      """Specification for query generation."""
108      additional_prompt_blocks: List[PromptBlock] = []
109      """Additional prompt blocks to include."""
110      count: int = 10
111      """Number of queries to generate."""
112      chunks_per_query: int = 5
113      """Number of document chunks to use per query."""
114  
115      def __init__(
116          self,
117          data_collection: DataCollectionProvider,
118          count: int = 10,
119          model="gpt-4o-mini",
120          provider="openai",
121          options: AnyOptions = None,
122          complexity: Optional[str] = None,
123          query_spec: GenerationSpec = GenerationSpec(kind="question"),
124          user: Union[str, UserProfile, None] = None,
125          service: Union[str, ServiceSpec, None] = None,
126          chunks_per_query: int = 5,
127          additional_prompt_blocks: Optional[List[PromptBlock]] = None,
128          query_template: Union[str, RagQueryPromptTemplate, None] = None,
129          **data: Any,
130      ):
131          self.data_collection = data_collection
132          self.count = count
133          self.chunks_per_query = chunks_per_query
134          additional: List[PromptBlock] = additional_prompt_blocks or []
135          if user is not None:
136              additional.append(user if isinstance(user, UserProfile) else UserProfile(role=user))
137          if service is not None:
138              additional.append(
139                  service if isinstance(service, ServiceSpec) else ServiceSpec(kind="RAG", description=service)
140              )
141          self.additional_prompt_blocks = additional
142  
143          if query_spec is not None:
144              self.query_spec = query_spec
145          else:
146              self.query_spec = GenerationSpec(kind="question", complexity=complexity or "medium")
147  
148          self.provider = provider
149          self.options = Options.from_any_options(options)
150          self.model = model
151          if isinstance(query_template, str):
152              self.query_template = RagQueryPromptTemplate(prompt_template=query_template)
153          else:
154              self.query_template = query_template or RagQueryPromptTemplate()
155          super().__init__(**data)
156  
157      def get_chunks_and_query_count(self, all_chunks_count: int) -> Tuple[int, int, int]:
158          """Calculate chunk and query counts for generation.
159  
160          Args:
161          * `all_chunks_count`: Total number of available chunks.
162  
163          Returns:
164          * Tuple of (chunk_set_count, chunks_in_set_count, questions_per_chunkset).
165          """
166          questions_per_chunkset = min(10, self.count)
167          chunk_set_count = ceil(self.count / questions_per_chunkset)
168          chunks_in_set_count = min(self.chunks_per_query, all_chunks_count)
169          return chunk_set_count, chunks_in_set_count, questions_per_chunkset
170  
171      async def generate_queries_with_context(self) -> Tuple[DataCollection, List[RAGQuery]]:
172          """Generate queries along with their document context.
173  
174          Returns:
175          * Tuple of (data_collection, list of generated queries).
176          """
177          documents = self.data_collection.get_data_collection()
178          chunk_set_count, chunks_in_set_count, questions_per_chunkset = self.get_chunks_and_query_count(
179              len(documents.chunks)
180          )
181          chunk_sets = generate_chunksets(documents, chunk_set_count, chunks_in_set_count)
182          queries: List[RAGQuery] = await self.generate_queries(chunk_sets, questions_per_chunkset)
183          return documents, queries
184  
185      async def agenerate(self) -> DatasetGeneratorResult:
186          """Generate dataset asynchronously.
187  
188          Returns:
189          * `pd.DataFrame` with generated queries in a "queries" column.
190          """
191          _, queries = await self.generate_queries_with_context()
192          return pd.DataFrame({"queries": queries})
193  
194      async def generate_queries(self, chunk_sets: Sequence[List[Chunk]], questions_per_chunkset: int) -> List[RAGQuery]:
195          """Generate queries from chunk sets.
196  
197          Args:
198          * `chunk_sets`: Sequence of chunk lists to generate queries from.
199          * `questions_per_chunkset`: Number of questions to generate per chunk set.
200  
201          Returns:
202          * List of generated query strings.
203          """
204          with self.query_template.with_context(
205              query_spec=self.query_spec, additional_prompt_blocks=self.additional_prompt_blocks
206          ):
207              requests = [
208                  self.query_template.generate(context="\n\n".join(chunks), number=questions_per_chunkset)
209                  for chunks in chunk_sets
210              ]
211          questions = await self.wrapper.run_batch(requests)
212          return [q for qs in questions for q in qs][: self.count]
213  
214      @property
215      def prepared_query_template(self) -> PreparedTemplate:
216          """Get the prepared query template with all context applied.
217  
218          Returns:
219          * `PreparedTemplate` ready for use with the LLM.
220          """
221          return self.query_template.prepare(
222              query_spec=self.query_spec, additional_prompt_blocks=self.additional_prompt_blocks
223          )
224  
225  
226  class RagResponseDatasetGenerator(BaseRagDatasetGenerator):
227      """Dataset generator for RAG responses based on queries and document chunks.
228  
229      Generates responses to queries using relevant document chunks as context,
230      useful for creating evaluation datasets for RAG systems.
231      """
232  
233      response_template: RagResponsePromptTemplate = RagResponsePromptTemplate()
234      """Prompt template for generating responses."""
235      response_spec: GenerationSpec = GenerationSpec(kind="response")
236      """Specification for response generation."""
237      query_spec: GenerationSpec = GenerationSpec(kind="question")
238      """Specification for query type."""
239      queries: List[RAGQuery]
240      """List of queries to generate responses for."""
241      additional_prompt_blocks: List[PromptBlock] = []
242      """Additional prompt blocks to include."""
243      include_context: bool = False
244      """Whether to include context chunks in output DataFrame."""
245  
246      def __init__(
247          self,
248          data_collection: DataCollectionProvider,
249          model="gpt-4o-mini",
250          provider="openai",
251          options: AnyOptions = None,
252          include_context: bool = False,
253          complexity: Optional[str] = None,
254          query_spec: GenerationSpec = GenerationSpec(kind="question"),
255          response_spec: GenerationSpec = GenerationSpec(kind="answer"),
256          user: Union[str, UserProfile, None] = None,
257          service: Union[str, ServiceSpec, None] = None,
258          additional_prompt_blocks: Optional[List[PromptBlock]] = None,
259          response_template: Union[str, RagResponsePromptTemplate, None] = None,
260          **data: Any,
261      ):
262          self.data_collection = data_collection
263          self.include_context = include_context
264          additional: List[PromptBlock] = additional_prompt_blocks or []
265          if user is not None:
266              additional.append(user if isinstance(user, UserProfile) else UserProfile(role=user))
267          if service is not None:
268              additional.append(
269                  service if isinstance(service, ServiceSpec) else ServiceSpec(kind="RAG", description=service)
270              )
271          self.additional_prompt_blocks = additional
272  
273          if query_spec is not None:
274              self.query_spec = query_spec
275          else:
276              self.query_spec = GenerationSpec(kind="question", complexity=complexity or "medium")
277  
278          if response_spec is not None:
279              self.response_spec = response_spec
280          else:
281              self.response_spec = GenerationSpec(kind="answer", complexity=complexity or "medium")
282  
283          self.provider = provider
284          self.options = Options.from_any_options(options)
285          self.model = model
286          if isinstance(response_template, str):
287              self.response_template = RagResponsePromptTemplate(prompt_template=response_template)
288          else:
289              self.response_template = response_template or RagResponsePromptTemplate()
290          super().__init__(**data)
291  
292      async def agenerate(self) -> DatasetGeneratorResult:
293          """Generate dataset asynchronously.
294  
295          Returns:
296          * `pd.DataFrame` with generated responses (and optionally context).
297          """
298          documents = self.data_collection.get_data_collection()
299          relevant_chunks = [documents.find_relevant_chunks(q) for q in self.queries]
300          responses = await self.generate_responses(self.queries, relevant_chunks)
301          data = {"responses": responses}
302          if self.include_context:
303              data["context"] = [";".join(chunks) for chunks in relevant_chunks]
304          return pd.DataFrame(data)
305  
306      async def generate_responses(self, queries: List[RAGQuery], relevant_chunks: List[List[Chunk]]) -> List[str]:
307          """Generate responses for queries using relevant chunks as context.
308  
309          Args:
310          * `queries`: List of query strings.
311          * `relevant_chunks`: List of chunk lists, one per query.
312  
313          Returns:
314          * List of generated response strings.
315          """
316          with self.response_template.with_context(
317              response_spec=self.response_spec,
318              query_spec=self.query_spec,
319              additional_prompt_blocks=self.additional_prompt_blocks,
320          ):
321              requests = [
322                  self.response_template.generate(input_value=question, context="\n".join(chunks))
323                  for question, chunks in zip(queries, relevant_chunks)
324              ]
325          return await self.wrapper.run_batch(requests)
326  
327      @property
328      def prepared_response_template(self) -> PreparedTemplate:
329          """Get the prepared response template with all context applied.
330  
331          Returns:
332          * `PreparedTemplate` ready for use with the LLM.
333          """
334          return self.response_template.prepare(
335              query_spec=self.query_spec,
336              additional_prompt_blocks=self.additional_prompt_blocks,
337              response_spec=self.response_spec,
338          )
339  
340  
341  class RagDatasetGenerator(BaseRagDatasetGenerator):
342      """Complete RAG dataset generator for query-response pairs.
343  
344      Generates both queries and responses using document chunks, creating
345      a complete evaluation dataset for RAG systems.
346      """
347  
348      query_template: RagQueryPromptTemplate = RagQueryPromptTemplate()
349      """Prompt template for generating queries."""
350      query_spec: GenerationSpec = GenerationSpec(kind="question")
351      """Specification for query generation."""
352      response_spec: GenerationSpec = GenerationSpec(kind="response")
353      """Specification for response generation."""
354      response_template: RagResponsePromptTemplate = RagResponsePromptTemplate()
355      """Prompt template for generating responses."""
356      additional_prompt_blocks: List[PromptBlock] = []
357      """Additional prompt blocks to include."""
358      include_context: bool = False
359      """Whether to include context chunks in output DataFrame."""
360      count: int
361      """Number of query-response pairs to generate."""
362  
363      def __init__(
364          self,
365          data_collection: DataCollectionProvider,
366          count: int = 10,
367          model="gpt-4o-mini",
368          provider="openai",
369          options: AnyOptions = None,
370          include_context: bool = False,
371          complexity: Optional[str] = None,
372          query_spec: GenerationSpec = GenerationSpec(kind="question"),
373          response_spec: GenerationSpec = GenerationSpec(kind="response"),
374          user: Union[str, UserProfile, None] = None,
375          service: Union[str, ServiceSpec, None] = None,
376          additional_prompt_blocks: Optional[List[PromptBlock]] = None,
377          query_template: Union[str, RagQueryPromptTemplate, None] = None,
378          response_template: Union[str, RagResponsePromptTemplate, None] = None,
379          **data: Any,
380      ):
381          self.data_collection = data_collection
382          self.include_context = include_context
383          self.count = count
384          additional: List[PromptBlock] = additional_prompt_blocks or []
385          if user is not None:
386              additional.append(user if isinstance(user, UserProfile) else UserProfile(role=user))
387          if service is not None:
388              additional.append(service if isinstance(service, ServiceSpec) else ServiceSpec(kind="RAG", purpose=service))
389          self.additional_prompt_blocks = additional
390  
391          if query_spec is not None:
392              self.query_spec = query_spec
393          else:
394              self.query_spec = GenerationSpec(kind="question", complexity=complexity or "medium")
395  
396          if response_spec is not None:
397              self.response_spec = response_spec
398          else:
399              self.response_spec = GenerationSpec(kind="answer", complexity=complexity or "medium")
400  
401          self.provider = provider
402          self.options = Options.from_any_options(options)
403          self.model = model
404          if isinstance(query_template, str):
405              self.query_template = RagQueryPromptTemplate(prompt_template=query_template)
406          else:
407              self.query_template = query_template or RagQueryPromptTemplate()
408          if isinstance(response_template, str):
409              self.response_template = RagResponsePromptTemplate(prompt_template=response_template)
410          else:
411              self.response_template = response_template or RagResponsePromptTemplate()
412          super().__init__(**data)
413  
414      async def agenerate(self) -> DatasetGeneratorResult:
415          """Generate complete RAG dataset with queries and responses.
416  
417          Returns:
418          * `pd.DataFrame` with "queries" and "responses" columns (and optionally "context").
419          """
420          documents, queries = await self.query_generator.generate_queries_with_context()
421          relevant_chunks = [documents.find_relevant_chunks(q) for q in queries]
422  
423          response_generator = self.response_generator(queries)
424          responses = await response_generator.generate_responses(queries, relevant_chunks)
425          data = {"queries": queries, "responses": responses}
426          if self.include_context:
427              data["context"] = [";".join(chunks) for chunks in relevant_chunks]
428          return pd.DataFrame(data)
429  
430      @property
431      def query_generator(self) -> RagQueryDatasetGenerator:
432          """Get the query generator instance.
433  
434          Returns:
435          * `RagQueryDatasetGenerator` configured with this generator's settings.
436          """
437          return RagQueryDatasetGenerator(
438              data_collection=self.data_collection,
439              count=self.count,
440              query_spec=self.query_spec,
441              query_template=self.query_template,
442              additional_prompt_blocks=self.additional_prompt_blocks,
443              options=self.options,
444              provider=self.provider,
445              model=self.model,
446          )
447  
448      def response_generator(self, queries: List[RAGQuery]) -> RagResponseDatasetGenerator:
449          """Get the response generator instance for given queries.
450  
451          Args:
452          * `queries`: List of queries to generate responses for.
453  
454          Returns:
455          * `RagResponseDatasetGenerator` configured with this generator's settings.
456          """
457          return RagResponseDatasetGenerator(
458              data_collection=self.data_collection,
459              query_spec=self.query_spec,
460              response_spec=self.response_spec,
461              response_template=self.response_template,
462              additional_prompt_blocks=self.additional_prompt_blocks,
463              queries=queries,
464              options=self.options,
465              provider=self.provider,
466              model=self.model,
467          )
468  
469      @property
470      def prepared_query_template(self) -> PreparedTemplate:
471          """Get the prepared query template.
472  
473          Returns:
474          * `PreparedTemplate` ready for use with the LLM.
475          """
476          return self.query_generator.prepared_query_template
477  
478      @property
479      def prepared_response_template(self) -> PreparedTemplate:
480          """Get the prepared response template.
481  
482          Returns:
483          * `PreparedTemplate` ready for use with the LLM.
484          """
485          return self.response_generator([]).prepared_response_template