openapi_functions.py
1 # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> 2 # 3 # SPDX-License-Identifier: Apache-2.0 4 5 import json 6 import os 7 from pathlib import Path 8 from typing import Any 9 10 import yaml 11 12 from haystack import component, logging 13 from haystack.dataclasses.byte_stream import ByteStream 14 from haystack.lazy_imports import LazyImport 15 16 logger = logging.getLogger(__name__) 17 18 with LazyImport("Run 'pip install jsonref'") as openapi_imports: 19 import jsonref 20 21 22 @component 23 class OpenAPIServiceToFunctions: 24 """ 25 Converts OpenAPI service definitions to a format suitable for OpenAI function calling. 26 27 The definition must respect OpenAPI specification 3.0.0 or higher. 28 It can be specified in JSON or YAML format. 29 Each function must have: 30 - unique operationId 31 - description 32 - requestBody and/or parameters 33 - schema for the requestBody and/or parameters 34 For more details on OpenAPI specification see the [official documentation](https://github.com/OAI/OpenAPI-Specification). 35 For more details on OpenAI function calling see the [official documentation](https://platform.openai.com/docs/guides/function-calling). 36 37 Usage example: 38 ```python 39 from haystack.components.converters import OpenAPIServiceToFunctions 40 from haystack.dataclasses.byte_stream import ByteStream 41 42 converter = OpenAPIServiceToFunctions() 43 spec = ByteStream.from_string( 44 '{"openapi":"3.0.0","info":{"title":"API","version":"1.0.0"},"paths":{"/search":{"get":{"operationId":"search","summary":"Search","parameters":[{"name":"q","in":"query","required":true,"schema":{"type":"string"}}]}}}}' 45 ) 46 result = converter.run(sources=[spec]) 47 assert result["functions"] 48 ``` 49 """ 50 51 MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 52 53 def __init__(self) -> None: 54 """ 55 Create an OpenAPIServiceToFunctions component. 56 """ 57 openapi_imports.check() 58 59 @component.output_types(functions=list[dict[str, Any]], openapi_specs=list[dict[str, Any]]) 60 def run(self, sources: list[str | Path | ByteStream]) -> dict[str, Any]: 61 """ 62 Converts OpenAPI definitions in OpenAI function calling format. 63 64 :param sources: 65 File paths or ByteStream objects of OpenAPI definitions (in JSON or YAML format). 66 67 :returns: 68 A dictionary with the following keys: 69 - functions: Function definitions in JSON object format 70 - openapi_specs: OpenAPI specs in JSON/YAML object format with resolved references 71 72 :raises RuntimeError: 73 If the OpenAPI definitions cannot be downloaded or processed. 74 :raises ValueError: 75 If the source type is not recognized or no functions are found in the OpenAPI definitions. 76 """ 77 all_extracted_fc_definitions: list[dict[str, Any]] = [] 78 all_openapi_specs = [] 79 for source in sources: 80 openapi_spec_content = None 81 if isinstance(source, (str, Path)): 82 if os.path.exists(source): 83 try: 84 with open(source) as f: 85 openapi_spec_content = f.read() 86 except OSError as e: 87 logger.warning( 88 "IO error reading OpenAPI specification file: {source}. Error: {e}", source=source, e=e 89 ) 90 else: 91 logger.warning("OpenAPI specification file not found: {source}", source=source) 92 elif isinstance(source, ByteStream): 93 openapi_spec_content = source.data.decode("utf-8") 94 if not openapi_spec_content: 95 logger.warning( 96 "Invalid OpenAPI specification content provided: {openapi_spec_content}", 97 openapi_spec_content=openapi_spec_content, 98 ) 99 else: 100 logger.warning( 101 "Invalid source type {source}. Only str, Path, and ByteStream are supported.", source=type(source) 102 ) 103 continue 104 105 if openapi_spec_content: 106 try: 107 service_openapi_spec = self._parse_openapi_spec(openapi_spec_content) 108 functions: list[dict[str, Any]] = self._openapi_to_functions(service_openapi_spec) 109 all_extracted_fc_definitions.extend(functions) 110 all_openapi_specs.append(service_openapi_spec) 111 except Exception as e: 112 logger.exception( 113 "Error processing OpenAPI specification from source {source}: {error}", source=source, error=e 114 ) 115 116 if not all_extracted_fc_definitions: 117 logger.warning("No OpenAI function definitions extracted from the provided OpenAPI specification sources.") 118 119 return {"functions": all_extracted_fc_definitions, "openapi_specs": all_openapi_specs} 120 121 def _openapi_to_functions(self, service_openapi_spec: dict[str, Any]) -> list[dict[str, Any]]: 122 """ 123 OpenAPI to OpenAI function conversion. 124 125 Extracts functions from the OpenAPI specification of the service and converts them into a format 126 suitable for OpenAI function calling. 127 128 :param service_openapi_spec: The OpenAPI specification from which functions are to be extracted. 129 :type service_openapi_spec: dict[str, Any] 130 :return: A list of dictionaries, each representing a function. Each dictionary includes the function's 131 name, description, and a schema of its parameters. 132 :rtype: list[dict[str, Any]] 133 """ 134 135 # Doesn't enforce rigid spec validation because that would require a lot of dependencies 136 # We check the version and require minimal fields to be present, so we can extract functions 137 spec_version = service_openapi_spec.get("openapi") 138 if not spec_version: 139 raise ValueError(f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}") 140 service_openapi_spec_version = int(spec_version.split(".")[0]) 141 142 # Compare the versions 143 if service_openapi_spec_version < OpenAPIServiceToFunctions.MIN_REQUIRED_OPENAPI_SPEC_VERSION: 144 raise ValueError( 145 f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be " 146 f"at least {OpenAPIServiceToFunctions.MIN_REQUIRED_OPENAPI_SPEC_VERSION}." 147 ) 148 149 functions: list[dict[str, Any]] = [] 150 for paths in service_openapi_spec["paths"].values(): 151 for path_spec in paths.values(): 152 function_dict = self._parse_endpoint_spec(path_spec) 153 if function_dict: 154 functions.append(function_dict) 155 return functions 156 157 def _parse_endpoint_spec(self, resolved_spec: dict[str, Any]) -> dict[str, Any] | None: 158 if not isinstance(resolved_spec, dict): 159 logger.warning("Invalid OpenAPI spec format provided. Could not extract function.") 160 return {} 161 162 function_name = resolved_spec.get("operationId") 163 description = resolved_spec.get("description") or resolved_spec.get("summary", "") 164 165 schema: dict[str, Any] = {"type": "object", "properties": {}} 166 167 # requestBody section 168 req_body_schema = ( 169 resolved_spec.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema", {}) 170 ) 171 if "properties" in req_body_schema: 172 for prop_name, prop_schema in req_body_schema["properties"].items(): 173 schema["properties"][prop_name] = self._parse_property_attributes(prop_schema) 174 175 if "required" in req_body_schema: 176 schema.setdefault("required", []).extend(req_body_schema["required"]) 177 178 # parameters section 179 for param in resolved_spec.get("parameters", []): 180 if "schema" in param: 181 schema_dict = self._parse_property_attributes(param["schema"]) 182 # these attributes are not in param[schema] level but on param level 183 useful_attributes = ["description", "pattern", "enum"] 184 schema_dict.update({key: param[key] for key in useful_attributes if param.get(key)}) 185 schema["properties"][param["name"]] = schema_dict 186 if param.get("required", False): 187 schema.setdefault("required", []).append(param["name"]) 188 189 if function_name and description and schema["properties"]: 190 return {"name": function_name, "description": description, "parameters": schema} 191 192 logger.warning( 193 "Invalid OpenAPI spec format provided. Could not extract function from {spec}", spec=resolved_spec 194 ) 195 return {} 196 197 def _parse_property_attributes( 198 self, property_schema: dict[str, Any], include_attributes: list[str] | None = None 199 ) -> dict[str, Any]: 200 """ 201 Parses the attributes of a property schema. 202 203 Recursively parses the attributes of a property schema, including nested objects and arrays, 204 and includes specified attributes like description, pattern, etc. 205 206 :param property_schema: The schema of the property to parse. 207 :param include_attributes: The list of attributes to include in the parsed schema. 208 :return: The parsed schema of the property including the specified attributes. 209 """ 210 include_attributes = include_attributes or ["description", "pattern", "enum"] 211 212 schema_type = property_schema.get("type") 213 214 parsed_schema = {"type": schema_type} if schema_type else {} 215 for attr in include_attributes: 216 if attr in property_schema: 217 parsed_schema[attr] = property_schema[attr] 218 219 if schema_type == "object": 220 properties = property_schema.get("properties", {}) 221 parsed_properties = { 222 prop_name: self._parse_property_attributes(prop, include_attributes) 223 for prop_name, prop in properties.items() 224 } 225 parsed_schema["properties"] = parsed_properties 226 227 if "required" in property_schema: 228 parsed_schema["required"] = property_schema["required"] 229 230 elif schema_type == "array": 231 items = property_schema.get("items", {}) 232 parsed_schema["items"] = self._parse_property_attributes(items, include_attributes) 233 234 return parsed_schema 235 236 def _parse_openapi_spec(self, content: str) -> dict[str, Any]: 237 """ 238 Parses OpenAPI specification content, supporting both JSON and YAML formats. 239 240 :param content: The content of the OpenAPI specification. 241 :return: The parsed OpenAPI specification. 242 """ 243 open_api_spec_content = None 244 try: 245 open_api_spec_content = json.loads(content) 246 return jsonref.replace_refs(open_api_spec_content) 247 except json.JSONDecodeError as json_error: 248 # heuristic to confirm that the content is likely malformed JSON 249 if content.strip().startswith(("{", "[")): 250 raise json_error 251 252 try: 253 open_api_spec_content = yaml.safe_load(content) 254 except yaml.YAMLError as e: 255 error_message = ( 256 "Failed to parse the OpenAPI specification. The content does not appear to be valid JSON or YAML.\n\n" 257 ) 258 raise RuntimeError(error_message, content) from e 259 260 # Replace references in the object with their resolved values, if any 261 return jsonref.replace_refs(open_api_spec_content)