news_fetching.py
1 import os 2 from dotenv import load_dotenv 3 from typing import Any, Dict, List 4 from haystack.dataclasses import ChatMessage 5 from haystack.components.tools import ToolInvoker 6 from haystack.components.generators.chat import OpenAIChatGenerator 7 from haystack.components.routers import ConditionalRouter 8 from haystack.tools import ComponentTool 9 from haystack.components.websearch import SerperDevWebSearch 10 from haystack import Pipeline, component 11 from haystack.core.component.types import Variadic 12 import argparse 13 14 import sys 15 sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..'))) 16 17 18 from ragaai_catalyst import RagaAICatalyst, Tracer, init_tracing 19 20 # Load environment variables from .env file 21 load_dotenv() 22 23 # Setup Raga AI Catalyst for enhanced monitoring and tracing 24 catalyst = RagaAICatalyst( 25 access_key=os.getenv('RAGAAI_CATALYST_ACCESS_KEY'), 26 secret_key=os.getenv('RAGAAI_CATALYST_SECRET_KEY'), 27 base_url=os.getenv('RAGAAI_CATALYST_BASE_URL') 28 ) 29 30 tracer = Tracer( 31 project_name='prompt_metric_dataset',#os.getenv("RAGAAI_PROJECT_NAME"), 32 dataset_name='pytest_dataset',#os.getenv("RAGAAI_DATASET_NAME"), 33 tracer_type="agentic/haystack", 34 ) 35 36 # Initialize tracing to track system performance and activities 37 init_tracing(catalyst=catalyst, tracer=tracer) 38 39 # Component to collect and store messages temporarily 40 @component() 41 class MessageCollector: 42 def __init__(self): 43 self._messages = [] 44 45 @component.output_types(messages=List[ChatMessage]) 46 def run(self, messages: Variadic[List[ChatMessage]]) -> Dict[str, Any]: 47 self._messages.extend([msg for inner in messages for msg in inner]) 48 return {"messages": self._messages} 49 50 def clear(self): 51 self._messages = [] 52 53 # Component tool for web search, using SerperDev 54 web_tool = ComponentTool( 55 component=SerperDevWebSearch(top_k=3) 56 ) 57 58 # Routing conditions to handle replies with or without tool calls 59 routes = [ 60 { 61 "condition": "{{replies[0].tool_calls | length > 0}}", 62 "output": "{{replies}}", 63 "output_name": "there_are_tool_calls", 64 "output_type": List[ChatMessage], 65 }, 66 { 67 "condition": "{{replies[0].tool_calls | length == 0}}", 68 "output": "{{replies}}", 69 "output_name": "final_replies", 70 "output_type": List[ChatMessage], 71 }, 72 ] 73 74 # Setup the pipeline for processing user queries 75 tool_agent = Pipeline() 76 tool_agent.add_component("message_collector", MessageCollector()) 77 tool_agent.add_component("generator", OpenAIChatGenerator(model="gpt-4o-mini", tools=[web_tool])) 78 tool_agent.add_component("router", ConditionalRouter(routes, unsafe=True)) 79 tool_agent.add_component("tool_invoker", ToolInvoker(tools=[web_tool])) 80 81 # Define connections in the pipeline 82 tool_agent.connect("generator.replies", "router") 83 tool_agent.connect("router.there_are_tool_calls", "tool_invoker") 84 tool_agent.connect("router.there_are_tool_calls", "message_collector") 85 tool_agent.connect("tool_invoker.tool_messages", "message_collector") 86 tool_agent.connect("message_collector", "generator.messages") 87 88 # Example messages to simulate user interaction 89 messages = [ 90 ChatMessage.from_system("Hello! Ask me anything about current news or information."), 91 ChatMessage.from_user("What is the latest news on the Mars Rover mission?") 92 ] 93 94 95 def main(info: str): 96 print(f"Info: {info}") 97 # Run the pipeline with the provided example messages 98 result = tool_agent.run({"messages": messages}) 99 100 # Print the final reply from the agent 101 print(result["router"]["final_replies"][0].text) 102 103 104 if __name__ == "__main__": 105 # Parse command-line arguments 106 parser = argparse.ArgumentParser(description="Test the news_fetching.py script.") 107 parser.add_argument("--info", type=str, default="testing-news-fetching", help="The info to use (e.g., testing-news-fetching)") 108 args = parser.parse_args() 109 110 main(args.info)