/ examples / baseball.py
baseball.py
  1  """
  2  Baseball statistics application with txtai and Streamlit.
  3  
  4  Install txtai and streamlit (>= 1.23) to run:
  5    pip install txtai streamlit
  6  """
  7  
  8  import datetime
  9  import math
 10  import os
 11  import random
 12  
 13  import altair as alt
 14  import numpy as np
 15  import pandas as pd
 16  import streamlit as st
 17  
 18  from txtai import Embeddings
 19  
 20  
 21  class Stats:
 22      """
 23      Base stats class. Contains methods for loading, indexing and searching baseball stats.
 24      """
 25  
 26      def __init__(self):
 27          """
 28          Creates a new Stats instance.
 29          """
 30  
 31          # Load columns
 32          self.columns = self.loadcolumns()
 33  
 34          # Load stats data
 35          self.stats = self.load()
 36  
 37          # Load names
 38          self.names = self.loadnames()
 39  
 40          # Build index
 41          self.vectors, self.data, self.maxyear, self.embeddings = self.index()
 42  
 43      def loadcolumns(self):
 44          """
 45          Returns a list of data columns.
 46  
 47          Returns:
 48              list of columns
 49          """
 50  
 51          raise NotImplementedError
 52  
 53      def load(self):
 54          """
 55          Loads and returns raw stats.
 56  
 57          Returns:
 58              stats
 59          """
 60  
 61          raise NotImplementedError
 62  
 63      def metric(self):
 64          """
 65          Primary metric column.
 66  
 67          Returns:
 68              metric column name
 69          """
 70  
 71          raise NotImplementedError
 72  
 73      def vector(self, row):
 74          """
 75          Build a vector for input row.
 76  
 77          Args:
 78              row: input row
 79  
 80          Returns:
 81              row vector
 82          """
 83  
 84          raise NotImplementedError
 85  
 86      def loadnames(self):
 87          """
 88          Loads a name - player id dictionary.
 89  
 90          Returns:
 91              {player name: player id}
 92          """
 93  
 94          # Get unique names
 95          names = {}
 96          rows = self.stats.sort_values(by=self.metric(), ascending=False)[["nameFirst", "nameLast", "playerID"]].drop_duplicates().reset_index()
 97          for x, row in rows.iterrows():
 98              # Name key
 99              key = f"{row['nameFirst']} {row['nameLast']}"
