/ src / python / txtai / vectors / dense / sbert.py
sbert.py
 1  """
 2  Sentence Transformers module
 3  """
 4  
 5  # Conditional import
 6  try:
 7      from sentence_transformers import SentenceTransformer
 8  
 9      SENTENCE_TRANSFORMERS = True
10  except ImportError:
11      SENTENCE_TRANSFORMERS = False
12  
13  from ...models import Models
14  
15  from ..base import Vectors
16  
17  
18  class STVectors(Vectors):
19      """
20      Builds vectors using sentence-transformers (aka SBERT).
21      """
22  
23      def __init__(self, config, scoring, models):
24          # Check before parent constructor since it calls loadmodel
25          if not SENTENCE_TRANSFORMERS:
26              raise ImportError('sentence-transformers is not available - install "vectors" extra to enable')
27  
28          # Pool parameter created here since loadmodel is called from parent constructor
29          self.pool = None
30  
31          super().__init__(config, scoring, models)
32  
33      def loadmodel(self, path):
34          # Get target device
35          gpu, pool = self.config.get("gpu", True), False
36  
37          # Default mode uses a single GPU. Setting to all spawns a process per GPU.
38          if isinstance(gpu, str) and gpu == "all":
39              # Get number of accelerator devices available
40              devices = Models.acceleratorcount()
41  
42              # Enable multiprocessing pooling only when multiple devices are available
43              gpu, pool = devices <= 1, devices > 1
44  
45          # Tensor device id
46          deviceid = Models.deviceid(gpu)
47  
48          # Additional model arguments
49          modelargs = self.config.get("vectors", {})
50  
51          # Load sentence-transformers encoder
52          model = self.loadencoder(path, device=Models.device(deviceid), **modelargs)
53  
54          # Start process pool for multiple GPUs
55          if pool:
56              self.pool = model.start_multi_process_pool()
57  
58          # Return model
59          return model
60  
61      def encode(self, data, category=None):
62          # Get encode method based on input category
63          encode = self.model.encode_query if category == "query" else self.model.encode_document if category == "data" else self.model.encode
64  
65          # Additional encoding arguments
66          encodeargs = self.config.get("encodeargs", {})
67  
68          # Encode with sentence transformers encoder
69          return encode(data, pool=self.pool, batch_size=self.encodebatch, **encodeargs)
70  
71      def close(self):
72          # Close pool before model is closed in parent method
73          if self.pool:
74              self.model.stop_multi_process_pool(self.pool)
75              self.pool = None
76  
77          super().close()
78  
79      def loadencoder(self, path, device, **kwargs):
80          """
81          Loads the embeddings encoder model from path.
82  
83          Args:
84              path: model path
85              device: tensor device
86              kwargs: additional keyword args
87  
88          Returns:
89              embeddings encoder
90          """
91  
92          return SentenceTransformer(path, device=device, **kwargs)