/ lib / init.sh
init.sh
  1  #!/usr/bin/env bash
  2  # init.sh — cc-ml init implementation
  3  #
  4  # Interactive scaffolding that auto-detects ML framework, metrics,
  5  # checkpoints, datasets, and environment, then writes:
  6  #   - autoresearch.yaml (standard schema + ml advisory fields)
  7  #   - bench/benchmark.sh (training wrapper)
  8  #   - .gitignore additions
  9  
 10  INIT_SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
 11  TEMPLATE_DIR="$INIT_SCRIPT_DIR/../templates"
 12  
 13  # --- ML metric defaults ---
 14  
 15  # Known metric patterns: name -> (regex, unit, criteria)
 16  # These cover common logging formats across PyTorch, JAX, TF, and
 17  # generic METRIC lines.
 18  
 19  declare -A METRIC_REGEXES
 20  declare -A METRIC_UNITS
 21  declare -A METRIC_CRITERIA
 22  
 23  # Loss variants (lower is better)
 24  METRIC_REGEXES[train_loss]='(?:train[_ ]?loss|training[_ ]?loss)\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)'
 25  METRIC_UNITS[train_loss]="loss"
 26  METRIC_CRITERIA[train_loss]="lower_is_better"
 27  
 28  METRIC_REGEXES[val_loss]='(?:val[_ ]?loss|validation[_ ]?loss|eval[_ ]?loss)\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)'
 29  METRIC_UNITS[val_loss]="loss"
 30  METRIC_CRITERIA[val_loss]="lower_is_better"
 31  
 32  METRIC_REGEXES[loss]='(?:^|\s)loss\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)'
 33  METRIC_UNITS[loss]="loss"
 34  METRIC_CRITERIA[loss]="lower_is_better"
 35  
 36  METRIC_REGEXES[perplexity]='(?:perplexity|ppl)\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)'
 37  METRIC_UNITS[perplexity]="ppl"
 38  METRIC_CRITERIA[perplexity]="lower_is_better"
 39  
 40  # Accuracy variants (higher is better)
 41  METRIC_REGEXES[accuracy]='(?:accuracy|acc)\s*[:=]\s*([\d.]+)'
 42  METRIC_UNITS[accuracy]="%"
 43  METRIC_CRITERIA[accuracy]="higher_is_better"
 44  
 45  METRIC_REGEXES[val_accuracy]='(?:val[_ ]?acc(?:uracy)?|validation[_ ]?acc(?:uracy)?)\s*[:=]\s*([\d.]+)'
 46  METRIC_UNITS[val_accuracy]="%"
 47  METRIC_CRITERIA[val_accuracy]="higher_is_better"
 48  
 49  METRIC_REGEXES[f1]='(?:f1[_ ]?score|f1)\s*[:=]\s*([\d.]+)'
 50  METRIC_UNITS[f1]="score"
 51  METRIC_CRITERIA[f1]="higher_is_better"
 52  
 53  METRIC_REGEXES[bleu]='(?:bleu)\s*[:=]\s*([\d.]+)'
 54  METRIC_UNITS[bleu]="score"
 55  METRIC_CRITERIA[bleu]="higher_is_better"
 56  
 57  METRIC_REGEXES[reward]='(?:reward|mean[_ ]?reward)\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)'
 58  METRIC_UNITS[reward]="reward"
 59  METRIC_CRITERIA[reward]="higher_is_better"
 60  
 61  # Learning rate (tracked as secondary, lower is typically not "better")
 62  METRIC_REGEXES[learning_rate]='(?:lr|learning[_ ]?rate)\s*[:=]\s*([\d.]+(?:e[+-]?\d+)?)'
 63  METRIC_UNITS[learning_rate]="lr"
 64  METRIC_CRITERIA[learning_rate]="lower_is_better"
 65  
 66  # --- Auto-detection ---
 67  
 68  detect_framework() {
 69    local dir="$1"
 70    local framework="generic"
 71  
 72    # Scan Python files for framework imports.
 73    local py_files
 74    py_files=$(find "$dir" -maxdepth 3 -name "*.py" -not -path "*/\.*" 2>/dev/null | head -20)
 75  
 76    for f in $py_files; do
 77      if grep -qE '^\s*(import torch|from torch)' "$f" 2>/dev/null; then
 78        framework="pytorch"
 79        break
 80      elif grep -qE '^\s*(import jax|from jax|import flax|from flax)' "$f" 2>/dev/null; then
 81        framework="jax"
 82        break
 83      elif grep -qE '^\s*(import tensorflow|from tensorflow|import keras|from keras)' "$f" 2>/dev/null; then
 84        framework="tensorflow"
 85        # Don't break — pytorch takes priority if both are found.
 86      fi
 87    done
 88  
 89    echo "$framework"
 90  }
 91  
 92  detect_training_script() {
 93    local dir="$1"
 94  
 95    # Look for common training script names.
 96    for name in train.py main.py run.py run_train.py train_model.py; do
 97      if [ -f "$dir/$name" ]; then
 98        echo "$name"
 99        return
100      fi
101    done
102  
103    # Look for files with training-related content.
104    local py_files
105    py_files=$(find "$dir" -maxdepth 2 -name "*.py" -not -path "*/\.*" 2>/dev/null)
106    for f in $py_files; do
107      if grep -qE '(\.backward\(\)|\.fit\(|train_step|training_loop|trainer\.train)' "$f" 2>/dev/null; then
108        # Return relative path.
109        echo "${f#"$dir/"}"
110        return
111      fi
112    done
113  
114    echo ""
115  }
116  
117  detect_checkpoint_dir() {
118    local dir="$1"
119  
120    # Check existing directories.
121    for name in checkpoints checkpoint ckpt models output runs saved_models; do
122      if [ -d "$dir/$name" ]; then
123        echo "./$name"
124        return
125      fi
126    done
127  
128    # Parse from training scripts.
129    local py_files
130    py_files=$(find "$dir" -maxdepth 2 -name "*.py" -not -path "*/\.*" 2>/dev/null)
131    for f in $py_files; do
132      local match
133      match=$(grep -oE '(--output_dir|--checkpoint_dir|--save_dir|checkpoint_dir|output_dir)\s*[=,]\s*["\x27]([^"\x27]+)["\x27]' "$f" 2>/dev/null | head -1 | sed 's/.*["'"'"']\(.*\)["'"'"'].*/\1/')
134      if [ -n "$match" ]; then
135        echo "./$match"
136        return
137      fi
138    done
139  
140    echo "./checkpoints"
141  }
142  
143  detect_data_dirs() {
144    local dir="$1"
145    local found=()
146  
147    # Check existing directories.
148    for name in data dataset datasets train_data; do
149      if [ -d "$dir/$name" ]; then
150        found+=("./$name")
151      fi
152    done
153  
154    # If nothing found, check training script args.
155    if [ ${#found[@]} -eq 0 ]; then
156      local py_files
157      py_files=$(find "$dir" -maxdepth 2 -name "*.py" -not -path "*/\.*" 2>/dev/null)
158      for f in $py_files; do
159        local match
160        match=$(grep -oE '(--data_dir|--data_path|--train_file|--dataset_path)\s*[=,]\s*["\x27]([^"\x27]+)["\x27]' "$f" 2>/dev/null | head -1 | sed 's/.*["'"'"']\(.*\)["'"'"'].*/\1/')
161        if [ -n "$match" ] && [ -e "$dir/$match" ]; then
162          found+=("./$match")
163        fi
164      done
165    fi
166  
167    # Return as newline-separated list.
168    printf '%s\n' "${found[@]}"
169  }
170  
171  detect_build_cmd() {
172    local dir="$1"
173  
174    if [ -f "$dir/requirements.txt" ]; then
175      if command -v uv >/dev/null 2>&1; then
176        echo "uv pip install -r requirements.txt"
177      else
178        echo "pip install -r requirements.txt"
179      fi
180    elif [ -f "$dir/pyproject.toml" ]; then
181      if command -v uv >/dev/null 2>&1; then
182        echo "uv pip install -e ."
183      else
184        echo "pip install -e ."
185      fi
186    elif [ -f "$dir/setup.py" ]; then
187      echo "pip install -e ."
188    elif [ -f "$dir/environment.yml" ]; then
189      echo "conda env update -f environment.yml"
190    elif [ -f "$dir/Pipfile" ]; then
191      echo "pipenv install"
192    else
193      echo ""
194    fi
195  }
196  
197  suggest_metrics() {
198    local framework="$1"
199  
200    case "$framework" in
201      pytorch)
202        echo "val_loss train_loss accuracy learning_rate"
203        ;;
204      jax)
205        echo "val_loss train_loss accuracy learning_rate"
206        ;;
207      tensorflow)
208        echo "val_loss loss accuracy val_accuracy"
209        ;;
210      *)
211        echo "loss accuracy"
212        ;;
213    esac
214  }
215  
216  # --- Interactive prompts ---
217  
218  prompt_value() {
219    local prompt="$1"
220    local default="$2"
221    local value
222  
223    if [ -n "$default" ]; then
224      printf "%s [%s]: " "$prompt" "$default" >&2
225    else
226      printf "%s: " "$prompt" >&2
227    fi
228  
229    read -r value
230    if [ -z "$value" ]; then
231      echo "$default"
232    else
233      echo "$value"
234    fi
235  }
236  
237  prompt_yes_no() {
238    local prompt="$1"
239    local default="${2:-y}"
240    local value
241  
242    if [ "$default" = "y" ]; then
243      printf "%s [Y/n]: " "$prompt" >&2
244    else
245      printf "%s [y/N]: " "$prompt" >&2
246    fi
247  
248    read -r value
249    value=${value:-$default}
250    case "$value" in
251      [yY]*) return 0 ;;
252      *) return 1 ;;
253    esac
254  }
255  
256  prompt_choice() {
257    local prompt="$1"
258    shift
259    local options=("$@")
260  
261    echo "$prompt" >&2
262    for i in "${!options[@]}"; do
263      echo "  $((i + 1))) ${options[$i]}" >&2
264    done
265    printf "Choice [1]: " >&2
266  
267    local choice
268    read -r choice
269    choice=${choice:-1}
270  
271    if [ "$choice" -ge 1 ] && [ "$choice" -le "${#options[@]}" ] 2>/dev/null; then
272      echo "${options[$((choice - 1))]}"
273    else
274      echo "${options[0]}"
275    fi
276  }
277  
278  # --- Benchmark script generation ---
279  
280  generate_benchmark_script() {
281    local train_cmd="$1"
282    local primary_metric="$2"
283    local primary_regex="$3"
284    local secondary_metrics="$4"  # space-separated
285  
286    local extraction=""
287  
288    # Primary metric: extract last occurrence.
289    extraction+="# Extract last occurrence of each metric (final epoch values).\n"
290    extraction+="# The regex patterns here match what autoresearch.yaml expects.\n\n"
291  
292    extraction+="# Primary metric\n"
293    extraction+="PRIMARY_VAL=\$(grep -oP '${primary_regex}' \"\$OUTPUT_FILE\" | tail -1 | grep -oP '[\\d.]+(?:e[+-]?\\d+)?' | tail -1)\n"
294    extraction+="if [ -n \"\$PRIMARY_VAL\" ]; then\n"
295    extraction+="  echo \"${primary_metric} : \$PRIMARY_VAL\"\n"
296    extraction+="else\n"
297    extraction+="  echo \"WARNING: could not extract ${primary_metric} from output\" >&2\n"
298    extraction+="  exit 1\n"
299    extraction+="fi\n"
300  
301    # Secondary metrics.
302    for metric in $secondary_metrics; do
303      if [ "$metric" = "$primary_metric" ]; then
304        continue
305      fi
306      local regex="${METRIC_REGEXES[$metric]:-}"
307      if [ -n "$regex" ]; then
308        extraction+="\n# Secondary: $metric\n"
309        extraction+="${metric^^}_VAL=\$(grep -oP '${regex}' \"\$OUTPUT_FILE\" | tail -1 | grep -oP '[\\d.]+(?:e[+-]?\\d+)?' | tail -1)\n"
310        extraction+="if [ -n \"\${${metric^^}_VAL}\" ]; then\n"
311        extraction+="  echo \"${metric} : \${${metric^^}_VAL}\"\n"
312        extraction+="fi\n"
313      fi
314    done
315  
316    echo -e "$extraction"
317  }
318  
319  # --- Main init command ---
320  
321  cmd_init() {
322    local dir
323    dir=$(pwd)
324  
325    # Parse flags.
326    local non_interactive=false
327    while [ $# -gt 0 ]; do
328      case "$1" in
329        --no-input) non_interactive=true; shift ;;
330        -h|--help)
331          echo "Usage: cc-ml init [--no-input]"
332          echo ""
333          echo "Scaffold autoresearch.yaml + bench/benchmark.sh for ML training."
334          echo "Auto-detects framework, metrics, checkpoints, and datasets."
335          echo ""
336          echo "Options:"
337          echo "  --no-input  Accept all defaults without prompting"
338          exit 0
339          ;;
340        *) die "unknown flag: $1" ;;
341      esac
342    done
343  
344    if [ -f "$dir/autoresearch.yaml" ]; then
345      if ! $non_interactive && ! prompt_yes_no "autoresearch.yaml already exists. Overwrite?"; then
346        info "aborting"
347        exit 0
348      fi
349    fi
350  
351    require_cmds git jq rad-artifact
352  
353    info "detecting ML project configuration..."
354  
355    # 1. Framework.
356    local framework
357    framework=$(detect_framework "$dir")
358    if ! $non_interactive; then
359      framework=$(prompt_value "Framework detected" "$framework")
360    fi
361    info "framework: $framework"
362  
363    # 2. Training script.
364    local train_script
365    train_script=$(detect_training_script "$dir")
366    if ! $non_interactive; then
367      train_script=$(prompt_value "Training script" "$train_script")
368    fi
369    if [ -z "$train_script" ]; then
370      die "no training script found — create one first or specify the path"
371    fi
372    info "training script: $train_script"
373  
374    # 3. Training command.
375    local train_cmd="python $train_script"
376    if ! $non_interactive; then
377      train_cmd=$(prompt_value "Training command" "$train_cmd")
378    fi
379  
380    # 4. Primary metric.
381    local suggested
382    suggested=$(suggest_metrics "$framework")
383    local primary_metric
384    primary_metric=$(echo "$suggested" | cut -d' ' -f1)
385    if ! $non_interactive; then
386      echo "Suggested metrics for $framework: $suggested" >&2
387      primary_metric=$(prompt_value "Primary metric (optimization target)" "$primary_metric")
388    fi
389  
390    local primary_regex="${METRIC_REGEXES[$primary_metric]:-}"
391    local primary_unit="${METRIC_UNITS[$primary_metric]:-}"
392    local primary_criteria="${METRIC_CRITERIA[$primary_metric]:-}"
393  
394    # If unknown metric, ask for details.
395    if [ -z "$primary_regex" ]; then
396      if ! $non_interactive; then
397        primary_regex=$(prompt_value "Regex to extract $primary_metric (with capture group)" "")
398        primary_unit=$(prompt_value "Unit for $primary_metric" "")
399        # Infer criteria from name.
400        if echo "$primary_metric" | grep -qiE 'loss|error|perplexity|mse|mae'; then
401          primary_criteria="lower_is_better"
402        elif echo "$primary_metric" | grep -qiE 'accuracy|acc|f1|bleu|reward|score'; then
403          primary_criteria="higher_is_better"
404        else
405          primary_criteria=$(prompt_choice "Criteria for $primary_metric" "lower_is_better" "higher_is_better")
406        fi
407      else
408        die "unknown metric '$primary_metric' — run without --no-input to configure"
409      fi
410    fi
411  
412    if ! $non_interactive; then
413      primary_criteria=$(prompt_value "Criteria" "$primary_criteria")
414    fi
415    info "primary metric: $primary_metric ($primary_unit, $primary_criteria)"
416  
417    # 5. Secondary metrics.
418    local secondary_metrics=""
419    local remaining
420    remaining=$(echo "$suggested" | cut -d' ' -f2-)
421    if ! $non_interactive && [ -n "$remaining" ]; then
422      echo "Suggested secondary metrics: $remaining" >&2
423      secondary_metrics=$(prompt_value "Secondary metrics (space-separated, or empty)" "$remaining")
424    else
425      secondary_metrics="$remaining"
426    fi
427  
428    # 6. Checkpoint directory.
429    local checkpoint_dir
430    checkpoint_dir=$(detect_checkpoint_dir "$dir")
431    if ! $non_interactive; then
432      checkpoint_dir=$(prompt_value "Checkpoint directory" "$checkpoint_dir")
433    fi
434    info "checkpoint dir: $checkpoint_dir"
435  
436    # 7. Datasets.
437    local data_dirs_raw
438    data_dirs_raw=$(detect_data_dirs "$dir")
439    local datasets_json="[]"
440    local datasets_yaml="[]"
441  
442    if [ -n "$data_dirs_raw" ]; then
443      info "detected data directories:"
444      local datasets_arr=()
445      while IFS= read -r dpath; do
446        [ -z "$dpath" ] && continue
447        echo "  $dpath" >&2
448        local dname
449        dname=$(basename "$dpath")
450        local dcid=""
451        if [ -e "$dir/$dpath" ]; then
452          info "computing CID for $dpath (this may take a while for large datasets)..."
453          dcid=$(rad-artifact cid "$dir/$dpath" 2>/dev/null || echo "")
454        fi
455        if [ -n "$dcid" ]; then
456          datasets_arr+=("{\"name\":\"$dname\",\"path\":\"$dpath\",\"cid\":\"$dcid\"}")
457          info "  $dname: $dcid"
458        else
459          datasets_arr+=("{\"name\":\"$dname\",\"path\":\"$dpath\",\"cid\":null}")
460          warn "  could not compute CID for $dpath"
461        fi
462      done <<< "$data_dirs_raw"
463  
464      if [ ${#datasets_arr[@]} -gt 0 ]; then
465        datasets_json=$(printf '%s\n' "${datasets_arr[@]}" | jq -s '.')
466      fi
467    fi
468  
469    if ! $non_interactive; then
470      if [ "$datasets_json" = "[]" ]; then
471        local custom_data
472        custom_data=$(prompt_value "Data directory (or empty to skip)" "")
473        if [ -n "$custom_data" ]; then
474          local dname
475          dname=$(basename "$custom_data")
476          local dcid=""
477          if [ -e "$dir/$custom_data" ]; then
478            info "computing CID for $custom_data..."
479            dcid=$(rad-artifact cid "$dir/$custom_data" 2>/dev/null || echo "")
480          fi
481          if [ -n "$dcid" ]; then
482            datasets_json="[{\"name\":\"$dname\",\"path\":\"$custom_data\",\"cid\":\"$dcid\"}]"
483          else
484            datasets_json="[{\"name\":\"$dname\",\"path\":\"$custom_data\",\"cid\":null}]"
485          fi
486        fi
487      fi
488    fi
489  
490    # 8. Build command.
491    local build_cmd
492    build_cmd=$(detect_build_cmd "$dir")
493    if ! $non_interactive; then
494      build_cmd=$(prompt_value "Build/install command" "$build_cmd")
495    fi
496  
497    # 9. Hot files (training script + config files).
498    local hot_files_list="[\"$train_script\"]"
499    for cfg in config.yaml config.json hparams.yaml; do
500      if [ -f "$dir/$cfg" ]; then
501        hot_files_list=$(echo "$hot_files_list" | jq --arg f "$cfg" '. + [$f]')
502      fi
503    done
504  
505    # --- Write files ---
506  
507    info "writing autoresearch.yaml..."
508  
509    # Use python for clean YAML generation.
510    python3 -c "
511  import yaml, json, sys
512  
513  class Dumper(yaml.SafeDumper):
514      pass
515  
516  def dict_representer(dumper, data):
517      return dumper.represent_mapping('tag:yaml.org,2002:map', data.items())
518  Dumper.add_representer(dict, dict_representer)
519  
520  config = {
521      'build_cmd': $(printf '%s' "$build_cmd" | jq -Rs .),
522      'bench_cmd': './bench/benchmark.sh',
523      'bench_dir': 'bench',
524      'test_cmd': None,
525      'metrics': [
526          {
527              'name': $(printf '%s' "$primary_metric" | jq -Rs .),
528              'unit': $(printf '%s' "$primary_unit" | jq -Rs .),
529              'regex': $(printf '%s' "$primary_regex" | jq -Rs .),
530              'criteria': $(printf '%s' "$primary_criteria" | jq -Rs .),
531          }
532      ],
533      'forbidden_paths': [
534          'autoresearch.yaml',
535          'bench/',
536          $(printf '%s' "$checkpoint_dir" | jq -Rs .),
537      ],
538      'hot_files': json.loads($(echo "$hot_files_list" | jq -Rs .)),
539      'timeout_secs': 1800,
540      'ml': {
541          'framework': $(printf '%s' "$framework" | jq -Rs .),
542          'checkpoint_dir': $(printf '%s' "$checkpoint_dir" | jq -Rs .),
543          'datasets': json.loads($(echo "$datasets_json" | jq -Rs .)),
544      },
545  }
546  
547  with open('autoresearch.yaml', 'w') as f:
548      f.write('# autoresearch.yaml \u2014 ML training config\\n')
549      f.write('# Generated by cc-ml init. Consumed by rad-experiment benchmark.\\n\\n')
550      yaml.dump(config, f, Dumper=Dumper, default_flow_style=False, sort_keys=False)
551  " || die "failed to write autoresearch.yaml"
552  
553    # Write benchmark script.
554    info "writing bench/benchmark.sh..."
555    mkdir -p "$dir/bench"
556  
557    local metric_extraction
558    metric_extraction=$(generate_benchmark_script "$train_cmd" "$primary_metric" "$primary_regex" "$secondary_metrics")
559  
560    sed \
561      -e "s|__TRAIN_CMD__|$train_cmd|g" \
562      -e "s|__METRIC_EXTRACTION__|$metric_extraction|g" \
563      "$TEMPLATE_DIR/benchmark.sh" > "$dir/bench/benchmark.sh"
564  
565    chmod +x "$dir/bench/benchmark.sh"
566  
567    # Update .gitignore.
568    info "updating .gitignore..."
569    local gitignore="$dir/.gitignore"
570    touch "$gitignore"
571  
572    for entry in ".community-computer/" "$checkpoint_dir"; do
573      # Strip leading ./ for gitignore.
574      local clean_entry="${entry#./}"
575      if ! grep -qF "$clean_entry" "$gitignore" 2>/dev/null; then
576        echo "$clean_entry" >> "$gitignore"
577      fi
578    done
579  
580    # Summary.
581    echo "" >&2
582    info "=== cc-ml init complete ==="
583    info "  autoresearch.yaml  — training config (edit metrics/hot_files as needed)"
584    info "  bench/benchmark.sh — training wrapper (edit if output format differs)"
585    info "  .gitignore         — updated"
586    echo "" >&2
587    info "Next steps:"
588    info "  1. Review autoresearch.yaml and bench/benchmark.sh"
589    info "  2. Run a test: ./bench/benchmark.sh"
590    info "  3. Start optimizing: use the cc-ml-train skill or rad-experiment"
591  
592    # Note about dataset publishing.
593    local ds_with_cids
594    ds_with_cids=$(echo "$datasets_json" | jq '[.[] | select(.cid != null)] | length')
595    if [ "$ds_with_cids" -gt 0 ]; then
596      echo "" >&2
597      info "Dataset CIDs computed. They will be published as artifacts at"
598      info "CHOSEN_BASE when the first experiment baseline is recorded."
599      info "To register a location (e.g., HuggingFace, S3), run:"
600      info "  rad-artifact location add <commit> --cid <CID> <URL>"
601    fi
602  }