/ examples / langchain / retrieval_qa_chain.py
retrieval_qa_chain.py
 1  import os
 2  import tempfile
 3  
 4  from langchain.chains import RetrievalQA
 5  from langchain.document_loaders import TextLoader
 6  from langchain.embeddings.openai import OpenAIEmbeddings
 7  from langchain.llms import OpenAI
 8  from langchain.text_splitter import CharacterTextSplitter
 9  from langchain.vectorstores import FAISS
10  
11  import mlflow
12  
13  assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."
14  
15  with tempfile.TemporaryDirectory() as temp_dir:
16      persist_dir = os.path.join(temp_dir, "faiss_index")
17  
18      # Create the vector db, persist the db to a local fs folder
19      loader = TextLoader("tests/langchain/state_of_the_union.txt")
20      documents = loader.load()
21      text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
22      docs = text_splitter.split_documents(documents)
23      embeddings = OpenAIEmbeddings()
24      db = FAISS.from_documents(docs, embeddings)
25      db.save_local(persist_dir)
26  
27      # Create the RetrievalQA chain
28      retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=db.as_retriever())
29  
30      # Log the retrievalQA chain
31      def load_retriever(persist_directory):
32          embeddings = OpenAIEmbeddings()
33          vectorstore = FAISS.load_local(persist_directory, embeddings)
34          return vectorstore.as_retriever()
35  
36      with mlflow.start_run() as run:
37          logged_model = mlflow.langchain.log_model(
38              retrievalQA,
39              name="retrieval_qa",
40              loader_fn=load_retriever,
41              persist_dir=persist_dir,
42          )
43  
44  # Load the retrievalQA chain
45  loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
46  print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))