simple_chain.py
1 import os 2 3 from langchain.chains import LLMChain 4 from langchain.llms import OpenAI 5 from langchain.prompts import PromptTemplate 6 7 import mlflow 8 9 # Ensure the OpenAI API key is set in the environment 10 assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable." 11 12 # Initialize the OpenAI model and the prompt template 13 llm = OpenAI(temperature=0.9) 14 prompt = PromptTemplate( 15 input_variables=["product"], 16 template="What is a good name for a company that makes {product}?", 17 ) 18 19 # Create the LLMChain with the specified model and prompt 20 chain = LLMChain(llm=llm, prompt=prompt) 21 22 # Log the LangChain LLMChain in an MLflow run 23 with mlflow.start_run(): 24 logged_model = mlflow.langchain.log_model(chain, name="langchain_model") 25 26 # Load the logged model using MLflow's Python function flavor 27 loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri) 28 29 # Predict using the loaded model 30 print(loaded_model.predict([{"product": "colorful socks"}]))