/ haystack / components / converters / openapi_functions.py
openapi_functions.py
  1  # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
  2  #
  3  # SPDX-License-Identifier: Apache-2.0
  4  
  5  import json
  6  import os
  7  from pathlib import Path
  8  from typing import Any
  9  
 10  import yaml
 11  
 12  from haystack import component, logging
 13  from haystack.dataclasses.byte_stream import ByteStream
 14  from haystack.lazy_imports import LazyImport
 15  
 16  logger = logging.getLogger(__name__)
 17  
 18  with LazyImport("Run 'pip install jsonref'") as openapi_imports:
 19      import jsonref
 20  
 21  
 22  @component
 23  class OpenAPIServiceToFunctions:
 24      """
 25      Converts OpenAPI service definitions to a format suitable for OpenAI function calling.
 26  
 27      The definition must respect OpenAPI specification 3.0.0 or higher.
 28      It can be specified in JSON or YAML format.
 29      Each function must have:
 30          - unique operationId
 31          - description
 32          - requestBody and/or parameters
 33          - schema for the requestBody and/or parameters
 34      For more details on OpenAPI specification see the [official documentation](https://github.com/OAI/OpenAPI-Specification).
 35      For more details on OpenAI function calling see the [official documentation](https://platform.openai.com/docs/guides/function-calling).
 36  
 37      Usage example:
 38      ```python
 39      from haystack.components.converters import OpenAPIServiceToFunctions
 40      from haystack.dataclasses.byte_stream import ByteStream
 41  
 42      converter = OpenAPIServiceToFunctions()
 43      spec = ByteStream.from_string(
 44          '{"openapi":"3.0.0","info":{"title":"API","version":"1.0.0"},"paths":{"/search":{"get":{"operationId":"search","summary":"Search","parameters":[{"name":"q","in":"query","required":true,"schema":{"type":"string"}}]}}}}'
 45      )
 46      result = converter.run(sources=[spec])
 47      assert result["functions"]
 48      ```
 49      """
 50  
 51      MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3
 52  
 53      def __init__(self) -> None:
 54          """
 55          Create an OpenAPIServiceToFunctions component.
 56          """
 57          openapi_imports.check()
 58  
 59      @component.output_types(functions=list[dict[str, Any]], openapi_specs=list[dict[str, Any]])
 60      def run(self, sources: list[str | Path | ByteStream]) -> dict[str, Any]:
 61          """
 62          Converts OpenAPI definitions in OpenAI function calling format.
 63  
 64          :param sources:
 65              File paths or ByteStream objects of OpenAPI definitions (in JSON or YAML format).
 66  
 67          :returns:
 68              A dictionary with the following keys:
 69              - functions: Function definitions in JSON object format
 70              - openapi_specs: OpenAPI specs in JSON/YAML object format with resolved references
 71  
 72          :raises RuntimeError:
 73              If the OpenAPI definitions cannot be downloaded or processed.
 74          :raises ValueError:
 75              If the source type is not recognized or no functions are found in the OpenAPI definitions.
 76          """
 77          all_extracted_fc_definitions: list[dict[str, Any]] = []
 78          all_openapi_specs = []
 79          for source in sources:
 80              openapi_spec_content = None
 81              if isinstance(source, (str, Path)):
 82                  if os.path.exists(source):
 83                      try:
 84                          with open(source) as f:
 85                              openapi_spec_content = f.read()
 86                      except OSError as e:
 87                          logger.warning(
 88                              "IO error reading OpenAPI specification file: {source}. Error: {e}", source=source, e=e
 89                          )
 90                  else:
 91                      logger.warning("OpenAPI specification file not found: {source}", source=source)
 92              elif isinstance(source, ByteStream):
 93                  openapi_spec_content = source.data.decode("utf-8")
 94                  if not openapi_spec_content:
 95                      logger.warning(
 96                          "Invalid OpenAPI specification content provided: {openapi_spec_content}",
 97                          openapi_spec_content=openapi_spec_content,
 98                      )
 99              else:
