/ examples / basic_usage.py
basic_usage.py
  1  #!/usr/bin/env python
  2  # /// script
  3  # requires-python = ">=3.11"
  4  # dependencies = [
  5  #   "fast-seqfunc",
  6  #   "pandas",
  7  #   "numpy",
  8  #   "matplotlib",
  9  #   "seaborn",
 10  #   "pycaret[full]>=3.0.0",
 11  #   "scikit-learn>=1.0.0",
 12  #   "fast-seqfunc @ git+https://github.com/ericmjl/fast-seqfunc.git@first-implementation",
 13  # ]
 14  # ///
 15  
 16  """
 17  Basic usage example for fast-seqfunc.
 18  
 19  This script demonstrates how to:
 20  1. Generate synthetic DNA sequence-function data
 21  2. Train a sequence-function model using one-hot encoding
 22  3. Evaluate the model
 23  4. Make predictions on new sequences
 24  """
 25  
 26  import random
 27  from pathlib import Path
 28  
 29  import matplotlib.pyplot as plt
 30  import numpy as np
 31  import pandas as pd
 32  import seaborn as sns
 33  
 34  from fast_seqfunc import load_model, predict, save_model, train_model
 35  
 36  # Set random seed for reproducibility
 37  np.random.seed(42)
 38  random.seed(42)
 39  
 40  
 41  def generate_random_nucleotide(length=100):
 42      """Generate a random DNA sequence of specified length."""
 43      nucleotides = "ACGT"
 44      return "".join(random.choice(nucleotides) for _ in range(length))
 45  
 46  
 47  def generate_synthetic_data(n_samples=1000, seq_length=100):
 48      """Generate synthetic sequence-function data.
 49  
 50      Creates sequences with a simple pattern:
 51      - Higher function value if more 'A' and 'G' nucleotides
 52      - Lower function value if more 'C' and 'T' nucleotides
 53      """
 54      sequences = []
 55      functions = []
 56  
 57      for _ in range(n_samples):
 58          # Generate random DNA sequence
 59          seq = generate_random_nucleotide(seq_length)
 60          sequences.append(seq)
 61  
 62          # Calculate function value based on simple rules
 63          # More A and G -> higher function
 64          a_count = seq.count("A")
 65          g_count = seq.count("G")
 66          c_count = seq.count("C")
 67          t_count = seq.count("T")
 68  
 69          # Simple function with some noise
 70          func_value = (
 71              0.5 * (a_count + g_count) / seq_length
 72              - 0.3 * (c_count + t_count) / seq_length
 73          )
 74          # func_value += np.random.normal(0, 0.1)  # Add noise
 75          functions.append(func_value)
 76  
 77      # Create DataFrame
 78      df = pd.DataFrame(
 79          {
 80              "sequence": sequences,
 81              "function": functions,
 82          }
 83      )
 84  
 85      return df
 86  
 87  
 88  def main():
 89      """Run the example pipeline."""
 90      print("Fast-SeqFunc Basic Example")
 91      print("=========================\n")
 92  
 93      # Create directory for outputs
 94      output_dir = Path("examples/output")
 95      output_dir.mkdir(parents=True, exist_ok=True)
 96  
 97      # Generate synthetic data
 98      print("Generating synthetic data...")
 99      n_samples = 5000
