/ src / revolve / nodes / check_user_request.py
check_user_request.py
 1  from revolve.prompts import get_user_intent_prompt, get_user_intent_prompt_ft
 2  from revolve.functions import log
 3  from revolve.llm import invoke_llm
 4  from revolve.data_types import State, ClassifyUserRequest
 5  from revolve.utils import create_ft_data
 6  
 7  from datetime import datetime
 8  import json
 9  import os
10  
11  from revolve.utils_git import init_or_attach_git_repo, create_branch_with_timestamp
12  
13  
14  
15  def check_user_request(state: State):
16      send  = state.get("send")
17      log("Started", send)
18      
19      message_history = state["messages"]
20      if len(message_history) == 2:
21          init_or_attach_git_repo()
22          branch_name = create_branch_with_timestamp()
23          log(f"Branch created: {branch_name}", send)
24  
25  
26      #understand user intent
27      last_message_content = state["messages"][-1]["content"]
28  
29      messages = get_user_intent_prompt(state["messages"])
30      structured_db_response = invoke_llm(messages, max_attempts=3, validation_class=ClassifyUserRequest, method="function_calling", manual_validation=True)
31      description = "Prompt classifed as a task. Task is in progress." if structured_db_response.classification not in ["respond_back"] else structured_db_response.message
32  
33      last_message_content = state["messages"][-1]["content"]
34      new_trace = {
35          "node_name": "check_user_request",
36          "node_type": "classify_user_request",
37          "node_input": last_message_content,
38          "node_output": "place_holder",
39          "trace_timestamp": datetime.now(),
40          "description": description,
41      }
42  
43      if os.environ.get("FT_SAVE_MODE","false") == "true":
44          messages_ft = get_user_intent_prompt_ft(state["messages"])
45          messages_ft.append({
46              "role": "assistant",
47              "content": structured_db_response.model_dump_json(),
48          })
49          temp_state = {"custom_ft_data":[messages_ft]}
50          create_ft_data(temp_state) 
51  
52      if structured_db_response.classification in ["respond_back"]:
53          log(description, send=send, level="workflow")
54  
55      return {
56          "classification": structured_db_response.classification,
57          "trace": [new_trace],
58      }