/ fast_seqfunc / cli.py
cli.py
  1  """Custom CLI for fast-seqfunc.
  2  
  3  This module provides a command-line interface for training sequence-function models
  4  and making predictions on new sequences.
  5  
  6  Typer's docs can be found at:
  7      https://typer.tiangolo.com
  8  """
  9  
 10  import random
 11  from pathlib import Path
 12  from typing import Any, Dict, Optional
 13  
 14  import numpy as np
 15  import pandas as pd
 16  import typer
 17  from loguru import logger
 18  
 19  from fast_seqfunc import synthetic
 20  from fast_seqfunc.core import (
 21      evaluate_model,
 22      load_model,
 23      predict,
 24      save_model,
 25      train_model,
 26  )
 27  
 28  app = typer.Typer()
 29  
 30  
 31  @app.command()
 32  def train(
 33      train_data: Path = typer.Argument(..., help="Path to CSV file with training data"),
 34      sequence_col: str = typer.Option("sequence", help="Column name for sequences"),
 35      target_col: str = typer.Option("function", help="Column name for target values"),
 36      val_data: Optional[Path] = typer.Option(
 37          None, help="Optional path to validation data"
 38      ),
 39      test_data: Optional[Path] = typer.Option(None, help="Optional path to test data"),
 40      embedding_method: str = typer.Option(
 41          "one-hot", help="Embedding method: one-hot, carp, esm2, or auto"
 42      ),
 43      model_type: str = typer.Option(
 44          "regression", help="Model type: regression or classification"
 45      ),
 46      output_path: Path = typer.Option(
 47          Path("model.pkl"), help="Path to save trained model"
 48      ),
 49      cache_dir: Optional[Path] = typer.Option(
 50          None, help="Directory to cache embeddings"
 51      ),
 52  ):
 53      """Train a sequence-function model on protein or nucleotide sequences."""
 54      logger.info(f"Training model using {embedding_method} embeddings...")
 55  
 56      # Parse embedding methods if multiple are provided
 57      if "," in embedding_method:
 58          embedding_method = [m.strip() for m in embedding_method.split(",")]
 59  
 60      # Train the model
 61      model = train_model(
 62          train_data=train_data,
 63          val_data=val_data,
 64          test_data=test_data,
 65          sequence_col=sequence_col,
 66          target_col=target_col,
 67          embedding_method=embedding_method,
 68          model_type=model_type,
 69          cache_dir=cache_dir,
 70      )
 71  
 72      # Save the trained model
 73      save_model(model, output_path)
 74      logger.info(f"Model saved to {output_path}")
 75  
 76  
 77  @app.command()
 78  def predict_cmd(
 79      model_path: Path = typer.Argument(..., help="Path to saved model"),
 80      input_data: Path = typer.Argument(
 81          ..., help="Path to CSV file with sequences to predict"
 82      ),
 83      sequence_col: str = typer.Option("sequence", help="Column name for sequences"),
 84      output_path: Path = typer.Option(
 85          Path("predictions.csv"), help="Path to save predictions"
 86      ),
 87  ):
 88      """Generate predictions for new sequences using a trained model."""
 89      logger.info(f"Loading model from {model_path}...")
 90      model_info = load_model(model_path)
 91  
 92      # Load input data
 93      logger.info(f"Loading sequences from {input_data}...")
 94      data = pd.read_csv(input_data)
 95  
 96      # Check if sequence column exists
 97      if sequence_col not in data.columns:
 98          logger.error(f"Column '{sequence_col}' not found in input data")
 99          raise typer.Exit(1)
