/ src / evidently / ui / service / tracing / api.py
api.py
  1  import logging
  2  import uuid
  3  from collections import defaultdict
  4  from enum import Enum
  5  from typing import Annotated
  6  from typing import Callable
  7  from typing import Dict
  8  from typing import List
  9  from typing import Optional
 10  
 11  import litestar
 12  from litestar import Router
 13  from litestar.exceptions import HTTPException
 14  from litestar.exceptions import NotFoundException
 15  from litestar.params import Dependency
 16  from opentelemetry.proto.collector.trace.v1 import trace_service_pb2
 17  
 18  from evidently._pydantic_compat import BaseModel
 19  from evidently.core.datasets import ServiceColumns
 20  from evidently.legacy.ui.type_aliases import UserID
 21  from evidently.ui.service.datasets.metadata import DatasetTracingParams
 22  from evidently.ui.service.managers.datasets import DatasetManager
 23  from evidently.ui.service.tracing.storage.base import HumanFeedbackModel
 24  from evidently.ui.service.tracing.storage.base import SpanModel
 25  from evidently.ui.service.tracing.storage.base import TraceModel
 26  from evidently.ui.service.tracing.storage.base import TracingStorage
 27  
 28  
 29  class SessionDefinition(BaseModel):
 30      type: str = "session"
 31      session_field: str = "session_id"
 32      user_field: str = "user_id"
 33      time_split_sec: int = 30 * 60
 34  
 35  
 36  @litestar.post("/")
 37  async def export(
 38      user_id: UserID,
 39      request: litestar.Request,
 40      tracing_storage: TracingStorage,
 41      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
 42  ) -> litestar.Response:
 43      data = await request.body()
 44      message = trace_service_pb2.ExportTraceServiceRequest()
 45      message.ParseFromString(data)
 46      export_id = None
 47      service_name = ""
 48      for resource in message.resource_spans:
 49          for res_attr in resource.resource.attributes:
 50              if res_attr.key == "service.name":
 51                  service_name = res_attr.value.string_value
 52              if res_attr.key == "evidently.export_id":
 53                  export_id = uuid.UUID(res_attr.value.string_value)
 54      if export_id is None:
 55          raise ValueError("Export ID is missing from trace data")
 56  
 57      tracing_storage.save(export_id, service_name, message)
 58  
 59      return litestar.Response(
 60          trace_service_pb2.ExportTraceServiceResponse().SerializeToString(),
 61          media_type="application/x-protobuf",
 62      )
 63  
 64  
 65  @litestar.delete("/{export_id:uuid}/{trace_id:str}")
 66  async def delete_trace(
 67      trace_id: str,
 68      tracing_storage: TracingStorage,
 69      user_id: UserID,
 70      export_id: uuid.UUID,
 71  ) -> None:
 72      await tracing_storage.delete_trace(export_id, trace_id)
 73  
 74  
 75  def _get_first_span(trace: TraceModel) -> Optional[SpanModel]:
 76      for span in trace.spans:
 77          if span.parent_span_id == "":
 78              return span
 79      return None
 80  
 81  
 82  class TraceListType(Enum):
 83      Auto = "auto"
 84      Ungrouped = "ungrouped"
 85      Session = "session"
 86      User = "user"
 87  
 88  
 89  class TraceListGetterType(Enum):
 90      with_filters_from_metadata = "with_filters_from_metadata"
 91      ungrouped = "ungrouped"
 92  
 93  
 94  class DatasetMetadata(BaseModel):
 95      id: uuid.UUID
 96      name: str
 97      params: Optional[DatasetTracingParams] = None
 98  
 99  
