check_user_request.py
1 from revolve.prompts import get_user_intent_prompt, get_user_intent_prompt_ft 2 from revolve.functions import log 3 from revolve.llm import invoke_llm 4 from revolve.data_types import State, ClassifyUserRequest 5 from revolve.utils import create_ft_data 6 7 from datetime import datetime 8 import json 9 import os 10 11 from revolve.utils_git import init_or_attach_git_repo, create_branch_with_timestamp 12 13 14 15 def check_user_request(state: State): 16 send = state.get("send") 17 log("Started", send) 18 19 message_history = state["messages"] 20 if len(message_history) == 2: 21 init_or_attach_git_repo() 22 branch_name = create_branch_with_timestamp() 23 log(f"Branch created: {branch_name}", send) 24 25 26 #understand user intent 27 last_message_content = state["messages"][-1]["content"] 28 29 messages = get_user_intent_prompt(state["messages"]) 30 structured_db_response = invoke_llm(messages, max_attempts=3, validation_class=ClassifyUserRequest, method="function_calling", manual_validation=True) 31 description = "Prompt classifed as a task. Task is in progress." if structured_db_response.classification not in ["respond_back"] else structured_db_response.message 32 33 last_message_content = state["messages"][-1]["content"] 34 new_trace = { 35 "node_name": "check_user_request", 36 "node_type": "classify_user_request", 37 "node_input": last_message_content, 38 "node_output": "place_holder", 39 "trace_timestamp": datetime.now(), 40 "description": description, 41 } 42 43 if os.environ.get("FT_SAVE_MODE","false") == "true": 44 messages_ft = get_user_intent_prompt_ft(state["messages"]) 45 messages_ft.append({ 46 "role": "assistant", 47 "content": structured_db_response.model_dump_json(), 48 }) 49 temp_state = {"custom_ft_data":[messages_ft]} 50 create_ft_data(temp_state) 51 52 if structured_db_response.classification in ["respond_back"]: 53 log(description, send=send, level="workflow") 54 55 return { 56 "classification": structured_db_response.classification, 57 "trace": [new_trace], 58 }