/ fast_seqfunc / cli.py
cli.py
1 """Custom CLI for fast-seqfunc. 2 3 This module provides a command-line interface for training sequence-function models 4 and making predictions on new sequences. 5 6 Typer's docs can be found at: 7 https://typer.tiangolo.com 8 """ 9 10 import random 11 from pathlib import Path 12 from typing import Any, Dict, Optional 13 14 import numpy as np 15 import pandas as pd 16 import typer 17 from loguru import logger 18 19 from fast_seqfunc import synthetic 20 from fast_seqfunc.core import ( 21 evaluate_model, 22 load_model, 23 predict, 24 save_model, 25 train_model, 26 ) 27 28 app = typer.Typer() 29 30 31 @app.command() 32 def train( 33 train_data: Path = typer.Argument(..., help="Path to CSV file with training data"), 34 sequence_col: str = typer.Option("sequence", help="Column name for sequences"), 35 target_col: str = typer.Option("function", help="Column name for target values"), 36 val_data: Optional[Path] = typer.Option( 37 None, help="Optional path to validation data" 38 ), 39 test_data: Optional[Path] = typer.Option(None, help="Optional path to test data"), 40 embedding_method: str = typer.Option( 41 "one-hot", help="Embedding method: one-hot, carp, esm2, or auto" 42 ), 43 model_type: str = typer.Option( 44 "regression", help="Model type: regression or classification" 45 ), 46 output_path: Path = typer.Option( 47 Path("model.pkl"), help="Path to save trained model" 48 ), 49 cache_dir: Optional[Path] = typer.Option( 50 None, help="Directory to cache embeddings" 51 ), 52 ): 53 """Train a sequence-function model on protein or nucleotide sequences.""" 54 logger.info(f"Training model using {embedding_method} embeddings...") 55 56 # Parse embedding methods if multiple are provided 57 if "," in embedding_method: 58 embedding_method = [m.strip() for m in embedding_method.split(",")] 59 60 # Train the model 61 model = train_model( 62 train_data=train_data, 63 val_data=val_data, 64 test_data=test_data, 65 sequence_col=sequence_col, 66 target_col=target_col, 67 embedding_method=embedding_method, 68 model_type=model_type, 69 cache_dir=cache_dir, 70 ) 71 72 # Save the trained model 73 save_model(model, output_path) 74 logger.info(f"Model saved to {output_path}") 75 76 77 @app.command() 78 def predict_cmd( 79 model_path: Path = typer.Argument(..., help="Path to saved model"), 80 input_data: Path = typer.Argument( 81 ..., help="Path to CSV file with sequences to predict" 82 ), 83 sequence_col: str = typer.Option("sequence", help="Column name for sequences"), 84 output_path: Path = typer.Option( 85 Path("predictions.csv"), help="Path to save predictions" 86 ), 87 ): 88 """Generate predictions for new sequences using a trained model.""" 89 logger.info(f"Loading model from {model_path}...") 90 model_info = load_model(model_path) 91 92 # Load input data 93 logger.info(f"Loading sequences from {input_data}...") 94 data = pd.read_csv(input_data) 95 96 # Check if sequence column exists 97 if sequence_col not in data.columns: 98 logger.error(f"Column '{sequence_col}' not found in input data") 99 raise typer.Exit(1) 100 101 # Generate predictions 102 logger.info("Generating predictions...") 103 predictions = predict( 104 model_info=model_info, 105 sequences=data[sequence_col], 106 ) 107 108 # Save predictions 109 result_df = pd.DataFrame( 110 { 111 sequence_col: data[sequence_col], 112 "prediction": predictions, 113 } 114 ) 115 116 # Save to CSV 117 result_df.to_csv(output_path, index=False) 118 logger.info(f"Predictions saved to {output_path}") 119 120 121 @app.command() 122 def compare_embeddings( 123 train_data: Path = typer.Argument(..., help="Path to CSV file with training data"), 124 sequence_col: str = typer.Option("sequence", help="Column name for sequences"), 125 target_col: str = typer.Option("function", help="Column name for target values"), 126 val_data: Optional[Path] = typer.Option( 127 None, help="Optional path to validation data" 128 ), 129 test_data: Optional[Path] = typer.Option( 130 None, help="Optional path to test data for final evaluation" 131 ), 132 model_type: str = typer.Option( 133 "regression", help="Model type: regression or classification" 134 ), 135 output_path: Path = typer.Option( 136 Path("embedding_comparison.csv"), help="Path to save comparison results" 137 ), 138 cache_dir: Optional[Path] = typer.Option( 139 None, help="Directory to cache embeddings" 140 ), 141 ): 142 """Compare different embedding methods on the same dataset.""" 143 logger.info("Comparing embedding methods...") 144 145 # List of embedding methods to compare 146 embedding_methods = ["one-hot", "carp", "esm2"] 147 results = [] 148 149 # Train models with each embedding method 150 for method in embedding_methods: 151 try: 152 logger.info(f"Training with {method} embeddings...") 153 154 # Train model with this embedding method 155 model_info = train_model( 156 train_data=train_data, 157 val_data=val_data, 158 test_data=test_data, 159 sequence_col=sequence_col, 160 target_col=target_col, 161 embedding_method=method, 162 model_type=model_type, 163 cache_dir=cache_dir, 164 ) 165 166 # Evaluate on test data if provided 167 if test_data: 168 test_df = pd.read_csv(test_data) 169 170 # Extract model components 171 model = model_info["model"] 172 embedder = model_info["embedder"] 173 embed_cols = model_info["embed_cols"] 174 175 metrics = evaluate_model( 176 model=model, 177 X_test=test_df[sequence_col], 178 y_test=test_df[target_col], 179 embedder=embedder, 180 model_type=model_type, 181 embed_cols=embed_cols, 182 ) 183 184 # Add method and metrics to results 185 result = {"embedding_method": method, **metrics} 186 results.append(result) 187 except Exception as e: 188 logger.error(f"Error training with {method}: {e}") 189 190 # Create DataFrame with results 191 results_df = pd.DataFrame(results) 192 193 # Save to CSV 194 results_df.to_csv(output_path, index=False) 195 logger.info(f"Comparison results saved to {output_path}") 196 197 198 @app.command() 199 def hello(): 200 """Echo the project's name.""" 201 typer.echo("This project's name is fast-seqfunc") 202 203 204 @app.command() 205 def describe(): 206 """Describe the project.""" 207 typer.echo("Painless sequence-function models for proteins and nucleotides.") 208 209 210 @app.command() 211 def generate_synthetic( 212 task: str = typer.Argument( 213 ..., 214 help="Type of synthetic data task to generate. Options: g_count, gc_content, " 215 "motif_position, motif_count, length_dependent, nonlinear_composition, " 216 "interaction, classification, multiclass", 217 ), 218 output_dir: Path = typer.Option( 219 Path("synthetic_data"), help="Directory to save generated datasets" 220 ), 221 total_count: int = typer.Option(1000, help="Total number of sequences to generate"), 222 train_ratio: float = typer.Option( 223 0.7, help="Proportion of data to use for training set" 224 ), 225 val_ratio: float = typer.Option( 226 0.15, help="Proportion of data to use for validation set" 227 ), 228 test_ratio: float = typer.Option( 229 0.15, help="Proportion of data to use for test set" 230 ), 231 split_data: bool = typer.Option( 232 True, help="Whether to split data into train/val/test sets" 233 ), 234 sequence_length: int = typer.Option( 235 30, help="Length of each sequence (for fixed-length tasks)" 236 ), 237 min_length: int = typer.Option( 238 20, help="Minimum sequence length (for variable-length tasks)" 239 ), 240 max_length: int = typer.Option( 241 50, help="Maximum sequence length (for variable-length tasks)" 242 ), 243 noise_level: float = typer.Option(0.1, help="Level of noise to add to the data"), 244 sequence_type: str = typer.Option( 245 "dna", help="Type of sequences to generate: dna, rna, or protein" 246 ), 247 alphabet: Optional[str] = typer.Option( 248 None, help="Custom alphabet for sequences. Overrides sequence_type if provided." 249 ), 250 motif: Optional[str] = typer.Option( 251 None, help="Custom motif for motif-based tasks" 252 ), 253 motifs: Optional[str] = typer.Option( 254 None, help="Comma-separated list of motifs for motif_count task" 255 ), 256 weights: Optional[str] = typer.Option( 257 None, help="Comma-separated list of weights for motif_count task" 258 ), 259 prefix: str = typer.Option("", help="Prefix for output filenames"), 260 random_seed: Optional[int] = typer.Option( 261 None, help="Random seed for reproducibility" 262 ), 263 ): 264 """Generate synthetic sequence-function data for testing and benchmarking. 265 266 This command creates synthetic datasets with controllable properties and 267 complexity to test sequence-function models. Data can be split into 268 train/validation/test sets. 269 270 Each task produces a different type of sequence-function relationship: 271 272 - g_count: Linear relationship based on count of G nucleotides 273 - gc_content: Linear relationship based on GC content 274 - motif_position: Function depends on the position of a motif (nonlinear) 275 - motif_count: Function depends on counts of multiple motifs (linear) 276 - length_dependent: Function depends on sequence length (nonlinear) 277 - nonlinear_composition: Nonlinear function of base composition 278 - interaction: Function depends on interactions between positions 279 - classification: Binary classification based on presence of motifs 280 - multiclass: Multi-class classification based on different patterns 281 282 Example usage: 283 284 $ fast-seqfunc generate-synthetic gc_content --output-dir data/gc_task 285 286 $ fast-seqfunc generate-synthetic motif_position --motif ATCG --noise-level 0.2 287 288 $ fast-seqfunc generate-synthetic classification \ 289 --sequence-type protein \ 290 --no-split-data 291 """ 292 # Set random seed if provided 293 if random_seed is not None: 294 random.seed(random_seed) 295 np.random.seed(random_seed) 296 297 logger.info(f"Generating synthetic data for task: {task}") 298 299 # Create output directory if it doesn't exist 300 output_dir.mkdir(parents=True, exist_ok=True) 301 302 # Set alphabet based on sequence type 303 if alphabet is None: 304 sequence_type = sequence_type.lower() 305 if sequence_type == "dna": 306 alphabet = "ACGT" 307 elif sequence_type == "rna": 308 alphabet = "ACGU" 309 elif sequence_type == "protein": 310 alphabet = "ACDEFGHIKLMNPQRSTVWY" 311 else: 312 logger.warning( 313 f"Unknown sequence type: {sequence_type}. Using DNA alphabet." 314 ) 315 alphabet = "ACGT" 316 317 logger.info(f"Using alphabet: {alphabet}") 318 319 # Task-specific parameters 320 task_params: Dict[str, Any] = {} 321 322 # Add common parameters that apply to most tasks 323 if task != "length_dependent": 324 task_params["length"] = sequence_length 325 326 # We need to patch the generate_random_sequences function to use our alphabet 327 # This approach uses monkey patching to avoid having to modify all task functions 328 original_generate_random_sequences = synthetic.generate_random_sequences 329 330 def patched_generate_random_sequences(*args, **kwargs): 331 """ 332 Patched version of `generate_random_sequences` that uses a custom alphabet. 333 334 This function overrides the alphabet parameter with our custom alphabet while 335 preserving all other parameters passed to the original function. 336 337 :param args: Positional arguments to pass to the original function 338 :param kwargs: Keyword arguments to pass to the original function 339 :return: Result from the original generate_random_sequences function 340 """ 341 # Override the alphabet parameter with our custom alphabet, 342 # but keep other parameters 343 kwargs["alphabet"] = alphabet 344 return original_generate_random_sequences(*args, **kwargs) 345 346 # Replace the function temporarily 347 synthetic.generate_random_sequences = patched_generate_random_sequences 348 349 # Add task-specific parameters based on the task type 350 if task == "motif_position": 351 # Use custom motif if provided 352 if motif: 353 task_params["motif"] = motif 354 else: 355 # Default motif depends on alphabet 356 if len(alphabet) == 4: # DNA/RNA 357 task_params["motif"] = "".join(random.sample(alphabet, 4)) 358 else: # Protein 359 task_params["motif"] = "".join( 360 random.sample(alphabet, min(4, len(alphabet))) 361 ) 362 logger.info(f"Using default motif: {task_params['motif']}") 363 364 elif task == "motif_count": 365 # Parse custom motifs if provided 366 if motifs: 367 task_params["motifs"] = [m.strip() for m in motifs.split(",")] 368 else: 369 # Generate default motifs based on alphabet 370 if len(alphabet) <= 8: # DNA/RNA 371 task_params["motifs"] = [ 372 "".join(random.sample(alphabet, 2)) for _ in range(4) 373 ] 374 else: # Protein 375 task_params["motifs"] = [ 376 "".join(random.sample(alphabet, 3)) for _ in range(4) 377 ] 378 logger.info(f"Using default motifs: {task_params['motifs']}") 379 380 # Parse custom weights if provided 381 if weights: 382 try: 383 weight_values = [float(w.strip()) for w in weights.split(",")] 384 if len(weight_values) != len(task_params["motifs"]): 385 logger.warning( 386 "Number of weights doesn't match number of motifs. " 387 "Using default weights." 388 ) 389 task_params["weights"] = [1.0, -0.5, 2.0, -1.5] 390 else: 391 task_params["weights"] = weight_values 392 except ValueError: 393 logger.warning("Invalid weight values. Using default weights.") 394 task_params["weights"] = [1.0, -0.5, 2.0, -1.5] 395 else: 396 task_params["weights"] = [1.0, -0.5, 2.0, -1.5] 397 398 elif task == "length_dependent": 399 task_params["min_length"] = min_length 400 task_params["max_length"] = max_length 401 402 # Validate the task 403 valid_tasks = [ 404 "g_count", 405 "gc_content", 406 "motif_position", 407 "motif_count", 408 "length_dependent", 409 "nonlinear_composition", 410 "interaction", 411 "classification", 412 "multiclass", 413 ] 414 415 if task not in valid_tasks: 416 logger.error( 417 f"Invalid task: {task}. Valid options are: {', '.join(valid_tasks)}" 418 ) 419 raise typer.Exit(1) 420 421 # The task functions don't directly accept an alphabet parameter 422 # so we need to remove it from task_params 423 if "alphabet" in task_params: 424 del task_params["alphabet"] 425 426 # Generate the dataset 427 try: 428 df = synthetic.generate_dataset_by_task( 429 task=task, count=total_count, noise_level=noise_level, **task_params 430 ) 431 432 logger.info(f"Generated {len(df)} sequences for task: {task}") 433 434 # Create filename prefix if provided 435 file_prefix = f"{prefix}_" if prefix else "" 436 437 # Save the full dataset if not splitting 438 if not split_data: 439 output_path = output_dir / f"{file_prefix}{task}_data.csv" 440 df.to_csv(output_path, index=False) 441 logger.info(f"Saved full dataset to {output_path}") 442 # Restore original function 443 synthetic.generate_random_sequences = original_generate_random_sequences 444 return 445 446 # Validate split ratios 447 if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-10: 448 logger.warning("Split ratios don't sum to 1.0. Normalizing.") 449 total = train_ratio + val_ratio + test_ratio 450 train_ratio /= total 451 val_ratio /= total 452 test_ratio /= total 453 454 # Shuffle the data 455 df = df.sample(frac=1.0, random_state=random_seed) 456 457 # Calculate split indices 458 n = len(df) 459 train_idx = int(n * train_ratio) 460 val_idx = train_idx + int(n * val_ratio) 461 462 # Split the data 463 train_df = df.iloc[:train_idx] 464 val_df = df.iloc[train_idx:val_idx] 465 test_df = df.iloc[val_idx:] 466 467 # Save the splits 468 train_path = output_dir / f"{file_prefix}train.csv" 469 val_path = output_dir / f"{file_prefix}val.csv" 470 test_path = output_dir / f"{file_prefix}test.csv" 471 472 train_df.to_csv(train_path, index=False) 473 val_df.to_csv(val_path, index=False) 474 test_df.to_csv(test_path, index=False) 475 476 logger.info(f"Saved train set ({len(train_df)} samples) to {train_path}") 477 logger.info(f"Saved validation set ({len(val_df)} samples) to {val_path}") 478 logger.info(f"Saved test set ({len(test_df)} samples) to {test_path}") 479 480 # Save task metadata 481 metadata = { 482 "task": task, 483 "sequence_type": sequence_type, 484 "alphabet": alphabet, 485 "total_count": total_count, 486 "train_count": len(train_df), 487 "val_count": len(val_df), 488 "test_count": len(test_df), 489 "noise_level": noise_level, 490 **task_params, 491 } 492 493 metadata_path = output_dir / f"{file_prefix}metadata.csv" 494 pd.DataFrame([metadata]).to_csv(metadata_path, index=False) 495 logger.info(f"Saved metadata to {metadata_path}") 496 497 except Exception as e: 498 logger.error(f"Error generating synthetic data: {e}") 499 raise typer.Exit(1) 500 finally: 501 # Make sure to restore the original function even if an error occurs 502 synthetic.generate_random_sequences = original_generate_random_sequences 503 504 505 @app.command() 506 def list_synthetic_tasks(): 507 """List all available synthetic sequence-function data tasks with descriptions.""" 508 tasks = { 509 "g_count": "A simple linear task where the function value is the count of G " 510 "nucleotides in the sequence.", 511 "gc_content": "A simple linear task where the function value is the GC content " 512 "(proportion of G and C) of the sequence.", 513 "motif_position": "A nonlinear task where the function value depends on the " 514 "position of a specific motif in the sequence.", 515 "motif_count": "A linear task where the function value is a weighted sum of " 516 "counts of multiple motifs in the sequence.", 517 "length_dependent": "A task with variable-length sequences where the function " 518 "value depends nonlinearly on the sequence length.", 519 "nonlinear_composition": "A complex nonlinear task where the function depends " 520 "on ratios between different nucleotide frequencies.", 521 "interaction": "A task testing positional interactions, " 522 "where specific nucleotide pairs at certain positions " 523 "contribute to the function.", 524 "classification": "A binary classification task where the class depends on the " 525 "presence of specific patterns in the sequence.", 526 "multiclass": "A multi-class classification task " 527 "with multiple sequence patterns " 528 "corresponding to different classes.", 529 } 530 531 typer.echo("Available synthetic sequence-function data tasks:") 532 typer.echo("") 533 534 for task, description in tasks.items(): 535 typer.echo(f"{task}:") 536 typer.echo(f" {description}") 537 typer.echo("") 538 539 typer.echo("Usage:") 540 typer.echo(" fast-seqfunc generate-synthetic TASK [OPTIONS]") 541 typer.echo("") 542 typer.echo("For detailed options:") 543 typer.echo(" fast-seqfunc generate-synthetic --help") 544 545 546 if __name__ == "__main__": 547 app()