data_types.py
1 from typing_extensions import TypedDict, Literal 2 from langchain_core.messages import AnyMessage 3 from datetime import datetime 4 from typing import Annotated, List, Callable 5 import operator 6 from pydantic import BaseModel, Field 7 import json 8 9 10 11 class Trace(TypedDict): 12 node_name: str 13 node_type: str 14 node_input:str 15 node_output:str 16 description: str 17 trace_timestamp:datetime 18 19 20 class ForeignKey(TypedDict): 21 """ Captures details about a foreign key relation to be used for API generation """ 22 links_to_table: str 23 foreign_column: str 24 rel_type:str 25 26 class Column(TypedDict): 27 """ Captures details about a specific column in a table for which API is being generated """ 28 column: str 29 type: str 30 is_primary_key: bool 31 foreign_key: ForeignKey 32 is_unique: bool 33 is_nullable: bool 34 is_uid: bool 35 enum_values: list[str] 36 37 38 class Table(TypedDict): 39 """ Captures details about an individual table for which API is being generated """ 40 individual_prompt: str 41 table_name: str 42 columns: list[Column] 43 44 class DBSchema(TypedDict): 45 """Captures a list of tables for which APIs will get generated""" 46 tables: list[Table] 47 48 class ClassifyUserRequest(BaseModel): 49 """ Classify the user prompt. if user request is a valid task, use "create_crud_task" or "other_tasks". If not use "respond_back" and provide a proper message to the user. """ 50 classification: Literal["respond_back", "create_crud_task", "other_tasks"] 51 message: str 52 53 54 class ApiRoute(TypedDict): 55 uri: str 56 resource_object: str 57 58 class Resource(TypedDict): 59 resource_file_name: str 60 resource_code:str 61 api_route: List[ApiRoute] 62 63 class NextNode(BaseModel): 64 name:Literal["generate_prompt_for_code_generation", "do_stuff", "do_other_stuff", "__end__"] = Field( 65 None, description="The next step in the routing process" 66 ) 67 68 class CodeHistoryMessage(BaseModel): 69 new_code:str = Field( 70 ..., 71 description="The complete code revised by the system. This must include the entire code, not just the part that was changed or fixed." 72 ) 73 what_was_the_problem:str = Field( 74 ..., 75 description="The problem that caused the code to be revised", 76 ) 77 what_is_fixed:str = Field( 78 ..., 79 description="The problem that was fixed by the new code", 80 ) 81 code_type:Literal["resource", "test", "api"] = Field( 82 ..., 83 description="The type of code that was revised", 84 ) 85 86 class GeneratedCode(BaseModel): 87 """ Captures the generated code for a test """ 88 full_test_code:str = Field( 89 ..., 90 description="The full test code generated by the system", 91 ) 92 class CodeHistory(TypedDict): 93 history_type:str # generation, revision 94 code:CodeHistoryMessage 95 test_report_before_revising: str 96 test_report_after_revising:str 97 iteration_index:int 98 test_revising_input_prompt:List[AnyMessage] 99 100 class TestStatus(TypedDict): 101 resource_file_name:str 102 resource_code:str 103 test_generation_input_prompt:List[AnyMessage] 104 test_revising_prompt:str 105 test_file_name: str 106 test_code:str 107 status: str # success, failed, fixed, in_progress 108 messages: list[AnyMessage] 109 code_history:list[CodeHistory] 110 iteration_count: int 111 table:Table 112 113 class State(TypedDict): 114 messages: Annotated[list[AnyMessage], operator.add] 115 trace:Annotated[list, operator.add] 116 resources:Annotated[list[Resource], operator.add] 117 DBSchema:DBSchema 118 next_node:str 119 test_status:list[TestStatus] 120 send: Callable[[dict], None] 121 test_mode:bool 122 classification:str 123 124 class Readme(TypedDict): 125 md_content:str 126 127 128 import json 129 import typing 130 from typing import TypedDict, get_type_hints 131 132 133 def is_typed_dict(cls: type) -> bool: 134 return isinstance(cls, type) and hasattr(cls, '__annotations__') and hasattr(cls, '__total__') 135 136 137 def typed_dict_to_json_schema(typed_dict_cls: type) -> dict: 138 if not is_typed_dict(typed_dict_cls): 139 raise TypeError("Expected a TypedDict class") 140 141 annotations = get_type_hints(typed_dict_cls) 142 required_keys = getattr(typed_dict_cls, '__required_keys__', set()) 143 144 def python_type_to_json_type(tp): 145 origin = typing.get_origin(tp) 146 args = typing.get_args(tp) 147 148 if origin is list: 149 item_type = args[0] if args else typing.Any 150 return { 151 "type": "array", 152 "items": python_type_to_json_type(item_type) 153 } 154 155 if origin is dict: 156 return {"type": "object"} 157 158 if is_typed_dict(tp): 159 return typed_dict_to_json_schema(tp) 160 161 if isinstance(tp, type): 162 if issubclass(tp, str): 163 return {"type": "string"} 164 elif issubclass(tp, int): 165 return {"type": "integer"} 166 elif issubclass(tp, float): 167 return {"type": "number"} 168 elif issubclass(tp, bool): 169 return {"type": "boolean"} 170 171 return {"type": "string"} 172 173 schema = { 174 "type": "object", 175 "properties": {}, 176 } 177 178 required = [] 179 for key, tp in annotations.items(): 180 schema["properties"][key] = python_type_to_json_type(tp) 181 if key in required_keys: 182 required.append(key) 183 184 if required: 185 schema["required"] = required 186 187 return schema 188 189 190 def typed_dict_dump_schema_json(typed_dict_cls: type, **json_kwargs) -> str: 191 schema = typed_dict_to_json_schema(typed_dict_cls) 192 return json.dumps(schema, **json_kwargs) 193 194 195 if __name__ == "__main__": 196 print(typed_dict_dump_schema_json(DBSchema, indent=2)) 197 198