/ test / graph_agent_workflow_test.erl
graph_agent_workflow_test.erl
  1  %%%-------------------------------------------------------------------
  2  %%% @doc 复杂 Agent 工作流测试
  3  %%%
  4  %%% 模拟 ReAct 风格的 Agent 工作流:
  5  %%% 1. 解析用户意图
  6  %%% 2. 规划执行步骤
  7  %%% 3. 执行工具调用 (可能多次循环)
  8  %%% 4. 生成最终响应
  9  %%%
 10  %%% 测试双引擎执行的正确性
 11  %%% @end
 12  %%%-------------------------------------------------------------------
 13  -module(graph_agent_workflow_test).
 14  
 15  -include_lib("eunit/include/eunit.hrl").
 16  
 17  -export([run/0, run_with_engine/1]).
 18  
 19  %%====================================================================
 20  %% 测试入口
 21  %%====================================================================
 22  
 23  run() ->
 24      io:format("~n=== 复杂 Agent 工作流测试 ===~n~n"),
 25  
 26      %% 构建工作流图
 27      {ok, Graph} = build_react_agent_graph(),
 28  
 29      %% 测试用例
 30      TestCases = [
 31          #{
 32              name => "简单问候",
 33              input => <<"你好,请问你是谁?">>,
 34              expected_tool_calls => 0
 35          },
 36          #{
 37              name => "天气查询",
 38              input => <<"北京今天天气怎么样?">>,
 39              expected_tool_calls => 1
 40          },
 41          #{
 42              name => "复杂计算任务",
 43              input => <<"帮我计算 123 * 456,然后查询结果对应的城市人口">>,
 44              expected_tool_calls => 2
 45          },
 46          #{
 47              name => "多步骤任务",
 48              input => <<"先搜索 Erlang 并发模型,然后总结要点">>,
 49              expected_tool_calls => 1
 50          }
 51      ],
 52  
 53      %% 使用 Pregel 引擎测试
 54      io:format("--- 测试 Pregel 引擎 ---~n"),
 55      run_test_cases(Graph, TestCases, pregel),
 56  
 57      io:format("~n=== 所有测试完成 ===~n"),
 58      ok.
 59  
 60  run_with_engine(Engine) ->
 61      {ok, Graph} = build_react_agent_graph(),
 62      InitState = graph:state(#{
 63          input => <<"查询北京和上海的天气,然后比较哪个更热">>,
 64          messages => [],
 65          tool_calls => 0,
 66          max_tool_calls => 5
 67      }),
 68  
 69      io:format("~n=== 使用 ~p 引擎执行复杂任务 ===~n", [Engine]),
 70      io:format("输入: ~s~n~n", [maps:get(input, InitState)]),
 71  
 72      Result = graph:run(Graph, InitState, #{engine => Engine}),
 73      print_result(Result),
 74      Result.
 75  
 76  %%====================================================================
 77  %% 工作流构建
 78  %%====================================================================
 79  
 80  build_react_agent_graph() ->
 81      %% 节点函数定义
 82      ParseFun = fun parse_intent/1,
 83      PlanFun = fun plan_actions/1,
 84      ToolFun = fun execute_tool/1,
 85      CheckFun = fun check_completion/1,
 86      RespondFun = fun generate_response/1,
 87  
 88      %% 路由函数
 89      PlanRouter = fun(State) ->
 90          case graph:get(State, needs_tool) of
 91              true -> execute_tool;
 92              false -> generate_response
 93          end
 94      end,
 95  
 96      CheckRouter = fun(State) ->
 97          case graph:get(State, is_complete) of
 98              true -> generate_response;
 99              false -> plan_actions  %% 循环回规划阶段
100          end
101      end,
102  
103      %% 构建图
104      B0 = graph:builder(#{max_iterations => 20}),
105  
106      %% 添加节点
107      B1 = graph:add_node(B0, parse_intent, ParseFun),
108      B2 = graph:add_node(B1, plan_actions, PlanFun),
109      B3 = graph:add_node(B2, execute_tool, ToolFun),
110      B4 = graph:add_node(B3, check_completion, CheckFun),
111      B5 = graph:add_node(B4, generate_response, RespondFun),
112  
113      %% 添加边
114      B6 = graph:add_edge(B5, parse_intent, plan_actions),
115      B7 = graph:add_conditional_edge(B6, plan_actions, PlanRouter),
116      B8 = graph:add_edge(B7, execute_tool, check_completion),
117      B9 = graph:add_conditional_edge(B8, check_completion, CheckRouter),
118      B10 = graph:add_edge(B9, generate_response, '__end__'),
119  
120      %% 设置入口并编译
121      B11 = graph:set_entry(B10, parse_intent),
122      graph:compile(B11).
123  
124  %%====================================================================
125  %% 节点函数实现
126  %%====================================================================
127  
128  %% 解析用户意图
129  parse_intent(State) ->
130      Input = graph:get(State, input, <<>>),
131  
132      %% 简单的意图识别
133      Intents = analyze_intent(Input),
134  
135      %% 添加消息到历史
136      Messages = graph:get(State, messages, []),
137      NewMessage = #{role => user, content => Input, timestamp => erlang:timestamp()},
138  
139      State1 = graph:set(State, intents, Intents),
140      State2 = graph:set(State1, messages, [NewMessage | Messages]),
141      State3 = graph:set(State2, current_step, 0),
142  
143      io:format("  [解析意图] 识别到意图: ~p~n", [Intents]),
144      {ok, State3}.
145  
146  %% 规划执行动作
147  plan_actions(State) ->
148      Intents = graph:get(State, intents, []),
149      CurrentStep = graph:get(State, current_step, 0),
150      ToolResults = graph:get(State, tool_results, []),
151  
152      %% 决定下一步动作
153      {NeedsTool, ToolName, ToolArgs} = plan_next_action(Intents, CurrentStep, ToolResults),
154  
155      State1 = graph:set(State, needs_tool, NeedsTool),
156      State2 = graph:set(State1, next_tool, ToolName),
157      State3 = graph:set(State2, tool_args, ToolArgs),
158  
159      case NeedsTool of
160          true ->
161              io:format("  [规划动作] 需要调用工具: ~p~n", [ToolName]);
162          false ->
163              io:format("  [规划动作] 无需工具,准备生成响应~n")
164      end,
165  
166      {ok, State3}.
167  
168  %% 执行工具
169  execute_tool(State) ->
170      ToolName = graph:get(State, next_tool),
171      ToolArgs = graph:get(State, tool_args, #{}),
172      ToolCalls = graph:get(State, tool_calls, 0),
173  
174      io:format("  [执行工具] ~p(~p)~n", [ToolName, ToolArgs]),
175  
176      %% 模拟工具执行
177      Result = simulate_tool_call(ToolName, ToolArgs),
178  
179      %% 更新状态
180      ToolResults = graph:get(State, tool_results, []),
181      NewResult = #{tool => ToolName, args => ToolArgs, result => Result},
182  
183      State1 = graph:set(State, tool_results, [NewResult | ToolResults]),
184      State2 = graph:set(State1, tool_calls, ToolCalls + 1),
185      State3 = graph:set(State2, current_step, graph:get(State, current_step, 0) + 1),
186  
187      io:format("  [工具结果] ~s~n", [Result]),
188      {ok, State3}.
189  
190  %% 检查是否完成
191  check_completion(State) ->
192      Intents = graph:get(State, intents, []),
193      ToolResults = graph:get(State, tool_results, []),
194      ToolCalls = graph:get(State, tool_calls, 0),
195      MaxToolCalls = graph:get(State, max_tool_calls, 3),
196  
197      %% 判断是否完成
198      IsComplete = check_if_complete(Intents, ToolResults, ToolCalls, MaxToolCalls),
199  
200      State1 = graph:set(State, is_complete, IsComplete),
201  
202      io:format("  [检查完成] 已调用工具 ~p 次, 完成状态: ~p~n", [ToolCalls, IsComplete]),
203      {ok, State1}.
204  
205  %% 生成最终响应
206  generate_response(State) ->
207      Intents = graph:get(State, intents, []),
208      ToolResults = graph:get(State, tool_results, []),
209      Input = graph:get(State, input, <<>>),
210  
211      %% 生成响应
212      Response = build_response(Input, Intents, ToolResults),
213  
214      %% 添加到消息历史
215      Messages = graph:get(State, messages, []),
216      NewMessage = #{role => assistant, content => Response, timestamp => erlang:timestamp()},
217  
218      State1 = graph:set(State, response, Response),
219      State2 = graph:set(State1, messages, [NewMessage | Messages]),
220  
221      io:format("  [生成响应] ~s~n", [Response]),
222      {ok, State2}.
223  
224  %%====================================================================
225  %% 辅助函数
226  %%====================================================================
227  
228  analyze_intent(Input) when is_binary(Input) ->
229      %% 使用 binary 模式匹配,支持 UTF-8
230      Patterns = [
231          {weather, [<<"天气">>, <<"温度">>, <<"下雨">>, <<"晴天">>, <<"weather">>]},
232          {search, [<<"搜索">>, <<"查询">>, <<"查找">>, <<"找">>, <<"search">>]},
233          {calculate, [<<"计算">>, <<"加">>, <<"减">>, <<"乘">>, <<"除">>, <<"*">>, <<"+">>, <<"-">>, <<"/">>]},
234          {compare, [<<"比较">>, <<"对比">>, <<"哪个">>, <<"compare">>]},
235          {greeting, [<<"你好">>, <<"您好">>, <<"hi">>, <<"hello">>, <<"Hi">>, <<"Hello">>]}
236      ],
237  
238      lists:filtermap(
239          fun({Intent, Keywords}) ->
240              HasKeyword = lists:any(
241                  fun(Kw) -> binary:match(Input, Kw) =/= nomatch end,
242                  Keywords
243              ),
244              case HasKeyword of
245                  true -> {true, Intent};
246                  false -> false
247              end
248          end,
249          Patterns
250      ).
251  
252  plan_next_action(Intents, CurrentStep, ToolResults) ->
253      CompletedTools = [maps:get(tool, R) || R <- ToolResults],
254  
255      %% 根据意图和已完成的工具决定下一步
256      case {Intents, CurrentStep} of
257          {[greeting | _], _} ->
258              {false, none, #{}};
259          {_, _} when CurrentStep >= length(Intents) ->
260              {false, none, #{}};
261          {[weather | _], 0} ->
262              {true, weather_api, #{city => <<"北京">>}};
263          {[search | _], 0} ->
264              {true, search_api, #{query => <<"Erlang">>}};
265          {[calculate | _], 0} ->
266              {true, calculator, #{expression => <<"123 * 456">>}};
267          {[compare, weather | _], Step} when Step < 2 ->
268              Cities = [<<"北京">>, <<"上海">>],
269              City = lists:nth(Step + 1, Cities),
270              case lists:member(weather_api, CompletedTools) of
271                  true when length(CompletedTools) < 2 ->
272                      {true, weather_api, #{city => City}};
273                  _ ->
274                      {false, none, #{}}
275              end;
276          _ ->
277              {false, none, #{}}
278      end.
279  
280  simulate_tool_call(weather_api, #{city := City}) ->
281      Temps = #{<<"北京">> => 25, <<"上海">> => 28, <<"广州">> => 32},
282      Temp = maps:get(City, Temps, 20),
283      iolist_to_binary(io_lib:format("~s今天天气晴,温度~p°C", [City, Temp]));
284  
285  simulate_tool_call(search_api, #{query := Query}) ->
286      iolist_to_binary(io_lib:format("搜索结果: 找到关于 ~s 的 10 条结果", [Query]));
287  
288  simulate_tool_call(calculator, #{expression := Expr}) ->
289      iolist_to_binary(io_lib:format("计算结果: ~s = 56088", [Expr]));
290  
291  simulate_tool_call(Tool, Args) ->
292      iolist_to_binary(io_lib:format("工具 ~p 执行完成,参数: ~p", [Tool, Args])).
293  
294  check_if_complete(Intents, ToolResults, ToolCalls, MaxToolCalls) ->
295      %% 超过最大调用次数
296      ToolCalls >= MaxToolCalls orelse
297      %% 简单任务不需要工具
298      Intents =:= [greeting] orelse
299      %% 所有需要的工具都已执行
300      (length(ToolResults) > 0 andalso
301       not lists:member(compare, Intents)) orelse
302      %% 比较任务需要两次工具调用
303      (lists:member(compare, Intents) andalso length(ToolResults) >= 2).
304  
305  build_response(Input, Intents, ToolResults) ->
306      case {Intents, ToolResults} of
307          {[greeting | _], _} ->
308              <<"你好!我是一个基于 Erlang Graph 框架的智能助手。有什么可以帮助你的吗?">>;
309          {_, []} ->
310              <<"我已经理解了你的问题,但目前没有足够的信息来回答。">>;
311          {[compare | _], Results} when length(Results) >= 2 ->
312              ResultTexts = [maps:get(result, R) || R <- Results],
313              iolist_to_binary([
314                  <<"根据查询结果:\n">>,
315                  lists:join(<<"\n">>, ResultTexts),
316                  <<"\n\n综合分析完成。">>
317              ]);
318          {_, [#{result := Result} | _]} ->
319              iolist_to_binary([
320                  <<"根据查询,">>, Result
321              ])
322      end.
323  
324  %%====================================================================
325  %% 测试运行
326  %%====================================================================
327  
328  run_test_cases(Graph, TestCases, Engine) ->
329      lists:foreach(
330          fun(#{name := Name, input := Input, expected_tool_calls := Expected}) ->
331              io:format("~n测试: ~s~n", [Name]),
332              io:format("输入: ~s~n", [Input]),
333  
334              InitState = graph:state(#{
335                  input => Input,
336                  messages => [],
337                  tool_calls => 0,
338                  max_tool_calls => 5
339              }),
340  
341              Result = graph:run(Graph, InitState, #{engine => Engine}),
342  
343              Status = maps:get(status, Result),
344              FinalState = maps:get(final_state, Result),
345              ToolCalls = graph:get(FinalState, tool_calls, 0),
346              Response = graph:get(FinalState, response, <<"无响应">>),
347  
348              io:format("状态: ~p, 工具调用次数: ~p (期望: ~p)~n", [Status, ToolCalls, Expected]),
349              io:format("响应: ~s~n", [Response]),
350  
351              %% 验证
352              case Status of
353                  completed -> io:format("[PASS] 执行完成~n");
354                  _ -> io:format("[WARN] 执行状态异常: ~p~n", [Status])
355              end
356          end,
357          TestCases
358      ).
359  
360  print_result(Result) ->
361      Status = maps:get(status, Result),
362      FinalState = maps:get(final_state, Result),
363  
364      io:format("~n--- 执行结果 ---~n"),
365      io:format("状态: ~p~n", [Status]),
366      io:format("工具调用次数: ~p~n", [graph:get(FinalState, tool_calls, 0)]),
367      io:format("响应: ~s~n", [graph:get(FinalState, response, <<"无响应">>)]),
368  
369      %% 打印消息历史
370      Messages = graph:get(FinalState, messages, []),
371      io:format("~n消息历史 (~p 条):~n", [length(Messages)]),
372      lists:foreach(
373          fun(#{role := Role, content := Content}) ->
374              io:format("  [~p] ~s~n", [Role, Content])
375          end,
376          lists:reverse(Messages)
377      ).
378  
379  %%====================================================================
380  %% EUnit Tests
381  %%====================================================================
382  
383  agent_workflow_test_() ->
384      {timeout, 60, fun() ->
385          {ok, Graph} = build_react_agent_graph(),
386  
387          %% 测试简单问候 - 使用英文避免编码问题
388          GreetingState = graph:state(#{
389              input => <<"Hello there!">>,
390              messages => [],
391              tool_calls => 0,
392              max_tool_calls => 3
393          }),
394  
395          %% 测试 Pregel 引擎
396          Result = graph:run(Graph, GreetingState),
397          ?assertEqual(completed, maps:get(status, Result)),
398          ?assertEqual(0, graph:get(maps:get(final_state, Result), tool_calls, 0)),
399  
400          ok
401      end}.