/ src / evidently / legacy / features / hf_feature.py
hf_feature.py
  1  from functools import partial
  2  from typing import Any
  3  from typing import Callable
  4  from typing import ClassVar
  5  from typing import Dict
  6  from typing import List
  7  from typing import Optional
  8  from typing import Tuple
  9  
 10  import pandas as pd
 11  
 12  from evidently.legacy.core import ColumnType
 13  from evidently.legacy.features.generated_features import DataFeature
 14  from evidently.legacy.features.generated_features import FeatureTypeFieldMixin
 15  from evidently.legacy.utils.data_preprocessing import DataDefinition
 16  
 17  
 18  class HuggingFaceFeature(FeatureTypeFieldMixin, DataFeature):
 19      class Config:
 20          type_alias = "evidently:feature:HuggingFaceFeature"
 21  
 22      column_name: str
 23      model: str
 24      params: dict
 25  
 26      def __init__(self, *, column_name: str, model: str, params: dict, display_name: str):
 27          self.column_name = column_name
 28          self.model = model
 29          self.params = params
 30          self.display_name = display_name
 31          super().__init__(feature_type=_model_type(model))
 32  
 33      def generate_data(self, data: pd.DataFrame, data_definition: DataDefinition) -> pd.Series:
 34          val = _models.get(self.model)
 35          if val is None:
 36              raise ValueError(f"Model {self.model} not found. Available models: {', '.join(_models.keys())}")
 37          _, available_params, func = val
 38          result = func(data[self.column_name], **{param: self.params.get(param, None) for param in available_params})
 39          return result
 40  
 41      def __hash__(self):
 42          return DataFeature.__hash__(self)
 43  
 44  
 45  class HuggingFaceToxicityFeature(DataFeature):
 46      class Config:
 47          type_alias = "evidently:feature:HuggingFaceToxicityFeature"
 48  
 49      __feature_type__: ClassVar = ColumnType.Numerical
 50      column_name: str
 51      model: Optional[str]
 52      toxic_label: Optional[str]
 53  
 54      def __init__(
 55          self,
 56          *,
 57          column_name: str,
 58          display_name: str,
 59          model: Optional[str] = None,
 60          toxic_label: Optional[str] = None,
 61      ):
 62          self.column_name = column_name
 63          self.model = model
 64          self.toxic_label = toxic_label
 65          super().__init__(display_name=display_name)
 66  
 67      def generate_data(self, data: pd.DataFrame, data_definition: DataDefinition) -> pd.Series:
 68          return _toxicity(self.model, self.toxic_label, data[self.column_name])
 69  
 70  
 71  def _samlowe_roberta_base_go_emotions(data: pd.Series, label: str) -> pd.Series:
 72      from transformers import pipeline
 73  
 74      def _convert_labels(row):
 75          return {x["label"]: x["score"] for x in row}
 76  
 77      classifier = pipeline(task="text-classification", model="SamLowe/roberta-base-go_emotions", top_k=None)
 78      model_outputs = classifier(data.fillna("").tolist())
 79      return pd.Series([_convert_labels(out).get(label, None) for out in model_outputs], index=data.index)
 80  
 81  
 82  def _openai_detector(data: pd.Series, score_threshold: float) -> pd.Series:
 83      from transformers import pipeline
 84  
 85      def _get_label(row):
 86          return row["label"] if row["score"] > score_threshold else "Unknown"
 87  
 88      pipe = pipeline("text-classification", model="roberta-base-openai-detector")
 89      return pd.Series([_get_label(x) for x in pipe(data.fillna("").tolist())], index=data.index)
 90  
 91  
 92  def _map_labels(labels: List[str], scores: List[float], threshold: float) -> Optional[str]:
 93      if len(labels) == 0:
 94          return None
 95      if len(labels) == 1:
 96          if scores[0] > threshold:
 97              return labels[0]
 98          else:
 99              return "not " + labels[0]
100      label = max(zip(labels, scores), key=lambda x: x[1])
101      return label[0] if label[1] > threshold else "unknown"
102  
103  
104  def _lmnli_fever(data: pd.Series, labels: List[str], threshold: Optional[float]) -> pd.Series:
105      from transformers import pipeline
106  
107      threshold = threshold if threshold is not None else 0.5
108  
109      classifier = pipeline(
110          "zero-shot-classification",
111          model="MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli",
112      )
113      output = classifier(data.fillna("").tolist(), labels, multi_label=False)
114  
115      return pd.Series([_map_labels(o["labels"], o["scores"], threshold) for o in output], index=data.index)
116  
117  
118  def _toxicity(model_name: Optional[str], toxic_label: Optional[str], data: pd.Series) -> pd.Series:
119      import evaluate
120  
121      column_data: List[Any] = data.values.tolist()
122      model = evaluate.load("toxicity", model_name, module_type="measurement")
123      if toxic_label is None:
124          scores = model.compute(predictions=column_data)
125      else:
126          scores = model.compute(predictions=column_data, toxic_label=toxic_label)
127      return pd.Series(scores["toxicity"], index=data.index)
128  
129  
130  def _dfp(data: pd.Series, threshold: Optional[float]) -> pd.Series:
131      from transformers import pipeline
132  
133      if threshold is None:
134          threshold = 0.5
135  
136      model = pipeline("token-classification", "lakshyakh93/deberta_finetuned_pii")
137      output = model(data.tolist())
138      converted_output = [
139          _map_labels(
140              [x["entity"] for x in entities],
141              [x["score"] for x in entities],
142              threshold,
143          )
144          for entities in output
145      ]
146      return pd.Series(converted_output, index=data.index)
147  
148  
149  def _model_type(model: str) -> ColumnType:
150      return _models.get(model, (ColumnType.Unknown, None, None))[0]
151  
152  
153  _models: Dict[str, Tuple[ColumnType, List[str], Callable[..., pd.Series]]] = {
154      "SamLowe/roberta-base-go_emotions": (ColumnType.Numerical, ["label"], _samlowe_roberta_base_go_emotions),
155      "openai-community/roberta-base-openai-detector": (ColumnType.Categorical, ["score_threshold"], _openai_detector),
156      "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli": (
157          ColumnType.Categorical,
158          ["labels", "threshold"],
159          _lmnli_fever,
160      ),
161      "DaNLP/da-electra-hatespeech-detection": (
162          ColumnType.Numerical,
163          [],
164          partial(_toxicity, "DaNLP/da-electra-hatespeech-detection", "offensive"),
165      ),
166      "facebook/roberta-hate-speech-dynabench-r4-target": (
167          ColumnType.Numerical,
168          [],
169          partial(_toxicity, "facebook/roberta-hate-speech-dynabench-r4-target", "hate"),
170      ),
171      "lakshyakh93/deberta_finetuned_pii": (
172          ColumnType.Categorical,
173          ["threshold"],
174          _dfp,
175      ),
176  }