init.sh
1 #!/usr/bin/env bash 2 # init.sh — cc-ml init implementation 3 # 4 # Interactive scaffolding that auto-detects ML framework, metrics, 5 # checkpoints, datasets, and environment, then writes: 6 # - autoresearch.yaml (standard schema + ml advisory fields) 7 # - bench/benchmark.sh (training wrapper) 8 # - .gitignore additions 9 10 INIT_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 11 TEMPLATE_DIR="$INIT_SCRIPT_DIR/../templates" 12 13 # --- ML metric defaults --- 14 15 # Known metric patterns: name -> (regex, unit, criteria) 16 # These cover common logging formats across PyTorch, JAX, TF, and 17 # generic METRIC lines. 18 19 declare -A METRIC_REGEXES 20 declare -A METRIC_UNITS 21 declare -A METRIC_CRITERIA 22 23 # Loss variants (lower is better) 24 METRIC_REGEXES[train_loss]='(?:train[_ ]?loss|training[_ ]?loss)\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)' 25 METRIC_UNITS[train_loss]="loss" 26 METRIC_CRITERIA[train_loss]="lower_is_better" 27 28 METRIC_REGEXES[val_loss]='(?:val[_ ]?loss|validation[_ ]?loss|eval[_ ]?loss)\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)' 29 METRIC_UNITS[val_loss]="loss" 30 METRIC_CRITERIA[val_loss]="lower_is_better" 31 32 METRIC_REGEXES[loss]='(?:^|\s)loss\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)' 33 METRIC_UNITS[loss]="loss" 34 METRIC_CRITERIA[loss]="lower_is_better" 35 36 METRIC_REGEXES[perplexity]='(?:perplexity|ppl)\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)' 37 METRIC_UNITS[perplexity]="ppl" 38 METRIC_CRITERIA[perplexity]="lower_is_better" 39 40 # Accuracy variants (higher is better) 41 METRIC_REGEXES[accuracy]='(?:accuracy|acc)\s*[:=]\s*([\d.]+)' 42 METRIC_UNITS[accuracy]="%" 43 METRIC_CRITERIA[accuracy]="higher_is_better" 44 45 METRIC_REGEXES[val_accuracy]='(?:val[_ ]?acc(?:uracy)?|validation[_ ]?acc(?:uracy)?)\s*[:=]\s*([\d.]+)' 46 METRIC_UNITS[val_accuracy]="%" 47 METRIC_CRITERIA[val_accuracy]="higher_is_better" 48 49 METRIC_REGEXES[f1]='(?:f1[_ ]?score|f1)\s*[:=]\s*([\d.]+)' 50 METRIC_UNITS[f1]="score" 51 METRIC_CRITERIA[f1]="higher_is_better" 52 53 METRIC_REGEXES[bleu]='(?:bleu)\s*[:=]\s*([\d.]+)' 54 METRIC_UNITS[bleu]="score" 55 METRIC_CRITERIA[bleu]="higher_is_better" 56 57 METRIC_REGEXES[reward]='(?:reward|mean[_ ]?reward)\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)' 58 METRIC_UNITS[reward]="reward" 59 METRIC_CRITERIA[reward]="higher_is_better" 60 61 # Learning rate (tracked as secondary, lower is typically not "better") 62 METRIC_REGEXES[learning_rate]='(?:lr|learning[_ ]?rate)\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)' 63 METRIC_UNITS[learning_rate]="lr" 64 METRIC_CRITERIA[learning_rate]="lower_is_better" 65 66 # --- Auto-detection --- 67 68 detect_framework() { 69 local dir="$1" 70 local framework="generic" 71 72 # Scan Python files for framework imports. 73 local py_files 74 py_files=$(find "$dir" -maxdepth 3 -name "*.py" -not -path "*/\.*" 2>/dev/null | head -20) 75 76 for f in $py_files; do 77 if grep -qE '^\s*(import torch|from torch)' "$f" 2>/dev/null; then 78 framework="pytorch" 79 break 80 elif grep -qE '^\s*(import jax|from jax|import flax|from flax)' "$f" 2>/dev/null; then 81 framework="jax" 82 break 83 elif grep -qE '^\s*(import tensorflow|from tensorflow|import keras|from keras)' "$f" 2>/dev/null; then 84 framework="tensorflow" 85 # Don't break — pytorch takes priority if both are found. 86 fi 87 done 88 89 echo "$framework" 90 } 91 92 detect_training_script() { 93 local dir="$1" 94 95 # Look for common training script names. 96 for name in train.py main.py run.py run_train.py train_model.py; do 97 if [ -f "$dir/$name" ]; then 98 echo "$name" 99 return 100 fi 101 done 102 103 # Look for files with training-related content. 104 local py_files 105 py_files=$(find "$dir" -maxdepth 2 -name "*.py" -not -path "*/\.*" 2>/dev/null) 106 for f in $py_files; do 107 if grep -qE '(\.backward\(\)|\.fit\(|train_step|training_loop|trainer\.train)' "$f" 2>/dev/null; then 108 # Return relative path. 109 echo "${f#"$dir/"}" 110 return 111 fi 112 done 113 114 echo "" 115 } 116 117 detect_checkpoint_dir() { 118 local dir="$1" 119 120 # Check existing directories. 121 for name in checkpoints checkpoint ckpt models output runs saved_models; do 122 if [ -d "$dir/$name" ]; then 123 echo "./$name" 124 return 125 fi 126 done 127 128 # Parse from training scripts. 129 local py_files 130 py_files=$(find "$dir" -maxdepth 2 -name "*.py" -not -path "*/\.*" 2>/dev/null) 131 for f in $py_files; do 132 local match 133 match=$(grep -oE '(--output_dir|--checkpoint_dir|--save_dir|checkpoint_dir|output_dir)\s*[=,]\s*["\x27]([^"\x27]+)["\x27]' "$f" 2>/dev/null | head -1 | sed 's/.*["'"'"']\(.*\)["'"'"'].*/\1/') 134 if [ -n "$match" ]; then 135 echo "./$match" 136 return 137 fi 138 done 139 140 echo "./checkpoints" 141 } 142 143 detect_data_dirs() { 144 local dir="$1" 145 local found=() 146 147 # Check existing directories. 148 for name in data dataset datasets train_data; do 149 if [ -d "$dir/$name" ]; then 150 found+=("./$name") 151 fi 152 done 153 154 # If nothing found, check training script args. 155 if [ ${#found[@]} -eq 0 ]; then 156 local py_files 157 py_files=$(find "$dir" -maxdepth 2 -name "*.py" -not -path "*/\.*" 2>/dev/null) 158 for f in $py_files; do 159 local match 160 match=$(grep -oE '(--data_dir|--data_path|--train_file|--dataset_path)\s*[=,]\s*["\x27]([^"\x27]+)["\x27]' "$f" 2>/dev/null | head -1 | sed 's/.*["'"'"']\(.*\)["'"'"'].*/\1/') 161 if [ -n "$match" ] && [ -e "$dir/$match" ]; then 162 found+=("./$match") 163 fi 164 done 165 fi 166 167 # Return as newline-separated list. 168 printf '%s\n' "${found[@]}" 169 } 170 171 detect_build_cmd() { 172 local dir="$1" 173 174 if [ -f "$dir/requirements.txt" ]; then 175 if command -v uv >/dev/null 2>&1; then 176 echo "uv pip install -r requirements.txt" 177 else 178 echo "pip install -r requirements.txt" 179 fi 180 elif [ -f "$dir/pyproject.toml" ]; then 181 if command -v uv >/dev/null 2>&1; then 182 echo "uv pip install -e ." 183 else 184 echo "pip install -e ." 185 fi 186 elif [ -f "$dir/setup.py" ]; then 187 echo "pip install -e ." 188 elif [ -f "$dir/environment.yml" ]; then 189 echo "conda env update -f environment.yml" 190 elif [ -f "$dir/Pipfile" ]; then 191 echo "pipenv install" 192 else 193 echo "" 194 fi 195 } 196 197 suggest_metrics() { 198 local framework="$1" 199 200 case "$framework" in 201 pytorch) 202 echo "val_loss train_loss accuracy learning_rate" 203 ;; 204 jax) 205 echo "val_loss train_loss accuracy learning_rate" 206 ;; 207 tensorflow) 208 echo "val_loss loss accuracy val_accuracy" 209 ;; 210 *) 211 echo "loss accuracy" 212 ;; 213 esac 214 } 215 216 # --- Interactive prompts --- 217 218 prompt_value() { 219 local prompt="$1" 220 local default="$2" 221 local value 222 223 if [ -n "$default" ]; then 224 printf "%s [%s]: " "$prompt" "$default" >&2 225 else 226 printf "%s: " "$prompt" >&2 227 fi 228 229 read -r value 230 if [ -z "$value" ]; then 231 echo "$default" 232 else 233 echo "$value" 234 fi 235 } 236 237 prompt_yes_no() { 238 local prompt="$1" 239 local default="${2:-y}" 240 local value 241 242 if [ "$default" = "y" ]; then 243 printf "%s [Y/n]: " "$prompt" >&2 244 else 245 printf "%s [y/N]: " "$prompt" >&2 246 fi 247 248 read -r value 249 value=${value:-$default} 250 case "$value" in 251 [yY]*) return 0 ;; 252 *) return 1 ;; 253 esac 254 } 255 256 prompt_choice() { 257 local prompt="$1" 258 shift 259 local options=("$@") 260 261 echo "$prompt" >&2 262 for i in "${!options[@]}"; do 263 echo " $((i + 1))) ${options[$i]}" >&2 264 done 265 printf "Choice [1]: " >&2 266 267 local choice 268 read -r choice 269 choice=${choice:-1} 270 271 if [ "$choice" -ge 1 ] && [ "$choice" -le "${#options[@]}" ] 2>/dev/null; then 272 echo "${options[$((choice - 1))]}" 273 else 274 echo "${options[0]}" 275 fi 276 } 277 278 # --- Benchmark script generation --- 279 280 generate_benchmark_script() { 281 local train_cmd="$1" 282 local primary_metric="$2" 283 local primary_regex="$3" 284 local secondary_metrics="$4" # space-separated 285 286 local extraction="" 287 288 # Primary metric: extract last occurrence. 289 extraction+="# Extract last occurrence of each metric (final epoch values).\n" 290 extraction+="# The regex patterns here match what autoresearch.yaml expects.\n\n" 291 292 extraction+="# Primary metric\n" 293 extraction+="PRIMARY_VAL=\$(grep -oP '${primary_regex}' \"\$OUTPUT_FILE\" | tail -1 | grep -oP '[\\d.]+(?:e[+-]?\\d+)?' | tail -1)\n" 294 extraction+="if [ -n \"\$PRIMARY_VAL\" ]; then\n" 295 extraction+=" echo \"${primary_metric} : \$PRIMARY_VAL\"\n" 296 extraction+="else\n" 297 extraction+=" echo \"WARNING: could not extract ${primary_metric} from output\" >&2\n" 298 extraction+=" exit 1\n" 299 extraction+="fi\n" 300 301 # Secondary metrics. 302 for metric in $secondary_metrics; do 303 if [ "$metric" = "$primary_metric" ]; then 304 continue 305 fi 306 local regex="${METRIC_REGEXES[$metric]:-}" 307 if [ -n "$regex" ]; then 308 extraction+="\n# Secondary: $metric\n" 309 extraction+="${metric^^}_VAL=\$(grep -oP '${regex}' \"\$OUTPUT_FILE\" | tail -1 | grep -oP '[\\d.]+(?:e[+-]?\\d+)?' | tail -1)\n" 310 extraction+="if [ -n \"\${${metric^^}_VAL}\" ]; then\n" 311 extraction+=" echo \"${metric} : \${${metric^^}_VAL}\"\n" 312 extraction+="fi\n" 313 fi 314 done 315 316 echo -e "$extraction" 317 } 318 319 # --- Main init command --- 320 321 cmd_init() { 322 local dir 323 dir=$(pwd) 324 325 # Parse flags. 326 local non_interactive=false 327 while [ $# -gt 0 ]; do 328 case "$1" in 329 --no-input) non_interactive=true; shift ;; 330 -h|--help) 331 echo "Usage: cc-ml init [--no-input]" 332 echo "" 333 echo "Scaffold autoresearch.yaml + bench/benchmark.sh for ML training." 334 echo "Auto-detects framework, metrics, checkpoints, and datasets." 335 echo "" 336 echo "Options:" 337 echo " --no-input Accept all defaults without prompting" 338 exit 0 339 ;; 340 *) die "unknown flag: $1" ;; 341 esac 342 done 343 344 if [ -f "$dir/autoresearch.yaml" ]; then 345 if ! $non_interactive && ! prompt_yes_no "autoresearch.yaml already exists. Overwrite?"; then 346 info "aborting" 347 exit 0 348 fi 349 fi 350 351 require_cmds git jq rad-artifact 352 353 info "detecting ML project configuration..." 354 355 # 1. Framework. 356 local framework 357 framework=$(detect_framework "$dir") 358 if ! $non_interactive; then 359 framework=$(prompt_value "Framework detected" "$framework") 360 fi 361 info "framework: $framework" 362 363 # 2. Training script. 364 local train_script 365 train_script=$(detect_training_script "$dir") 366 if ! $non_interactive; then 367 train_script=$(prompt_value "Training script" "$train_script") 368 fi 369 if [ -z "$train_script" ]; then 370 die "no training script found — create one first or specify the path" 371 fi 372 info "training script: $train_script" 373 374 # 3. Training command. 375 local train_cmd="python $train_script" 376 if ! $non_interactive; then 377 train_cmd=$(prompt_value "Training command" "$train_cmd") 378 fi 379 380 # 4. Primary metric. 381 local suggested 382 suggested=$(suggest_metrics "$framework") 383 local primary_metric 384 primary_metric=$(echo "$suggested" | cut -d' ' -f1) 385 if ! $non_interactive; then 386 echo "Suggested metrics for $framework: $suggested" >&2 387 primary_metric=$(prompt_value "Primary metric (optimization target)" "$primary_metric") 388 fi 389 390 local primary_regex="${METRIC_REGEXES[$primary_metric]:-}" 391 local primary_unit="${METRIC_UNITS[$primary_metric]:-}" 392 local primary_criteria="${METRIC_CRITERIA[$primary_metric]:-}" 393 394 # If unknown metric, ask for details. 395 if [ -z "$primary_regex" ]; then 396 if ! $non_interactive; then 397 primary_regex=$(prompt_value "Regex to extract $primary_metric (with capture group)" "") 398 primary_unit=$(prompt_value "Unit for $primary_metric" "") 399 # Infer criteria from name. 400 if echo "$primary_metric" | grep -qiE 'loss|error|perplexity|mse|mae'; then 401 primary_criteria="lower_is_better" 402 elif echo "$primary_metric" | grep -qiE 'accuracy|acc|f1|bleu|reward|score'; then 403 primary_criteria="higher_is_better" 404 else 405 primary_criteria=$(prompt_choice "Criteria for $primary_metric" "lower_is_better" "higher_is_better") 406 fi 407 else 408 die "unknown metric '$primary_metric' — run without --no-input to configure" 409 fi 410 fi 411 412 if ! $non_interactive; then 413 primary_criteria=$(prompt_value "Criteria" "$primary_criteria") 414 fi 415 info "primary metric: $primary_metric ($primary_unit, $primary_criteria)" 416 417 # 5. Secondary metrics. 418 local secondary_metrics="" 419 local remaining 420 remaining=$(echo "$suggested" | cut -d' ' -f2-) 421 if ! $non_interactive && [ -n "$remaining" ]; then 422 echo "Suggested secondary metrics: $remaining" >&2 423 secondary_metrics=$(prompt_value "Secondary metrics (space-separated, or empty)" "$remaining") 424 else 425 secondary_metrics="$remaining" 426 fi 427 428 # 6. Checkpoint directory. 429 local checkpoint_dir 430 checkpoint_dir=$(detect_checkpoint_dir "$dir") 431 if ! $non_interactive; then 432 checkpoint_dir=$(prompt_value "Checkpoint directory" "$checkpoint_dir") 433 fi 434 info "checkpoint dir: $checkpoint_dir" 435 436 # 7. Datasets. 437 local data_dirs_raw 438 data_dirs_raw=$(detect_data_dirs "$dir") 439 local datasets_json="[]" 440 local datasets_yaml="[]" 441 442 if [ -n "$data_dirs_raw" ]; then 443 info "detected data directories:" 444 local datasets_arr=() 445 while IFS= read -r dpath; do 446 [ -z "$dpath" ] && continue 447 echo " $dpath" >&2 448 local dname 449 dname=$(basename "$dpath") 450 local dcid="" 451 if [ -e "$dir/$dpath" ]; then 452 info "computing CID for $dpath (this may take a while for large datasets)..." 453 dcid=$(rad-artifact cid "$dir/$dpath" 2>/dev/null || echo "") 454 fi 455 if [ -n "$dcid" ]; then 456 datasets_arr+=("{\"name\":\"$dname\",\"path\":\"$dpath\",\"cid\":\"$dcid\"}") 457 info " $dname: $dcid" 458 else 459 datasets_arr+=("{\"name\":\"$dname\",\"path\":\"$dpath\",\"cid\":null}") 460 warn " could not compute CID for $dpath" 461 fi 462 done <<< "$data_dirs_raw" 463 464 if [ ${#datasets_arr[@]} -gt 0 ]; then 465 datasets_json=$(printf '%s\n' "${datasets_arr[@]}" | jq -s '.') 466 fi 467 fi 468 469 if ! $non_interactive; then 470 if [ "$datasets_json" = "[]" ]; then 471 local custom_data 472 custom_data=$(prompt_value "Data directory (or empty to skip)" "") 473 if [ -n "$custom_data" ]; then 474 local dname 475 dname=$(basename "$custom_data") 476 local dcid="" 477 if [ -e "$dir/$custom_data" ]; then 478 info "computing CID for $custom_data..." 479 dcid=$(rad-artifact cid "$dir/$custom_data" 2>/dev/null || echo "") 480 fi 481 if [ -n "$dcid" ]; then 482 datasets_json="[{\"name\":\"$dname\",\"path\":\"$custom_data\",\"cid\":\"$dcid\"}]" 483 else 484 datasets_json="[{\"name\":\"$dname\",\"path\":\"$custom_data\",\"cid\":null}]" 485 fi 486 fi 487 fi 488 fi 489 490 # 8. Build command. 491 local build_cmd 492 build_cmd=$(detect_build_cmd "$dir") 493 if ! $non_interactive; then 494 build_cmd=$(prompt_value "Build/install command" "$build_cmd") 495 fi 496 497 # 9. Hot files (training script + config files). 498 local hot_files_list="[\"$train_script\"]" 499 for cfg in config.yaml config.json hparams.yaml; do 500 if [ -f "$dir/$cfg" ]; then 501 hot_files_list=$(echo "$hot_files_list" | jq --arg f "$cfg" '. + [$f]') 502 fi 503 done 504 505 # --- Write files --- 506 507 info "writing autoresearch.yaml..." 508 509 # Use python for clean YAML generation. 510 python3 -c " 511 import yaml, json, sys 512 513 class Dumper(yaml.SafeDumper): 514 pass 515 516 def dict_representer(dumper, data): 517 return dumper.represent_mapping('tag:yaml.org,2002:map', data.items()) 518 Dumper.add_representer(dict, dict_representer) 519 520 config = { 521 'build_cmd': $(printf '%s' "$build_cmd" | jq -Rs .), 522 'bench_cmd': './bench/benchmark.sh', 523 'bench_dir': 'bench', 524 'test_cmd': None, 525 'metrics': [ 526 { 527 'name': $(printf '%s' "$primary_metric" | jq -Rs .), 528 'unit': $(printf '%s' "$primary_unit" | jq -Rs .), 529 'regex': $(printf '%s' "$primary_regex" | jq -Rs .), 530 'criteria': $(printf '%s' "$primary_criteria" | jq -Rs .), 531 } 532 ], 533 'forbidden_paths': [ 534 'autoresearch.yaml', 535 'bench/', 536 $(printf '%s' "$checkpoint_dir" | jq -Rs .), 537 ], 538 'hot_files': json.loads($(echo "$hot_files_list" | jq -Rs .)), 539 'timeout_secs': 1800, 540 'ml': { 541 'framework': $(printf '%s' "$framework" | jq -Rs .), 542 'checkpoint_dir': $(printf '%s' "$checkpoint_dir" | jq -Rs .), 543 'datasets': json.loads($(echo "$datasets_json" | jq -Rs .)), 544 }, 545 } 546 547 with open('autoresearch.yaml', 'w') as f: 548 f.write('# autoresearch.yaml \u2014 ML training config\\n') 549 f.write('# Generated by cc-ml init. Consumed by rad-experiment benchmark.\\n\\n') 550 yaml.dump(config, f, Dumper=Dumper, default_flow_style=False, sort_keys=False) 551 " || die "failed to write autoresearch.yaml" 552 553 # Write benchmark script. 554 info "writing bench/benchmark.sh..." 555 mkdir -p "$dir/bench" 556 557 local metric_extraction 558 metric_extraction=$(generate_benchmark_script "$train_cmd" "$primary_metric" "$primary_regex" "$secondary_metrics") 559 560 sed \ 561 -e "s|__TRAIN_CMD__|$train_cmd|g" \ 562 -e "s|__METRIC_EXTRACTION__|$metric_extraction|g" \ 563 "$TEMPLATE_DIR/benchmark.sh" > "$dir/bench/benchmark.sh" 564 565 chmod +x "$dir/bench/benchmark.sh" 566 567 # Update .gitignore. 568 info "updating .gitignore..." 569 local gitignore="$dir/.gitignore" 570 touch "$gitignore" 571 572 for entry in ".community-computer/" "$checkpoint_dir"; do 573 # Strip leading ./ for gitignore. 574 local clean_entry="${entry#./}" 575 if ! grep -qF "$clean_entry" "$gitignore" 2>/dev/null; then 576 echo "$clean_entry" >> "$gitignore" 577 fi 578 done 579 580 # Summary. 581 echo "" >&2 582 info "=== cc-ml init complete ===" 583 info " autoresearch.yaml — training config (edit metrics/hot_files as needed)" 584 info " bench/benchmark.sh — training wrapper (edit if output format differs)" 585 info " .gitignore — updated" 586 echo "" >&2 587 info "Next steps:" 588 info " 1. Review autoresearch.yaml and bench/benchmark.sh" 589 info " 2. Run a test: ./bench/benchmark.sh" 590 info " 3. Start optimizing: use the cc-ml-train skill or rad-experiment" 591 592 # Note about dataset publishing. 593 local ds_with_cids 594 ds_with_cids=$(echo "$datasets_json" | jq '[.[] | select(.cid != null)] | length') 595 if [ "$ds_with_cids" -gt 0 ]; then 596 echo "" >&2 597 info "Dataset CIDs computed. They will be published as artifacts at" 598 info "CHOSEN_BASE when the first experiment baseline is recorded." 599 info "To register a location (e.g., HuggingFace, S3), run:" 600 info " rad-artifact location add <commit> --cid <CID> <URL>" 601 fi 602 }