/ src / revolve / tools.py
tools.py
 1  from revolve.utils import read_python_code
 2  from revolve.functions import get_file_list, run_pytest
 3  from revolve.db import get_adapter
 4  from revolve.external import get_db_type
 5  
 6  from langchain_core.tools import tool
 7  from langchain_core.tools.structured import StructuredTool
 8  
 9  
10  
11  
12  @tool
13  def _get_file_list()-> list:
14      """Get a list of files in the source folder."""
15      file_list = get_file_list()
16      filtered_files = [file for file in file_list if file.endswith(('.py', '.md', '.json'))]
17      return filtered_files
18  
19  @tool
20  def _read_file(file_name: str) -> str:
21      """Read the content of a file. Only supports .py, .md, and .json files."""
22      #if file ends with only .py, .md, or .json
23      if not file_name.endswith(('.py', '.md', '.json')):
24          return "File name is not valid or file type is not supported. Please provide a valid file name with .py, .md, or .json extension."
25  
26      return read_python_code(file_name)
27  
28  
29  @tool
30  def _run_test(test_file: str) -> str:
31      """Run pytest on a test file."""
32      #check if starts with test_ and ends with .py
33      if not test_file.startswith("test_") or not test_file.endswith(".py"):
34          return "Test file name is not valid. Please provide a valid test file name starting with 'test_' and ending with '.py'."
35  
36      return run_pytest(test_file)
37  
38  def get_tools():
39      """Get the list of tools."""
40      adapter = get_adapter(get_db_type())
41      methods = get_functions(adapter)
42      tool_methods = [
43          _get_file_list,
44          _read_file,
45          _run_test
46      ]
47  
48      for method in methods:
49          tool = StructuredTool.from_function(
50              method,
51              name=method.__name__,
52              description=method.__doc__ or "No description available.",
53          )
54          tool_methods.append(tool)
55  
56      return tool_methods
57  
58  def get_functions(adapter):
59      """
60      Retrieve a list of functions of the adapter
61      """
62      methods = []
63  
64      for m in dir(adapter):
65          if m.startswith("__"):
66              continue
67  
68          attr = getattr(adapter, m)
69          if not callable(attr):
70              continue
71          
72          if not getattr(attr, "_db_tool", False):
73              continue
74  
75          methods.append(attr)
76  
77      return methods