/ integrations / notion.py
notion.py
  1  """Notion API client for syncing scored jobs to a Notion database."""
  2  
  3  __all__ = ["NotionClientProtocol", "NotionClient"]
  4  
  5  import logging
  6  from collections.abc import Iterator
  7  from typing import Any, Protocol, cast, runtime_checkable
  8  
  9  from notion_client import (
 10      APIResponseError,
 11      collect_paginated_api,
 12      is_full_data_source,
 13      is_full_database,
 14  )
 15  from notion_client import (
 16      Client as _SDKClient,
 17  )
 18  from tenacity import (
 19      retry,
 20      retry_if_exception,
 21      stop_after_attempt,
 22      wait_exponential,
 23  )
 24  
 25  from config import NotionPropertyMapping
 26  from exceptions import APIError, ParseError
 27  from models.notion import ExistingPagesIndex, NotionDatabaseInfo
 28  
 29  logger = logging.getLogger(__name__)
 30  
 31  _MAX_BLOCKS_PER_PAGE = 100
 32  # Notion API limit for children blocks in a single pages.create call.
 33  
 34  _RETRY_ATTEMPTS = 5
 35  # Maximum retry attempts for transient Notion API errors.
 36  
 37  _DEFAULT_TIMEOUT_MS = 30_000
 38  # 30 s default; the notion-client SDK uses httpx under the hood.
 39  
 40  
 41  def _is_transient(exc: BaseException) -> bool:
 42      """Return True for retriable Notion API errors (5xx, 429)."""
 43      if isinstance(exc, APIResponseError):
 44          return exc.status >= 500 or exc.status == 429
 45      return False
 46  
 47  
 48  def _build_coached_filter(coached_prop: str) -> dict[str, Any]:
 49      """Build a Notion filter for Coached == False.
 50  
 51      Args:
 52          coached_prop: Name of the Coached checkbox property.
 53  
 54      Returns:
 55          Notion filter dict.
 56      """
 57      return {"property": coached_prop, "checkbox": {"equals": False}}
 58  
 59  
 60  def _build_status_filter(status_prop: str, status_value: str) -> dict[str, Any]:
 61      """Build a Notion filter for Status == status_value.
 62  
 63      Args:
 64          status_prop: Name of the Status select property.
 65          status_value: Value to filter by.
 66  
 67      Returns:
 68          Notion filter dict.
 69      """
 70      return {"property": status_prop, "select": {"equals": status_value}}
 71  
 72  
 73  @runtime_checkable
 74  class NotionClientProtocol(Protocol):
 75      """Interface contract for any Notion client implementation."""
 76  
 77      def ping(self) -> None:
 78          """Verify the Notion token is valid."""
 79          ...
 80  
 81      def list_databases(self) -> list[NotionDatabaseInfo]:
 82          """List all Notion databases accessible to the integration."""
 83          ...
 84  
 85      def validate_database_schema(
 86          self, database_id: str, property_mapping: NotionPropertyMapping
 87      ) -> None:
 88          """Validate that the database has all required properties with correct types."""
 89          ...
 90  
 91      def get_existing_pages_index(
 92          self, database_id: str, property_mapping: NotionPropertyMapping
 93      ) -> ExistingPagesIndex:
 94          """Build a dedup index from all pages in the database."""
 95          ...
 96  
 97      def create_page(
 98          self,
 99          database_id: str,
100          properties: dict[str, Any],
101          *,
102          icon: str | None = None,
103          blocks: list[dict[str, Any]] | None = None,
104      ) -> str:
105          """Create a Notion page and return its ID."""
106          ...
107  
108      def get_uncoached_pages(
109          self,
110          database_id: str,
111          property_mapping: NotionPropertyMapping,
112          *,
113          status_filter: str | None = None,
114      ) -> dict[str, str]:
115          """Return {job_posting_id: page_id} for all uncoached pages."""
116          ...
117  
118      def append_blocks(self, page_id: str, blocks: list[dict[str, Any]]) -> list[str]:
119          """Append content blocks to an existing page, batching at _MAX_BLOCKS_PER_PAGE."""
120          ...
121  
122      def update_page_property(
123          self, page_id: str, property_name: str, value: dict[str, Any]
124      ) -> None:
125          """Update a single property on an existing page."""
126          ...
127  
128  
129  class NotionClient:
130      """Wraps the notion-client SDK to sync job pages to a Notion database.
131  
132      Args:
133          token: Notion integration token.
134      """
135  
136      def __init__(self, token: str) -> None:
137          self._client = _SDKClient(auth=token, timeout_ms=_DEFAULT_TIMEOUT_MS)
138          self._parent_db_ids: dict[str, str] = {}
139          # Cache: data_source_id -> parent database_id (needed for pages.create).
140  
141      @retry(
142          retry=retry_if_exception(_is_transient),
143          wait=wait_exponential(multiplier=1, min=2, max=60),
144          stop=stop_after_attempt(_RETRY_ATTEMPTS),
145          reraise=True,
146      )
147      def ping(self) -> None:
148          """Verify the Notion token is valid by calling the users.me endpoint.
149  
150          Raises:
151              APIError: If the token is invalid or the API is unreachable.
152          """
153          try:
154              self._client.users.me()
155              logger.debug("Notion ping successful.")
156          except APIResponseError as exc:
157              if _is_transient(exc):
158                  raise
159              raise APIError("notion", exc.status, str(exc)) from exc
160  
161      @retry(
162          retry=retry_if_exception(_is_transient),
163          stop=stop_after_attempt(_RETRY_ATTEMPTS),
164          wait=wait_exponential(multiplier=1, min=1, max=30),
165          reraise=True,
166      )
167      def validate_database_schema(
168          self, database_id: str, property_mapping: NotionPropertyMapping
169      ) -> None:
170          """Validate that the database has all required properties with correct types.
171  
172          Args:
173              database_id: Notion database ID.
174              property_mapping: Expected property name mapping.
175  
176          Raises:
177              ParseError: If any properties are missing or have wrong types.
178          """
179          response = cast(
180              dict[str, Any],
181              self._client.request(
182                  path=f"data_sources/{database_id}",
183                  method="GET",
184              ),
185          )
186          self._cache_parent_db_id(database_id, response)
187          db_properties = response.get("properties", {})
188  
189          expected = self._expected_schema(property_mapping)
190          issues: list[str] = []
191  
192          for prop_name, expected_type in expected.items():
193              if prop_name not in db_properties:
194                  issues.append(f"missing property {prop_name!r} (expected {expected_type})")
195              else:
196                  actual_type = db_properties[prop_name].get("type", "unknown")
197                  if actual_type != expected_type:
198                      issues.append(
199                          f"property {prop_name!r}: expected {expected_type}, got {actual_type}"
200                      )
201  
202          if issues:
203              detail = "; ".join(issues)
204              raise ParseError("notion_database", detail)
205  
206          logger.info("Database schema validated: all %d properties match.", len(expected))
207  
208      @staticmethod
209      def _expected_schema(pm: NotionPropertyMapping) -> dict[str, str]:
210          """Build the expected property name -> Notion type mapping."""
211          return {
212              pm.title: "title",
213              pm.position: "rich_text",
214              pm.status: "select",
215              pm.job_url: "url",
216              pm.company_linkedin: "url",
217              pm.job_id: "rich_text",
218              pm.ai_reasoning: "rich_text",
219              pm.company: "rich_text",
220              pm.location: "rich_text",
221              pm.fit_category: "select",
222              pm.networking_signal: "number",
223              pm.networking_rationale: "rich_text",
224              pm.score: "number",
225              pm.coached: "checkbox",
226          }
227  
228      @retry(
229          retry=retry_if_exception(_is_transient),
230          stop=stop_after_attempt(_RETRY_ATTEMPTS),
231          wait=wait_exponential(multiplier=1, min=1, max=30),
232          reraise=True,
233      )
234      def list_databases(self) -> list[NotionDatabaseInfo]:
235          """List all Notion databases accessible to the integration.
236  
237          Paginates through the search API and returns results sorted
238          alphabetically by title.
239  
240          Returns:
241              Sorted list of NotionDatabaseInfo.
242          """
243          databases: list[NotionDatabaseInfo] = []
244  
245          for result in collect_paginated_api(self._client.search):
246              if not (is_full_database(result) or is_full_data_source(result)):
247                  continue
248              title_parts = result.get("title", [])
249              title = title_parts[0].get("plain_text", "") if title_parts else ""
250              databases.append(NotionDatabaseInfo(
251                  database_id=result["id"],
252                  title=title,
253              ))
254  
255          databases.sort(key=lambda db: db.title)
256          logger.info("Found %d accessible databases.", len(databases))
257          return databases
258  
259      @retry(
260          retry=retry_if_exception(_is_transient),
261          stop=stop_after_attempt(_RETRY_ATTEMPTS),
262          wait=wait_exponential(multiplier=1, min=1, max=30),
263          reraise=True,
264      )
265      def get_existing_pages_index(
266          self, database_id: str, property_mapping: NotionPropertyMapping
267      ) -> ExistingPagesIndex:
268          """Build a dedup index from all pages in the database.
269  
270          Paginates through all pages and extracts job IDs and company names.
271  
272          Args:
273              database_id: Notion database ID.
274              property_mapping: Property name mapping.
275  
276          Returns:
277              ExistingPagesIndex with job_ids and company_names sets.
278          """
279          job_ids: set[str] = set()
280          company_names: set[str] = set()
281  
282          for page in self._paginate_query(database_id, {}):
283              props = page.get("properties", {})
284              job_id = self._extract_rich_text(props, property_mapping.job_id)
285              if job_id:
286                  job_ids.add(job_id)
287              company = self._extract_rich_text(props, property_mapping.company)
288              if company:
289                  company_names.add(company)
290  
291          logger.info(
292              "Fetched existing pages index: %d job IDs, %d companies.",
293              len(job_ids),
294              len(company_names),
295          )
296          return ExistingPagesIndex(
297              job_ids=frozenset(job_ids), company_names=frozenset(company_names)
298          )
299  
300      @retry(
301          retry=retry_if_exception(_is_transient),
302          stop=stop_after_attempt(_RETRY_ATTEMPTS),
303          wait=wait_exponential(multiplier=1, min=1, max=30),
304          reraise=True,
305      )
306      def create_page(
307          self,
308          database_id: str,
309          properties: dict[str, Any],
310          *,
311          icon: str | None = None,
312          blocks: list[dict[str, Any]] | None = None,
313      ) -> str:
314          """Create a Notion page with the given properties and optional content blocks.
315  
316          Args:
317              database_id: Parent database ID.
318              properties: Pre-built Notion properties dict.
319              icon: Optional emoji icon for the page.
320              blocks: Optional initial content blocks (truncated to 100).
321  
322          Returns:
323              The created page ID.
324  
325          Raises:
326              APIError: On non-retriable Notion API errors (4xx).
327          """
328          parent_db_id = self._resolve_parent_db_id(database_id)
329          truncated = blocks[:_MAX_BLOCKS_PER_PAGE] if blocks else []
330          try:
331              logger.debug("Creating page in database %s.", database_id)
332              kwargs: dict[str, Any] = {
333                  "parent": {"database_id": parent_db_id},
334                  "properties": properties,
335              }
336              if icon is not None:
337                  kwargs["icon"] = {"type": "emoji", "emoji": icon}
338              if truncated:
339                  kwargs["children"] = truncated
340              result = cast(dict[str, Any], self._client.pages.create(**kwargs))
341              page_id: str = result["id"]
342              logger.info("Created page %s.", page_id)
343              return page_id
344          except APIResponseError as exc:
345              if _is_transient(exc):
346                  raise
347              raise APIError("notion", exc.status, str(exc)) from exc
348  
349      @retry(
350          retry=retry_if_exception(_is_transient),
351          stop=stop_after_attempt(_RETRY_ATTEMPTS),
352          wait=wait_exponential(multiplier=1, min=1, max=30),
353          reraise=True,
354      )
355      def get_uncoached_pages(
356          self,
357          database_id: str,
358          property_mapping: NotionPropertyMapping,
359          *,
360          status_filter: str | None = None,
361      ) -> dict[str, str]:
362          """Return {job_posting_id: page_id} for pages where Coached is False.
363  
364          Args:
365              database_id: Notion database ID.
366              property_mapping: Property name mapping.
367              status_filter: If set, also filter by Status == this value.
368  
369          Returns:
370              Dict mapping job_posting_id to page_id.
371          """
372          pm = property_mapping
373          filters: list[dict[str, Any]] = [_build_coached_filter(pm.coached)]
374          if status_filter is not None:
375              filters.append(_build_status_filter(pm.status, status_filter))
376          filter_body: dict[str, Any] = {"and": filters}
377  
378          result: dict[str, str] = {}
379  
380          for page in self._paginate_query(database_id, {"filter": filter_body}):
381              page_id: str = page.get("id", "")
382              props = page.get("properties", {})
383              job_id = self._extract_rich_text(props, pm.job_id)
384              if job_id:
385                  result[job_id] = page_id
386  
387          logger.info("Found %d uncoached pages.", len(result))
388          return result
389  
390      def append_blocks(self, page_id: str, blocks: list[dict[str, Any]]) -> list[str]:
391          """Append content blocks to an existing Notion page.
392  
393          Batches at _MAX_BLOCKS_PER_PAGE to respect the Notion API limit.
394          Each batch retries independently on transient errors.
395  
396          Args:
397              page_id: Target page ID.
398              blocks: List of Notion block dicts to append.
399  
400          Returns:
401              List of created block IDs, one per input block, in the same order.
402  
403          Raises:
404              APIError: On non-retriable Notion API errors (4xx).
405          """
406          created_ids: list[str] = []
407          for i in range(0, len(blocks), _MAX_BLOCKS_PER_PAGE):
408              batch = blocks[i : i + _MAX_BLOCKS_PER_PAGE]
409              created_ids.extend(self._append_batch(page_id, batch))
410          return created_ids
411  
412      @retry(
413          retry=retry_if_exception(_is_transient),
414          stop=stop_after_attempt(_RETRY_ATTEMPTS),
415          wait=wait_exponential(multiplier=1, min=1, max=30),
416          reraise=True,
417      )
418      def _append_batch(
419          self, page_id: str, batch: list[dict[str, Any]]
420      ) -> list[str]:
421          """Append a single batch of blocks to a page, with retry on transient errors.
422  
423          Args:
424              page_id: Target page ID.
425              batch: List of Notion block dicts (max _MAX_BLOCKS_PER_PAGE).
426  
427          Returns:
428              List of created block IDs from this batch.
429  
430          Raises:
431              APIError: On non-retriable Notion API errors (4xx).
432          """
433          try:
434              response = cast(
435                  dict[str, Any],
436                  self._client.blocks.children.append(
437                      block_id=page_id, children=batch
438                  ),
439              )
440              return [str(block["id"]) for block in response.get("results", [])]
441          except APIResponseError as exc:
442              if _is_transient(exc):
443                  raise
444              raise APIError("notion", exc.status, str(exc)) from exc
445  
446      @retry(
447          retry=retry_if_exception(_is_transient),
448          stop=stop_after_attempt(_RETRY_ATTEMPTS),
449          wait=wait_exponential(multiplier=1, min=1, max=30),
450          reraise=True,
451      )
452      def update_page_property(
453          self, page_id: str, property_name: str, value: dict[str, Any]
454      ) -> None:
455          """Update a single property on an existing Notion page.
456  
457          Args:
458              page_id: Target page ID.
459              property_name: Name of the property to update.
460              value: Notion property value dict (e.g. {"checkbox": True}).
461  
462          Raises:
463              APIError: On non-retriable Notion API errors (4xx).
464          """
465          try:
466              self._client.pages.update(
467                  page_id=page_id, properties={property_name: value}
468              )
469          except APIResponseError as exc:
470              if _is_transient(exc):
471                  raise
472              raise APIError("notion", exc.status, str(exc)) from exc
473  
474      # ── Private helpers ──────────────────────────────────────────────────────
475  
476      def _cache_parent_db_id(
477          self, data_source_id: str, response: dict[str, Any]
478      ) -> None:
479          """Extract and cache the parent database ID from a data_source response."""
480          parent = response.get("parent", {})
481          parent_db_id = parent.get("database_id")
482          if parent_db_id:
483              self._parent_db_ids[data_source_id] = parent_db_id
484  
485      def _resolve_parent_db_id(self, data_source_id: str) -> str:
486          """Get the parent database ID for a data_source, fetching if needed.
487  
488          The Notion v3 API requires the parent database_id for pages.create,
489          which differs from the data_source ID used for queries.
490          """
491          if data_source_id in self._parent_db_ids:
492              return self._parent_db_ids[data_source_id]
493          try:
494              response = cast(
495                  dict[str, Any],
496                  self._client.request(
497                      path=f"data_sources/{data_source_id}",
498                      method="GET",
499                  ),
500              )
501          except APIResponseError as exc:
502              if _is_transient(exc):
503                  raise
504              raise APIError(
505                  "notion",
506                  exc.status,
507                  f"Database {data_source_id!r} not found or not shared "
508                  f"with this integration: {exc}",
509              ) from exc
510          self._cache_parent_db_id(data_source_id, response)
511          parent_db_id = self._parent_db_ids.get(data_source_id)
512          if parent_db_id is None:
513              raise ParseError(
514                  "notion_data_source",
515                  f"No parent database ID found for {data_source_id!r}",
516              )
517          return parent_db_id
518  
519      def _paginate_query(
520          self, database_id: str, body: dict[str, Any]
521      ) -> Iterator[dict[str, Any]]:
522          """Yield all pages from a paginated data_source query.
523  
524          Args:
525              database_id: Notion database ID to query.
526              body: Initial request body (e.g. filter dict). start_cursor is injected
527                  automatically on each subsequent page.
528  
529          Yields:
530              Each page dict from the results array.
531          """
532          start_cursor: str | None = None
533          while True:
534              request_body = {**body}
535              if start_cursor is not None:
536                  request_body["start_cursor"] = start_cursor
537              logger.debug("Querying database %s (cursor=%s).", database_id, start_cursor)
538              response = cast(
539                  dict[str, Any],
540                  self._client.request(
541                      path=f"data_sources/{database_id}/query",
542                      method="POST",
543                      body=request_body,
544                  ),
545              )
546              yield from response.get("results", [])
547              if not response.get("has_more", False):
548                  break
549              start_cursor = response.get("next_cursor")
550  
551      @staticmethod
552      def _extract_rich_text(properties: dict[str, Any], property_name: str) -> str:
553          """Extract plain text from a rich_text property."""
554          prop = properties.get(property_name, {})
555          rich_text = prop.get("rich_text", [])
556          if rich_text:
557              return str(rich_text[0].get("plain_text", ""))
558          return ""
559