sentence_transformer.py
1 import torch 2 from transformers import BertModel, BertTokenizerFast, pipeline 3 4 import mlflow 5 6 sentence_transformers_architecture = "sentence-transformers/all-MiniLM-L12-v2" 7 task = "feature-extraction" 8 9 model = BertModel.from_pretrained(sentence_transformers_architecture) 10 tokenizer = BertTokenizerFast.from_pretrained(sentence_transformers_architecture) 11 12 sentence_transformer_pipeline = pipeline(task=task, model=model, tokenizer=tokenizer) 13 14 with mlflow.start_run(): 15 model_info = mlflow.transformers.log_model( 16 transformers_model=sentence_transformer_pipeline, 17 name="sentence_transformer", 18 framework="pt", 19 torch_dtype=torch.bfloat16, 20 ) 21 22 loaded = mlflow.transformers.load_model(model_info.model_uri, return_type="components") 23 24 25 def pool_and_normalize_encodings(input_sentences, model, tokenizer, **kwargs): 26 def pool(model_output, attention_mask): 27 embeddings = model_output[0] 28 expanded_mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float() 29 return torch.sum(embeddings * expanded_mask, 1) / torch.clamp( 30 expanded_mask.sum(1), min=1e-9 31 ) 32 33 encoded = tokenizer( 34 input_sentences, 35 padding=True, 36 truncation=True, 37 return_tensors="pt", 38 ) 39 with torch.no_grad(): 40 model_output = model(**encoded) 41 42 pooled = pool(model_output, encoded["attention_mask"]) 43 return torch.nn.functional.normalize(pooled, p=2, dim=1) 44 45 46 sentences = [ 47 "He said that he's sinking all of his investment budget into coconuts.", 48 "No matter how deep you dig, there's going to be a point when it just gets too hot.", 49 "She said that there isn't a noticeable difference between a 10 year and a 15 year whisky.", 50 ] 51 52 encoded_sentences = pool_and_normalize_encodings(sentences, **loaded) 53 54 print(encoded_sentences)