graph_agent_workflow_test.erl
1 %%%------------------------------------------------------------------- 2 %%% @doc 复杂 Agent 工作流测试 3 %%% 4 %%% 模拟 ReAct 风格的 Agent 工作流: 5 %%% 1. 解析用户意图 6 %%% 2. 规划执行步骤 7 %%% 3. 执行工具调用 (可能多次循环) 8 %%% 4. 生成最终响应 9 %%% 10 %%% 测试双引擎执行的正确性 11 %%% @end 12 %%%------------------------------------------------------------------- 13 -module(graph_agent_workflow_test). 14 15 -include_lib("eunit/include/eunit.hrl"). 16 17 -export([run/0, run_with_engine/1]). 18 19 %%==================================================================== 20 %% 测试入口 21 %%==================================================================== 22 23 run() -> 24 io:format("~n=== 复杂 Agent 工作流测试 ===~n~n"), 25 26 %% 构建工作流图 27 {ok, Graph} = build_react_agent_graph(), 28 29 %% 测试用例 30 TestCases = [ 31 #{ 32 name => "简单问候", 33 input => <<"你好,请问你是谁?">>, 34 expected_tool_calls => 0 35 }, 36 #{ 37 name => "天气查询", 38 input => <<"北京今天天气怎么样?">>, 39 expected_tool_calls => 1 40 }, 41 #{ 42 name => "复杂计算任务", 43 input => <<"帮我计算 123 * 456,然后查询结果对应的城市人口">>, 44 expected_tool_calls => 2 45 }, 46 #{ 47 name => "多步骤任务", 48 input => <<"先搜索 Erlang 并发模型,然后总结要点">>, 49 expected_tool_calls => 1 50 } 51 ], 52 53 %% 使用 Pregel 引擎测试 54 io:format("--- 测试 Pregel 引擎 ---~n"), 55 run_test_cases(Graph, TestCases, pregel), 56 57 io:format("~n=== 所有测试完成 ===~n"), 58 ok. 59 60 run_with_engine(Engine) -> 61 {ok, Graph} = build_react_agent_graph(), 62 InitState = graph:state(#{ 63 input => <<"查询北京和上海的天气,然后比较哪个更热">>, 64 messages => [], 65 tool_calls => 0, 66 max_tool_calls => 5 67 }), 68 69 io:format("~n=== 使用 ~p 引擎执行复杂任务 ===~n", [Engine]), 70 io:format("输入: ~s~n~n", [maps:get(input, InitState)]), 71 72 Result = graph:run(Graph, InitState, #{engine => Engine}), 73 print_result(Result), 74 Result. 75 76 %%==================================================================== 77 %% 工作流构建 78 %%==================================================================== 79 80 build_react_agent_graph() -> 81 %% 节点函数定义 82 ParseFun = fun parse_intent/1, 83 PlanFun = fun plan_actions/1, 84 ToolFun = fun execute_tool/1, 85 CheckFun = fun check_completion/1, 86 RespondFun = fun generate_response/1, 87 88 %% 路由函数 89 PlanRouter = fun(State) -> 90 case graph:get(State, needs_tool) of 91 true -> execute_tool; 92 false -> generate_response 93 end 94 end, 95 96 CheckRouter = fun(State) -> 97 case graph:get(State, is_complete) of 98 true -> generate_response; 99 false -> plan_actions %% 循环回规划阶段 100 end 101 end, 102 103 %% 构建图 104 B0 = graph:builder(#{max_iterations => 20}), 105 106 %% 添加节点 107 B1 = graph:add_node(B0, parse_intent, ParseFun), 108 B2 = graph:add_node(B1, plan_actions, PlanFun), 109 B3 = graph:add_node(B2, execute_tool, ToolFun), 110 B4 = graph:add_node(B3, check_completion, CheckFun), 111 B5 = graph:add_node(B4, generate_response, RespondFun), 112 113 %% 添加边 114 B6 = graph:add_edge(B5, parse_intent, plan_actions), 115 B7 = graph:add_conditional_edge(B6, plan_actions, PlanRouter), 116 B8 = graph:add_edge(B7, execute_tool, check_completion), 117 B9 = graph:add_conditional_edge(B8, check_completion, CheckRouter), 118 B10 = graph:add_edge(B9, generate_response, '__end__'), 119 120 %% 设置入口并编译 121 B11 = graph:set_entry(B10, parse_intent), 122 graph:compile(B11). 123 124 %%==================================================================== 125 %% 节点函数实现 126 %%==================================================================== 127 128 %% 解析用户意图 129 parse_intent(State) -> 130 Input = graph:get(State, input, <<>>), 131 132 %% 简单的意图识别 133 Intents = analyze_intent(Input), 134 135 %% 添加消息到历史 136 Messages = graph:get(State, messages, []), 137 NewMessage = #{role => user, content => Input, timestamp => erlang:timestamp()}, 138 139 State1 = graph:set(State, intents, Intents), 140 State2 = graph:set(State1, messages, [NewMessage | Messages]), 141 State3 = graph:set(State2, current_step, 0), 142 143 io:format(" [解析意图] 识别到意图: ~p~n", [Intents]), 144 {ok, State3}. 145 146 %% 规划执行动作 147 plan_actions(State) -> 148 Intents = graph:get(State, intents, []), 149 CurrentStep = graph:get(State, current_step, 0), 150 ToolResults = graph:get(State, tool_results, []), 151 152 %% 决定下一步动作 153 {NeedsTool, ToolName, ToolArgs} = plan_next_action(Intents, CurrentStep, ToolResults), 154 155 State1 = graph:set(State, needs_tool, NeedsTool), 156 State2 = graph:set(State1, next_tool, ToolName), 157 State3 = graph:set(State2, tool_args, ToolArgs), 158 159 case NeedsTool of 160 true -> 161 io:format(" [规划动作] 需要调用工具: ~p~n", [ToolName]); 162 false -> 163 io:format(" [规划动作] 无需工具,准备生成响应~n") 164 end, 165 166 {ok, State3}. 167 168 %% 执行工具 169 execute_tool(State) -> 170 ToolName = graph:get(State, next_tool), 171 ToolArgs = graph:get(State, tool_args, #{}), 172 ToolCalls = graph:get(State, tool_calls, 0), 173 174 io:format(" [执行工具] ~p(~p)~n", [ToolName, ToolArgs]), 175 176 %% 模拟工具执行 177 Result = simulate_tool_call(ToolName, ToolArgs), 178 179 %% 更新状态 180 ToolResults = graph:get(State, tool_results, []), 181 NewResult = #{tool => ToolName, args => ToolArgs, result => Result}, 182 183 State1 = graph:set(State, tool_results, [NewResult | ToolResults]), 184 State2 = graph:set(State1, tool_calls, ToolCalls + 1), 185 State3 = graph:set(State2, current_step, graph:get(State, current_step, 0) + 1), 186 187 io:format(" [工具结果] ~s~n", [Result]), 188 {ok, State3}. 189 190 %% 检查是否完成 191 check_completion(State) -> 192 Intents = graph:get(State, intents, []), 193 ToolResults = graph:get(State, tool_results, []), 194 ToolCalls = graph:get(State, tool_calls, 0), 195 MaxToolCalls = graph:get(State, max_tool_calls, 3), 196 197 %% 判断是否完成 198 IsComplete = check_if_complete(Intents, ToolResults, ToolCalls, MaxToolCalls), 199 200 State1 = graph:set(State, is_complete, IsComplete), 201 202 io:format(" [检查完成] 已调用工具 ~p 次, 完成状态: ~p~n", [ToolCalls, IsComplete]), 203 {ok, State1}. 204 205 %% 生成最终响应 206 generate_response(State) -> 207 Intents = graph:get(State, intents, []), 208 ToolResults = graph:get(State, tool_results, []), 209 Input = graph:get(State, input, <<>>), 210 211 %% 生成响应 212 Response = build_response(Input, Intents, ToolResults), 213 214 %% 添加到消息历史 215 Messages = graph:get(State, messages, []), 216 NewMessage = #{role => assistant, content => Response, timestamp => erlang:timestamp()}, 217 218 State1 = graph:set(State, response, Response), 219 State2 = graph:set(State1, messages, [NewMessage | Messages]), 220 221 io:format(" [生成响应] ~s~n", [Response]), 222 {ok, State2}. 223 224 %%==================================================================== 225 %% 辅助函数 226 %%==================================================================== 227 228 analyze_intent(Input) when is_binary(Input) -> 229 %% 使用 binary 模式匹配,支持 UTF-8 230 Patterns = [ 231 {weather, [<<"天气">>, <<"温度">>, <<"下雨">>, <<"晴天">>, <<"weather">>]}, 232 {search, [<<"搜索">>, <<"查询">>, <<"查找">>, <<"找">>, <<"search">>]}, 233 {calculate, [<<"计算">>, <<"加">>, <<"减">>, <<"乘">>, <<"除">>, <<"*">>, <<"+">>, <<"-">>, <<"/">>]}, 234 {compare, [<<"比较">>, <<"对比">>, <<"哪个">>, <<"compare">>]}, 235 {greeting, [<<"你好">>, <<"您好">>, <<"hi">>, <<"hello">>, <<"Hi">>, <<"Hello">>]} 236 ], 237 238 lists:filtermap( 239 fun({Intent, Keywords}) -> 240 HasKeyword = lists:any( 241 fun(Kw) -> binary:match(Input, Kw) =/= nomatch end, 242 Keywords 243 ), 244 case HasKeyword of 245 true -> {true, Intent}; 246 false -> false 247 end 248 end, 249 Patterns 250 ). 251 252 plan_next_action(Intents, CurrentStep, ToolResults) -> 253 CompletedTools = [maps:get(tool, R) || R <- ToolResults], 254 255 %% 根据意图和已完成的工具决定下一步 256 case {Intents, CurrentStep} of 257 {[greeting | _], _} -> 258 {false, none, #{}}; 259 {_, _} when CurrentStep >= length(Intents) -> 260 {false, none, #{}}; 261 {[weather | _], 0} -> 262 {true, weather_api, #{city => <<"北京">>}}; 263 {[search | _], 0} -> 264 {true, search_api, #{query => <<"Erlang">>}}; 265 {[calculate | _], 0} -> 266 {true, calculator, #{expression => <<"123 * 456">>}}; 267 {[compare, weather | _], Step} when Step < 2 -> 268 Cities = [<<"北京">>, <<"上海">>], 269 City = lists:nth(Step + 1, Cities), 270 case lists:member(weather_api, CompletedTools) of 271 true when length(CompletedTools) < 2 -> 272 {true, weather_api, #{city => City}}; 273 _ -> 274 {false, none, #{}} 275 end; 276 _ -> 277 {false, none, #{}} 278 end. 279 280 simulate_tool_call(weather_api, #{city := City}) -> 281 Temps = #{<<"北京">> => 25, <<"上海">> => 28, <<"广州">> => 32}, 282 Temp = maps:get(City, Temps, 20), 283 iolist_to_binary(io_lib:format("~s今天天气晴,温度~p°C", [City, Temp])); 284 285 simulate_tool_call(search_api, #{query := Query}) -> 286 iolist_to_binary(io_lib:format("搜索结果: 找到关于 ~s 的 10 条结果", [Query])); 287 288 simulate_tool_call(calculator, #{expression := Expr}) -> 289 iolist_to_binary(io_lib:format("计算结果: ~s = 56088", [Expr])); 290 291 simulate_tool_call(Tool, Args) -> 292 iolist_to_binary(io_lib:format("工具 ~p 执行完成,参数: ~p", [Tool, Args])). 293 294 check_if_complete(Intents, ToolResults, ToolCalls, MaxToolCalls) -> 295 %% 超过最大调用次数 296 ToolCalls >= MaxToolCalls orelse 297 %% 简单任务不需要工具 298 Intents =:= [greeting] orelse 299 %% 所有需要的工具都已执行 300 (length(ToolResults) > 0 andalso 301 not lists:member(compare, Intents)) orelse 302 %% 比较任务需要两次工具调用 303 (lists:member(compare, Intents) andalso length(ToolResults) >= 2). 304 305 build_response(Input, Intents, ToolResults) -> 306 case {Intents, ToolResults} of 307 {[greeting | _], _} -> 308 <<"你好!我是一个基于 Erlang Graph 框架的智能助手。有什么可以帮助你的吗?">>; 309 {_, []} -> 310 <<"我已经理解了你的问题,但目前没有足够的信息来回答。">>; 311 {[compare | _], Results} when length(Results) >= 2 -> 312 ResultTexts = [maps:get(result, R) || R <- Results], 313 iolist_to_binary([ 314 <<"根据查询结果:\n">>, 315 lists:join(<<"\n">>, ResultTexts), 316 <<"\n\n综合分析完成。">> 317 ]); 318 {_, [#{result := Result} | _]} -> 319 iolist_to_binary([ 320 <<"根据查询,">>, Result 321 ]) 322 end. 323 324 %%==================================================================== 325 %% 测试运行 326 %%==================================================================== 327 328 run_test_cases(Graph, TestCases, Engine) -> 329 lists:foreach( 330 fun(#{name := Name, input := Input, expected_tool_calls := Expected}) -> 331 io:format("~n测试: ~s~n", [Name]), 332 io:format("输入: ~s~n", [Input]), 333 334 InitState = graph:state(#{ 335 input => Input, 336 messages => [], 337 tool_calls => 0, 338 max_tool_calls => 5 339 }), 340 341 Result = graph:run(Graph, InitState, #{engine => Engine}), 342 343 Status = maps:get(status, Result), 344 FinalState = maps:get(final_state, Result), 345 ToolCalls = graph:get(FinalState, tool_calls, 0), 346 Response = graph:get(FinalState, response, <<"无响应">>), 347 348 io:format("状态: ~p, 工具调用次数: ~p (期望: ~p)~n", [Status, ToolCalls, Expected]), 349 io:format("响应: ~s~n", [Response]), 350 351 %% 验证 352 case Status of 353 completed -> io:format("[PASS] 执行完成~n"); 354 _ -> io:format("[WARN] 执行状态异常: ~p~n", [Status]) 355 end 356 end, 357 TestCases 358 ). 359 360 print_result(Result) -> 361 Status = maps:get(status, Result), 362 FinalState = maps:get(final_state, Result), 363 364 io:format("~n--- 执行结果 ---~n"), 365 io:format("状态: ~p~n", [Status]), 366 io:format("工具调用次数: ~p~n", [graph:get(FinalState, tool_calls, 0)]), 367 io:format("响应: ~s~n", [graph:get(FinalState, response, <<"无响应">>)]), 368 369 %% 打印消息历史 370 Messages = graph:get(FinalState, messages, []), 371 io:format("~n消息历史 (~p 条):~n", [length(Messages)]), 372 lists:foreach( 373 fun(#{role := Role, content := Content}) -> 374 io:format(" [~p] ~s~n", [Role, Content]) 375 end, 376 lists:reverse(Messages) 377 ). 378 379 %%==================================================================== 380 %% EUnit Tests 381 %%==================================================================== 382 383 agent_workflow_test_() -> 384 {timeout, 60, fun() -> 385 {ok, Graph} = build_react_agent_graph(), 386 387 %% 测试简单问候 - 使用英文避免编码问题 388 GreetingState = graph:state(#{ 389 input => <<"Hello there!">>, 390 messages => [], 391 tool_calls => 0, 392 max_tool_calls => 3 393 }), 394 395 %% 测试 Pregel 引擎 396 Result = graph:run(Graph, GreetingState), 397 ?assertEqual(completed, maps:get(status, Result)), 398 ?assertEqual(0, graph:get(maps:get(final_state, Result), tool_calls, 0)), 399 400 ok 401 end}.