/ examples / langchain / chain_as_code.py
chain_as_code.py
 1  # This example demonstrates defining a model directly from code.
 2  # This feature allows for defining model logic within a python script, module, or notebook that is stored
 3  # directly as serialized code, as opposed to object serialization that would otherwise occur when saving
 4  # or logging a model object.
 5  # This script defines the model's logic and specifies which class within the file contains the model code.
 6  # The companion example to this, chain_as_code_driver.py, is the driver code that performs the  logging and
 7  # loading of this model definition.
 8  
 9  import os
10  from operator import itemgetter
11  
12  from langchain_core.output_parsers import StrOutputParser
13  from langchain_core.prompts import PromptTemplate
14  from langchain_core.runnables import RunnableLambda
15  from langchain_openai import OpenAI
16  
17  import mlflow
18  
19  mlflow.langchain.autolog()
20  
21  assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."
22  
23  
24  # Return the string contents of the most recent message from the user
25  def extract_user_query_string(chat_messages_array):
26      return chat_messages_array[-1]["content"]
27  
28  
29  # Return the chat history, which is is everything before the last question
30  def extract_chat_history(chat_messages_array):
31      return chat_messages_array[:-1]
32  
33  
34  prompt = PromptTemplate(
35      template="You are a hello world bot.  Respond with a reply to the user's question that is fun and interesting to the user.  User's question: {question}",
36      input_variables=["question"],
37  )
38  
39  model = OpenAI(temperature=0.9)
40  
41  chain = (
42      {
43          "question": itemgetter("messages") | RunnableLambda(extract_user_query_string),
44          "chat_history": itemgetter("messages") | RunnableLambda(extract_chat_history),
45      }
46      | prompt
47      | model
48      | StrOutputParser()
49  )
50  
51  question = {
52      "messages": [
53          {
54              "role": "user",
55              "content": "what is rag?",
56          },
57      ]
58  }
59  
60  chain.invoke(question)
61  
62  # IMPORTANT: The model code needs to call `mlflow.models.set_model()` to set the model,
63  # which will be loaded back using `mlflow.langchain.load_model` for inference.
64  mlflow.models.set_model(model=chain)