/ examples / transformers / simple.py
simple.py
 1  import transformers
 2  
 3  import mlflow
 4  
 5  task = "text-generation"
 6  
 7  generation_pipeline = transformers.pipeline(
 8      task=task,
 9      model="gpt2",
10  )
11  
12  input_example = ["prompt 1", "prompt 2", "prompt 3"]
13  
14  parameters = {"max_length": 512, "do_sample": True}
15  
16  with mlflow.start_run() as run:
17      model_info = mlflow.transformers.log_model(
18          transformers_model=generation_pipeline,
19          name="text_generator",
20          input_example=(["prompt 1", "prompt 2", "prompt 3"], parameters),
21      )
22  
23  sentence_generator = mlflow.pyfunc.load_model(model_info.model_uri)
24  
25  print(
26      sentence_generator.predict(
27          ["tell me a story about rocks", "Tell me a joke about a dog that likes spaghetti"],
28          # pass in additional parameters applied to the pipeline during inference
29          params=parameters,
30      )
31  )