generate_onnx_models.py
1 """ 2 Generates the following test resources: 3 4 - tf_model_multiple_inputs_float32.onnx 5 - tf_model_multiple_inputs_float64.onnx 6 - sklearn_model.onnx 7 8 Usage: python generate_onnx_models.py 9 """ 10 11 import numpy as np 12 import onnx 13 import onnxmltools 14 import pandas as pd 15 import tensorflow.compat.v1 as tf 16 import tf2onnx 17 from skl2onnx.common.data_types import FloatTensorType 18 from sklearn import datasets 19 from sklearn.linear_model import LogisticRegression 20 21 tf.disable_v2_behavior() 22 23 24 def generate_tf_onnx_model_multiple_inputs_float64(): 25 graph = tf.Graph() 26 with graph.as_default(): 27 t_in1 = tf.placeholder(tf.float64, 10, name="first_input") 28 t_in2 = tf.placeholder(tf.float64, 10, name="second_input") 29 t_out = tf.multiply(t_in1, t_in2) 30 tf.identity(t_out, name="output") 31 32 sess = tf.Session(graph=graph) 33 34 onnx_graph = tf2onnx.tfonnx.process_tf_graph( 35 sess.graph, input_names=["first_input:0", "second_input:0"], output_names=["output:0"] 36 ) 37 model_proto = onnx_graph.make_model("test") 38 39 onnx.save_model(model_proto, "tf_model_multiple_inputs_float64.onnx") 40 41 42 def generate_tf_onnx_model_multiple_inputs_float32(): 43 graph = tf.Graph() 44 with graph.as_default(): 45 t_in1 = tf.placeholder(tf.float32, 10, name="first_input") 46 t_in2 = tf.placeholder(tf.float32, 10, name="second_input") 47 t_out = tf.multiply(t_in1, t_in2) 48 tf.identity(t_out, name="output") 49 50 sess = tf.Session(graph=graph) 51 52 onnx_graph = tf2onnx.tfonnx.process_tf_graph( 53 sess.graph, input_names=["first_input:0", "second_input:0"], output_names=["output:0"] 54 ) 55 model_proto = onnx_graph.make_model("test") 56 57 onnx.save_model(model_proto, "tf_model_multiple_inputs_float32.onnx") 58 59 60 def generate_sklearn_onnx_model(): 61 iris = datasets.load_iris() 62 data = pd.DataFrame( 63 data=np.c_[iris["data"], iris["target"]], columns=iris["feature_names"] + ["target"] 64 ) 65 y = data["target"] 66 x = data.drop("target", axis=1) 67 68 model = LogisticRegression() 69 model.fit(x, y) 70 71 initial_type = [("float_input", FloatTensorType([None, 4]))] 72 onx = onnxmltools.convert_sklearn(model, initial_types=initial_type) 73 onnx.save_model(onx, "sklearn_model.onnx") 74 75 76 generate_tf_onnx_model_multiple_inputs_float32() 77 generate_tf_onnx_model_multiple_inputs_float64() 78 generate_sklearn_onnx_model()