sbert.py
1 """ 2 Sentence Transformers module 3 """ 4 5 # Conditional import 6 try: 7 from sentence_transformers import SentenceTransformer 8 9 SENTENCE_TRANSFORMERS = True 10 except ImportError: 11 SENTENCE_TRANSFORMERS = False 12 13 from ...models import Models 14 15 from ..base import Vectors 16 17 18 class STVectors(Vectors): 19 """ 20 Builds vectors using sentence-transformers (aka SBERT). 21 """ 22 23 def __init__(self, config, scoring, models): 24 # Check before parent constructor since it calls loadmodel 25 if not SENTENCE_TRANSFORMERS: 26 raise ImportError('sentence-transformers is not available - install "vectors" extra to enable') 27 28 # Pool parameter created here since loadmodel is called from parent constructor 29 self.pool = None 30 31 super().__init__(config, scoring, models) 32 33 def loadmodel(self, path): 34 # Get target device 35 gpu, pool = self.config.get("gpu", True), False 36 37 # Default mode uses a single GPU. Setting to all spawns a process per GPU. 38 if isinstance(gpu, str) and gpu == "all": 39 # Get number of accelerator devices available 40 devices = Models.acceleratorcount() 41 42 # Enable multiprocessing pooling only when multiple devices are available 43 gpu, pool = devices <= 1, devices > 1 44 45 # Tensor device id 46 deviceid = Models.deviceid(gpu) 47 48 # Additional model arguments 49 modelargs = self.config.get("vectors", {}) 50 51 # Load sentence-transformers encoder 52 model = self.loadencoder(path, device=Models.device(deviceid), **modelargs) 53 54 # Start process pool for multiple GPUs 55 if pool: 56 self.pool = model.start_multi_process_pool() 57 58 # Return model 59 return model 60 61 def encode(self, data, category=None): 62 # Get encode method based on input category 63 encode = self.model.encode_query if category == "query" else self.model.encode_document if category == "data" else self.model.encode 64 65 # Additional encoding arguments 66 encodeargs = self.config.get("encodeargs", {}) 67 68 # Encode with sentence transformers encoder 69 return encode(data, pool=self.pool, batch_size=self.encodebatch, **encodeargs) 70 71 def close(self): 72 # Close pool before model is closed in parent method 73 if self.pool: 74 self.model.stop_multi_process_pool(self.pool) 75 self.pool = None 76 77 super().close() 78 79 def loadencoder(self, path, device, **kwargs): 80 """ 81 Loads the embeddings encoder model from path. 82 83 Args: 84 path: model path 85 device: tensor device 86 kwargs: additional keyword args 87 88 Returns: 89 embeddings encoder 90 """ 91 92 return SentenceTransformer(path, device=device, **kwargs)