embeddings.py
1 """ 2 Embeddings module 3 """ 4 5 from smolagents import Tool 6 7 from ...embeddings import Embeddings 8 9 10 class EmbeddingsTool(Tool): 11 """ 12 Tool to execute an Embeddings search. 13 """ 14 15 def __init__(self, config): 16 """ 17 Creates a new EmbeddingsTool. 18 19 Args: 20 config: embeddings tool configuration 21 """ 22 23 # Tool parameters 24 self.name = config["name"] 25 self.description = f"""{config['description']}. Results are returned as a list of dict elements. 26 Each result has keys 'id', 'text', 'score'.""" 27 28 # Input and output descriptions 29 self.inputs = {"query": {"type": "string", "description": "The search query to perform."}} 30 self.output_type = "any" 31 32 # Load embeddings instance 33 self.embeddings = self.load(config) 34 35 # Validate parameters and initialize tool 36 super().__init__() 37 38 # pylint: disable=W0221 39 def forward(self, query): 40 """ 41 Runs a search. 42 43 Args: 44 query: input query 45 46 Returns: 47 search results 48 """ 49 50 return self.embeddings.search(query, 5) 51 52 def load(self, config): 53 """ 54 Loads an embeddings instance from config. 55 56 Args: 57 config: embeddings tool configuration 58 59 Returns: 60 Embeddings 61 """ 62 63 if "target" in config: 64 return config["target"] 65 66 embeddings = Embeddings() 67 embeddings.load(**config) 68 69 return embeddings