basic_usage.py
1 #!/usr/bin/env python 2 # /// script 3 # requires-python = ">=3.11" 4 # dependencies = [ 5 # "fast-seqfunc", 6 # "pandas", 7 # "numpy", 8 # "matplotlib", 9 # "seaborn", 10 # "pycaret[full]>=3.0.0", 11 # "scikit-learn>=1.0.0", 12 # "fast-seqfunc @ git+https://github.com/ericmjl/fast-seqfunc.git@first-implementation", 13 # ] 14 # /// 15 16 """ 17 Basic usage example for fast-seqfunc. 18 19 This script demonstrates how to: 20 1. Generate synthetic DNA sequence-function data 21 2. Train a sequence-function model using one-hot encoding 22 3. Evaluate the model 23 4. Make predictions on new sequences 24 """ 25 26 import random 27 from pathlib import Path 28 29 import matplotlib.pyplot as plt 30 import numpy as np 31 import pandas as pd 32 import seaborn as sns 33 34 from fast_seqfunc import load_model, predict, save_model, train_model 35 36 # Set random seed for reproducibility 37 np.random.seed(42) 38 random.seed(42) 39 40 41 def generate_random_nucleotide(length=100): 42 """Generate a random DNA sequence of specified length.""" 43 nucleotides = "ACGT" 44 return "".join(random.choice(nucleotides) for _ in range(length)) 45 46 47 def generate_synthetic_data(n_samples=1000, seq_length=100): 48 """Generate synthetic sequence-function data. 49 50 Creates sequences with a simple pattern: 51 - Higher function value if more 'A' and 'G' nucleotides 52 - Lower function value if more 'C' and 'T' nucleotides 53 """ 54 sequences = [] 55 functions = [] 56 57 for _ in range(n_samples): 58 # Generate random DNA sequence 59 seq = generate_random_nucleotide(seq_length) 60 sequences.append(seq) 61 62 # Calculate function value based on simple rules 63 # More A and G -> higher function 64 a_count = seq.count("A") 65 g_count = seq.count("G") 66 c_count = seq.count("C") 67 t_count = seq.count("T") 68 69 # Simple function with some noise 70 func_value = ( 71 0.5 * (a_count + g_count) / seq_length 72 - 0.3 * (c_count + t_count) / seq_length 73 ) 74 # func_value += np.random.normal(0, 0.1) # Add noise 75 functions.append(func_value) 76 77 # Create DataFrame 78 df = pd.DataFrame( 79 { 80 "sequence": sequences, 81 "function": functions, 82 } 83 ) 84 85 return df 86 87 88 def main(): 89 """Run the example pipeline.""" 90 print("Fast-SeqFunc Basic Example") 91 print("=========================\n") 92 93 # Create directory for outputs 94 output_dir = Path("examples/output") 95 output_dir.mkdir(parents=True, exist_ok=True) 96 97 # Generate synthetic data 98 print("Generating synthetic data...") 99 n_samples = 5000 100 all_data = generate_synthetic_data(n_samples=n_samples) 101 102 # Split into train and test sets (validation handled internally) 103 train_size = int(0.8 * n_samples) 104 test_size = n_samples - train_size 105 106 train_data = all_data[:train_size].copy() 107 test_data = all_data[train_size:].copy() 108 109 print(f"Data split: {train_size} train, {test_size} test samples") 110 111 # Save data files 112 train_data.to_csv(output_dir / "train_data.csv", index=False) 113 test_data.to_csv(output_dir / "test_data.csv", index=False) 114 115 # Train and compare multiple models automatically 116 print("\nTraining and comparing sequence-function models...") 117 model_info = train_model( 118 train_data=train_data, 119 test_data=test_data, 120 sequence_col="sequence", 121 target_col="function", 122 embedding_method="one-hot", 123 model_type="regression", 124 optimization_metric="r2", # Optimize for R-squared 125 ) 126 127 # Display test results if available 128 if model_info.get("test_results"): 129 print("\nTest metrics from training:") 130 for metric, value in model_info["test_results"].items(): 131 print(f" {metric}: {value:.4f}") 132 133 # Save the model 134 model_path = output_dir / "model.pkl" 135 save_model(model_info, model_path) 136 print(f"Model saved to {model_path}") 137 138 # Make predictions on test data 139 print("\nMaking predictions on test data...") 140 test_predictions = predict(model_info, test_data["sequence"]) 141 142 # Create a results DataFrame 143 results_df = test_data.copy() 144 results_df["prediction"] = test_predictions 145 results_df.to_csv(output_dir / "test_predictions.csv", index=False) 146 147 # Calculate metrics manually 148 true_values = test_data["function"] 149 mse = ((test_predictions - true_values) ** 2).mean() 150 r2 = ( 151 1 152 - ((test_predictions - true_values) ** 2).sum() 153 / ((true_values - true_values.mean()) ** 2).sum() 154 ) 155 156 print("Manual test metrics calculation:") 157 print(f" Mean Squared Error: {mse:.4f}") 158 print(f" R²: {r2:.4f}") 159 160 # Create a scatter plot of true vs predicted values 161 plt.figure(figsize=(8, 6)) 162 sns.scatterplot(x=true_values, y=test_predictions, alpha=0.6) 163 plt.plot( 164 [min(true_values), max(true_values)], 165 [min(true_values), max(true_values)], 166 "r--", 167 ) 168 plt.xlabel("True Function Value") 169 plt.ylabel("Predicted Function Value") 170 plt.title("True vs Predicted Function Values") 171 plt.tight_layout() 172 plt.savefig(output_dir / "true_vs_predicted.png", dpi=300) 173 print(f"Plot saved to {output_dir / 'true_vs_predicted.png'}") 174 175 # Create plots showing function score vs nucleotide counts 176 print("\nCreating nucleotide count vs function plots...") 177 178 # Calculate nucleotide counts for all sequences 179 all_data_with_counts = all_data.copy() 180 all_data_with_counts["A_count"] = all_data["sequence"].apply(lambda x: x.count("A")) 181 all_data_with_counts["G_count"] = all_data["sequence"].apply(lambda x: x.count("G")) 182 all_data_with_counts["C_count"] = all_data["sequence"].apply(lambda x: x.count("C")) 183 all_data_with_counts["T_count"] = all_data["sequence"].apply(lambda x: x.count("T")) 184 185 # Create a 2x2 grid of scatter plots 186 fig, axes = plt.subplots(2, 2, figsize=(12, 10)) 187 188 # Plot function vs A count 189 sns.scatterplot( 190 x="A_count", y="function", data=all_data_with_counts, alpha=0.6, ax=axes[0, 0] 191 ) 192 axes[0, 0].set_title("Function vs A Count") 193 axes[0, 0].set_xlabel("Number of A's") 194 axes[0, 0].set_ylabel("Function Value") 195 196 # Plot function vs G count 197 sns.scatterplot( 198 x="G_count", y="function", data=all_data_with_counts, alpha=0.6, ax=axes[0, 1] 199 ) 200 axes[0, 1].set_title("Function vs G Count") 201 axes[0, 1].set_xlabel("Number of G's") 202 axes[0, 1].set_ylabel("Function Value") 203 204 # Plot function vs C count 205 sns.scatterplot( 206 x="C_count", y="function", data=all_data_with_counts, alpha=0.6, ax=axes[1, 0] 207 ) 208 axes[1, 0].set_title("Function vs C Count") 209 axes[1, 0].set_xlabel("Number of C's") 210 axes[1, 0].set_ylabel("Function Value") 211 212 # Plot function vs T count 213 sns.scatterplot( 214 x="T_count", y="function", data=all_data_with_counts, alpha=0.6, ax=axes[1, 1] 215 ) 216 axes[1, 1].set_title("Function vs T Count") 217 axes[1, 1].set_xlabel("Number of T's") 218 axes[1, 1].set_ylabel("Function Value") 219 220 plt.tight_layout() 221 plt.savefig(output_dir / "nucleotide_counts_vs_function.png", dpi=300) 222 print( 223 f"Nucleotide count plots saved to " 224 f"{output_dir / 'nucleotide_counts_vs_function.png'}" 225 ) 226 227 # Test loading the model 228 print("\nTesting model loading...") 229 load_model(model_path) 230 print("Model loaded successfully") 231 232 233 if __name__ == "__main__": 234 main()