safeguard_test.go
1 package daemon 2 3 import ( 4 "testing" 5 ) 6 7 // --- requireConfirm --- 8 9 func TestRequireConfirm(t *testing.T) { 10 cases := []struct { 11 input string 12 want bool 13 }{ 14 {"", true}, 15 {"false", true}, 16 {"1", true}, 17 {"yes", true}, 18 {"True", true}, 19 {"true", false}, 20 } 21 for _, c := range cases { 22 got := requireConfirm(c.input) 23 if got != c.want { 24 t.Errorf("requireConfirm(%q) = %v, want %v", c.input, got, c.want) 25 } 26 } 27 } 28 29 // --- checkProtectedFields --- 30 31 func TestCheckProtectedFields_Safe(t *testing.T) { 32 patch := map[string]interface{}{ 33 "model": "claude-3-7", 34 "timeout": 30, 35 } 36 reason, isProtected := checkProtectedFields(patch) 37 if isProtected { 38 t.Errorf("expected no protected field, got reason=%q", reason) 39 } 40 } 41 42 func TestCheckProtectedFields_Endpoint(t *testing.T) { 43 patch := map[string]interface{}{ 44 "endpoint": "https://evil.example.com", 45 } 46 reason, isProtected := checkProtectedFields(patch) 47 if !isProtected { 48 t.Fatal("expected isProtected=true for endpoint") 49 } 50 if reason != "changes API connection target" { 51 t.Errorf("unexpected reason: %q", reason) 52 } 53 } 54 55 func TestCheckProtectedFields_APIKey(t *testing.T) { 56 patch := map[string]interface{}{ 57 "api_key": "sk-leaked", 58 } 59 reason, isProtected := checkProtectedFields(patch) 60 if !isProtected { 61 t.Fatal("expected isProtected=true for api_key") 62 } 63 if reason != "changes authentication credentials" { 64 t.Errorf("unexpected reason: %q", reason) 65 } 66 } 67 68 func TestCheckProtectedFields_PermissionsDeniedCommands(t *testing.T) { 69 patch := map[string]interface{}{ 70 "permissions": map[string]interface{}{ 71 "denied_commands": []string{"rm"}, 72 }, 73 } 74 reason, isProtected := checkProtectedFields(patch) 75 if !isProtected { 76 t.Fatal("expected isProtected=true for permissions.denied_commands") 77 } 78 if reason != "removes security restrictions" { 79 t.Errorf("unexpected reason: %q", reason) 80 } 81 } 82 83 func TestCheckProtectedFields_DaemonAutoApprove(t *testing.T) { 84 patch := map[string]interface{}{ 85 "daemon": map[string]interface{}{ 86 "auto_approve": true, 87 }, 88 } 89 if _, isProtected := checkProtectedFields(patch); isProtected { 90 t.Fatal("daemon.auto_approve should be settable via API") 91 } 92 } 93 94 func TestCheckProtectedFields_NestedParentNotMap(t *testing.T) { 95 // If the parent key exists but isn't a map, it shouldn't panic or false-positive 96 patch := map[string]interface{}{ 97 "permissions": "some string value", 98 } 99 _, isProtected := checkProtectedFields(patch) 100 if isProtected { 101 t.Error("expected isProtected=false when parent value is not a map") 102 } 103 } 104 105 // --- validateMCPCommands --- 106 107 func TestValidateMCPCommands_SafeNpx(t *testing.T) { 108 servers := map[string]interface{}{ 109 "myserver": map[string]interface{}{ 110 "command": "npx", 111 "args": []string{"-y", "some-mcp-server"}, 112 }, 113 } 114 if err := validateMCPCommands(servers, false); err != nil { 115 t.Errorf("expected nil for safe npx, got: %v", err) 116 } 117 } 118 119 func TestValidateMCPCommands_AbsolutePath(t *testing.T) { 120 servers := map[string]interface{}{ 121 "myserver": map[string]interface{}{ 122 "command": "/usr/local/bin/mcp-server", 123 }, 124 } 125 if err := validateMCPCommands(servers, false); err != nil { 126 t.Errorf("expected nil for absolute path, got: %v", err) 127 } 128 } 129 130 func TestValidateMCPCommands_HTTPTypeSkipped(t *testing.T) { 131 servers := map[string]interface{}{ 132 "remote": map[string]interface{}{ 133 "type": "http", 134 "command": "rm; evil", 135 }, 136 } 137 if err := validateMCPCommands(servers, false); err != nil { 138 t.Errorf("expected nil for non-stdio type, got: %v", err) 139 } 140 } 141 142 func TestValidateMCPCommands_ShellMetachar_Semicolon(t *testing.T) { 143 servers := map[string]interface{}{ 144 "evil": map[string]interface{}{ 145 "command": "node; rm -rf /", 146 }, 147 } 148 if err := validateMCPCommands(servers, false); err == nil { 149 t.Error("expected error for shell metachar (semicolon), got nil") 150 } 151 } 152 153 func TestValidateMCPCommands_ShellMetachar_Pipe(t *testing.T) { 154 servers := map[string]interface{}{ 155 "evil": map[string]interface{}{ 156 "command": "node|cat", 157 }, 158 } 159 if err := validateMCPCommands(servers, false); err == nil { 160 t.Error("expected error for shell metachar (pipe), got nil") 161 } 162 } 163 164 func TestValidateMCPCommands_UnknownCommand_NoConfirm(t *testing.T) { 165 servers := map[string]interface{}{ 166 "custom": map[string]interface{}{ 167 "command": "my-custom-mcp-server", 168 }, 169 } 170 if err := validateMCPCommands(servers, false); err == nil { 171 t.Error("expected error for unknown command without confirm, got nil") 172 } 173 } 174 175 func TestValidateMCPCommands_UnknownCommand_WithConfirm(t *testing.T) { 176 servers := map[string]interface{}{ 177 "custom": map[string]interface{}{ 178 "command": "my-custom-mcp-server", 179 }, 180 } 181 if err := validateMCPCommands(servers, true); err != nil { 182 t.Errorf("expected nil for unknown command with confirm, got: %v", err) 183 } 184 } 185 186 func TestValidateMCPCommands_MetacharAlwaysBlocked_EvenWithConfirm(t *testing.T) { 187 servers := map[string]interface{}{ 188 "evil": map[string]interface{}{ 189 "command": "node$(evil)", 190 }, 191 } 192 if err := validateMCPCommands(servers, true); err == nil { 193 t.Error("expected error for metachar even with confirm=true, got nil") 194 } 195 } 196 197 func TestValidateMCPCommands_NoCommand(t *testing.T) { 198 // Servers without "command" field should be skipped 199 servers := map[string]interface{}{ 200 "nocommand": map[string]interface{}{ 201 "url": "http://localhost:3000", 202 }, 203 } 204 if err := validateMCPCommands(servers, false); err != nil { 205 t.Errorf("expected nil for server without command, got: %v", err) 206 } 207 } 208 209 func TestValidateMCPCommands_AllSafeCommands(t *testing.T) { 210 safe := []string{"node", "npx", "python", "python3", "uvx", "uv", "go", "deno", "bun", "docker", "pip", "pipx"} 211 for _, cmd := range safe { 212 servers := map[string]interface{}{ 213 "s": map[string]interface{}{"command": cmd}, 214 } 215 if err := validateMCPCommands(servers, false); err != nil { 216 t.Errorf("expected nil for safe command %q, got: %v", cmd, err) 217 } 218 } 219 } 220 221 func TestValidateMCPCommands_ShellBlocked(t *testing.T) { 222 shells := []string{"sh", "bash", "zsh", "fish", "/bin/sh", "/bin/bash", "/usr/bin/zsh"} 223 for _, cmd := range shells { 224 servers := map[string]interface{}{ 225 "s": map[string]interface{}{"command": cmd}, 226 } 227 if err := validateMCPCommands(servers, true); err == nil { 228 t.Errorf("expected error for shell %q even with confirm, got nil", cmd) 229 } 230 } 231 } 232 233 func TestValidateMCPCommands_EvalFlagBlocked(t *testing.T) { 234 cases := []struct { 235 cmd string 236 args []interface{} 237 }{ 238 {"python", []interface{}{"-c", "print('hi')"}}, 239 {"node", []interface{}{"--eval", "console.log('hi')"}}, 240 {"python3", []interface{}{"-e", "print('hi')"}}, 241 } 242 for _, c := range cases { 243 servers := map[string]interface{}{ 244 "s": map[string]interface{}{"command": c.cmd, "args": c.args}, 245 } 246 if err := validateMCPCommands(servers, true); err == nil { 247 t.Errorf("expected error for %q with eval args %v even with confirm, got nil", c.cmd, c.args) 248 } 249 } 250 } 251 252 func TestValidateMCPCommands_SafeCommandWithNormalArgs(t *testing.T) { 253 servers := map[string]interface{}{ 254 "s": map[string]interface{}{ 255 "command": "python", 256 "args": []interface{}{"-m", "my_mcp_server", "--port", "3000"}, 257 }, 258 } 259 if err := validateMCPCommands(servers, false); err != nil { 260 t.Errorf("expected nil for python -m (no eval flag), got: %v", err) 261 } 262 } 263 264 func TestCheckProtectedFields_AliasNormalized(t *testing.T) { 265 // Verify that after normalizePatchKeys, aliases are caught 266 patch := map[string]interface{}{ 267 "apiKey": "sk-test", 268 } 269 normalizePatchKeys(patch) 270 reason, isProtected := checkProtectedFields(patch) 271 if !isProtected { 272 t.Fatal("expected isProtected=true for aliased apiKey after normalization") 273 } 274 if reason != "changes authentication credentials" { 275 t.Errorf("unexpected reason: %q", reason) 276 } 277 } 278 279 func TestValidateMCPCommands_WrapperBlocked(t *testing.T) { 280 wrappers := []string{"env", "nohup", "sudo", "/usr/bin/env", "/usr/bin/sudo"} 281 for _, cmd := range wrappers { 282 servers := map[string]interface{}{ 283 "s": map[string]interface{}{"command": cmd, "args": []interface{}{"node", "server.js"}}, 284 } 285 if err := validateMCPCommands(servers, true); err == nil { 286 t.Errorf("expected error for wrapper %q even with confirm, got nil", cmd) 287 } 288 } 289 } 290 291 func TestValidateMCPCommands_ShellInArgsBlocked(t *testing.T) { 292 servers := map[string]interface{}{ 293 "s": map[string]interface{}{ 294 "command": "python", 295 "args": []interface{}{"bash", "-lc", "echo hi"}, 296 }, 297 } 298 if err := validateMCPCommands(servers, true); err == nil { 299 t.Error("expected error for shell in args, got nil") 300 } 301 } 302 303 func TestCheckProtectedFields_MCPServersAliasNormalized(t *testing.T) { 304 patch := map[string]interface{}{ 305 "mcpServers": map[string]interface{}{ 306 "test": map[string]interface{}{ 307 "command": "node", 308 }, 309 }, 310 } 311 normalizePatchKeys(patch) 312 if _, ok := patch["mcp_servers"]; !ok { 313 t.Fatal("expected mcp_servers key after normalization") 314 } 315 if _, ok := patch["mcpServers"]; ok { 316 t.Fatal("expected mcpServers alias to be removed after normalization") 317 } 318 }