baseball.py
1 """ 2 Baseball statistics application with txtai and Streamlit. 3 4 Install txtai and streamlit (>= 1.23) to run: 5 pip install txtai streamlit 6 """ 7 8 import datetime 9 import math 10 import os 11 import random 12 13 import altair as alt 14 import numpy as np 15 import pandas as pd 16 import streamlit as st 17 18 from txtai import Embeddings 19 20 21 class Stats: 22 """ 23 Base stats class. Contains methods for loading, indexing and searching baseball stats. 24 """ 25 26 def __init__(self): 27 """ 28 Creates a new Stats instance. 29 """ 30 31 # Load columns 32 self.columns = self.loadcolumns() 33 34 # Load stats data 35 self.stats = self.load() 36 37 # Load names 38 self.names = self.loadnames() 39 40 # Build index 41 self.vectors, self.data, self.maxyear, self.embeddings = self.index() 42 43 def loadcolumns(self): 44 """ 45 Returns a list of data columns. 46 47 Returns: 48 list of columns 49 """ 50 51 raise NotImplementedError 52 53 def load(self): 54 """ 55 Loads and returns raw stats. 56 57 Returns: 58 stats 59 """ 60 61 raise NotImplementedError 62 63 def metric(self): 64 """ 65 Primary metric column. 66 67 Returns: 68 metric column name 69 """ 70 71 raise NotImplementedError 72 73 def vector(self, row): 74 """ 75 Build a vector for input row. 76 77 Args: 78 row: input row 79 80 Returns: 81 row vector 82 """ 83 84 raise NotImplementedError 85 86 def loadnames(self): 87 """ 88 Loads a name - player id dictionary. 89 90 Returns: 91 {player name: player id} 92 """ 93 94 # Get unique names 95 names = {} 96 rows = self.stats.sort_values(by=self.metric(), ascending=False)[["nameFirst", "nameLast", "playerID"]].drop_duplicates().reset_index() 97 for x, row in rows.iterrows(): 98 # Name key 99 key = f"{row['nameFirst']} {row['nameLast']}" 100 key += f" ({row['playerID']})" if key in names else "" 101 102 if key not in names: 103 # Scale scores of top n players 104 exponent = 2 if ((len(rows) - x) / len(rows)) >= 0.95 else 1 105 106 # score = num seasons ^ exponent 107 score = math.pow(len(self.stats[self.stats["playerID"] == row["playerID"]]), exponent) 108 109 # Save name key - values pair 110 names[key] = (row["playerID"], score) 111 112 return names 113 114 def index(self): 115 """ 116 Builds an embeddings index to stats data. Returns vectors, input data and embeddings index. 117 118 Returns: 119 vectors, data, embeddings 120 """ 121 122 # Build data dictionary 123 vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()} 124 data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()} 125 maxyear = max(row["yearID"] for _, row in self.stats.iterrows()) 126 127 embeddings = Embeddings({"transform": Stats.transform}) 128 embeddings.index((uid, vectors[uid], None) for uid in vectors) 129 130 return vectors, data, maxyear, embeddings 131 132 def metrics(self, name): 133 """ 134 Looks up a player's active years, best statistical year and key metrics. 135 136 Args: 137 name: player name 138 139 Returns: 140 active, best, metrics 141 """ 142 143 if name in self.names: 144 # Get player stats 145 stats = self.stats[self.stats["playerID"] == self.names[name][0]] 146 147 # Build key metrics 148 metrics = stats[["yearID", self.metric()]] 149 150 # Get best year, sort by primary metric 151 best = int(stats.sort_values(by=self.metric(), ascending=False)["yearID"].iloc[0]) 152 153 # Get years active, best year, along with metric trends 154 return metrics["yearID"].tolist(), best, metrics 155 156 return range(1871, datetime.datetime.today().year), 1950, None 157 158 def search(self, name=None, year=None, window=None, row=None, limit=10): 159 """ 160 Runs an embeddings search. This method takes either a player-year or stats row as input. 161 162 Args: 163 name: player name to search 164 year: year to search 165 window: limit to window recent seasons 166 row: row of stats to search 167 limit: max results to return 168 169 Returns: 170 list of results 171 """ 172 173 if row: 174 query = self.vector(row) 175 else: 176 # Lookup player key and build vector id 177 name = self.names.get(name) 178 query = f"{year}{name[0] if name else name}" 179 query = self.vectors.get(query) 180 181 results, ids = [], set() 182 if query is not None: 183 candidates = limit * 100 if window else limit * 5 184 for uid, _ in self.embeddings.search(query, candidates): 185 # Only add unique players 186 if uid[4:] not in ids: 187 result = self.data[uid].copy() 188 189 # Add first player if this is a player comparison. Limit results to window, if necessary 190 if (not ids and not row) or not window or result["yearID"] > self.maxyear - window: 191 result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml' 192 results.append(result) 193 ids.add(uid[4:]) 194 195 if len(ids) >= limit: 196 break 197 198 return results 199 200 def transform(self, row): 201 """ 202 Transforms a stats row into a vector. 203 204 Args: 205 row: stats row 206 207 Returns: 208 vector 209 """ 210 211 if isinstance(row, np.ndarray): 212 return row 213 214 return np.array([0.0 if not row[x] or np.isnan(row[x]) else row[x] for x in self.columns]) 215 216 217 class Batting(Stats): 218 """ 219 Batting stats. 220 """ 221 222 def loadcolumns(self): 223 return [ 224 "birthMonth", 225 "yearID", 226 "age", 227 "height", 228 "weight", 229 "G", 230 "AB", 231 "R", 232 "H", 233 "1B", 234 "2B", 235 "3B", 236 "HR", 237 "RBI", 238 "SB", 239 "CS", 240 "BB", 241 "SO", 242 "IBB", 243 "HBP", 244 "SH", 245 "SF", 246 "GIDP", 247 "POS", 248 "AVG", 249 "OBP", 250 "TB", 251 "SLG", 252 "OPS", 253 "OPS+", 254 ] 255 256 def load(self): 257 # Retrieve raw data 258 players = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/People.csv") 259 batting = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/Batting.csv") 260 fielding = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/Fielding.csv") 261 262 # Merge player data in 263 batting = pd.merge(players, batting, how="inner", on=["playerID"]) 264 265 # Require player to have at least 350 plate appearances. 266 batting = batting[((batting["AB"] + batting["BB"]) >= 350) & (batting["stint"] == 1)] 267 268 # Derive primary player positions 269 positions = self.positions(fielding) 270 271 # Calculated columns 272 batting["age"] = batting["yearID"] - batting["birthYear"] 273 batting["POS"] = batting.apply(lambda row: self.position(positions, row), axis=1) 274 batting["AVG"] = batting["H"] / batting["AB"] 275 batting["OBP"] = (batting["H"] + batting["BB"]) / (batting["AB"] + batting["BB"]) 276 batting["1B"] = batting["H"] - batting["2B"] - batting["3B"] - batting["HR"] 277 batting["TB"] = batting["1B"] + 2 * batting["2B"] + 3 * batting["3B"] + 4 * batting["HR"] 278 batting["SLG"] = batting["TB"] / batting["AB"] 279 batting["OPS"] = batting["OBP"] + batting["SLG"] 280 batting["OPS+"] = 100 + (batting["OPS"] - batting["OPS"].mean()) * 100 281 282 return batting 283 284 def metric(self): 285 return "OPS+" 286 287 def vector(self, row): 288 row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"] 289 row["AVG"] = row["H"] / row["AB"] 290 row["OBP"] = (row["H"] + row["BB"]) / (row["AB"] + row["BB"]) 291 row["SLG"] = row["TB"] / row["AB"] 292 row["OPS"] = row["OBP"] + row["SLG"] 293 row["OPS+"] = 100 + (row["OPS"] - self.stats["OPS"].mean()) * 100 294 295 return self.transform(row) 296 297 def positions(self, fielding): 298 """ 299 Derives primary positions for players. 300 301 Args: 302 fielding: fielding data 303 304 Returns: 305 {player id: (position, number of games)} 306 """ 307 308 positions = {} 309 for _, row in fielding.iterrows(): 310 uid = f'{row["yearID"]}{row["playerID"]}' 311 position = row["POS"] if row["POS"] else 0 312 if position == "P": 313 position = 1 314 elif position == "C": 315 position = 2 316 elif position == "1B": 317 position = 3 318 elif position == "2B": 319 position = 4 320 elif position == "3B": 321 position = 5 322 elif position == "SS": 323 position = 6 324 elif position == "OF": 325 position = 7 326 327 # Save position if not set or player played more at this position 328 if uid not in positions or positions[uid][1] < row["G"]: 329 positions[uid] = (position, row["G"]) 330 331 return positions 332 333 def position(self, positions, row): 334 """ 335 Looks up primary position for player row. 336 337 Arg: 338 positions: all player positions 339 row: player row 340 341 Returns: 342 primary player positions 343 """ 344 345 uid = f'{row["yearID"]}{row["playerID"]}' 346 return positions[uid][0] if uid in positions else 0 347 348 349 class Pitching(Stats): 350 """ 351 Pitching stats. 352 """ 353 354 def loadcolumns(self): 355 return [ 356 "birthMonth", 357 "yearID", 358 "age", 359 "height", 360 "weight", 361 "W", 362 "L", 363 "G", 364 "GS", 365 "CG", 366 "SHO", 367 "SV", 368 "IPouts", 369 "H", 370 "ER", 371 "HR", 372 "BB", 373 "SO", 374 "BAOpp", 375 "ERA", 376 "IBB", 377 "WP", 378 "HBP", 379 "BK", 380 "BFP", 381 "GF", 382 "R", 383 "SH", 384 "SF", 385 "GIDP", 386 "WHIP", 387 "WADJ", 388 ] 389 390 def load(self): 391 # Retrieve raw data 392 players = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/People.csv") 393 pitching = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/Pitching.csv") 394 395 # Merge player data in 396 pitching = pd.merge(players, pitching, how="inner", on=["playerID"]) 397 398 # Require player to have 20 appearances 399 pitching = pitching[(pitching["G"] >= 20) & (pitching["stint"] == 1)] 400 401 # Calculated columns 402 pitching["age"] = pitching["yearID"] - pitching["birthYear"] 403 pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3) 404 pitching["WADJ"] = (pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"]) 405 406 return pitching 407 408 def metric(self): 409 return "WADJ" 410 411 def vector(self, row): 412 row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None 413 row["WADJ"] = (row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None 414 415 return self.transform(row) 416 417 418 class Application: 419 """ 420 Main application. 421 """ 422 423 def __init__(self): 424 """ 425 Creates a new application. 426 """ 427 428 # Batting stats 429 self.batting = Batting() 430 431 # Pitching stats 432 self.pitching = Pitching() 433 434 def run(self): 435 """ 436 Runs a Streamlit application. 437 """ 438 439 st.set_page_config(layout="wide", page_title="Baseball Stats", page_icon="⚾") 440 st.title("⚾ Baseball Stats") 441 st.markdown( 442 """ 443 This application finds the best matching players using vector search with [txtai](https://github.com/neuml/txtai). 444 Raw data is from the [Lahman Baseball Database](https://sabr.org/lahman-database/). Read [this 445 article](https://medium.com/neuml/explore-baseball-history-with-vector-search-5778d98d6846) for more details. 446 """ 447 ) 448 449 player, search = st.tabs(["Player", "Search"]) 450 451 # Player tab 452 with player: 453 self.player() 454 455 # Search 456 with search: 457 self.search() 458 459 def player(self): 460 """ 461 Player tab. 462 """ 463 464 st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.") 465 466 # Get parameters 467 params = self.params() 468 469 # Category and stats 470 category = self.category(params.get("category"), "category") 471 stats = self.batting if category == "Batting" else self.pitching 472 473 # Player name 474 name = self.name(stats.names, params.get("name")) 475 476 # Limit player-year comparisons using this window 477 window = self.window(params.get("window"), "window") 478 479 # Player metrics 480 active, best, metrics = stats.metrics(name) 481 482 # Player year 483 year = self.year(active, params.get("year"), best) 484 485 # Display metrics chart 486 if len(active) > 1: 487 self.chart(category, metrics) 488 489 # Run search 490 results = stats.search(name, year, window) 491 492 # Display results 493 self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:]) 494 495 # Save query parameters 496 st.query_params.clear() 497 for key, value in [("category", category), ("name", name), ("window", window), ("year", year)]: 498 if value: 499 st.query_params[key] = value 500 501 def search(self): 502 """ 503 Stats search tab. 504 """ 505 506 st.markdown("Find players with similar statistics.") 507 508 stats, category = None, self.category("Batting", "searchcategory") 509 510 # Limit player-year comparisons using this window 511 window = self.window(None, "searchwindow") 512 513 with st.form("search"): 514 if category == "Batting": 515 stats, columns = self.batting, self.batting.columns[:-6] 516 elif category == "Pitching": 517 stats, columns = self.pitching, self.pitching.columns[:-2] 518 519 # Enter stats with data editor 520 inputs = st.data_editor(pd.DataFrame([dict((column, None) for column in columns)]), hide_index=True).astype(float) 521 522 submitted = st.form_submit_button("Search") 523 if submitted: 524 # Run search 525 results = stats.search(window=window, row=inputs.to_dict(orient="records")[0]) 526 527 # Display table 528 self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:]) 529 530 def params(self): 531 """ 532 Get application parameters. This method combines URL parameters with session parameters. 533 534 Returns: 535 parameters 536 """ 537 538 # Get query parameters 539 params = {x: st.query_params.get(x) for x in ["category", "name", "window", "year"]} 540 541 # Sync parameters with session state 542 if all(x in st.session_state for x in ["category", "name", "year"]): 543 # Copy session year if category and name are unchanged 544 params["year"] = str(st.session_state["year"]) if all(params.get(x) == st.session_state[x] for x in ["category", "name"]) else None 545 546 # Copy category, name and window from session state 547 params["category"] = st.session_state["category"] 548 params["name"] = st.session_state["name"] 549 params["window"] = st.session_state["window"] 550 551 return params 552 553 def category(self, category, key): 554 """ 555 Builds category input widget. 556 557 Args: 558 category: category parameter 559 key: widget key 560 561 Returns: 562 category component 563 """ 564 565 # List of stat categories 566 categories = ["Batting", "Pitching"] 567 568 # Get category parameter, default if not available or valid 569 default = categories.index(category) if category and category in categories else 0 570 571 # Radio box component 572 return st.radio("Stat", categories, index=default, horizontal=True, key=key) 573 574 def window(self, window, key): 575 """ 576 Limit results to last N seasons. 577 578 Args: 579 window: limit to window seasons 580 key: widget key 581 582 Returns: 583 window component 584 """ 585 586 # Get window size 587 window = st.text_input("Limit to last N seasons", value=window, key=key) 588 589 # Clear invalid input 590 if window and (not window.isnumeric() or int(window) < 1): 591 st.error("Window must be a number greater or equal to 1") 592 return None 593 594 # Convert to int 595 return int(window) if window else window 596 597 def name(self, names, name): 598 """ 599 Builds name input widget. 600 601 Args: 602 names: list of all allowable names 603 604 Returns: 605 name component 606 """ 607 608 # Get name parameter, default to random weighted value if not valid 609 name = name if name and name in names else random.choices(list(names.keys()), weights=[names[x][1] for x in names])[0] 610 611 # Sort names for display 612 names = sorted(names) 613 614 # Select box component 615 return st.selectbox("Name", names, names.index(name), key="name") 616 617 def year(self, years, year, best): 618 """ 619 Builds year input widget. 620 621 Args: 622 years: active years for a player 623 year: year parameter 624 best: default to best year if year is invalid 625 626 Returns: 627 year component 628 """ 629 630 # Get year parameter, default if not available or valid 631 year = int(year) if year and year.isdigit() and int(year) in years else best 632 633 # Slider component 634 return int(st.select_slider("Year", years, year, key="year") if len(years) > 1 else years[0]) 635 636 def chart(self, category, metrics): 637 """ 638 Displays a metric chart. 639 640 Args: 641 category: Batting or Pitching 642 metrics: player metrics to plot 643 """ 644 645 # Key metric 646 metric = self.batting.metric() if category == "Batting" else self.pitching.metric() 647 648 # Cast year to string 649 metrics["yearID"] = metrics["yearID"].astype(str) 650 651 # Metric over years 652 chart = ( 653 alt.Chart(metrics) 654 .mark_line(interpolate="monotone", point=True, strokeWidth=2.5, opacity=0.75) 655 .encode(x=alt.X("yearID", title=""), y=alt.Y(metric, scale=alt.Scale(zero=False))) 656 ) 657 658 # Create metric median rule line 659 rule = alt.Chart(metrics).mark_rule(color="gray", strokeDash=[3, 5], opacity=0.5).encode(y=f"median({metric})") 660 661 # Layered chart configuration 662 chart = (chart + rule).encode(y=alt.Y(title=metric)).properties(height=200).configure_axis(grid=False) 663 664 # Draw chart 665 st.altair_chart(chart + rule, theme="streamlit", width="stretch") 666 667 def table(self, results, columns): 668 """ 669 Displays a list of results as a table. 670 671 Args: 672 results: list of results 673 columns: column names 674 """ 675 676 if results: 677 st.dataframe( 678 results, 679 column_order=columns, 680 column_config={ 681 "link": st.column_config.LinkColumn("Link", width="small", display_text=":material/open_in_new:"), 682 "yearID": st.column_config.NumberColumn("Year", format="%d"), 683 "nameFirst": "First", 684 "nameLast": "Last", 685 "teamID": "Team", 686 "age": "Age", 687 "weight": "Weight", 688 "height": "Height", 689 }, 690 ) 691 else: 692 st.write("Player-Year not found") 693 694 695 @st.cache_resource(show_spinner=False) 696 def create(): 697 """ 698 Creates and caches a Streamlit application. 699 700 Returns: 701 Application 702 """ 703 704 return Application() 705 706 707 if __name__ == "__main__": 708 os.environ["TOKENIZERS_PARALLELISM"] = "false" 709 710 # Create and run application 711 app = create() 712 app.run()