100      all_data = generate_synthetic_data(n_samples=n_samples)
101  
102      # Split into train and test sets (validation handled internally)
103      train_size = int(0.8 * n_samples)
104      test_size = n_samples - train_size
105  
106      train_data = all_data[:train_size].copy()
107      test_data = all_data[train_size:].copy()
108  
109      print(f"Data split: {train_size} train, {test_size} test samples")
110  
111      # Save data files
112      train_data.to_csv(output_dir / "train_data.csv", index=False)
113      test_data.to_csv(output_dir / "test_data.csv", index=False)
114  
115      # Train and compare multiple models automatically
116      print("\nTraining and comparing sequence-function models...")
117      model_info = train_model(
118          train_data=train_data,
119          test_data=test_data,
120          sequence_col="sequence",
121          target_col="function",
122          embedding_method="one-hot",
123          model_type="regression",
124          optimization_metric="r2",  # Optimize for R-squared
125      )
126  
127      # Display test results if available
128      if model_info.get("test_results"):
129          print("\nTest metrics from training:")
130          for metric, value in model_info["test_results"].items():
131              print(f"  {metric}: {value:.4f}")
132  
133      # Save the model
134      model_path = output_dir / "model.pkl"
135      save_model(model_info, model_path)
136      print(f"Model saved to {model_path}")
137  
138      # Make predictions on test data
139      print("\nMaking predictions on test data...")
140      test_predictions = predict(model_info, test_data["sequence"])
141  
142      # Create a results DataFrame
143      results_df = test_data.copy()
144      results_df["prediction"] = test_predictions
145      results_df.to_csv(output_dir / "test_predictions.csv", index=False)
146  
147      # Calculate metrics manually
148      true_values = test_data["function"]
149      mse = ((test_predictions - true_values) ** 2).mean()
150      r2 = (
151          1
152          - ((test_predictions - true_values) ** 2).sum()
153          / ((true_values - true_values.mean()) ** 2).sum()
154      )
155  
156      print("Manual test metrics calculation:")
157      print(f"  Mean Squared Error: {mse:.4f}")
158      print(f"  R²: {r2:.4f}")
159  
160      # Create a scatter plot of true vs predicted values
161      plt.figure(figsize=(8, 6))
162      sns.scatterplot(x=true_values, y=test_predictions, alpha=0.6)
163      plt.plot(
164          [min(true_values), max(true_values)],
165          [min(true_values), max(true_values)],
166          "r--",
167      )
168      plt.xlabel("True Function Value")
169      plt.ylabel("Predicted Function Value")
170      plt.title("True vs Predicted Function Values")
171      plt.tight_layout()
172      plt.savefig(output_dir / "true_vs_predicted.png", dpi=300)
173      print(f"Plot saved to {output_dir / 'true_vs_predicted.png'}")
174  
175      # Create plots showing function score vs nucleotide counts
176      print("\nCreating nucleotide count vs function plots...")
177  
178      # Calculate nucleotide counts for all sequences
179      all_data_with_counts = all_data.copy()
180      all_data_with_counts["A_count"] = all_data["sequence"].apply(lambda x: x.count("A"))
181      all_data_with_counts["G_count"] = all_data["sequence"].apply(lambda x: x.count("G"))
182      all_data_with_counts["C_count"] = all_data["sequence"].apply(lambda x: x.count("C"))
183      all_data_with_counts["T_count"] = all_data["sequence"].apply(lambda x: x.count("T"))
184  
185      # Create a 2x2 grid of scatter plots
186      fig, axes = plt.subplots(2, 2, figsize=(12, 10))
187  
188      # Plot function vs A count
189      sns.scatterplot(
190          x="A_count", y="function", data=all_data_with_counts, alpha=0.6, ax=axes[0, 0]
191      )
192      axes[0, 0].set_title("Function vs A Count")
193      axes[0, 0].set_xlabel("Number of A's")
194      axes[0, 0].set_ylabel("Function Value")
195  
196      # Plot function vs G count
197      sns.scatterplot(
198          x="G_count", y="function", data=all_data_with_counts, alpha=0.6, ax=axes[0, 1]
199      )
200      axes[0, 1].set_title("Function vs G Count")
201      axes[0, 1].set_xlabel("Number of G's")
202      axes[0, 1].set_ylabel("Function Value")
203  
204      # Plot function vs C count
205      sns.scatterplot(
206          x="C_count", y="function", data=all_data_with_counts, alpha=0.6, ax=axes[1, 0]
207      )
208      axes[1, 0].set_title("Function vs C Count")
209      axes[1, 0].set_xlabel("Number of C's")
210      axes[1, 0].set_ylabel("Function Value")
211  
212      # Plot function vs T count
213      sns.scatterplot(
214          x="T_count", y="function", data=all_data_with_counts, alpha=0.6, ax=axes[1, 1]
215      )
216      axes[1, 1].set_title("Function vs T Count")
217      axes[1, 1].set_xlabel("Number of T's")
218      axes[1, 1].set_ylabel("Function Value")
219  
220      plt.tight_layout()
221      plt.savefig(output_dir / "nucleotide_counts_vs_function.png", dpi=300)
222      print(
223          f"Nucleotide count plots saved to "
224          f"{output_dir / 'nucleotide_counts_vs_function.png'}"
225      )
226  
227      # Test loading the model
228      print("\nTesting model loading...")
229      load_model(model_path)
230      print("Model loaded successfully")
231  
232  
233  if __name__ == "__main__":
234      main()