/ src / chunkers / tokenizers / gemini.py
gemini.py
  1  from __future__ import annotations
  2  
  3  """Gemini tokenizer implementation."""
  4  
  5  from typing import Any
  6  
  7  from google import genai
  8  from pydantic import ConfigDict
  9  
 10  from .base_tokenizer import Tokenizer
 11  
 12  
 13  class GeminiTokenizer(Tokenizer):
 14      """Custom tokenizer for Docling's HybridChunker that uses Gemini's token counting API.
 15  
 16      This tokenizer integrates with Google's Gemini API to accurately count tokens
 17      for text chunks, with a fast local fallback for performance.
 18      """
 19  
 20      model_config = ConfigDict(extra='allow', arbitrary_types_allowed=True)
 21  
 22      def __init__(
 23          self,
 24          model: str = "gemini-embedding-001",
 25          max_tokens: int = 2048,
 26          google_api_key: str | None = None,
 27          chars_per_token_ratio: float | None = None,
 28          split_buffer_size: int | None = None,
 29          **kwargs
 30      ):
 31          """Initialize the Gemini tokenizer.
 32  
 33          Parameters
 34          ----------
 35          model
 36              Gemini model name to use for token counting (default: gemini-embedding-001).
 37              Note: Some models like "models/embedding-001" don't support countTokens API.
 38          max_tokens
 39              Maximum number of tokens per chunk (default: 2048).
 40          google_api_key
 41              Optional Google API key. Used only if client is not provided.
 42              If not provided, uses GOOGLE_API_KEY env var.
 43          chars_per_token_ratio
 44              Ratio of characters to tokens for threshold estimation.
 45          split_buffer_size
 46              Number of words to buffer before checking limits.
 47          """
 48          super().__init__(
 49              chars_per_token_ratio=chars_per_token_ratio,
 50              split_buffer_size=split_buffer_size,
 51              **kwargs
 52          )
 53          self.client = genai.Client(api_key=google_api_key)
 54          self.model = model
 55          self.max_tokens = max_tokens
 56  
 57      def count_tokens(self, text: str) -> int:
 58          """Count tokens using Gemini's API with fallback to local estimation.
 59  
 60          Parameters
 61          ----------
 62          text
 63              The text to count tokens for.
 64  
 65          Returns
 66          -------
 67          Number of tokens (from API or estimated).
 68          """
 69          response = self.client.models.count_tokens(model=self.model, contents=text)
 70          return response.total_tokens
 71  
 72      def get_max_tokens(self) -> int:
 73          """Returns the maximum tokens allowed per chunk.
 74  
 75          Returns
 76          -------
 77          Maximum tokens.
 78          """
 79          return self.max_tokens
 80  
 81      def _hash_attributes(self) -> tuple:
 82          """Return hashable attributes that uniquely identify this tokenizer.
 83  
 84          Returns
 85          -------
 86          Tuple containing class type, model name, and configuration attributes.
 87          Note: Does not include client or google_api_key as these are implementation
 88          details, not part of the tokenizer's identity.
 89          """
 90          return (
 91              type(self),
 92              self.model,
 93              self.max_tokens,
 94              self.chars_per_token_ratio,
 95              self.split_buffer_size,
 96          )
 97  
 98  def create_gemini_tokenizer(config: dict[str, Any]) -> Tokenizer:
 99      """Create a Gemini tokenizer from configuration.
100  
101      Parameters
102      ----------
103      config
104          Configuration dictionary with keys:
105          - max_tokens: int (optional) - Maximum tokens per chunk (default: 2048)
106          - google_api_key: str (optional) - Google API key for token counting
107          - model: str (optional) - Gemini model name for token counting (default: gemini-embedding-001)
108          - chars_per_token_ratio: float (optional) - Char-to-token ratio for threshold (default: 1.5)
109          - split_buffer_size: int (optional) - Words to buffer before checking limits (default: 5)
110  
111      Returns
112      -------
113      Tokenizer instance.
114  
115      Raises
116      ------
117      ValueError
118          If invalid configuration values are provided.
119      """
120      max_tokens = config.get("max_tokens", 2048)
121      google_api_key = config.get("llm_api_key")
122      tokenizer_model = config.get("llm_model", "gemini-embedding-001")
123      chars_per_token_ratio = config.get("chars_per_token_ratio")
124      split_buffer_size = config.get("split_buffer_size")
125  
126      return GeminiTokenizer(
127          model=tokenizer_model,
128          max_tokens=max_tokens,
129          google_api_key=google_api_key,
130          chars_per_token_ratio=chars_per_token_ratio,
131          split_buffer_size=split_buffer_size,
132      )
133