/ src / vector_store / embeddings.py
embeddings.py
 1  """
 2  Contains code for embedding chunks of text into selected vector databases.
 3  """
 4  import os
 5  from pathlib import Path
 6  from loguru import logger 
 7  from argparse import ArgumentParser 
 8  
 9  from langchain_core.documents import Document
10  from langchain_chroma.vectorstores import Chroma
11  from langchain_huggingface import HuggingFaceEmbeddings
12  
13  from src.setup.paths import CHROMA_DIR
14  from src.setup.config import embed_config
15  from src.data_processing.cleaning import Cleaner 
16  from src.data_processing.chunking import split_documents 
17  from src.data_preparation.sourcing import Author, get_sources 
18  
19      
20  class ChromaAPI: 
21      def __init__(self, author: Author) -> None: 
22          self.author: Author = author
23          self.embeddings_directory: Path = CHROMA_DIR.joinpath(author.name) 
24  
25          # logger.info(f"Creating vector database for {author.name}")
26          self.vector_store: Chroma = Chroma(
27              collection_name=author.name.replace(" ", "_"),  # Chroma does not permit collection names to have whitespace in them 
28              persist_directory=str(self.embeddings_directory),
29              embedding_function=get_embedding_model()
30          ) 
31  
32      def embed_books(self, chunk: bool) -> list[str] | None:
33  
34          if len(os.listdir(self.embeddings_directory)) > 1:
35              logger.success(f"Embeddings have already been made for {self.author.name}'s texts")
36          else:
37              cleaner = Cleaner(author=self.author)
38              documents: list[Document] | None = cleaner.execute() 
39              
40              if documents == None:
41                  raise Exception(f"Unable to retrieve cleaned text for {self.author.name}")
42              else:
43                  logger.info(f"Creating vector embeddings of the texts by {self.author.name}")
44                  if chunk:
45                      try:
46                          chunks: list[Document] = split_documents(documents=documents)
47                          ids = self.vector_store.add_documents(documents=chunks)
48                      except Exception as e:
49                          logger.error(e)
50                  else:
51                      ids = self.vector_store.add_documents(documents=documents)
52                      logger.success(f"Successfully embedded the {'chunks of' if chunk else ''} text using ChromaDB.")
53                      return ids
54  
55  
56  def get_embedding_model() -> HuggingFaceEmbeddings:
57      return HuggingFaceEmbeddings(model_name=embed_config.embedding_model_name)
58  
59  
60  if __name__ == "__main__": 
61  
62      parser = ArgumentParser()
63      _ = parser.add_argument("--chunk", action="store_true")
64      args = parser.parse_args()
65     
66      for author in get_sources():
67          author.download_books()
68          api = ChromaAPI(author=author)
69          _ = api.embed_books(chunk=args.chunk)
70