/ src / python / txtai / vectors / dense / m2v.py
m2v.py
 1  """
 2  Model2Vec module
 3  """
 4  
 5  import json
 6  
 7  from huggingface_hub.errors import HFValidationError
 8  from transformers.utils import cached_file
 9  
10  # Conditional import
11  try:
12      from model2vec import StaticModel
13  
14      MODEL2VEC = True
15  except ImportError:
16      MODEL2VEC = False
17  
18  from ..base import Vectors
19  
20  
21  class Model2Vec(Vectors):
22      """
23      Builds vectors using Model2Vec.
24      """
25  
26      @staticmethod
27      def ismodel(path):
28          """
29          Checks if path is a Model2Vec model.
30  
31          Args:
32              path: input path
33  
34          Returns:
35              True if this is a Model2Vec model, False otherwise
36          """
37  
38          try:
39              # Download file and parse JSON
40              path = cached_file(path_or_repo_id=path, filename="config.json")
41              if path:
42                  with open(path, encoding="utf-8") as f:
43                      config = json.load(f)
44                      return config.get("model_type") == "model2vec"
45  
46          # Ignore this error - invalid repo or directory
47          except (HFValidationError, OSError):
48              pass
49  
50          return False
51  
52      def __init__(self, config, scoring, models):
53          # Check before parent constructor since it calls loadmodel
54          if not MODEL2VEC:
55              raise ImportError('Model2Vec is not available - install "vectors" extra to enable')
56  
57          super().__init__(config, scoring, models)
58  
59      def loadmodel(self, path):
60          return StaticModel.from_pretrained(path)
61  
62      def encode(self, data, category=None):
63          # Additional model arguments
64          modelargs = self.config.get("vectors", {})
65  
66          # Encode data
67          return self.model.encode(data, batch_size=self.encodebatch, **modelargs)