accessibility.go
1 package tools 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "image" 8 "image/color" 9 "image/draw" 10 "image/png" 11 "math" 12 "os" 13 "regexp" 14 "runtime" 15 "strings" 16 17 _ "image/jpeg" 18 19 "github.com/Kocoro-lab/ShanClaw/internal/agent" 20 ) 21 22 type refEntry struct { 23 path string 24 role string 25 pid int 26 } 27 28 type appContext struct { 29 App string `json:"app"` 30 Window string `json:"window"` 31 URL string `json:"url,omitempty"` 32 FocusedElement string `json:"focused_element,omitempty"` 33 } 34 35 func formatContext(ctx *appContext) string { 36 if ctx == nil { 37 return "" 38 } 39 msg := fmt.Sprintf("\n[context: %s — %s", ctx.App, ctx.Window) 40 if ctx.URL != "" { 41 msg += fmt.Sprintf(" (%s)", ctx.URL) 42 } 43 if ctx.FocusedElement != "" { 44 msg += fmt.Sprintf(", focused: %s", ctx.FocusedElement) 45 } 46 msg += "]" 47 return msg 48 } 49 50 type AccessibilityTool struct { 51 client *AXClient 52 refs map[string]refEntry 53 lastPID int 54 } 55 56 type accessibilityArgs struct { 57 Action string `json:"action"` 58 App string `json:"app,omitempty"` 59 MaxDepth int `json:"max_depth,omitempty"` 60 Budget int `json:"semantic_budget,omitempty"` 61 Filter string `json:"filter,omitempty"` 62 Ref string `json:"ref,omitempty"` 63 Value *string `json:"value,omitempty"` 64 Query string `json:"query,omitempty"` 65 Role string `json:"role,omitempty"` 66 Identifier string `json:"identifier,omitempty"` 67 DX int `json:"dx,omitempty"` 68 DY int `json:"dy,omitempty"` 69 Roles []string `json:"roles,omitempty"` 70 MaxLabels int `json:"max_labels,omitempty"` 71 } 72 73 func (t *AccessibilityTool) Info() agent.ToolInfo { 74 return agent.ToolInfo{ 75 Name: "accessibility", 76 Description: "Interact with macOS apps via the accessibility tree. Workflow: (1) Use 'annotate' to get a labeled screenshot with numbered elements, (2) click/type by ref. Actions: read_tree, click, press, set_value, get_value, find, scroll, annotate. Always specify 'app' parameter with the exact app name. For web content in browsers, prefer the browser tool instead.", 77 Parameters: map[string]any{ 78 "type": "object", 79 "properties": map[string]any{ 80 "action": map[string]any{"type": "string", "description": "Action: read_tree, click, press, set_value, get_value, find, scroll, annotate"}, 81 "app": map[string]any{"type": "string", "description": "Target app name (defaults to frontmost app)"}, 82 "max_depth": map[string]any{"type": "integer", "description": "Tree depth (default: 25 semantic budget, layout containers cost 0)"}, 83 "semantic_budget": map[string]any{"type": "integer", "description": "Semantic depth budget (default: 25, layout containers cost 0 depth)"}, 84 "filter": map[string]any{"type": "string", "description": "Filter: all (default) or interactive (for read_tree)"}, 85 "ref": map[string]any{"type": "string", "description": "Element ref from read_tree (e.g. e14, for click/press/set_value/get_value/scroll)"}, 86 "value": map[string]any{"type": "string", "description": "Value to set (for set_value)"}, 87 "query": map[string]any{"type": "string", "description": "Text to search for (for find, case-insensitive substring)"}, 88 "role": map[string]any{"type": "string", "description": "AX role filter (for find, e.g. AXButton)"}, 89 "identifier": map[string]any{"type": "string", "description": "AX identifier to find (exact match, for find)"}, 90 "dx": map[string]any{"type": "integer", "description": "Horizontal scroll amount in pixels (for scroll)"}, 91 "dy": map[string]any{"type": "integer", "description": "Vertical scroll amount in pixels (for scroll, positive=down)"}, 92 "roles": map[string]any{"type": "array", "items": map[string]any{"type": "string"}, "description": "Filter by AX roles (for annotate, e.g. [\"AXButton\", \"AXTextField\"])"}, 93 "max_labels": map[string]any{"type": "integer", "description": "Max elements to annotate (default: 50, for annotate)"}, 94 }, 95 }, 96 Required: []string{"action"}, 97 } 98 } 99 100 func (t *AccessibilityTool) RequiresApproval() bool { return false } 101 102 func (t *AccessibilityTool) IsReadOnlyCall(argsJSON string) bool { 103 var args struct { 104 Action string `json:"action"` 105 } 106 if json.Unmarshal([]byte(argsJSON), &args) != nil { 107 return false 108 } 109 switch args.Action { 110 case "read_tree", "annotate", "find", "get_value": 111 return true 112 default: 113 return false 114 } 115 } 116 117 func (t *AccessibilityTool) Run(ctx context.Context, argsJSON string) (agent.ToolResult, error) { 118 if runtime.GOOS != "darwin" || t.client == nil { 119 return agent.ToolResult{Content: "accessibility tool is only available on macOS", IsError: true}, nil 120 } 121 122 var args accessibilityArgs 123 if err := json.Unmarshal([]byte(argsJSON), &args); err != nil { 124 return agent.ToolResult{Content: fmt.Sprintf("invalid arguments: %v", err), IsError: true}, nil 125 } 126 127 if args.Action == "" { 128 return agent.ToolResult{Content: "missing required parameter: action", IsError: true}, nil 129 } 130 131 switch args.Action { 132 case "read_tree": 133 return t.readTree(ctx, args) 134 case "click", "press": 135 return t.performAction(ctx, args.Action, args.Ref) 136 case "set_value": 137 return t.setValue(ctx, args.Ref, args.Value) 138 case "get_value": 139 return t.getValue(ctx, args.Ref) 140 case "find": 141 return t.find(ctx, args) 142 case "scroll": 143 return t.scroll(ctx, args) 144 case "annotate": 145 return t.annotate(ctx, args) 146 default: 147 return agent.ToolResult{ 148 Content: fmt.Sprintf("unknown action: %q (valid: read_tree, click, press, set_value, get_value, find, scroll, annotate)", args.Action), 149 IsError: true, 150 }, nil 151 } 152 } 153 154 // validAppName checks that an app name contains only safe characters. 155 var validAppNamePattern = regexp.MustCompile(`^[a-zA-Z0-9 ._\-()]+$`) 156 157 func (t *AccessibilityTool) resolvePID(ctx context.Context, appName string) (int, error) { 158 if appName == "" { 159 return 0, nil // ax_server will use frontmost 160 } 161 if !validAppNamePattern.MatchString(appName) { 162 return 0, fmt.Errorf("invalid app name %q — only letters, numbers, spaces, dots, hyphens, underscores, and parentheses allowed", appName) 163 } 164 result, err := t.client.Call(ctx, "resolve_pid", map[string]any{"app_name": appName}) 165 if err != nil { 166 return 0, fmt.Errorf("app %q not found or not running", appName) 167 } 168 var pidResult struct { 169 PID int `json:"pid"` 170 } 171 if err := json.Unmarshal(result, &pidResult); err != nil { 172 return 0, fmt.Errorf("could not parse PID for %q", appName) 173 } 174 return pidResult.PID, nil 175 } 176 177 func (t *AccessibilityTool) readTree(ctx context.Context, args accessibilityArgs) (agent.ToolResult, error) { 178 pid, err := t.resolvePID(ctx, args.App) 179 if err != nil { 180 return agent.ToolResult{Content: err.Error(), IsError: true}, nil 181 } 182 183 params := map[string]any{ 184 "filter": args.Filter, 185 } 186 if args.Filter == "" { 187 params["filter"] = "all" 188 } 189 if pid > 0 { 190 params["pid"] = pid 191 } 192 // Semantic budget takes priority over max_depth 193 if args.Budget > 0 { 194 params["semantic_budget"] = args.Budget 195 } else if args.MaxDepth > 0 { 196 params["max_depth"] = args.MaxDepth 197 } 198 199 result, err := t.client.Call(ctx, "read_tree", params) 200 if err != nil { 201 return agent.ToolResult{Content: fmt.Sprintf("accessibility error: %v", err), IsError: true}, nil 202 } 203 204 var treeResult struct { 205 App string `json:"app"` 206 PID int `json:"pid"` 207 Window string `json:"window"` 208 Elements []any `json:"elements"` 209 RefPaths map[string]map[string]string `json:"ref_paths"` 210 } 211 if err := json.Unmarshal(result, &treeResult); err != nil { 212 return agent.ToolResult{Content: fmt.Sprintf("parse error: %v", err), IsError: true}, nil 213 } 214 215 t.refs = make(map[string]refEntry) 216 t.lastPID = treeResult.PID 217 218 for ref, entry := range treeResult.RefPaths { 219 t.refs[ref] = refEntry{ 220 path: entry["path"], 221 role: entry["role"], 222 pid: treeResult.PID, 223 } 224 } 225 226 // Remove ref_paths from output (agent doesn't need them) 227 var outputMap map[string]any 228 json.Unmarshal(result, &outputMap) 229 delete(outputMap, "ref_paths") 230 231 outputJSON, _ := json.MarshalIndent(outputMap, "", " ") 232 content := string(outputJSON) 233 234 // Truncate if too large 235 if len(content) > 8000 { 236 if elems, ok := outputMap["elements"].([]any); ok { 237 lo, hi := 0, len(elems) 238 for lo < hi { 239 mid := (lo + hi + 1) / 2 240 outputMap["elements"] = elems[:mid] 241 trial, _ := json.MarshalIndent(outputMap, "", " ") 242 if len(trial) <= 7800 { 243 lo = mid 244 } else { 245 hi = mid - 1 246 } 247 } 248 outputMap["elements"] = elems[:lo] 249 outputMap["truncated"] = fmt.Sprintf("showing %d of %d elements — use filter='interactive' or lower semantic_budget", lo, len(elems)) 250 outputJSON, _ = json.MarshalIndent(outputMap, "", " ") 251 content = string(outputJSON) 252 } 253 } 254 255 return agent.ToolResult{Content: content}, nil 256 } 257 258 func (t *AccessibilityTool) lookupRef(ref string) (refEntry, error) { 259 if ref == "" { 260 return refEntry{}, fmt.Errorf("missing required parameter: ref") 261 } 262 if t.refs == nil || len(t.refs) == 0 { 263 return refEntry{}, fmt.Errorf("no refs available — call read_tree first") 264 } 265 entry, ok := t.refs[ref] 266 if !ok { 267 return refEntry{}, fmt.Errorf("unknown ref %q — call read_tree to get current refs", ref) 268 } 269 return entry, nil 270 } 271 272 func (t *AccessibilityTool) performAction(ctx context.Context, action string, ref string) (agent.ToolResult, error) { 273 entry, err := t.lookupRef(ref) 274 if err != nil { 275 return agent.ToolResult{Content: err.Error(), IsError: true}, nil 276 } 277 278 params := map[string]any{ 279 "pid": entry.pid, 280 "path": entry.path, 281 } 282 if entry.role != "" { 283 params["expected_role"] = entry.role 284 } 285 286 result, err := t.client.Call(ctx, action, params) 287 if err != nil { 288 return agent.ToolResult{Content: fmt.Sprintf("accessibility error: %v", err), IsError: true}, nil 289 } 290 291 var actionResult struct { 292 Result string `json:"result"` 293 Context *appContext `json:"context,omitempty"` 294 } 295 json.Unmarshal(result, &actionResult) 296 return agent.ToolResult{Content: actionResult.Result + formatContext(actionResult.Context)}, nil 297 } 298 299 func (t *AccessibilityTool) setValue(ctx context.Context, ref string, value *string) (agent.ToolResult, error) { 300 entry, err := t.lookupRef(ref) 301 if err != nil { 302 return agent.ToolResult{Content: err.Error(), IsError: true}, nil 303 } 304 if value == nil { 305 return agent.ToolResult{Content: "set_value requires 'value' parameter", IsError: true}, nil 306 } 307 308 params := map[string]any{ 309 "pid": entry.pid, 310 "path": entry.path, 311 "value": *value, 312 } 313 if entry.role != "" { 314 params["expected_role"] = entry.role 315 } 316 317 result, err := t.client.Call(ctx, "set_value", params) 318 if err != nil { 319 return agent.ToolResult{Content: fmt.Sprintf("accessibility error: %v", err), IsError: true}, nil 320 } 321 322 var actionResult struct { 323 Result string `json:"result"` 324 Context *appContext `json:"context,omitempty"` 325 } 326 json.Unmarshal(result, &actionResult) 327 return agent.ToolResult{Content: actionResult.Result + formatContext(actionResult.Context)}, nil 328 } 329 330 func (t *AccessibilityTool) getValue(ctx context.Context, ref string) (agent.ToolResult, error) { 331 entry, err := t.lookupRef(ref) 332 if err != nil { 333 return agent.ToolResult{Content: err.Error(), IsError: true}, nil 334 } 335 336 params := map[string]any{ 337 "pid": entry.pid, 338 "path": entry.path, 339 } 340 341 result, err := t.client.Call(ctx, "get_value", params) 342 if err != nil { 343 return agent.ToolResult{Content: fmt.Sprintf("accessibility error: %v", err), IsError: true}, nil 344 } 345 346 var actionResult struct { 347 Result string `json:"result"` 348 Role string `json:"role"` 349 Context *appContext `json:"context,omitempty"` 350 } 351 json.Unmarshal(result, &actionResult) 352 msg := actionResult.Result 353 if actionResult.Role != "" { 354 msg = fmt.Sprintf("%s (role: %s)", msg, actionResult.Role) 355 } 356 msg += formatContext(actionResult.Context) 357 return agent.ToolResult{Content: msg}, nil 358 } 359 360 func (t *AccessibilityTool) find(ctx context.Context, args accessibilityArgs) (agent.ToolResult, error) { 361 pid, err := t.resolvePID(ctx, args.App) 362 if err != nil { 363 return agent.ToolResult{Content: err.Error(), IsError: true}, nil 364 } 365 366 params := map[string]any{} 367 if pid > 0 { 368 params["pid"] = pid 369 } 370 if args.Query != "" { 371 params["query"] = args.Query 372 } 373 if args.Role != "" { 374 params["role"] = args.Role 375 } 376 if args.Identifier != "" { 377 params["identifier"] = args.Identifier 378 } 379 380 result, err := t.client.Call(ctx, "find", params) 381 if err != nil { 382 return agent.ToolResult{Content: fmt.Sprintf("find error: %v", err), IsError: true}, nil 383 } 384 385 outputJSON, _ := json.MarshalIndent(json.RawMessage(result), "", " ") 386 content := string(outputJSON) 387 if len(content) > 8000 { 388 content = content[:7900] + "\n... [truncated]" 389 } 390 return agent.ToolResult{Content: content}, nil 391 } 392 393 func (t *AccessibilityTool) annotate(ctx context.Context, args accessibilityArgs) (agent.ToolResult, error) { 394 pid, err := t.resolvePID(ctx, args.App) 395 if err != nil { 396 return agent.ToolResult{Content: err.Error(), IsError: true}, nil 397 } 398 399 params := map[string]any{} 400 if pid > 0 { 401 params["pid"] = pid 402 } 403 if len(args.Roles) > 0 { 404 params["roles"] = args.Roles 405 } 406 if args.MaxLabels > 0 { 407 params["max_labels"] = args.MaxLabels 408 } 409 410 result, err := t.client.Call(ctx, "annotate", params) 411 if err != nil { 412 return agent.ToolResult{Content: fmt.Sprintf("annotate error: %v", err), IsError: true}, nil 413 } 414 415 // Parse the annotation result 416 var annotateResult struct { 417 App string `json:"app"` 418 PID int `json:"pid"` 419 Window string `json:"window"` 420 Annotations []annotationEntry `json:"annotations"` 421 RefPaths map[string]map[string]string `json:"ref_paths"` 422 } 423 if err := json.Unmarshal(result, &annotateResult); err != nil { 424 return agent.ToolResult{Content: fmt.Sprintf("parse error: %v", err), IsError: true}, nil 425 } 426 427 // Store refs so the agent can click by ref after annotating 428 t.refs = make(map[string]refEntry) 429 t.lastPID = annotateResult.PID 430 for ref, entry := range annotateResult.RefPaths { 431 t.refs[ref] = refEntry{ 432 path: entry["path"], 433 role: entry["role"], 434 pid: annotateResult.PID, 435 } 436 } 437 438 // Build text index 439 lines := make([]string, 0, len(annotateResult.Annotations)+1) 440 lines = append(lines, fmt.Sprintf("App: %s | Window: %s | %d elements", annotateResult.App, annotateResult.Window, len(annotateResult.Annotations))) 441 for _, a := range annotateResult.Annotations { 442 title := a.Title 443 if title == "" { 444 title = "(untitled)" 445 } 446 lines = append(lines, fmt.Sprintf("[%d] ref=%s %s %q (%.0f, %.0f, %.0f x %.0f)", a.Label, a.Ref, a.Role, title, a.X, a.Y, a.Width, a.Height)) 447 } 448 content := strings.Join(lines, "\n") 449 450 // Take a screenshot and draw annotation markers on it 451 screenshotPath, imgBlock, captureErr := CaptureAndEncode(DefaultAPIWidth) 452 var images []agent.ImageBlock 453 if captureErr == nil { 454 // Get screen dimensions for coordinate mapping 455 screenW, screenH, dimErr := GetScreenDimensions() 456 if dimErr == nil && len(annotateResult.Annotations) > 0 { 457 annotatedBlock, annotErr := drawAnnotations(screenshotPath, annotateResult.Annotations, screenW, screenH) 458 if annotErr == nil { 459 imgBlock = annotatedBlock 460 } 461 } 462 images = append(images, imgBlock) 463 // Clean up original screenshot temp file 464 os.Remove(screenshotPath) 465 } 466 467 return agent.ToolResult{ 468 Content: content, 469 Images: images, 470 }, nil 471 } 472 473 func (t *AccessibilityTool) scroll(ctx context.Context, args accessibilityArgs) (agent.ToolResult, error) { 474 pid := t.lastPID 475 var path *string 476 if args.Ref != "" { 477 entry, err := t.lookupRef(args.Ref) 478 if err != nil { 479 return agent.ToolResult{Content: err.Error(), IsError: true}, nil 480 } 481 pid = entry.pid 482 path = &entry.path 483 } 484 485 params := map[string]any{ 486 "dx": args.DX, 487 "dy": args.DY, 488 } 489 if pid > 0 { 490 params["pid"] = pid 491 } 492 if path != nil { 493 params["path"] = *path 494 } 495 496 result, err := t.client.Call(ctx, "scroll", params) 497 if err != nil { 498 return agent.ToolResult{Content: fmt.Sprintf("scroll error: %v", err), IsError: true}, nil 499 } 500 501 var actionResult struct { 502 Result string `json:"result"` 503 Context *appContext `json:"context,omitempty"` 504 } 505 json.Unmarshal(result, &actionResult) 506 return agent.ToolResult{Content: actionResult.Result + formatContext(actionResult.Context)}, nil 507 } 508 509 type annotationEntry struct { 510 Label int `json:"label"` 511 Ref string `json:"ref"` 512 Role string `json:"role"` 513 Title string `json:"title,omitempty"` 514 X float64 `json:"x"` 515 Y float64 `json:"y"` 516 Width float64 `json:"width"` 517 Height float64 `json:"height"` 518 } 519 520 // drawAnnotations loads a screenshot image and draws numbered markers at each 521 // annotation's center position. Returns the annotated image as an ImageBlock. 522 func drawAnnotations(imgPath string, annotations []annotationEntry, screenW, screenH int) (agent.ImageBlock, error) { 523 f, err := os.Open(imgPath) 524 if err != nil { 525 return agent.ImageBlock{}, err 526 } 527 defer f.Close() 528 529 img, _, err := image.Decode(f) 530 if err != nil { 531 return agent.ImageBlock{}, err 532 } 533 534 bounds := img.Bounds() 535 annotated := image.NewRGBA(bounds) 536 draw.Draw(annotated, bounds, img, image.Point{}, draw.Src) 537 538 // Scale: screen coordinates -> image coordinates 539 scaleX := float64(bounds.Dx()) / float64(screenW) 540 scaleY := float64(bounds.Dy()) / float64(screenH) 541 542 for _, a := range annotations { 543 // Center of element in screen coords -> image coords 544 cx := int((a.X + a.Width/2) * scaleX) 545 cy := int((a.Y + a.Height/2) * scaleY) 546 drawMarker(annotated, cx, cy, a.Label) 547 } 548 549 // Write annotated image to a temp file 550 outFile, err := os.CreateTemp("", "shannon-annotated-*.png") 551 if err != nil { 552 return agent.ImageBlock{}, err 553 } 554 defer outFile.Close() 555 556 if err := png.Encode(outFile, annotated); err != nil { 557 os.Remove(outFile.Name()) 558 return agent.ImageBlock{}, err 559 } 560 561 block, err := EncodeImage(outFile.Name()) 562 os.Remove(outFile.Name()) // clean up temp file after encoding 563 if err != nil { 564 return agent.ImageBlock{}, err 565 } 566 return block, nil 567 } 568 569 // drawMarker draws a filled circle with a contrasting border at (x, y) on the image. 570 func drawMarker(img *image.RGBA, x, y, label int) { 571 radius := 10 572 bounds := img.Bounds() 573 red := color.RGBA{R: 255, G: 50, B: 50, A: 230} 574 white := color.RGBA{R: 255, G: 255, B: 255, A: 255} 575 576 for dy := -radius; dy <= radius; dy++ { 577 for dx := -radius; dx <= radius; dx++ { 578 dist := math.Sqrt(float64(dx*dx + dy*dy)) 579 if dist > float64(radius) { 580 continue 581 } 582 px, py := x+dx, y+dy 583 if px < bounds.Min.X || px >= bounds.Max.X || py < bounds.Min.Y || py >= bounds.Max.Y { 584 continue 585 } 586 if dist > float64(radius-2) { 587 img.Set(px, py, white) 588 } else { 589 img.Set(px, py, red) 590 } 591 } 592 } 593 594 // Draw label number using simple pixel font 595 drawLabelNumber(img, x, y, label, white) 596 } 597 598 // digitPatterns contains 5x7 bitmap patterns for digits 0-9. 599 // Each digit is a 5-wide, 7-tall grid stored as 7 bytes where bits 4..0 represent columns. 600 var digitPatterns = [10][7]byte{ 601 {0x0E, 0x11, 0x13, 0x15, 0x19, 0x11, 0x0E}, // 0 602 {0x04, 0x0C, 0x04, 0x04, 0x04, 0x04, 0x0E}, // 1 603 {0x0E, 0x11, 0x01, 0x06, 0x08, 0x10, 0x1F}, // 2 604 {0x0E, 0x11, 0x01, 0x06, 0x01, 0x11, 0x0E}, // 3 605 {0x02, 0x06, 0x0A, 0x12, 0x1F, 0x02, 0x02}, // 4 606 {0x1F, 0x10, 0x1E, 0x01, 0x01, 0x11, 0x0E}, // 5 607 {0x06, 0x08, 0x10, 0x1E, 0x11, 0x11, 0x0E}, // 6 608 {0x1F, 0x01, 0x02, 0x04, 0x08, 0x08, 0x08}, // 7 609 {0x0E, 0x11, 0x11, 0x0E, 0x11, 0x11, 0x0E}, // 8 610 {0x0E, 0x11, 0x11, 0x0F, 0x01, 0x02, 0x0C}, // 9 611 } 612 613 // drawLabelNumber renders a number at position (cx, cy) using a simple bitmap font. 614 func drawLabelNumber(img *image.RGBA, cx, cy, num int, col color.RGBA) { 615 s := fmt.Sprintf("%d", num) 616 totalW := len(s) * 6 // 5px wide + 1px gap per digit 617 startX := cx - totalW/2 618 startY := cy - 3 // center vertically (7px tall / 2) 619 bounds := img.Bounds() 620 621 for i, ch := range s { 622 d := int(ch - '0') 623 if d < 0 || d > 9 { 624 continue 625 } 626 ox := startX + i*6 627 for row := 0; row < 7; row++ { 628 bits := digitPatterns[d][row] 629 for colIdx := 0; colIdx < 5; colIdx++ { 630 if bits&(1<<uint(4-colIdx)) != 0 { 631 px, py := ox+colIdx, startY+row 632 if px >= bounds.Min.X && px < bounds.Max.X && py >= bounds.Min.Y && py < bounds.Max.Y { 633 img.Set(px, py, col) 634 } 635 } 636 } 637 } 638 } 639 }