retrieval_qa_chain.py
1 import os 2 import tempfile 3 4 from langchain.chains import RetrievalQA 5 from langchain.document_loaders import TextLoader 6 from langchain.embeddings.openai import OpenAIEmbeddings 7 from langchain.llms import OpenAI 8 from langchain.text_splitter import CharacterTextSplitter 9 from langchain.vectorstores import FAISS 10 11 import mlflow 12 13 assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable." 14 15 with tempfile.TemporaryDirectory() as temp_dir: 16 persist_dir = os.path.join(temp_dir, "faiss_index") 17 18 # Create the vector db, persist the db to a local fs folder 19 loader = TextLoader("tests/langchain/state_of_the_union.txt") 20 documents = loader.load() 21 text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) 22 docs = text_splitter.split_documents(documents) 23 embeddings = OpenAIEmbeddings() 24 db = FAISS.from_documents(docs, embeddings) 25 db.save_local(persist_dir) 26 27 # Create the RetrievalQA chain 28 retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=db.as_retriever()) 29 30 # Log the retrievalQA chain 31 def load_retriever(persist_directory): 32 embeddings = OpenAIEmbeddings() 33 vectorstore = FAISS.load_local(persist_directory, embeddings) 34 return vectorstore.as_retriever() 35 36 with mlflow.start_run() as run: 37 logged_model = mlflow.langchain.log_model( 38 retrievalQA, 39 name="retrieval_qa", 40 loader_fn=load_retriever, 41 persist_dir=persist_dir, 42 ) 43 44 # Load the retrievalQA chain 45 loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri) 46 print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))