utils.py
 1  # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved.
 2  #
 3  # Licensed under the Apache License, Version 2.0 (the "License");
 4  # you may not use this file except in compliance with the License.
 5  # You may obtain a copy of the License at
 6  #
 7  #     http://www.apache.org/licenses/LICENSE-2.0
 8  #
 9  # Unless required by applicable law or agreed to in writing, software
10  # distributed under the License is distributed on an "AS IS" BASIS,
11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  # See the License for the specific language governing permissions and
13  # limitations under the License.
14  #
15  # Requirement: Any integration or derivative work must explicitly attribute
16  # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its
17  # documentation or user interface, as detailed in the NOTICE file.
18  
19  from pydantic import BaseModel
20  from .schema import SyntheticData, SyntheticDataList
21  
22  from deepeval.metrics.utils import trimAndLoadJson, initialize_model
23  from deepeval.models import DeepEvalBaseLLM
24  
25  
26  def generate_schema(
27      prompt: str,
28      schema: BaseModel,
29      model: DeepEvalBaseLLM = None,
30  ) -> BaseModel:
31      """
32      Generate schema using the provided model.
33  
34      Args:
35          prompt: The prompt to send to the model
36          schema: The schema to validate the response against
37          model: The model to use
38  
39      Returns:
40          The validated schema object
41      """
42      _, using_native_model = initialize_model(model=model)
43  
44      if using_native_model:
45          res, _ = model.generate(prompt, schema=schema)
46          return res
47      else:
48          try:
49              res = model.generate(prompt, schema=schema)
50              return res
51          except TypeError:
52              res = model.generate(prompt)
53              data = trimAndLoadJson(res)
54              if schema == SyntheticDataList:
55                  data_list = [SyntheticData(**item) for item in data["data"]]
56                  return SyntheticDataList(data=data_list)
57              else:
58                  return schema(**data)
59  
60  
61  async def a_generate_schema(
62      prompt: str,
63      schema: BaseModel,
64      model: DeepEvalBaseLLM = None,
65  ) -> BaseModel:
66      """
67      Asynchronously generate schema using the provided model.
68  
69      Args:
70          prompt: The prompt to send to the model
71          schema: The schema to validate the response against
72          model: The model to use
73  
74      Returns:
75          The validated schema object
76      """
77      _, using_native_model = initialize_model(model=model)
78  
79      if using_native_model:
80          res, _ = await model.a_generate(prompt, schema=schema)
81          return res
82      else:
83          try:
84              res = await model.a_generate(prompt, schema=schema)
85              return res
86          except TypeError:
87              res = await model.a_generate(prompt)
88              data = trimAndLoadJson(res)
89              if schema == SyntheticDataList:
90                  data_list = [SyntheticData(**item) for item in data["data"]]
91                  return SyntheticDataList(data=data_list)
92              else:
93                  return schema(**data)