/ ocr_engine.py
ocr_engine.py
  1  """brane - Local VLM-based OCR using Qwen3-VL via Ollama."""
  2  
  3  import sys
  4  from pathlib import Path
  5  
  6  import click
  7  import ollama
  8  
  9  from prompts import MARKDOWN_PROMPT, PLAIN_PROMPT
 10  
 11  SUPPORTED_FORMATS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff", ".tif", ".gif"}
 12  MODELS = {
 13      "8b": "qwen3-vl:8b",
 14      "30b": "qwen3-vl:30b",
 15  }
 16  DEFAULT_MODEL = "8b"
 17  
 18  
 19  def validate_image(path: Path) -> Path:
 20      if not path.exists():
 21          raise click.BadParameter(f"File not found: {path}")
 22      if path.suffix.lower() not in SUPPORTED_FORMATS:
 23          raise click.BadParameter(
 24              f"Unsupported format: {path.suffix} (supported: {', '.join(sorted(SUPPORTED_FORMATS))})"
 25          )
 26      return path
 27  
 28  
 29  def ocr_image(image_path: Path, model: str, prompt: str, stream: bool = True, keep_alive: int = 0):
 30      """Run OCR on a single image and yield text chunks."""
 31      response = ollama.chat(
 32          model=model,
 33          messages=[
 34              {
 35                  "role": "user",
 36                  "content": prompt,
 37                  "images": [str(image_path)],
 38              }
 39          ],
 40          stream=stream,
 41          keep_alive=keep_alive,
 42      )
 43  
 44      if stream:
 45          for chunk in response:
 46              text = chunk["message"]["content"]
 47              if text:
 48                  yield text
 49      else:
 50          yield response["message"]["content"]
 51  
 52  
 53  @click.command("ocr")
 54  @click.argument("images", nargs=-1, required=True, type=click.Path(exists=True, path_type=Path))
 55  @click.option("-o", "--output", type=click.Path(path_type=Path), help="Write output to file.")
 56  @click.option(
 57      "-f", "--format", "fmt", type=click.Choice(["markdown", "plain"]), default="markdown",
 58      help="Output format (default: markdown).",
 59  )
 60  @click.option(
 61      "-m", "--model", "model_size", type=click.Choice(list(MODELS.keys())), default=DEFAULT_MODEL,
 62      help=f"Model size (default: {DEFAULT_MODEL}).",
 63  )
 64  @click.option("-p", "--prompt", "custom_prompt", default=None, help="Custom prompt override.")
 65  @click.option("--no-stream", is_flag=True, help="Disable streaming output.")
 66  @click.option("--persist", is_flag=True, help="Keep model loaded in VRAM after completion.")
 67  def ocr(images, output, fmt, model_size, custom_prompt, no_stream, persist):
 68      """Extract text from images using local VLM OCR.
 69  
 70      Examples:
 71  
 72          brane screenshot.png
 73  
 74          brane document.png -o result.md
 75  
 76          brane *.png -m 30b -o combined.md
 77  
 78          brane photo.jpg --format plain
 79  
 80          brane table.png --prompt "Extract this table as CSV"
 81      """
 82      model = MODELS[model_size]
 83      prompt = custom_prompt or (MARKDOWN_PROMPT if fmt == "markdown" else PLAIN_PROMPT)
 84  
 85      validated = []
 86      for img in images:
 87          try:
 88              validated.append(validate_image(img))
 89          except click.BadParameter as e:
 90              click.echo(f"Error: {e}", err=True)
 91              sys.exit(1)
 92  
 93      try:
 94          ollama.list()
 95      except Exception:
 96          click.echo(
 97              "Error: Cannot connect to Ollama. Is it running?\n"
 98              "  Start it with: ollama serve",
 99              err=True,
100          )
101          sys.exit(1)
102  
103      output_parts = []
104      for i, img_path in enumerate(validated):
105          if len(validated) > 1:
106              header = f"\n--- {img_path.name} ---\n\n" if i > 0 else f"--- {img_path.name} ---\n\n"
107              if output:
108                  output_parts.append(header)
109              else:
110                  click.echo(header, nl=False)
111  
112          keep_alive = -1 if persist else 0
113          if output:
114              for chunk in ocr_image(img_path, model, prompt, stream=False, keep_alive=keep_alive):
115                  output_parts.append(chunk)
116              if i < len(validated) - 1:
117                  output_parts.append("\n\n")
118          else:
119              for chunk in ocr_image(img_path, model, prompt, stream=not no_stream, keep_alive=keep_alive):
120                  click.echo(chunk, nl=False)
121              click.echo()
122  
123      if output:
124          output.write_text("".join(output_parts))
125          click.echo(f"Written to {output}", err=True)
126  
127  
128  if __name__ == "__main__":
129      ocr()