retriever_chain.py
1 import os 2 import tempfile 3 4 from langchain.document_loaders import TextLoader 5 from langchain.embeddings.openai import OpenAIEmbeddings 6 from langchain.text_splitter import CharacterTextSplitter 7 from langchain.vectorstores import FAISS 8 9 import mlflow 10 11 assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable." 12 13 with tempfile.TemporaryDirectory() as temp_dir: 14 persist_dir = os.path.join(temp_dir, "faiss_index") 15 16 # Create the vector database and persist it to a local filesystem folder 17 loader = TextLoader("tests/langchain/state_of_the_union.txt") 18 documents = loader.load() 19 text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) 20 docs = text_splitter.split_documents(documents) 21 embeddings = OpenAIEmbeddings() 22 db = FAISS.from_documents(docs, embeddings) 23 db.save_local(persist_dir) 24 25 # Define a loader function to recall the retriever from the persisted vectorstore 26 def load_retriever(persist_directory): 27 embeddings = OpenAIEmbeddings() 28 vectorstore = FAISS.load_local(persist_directory, embeddings) 29 return vectorstore.as_retriever() 30 31 # Log the retriever with the loader function 32 with mlflow.start_run() as run: 33 logged_model = mlflow.langchain.log_model( 34 db.as_retriever(), 35 name="retriever", 36 loader_fn=load_retriever, 37 persist_dir=persist_dir, 38 ) 39 40 # Load the retriever chain 41 loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri) 42 print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))