m2v.py
1 """ 2 Model2Vec module 3 """ 4 5 import json 6 7 from huggingface_hub.errors import HFValidationError 8 from transformers.utils import cached_file 9 10 # Conditional import 11 try: 12 from model2vec import StaticModel 13 14 MODEL2VEC = True 15 except ImportError: 16 MODEL2VEC = False 17 18 from ..base import Vectors 19 20 21 class Model2Vec(Vectors): 22 """ 23 Builds vectors using Model2Vec. 24 """ 25 26 @staticmethod 27 def ismodel(path): 28 """ 29 Checks if path is a Model2Vec model. 30 31 Args: 32 path: input path 33 34 Returns: 35 True if this is a Model2Vec model, False otherwise 36 """ 37 38 try: 39 # Download file and parse JSON 40 path = cached_file(path_or_repo_id=path, filename="config.json") 41 if path: 42 with open(path, encoding="utf-8") as f: 43 config = json.load(f) 44 return config.get("model_type") == "model2vec" 45 46 # Ignore this error - invalid repo or directory 47 except (HFValidationError, OSError): 48 pass 49 50 return False 51 52 def __init__(self, config, scoring, models): 53 # Check before parent constructor since it calls loadmodel 54 if not MODEL2VEC: 55 raise ImportError('Model2Vec is not available - install "vectors" extra to enable') 56 57 super().__init__(config, scoring, models) 58 59 def loadmodel(self, path): 60 return StaticModel.from_pretrained(path) 61 62 def encode(self, data, category=None): 63 # Additional model arguments 64 modelargs = self.config.get("vectors", {}) 65 66 # Encode data 67 return self.model.encode(data, batch_size=self.encodebatch, **modelargs)