/ src / chunkers / tokenizers / base_tokenizer.py
base_tokenizer.py
  1  """Base tokenizer class that extends BaseTokenizer from docling."""
  2  
  3  import logging
  4  from abc import ABC, abstractmethod
  5  
  6  from docling_core.transforms.chunker.tokenizer.base import BaseTokenizer
  7  
  8  from ..constants import CHARS_PER_TOKEN_RATIO, SPLIT_BUFFER_SIZE
  9  
 10  
 11  class Tokenizer(BaseTokenizer, ABC):
 12      """Base class for tokenizer implementations that extends BaseTokenizer.
 13  
 14      Tokenizers are used by hybrid chunkers to count tokens in text
 15      and split text into chunks based on token limits.
 16      """
 17  
 18      def __init__(
 19          self,
 20          chars_per_token_ratio: float | None = None,
 21          split_buffer_size: int | None = None,
 22          **kwargs,
 23      ):
 24          """Initialize the tokenizer.
 25  
 26          Parameters
 27          ----------
 28          chars_per_token_ratio
 29              Ratio of characters to tokens for threshold estimation.
 30              Lower values are more conservative. Default from constants.
 31          split_buffer_size
 32              Number of words to buffer before checking limits. Default from constants.
 33          """
 34          super().__init__(**kwargs)
 35          self.chars_per_token_ratio = chars_per_token_ratio or CHARS_PER_TOKEN_RATIO
 36          self.split_buffer_size = split_buffer_size or SPLIT_BUFFER_SIZE
 37  
 38      def __call__(self, text: str) -> int:
 39          """Make tokenizer callable for semchunk compatibility.
 40  
 41          Parameters
 42          ----------
 43          text
 44              The text to count tokens for.
 45  
 46          Returns
 47          -------
 48          Number of tokens in the text.
 49          """
 50          return self.count_tokens(text)
 51  
 52      @abstractmethod
 53      def count_tokens(self, text: str) -> int:
 54          """Count the number of tokens in the given text.
 55  
 56          Parameters
 57          ----------
 58          text
 59              The text to count tokens for.
 60  
 61          Returns
 62          -------
 63          Number of tokens in the text.
 64          """
 65          ...
 66  
 67      @abstractmethod
 68      def get_max_tokens(self) -> int:
 69          """Get the maximum number of tokens allowed per chunk.
 70  
 71          Returns
 72          -------
 73          Maximum tokens per chunk.
 74          """
 75          ...
 76  
 77      @abstractmethod
 78      def _hash_attributes(self) -> tuple:
 79          """Return a tuple of hashable attributes that uniquely identify this tokenizer.
 80  
 81          This method must be implemented by subclasses to return a tuple containing
 82          all immutable configuration attributes that should be included in the hash
 83          and equality comparison. The tuple should not include mutable objects like
 84          API clients.
 85  
 86          Returns
 87          -------
 88          Tuple of hashable attributes.
 89          """
 90          ...
 91  
 92      def __hash__(self) -> int:
 93          """Compute hash from hashable attributes.
 94  
 95          Returns
 96          -------
 97          Hash value based on _hash_attributes().
 98          """
 99          return hash(self._hash_attributes())
100  
101      def __eq__(self, other: object) -> bool:
102          """Compare tokenizers for equality based on hashable attributes.
103  
104          Parameters
105          ----------
106          other
107              Object to compare with.
108  
109          Returns
110          -------
111          True if both are Tokenizer instances with same hash attributes, False otherwise.
112          """
113          if not isinstance(other, Tokenizer):
114              return False
115          return self._hash_attributes() == other._hash_attributes()
116  
117      def get_tokenizer(self):
118          """Returns the tokenizer instance (required by BaseTokenizer).
119  
120          Returns
121          -------
122          The tokenizer instance (self).
123          """
124          return self
125  
126      def _join(self, base: str, addition: str) -> str:
127          """Join two text parts with a space."""
128          return f"{base} {addition}" if base else addition
129  
130      def split_text(self, text: str) -> list[str]:
131          """Split text into chunks based on token limits.
132  
133          Parameters
134          ----------
135          text
136              Text to split.
137  
138          Returns
139          -------
140          List of text chunks, each within the token limit.
141          """
142          if not text:
143              return []
144  
145          max_tokens = self.get_max_tokens()
146          char_threshold = int(max_tokens * self.chars_per_token_ratio)
147          buffer_size = self.split_buffer_size
148          logger = logging.getLogger(__name__)
149          logger.info(f"split_text called: {len(text)} chars, max_tokens={max_tokens}, char_threshold={char_threshold}")
150  
151          chunks: list[str] = []
152          words = text.split()
153          current_chunk = ""
154          buffer: list[str] = []
155  
156  
157          for word in words:
158              buffer.append(word)
159              test_chunk = self._join(current_chunk, " ".join(buffer))
160              near_limit = len(test_chunk) >= char_threshold
161              buffer_full = len(buffer) >= buffer_size
162  
163              if not buffer_full:
164                  continue
165  
166              if near_limit:
167                  logger.info("Near char threshold, checking token count")
168                  if self.count_tokens(test_chunk) <= max_tokens:
169                      current_chunk = test_chunk
170                  else:
171                      logger.info("Exceeded token limit, starting new chunk")
172                      if current_chunk:
173                          chunks.append(current_chunk)
174                      current_chunk = " ".join(buffer)
175              else:
176                  current_chunk = test_chunk
177  
178              buffer = []
179  
180          if buffer:
181              current_chunk = self._join(current_chunk, " ".join(buffer))
182          if current_chunk:
183              chunks.append(current_chunk)
184  
185          return chunks
186