base_tokenizer.py
1 """Base tokenizer class that extends BaseTokenizer from docling.""" 2 3 import logging 4 from abc import ABC, abstractmethod 5 6 from docling_core.transforms.chunker.tokenizer.base import BaseTokenizer 7 8 from ..constants import CHARS_PER_TOKEN_RATIO, SPLIT_BUFFER_SIZE 9 10 11 class Tokenizer(BaseTokenizer, ABC): 12 """Base class for tokenizer implementations that extends BaseTokenizer. 13 14 Tokenizers are used by hybrid chunkers to count tokens in text 15 and split text into chunks based on token limits. 16 """ 17 18 def __init__( 19 self, 20 chars_per_token_ratio: float | None = None, 21 split_buffer_size: int | None = None, 22 **kwargs, 23 ): 24 """Initialize the tokenizer. 25 26 Parameters 27 ---------- 28 chars_per_token_ratio 29 Ratio of characters to tokens for threshold estimation. 30 Lower values are more conservative. Default from constants. 31 split_buffer_size 32 Number of words to buffer before checking limits. Default from constants. 33 """ 34 super().__init__(**kwargs) 35 self.chars_per_token_ratio = chars_per_token_ratio or CHARS_PER_TOKEN_RATIO 36 self.split_buffer_size = split_buffer_size or SPLIT_BUFFER_SIZE 37 38 def __call__(self, text: str) -> int: 39 """Make tokenizer callable for semchunk compatibility. 40 41 Parameters 42 ---------- 43 text 44 The text to count tokens for. 45 46 Returns 47 ------- 48 Number of tokens in the text. 49 """ 50 return self.count_tokens(text) 51 52 @abstractmethod 53 def count_tokens(self, text: str) -> int: 54 """Count the number of tokens in the given text. 55 56 Parameters 57 ---------- 58 text 59 The text to count tokens for. 60 61 Returns 62 ------- 63 Number of tokens in the text. 64 """ 65 ... 66 67 @abstractmethod 68 def get_max_tokens(self) -> int: 69 """Get the maximum number of tokens allowed per chunk. 70 71 Returns 72 ------- 73 Maximum tokens per chunk. 74 """ 75 ... 76 77 @abstractmethod 78 def _hash_attributes(self) -> tuple: 79 """Return a tuple of hashable attributes that uniquely identify this tokenizer. 80 81 This method must be implemented by subclasses to return a tuple containing 82 all immutable configuration attributes that should be included in the hash 83 and equality comparison. The tuple should not include mutable objects like 84 API clients. 85 86 Returns 87 ------- 88 Tuple of hashable attributes. 89 """ 90 ... 91 92 def __hash__(self) -> int: 93 """Compute hash from hashable attributes. 94 95 Returns 96 ------- 97 Hash value based on _hash_attributes(). 98 """ 99 return hash(self._hash_attributes()) 100 101 def __eq__(self, other: object) -> bool: 102 """Compare tokenizers for equality based on hashable attributes. 103 104 Parameters 105 ---------- 106 other 107 Object to compare with. 108 109 Returns 110 ------- 111 True if both are Tokenizer instances with same hash attributes, False otherwise. 112 """ 113 if not isinstance(other, Tokenizer): 114 return False 115 return self._hash_attributes() == other._hash_attributes() 116 117 def get_tokenizer(self): 118 """Returns the tokenizer instance (required by BaseTokenizer). 119 120 Returns 121 ------- 122 The tokenizer instance (self). 123 """ 124 return self 125 126 def _join(self, base: str, addition: str) -> str: 127 """Join two text parts with a space.""" 128 return f"{base} {addition}" if base else addition 129 130 def split_text(self, text: str) -> list[str]: 131 """Split text into chunks based on token limits. 132 133 Parameters 134 ---------- 135 text 136 Text to split. 137 138 Returns 139 ------- 140 List of text chunks, each within the token limit. 141 """ 142 if not text: 143 return [] 144 145 max_tokens = self.get_max_tokens() 146 char_threshold = int(max_tokens * self.chars_per_token_ratio) 147 buffer_size = self.split_buffer_size 148 logger = logging.getLogger(__name__) 149 logger.info(f"split_text called: {len(text)} chars, max_tokens={max_tokens}, char_threshold={char_threshold}") 150 151 chunks: list[str] = [] 152 words = text.split() 153 current_chunk = "" 154 buffer: list[str] = [] 155 156 157 for word in words: 158 buffer.append(word) 159 test_chunk = self._join(current_chunk, " ".join(buffer)) 160 near_limit = len(test_chunk) >= char_threshold 161 buffer_full = len(buffer) >= buffer_size 162 163 if not buffer_full: 164 continue 165 166 if near_limit: 167 logger.info("Near char threshold, checking token count") 168 if self.count_tokens(test_chunk) <= max_tokens: 169 current_chunk = test_chunk 170 else: 171 logger.info("Exceeded token limit, starting new chunk") 172 if current_chunk: 173 chunks.append(current_chunk) 174 current_chunk = " ".join(buffer) 175 else: 176 current_chunk = test_chunk 177 178 buffer = [] 179 180 if buffer: 181 current_chunk = self._join(current_chunk, " ".join(buffer)) 182 if current_chunk: 183 chunks.append(current_chunk) 184 185 return chunks 186