/ internal / daemon / safeguard_test.go
safeguard_test.go
  1  package daemon
  2  
  3  import (
  4  	"testing"
  5  )
  6  
  7  // --- requireConfirm ---
  8  
  9  func TestRequireConfirm(t *testing.T) {
 10  	cases := []struct {
 11  		input string
 12  		want  bool
 13  	}{
 14  		{"", true},
 15  		{"false", true},
 16  		{"1", true},
 17  		{"yes", true},
 18  		{"True", true},
 19  		{"true", false},
 20  	}
 21  	for _, c := range cases {
 22  		got := requireConfirm(c.input)
 23  		if got != c.want {
 24  			t.Errorf("requireConfirm(%q) = %v, want %v", c.input, got, c.want)
 25  		}
 26  	}
 27  }
 28  
 29  // --- checkProtectedFields ---
 30  
 31  func TestCheckProtectedFields_Safe(t *testing.T) {
 32  	patch := map[string]interface{}{
 33  		"model":   "claude-3-7",
 34  		"timeout": 30,
 35  	}
 36  	reason, isProtected := checkProtectedFields(patch)
 37  	if isProtected {
 38  		t.Errorf("expected no protected field, got reason=%q", reason)
 39  	}
 40  }
 41  
 42  func TestCheckProtectedFields_Endpoint(t *testing.T) {
 43  	patch := map[string]interface{}{
 44  		"endpoint": "https://evil.example.com",
 45  	}
 46  	reason, isProtected := checkProtectedFields(patch)
 47  	if !isProtected {
 48  		t.Fatal("expected isProtected=true for endpoint")
 49  	}
 50  	if reason != "changes API connection target" {
 51  		t.Errorf("unexpected reason: %q", reason)
 52  	}
 53  }
 54  
 55  func TestCheckProtectedFields_APIKey(t *testing.T) {
 56  	patch := map[string]interface{}{
 57  		"api_key": "sk-leaked",
 58  	}
 59  	reason, isProtected := checkProtectedFields(patch)
 60  	if !isProtected {
 61  		t.Fatal("expected isProtected=true for api_key")
 62  	}
 63  	if reason != "changes authentication credentials" {
 64  		t.Errorf("unexpected reason: %q", reason)
 65  	}
 66  }
 67  
 68  func TestCheckProtectedFields_PermissionsDeniedCommands(t *testing.T) {
 69  	patch := map[string]interface{}{
 70  		"permissions": map[string]interface{}{
 71  			"denied_commands": []string{"rm"},
 72  		},
 73  	}
 74  	reason, isProtected := checkProtectedFields(patch)
 75  	if !isProtected {
 76  		t.Fatal("expected isProtected=true for permissions.denied_commands")
 77  	}
 78  	if reason != "removes security restrictions" {
 79  		t.Errorf("unexpected reason: %q", reason)
 80  	}
 81  }
 82  
 83  func TestCheckProtectedFields_DaemonAutoApprove(t *testing.T) {
 84  	patch := map[string]interface{}{
 85  		"daemon": map[string]interface{}{
 86  			"auto_approve": true,
 87  		},
 88  	}
 89  	if _, isProtected := checkProtectedFields(patch); isProtected {
 90  		t.Fatal("daemon.auto_approve should be settable via API")
 91  	}
 92  }
 93  
 94  func TestCheckProtectedFields_NestedParentNotMap(t *testing.T) {
 95  	// If the parent key exists but isn't a map, it shouldn't panic or false-positive
 96  	patch := map[string]interface{}{
 97  		"permissions": "some string value",
 98  	}
 99  	_, isProtected := checkProtectedFields(patch)
100  	if isProtected {
101  		t.Error("expected isProtected=false when parent value is not a map")
102  	}
103  }
104  
105  // --- validateMCPCommands ---
106  
107  func TestValidateMCPCommands_SafeNpx(t *testing.T) {
108  	servers := map[string]interface{}{
109  		"myserver": map[string]interface{}{
110  			"command": "npx",
111  			"args":    []string{"-y", "some-mcp-server"},
112  		},
113  	}
114  	if err := validateMCPCommands(servers, false); err != nil {
115  		t.Errorf("expected nil for safe npx, got: %v", err)
116  	}
117  }
118  
119  func TestValidateMCPCommands_AbsolutePath(t *testing.T) {
120  	servers := map[string]interface{}{
121  		"myserver": map[string]interface{}{
122  			"command": "/usr/local/bin/mcp-server",
123  		},
124  	}
125  	if err := validateMCPCommands(servers, false); err != nil {
126  		t.Errorf("expected nil for absolute path, got: %v", err)
127  	}
128  }
129  
130  func TestValidateMCPCommands_HTTPTypeSkipped(t *testing.T) {
131  	servers := map[string]interface{}{
132  		"remote": map[string]interface{}{
133  			"type":    "http",
134  			"command": "rm; evil",
135  		},
136  	}
137  	if err := validateMCPCommands(servers, false); err != nil {
138  		t.Errorf("expected nil for non-stdio type, got: %v", err)
139  	}
140  }
141  
142  func TestValidateMCPCommands_ShellMetachar_Semicolon(t *testing.T) {
143  	servers := map[string]interface{}{
144  		"evil": map[string]interface{}{
145  			"command": "node; rm -rf /",
146  		},
147  	}
148  	if err := validateMCPCommands(servers, false); err == nil {
149  		t.Error("expected error for shell metachar (semicolon), got nil")
150  	}
151  }
152  
153  func TestValidateMCPCommands_ShellMetachar_Pipe(t *testing.T) {
154  	servers := map[string]interface{}{
155  		"evil": map[string]interface{}{
156  			"command": "node|cat",
157  		},
158  	}
159  	if err := validateMCPCommands(servers, false); err == nil {
160  		t.Error("expected error for shell metachar (pipe), got nil")
161  	}
162  }
163  
164  func TestValidateMCPCommands_UnknownCommand_NoConfirm(t *testing.T) {
165  	servers := map[string]interface{}{
166  		"custom": map[string]interface{}{
167  			"command": "my-custom-mcp-server",
168  		},
169  	}
170  	if err := validateMCPCommands(servers, false); err == nil {
171  		t.Error("expected error for unknown command without confirm, got nil")
172  	}
173  }
174  
175  func TestValidateMCPCommands_UnknownCommand_WithConfirm(t *testing.T) {
176  	servers := map[string]interface{}{
177  		"custom": map[string]interface{}{
178  			"command": "my-custom-mcp-server",
179  		},
180  	}
181  	if err := validateMCPCommands(servers, true); err != nil {
182  		t.Errorf("expected nil for unknown command with confirm, got: %v", err)
183  	}
184  }
185  
186  func TestValidateMCPCommands_MetacharAlwaysBlocked_EvenWithConfirm(t *testing.T) {
187  	servers := map[string]interface{}{
188  		"evil": map[string]interface{}{
189  			"command": "node$(evil)",
190  		},
191  	}
192  	if err := validateMCPCommands(servers, true); err == nil {
193  		t.Error("expected error for metachar even with confirm=true, got nil")
194  	}
195  }
196  
197  func TestValidateMCPCommands_NoCommand(t *testing.T) {
198  	// Servers without "command" field should be skipped
199  	servers := map[string]interface{}{
200  		"nocommand": map[string]interface{}{
201  			"url": "http://localhost:3000",
202  		},
203  	}
204  	if err := validateMCPCommands(servers, false); err != nil {
205  		t.Errorf("expected nil for server without command, got: %v", err)
206  	}
207  }
208  
209  func TestValidateMCPCommands_AllSafeCommands(t *testing.T) {
210  	safe := []string{"node", "npx", "python", "python3", "uvx", "uv", "go", "deno", "bun", "docker", "pip", "pipx"}
211  	for _, cmd := range safe {
212  		servers := map[string]interface{}{
213  			"s": map[string]interface{}{"command": cmd},
214  		}
215  		if err := validateMCPCommands(servers, false); err != nil {
216  			t.Errorf("expected nil for safe command %q, got: %v", cmd, err)
217  		}
218  	}
219  }
220  
221  func TestValidateMCPCommands_ShellBlocked(t *testing.T) {
222  	shells := []string{"sh", "bash", "zsh", "fish", "/bin/sh", "/bin/bash", "/usr/bin/zsh"}
223  	for _, cmd := range shells {
224  		servers := map[string]interface{}{
225  			"s": map[string]interface{}{"command": cmd},
226  		}
227  		if err := validateMCPCommands(servers, true); err == nil {
228  			t.Errorf("expected error for shell %q even with confirm, got nil", cmd)
229  		}
230  	}
231  }
232  
233  func TestValidateMCPCommands_EvalFlagBlocked(t *testing.T) {
234  	cases := []struct {
235  		cmd  string
236  		args []interface{}
237  	}{
238  		{"python", []interface{}{"-c", "print('hi')"}},
239  		{"node", []interface{}{"--eval", "console.log('hi')"}},
240  		{"python3", []interface{}{"-e", "print('hi')"}},
241  	}
242  	for _, c := range cases {
243  		servers := map[string]interface{}{
244  			"s": map[string]interface{}{"command": c.cmd, "args": c.args},
245  		}
246  		if err := validateMCPCommands(servers, true); err == nil {
247  			t.Errorf("expected error for %q with eval args %v even with confirm, got nil", c.cmd, c.args)
248  		}
249  	}
250  }
251  
252  func TestValidateMCPCommands_SafeCommandWithNormalArgs(t *testing.T) {
253  	servers := map[string]interface{}{
254  		"s": map[string]interface{}{
255  			"command": "python",
256  			"args":    []interface{}{"-m", "my_mcp_server", "--port", "3000"},
257  		},
258  	}
259  	if err := validateMCPCommands(servers, false); err != nil {
260  		t.Errorf("expected nil for python -m (no eval flag), got: %v", err)
261  	}
262  }
263  
264  func TestCheckProtectedFields_AliasNormalized(t *testing.T) {
265  	// Verify that after normalizePatchKeys, aliases are caught
266  	patch := map[string]interface{}{
267  		"apiKey": "sk-test",
268  	}
269  	normalizePatchKeys(patch)
270  	reason, isProtected := checkProtectedFields(patch)
271  	if !isProtected {
272  		t.Fatal("expected isProtected=true for aliased apiKey after normalization")
273  	}
274  	if reason != "changes authentication credentials" {
275  		t.Errorf("unexpected reason: %q", reason)
276  	}
277  }
278  
279  func TestValidateMCPCommands_WrapperBlocked(t *testing.T) {
280  	wrappers := []string{"env", "nohup", "sudo", "/usr/bin/env", "/usr/bin/sudo"}
281  	for _, cmd := range wrappers {
282  		servers := map[string]interface{}{
283  			"s": map[string]interface{}{"command": cmd, "args": []interface{}{"node", "server.js"}},
284  		}
285  		if err := validateMCPCommands(servers, true); err == nil {
286  			t.Errorf("expected error for wrapper %q even with confirm, got nil", cmd)
287  		}
288  	}
289  }
290  
291  func TestValidateMCPCommands_ShellInArgsBlocked(t *testing.T) {
292  	servers := map[string]interface{}{
293  		"s": map[string]interface{}{
294  			"command": "python",
295  			"args":    []interface{}{"bash", "-lc", "echo hi"},
296  		},
297  	}
298  	if err := validateMCPCommands(servers, true); err == nil {
299  		t.Error("expected error for shell in args, got nil")
300  	}
301  }
302  
303  func TestCheckProtectedFields_MCPServersAliasNormalized(t *testing.T) {
304  	patch := map[string]interface{}{
305  		"mcpServers": map[string]interface{}{
306  			"test": map[string]interface{}{
307  				"command": "node",
308  			},
309  		},
310  	}
311  	normalizePatchKeys(patch)
312  	if _, ok := patch["mcp_servers"]; !ok {
313  		t.Fatal("expected mcp_servers key after normalization")
314  	}
315  	if _, ok := patch["mcpServers"]; ok {
316  		t.Fatal("expected mcpServers alias to be removed after normalization")
317  	}
318  }