/ extensions / protect-paths.ts
protect-paths.ts
  1  /**
  2   * Protect Paths Extension
  3   *
  4   * Standalone directory protection hooks that complement @aliou/pi-guardrails
  5   * (which handles .env files and dangerous command confirmation)
  6   *
  7   * This extension protects:
  8   * - .git/ directory contents (prevents repository corruption)
  9   * - node_modules/ directory contents (use package manager instead)
 10   * - Homebrew install/upgrade commands (remind to use project package manager)
 11   * - Broad delete commands (rm/rmdir/unlink)
 12   * - Piped shell execution (e.g. `curl ... | sh`)
 13   *
 14   * Bash command checks are AST-backed via just-bash parsing so nested
 15   * substitutions/functions/conditionals are inspected instead of regex-only matching
 16   *
 17   * Dependency note:
 18   * - For best results, install `just-bash` >= 2 (provides the bash AST parser export)
 19   * - If unavailable, this extension falls back to best-effort regex checks
 20   */
 21  
 22  import { resolve, sep } from "node:path";
 23  import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
 24  
 25  let parseBash: ((input: string) => any) | null = null;
 26  let justBashLoadPromise: Promise<void> | null = null;
 27  let justBashLoadDone = false;
 28  
 29  async function ensureJustBashLoaded(): Promise<void> {
 30      if (justBashLoadDone) return;
 31  
 32      if (!justBashLoadPromise) {
 33          justBashLoadPromise = import("just-bash")
 34              .then((mod: any) => {
 35                  parseBash = typeof mod?.parse === "function" ? mod.parse : null;
 36              })
 37              .catch(() => {
 38                  parseBash = null;
 39              })
 40              .finally(() => {
 41                  justBashLoadDone = true;
 42              });
 43      }
 44  
 45      await justBashLoadPromise;
 46  }
 47  
 48  let warnedAstUnavailable = false;
 49  function maybeWarnAstUnavailable(ctx: any): void {
 50      if (warnedAstUnavailable) return;
 51      if (parseBash) return;
 52      if (!ctx?.hasUI) return;
 53  
 54      warnedAstUnavailable = true;
 55      ctx.ui.notify(
 56          "protect-paths: just-bash >= 2 is not available; falling back to best-effort regex command checks",
 57          "warning",
 58      );
 59  }
 60  
 61  type BashInvocation = {
 62      pipelineIndex: number;
 63      pipelineLength: number;
 64      commandNameRaw: string;
 65      commandName: string;
 66      args: string[];
 67      effectiveCommandNameRaw: string;
 68      effectiveCommandName: string;
 69      effectiveArgs: string[];
 70      redirections: Array<{ operator: string; target: string }>;
 71  };
 72  
 73  type BashAnalysis = {
 74      parseError?: string;
 75      invocations: BashInvocation[];
 76  };
 77  
 78  const WRAPPER_COMMANDS = new Set(["command", "builtin", "exec", "nohup"]);
 79  
 80  function commandBaseName(value: string): string {
 81      const normalized = value.replace(/\\+/g, "/");
 82      const idx = normalized.lastIndexOf("/");
 83      const base = idx >= 0 ? normalized.slice(idx + 1) : normalized;
 84      return base.toLowerCase();
 85  }
 86  
 87  function partToText(part: any): string {
 88      if (!part || typeof part !== "object") return "";
 89  
 90      switch (part.type) {
 91          case "Literal":
 92          case "SingleQuoted":
 93          case "Escaped":
 94              return typeof part.value === "string" ? part.value : "";
 95          case "DoubleQuoted":
 96              return Array.isArray(part.parts) ? part.parts.map(partToText).join("") : "";
 97          case "Glob":
 98              return typeof part.pattern === "string" ? part.pattern : "";
 99          case "TildeExpansion":
100              return typeof part.user === "string" && part.user.length > 0 ? `~${part.user}` : "~";
101          case "BraceExpansion":
102              return "{...}";
103          case "ParameterExpansion":
104              return typeof part.parameter === "string" && part.parameter.length > 0
105                  ? "${" + part.parameter + "}"
106                  : "${}";
107          case "CommandSubstitution":
108              return "$(...)";
109          case "ProcessSubstitution":
110              return part.direction === "output" ? ">(...)" : "<(...)";
111          case "ArithmeticExpansion":
112              return "$((...))";
113          default:
114              return "";
115      }
116  }
117  
118  function wordToText(word: any): string {
119      if (!word || typeof word !== "object" || !Array.isArray(word.parts)) return "";
120      return word.parts.map(partToText).join("");
121  }
122  
123  function resolveEffectiveCommand(commandNameRaw: string, args: string[]): {
124      effectiveCommandNameRaw: string;
125      effectiveCommandName: string;
126      effectiveArgs: string[];
127  } {
128      const primary = commandNameRaw.trim();
129      const primaryBase = commandBaseName(primary);
130  
131      if (WRAPPER_COMMANDS.has(primaryBase)) {
132          const next = args[0] ?? "";
133          return {
134              effectiveCommandNameRaw: next,
135              effectiveCommandName: commandBaseName(next),
136              effectiveArgs: args.slice(1),
137          };
138      }
139  
140      if (primaryBase === "env") {
141          let idx = 0;
142          while (idx < args.length) {
143              const token = args[idx] ?? "";
144              if (token === "--") {
145                  idx += 1;
146                  break;
147              }
148              if (token.startsWith("-") || /^[A-Za-z_][A-Za-z0-9_]*=.*/.test(token)) {
149                  idx += 1;
150                  continue;
151              }
152              break;
153          }
154  
155          const next = args[idx] ?? "";
156          return {
157              effectiveCommandNameRaw: next,
158              effectiveCommandName: commandBaseName(next),
159              effectiveArgs: args.slice(idx + 1),
160          };
161      }
162  
163      if (primaryBase === "sudo") {
164          let idx = 0;
165          while (idx < args.length) {
166              const token = args[idx] ?? "";
167              if (token === "--") {
168                  idx += 1;
169                  break;
170              }
171              if (token.startsWith("-")) {
172                  idx += 1;
173                  continue;
174              }
175              break;
176          }
177  
178          const next = args[idx] ?? "";
179          return {
180              effectiveCommandNameRaw: next,
181              effectiveCommandName: commandBaseName(next),
182              effectiveArgs: args.slice(idx + 1),
183          };
184      }
185  
186      return {
187          effectiveCommandNameRaw: primary,
188          effectiveCommandName: primaryBase,
189          effectiveArgs: args,
190      };
191  }
192  
193  function collectNestedScriptsFromWord(word: any, collect: (script: any) => void): void {
194      if (!word || typeof word !== "object" || !Array.isArray(word.parts)) return;
195  
196      for (const part of word.parts) {
197          if (!part || typeof part !== "object") continue;
198  
199          if (part.type === "DoubleQuoted") {
200              collectNestedScriptsFromWord(part, collect);
201              continue;
202          }
203  
204          if ((part.type === "CommandSubstitution" || part.type === "ProcessSubstitution") && part.body) {
205              collect(part.body);
206          }
207      }
208  }
209  
210  function analyzeBashScript(command: string): BashAnalysis {
211      try {
212          if (!parseBash) {
213              return { parseError: "just-bash parse unavailable", invocations: [] };
214          }
215  
216          const ast: any = parseBash(command);
217          const invocations: BashInvocation[] = [];
218  
219          const visitScript = (script: any) => {
220              if (!script || typeof script !== "object" || !Array.isArray(script.statements)) return;
221  
222              for (const statement of script.statements) {
223                  if (!statement || typeof statement !== "object" || !Array.isArray(statement.pipelines)) continue;
224  
225                  for (const [, pipeline] of statement.pipelines.entries()) {
226                      if (!pipeline || typeof pipeline !== "object" || !Array.isArray(pipeline.commands)) continue;
227  
228                      const pipelineLength = pipeline.commands.length;
229  
230                      for (const [pipelineIndex, commandNode] of pipeline.commands.entries()) {
231                          if (!commandNode || typeof commandNode !== "object") continue;
232  
233                          if (commandNode.type === "SimpleCommand") {
234                              const commandNameRaw = wordToText(commandNode.name).trim();
235                              const commandName = commandBaseName(commandNameRaw);
236                              const args = Array.isArray(commandNode.args)
237                                  ? commandNode.args.map((arg: any) => wordToText(arg)).filter(Boolean)
238                                  : [];
239                              const redirections = Array.isArray(commandNode.redirections)
240                                  ? commandNode.redirections.map((r: any) => ({
241                                      operator: typeof r?.operator === "string" ? r.operator : "",
242                                      target: r?.target?.type === "HereDoc" ? "heredoc" : wordToText(r?.target),
243                                  }))
244                                  : [];
245  
246                              const effective = resolveEffectiveCommand(commandNameRaw, args);
247                              invocations.push({
248                                  pipelineIndex,
249                                  pipelineLength,
250                                  commandNameRaw,
251                                  commandName,
252                                  args,
253                                  effectiveCommandNameRaw: effective.effectiveCommandNameRaw,
254                                  effectiveCommandName: effective.effectiveCommandName,
255                                  effectiveArgs: effective.effectiveArgs,
256                                  redirections,
257                              });
258  
259                              if (commandNode.name) {
260                                  collectNestedScriptsFromWord(commandNode.name, visitScript);
261                              }
262                              if (Array.isArray(commandNode.args)) {
263                                  for (const arg of commandNode.args) {
264                                      collectNestedScriptsFromWord(arg, visitScript);
265                                  }
266                              }
267                              continue;
268                          }
269  
270                          if (Array.isArray(commandNode.body)) visitScript({ statements: commandNode.body });
271                          if (Array.isArray(commandNode.condition)) visitScript({ statements: commandNode.condition });
272                          if (Array.isArray(commandNode.clauses)) {
273                              for (const clause of commandNode.clauses) {
274                                  if (Array.isArray(clause?.condition)) visitScript({ statements: clause.condition });
275                                  if (Array.isArray(clause?.body)) visitScript({ statements: clause.body });
276                              }
277                          }
278                          if (Array.isArray(commandNode.elseBody)) visitScript({ statements: commandNode.elseBody });
279                          if (Array.isArray(commandNode.items)) {
280                              for (const item of commandNode.items) {
281                                  if (Array.isArray(item?.body)) visitScript({ statements: item.body });
282                              }
283                          }
284                          if (commandNode.word) collectNestedScriptsFromWord(commandNode.word, visitScript);
285                          if (Array.isArray(commandNode.words)) {
286                              for (const word of commandNode.words) {
287                                  collectNestedScriptsFromWord(word, visitScript);
288                              }
289                          }
290                      }
291                  }
292              }
293          };
294  
295          visitScript(ast);
296          return { invocations };
297      } catch (error: any) {
298          return { parseError: error?.message ?? String(error), invocations: [] };
299      }
300  }
301  
302  // ============================================================================
303  // Configuration
304  // ============================================================================
305  
306  // Allow reading Pi's own node_modules when installed via Homebrew
307  const ALLOWED_NODE_MODULES_PREFIXES = [
308      resolve("/opt/homebrew/lib/node_modules/@mariozechner/pi-coding-agent"),
309  ];
310  
311  const SHELL_EXECUTABLES = new Set(["sh", "bash", "zsh", "dash", "ksh", "fish"]);
312  const DELETE_EXECUTABLES = new Set(["rm", "rmdir", "unlink"]);
313  const BREW_ACTIONS = new Set(["install", "bundle", "upgrade", "reinstall"]);
314  
315  const GIT_REF_REGEX = /(^|[^A-Za-z0-9._-])(\.git(?:[\\/][^\s]*)?)/g;
316  const NODE_MODULES_REF_REGEX = /(^|[^A-Za-z0-9._-])(node_modules(?:[\\/][^\s]*)?)/g;
317  
318  // Regex fallback for parse failures
319  const BREW_INSTALL_PATTERNS = [
320      /\bbrew\s+install\b/,
321      /\bbrew\s+cask\s+install\b/,
322      /\bbrew\s+bundle\b/,
323      /\bbrew\s+upgrade\b/,
324      /\bbrew\s+reinstall\b/,
325  ];
326  
327  // Tools that can read files (allowed to read from allowlisted node_modules)
328  const READ_TOOLS = ["read", "grep", "find", "ls"];
329  
330  // Tools that can write/modify files (strict: no node_modules allowlist)
331  const WRITE_TOOLS = ["write", "edit"];
332  
333  // ============================================================================
334  // Path checking
335  // ============================================================================
336  
337  const GIT_DIR_PATTERN = /(?:^|[/\\])\.git(?:[/\\]|$)/;
338  const NODE_MODULES_PATTERN = /(?:^|[/\\])node_modules(?:[/\\]|$)/;
339  
340  function isAllowedNodeModulesPath(filePath: string): boolean {
341      const resolved = resolve(filePath);
342      return ALLOWED_NODE_MODULES_PREFIXES.some(
343          (prefix) => resolved === prefix || resolved.startsWith(`${prefix}${sep}`),
344      );
345  }
346  
347  function isProtectedDirectory(filePath: string, allowNodeModulesRead: boolean): boolean {
348      const resolved = resolve(filePath);
349  
350      if (GIT_DIR_PATTERN.test(resolved)) {
351          return true;
352      }
353  
354      if (NODE_MODULES_PATTERN.test(resolved)) {
355          if (allowNodeModulesRead && isAllowedNodeModulesPath(resolved)) {
356              return false;
357          }
358          return true;
359      }
360  
361      return false;
362  }
363  
364  function getProtectionReason(filePath: string): string {
365      if (GIT_DIR_PATTERN.test(filePath)) {
366          return `Accessing ${filePath} is not allowed. The .git directory is protected to prevent repository corruption.`;
367      }
368      if (NODE_MODULES_PATTERN.test(filePath)) {
369          return `Accessing ${filePath} is not allowed. The node_modules directory is protected. Use package manager commands to manage dependencies.`;
370      }
371      return `Path "${filePath}" is protected.`;
372  }
373  
374  function extractPathFromInput(input: Record<string, unknown>): string {
375      const p = String(input.file_path ?? input.path ?? "");
376      return p || "";
377  }
378  
379  function appendMatches(refs: Set<string>, token: string, regex: RegExp): void {
380      regex.lastIndex = 0;
381      for (const match of token.matchAll(regex)) {
382          const captured = typeof match[2] === "string" ? match[2].trim() : "";
383          if (!captured) continue;
384          refs.add(captured);
385      }
386  }
387  
388  function extractProtectedDirRefsFromCommand(command: string): string[] {
389      const refs = new Set<string>();
390  
391      const analysis = analyzeBashScript(command);
392      if (!analysis.parseError) {
393          for (const invocation of analysis.invocations) {
394              const tokens = [
395                  invocation.commandNameRaw,
396                  invocation.effectiveCommandNameRaw,
397                  ...invocation.args,
398                  ...invocation.effectiveArgs,
399                  ...invocation.redirections.map((r) => r.target),
400              ].filter((value) => typeof value === "string" && value.length > 0);
401  
402              for (const token of tokens) {
403                  appendMatches(refs, token, GIT_REF_REGEX);
404                  appendMatches(refs, token, NODE_MODULES_REF_REGEX);
405              }
406          }
407      } else {
408          // Fallback: keep prior regex behavior if parser fails
409          const gitDirRegex =
410              /(?:^|\s|[<>|;&"'`])([^\s<>|;&"'`]*\.git[/\\][^\s<>|;&"'`]*)((?:\s|$|[<>|;&"'`]))/gi;
411          for (const match of command.matchAll(gitDirRegex)) {
412              if (match[1]) refs.add(match[1]);
413          }
414  
415          const nodeModulesRegex =
416              /(?:^|\s|[<>|;&"'`])([^\s<>|;&"'`]*node_modules[/\\][^\s<>|;&"'`]*)((?:\s|$|[<>|;&"'`]))/gi;
417          for (const match of command.matchAll(nodeModulesRegex)) {
418              if (match[1]) refs.add(match[1]);
419          }
420      }
421  
422      return [...refs];
423  }
424  
425  function isBrewInstallOrUpgrade(command: string): boolean {
426      const analysis = analyzeBashScript(command);
427  
428      if (!analysis.parseError) {
429          for (const invocation of analysis.invocations) {
430              if (invocation.effectiveCommandName !== "brew") continue;
431  
432              const args = invocation.effectiveArgs;
433              const first = (args[0] ?? "").toLowerCase();
434              const second = (args[1] ?? "").toLowerCase();
435  
436              if (BREW_ACTIONS.has(first)) {
437                  return true;
438              }
439  
440              if (first === "cask" && second === "install") {
441                  return true;
442              }
443          }
444  
445          return false;
446      }
447  
448      return BREW_INSTALL_PATTERNS.some((pattern) => pattern.test(command));
449  }
450  
451  function detectDangerousCommand(command: string): { kind: "delete" | "piped shell"; commandName?: string } | null {
452      const analysis = analyzeBashScript(command);
453  
454      if (!analysis.parseError) {
455          const deleteMatch = analysis.invocations.find((invocation) => DELETE_EXECUTABLES.has(invocation.effectiveCommandName));
456          if (deleteMatch) {
457              return {
458                  kind: "delete",
459                  commandName: deleteMatch.effectiveCommandNameRaw || deleteMatch.commandNameRaw,
460              };
461          }
462  
463          const pipedShellMatch = analysis.invocations.find(
464              (invocation) =>
465                  invocation.pipelineLength > 1
466                  && invocation.pipelineIndex > 0
467                  && SHELL_EXECUTABLES.has(invocation.effectiveCommandName),
468          );
469          if (pipedShellMatch) {
470              return {
471                  kind: "piped shell",
472                  commandName: pipedShellMatch.effectiveCommandNameRaw || pipedShellMatch.commandNameRaw,
473              };
474          }
475  
476          return null;
477      }
478  
479      // Fallback for parser failures
480      if (/\brm\s+/.test(command)) {
481          return { kind: "delete", commandName: "rm" };
482      }
483  
484      if (/\|\s*(?:sh|bash|zsh|dash|ksh|fish)\b/.test(command)) {
485          return { kind: "piped shell" };
486      }
487  
488      return null;
489  }
490  
491  // ============================================================================
492  // Extension
493  // ============================================================================
494  
495  export default function (pi: ExtensionAPI) {
496      // --- Directory protection for file-oriented tools ---
497      pi.on("tool_call", async (event, ctx) => {
498          const isReadTool = READ_TOOLS.includes(event.toolName);
499          const isWriteTool = WRITE_TOOLS.includes(event.toolName);
500          if (!isReadTool && !isWriteTool) return;
501  
502          const filePath = extractPathFromInput(event.input);
503          if (!filePath) return;
504  
505          const allowNodeModulesRead = isReadTool;
506          if (isProtectedDirectory(filePath, allowNodeModulesRead)) {
507              ctx.ui.notify(`Blocked access to protected path: ${filePath}`, "warning");
508              return {
509                  block: true,
510                  reason: getProtectionReason(filePath),
511              };
512          }
513          return;
514      });
515  
516      // --- Directory protection for bash commands ---
517      pi.on("tool_call", async (event, ctx) => {
518          if (event.toolName !== "bash") return;
519  
520          await ensureJustBashLoaded();
521          maybeWarnAstUnavailable(ctx);
522  
523          const command = String(event.input.command ?? "");
524          const refs = extractProtectedDirRefsFromCommand(command);
525  
526          for (const ref of refs) {
527              if (isProtectedDirectory(ref, false)) {
528                  ctx.ui.notify(`Blocked access to protected path: ${ref}`, "warning");
529                  return {
530                      block: true,
531                      reason: `Command references protected path ${ref}. ${getProtectionReason(ref)}`,
532                  };
533              }
534          }
535          return;
536      });
537  
538      // --- Prevent Homebrew install/upgrade ---
539      pi.on("tool_call", async (event, ctx) => {
540          if (event.toolName !== "bash") return;
541  
542          await ensureJustBashLoaded();
543          maybeWarnAstUnavailable(ctx);
544  
545          const command = String(event.input.command ?? "");
546  
547          if (isBrewInstallOrUpgrade(command)) {
548              ctx.ui.notify("Blocked brew command. Use the project's package manager instead.", "warning");
549              return {
550                  block: true,
551                  reason: "Homebrew install/upgrade commands are blocked. Please use the project's package manager (npm, pnpm, bun, nix, etc.) instead.",
552              };
553          }
554  
555          return;
556      });
557  
558      // --- Extra permission gates (confirm, not hard block) ---
559      // These complement upstream @aliou/pi-guardrails which covers rm -rf, sudo,
560      // dd, mkfs, chmod -R 777, chown -R via AST structural matching.
561      pi.on("tool_call", async (event, ctx) => {
562          if (event.toolName !== "bash") return;
563  
564          await ensureJustBashLoaded();
565          maybeWarnAstUnavailable(ctx);
566  
567          const command = String(event.input.command ?? "");
568          const danger = detectDangerousCommand(command);
569          if (!danger) return;
570  
571          const truncatedCmd = command.length > 80
572              ? `${command.substring(0, 80)}...`
573              : command;
574  
575          const proceed = await ctx.ui.confirm(
576              "Dangerous Command Detected",
577              `This command contains ${danger.kind}${danger.commandName ? ` (${danger.commandName})` : ""}:\n\n${truncatedCmd}\n\nAllow execution?`,
578          );
579  
580          if (!proceed) {
581              return { block: true, reason: "User denied dangerous command" };
582          }
583  
584          return;
585      });
586  }