embeddings.py
1 """ 2 Contains code for embedding chunks of text into selected vector databases. 3 """ 4 import os 5 from pathlib import Path 6 from loguru import logger 7 from argparse import ArgumentParser 8 9 from langchain_core.documents import Document 10 from langchain_chroma.vectorstores import Chroma 11 from langchain_huggingface import HuggingFaceEmbeddings 12 13 from src.setup.paths import CHROMA_DIR 14 from src.setup.config import embed_config 15 from src.data_processing.cleaning import Cleaner 16 from src.data_processing.chunking import split_documents 17 from src.data_preparation.sourcing import Author, get_sources 18 19 20 class ChromaAPI: 21 def __init__(self, author: Author) -> None: 22 self.author: Author = author 23 self.embeddings_directory: Path = CHROMA_DIR.joinpath(author.name) 24 25 # logger.info(f"Creating vector database for {author.name}") 26 self.vector_store: Chroma = Chroma( 27 collection_name=author.name.replace(" ", "_"), # Chroma does not permit collection names to have whitespace in them 28 persist_directory=str(self.embeddings_directory), 29 embedding_function=get_embedding_model() 30 ) 31 32 def embed_books(self, chunk: bool) -> list[str] | None: 33 34 if len(os.listdir(self.embeddings_directory)) > 1: 35 logger.success(f"Embeddings have already been made for {self.author.name}'s texts") 36 else: 37 cleaner = Cleaner(author=self.author) 38 documents: list[Document] | None = cleaner.execute() 39 40 if documents == None: 41 raise Exception(f"Unable to retrieve cleaned text for {self.author.name}") 42 else: 43 logger.info(f"Creating vector embeddings of the texts by {self.author.name}") 44 if chunk: 45 try: 46 chunks: list[Document] = split_documents(documents=documents) 47 ids = self.vector_store.add_documents(documents=chunks) 48 except Exception as e: 49 logger.error(e) 50 else: 51 ids = self.vector_store.add_documents(documents=documents) 52 logger.success(f"Successfully embedded the {'chunks of' if chunk else ''} text using ChromaDB.") 53 return ids 54 55 56 def get_embedding_model() -> HuggingFaceEmbeddings: 57 return HuggingFaceEmbeddings(model_name=embed_config.embedding_model_name) 58 59 60 if __name__ == "__main__": 61 62 parser = ArgumentParser() 63 _ = parser.add_argument("--chunk", action="store_true") 64 args = parser.parse_args() 65 66 for author in get_sources(): 67 author.download_books() 68 api = ChromaAPI(author=author) 69 _ = api.embed_books(chunk=args.chunk) 70