/ restai / eval.py
eval.py
  1  """Evaluation engine for AI projects using DeepEval metrics."""
  2  
  3  import json
  4  import logging
  5  import time
  6  from datetime import datetime, timezone
  7  from typing import Optional
  8  
  9  from llama_index.core.llms.llm import LLM
 10  from deepeval.models.base_model import DeepEvalBaseLLM
 11  from deepeval.metrics import AnswerRelevancyMetric, FaithfulnessMetric, GEval
 12  from deepeval.test_case import LLMTestCase, LLMTestCaseParams
 13  
 14  from restai.models.databasemodels import (
 15      EvalRunDatabase,
 16      EvalTestCaseDatabase,
 17      EvalResultDatabase,
 18  )
 19  
 20  logger = logging.getLogger(__name__)
 21  
 22  VALID_METRICS = {"answer_relevancy", "faithfulness", "correctness"}
 23  
 24  
 25  class DeepEvalLLM(DeepEvalBaseLLM):
 26      """Adapter wrapping a LlamaIndex LLM for use with DeepEval."""
 27  
 28      def __init__(self, model: LLM, *args, **kwargs):
 29          self._llm = model
 30          super().__init__(*args, **kwargs)
 31  
 32      def load_model(self):
 33          return self._llm
 34  
 35      def generate(self, prompt: str) -> str:
 36          return self._llm.complete(prompt).text
 37  
 38      async def a_generate(self, prompt: str) -> str:
 39          res = await self._llm.complete(prompt)
 40          return res.text
 41  
 42      def get_model_name(self):
 43          return "Custom LLamaindex LLM"
 44  
 45  
 46  def eval_rag(question, response, llm):
 47      """Legacy single-question RAG evaluation (kept for backward compatibility)."""
 48      if response is not None:
 49          actual_output = response.response
 50          retrieval_context = [node.get_content() for node in response.source_nodes]
 51      else:
 52          return None
 53  
 54      test_case = LLMTestCase(
 55          input=question, actual_output=actual_output, retrieval_context=retrieval_context
 56      )
 57  
 58      llm = DeepEvalLLM(model=llm)
 59  
 60      metric = AnswerRelevancyMetric(
 61          threshold=0.5, model=llm, include_reason=True, async_mode=False
 62      )
 63      metric.measure(test_case)
 64  
 65      return metric
 66  
 67  
 68  def _build_metric(metric_name: str, eval_llm: DeepEvalLLM):
 69      """Create a DeepEval metric instance by name."""
 70      if metric_name == "answer_relevancy":
 71          return AnswerRelevancyMetric(
 72              threshold=0.5, model=eval_llm, include_reason=True, async_mode=False
 73          )
 74      elif metric_name == "faithfulness":
 75          return FaithfulnessMetric(
 76              threshold=0.5, model=eval_llm, include_reason=True, async_mode=False
 77          )
 78      elif metric_name == "correctness":
 79          return GEval(
 80              name="Correctness",
 81              criteria="Determine whether the actual output is factually correct and matches the expected output.",
 82              evaluation_params=[
 83                  LLMTestCaseParams.INPUT,
 84                  LLMTestCaseParams.ACTUAL_OUTPUT,
 85                  LLMTestCaseParams.EXPECTED_OUTPUT,
 86              ],
 87              threshold=0.5,
 88              model=eval_llm,
 89              async_mode=False,
 90          )
 91      else:
 92          raise ValueError(f"Unknown metric: {metric_name}")
 93  
 94  
 95  async def _get_project_answer(project, question: str, brain, user, db):
 96      """Call a project's question method directly and return (answer_text, sources_list, latency_ms)."""
 97      from restai.models.models import QuestionModel
 98  
 99      q = QuestionModel(question=question, stream=False)
