random_forest.py
1 import h2o 2 from h2o.estimators.random_forest import H2ORandomForestEstimator 3 4 import mlflow 5 import mlflow.h2o 6 7 h2o.init() 8 9 wine = h2o.import_file(path="wine-quality.csv") 10 r = wine["quality"].runif() 11 train = wine[r < 0.7] 12 test = wine[0.3 <= r] 13 14 15 def train_random_forest(ntrees): 16 with mlflow.start_run(): 17 rf = H2ORandomForestEstimator(ntrees=ntrees) 18 train_cols = [n for n in wine.col_names if n != "quality"] 19 rf.train(train_cols, "quality", training_frame=train, validation_frame=test) 20 21 mlflow.log_param("ntrees", ntrees) 22 23 mlflow.log_metric("rmse", rf.rmse()) 24 mlflow.log_metric("r2", rf.r2()) 25 mlflow.log_metric("mae", rf.mae()) 26 27 mlflow.h2o.log_model(rf, name="model") 28 29 30 if __name__ == "__main__": 31 for ntrees in [10, 20, 50, 100, 200]: 32 train_random_forest(ntrees)