/ examples / transformers / sentence_transformer.py
sentence_transformer.py
 1  import torch
 2  from transformers import BertModel, BertTokenizerFast, pipeline
 3  
 4  import mlflow
 5  
 6  sentence_transformers_architecture = "sentence-transformers/all-MiniLM-L12-v2"
 7  task = "feature-extraction"
 8  
 9  model = BertModel.from_pretrained(sentence_transformers_architecture)
10  tokenizer = BertTokenizerFast.from_pretrained(sentence_transformers_architecture)
11  
12  sentence_transformer_pipeline = pipeline(task=task, model=model, tokenizer=tokenizer)
13  
14  with mlflow.start_run():
15      model_info = mlflow.transformers.log_model(
16          transformers_model=sentence_transformer_pipeline,
17          name="sentence_transformer",
18          framework="pt",
19          torch_dtype=torch.bfloat16,
20      )
21  
22  loaded = mlflow.transformers.load_model(model_info.model_uri, return_type="components")
23  
24  
25  def pool_and_normalize_encodings(input_sentences, model, tokenizer, **kwargs):
26      def pool(model_output, attention_mask):
27          embeddings = model_output[0]
28          expanded_mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
29          return torch.sum(embeddings * expanded_mask, 1) / torch.clamp(
30              expanded_mask.sum(1), min=1e-9
31          )
32  
33      encoded = tokenizer(
34          input_sentences,
35          padding=True,
36          truncation=True,
37          return_tensors="pt",
38      )
39      with torch.no_grad():
40          model_output = model(**encoded)
41  
42      pooled = pool(model_output, encoded["attention_mask"])
43      return torch.nn.functional.normalize(pooled, p=2, dim=1)
44  
45  
46  sentences = [
47      "He said that he's sinking all of his investment budget into coconuts.",
48      "No matter how deep you dig, there's going to be a point when it just gets too hot.",
49      "She said that there isn't a noticeable difference between a 10 year and a 15 year whisky.",
50  ]
51  
52  encoded_sentences = pool_and_normalize_encodings(sentences, **loaded)
53  
54  print(encoded_sentences)