simple.py
1 import transformers 2 3 import mlflow 4 5 task = "text-generation" 6 7 generation_pipeline = transformers.pipeline( 8 task=task, 9 model="gpt2", 10 ) 11 12 input_example = ["prompt 1", "prompt 2", "prompt 3"] 13 14 parameters = {"max_length": 512, "do_sample": True} 15 16 with mlflow.start_run() as run: 17 model_info = mlflow.transformers.log_model( 18 transformers_model=generation_pipeline, 19 name="text_generator", 20 input_example=(["prompt 1", "prompt 2", "prompt 3"], parameters), 21 ) 22 23 sentence_generator = mlflow.pyfunc.load_model(model_info.model_uri) 24 25 print( 26 sentence_generator.predict( 27 ["tell me a story about rocks", "Tell me a joke about a dog that likes spaghetti"], 28 # pass in additional parameters applied to the pipeline during inference 29 params=parameters, 30 ) 31 )