100  
101      # Generate predictions
102      logger.info("Generating predictions...")
103      predictions = predict(
104          model_info=model_info,
105          sequences=data[sequence_col],
106      )
107  
108      # Save predictions
109      result_df = pd.DataFrame(
110          {
111              sequence_col: data[sequence_col],
112              "prediction": predictions,
113          }
114      )
115  
116      # Save to CSV
117      result_df.to_csv(output_path, index=False)
118      logger.info(f"Predictions saved to {output_path}")
119  
120  
121  @app.command()
122  def compare_embeddings(
123      train_data: Path = typer.Argument(..., help="Path to CSV file with training data"),
124      sequence_col: str = typer.Option("sequence", help="Column name for sequences"),
125      target_col: str = typer.Option("function", help="Column name for target values"),
126      val_data: Optional[Path] = typer.Option(
127          None, help="Optional path to validation data"
128      ),
129      test_data: Optional[Path] = typer.Option(
130          None, help="Optional path to test data for final evaluation"
131      ),
132      model_type: str = typer.Option(
133          "regression", help="Model type: regression or classification"
134      ),
135      output_path: Path = typer.Option(
136          Path("embedding_comparison.csv"), help="Path to save comparison results"
137      ),
138      cache_dir: Optional[Path] = typer.Option(
139          None, help="Directory to cache embeddings"
140      ),
141  ):
142      """Compare different embedding methods on the same dataset."""
143      logger.info("Comparing embedding methods...")
144  
145      # List of embedding methods to compare
146      embedding_methods = ["one-hot", "carp", "esm2"]
147      results = []
148  
149      # Train models with each embedding method
150      for method in embedding_methods:
151          try:
152              logger.info(f"Training with {method} embeddings...")
153  
154              # Train model with this embedding method
155              model_info = train_model(
156                  train_data=train_data,
157                  val_data=val_data,
158                  test_data=test_data,
159                  sequence_col=sequence_col,
160                  target_col=target_col,
161                  embedding_method=method,
162                  model_type=model_type,
163                  cache_dir=cache_dir,
164              )
165  
166              # Evaluate on test data if provided
167              if test_data:
168                  test_df = pd.read_csv(test_data)
169  
170                  # Extract model components
171                  model = model_info["model"]
172                  embedder = model_info["embedder"]
173                  embed_cols = model_info["embed_cols"]
174  
175                  metrics = evaluate_model(
176                      model=model,
177                      X_test=test_df[sequence_col],
178                      y_test=test_df[target_col],
179                      embedder=embedder,
180                      model_type=model_type,
181                      embed_cols=embed_cols,
182                  )
183  
184                  # Add method and metrics to results
185                  result = {"embedding_method": method, **metrics}
186                  results.append(result)
187          except Exception as e:
188              logger.error(f"Error training with {method}: {e}")
189  
190      # Create DataFrame with results
191      results_df = pd.DataFrame(results)
192  
193      # Save to CSV
194      results_df.to_csv(output_path, index=False)
195      logger.info(f"Comparison results saved to {output_path}")
196  
197  
198  @app.command()
199  def hello():
200      """Echo the project's name."""
201      typer.echo("This project's name is fast-seqfunc")
202  
203  
204  @app.command()
205  def describe():
206      """Describe the project."""
207      typer.echo("Painless sequence-function models for proteins and nucleotides.")
208  
209  
210  @app.command()
211  def generate_synthetic(
212      task: str = typer.Argument(
213          ...,
214          help="Type of synthetic data task to generate. Options: g_count, gc_content, "
215          "motif_position, motif_count, length_dependent, nonlinear_composition, "
216          "interaction, classification, multiclass",
217      ),
218      output_dir: Path = typer.Option(
219          Path("synthetic_data"), help="Directory to save generated datasets"
220      ),
221      total_count: int = typer.Option(1000, help="Total number of sequences to generate"),
222      train_ratio: float = typer.Option(
223          0.7, help="Proportion of data to use for training set"
224      ),
225      val_ratio: float = typer.Option(
226          0.15, help="Proportion of data to use for validation set"
227      ),
228      test_ratio: float = typer.Option(
229          0.15, help="Proportion of data to use for test set"
230      ),
231      split_data: bool = typer.Option(
232          True, help="Whether to split data into train/val/test sets"
233      ),
234      sequence_length: int = typer.Option(
235          30, help="Length of each sequence (for fixed-length tasks)"
236      ),
237      min_length: int = typer.Option(
238          20, help="Minimum sequence length (for variable-length tasks)"
239      ),
240      max_length: int = typer.Option(
241          50, help="Maximum sequence length (for variable-length tasks)"
242      ),
243      noise_level: float = typer.Option(0.1, help="Level of noise to add to the data"),
244      sequence_type: str = typer.Option(
245          "dna", help="Type of sequences to generate: dna, rna, or protein"
246      ),
247      alphabet: Optional[str] = typer.Option(
248          None, help="Custom alphabet for sequences. Overrides sequence_type if provided."
249      ),
250      motif: Optional[str] = typer.Option(
251          None, help="Custom motif for motif-based tasks"
252      ),
253      motifs: Optional[str] = typer.Option(
254          None, help="Comma-separated list of motifs for motif_count task"
255      ),
256      weights: Optional[str] = typer.Option(
257          None, help="Comma-separated list of weights for motif_count task"
258      ),
259      prefix: str = typer.Option("", help="Prefix for output filenames"),
260      random_seed: Optional[int] = typer.Option(
261          None, help="Random seed for reproducibility"
262      ),
263  ):
264      """Generate synthetic sequence-function data for testing and benchmarking.
265  
266      This command creates synthetic datasets with controllable properties and
267      complexity to test sequence-function models. Data can be split into
268      train/validation/test sets.
269  
270      Each task produces a different type of sequence-function relationship:
271  
272      - g_count: Linear relationship based on count of G nucleotides
273      - gc_content: Linear relationship based on GC content
274      - motif_position: Function depends on the position of a motif (nonlinear)
275      - motif_count: Function depends on counts of multiple motifs (linear)
276      - length_dependent: Function depends on sequence length (nonlinear)
277      - nonlinear_composition: Nonlinear function of base composition
278      - interaction: Function depends on interactions between positions
279      - classification: Binary classification based on presence of motifs
280      - multiclass: Multi-class classification based on different patterns
281  
282      Example usage:
283  
284      $ fast-seqfunc generate-synthetic gc_content --output-dir data/gc_task
285  
286      $ fast-seqfunc generate-synthetic motif_position --motif ATCG --noise-level 0.2
287  
288      $ fast-seqfunc generate-synthetic classification \
289          --sequence-type protein \
290          --no-split-data
291      """
292      # Set random seed if provided
293      if random_seed is not None:
294          random.seed(random_seed)
295          np.random.seed(random_seed)
296  
297      logger.info(f"Generating synthetic data for task: {task}")
298  
299      # Create output directory if it doesn't exist
300      output_dir.mkdir(parents=True, exist_ok=True)
301  
302      # Set alphabet based on sequence type
303      if alphabet is None:
304          sequence_type = sequence_type.lower()
305          if sequence_type == "dna":
306              alphabet = "ACGT"
307          elif sequence_type == "rna":
308              alphabet = "ACGU"
309          elif sequence_type == "protein":
310              alphabet = "ACDEFGHIKLMNPQRSTVWY"
311          else:
312              logger.warning(
313                  f"Unknown sequence type: {sequence_type}. Using DNA alphabet."
314              )
315              alphabet = "ACGT"
316  
317      logger.info(f"Using alphabet: {alphabet}")
318  
319      # Task-specific parameters
320      task_params: Dict[str, Any] = {}
321  
322      # Add common parameters that apply to most tasks
323      if task != "length_dependent":
324          task_params["length"] = sequence_length
325  
326      # We need to patch the generate_random_sequences function to use our alphabet
327      # This approach uses monkey patching to avoid having to modify all task functions
328      original_generate_random_sequences = synthetic.generate_random_sequences
329  
330      def patched_generate_random_sequences(*args, **kwargs):
331          """
332          Patched version of `generate_random_sequences` that uses a custom alphabet.
333  
334          This function overrides the alphabet parameter with our custom alphabet while
335          preserving all other parameters passed to the original function.
336  
337          :param args: Positional arguments to pass to the original function
338          :param kwargs: Keyword arguments to pass to the original function
339          :return: Result from the original generate_random_sequences function
340          """
341          # Override the alphabet parameter with our custom alphabet,
342          # but keep other parameters
343          kwargs["alphabet"] = alphabet
344          return original_generate_random_sequences(*args, **kwargs)
345  
346      # Replace the function temporarily
347      synthetic.generate_random_sequences = patched_generate_random_sequences
348  
349      # Add task-specific parameters based on the task type
350      if task == "motif_position":
351          # Use custom motif if provided
352          if motif:
353              task_params["motif"] = motif
354          else:
355              # Default motif depends on alphabet
356              if len(alphabet) == 4:  # DNA/RNA
357                  task_params["motif"] = "".join(random.sample(alphabet, 4))
358              else:  # Protein
359                  task_params["motif"] = "".join(
360                      random.sample(alphabet, min(4, len(alphabet)))
361                  )
362              logger.info(f"Using default motif: {task_params['motif']}")
363  
364      elif task == "motif_count":
365          # Parse custom motifs if provided
366          if motifs:
367              task_params["motifs"] = [m.strip() for m in motifs.split(",")]
368          else:
369              # Generate default motifs based on alphabet
370              if len(alphabet) <= 8:  # DNA/RNA
371                  task_params["motifs"] = [
372                      "".join(random.sample(alphabet, 2)) for _ in range(4)
373                  ]
374              else:  # Protein
375                  task_params["motifs"] = [
376                      "".join(random.sample(alphabet, 3)) for _ in range(4)
377                  ]
378              logger.info(f"Using default motifs: {task_params['motifs']}")
379  
380          # Parse custom weights if provided
381          if weights:
382              try:
383                  weight_values = [float(w.strip()) for w in weights.split(",")]
384                  if len(weight_values) != len(task_params["motifs"]):
385                      logger.warning(
386                          "Number of weights doesn't match number of motifs. "
387                          "Using default weights."
388                      )
389                      task_params["weights"] = [1.0, -0.5, 2.0, -1.5]
390                  else:
391                      task_params["weights"] = weight_values
392              except ValueError:
393                  logger.warning("Invalid weight values. Using default weights.")
394                  task_params["weights"] = [1.0, -0.5, 2.0, -1.5]
395          else:
396              task_params["weights"] = [1.0, -0.5, 2.0, -1.5]
397  
398      elif task == "length_dependent":
399          task_params["min_length"] = min_length
400          task_params["max_length"] = max_length
401  
402      # Validate the task
403      valid_tasks = [
404          "g_count",
405          "gc_content",
406          "motif_position",
407          "motif_count",
408          "length_dependent",
409          "nonlinear_composition",
410          "interaction",
411          "classification",
412          "multiclass",
413      ]
414  
415      if task not in valid_tasks:
416          logger.error(
417              f"Invalid task: {task}. Valid options are: {', '.join(valid_tasks)}"
418          )
419          raise typer.Exit(1)
420  
421      # The task functions don't directly accept an alphabet parameter
422      # so we need to remove it from task_params
423      if "alphabet" in task_params:
424          del task_params["alphabet"]
425  
426      # Generate the dataset
427      try:
428          df = synthetic.generate_dataset_by_task(
429              task=task, count=total_count, noise_level=noise_level, **task_params
430          )
431  
432          logger.info(f"Generated {len(df)} sequences for task: {task}")
433  
434          # Create filename prefix if provided
435          file_prefix = f"{prefix}_" if prefix else ""
436  
437          # Save the full dataset if not splitting
438          if not split_data:
439              output_path = output_dir / f"{file_prefix}{task}_data.csv"
440              df.to_csv(output_path, index=False)
441              logger.info(f"Saved full dataset to {output_path}")
442              # Restore original function
443              synthetic.generate_random_sequences = original_generate_random_sequences
444              return
445  
446          # Validate split ratios
447          if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-10:
448              logger.warning("Split ratios don't sum to 1.0. Normalizing.")
449              total = train_ratio + val_ratio + test_ratio
450              train_ratio /= total
451              val_ratio /= total
452              test_ratio /= total
453  
454          # Shuffle the data
455          df = df.sample(frac=1.0, random_state=random_seed)
456  
457          # Calculate split indices
458          n = len(df)
459          train_idx = int(n * train_ratio)
460          val_idx = train_idx + int(n * val_ratio)
461  
462          # Split the data
463          train_df = df.iloc[:train_idx]
464          val_df = df.iloc[train_idx:val_idx]
465          test_df = df.iloc[val_idx:]
466  
467          # Save the splits
468          train_path = output_dir / f"{file_prefix}train.csv"
469          val_path = output_dir / f"{file_prefix}val.csv"
470          test_path = output_dir / f"{file_prefix}test.csv"
471  
472          train_df.to_csv(train_path, index=False)
473          val_df.to_csv(val_path, index=False)
474          test_df.to_csv(test_path, index=False)
475  
476          logger.info(f"Saved train set ({len(train_df)} samples) to {train_path}")
477          logger.info(f"Saved validation set ({len(val_df)} samples) to {val_path}")
478          logger.info(f"Saved test set ({len(test_df)} samples) to {test_path}")
479  
480          # Save task metadata
481          metadata = {
482              "task": task,
483              "sequence_type": sequence_type,
484              "alphabet": alphabet,
485              "total_count": total_count,
486              "train_count": len(train_df),
487              "val_count": len(val_df),
488              "test_count": len(test_df),
489              "noise_level": noise_level,
490              **task_params,
491          }
492  
493          metadata_path = output_dir / f"{file_prefix}metadata.csv"
494          pd.DataFrame([metadata]).to_csv(metadata_path, index=False)
495          logger.info(f"Saved metadata to {metadata_path}")
496  
497      except Exception as e:
498          logger.error(f"Error generating synthetic data: {e}")
499          raise typer.Exit(1)
500      finally:
501          # Make sure to restore the original function even if an error occurs
502          synthetic.generate_random_sequences = original_generate_random_sequences
503  
504  
505  @app.command()
506  def list_synthetic_tasks():
507      """List all available synthetic sequence-function data tasks with descriptions."""
508      tasks = {
509          "g_count": "A simple linear task where the function value is the count of G "
510          "nucleotides in the sequence.",
511          "gc_content": "A simple linear task where the function value is the GC content "
512          "(proportion of G and C) of the sequence.",
513          "motif_position": "A nonlinear task where the function value depends on the "
514          "position of a specific motif in the sequence.",
515          "motif_count": "A linear task where the function value is a weighted sum of "
516          "counts of multiple motifs in the sequence.",
517          "length_dependent": "A task with variable-length sequences where the function "
518          "value depends nonlinearly on the sequence length.",
519          "nonlinear_composition": "A complex nonlinear task where the function depends "
520          "on ratios between different nucleotide frequencies.",
521          "interaction": "A task testing positional interactions, "
522          "where specific nucleotide pairs at certain positions "
523          "contribute to the function.",
524          "classification": "A binary classification task where the class depends on the "
525          "presence of specific patterns in the sequence.",
526          "multiclass": "A multi-class classification task "
527          "with multiple sequence patterns "
528          "corresponding to different classes.",
529      }
530  
531      typer.echo("Available synthetic sequence-function data tasks:")
532      typer.echo("")
533  
534      for task, description in tasks.items():
535          typer.echo(f"{task}:")
536          typer.echo(f"  {description}")
537          typer.echo("")
538  
539      typer.echo("Usage:")
540      typer.echo("  fast-seqfunc generate-synthetic TASK [OPTIONS]")
541      typer.echo("")
542      typer.echo("For detailed options:")
543      typer.echo("  fast-seqfunc generate-synthetic --help")
544  
545  
546  if __name__ == "__main__":
547      app()