chain_as_code.py
1 # This example demonstrates defining a model directly from code. 2 # This feature allows for defining model logic within a python script, module, or notebook that is stored 3 # directly as serialized code, as opposed to object serialization that would otherwise occur when saving 4 # or logging a model object. 5 # This script defines the model's logic and specifies which class within the file contains the model code. 6 # The companion example to this, chain_as_code_driver.py, is the driver code that performs the logging and 7 # loading of this model definition. 8 9 import os 10 from operator import itemgetter 11 12 from langchain_core.output_parsers import StrOutputParser 13 from langchain_core.prompts import PromptTemplate 14 from langchain_core.runnables import RunnableLambda 15 from langchain_openai import OpenAI 16 17 import mlflow 18 19 mlflow.langchain.autolog() 20 21 assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable." 22 23 24 # Return the string contents of the most recent message from the user 25 def extract_user_query_string(chat_messages_array): 26 return chat_messages_array[-1]["content"] 27 28 29 # Return the chat history, which is is everything before the last question 30 def extract_chat_history(chat_messages_array): 31 return chat_messages_array[:-1] 32 33 34 prompt = PromptTemplate( 35 template="You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}", 36 input_variables=["question"], 37 ) 38 39 model = OpenAI(temperature=0.9) 40 41 chain = ( 42 { 43 "question": itemgetter("messages") | RunnableLambda(extract_user_query_string), 44 "chat_history": itemgetter("messages") | RunnableLambda(extract_chat_history), 45 } 46 | prompt 47 | model 48 | StrOutputParser() 49 ) 50 51 question = { 52 "messages": [ 53 { 54 "role": "user", 55 "content": "what is rag?", 56 }, 57 ] 58 } 59 60 chain.invoke(question) 61 62 # IMPORTANT: The model code needs to call `mlflow.models.set_model()` to set the model, 63 # which will be loaded back using `mlflow.langchain.load_model` for inference. 64 mlflow.models.set_model(model=chain)