100              key += f" ({row['playerID']})" if key in names else ""
101  
102              if key not in names:
103                  # Scale scores of top n players
104                  exponent = 2 if ((len(rows) - x) / len(rows)) >= 0.95 else 1
105  
106                  # score = num seasons ^ exponent
107                  score = math.pow(len(self.stats[self.stats["playerID"] == row["playerID"]]), exponent)
108  
109                  # Save name key - values pair
110                  names[key] = (row["playerID"], score)
111  
112          return names
113  
114      def index(self):
115          """
116          Builds an embeddings index to stats data. Returns vectors, input data and embeddings index.
117  
118          Returns:
119              vectors, data, embeddings
120          """
121  
122          # Build data dictionary
123          vectors = {f'{row["yearID"]}{row["playerID"]}': self.transform(row) for _, row in self.stats.iterrows()}
124          data = {f'{row["yearID"]}{row["playerID"]}': dict(row) for _, row in self.stats.iterrows()}
125          maxyear = max(row["yearID"] for _, row in self.stats.iterrows())
126  
127          embeddings = Embeddings({"transform": Stats.transform})
128          embeddings.index((uid, vectors[uid], None) for uid in vectors)
129  
130          return vectors, data, maxyear, embeddings
131  
132      def metrics(self, name):
133          """
134          Looks up a player's active years, best statistical year and key metrics.
135  
136          Args:
137              name: player name
138  
139          Returns:
140              active, best, metrics
141          """
142  
143          if name in self.names:
144              # Get player stats
145              stats = self.stats[self.stats["playerID"] == self.names[name][0]]
146  
147              # Build key metrics
148              metrics = stats[["yearID", self.metric()]]
149  
150              # Get best year, sort by primary metric
151              best = int(stats.sort_values(by=self.metric(), ascending=False)["yearID"].iloc[0])
152  
153              # Get years active, best year, along with metric trends
154              return metrics["yearID"].tolist(), best, metrics
155  
156          return range(1871, datetime.datetime.today().year), 1950, None
157  
158      def search(self, name=None, year=None, window=None, row=None, limit=10):
159          """
160          Runs an embeddings search. This method takes either a player-year or stats row as input.
161  
162          Args:
163              name: player name to search
164              year: year to search
165              window: limit to window recent seasons
166              row: row of stats to search
167              limit: max results to return
168  
169          Returns:
170              list of results
171          """
172  
173          if row:
174              query = self.vector(row)
175          else:
176              # Lookup player key and build vector id
177              name = self.names.get(name)
178              query = f"{year}{name[0] if name else name}"
179              query = self.vectors.get(query)
180  
181          results, ids = [], set()
182          if query is not None:
183              candidates = limit * 100 if window else limit * 5
184              for uid, _ in self.embeddings.search(query, candidates):
185                  # Only add unique players
186                  if uid[4:] not in ids:
187                      result = self.data[uid].copy()
188  
189                      # Add first player if this is a player comparison. Limit results to window, if necessary
190                      if (not ids and not row) or not window or result["yearID"] > self.maxyear - window:
191                          result["link"] = f'https://www.baseball-reference.com/players/{result["nameLast"].lower()[0]}/{result["bbrefID"]}.shtml'
192                          results.append(result)
193                          ids.add(uid[4:])
194  
195                      if len(ids) >= limit:
196                          break
197  
198          return results
199  
200      def transform(self, row):
201          """
202          Transforms a stats row into a vector.
203  
204          Args:
205              row: stats row
206  
207          Returns:
208              vector
209          """
210  
211          if isinstance(row, np.ndarray):
212              return row
213  
214          return np.array([0.0 if not row[x] or np.isnan(row[x]) else row[x] for x in self.columns])
215  
216  
217  class Batting(Stats):
218      """
219      Batting stats.
220      """
221  
222      def loadcolumns(self):
223          return [
224              "birthMonth",
225              "yearID",
226              "age",
227              "height",
228              "weight",
229              "G",
230              "AB",
231              "R",
232              "H",
233              "1B",
234              "2B",
235              "3B",
236              "HR",
237              "RBI",
238              "SB",
239              "CS",
240              "BB",
241              "SO",
242              "IBB",
243              "HBP",
244              "SH",
245              "SF",
246              "GIDP",
247              "POS",
248              "AVG",
249              "OBP",
250              "TB",
251              "SLG",
252              "OPS",
253              "OPS+",
254          ]
255  
256      def load(self):
257          # Retrieve raw data
258          players = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/People.csv")
259          batting = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/Batting.csv")
260          fielding = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/Fielding.csv")
261  
262          # Merge player data in
263          batting = pd.merge(players, batting, how="inner", on=["playerID"])
264  
265          # Require player to have at least 350 plate appearances.
266          batting = batting[((batting["AB"] + batting["BB"]) >= 350) & (batting["stint"] == 1)]
267  
268          # Derive primary player positions
269          positions = self.positions(fielding)
270  
271          # Calculated columns
272          batting["age"] = batting["yearID"] - batting["birthYear"]
273          batting["POS"] = batting.apply(lambda row: self.position(positions, row), axis=1)
274          batting["AVG"] = batting["H"] / batting["AB"]
275          batting["OBP"] = (batting["H"] + batting["BB"]) / (batting["AB"] + batting["BB"])
276          batting["1B"] = batting["H"] - batting["2B"] - batting["3B"] - batting["HR"]
277          batting["TB"] = batting["1B"] + 2 * batting["2B"] + 3 * batting["3B"] + 4 * batting["HR"]
278          batting["SLG"] = batting["TB"] / batting["AB"]
279          batting["OPS"] = batting["OBP"] + batting["SLG"]
280          batting["OPS+"] = 100 + (batting["OPS"] - batting["OPS"].mean()) * 100
281  
282          return batting
283  
284      def metric(self):
285          return "OPS+"
286  
287      def vector(self, row):
288          row["TB"] = row["1B"] + 2 * row["2B"] + 3 * row["3B"] + 4 * row["HR"]
289          row["AVG"] = row["H"] / row["AB"]
290          row["OBP"] = (row["H"] + row["BB"]) / (row["AB"] + row["BB"])
291          row["SLG"] = row["TB"] / row["AB"]
292          row["OPS"] = row["OBP"] + row["SLG"]
293          row["OPS+"] = 100 + (row["OPS"] - self.stats["OPS"].mean()) * 100
294  
295          return self.transform(row)
296  
297      def positions(self, fielding):
298          """
299          Derives primary positions for players.
300  
301          Args:
302              fielding: fielding data
303  
304          Returns:
305              {player id: (position, number of games)}
306          """
307  
308          positions = {}
309          for _, row in fielding.iterrows():
310              uid = f'{row["yearID"]}{row["playerID"]}'
311              position = row["POS"] if row["POS"] else 0
312              if position == "P":
313                  position = 1
314              elif position == "C":
315                  position = 2
316              elif position == "1B":
317                  position = 3
318              elif position == "2B":
319                  position = 4
320              elif position == "3B":
321                  position = 5
322              elif position == "SS":
323                  position = 6
324              elif position == "OF":
325                  position = 7
326  
327              # Save position if not set or player played more at this position
328              if uid not in positions or positions[uid][1] < row["G"]:
329                  positions[uid] = (position, row["G"])
330  
331          return positions
332  
333      def position(self, positions, row):
334          """
335          Looks up primary position for player row.
336  
337          Arg:
338              positions: all player positions
339              row: player row
340  
341          Returns:
342              primary player positions
343          """
344  
345          uid = f'{row["yearID"]}{row["playerID"]}'
346          return positions[uid][0] if uid in positions else 0
347  
348  
349  class Pitching(Stats):
350      """
351      Pitching stats.
352      """
353  
354      def loadcolumns(self):
355          return [
356              "birthMonth",
357              "yearID",
358              "age",
359              "height",
360              "weight",
361              "W",
362              "L",
363              "G",
364              "GS",
365              "CG",
366              "SHO",
367              "SV",
368              "IPouts",
369              "H",
370              "ER",
371              "HR",
372              "BB",
373              "SO",
374              "BAOpp",
375              "ERA",
376              "IBB",
377              "WP",
378              "HBP",
379              "BK",
380              "BFP",
381              "GF",
382              "R",
383              "SH",
384              "SF",
385              "GIDP",
386              "WHIP",
387              "WADJ",
388          ]
389  
390      def load(self):
391          # Retrieve raw data
392          players = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/People.csv")
393          pitching = pd.read_csv("https://hf.co/datasets/neuml/baseballdata/resolve/main/Pitching.csv")
394  
395          # Merge player data in
396          pitching = pd.merge(players, pitching, how="inner", on=["playerID"])
397  
398          # Require player to have 20 appearances
399          pitching = pitching[(pitching["G"] >= 20) & (pitching["stint"] == 1)]
400  
401          # Calculated columns
402          pitching["age"] = pitching["yearID"] - pitching["birthYear"]
403          pitching["WHIP"] = (pitching["BB"] + pitching["H"]) / (pitching["IPouts"] / 3)
404          pitching["WADJ"] = (pitching["W"] + pitching["SV"]) / (pitching["ERA"] + pitching["WHIP"])
405  
406          return pitching
407  
408      def metric(self):
409          return "WADJ"
410  
411      def vector(self, row):
412          row["WHIP"] = (row["BB"] + row["H"]) / (row["IPouts"] / 3) if row["IPouts"] else None
413          row["WADJ"] = (row["W"] + row["SV"]) / (row["ERA"] + row["WHIP"]) if row["ERA"] and row["WHIP"] else None
414  
415          return self.transform(row)
416  
417  
418  class Application:
419      """
420      Main application.
421      """
422  
423      def __init__(self):
424          """
425          Creates a new application.
426          """
427  
428          # Batting stats
429          self.batting = Batting()
430  
431          # Pitching stats
432          self.pitching = Pitching()
433  
434      def run(self):
435          """
436          Runs a Streamlit application.
437          """
438  
439          st.set_page_config(layout="wide", page_title="Baseball Stats", page_icon="⚾")
440          st.title("⚾ Baseball Stats")
441          st.markdown(
442              """
443              This application finds the best matching players using vector search with [txtai](https://github.com/neuml/txtai).
444              Raw data is from the [Lahman Baseball Database](https://sabr.org/lahman-database/). Read [this
445              article](https://medium.com/neuml/explore-baseball-history-with-vector-search-5778d98d6846) for more details.
446              """
447          )
448  
449          player, search = st.tabs(["Player", "Search"])
450  
451          # Player tab
452          with player:
453              self.player()
454  
455          # Search
456          with search:
457              self.search()
458  
459      def player(self):
460          """
461          Player tab.
462          """
463  
464          st.markdown("Match by player-season. Each player search defaults to the best season sorted by OPS or Wins Adjusted.")
465  
466          # Get parameters
467          params = self.params()
468  
469          # Category and stats
470          category = self.category(params.get("category"), "category")
471          stats = self.batting if category == "Batting" else self.pitching
472  
473          # Player name
474          name = self.name(stats.names, params.get("name"))
475  
476          # Limit player-year comparisons using this window
477          window = self.window(params.get("window"), "window")
478  
479          # Player metrics
480          active, best, metrics = stats.metrics(name)
481  
482          # Player year
483          year = self.year(active, params.get("year"), best)
484  
485          # Display metrics chart
486          if len(active) > 1:
487              self.chart(category, metrics)
488  
489          # Run search
490          results = stats.search(name, year, window)
491  
492          # Display results
493          self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
494  
495          # Save query parameters
496          st.query_params.clear()
497          for key, value in [("category", category), ("name", name), ("window", window), ("year", year)]:
498              if value:
499                  st.query_params[key] = value
500  
501      def search(self):
502          """
503          Stats search tab.
504          """
505  
506          st.markdown("Find players with similar statistics.")
507  
508          stats, category = None, self.category("Batting", "searchcategory")
509  
510          # Limit player-year comparisons using this window
511          window = self.window(None, "searchwindow")
512  
513          with st.form("search"):
514              if category == "Batting":
515                  stats, columns = self.batting, self.batting.columns[:-6]
516              elif category == "Pitching":
517                  stats, columns = self.pitching, self.pitching.columns[:-2]
518  
519              # Enter stats with data editor
520              inputs = st.data_editor(pd.DataFrame([dict((column, None) for column in columns)]), hide_index=True).astype(float)
521  
522              submitted = st.form_submit_button("Search")
523              if submitted:
524                  # Run search
525                  results = stats.search(window=window, row=inputs.to_dict(orient="records")[0])
526  
527                  # Display table
528                  self.table(results, ["link", "nameFirst", "nameLast", "teamID"] + stats.columns[1:])
529  
530      def params(self):
531          """
532          Get application parameters. This method combines URL parameters with session parameters.
533  
534          Returns:
535              parameters
536          """
537  
538          # Get query parameters
539          params = {x: st.query_params.get(x) for x in ["category", "name", "window", "year"]}
540  
541          # Sync parameters with session state
542          if all(x in st.session_state for x in ["category", "name", "year"]):
543              # Copy session year if category and name are unchanged
544              params["year"] = str(st.session_state["year"]) if all(params.get(x) == st.session_state[x] for x in ["category", "name"]) else None
545  
546              # Copy category, name and window from session state
547              params["category"] = st.session_state["category"]
548              params["name"] = st.session_state["name"]
549              params["window"] = st.session_state["window"]
550  
551          return params
552  
553      def category(self, category, key):
554          """
555          Builds category input widget.
556  
557          Args:
558              category: category parameter
559              key: widget key
560  
561          Returns:
562              category component
563          """
564  
565          # List of stat categories
566          categories = ["Batting", "Pitching"]
567  
568          # Get category parameter, default if not available or valid
569          default = categories.index(category) if category and category in categories else 0
570  
571          # Radio box component
572          return st.radio("Stat", categories, index=default, horizontal=True, key=key)
573  
574      def window(self, window, key):
575          """
576          Limit results to last N seasons.
577  
578          Args:
579              window: limit to window seasons
580              key: widget key
581  
582          Returns:
583              window component
584          """
585  
586          # Get window size
587          window = st.text_input("Limit to last N seasons", value=window, key=key)
588  
589          # Clear invalid input
590          if window and (not window.isnumeric() or int(window) < 1):
591              st.error("Window must be a number greater or equal to 1")
592              return None
593  
594          # Convert to int
595          return int(window) if window else window
596  
597      def name(self, names, name):
598          """
599          Builds name input widget.
600  
601          Args:
602              names: list of all allowable names
603  
604          Returns:
605              name component
606          """
607  
608          # Get name parameter, default to random weighted value if not valid
609          name = name if name and name in names else random.choices(list(names.keys()), weights=[names[x][1] for x in names])[0]
610  
611          # Sort names for display
612          names = sorted(names)
613  
614          # Select box component
615          return st.selectbox("Name", names, names.index(name), key="name")
616  
617      def year(self, years, year, best):
618          """
619          Builds year input widget.
620  
621          Args:
622              years: active years for a player
623              year: year parameter
624              best: default to best year if year is invalid
625  
626          Returns:
627              year component
628          """
629  
630          # Get year parameter, default if not available or valid
631          year = int(year) if year and year.isdigit() and int(year) in years else best
632  
633          # Slider component
634          return int(st.select_slider("Year", years, year, key="year") if len(years) > 1 else years[0])
635  
636      def chart(self, category, metrics):
637          """
638          Displays a metric chart.
639  
640          Args:
641              category: Batting or Pitching
642              metrics: player metrics to plot
643          """
644  
645          # Key metric
646          metric = self.batting.metric() if category == "Batting" else self.pitching.metric()
647  
648          # Cast year to string
649          metrics["yearID"] = metrics["yearID"].astype(str)
650  
651          # Metric over years
652          chart = (
653              alt.Chart(metrics)
654              .mark_line(interpolate="monotone", point=True, strokeWidth=2.5, opacity=0.75)
655              .encode(x=alt.X("yearID", title=""), y=alt.Y(metric, scale=alt.Scale(zero=False)))
656          )
657  
658          # Create metric median rule line
659          rule = alt.Chart(metrics).mark_rule(color="gray", strokeDash=[3, 5], opacity=0.5).encode(y=f"median({metric})")
660  
661          # Layered chart configuration
662          chart = (chart + rule).encode(y=alt.Y(title=metric)).properties(height=200).configure_axis(grid=False)
663  
664          # Draw chart
665          st.altair_chart(chart + rule, theme="streamlit", width="stretch")
666  
667      def table(self, results, columns):
668          """
669          Displays a list of results as a table.
670  
671          Args:
672              results: list of results
673              columns: column names
674          """
675  
676          if results:
677              st.dataframe(
678                  results,
679                  column_order=columns,
680                  column_config={
681                      "link": st.column_config.LinkColumn("Link", width="small", display_text=":material/open_in_new:"),
682                      "yearID": st.column_config.NumberColumn("Year", format="%d"),
683                      "nameFirst": "First",
684                      "nameLast": "Last",
685                      "teamID": "Team",
686                      "age": "Age",
687                      "weight": "Weight",
688                      "height": "Height",
689                  },
690              )
691          else:
692              st.write("Player-Year not found")
693  
694  
695  @st.cache_resource(show_spinner=False)
696  def create():
697      """
698      Creates and caches a Streamlit application.
699  
700      Returns:
701          Application
702      """
703  
704      return Application()
705  
706  
707  if __name__ == "__main__":
708      os.environ["TOKENIZERS_PARALLELISM"] = "false"
709  
710      # Create and run application
711      app = create()
712      app.run()