/ src / revolve / data_types.py
data_types.py
  1  from typing_extensions import TypedDict, Literal
  2  from langchain_core.messages import AnyMessage
  3  from datetime import datetime
  4  from typing import Annotated, List, Callable
  5  import operator
  6  from pydantic import BaseModel, Field
  7  import json
  8  
  9  
 10  
 11  class Trace(TypedDict):
 12      node_name: str
 13      node_type: str
 14      node_input:str
 15      node_output:str
 16      description: str
 17      trace_timestamp:datetime
 18  
 19  
 20  class ForeignKey(TypedDict):
 21      """ Captures details about a foreign key relation to be used for API generation """
 22      links_to_table: str
 23      foreign_column: str
 24      rel_type:str
 25  
 26  class Column(TypedDict):
 27      """ Captures details about a specific column in a table for which API is being generated """
 28      column: str
 29      type: str
 30      is_primary_key: bool
 31      foreign_key: ForeignKey
 32      is_unique: bool
 33      is_nullable: bool
 34      is_uid: bool
 35      enum_values: list[str]
 36  
 37  
 38  class Table(TypedDict):
 39      """ Captures details about an individual table for which API is being generated """
 40      individual_prompt: str
 41      table_name: str
 42      columns: list[Column]
 43  
 44  class DBSchema(TypedDict):
 45      """Captures a list of tables for which APIs will get generated"""
 46      tables: list[Table]
 47  
 48  class ClassifyUserRequest(BaseModel):
 49      """ Classify the user prompt. if user request is a valid task, use "create_crud_task" or "other_tasks". If not use "respond_back" and provide a proper message to the user. """
 50      classification: Literal["respond_back", "create_crud_task", "other_tasks"]
 51      message: str
 52  
 53  
 54  class ApiRoute(TypedDict):
 55      uri: str
 56      resource_object: str
 57  
 58  class Resource(TypedDict):
 59      resource_file_name: str
 60      resource_code:str
 61      api_route: List[ApiRoute]
 62  
 63  class NextNode(BaseModel):
 64      name:Literal["generate_prompt_for_code_generation", "do_stuff", "do_other_stuff", "__end__"] = Field(
 65          None, description="The next step in the routing process"
 66      )
 67  
 68  class CodeHistoryMessage(BaseModel):
 69      new_code:str = Field(
 70          ...,
 71          description="The complete code revised by the system. This must include the entire code, not just the part that was changed or fixed."
 72      )
 73      what_was_the_problem:str  = Field(
 74          ...,
 75          description="The problem that caused the code to be revised",
 76      )
 77      what_is_fixed:str = Field(
 78          ...,
 79          description="The problem that was fixed by the new code",
 80      )
 81      code_type:Literal["resource", "test", "api"] = Field(
 82          ...,
 83          description="The type of code that was revised",
 84      )
 85  
 86  class GeneratedCode(BaseModel):
 87      """ Captures the generated code for a test """
 88      full_test_code:str = Field(
 89          ...,
 90          description="The full test code generated by the system",
 91      )
 92  class CodeHistory(TypedDict):
 93      history_type:str # generation, revision
 94      code:CodeHistoryMessage
 95      test_report_before_revising: str
 96      test_report_after_revising:str
 97      iteration_index:int
 98      test_revising_input_prompt:List[AnyMessage]
 99  
100  class TestStatus(TypedDict):
101      resource_file_name:str
102      resource_code:str
103      test_generation_input_prompt:List[AnyMessage]
104      test_revising_prompt:str
105      test_file_name: str
106      test_code:str
107      status: str # success, failed, fixed, in_progress
108      messages: list[AnyMessage]
109      code_history:list[CodeHistory]
110      iteration_count: int
111      table:Table
112   
113  class State(TypedDict):
114      messages: Annotated[list[AnyMessage], operator.add]
115      trace:Annotated[list, operator.add]
116      resources:Annotated[list[Resource], operator.add]
117      DBSchema:DBSchema
118      next_node:str
119      test_status:list[TestStatus]
120      send: Callable[[dict], None]
121      test_mode:bool
122      classification:str
123  
124  class Readme(TypedDict):
125      md_content:str
126  
127  
128  import json
129  import typing
130  from typing import TypedDict, get_type_hints
131  
132  
133  def is_typed_dict(cls: type) -> bool:
134      return isinstance(cls, type) and hasattr(cls, '__annotations__') and hasattr(cls, '__total__')
135  
136  
137  def typed_dict_to_json_schema(typed_dict_cls: type) -> dict:
138      if not is_typed_dict(typed_dict_cls):
139          raise TypeError("Expected a TypedDict class")
140  
141      annotations = get_type_hints(typed_dict_cls)
142      required_keys = getattr(typed_dict_cls, '__required_keys__', set())
143  
144      def python_type_to_json_type(tp):
145          origin = typing.get_origin(tp)
146          args = typing.get_args(tp)
147  
148          if origin is list:
149              item_type = args[0] if args else typing.Any
150              return {
151                  "type": "array",
152                  "items": python_type_to_json_type(item_type)
153              }
154  
155          if origin is dict:
156              return {"type": "object"}
157  
158          if is_typed_dict(tp):
159              return typed_dict_to_json_schema(tp)
160  
161          if isinstance(tp, type):
162              if issubclass(tp, str):
163                  return {"type": "string"}
164              elif issubclass(tp, int):
165                  return {"type": "integer"}
166              elif issubclass(tp, float):
167                  return {"type": "number"}
168              elif issubclass(tp, bool):
169                  return {"type": "boolean"}
170  
171          return {"type": "string"}
172  
173      schema = {
174          "type": "object",
175          "properties": {},
176      }
177  
178      required = []
179      for key, tp in annotations.items():
180          schema["properties"][key] = python_type_to_json_type(tp)
181          if key in required_keys:
182              required.append(key)
183  
184      if required:
185          schema["required"] = required
186  
187      return schema
188  
189  
190  def typed_dict_dump_schema_json(typed_dict_cls: type, **json_kwargs) -> str:
191      schema = typed_dict_to_json_schema(typed_dict_cls)
192      return json.dumps(schema, **json_kwargs)
193  
194  
195  if __name__ == "__main__":
196      print(typed_dict_dump_schema_json(DBSchema, indent=2))
197  
198