/ examples / langchain / simple_chain.py
simple_chain.py
 1  import os
 2  
 3  from langchain.chains import LLMChain
 4  from langchain.llms import OpenAI
 5  from langchain.prompts import PromptTemplate
 6  
 7  import mlflow
 8  
 9  # Ensure the OpenAI API key is set in the environment
10  assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."
11  
12  # Initialize the OpenAI model and the prompt template
13  llm = OpenAI(temperature=0.9)
14  prompt = PromptTemplate(
15      input_variables=["product"],
16      template="What is a good name for a company that makes {product}?",
17  )
18  
19  # Create the LLMChain with the specified model and prompt
20  chain = LLMChain(llm=llm, prompt=prompt)
21  
22  # Log the LangChain LLMChain in an MLflow run
23  with mlflow.start_run():
24      logged_model = mlflow.langchain.log_model(chain, name="langchain_model")
25  
26  # Load the logged model using MLflow's Python function flavor
27  loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)
28  
29  # Predict using the loaded model
30  print(loaded_model.predict([{"product": "colorful socks"}]))