100      start = time.perf_counter()
101  
102      # Determine the project type handler
103      match project.props.type:
104          case "rag":
105              from restai.projects.rag import RAG
106              handler = RAG(brain)
107          case "agent":
108              from restai.projects.agent import Agent
109              handler = Agent(brain)
110          case "block":
111              from restai.projects.block import Block
112              handler = Block(brain)
113          case _:
114              return "", [], 0
115  
116      try:
117          output_generator = handler.question(project, q, user, db)
118          result = None
119          async for line in output_generator:
120              result = line
121              break
122  
123          latency_ms = int((time.perf_counter() - start) * 1000)
124  
125          if result is None:
126              return "", [], latency_ms
127  
128          answer = result.get("answer", "") if isinstance(result, dict) else str(result)
129          sources = []
130          if isinstance(result, dict) and "sources" in result:
131              for s in result["sources"]:
132                  if isinstance(s, dict) and "text" in s:
133                      sources.append(s["text"])
134                  elif isinstance(s, str):
135                      sources.append(s)
136  
137          return answer, sources, latency_ms
138      except Exception as e:
139          latency_ms = int((time.perf_counter() - start) * 1000)
140          logger.exception("Error getting project answer: %s", e)
141          return f"Error: {e}", [], latency_ms
142  
143  
144  async def run_evaluation(run_id: int, app):
145      """Execute an evaluation run in the background.
146  
147      Args:
148          run_id: ID of the EvalRunDatabase record to execute.
149          app: FastAPI app instance (for accessing brain via app.state.brain).
150      """
151      from restai.database import get_db_wrapper
152      from restai.models.models import User
153  
154      db = get_db_wrapper()
155      try:
156          run = db.db.query(EvalRunDatabase).filter(EvalRunDatabase.id == run_id).first()
157          if run is None:
158              logger.error("Eval run %d not found", run_id)
159              return
160  
161          run.status = "running"
162          run.started_at = datetime.now(timezone.utc)
163          db.db.commit()
164  
165          metrics_list = json.loads(run.metrics) if isinstance(run.metrics, str) else run.metrics
166          test_cases = (
167              db.db.query(EvalTestCaseDatabase)
168              .filter(EvalTestCaseDatabase.dataset_id == run.dataset_id)
169              .all()
170          )
171  
172          if not test_cases:
173              run.status = "completed"
174              run.summary = json.dumps({})
175              run.completed_at = datetime.now(timezone.utc)
176              db.db.commit()
177              return
178  
179          brain = app.state.brain
180          project = brain.find_project(run.project_id, db)
181          if project is None:
182              run.status = "failed"
183              run.error = "Project not found or could not be loaded"
184              run.completed_at = datetime.now(timezone.utc)
185              db.db.commit()
186              return
187  
188          # Apply prompt version if specified, or record active version
189          if run.prompt_version_id:
190              pv = db.get_prompt_version(run.prompt_version_id)
191              if pv and pv.project_id == run.project_id:
192                  project.props.system = pv.system_prompt
193          else:
194              active_pv = db.get_active_prompt_version(run.project_id)
195              if active_pv:
196                  run.prompt_version_id = active_pv.id
197                  db.db.commit()
198  
199          # Get eval LLM — use the project's own LLM
200          eval_llm = None
201          if project.props.llm:
202              llm_model = brain.get_llm(project.props.llm, db)
203              if llm_model:
204                  eval_llm = DeepEvalLLM(model=llm_model.llm)
205  
206          # Create a synthetic user for the eval (use the project creator)
207          user_db = db.get_user_by_id(project.props.creator) if project.props.creator else None
208          if user_db is None:
209              # Fallback to admin
210              user_db = db.get_user_by_username("admin")
211          user = User.model_validate(user_db)
212  
213          score_totals = {}
214          score_counts = {}
215  
216          for tc in test_cases:
217              try:
218                  answer, sources, latency_ms = await _get_project_answer(
219                      project, tc.question, brain, user, db
220                  )
221  
222                  context = None
223                  if tc.context:
224                      try:
225                          context = json.loads(tc.context) if isinstance(tc.context, str) else tc.context
226                      except (json.JSONDecodeError, TypeError):
227                          context = None
228  
229                  retrieval_context = sources if sources else (context or [])
230  
231                  for metric_name in metrics_list:
232                      try:
233                          if metric_name == "faithfulness" and not retrieval_context:
234                              # Skip faithfulness if no context available
235                              continue
236                          if metric_name == "correctness" and not tc.expected_answer:
237                              # Skip correctness if no expected answer
238                              continue
239  
240                          test = LLMTestCase(
241                              input=tc.question,
242                              actual_output=answer,
243                              expected_output=tc.expected_answer,
244                              retrieval_context=retrieval_context if retrieval_context else None,
245                          )
246  
247                          if eval_llm:
248                              metric = _build_metric(metric_name, eval_llm)
249                              metric.measure(test)
250                              score = metric.score
251                              reason = metric.reason if hasattr(metric, 'reason') else None
252                          else:
253                              score = 0.0
254                              reason = "No LLM available for evaluation"
255  
256                          passed = score >= 0.5 if score is not None else False
257  
258                          result = EvalResultDatabase(
259                              run_id=run.id,
260                              test_case_id=tc.id,
261                              actual_answer=answer,
262                              retrieval_context=json.dumps(retrieval_context) if retrieval_context else None,
263                              metric_name=metric_name,
264                              score=score,
265                              reason=reason,
266                              passed=passed,
267                              latency_ms=latency_ms,
268                          )
269                          db.db.add(result)
270  
271                          if score is not None:
272                              score_totals[metric_name] = score_totals.get(metric_name, 0) + score
273                              score_counts[metric_name] = score_counts.get(metric_name, 0) + 1
274  
275                      except Exception as e:
276                          logger.exception("Error evaluating metric '%s' for test case %d: %s", metric_name, tc.id, e)
277                          result = EvalResultDatabase(
278                              run_id=run.id,
279                              test_case_id=tc.id,
280                              actual_answer=answer,
281                              metric_name=metric_name,
282                              score=0.0,
283                              reason=f"Evaluation error: {e}",
284                              passed=False,
285                              latency_ms=latency_ms,
286                          )
287                          db.db.add(result)
288  
289                  db.db.commit()
290  
291              except Exception as e:
292                  logger.exception("Error processing test case %d: %s", tc.id, e)
293                  continue
294  
295          summary = {
296              k: round(score_totals[k] / score_counts[k], 4)
297              for k in score_totals
298              if score_counts.get(k, 0) > 0
299          }
300  
301          run.status = "completed"
302          run.summary = json.dumps(summary)
303          run.completed_at = datetime.now(timezone.utc)
304          db.db.commit()
305          try:
306              from restai.webhooks import emit_event_for_project_id
307              emit_event_for_project_id(run.project_id, "eval_completed", {
308                  "run_id": run.id, "status": "completed", "summary": summary,
309              })
310          except Exception:
311              pass
312  
313      except Exception as e:
314          logger.exception("Eval run %d failed: %s", run_id, e)
315          try:
316              run = db.db.query(EvalRunDatabase).filter(EvalRunDatabase.id == run_id).first()
317              if run:
318                  run.status = "failed"
319                  run.error = str(e)
320                  run.completed_at = datetime.now(timezone.utc)
321                  db.db.commit()
322                  try:
323                      from restai.webhooks import emit_event_for_project_id
324                      emit_event_for_project_id(run.project_id, "eval_completed", {
325                          "run_id": run.id, "status": "failed", "error": str(e)[:500],
326                      })
327                  except Exception:
328                      pass
329          except Exception:
330              pass
331      finally:
332          db.db.close()