100  class TraceSessionsResponse(BaseModel):
101      sessions: Dict[str, List[TraceModel]]
102      metadata: DatasetMetadata
103  
104  
105  @litestar.post("/metadata")
106  async def update_metadata(
107      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
108      user_id: UserID,
109      export_id: uuid.UUID,
110      data: DatasetTracingParams,
111  ) -> None:
112      dataset = await dataset_manager.get_dataset_metadata(user_id, export_id)
113      if dataset is None:
114          raise NotFoundException(f"Dataset {export_id} not found")
115  
116      await dataset_manager.update_dataset_tracing_metadata(user_id, export_id, data)
117  
118  
119  @litestar.get("/list")
120  async def trace_sessions(
121      tracing_storage: TracingStorage,
122      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
123      user_id: UserID,
124      export_id: uuid.UUID,
125      getter_type: TraceListGetterType,
126  ) -> TraceSessionsResponse:
127      traces = tracing_storage.read_traces_with_filter(export_id, None, None)
128      dataset = await dataset_manager.get_dataset_metadata(user_id, export_id)
129      if dataset is None:
130          raise NotFoundException(f"Dataset {export_id} not found")
131  
132      metadata = DatasetMetadata(id=dataset.id, name=dataset.name)
133  
134      tracing_session_params = (
135          await _determine_session_info(traces) if dataset.tracing_params is None else dataset.tracing_params
136      )
137  
138      metadata.params = tracing_session_params
139  
140      if getter_type == TraceListGetterType.ungrouped:
141          return TraceSessionsResponse(sessions={"undefined": traces}, metadata=metadata)
142  
143      type = TraceListType.Ungrouped
144      if metadata.params.session_type is not None:
145          type = TraceListType(tracing_session_params.session_type)
146  
147      session_field = metadata.params.session_field or "session_id"
148      user_field = metadata.params.user_field or "user_id"
149      time_split_sec = metadata.params.dialog_split_time_seconds or 30 * 60
150  
151      if type == TraceListType.Session:
152          return TraceSessionsResponse(sessions=await _split_by_session(traces, session_field), metadata=metadata)
153      if type == TraceListType.User:
154          return TraceSessionsResponse(
155              sessions=await _split_by_user(traces, user_field, time_split_sec), metadata=metadata
156          )
157  
158      return TraceSessionsResponse(sessions={"undefined": traces}, metadata=metadata)
159  
160  
161  class AddHumanFeedbackRequest(BaseModel):
162      trace_id: str
163      feedback: HumanFeedbackModel
164  
165  
166  @litestar.post("/human_feedback")
167  async def add_human_feedback(
168      tracing_storage: TracingStorage,
169      dataset_manager: Annotated[DatasetManager, Dependency(skip_validation=True)],
170      user_id: UserID,
171      export_id: uuid.UUID,
172      data: AddHumanFeedbackRequest,
173  ) -> None:
174      metadata = await dataset_manager.get_dataset_metadata(user_id, export_id)
175      try:
176          span_name = await tracing_storage.add_feedback(export_id, data.trace_id, data.feedback)
177      except NotImplementedError:
178          raise HTTPException(status_code=501, detail="Human feedback supported only for SQLite storage")
179      data_definition = metadata.data_definition
180      if data_definition.service_columns is None:
181          data_definition.service_columns = ServiceColumns()
182      if (
183          data_definition.service_columns.human_feedback_label is None
184          or not data_definition.service_columns.human_feedback_label.startswith(span_name)
185      ):
186          if (
187              data_definition.service_columns.human_feedback_label
188              and data_definition.service_columns.human_feedback_label.startswith(span_name)
189          ):
190              logging.warning(
191                  f"{metadata.name} dataset has different human feedback label:"
192                  f" was: {data_definition.service_columns.human_feedback_label}, but expected:"
193                  f" {span_name}.human_feedback_label. Replacing it with a new one..."
194              )
195          data_definition.service_columns.human_feedback_label = f"{span_name}.human_feedback_label"
196          data_definition.service_columns.human_feedback_comment = f"{span_name}.human_feedback_comment"
197          await dataset_manager.update_dataset(user_id, export_id, None, None, data_definition, None, None)
198  
199  
200  async def _determine_session_info(traces: List[TraceModel]) -> DatasetTracingParams:
201      sample = traces[:10]
202      params = DatasetTracingParams()
203      for trace in sample:
204          span = _get_first_span(trace)
205          if span is None:
206              continue
207          if params.session_type is None:
208              if "session_id" in span.attributes:
209                  params.session_type = "session"
210                  params.session_field = "session_id"
211              if "user_id" in span.attributes:
212                  params.session_type = "user"
213                  params.user_field = "user_id"
214                  params.dialog_split_time_seconds = 30 * 60
215          if params.user_message_field is None:
216              for possible_input_field in ["input", "question"]:
217                  if possible_input_field in span.attributes:
218                      params.user_message_field = f"{span.span_name}:{possible_input_field}"
219                      break
220          if params.assistant_message_field is None:
221              for possible_output_field in ["output", "answer", "result"]:
222                  if possible_output_field in span.attributes:
223                      params.assistant_message_field = f"{span.span_name}:{possible_output_field}"
224                      break
225                  if f"{possible_output_field}.data" in span.attributes:
226                      params.assistant_message_field = f"{span.span_name}:{possible_output_field}.data"
227                      break
228      return params
229  
230  
231  async def _split_by_session(traces: List[TraceModel], session_field: str) -> Dict[str, List[TraceModel]]:
232      result: Dict[str, List[TraceModel]] = defaultdict(list)
233      for trace in traces:
234          span = _get_first_span(trace)
235          if span is None:
236              continue
237          session_id = str(span.attributes.get(session_field, "undefined"))
238          result[session_id].append(trace)
239      return result
240  
241  
242  async def _split_by_user(
243      traces: List[TraceModel], user_field: str, dialog_split_sec: int
244  ) -> Dict[str, List[TraceModel]]:
245      result: Dict[str, List[TraceModel]] = defaultdict(list)
246      traces_by_users: Dict[str, List[TraceModel]] = defaultdict(list)
247      for trace in traces:
248          span = _get_first_span(trace)
249          if span is None:
250              continue
251          user_id = str(span.attributes.get(user_field, "undefined"))
252          traces_by_users[user_id].append(trace)
253      for user_id, user_traces in traces_by_users.items():
254          start_time = user_traces[0].start_time.timestamp()
255          session_id = user_id + f":{user_traces[0].start_time.isoformat()}"
256          for trace in user_traces:
257              trace_start = trace.start_time.timestamp()
258              if (trace_start - start_time) > dialog_split_sec:
259                  session_id = user_id + f":{trace.start_time.isoformat()}"
260              result[session_id].append(trace)
261              start_time = trace_start
262      return result
263  
264  
265  def tracing_api(guard: Callable) -> Router:
266      return Router(
267          "/v1/traces",
268          route_handlers=[
269              # read
270              Router(
271                  "",
272                  route_handlers=[
273                      trace_sessions,
274                  ],
275              ),
276              # write
277              Router(
278                  "",
279                  route_handlers=[
280                      export,
281                      delete_trace,
282                      update_metadata,
283                      add_human_feedback,
284                  ],
285                  guards=[guard],
286              ),
287          ],
288      )