100                  logger.warning(
101                      "Invalid source type {source}. Only str, Path, and ByteStream are supported.", source=type(source)
102                  )
103                  continue
104  
105              if openapi_spec_content:
106                  try:
107                      service_openapi_spec = self._parse_openapi_spec(openapi_spec_content)
108                      functions: list[dict[str, Any]] = self._openapi_to_functions(service_openapi_spec)
109                      all_extracted_fc_definitions.extend(functions)
110                      all_openapi_specs.append(service_openapi_spec)
111                  except Exception as e:
112                      logger.exception(
113                          "Error processing OpenAPI specification from source {source}: {error}", source=source, error=e
114                      )
115  
116          if not all_extracted_fc_definitions:
117              logger.warning("No OpenAI function definitions extracted from the provided OpenAPI specification sources.")
118  
119          return {"functions": all_extracted_fc_definitions, "openapi_specs": all_openapi_specs}
120  
121      def _openapi_to_functions(self, service_openapi_spec: dict[str, Any]) -> list[dict[str, Any]]:
122          """
123          OpenAPI to OpenAI function conversion.
124  
125          Extracts functions from the OpenAPI specification of the service and converts them into a format
126          suitable for OpenAI function calling.
127  
128          :param service_openapi_spec: The OpenAPI specification from which functions are to be extracted.
129          :type service_openapi_spec: dict[str, Any]
130          :return: A list of dictionaries, each representing a function. Each dictionary includes the function's
131                   name, description, and a schema of its parameters.
132          :rtype: list[dict[str, Any]]
133          """
134  
135          # Doesn't enforce rigid spec validation because that would require a lot of dependencies
136          # We check the version and require minimal fields to be present, so we can extract functions
137          spec_version = service_openapi_spec.get("openapi")
138          if not spec_version:
139              raise ValueError(f"Invalid OpenAPI spec provided. Could not extract version from {service_openapi_spec}")
140          service_openapi_spec_version = int(spec_version.split(".")[0])
141  
142          # Compare the versions
143          if service_openapi_spec_version < OpenAPIServiceToFunctions.MIN_REQUIRED_OPENAPI_SPEC_VERSION:
144              raise ValueError(
145                  f"Invalid OpenAPI spec version {service_openapi_spec_version}. Must be "
146                  f"at least {OpenAPIServiceToFunctions.MIN_REQUIRED_OPENAPI_SPEC_VERSION}."
147              )
148  
149          functions: list[dict[str, Any]] = []
150          for paths in service_openapi_spec["paths"].values():
151              for path_spec in paths.values():
152                  function_dict = self._parse_endpoint_spec(path_spec)
153                  if function_dict:
154                      functions.append(function_dict)
155          return functions
156  
157      def _parse_endpoint_spec(self, resolved_spec: dict[str, Any]) -> dict[str, Any] | None:
158          if not isinstance(resolved_spec, dict):
159              logger.warning("Invalid OpenAPI spec format provided. Could not extract function.")
160              return {}
161  
162          function_name = resolved_spec.get("operationId")
163          description = resolved_spec.get("description") or resolved_spec.get("summary", "")
164  
165          schema: dict[str, Any] = {"type": "object", "properties": {}}
166  
167          # requestBody section
168          req_body_schema = (
169              resolved_spec.get("requestBody", {}).get("content", {}).get("application/json", {}).get("schema", {})
170          )
171          if "properties" in req_body_schema:
172              for prop_name, prop_schema in req_body_schema["properties"].items():
173                  schema["properties"][prop_name] = self._parse_property_attributes(prop_schema)
174  
175              if "required" in req_body_schema:
176                  schema.setdefault("required", []).extend(req_body_schema["required"])
177  
178          # parameters section
179          for param in resolved_spec.get("parameters", []):
180              if "schema" in param:
181                  schema_dict = self._parse_property_attributes(param["schema"])
182                  # these attributes are not in param[schema] level but on param level
183                  useful_attributes = ["description", "pattern", "enum"]
184                  schema_dict.update({key: param[key] for key in useful_attributes if param.get(key)})
185                  schema["properties"][param["name"]] = schema_dict
186                  if param.get("required", False):
187                      schema.setdefault("required", []).append(param["name"])
188  
189          if function_name and description and schema["properties"]:
190              return {"name": function_name, "description": description, "parameters": schema}
191  
192          logger.warning(
193              "Invalid OpenAPI spec format provided. Could not extract function from {spec}", spec=resolved_spec
194          )
195          return {}
196  
197      def _parse_property_attributes(
198          self, property_schema: dict[str, Any], include_attributes: list[str] | None = None
199      ) -> dict[str, Any]:
200          """
201          Parses the attributes of a property schema.
202  
203          Recursively parses the attributes of a property schema, including nested objects and arrays,
204          and includes specified attributes like description, pattern, etc.
205  
206          :param property_schema: The schema of the property to parse.
207          :param include_attributes: The list of attributes to include in the parsed schema.
208          :return: The parsed schema of the property including the specified attributes.
209          """
210          include_attributes = include_attributes or ["description", "pattern", "enum"]
211  
212          schema_type = property_schema.get("type")
213  
214          parsed_schema = {"type": schema_type} if schema_type else {}
215          for attr in include_attributes:
216              if attr in property_schema:
217                  parsed_schema[attr] = property_schema[attr]
218  
219          if schema_type == "object":
220              properties = property_schema.get("properties", {})
221              parsed_properties = {
222                  prop_name: self._parse_property_attributes(prop, include_attributes)
223                  for prop_name, prop in properties.items()
224              }
225              parsed_schema["properties"] = parsed_properties
226  
227              if "required" in property_schema:
228                  parsed_schema["required"] = property_schema["required"]
229  
230          elif schema_type == "array":
231              items = property_schema.get("items", {})
232              parsed_schema["items"] = self._parse_property_attributes(items, include_attributes)
233  
234          return parsed_schema
235  
236      def _parse_openapi_spec(self, content: str) -> dict[str, Any]:
237          """
238          Parses OpenAPI specification content, supporting both JSON and YAML formats.
239  
240          :param content: The content of the OpenAPI specification.
241          :return: The parsed OpenAPI specification.
242          """
243          open_api_spec_content = None
244          try:
245              open_api_spec_content = json.loads(content)
246              return jsonref.replace_refs(open_api_spec_content)
247          except json.JSONDecodeError as json_error:
248              # heuristic to confirm that the content is likely malformed JSON
249              if content.strip().startswith(("{", "[")):
250                  raise json_error
251  
252          try:
253              open_api_spec_content = yaml.safe_load(content)
254          except yaml.YAMLError as e:
255              error_message = (
256                  "Failed to parse the OpenAPI specification. The content does not appear to be valid JSON or YAML.\n\n"
257              )
258              raise RuntimeError(error_message, content) from e
259  
260          # Replace references in the object with their resolved values, if any
261          return jsonref.replace_refs(open_api_spec_content)