utils.py
1 # Copyright (c) 2024-2026 Tencent Zhuque Lab. All rights reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # 15 # Requirement: Any integration or derivative work must explicitly attribute 16 # Tencent Zhuque Lab (https://github.com/Tencent/AI-Infra-Guard) in its 17 # documentation or user interface, as detailed in the NOTICE file. 18 19 from pydantic import BaseModel 20 from .schema import SyntheticData, SyntheticDataList 21 22 from deepeval.metrics.utils import trimAndLoadJson, initialize_model 23 from deepeval.models import DeepEvalBaseLLM 24 25 26 def generate_schema( 27 prompt: str, 28 schema: BaseModel, 29 model: DeepEvalBaseLLM = None, 30 ) -> BaseModel: 31 """ 32 Generate schema using the provided model. 33 34 Args: 35 prompt: The prompt to send to the model 36 schema: The schema to validate the response against 37 model: The model to use 38 39 Returns: 40 The validated schema object 41 """ 42 _, using_native_model = initialize_model(model=model) 43 44 if using_native_model: 45 res, _ = model.generate(prompt, schema=schema) 46 return res 47 else: 48 try: 49 res = model.generate(prompt, schema=schema) 50 return res 51 except TypeError: 52 res = model.generate(prompt) 53 data = trimAndLoadJson(res) 54 if schema == SyntheticDataList: 55 data_list = [SyntheticData(**item) for item in data["data"]] 56 return SyntheticDataList(data=data_list) 57 else: 58 return schema(**data) 59 60 61 async def a_generate_schema( 62 prompt: str, 63 schema: BaseModel, 64 model: DeepEvalBaseLLM = None, 65 ) -> BaseModel: 66 """ 67 Asynchronously generate schema using the provided model. 68 69 Args: 70 prompt: The prompt to send to the model 71 schema: The schema to validate the response against 72 model: The model to use 73 74 Returns: 75 The validated schema object 76 """ 77 _, using_native_model = initialize_model(model=model) 78 79 if using_native_model: 80 res, _ = await model.a_generate(prompt, schema=schema) 81 return res 82 else: 83 try: 84 res = await model.a_generate(prompt, schema=schema) 85 return res 86 except TypeError: 87 res = await model.a_generate(prompt) 88 data = trimAndLoadJson(res) 89 if schema == SyntheticDataList: 90 data_list = [SyntheticData(**item) for item in data["data"]] 91 return SyntheticDataList(data=data_list) 92 else: 93 return schema(**data)