gemini.py
1 from __future__ import annotations 2 3 """Gemini tokenizer implementation.""" 4 5 from typing import Any 6 7 from google import genai 8 from pydantic import ConfigDict 9 10 from .base_tokenizer import Tokenizer 11 12 13 class GeminiTokenizer(Tokenizer): 14 """Custom tokenizer for Docling's HybridChunker that uses Gemini's token counting API. 15 16 This tokenizer integrates with Google's Gemini API to accurately count tokens 17 for text chunks, with a fast local fallback for performance. 18 """ 19 20 model_config = ConfigDict(extra='allow', arbitrary_types_allowed=True) 21 22 def __init__( 23 self, 24 model: str = "gemini-embedding-001", 25 max_tokens: int = 2048, 26 google_api_key: str | None = None, 27 chars_per_token_ratio: float | None = None, 28 split_buffer_size: int | None = None, 29 **kwargs 30 ): 31 """Initialize the Gemini tokenizer. 32 33 Parameters 34 ---------- 35 model 36 Gemini model name to use for token counting (default: gemini-embedding-001). 37 Note: Some models like "models/embedding-001" don't support countTokens API. 38 max_tokens 39 Maximum number of tokens per chunk (default: 2048). 40 google_api_key 41 Optional Google API key. Used only if client is not provided. 42 If not provided, uses GOOGLE_API_KEY env var. 43 chars_per_token_ratio 44 Ratio of characters to tokens for threshold estimation. 45 split_buffer_size 46 Number of words to buffer before checking limits. 47 """ 48 super().__init__( 49 chars_per_token_ratio=chars_per_token_ratio, 50 split_buffer_size=split_buffer_size, 51 **kwargs 52 ) 53 self.client = genai.Client(api_key=google_api_key) 54 self.model = model 55 self.max_tokens = max_tokens 56 57 def count_tokens(self, text: str) -> int: 58 """Count tokens using Gemini's API with fallback to local estimation. 59 60 Parameters 61 ---------- 62 text 63 The text to count tokens for. 64 65 Returns 66 ------- 67 Number of tokens (from API or estimated). 68 """ 69 response = self.client.models.count_tokens(model=self.model, contents=text) 70 return response.total_tokens 71 72 def get_max_tokens(self) -> int: 73 """Returns the maximum tokens allowed per chunk. 74 75 Returns 76 ------- 77 Maximum tokens. 78 """ 79 return self.max_tokens 80 81 def _hash_attributes(self) -> tuple: 82 """Return hashable attributes that uniquely identify this tokenizer. 83 84 Returns 85 ------- 86 Tuple containing class type, model name, and configuration attributes. 87 Note: Does not include client or google_api_key as these are implementation 88 details, not part of the tokenizer's identity. 89 """ 90 return ( 91 type(self), 92 self.model, 93 self.max_tokens, 94 self.chars_per_token_ratio, 95 self.split_buffer_size, 96 ) 97 98 def create_gemini_tokenizer(config: dict[str, Any]) -> Tokenizer: 99 """Create a Gemini tokenizer from configuration. 100 101 Parameters 102 ---------- 103 config 104 Configuration dictionary with keys: 105 - max_tokens: int (optional) - Maximum tokens per chunk (default: 2048) 106 - google_api_key: str (optional) - Google API key for token counting 107 - model: str (optional) - Gemini model name for token counting (default: gemini-embedding-001) 108 - chars_per_token_ratio: float (optional) - Char-to-token ratio for threshold (default: 1.5) 109 - split_buffer_size: int (optional) - Words to buffer before checking limits (default: 5) 110 111 Returns 112 ------- 113 Tokenizer instance. 114 115 Raises 116 ------ 117 ValueError 118 If invalid configuration values are provided. 119 """ 120 max_tokens = config.get("max_tokens", 2048) 121 google_api_key = config.get("llm_api_key") 122 tokenizer_model = config.get("llm_model", "gemini-embedding-001") 123 chars_per_token_ratio = config.get("chars_per_token_ratio") 124 split_buffer_size = config.get("split_buffer_size") 125 126 return GeminiTokenizer( 127 model=tokenizer_model, 128 max_tokens=max_tokens, 129 google_api_key=google_api_key, 130 chars_per_token_ratio=chars_per_token_ratio, 131 split_buffer_size=split_buffer_size, 132 ) 133