random_forest.ipynb
1 { 2 "cells": [ 3 { 4 "cell_type": "markdown", 5 "metadata": {}, 6 "source": [ 7 "# Train Random Forest Estimator with H2O" 8 ] 9 }, 10 { 11 "cell_type": "code", 12 "execution_count": 1, 13 "metadata": {}, 14 "outputs": [ 15 { 16 "name": "stdout", 17 "output_type": "stream", 18 "text": [ 19 "Checking whether there is an H2O instance running at http://localhost:54321..... not found.\n", 20 "Attempting to start a local H2O server...\n", 21 " Java Version: openjdk version \"1.8.0_181\"; OpenJDK Runtime Environment (build 1.8.0_181-8u181-b13-2~deb9u1-b13); OpenJDK 64-Bit Server VM (build 25.181-b13, mixed mode)\n", 22 " Starting server from /opt/conda/lib/python2.7/site-packages/h2o/backend/bin/h2o.jar\n", 23 " Ice root: /tmp/tmpz8qTmm\n", 24 " JVM stdout: /tmp/tmpz8qTmm/h2o_unknownUser_started_from_python.out\n", 25 " JVM stderr: /tmp/tmpz8qTmm/h2o_unknownUser_started_from_python.err\n", 26 " Server is running at http://127.0.0.1:54321\n", 27 "Connecting to H2O server at http://127.0.0.1:54321... successful.\n" 28 ] 29 }, 30 { 31 "data": { 32 "text/html": [ 33 "<div style=\"overflow:auto\"><table style=\"width:50%\"><tr><td>H2O cluster uptime:</td>\n", 34 "<td>01 secs</td></tr>\n", 35 "<tr><td>H2O cluster timezone:</td>\n", 36 "<td>Etc/UTC</td></tr>\n", 37 "<tr><td>H2O data parsing timezone:</td>\n", 38 "<td>UTC</td></tr>\n", 39 "<tr><td>H2O cluster version:</td>\n", 40 "<td>3.22.1.1</td></tr>\n", 41 "<tr><td>H2O cluster version age:</td>\n", 42 "<td>23 days </td></tr>\n", 43 "<tr><td>H2O cluster name:</td>\n", 44 "<td>H2O_from_python_unknownUser_ukj9f9</td></tr>\n", 45 "<tr><td>H2O cluster total nodes:</td>\n", 46 "<td>1</td></tr>\n", 47 "<tr><td>H2O cluster free memory:</td>\n", 48 "<td>3.042 Gb</td></tr>\n", 49 "<tr><td>H2O cluster total cores:</td>\n", 50 "<td>7</td></tr>\n", 51 "<tr><td>H2O cluster allowed cores:</td>\n", 52 "<td>7</td></tr>\n", 53 "<tr><td>H2O cluster status:</td>\n", 54 "<td>accepting new members, healthy</td></tr>\n", 55 "<tr><td>H2O connection url:</td>\n", 56 "<td>http://127.0.0.1:54321</td></tr>\n", 57 "<tr><td>H2O connection proxy:</td>\n", 58 "<td>None</td></tr>\n", 59 "<tr><td>H2O internal security:</td>\n", 60 "<td>False</td></tr>\n", 61 "<tr><td>H2O API Extensions:</td>\n", 62 "<td>XGBoost, Algos, AutoML, Core V3, Core V4</td></tr>\n", 63 "<tr><td>Python version:</td>\n", 64 "<td>2.7.15 final</td></tr></table></div>" 65 ], 66 "text/plain": [ 67 "-------------------------- ----------------------------------------\n", 68 "H2O cluster uptime: 01 secs\n", 69 "H2O cluster timezone: Etc/UTC\n", 70 "H2O data parsing timezone: UTC\n", 71 "H2O cluster version: 3.22.1.1\n", 72 "H2O cluster version age: 23 days\n", 73 "H2O cluster name: H2O_from_python_unknownUser_ukj9f9\n", 74 "H2O cluster total nodes: 1\n", 75 "H2O cluster free memory: 3.042 Gb\n", 76 "H2O cluster total cores: 7\n", 77 "H2O cluster allowed cores: 7\n", 78 "H2O cluster status: accepting new members, healthy\n", 79 "H2O connection url: http://127.0.0.1:54321\n", 80 "H2O connection proxy:\n", 81 "H2O internal security: False\n", 82 "H2O API Extensions: XGBoost, Algos, AutoML, Core V3, Core V4\n", 83 "Python version: 2.7.15 final\n", 84 "-------------------------- ----------------------------------------" 85 ] 86 }, 87 "metadata": {}, 88 "output_type": "display_data" 89 }, 90 { 91 "name": "stdout", 92 "output_type": "stream", 93 "text": [ 94 "Parse progress: |█████████████████████████████████████████████████████████| 100%\n" 95 ] 96 } 97 ], 98 "source": [ 99 "import h2o\n", 100 "from h2o.estimators.random_forest import H2ORandomForestEstimator\n", 101 "\n", 102 "import mlflow\n", 103 "import mlflow.h2o\n", 104 "\n", 105 "h2o.init()\n", 106 "\n", 107 "wine = h2o.import_file(path=\"wine-quality.csv\")\n", 108 "r = wine[\"quality\"].runif()\n", 109 "train = wine[r < 0.7]\n", 110 "test = wine[0.3 <= r]" 111 ] 112 }, 113 { 114 "cell_type": "code", 115 "execution_count": 2, 116 "metadata": {}, 117 "outputs": [], 118 "source": [ 119 "def train_random_forest(ntrees):\n", 120 " with mlflow.start_run():\n", 121 " rf = H2ORandomForestEstimator(ntrees=ntrees)\n", 122 " train_cols = [n for n in wine.col_names if n != \"quality\"]\n", 123 " rf.train(train_cols, \"quality\", training_frame=train, validation_frame=test)\n", 124 "\n", 125 " mlflow.log_param(\"ntrees\", ntrees)\n", 126 "\n", 127 " mlflow.log_metric(\"rmse\", rf.rmse())\n", 128 " mlflow.log_metric(\"r2\", rf.r2())\n", 129 " mlflow.log_metric(\"mae\", rf.mae())\n", 130 "\n", 131 " mlflow.h2o.log_model(rf, name=\"model\")" 132 ] 133 }, 134 { 135 "cell_type": "code", 136 "execution_count": 3, 137 "metadata": {}, 138 "outputs": [ 139 { 140 "name": "stdout", 141 "output_type": "stream", 142 "text": [ 143 "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", 144 "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", 145 "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", 146 "drf Model Build progress: |███████████████████████████████████████████████| 100%\n", 147 "drf Model Build progress: |███████████████████████████████████████████████| 100%\n" 148 ] 149 } 150 ], 151 "source": [ 152 "for ntrees in [10, 20, 50, 100, 200]:\n", 153 " train_random_forest(ntrees)" 154 ] 155 }, 156 { 157 "cell_type": "code", 158 "execution_count": 4, 159 "metadata": {}, 160 "outputs": [], 161 "source": [ 162 "import yaml" 163 ] 164 }, 165 { 166 "cell_type": "code", 167 "execution_count": 5, 168 "metadata": {}, 169 "outputs": [ 170 { 171 "data": { 172 "text/plain": [ 173 "<function yaml.safe_dump>" 174 ] 175 }, 176 "execution_count": 5, 177 "metadata": {}, 178 "output_type": "execute_result" 179 } 180 ], 181 "source": [ 182 "yaml.safe_dump" 183 ] 184 } 185 ], 186 "metadata": { 187 "kernelspec": { 188 "display_name": "Python 2", 189 "language": "python", 190 "name": "python2" 191 }, 192 "language_info": { 193 "codemirror_mode": { 194 "name": "ipython", 195 "version": 2 196 }, 197 "file_extension": ".py", 198 "mimetype": "text/x-python", 199 "name": "python", 200 "nbconvert_exporter": "python", 201 "pygments_lexer": "ipython2", 202 "version": "2.7.15" 203 } 204 }, 205 "nbformat": 4, 206 "nbformat_minor": 2 207 }