/ examples / langchain / retriever_chain.py
retriever_chain.py
 1  import os
 2  import tempfile
 3  
 4  from langchain.document_loaders import TextLoader
 5  from langchain.embeddings.openai import OpenAIEmbeddings
 6  from langchain.text_splitter import CharacterTextSplitter
 7  from langchain.vectorstores import FAISS
 8  
 9  import mlflow
10  
11  assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."
12  
13  with tempfile.TemporaryDirectory() as temp_dir:
14      persist_dir = os.path.join(temp_dir, "faiss_index")
15  
16      # Create the vector database and persist it to a local filesystem folder
17      loader = TextLoader("tests/langchain/state_of_the_union.txt")
18      documents = loader.load()
19      text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
20      docs = text_splitter.split_documents(documents)
21      embeddings = OpenAIEmbeddings()
22      db = FAISS.from_documents(docs, embeddings)
23      db.save_local(persist_dir)
24  
25      # Define a loader function to recall the retriever from the persisted vectorstore
26      def load_retriever(persist_directory):
27          embeddings = OpenAIEmbeddings()
28          vectorstore = FAISS.load_local(persist_directory, embeddings)
29          return vectorstore.as_retriever()
30  
31      # Log the retriever with the loader function
32      with mlflow.start_run() as run:
33          logged_model = mlflow.langchain.log_model(
34              db.as_retriever(),
35              name="retriever",
36              loader_fn=load_retriever,
37              persist_dir=persist_dir,
38          )
39  
40  # Load the retriever chain
41  loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
42  print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}]))