synthetic_data_demo.py
1 # /// script 2 # requires-python = ">=3.11" 3 # dependencies = [ 4 # "fast-seqfunc", 5 # "matplotlib", 6 # "seaborn", 7 # "pandas", 8 # "numpy", 9 # ] 10 # /// 11 12 """Demo script for generating and visualizing synthetic sequence-function data. 13 14 This script demonstrates how to generate various synthetic datasets using 15 the fast-seqfunc.synthetic module and train models on them. 16 """ 17 18 import argparse 19 import tempfile 20 from pathlib import Path 21 22 import matplotlib.pyplot as plt 23 import numpy as np 24 import seaborn as sns 25 from loguru import logger 26 27 from fast_seqfunc.core import predict, train_model 28 from fast_seqfunc.synthetic import generate_dataset_by_task 29 30 31 def parse_args(): 32 """Parse command line arguments.""" 33 parser = argparse.ArgumentParser( 34 description="Generate and visualize synthetic sequence-function data" 35 ) 36 parser.add_argument( 37 "--task", 38 type=str, 39 default="g_count", 40 choices=[ 41 "g_count", 42 "gc_content", 43 "motif_position", 44 "motif_count", 45 "length_dependent", 46 "nonlinear_composition", 47 "interaction", 48 "classification", 49 "multiclass", 50 ], 51 help="Sequence-function task to generate", 52 ) 53 parser.add_argument( 54 "--count", type=int, default=500, help="Number of sequences to generate" 55 ) 56 parser.add_argument( 57 "--noise", type=float, default=0.1, help="Noise level to add to the data" 58 ) 59 parser.add_argument( 60 "--output", type=str, default="synthetic_data.csv", help="Output file path" 61 ) 62 parser.add_argument( 63 "--plot", action="store_true", help="Generate plots of the data" 64 ) 65 parser.add_argument( 66 "--train", action="store_true", help="Train a model on the generated data" 67 ) 68 69 return parser.parse_args() 70 71 72 def visualize_data(df, task_name): 73 """Create visualizations for the generated data. 74 75 :param df: DataFrame with sequences and functions 76 :param task_name: Name of the task for plot title 77 """ 78 plt.figure(figsize=(14, 6)) 79 80 # For classification tasks, show class distribution 81 if task_name in ["classification", "multiclass"]: 82 plt.subplot(1, 2, 1) 83 df["function"].value_counts().plot(kind="bar") 84 plt.title(f"Class Distribution for {task_name}") 85 plt.xlabel("Class") 86 plt.ylabel("Count") 87 88 plt.subplot(1, 2, 2) 89 # Show sequence length distribution 90 df["seq_length"] = df["sequence"].apply(len) 91 sns.histplot(df["seq_length"], kde=True) 92 plt.title("Sequence Length Distribution") 93 plt.xlabel("Sequence Length") 94 else: 95 # For regression tasks, show function distribution 96 plt.subplot(1, 2, 1) 97 sns.histplot(df["function"], kde=True) 98 plt.title(f"Function Distribution for {task_name}") 99 plt.xlabel("Function Value") 100 101 plt.subplot(1, 2, 2) 102 # For tasks with variable length, plot function vs length 103 if task_name == "length_dependent": 104 df["seq_length"] = df["sequence"].apply(len) 105 sns.scatterplot(x="seq_length", y="function", data=df) 106 plt.title("Function vs Sequence Length") 107 plt.xlabel("Sequence Length") 108 plt.ylabel("Function Value") 109 # For GC content, show relationship with function 110 elif task_name in ["g_count", "gc_content"]: 111 df["gc_content"] = df["sequence"].apply( 112 lambda s: (s.count("G") + s.count("C")) / len(s) 113 ) 114 sns.scatterplot(x="gc_content", y="function", data=df) 115 plt.title("Function vs GC Content") 116 plt.xlabel("GC Content") 117 plt.ylabel("Function Value") 118 # For other tasks, show example sequences 119 else: 120 # Sample 10 random sequences to display 121 examples = df.sample(min(10, len(df))) 122 plt.clf() 123 plt.figure(figsize=(12, 6)) 124 plt.bar(range(len(examples)), examples["function"]) 125 plt.xticks(range(len(examples)), examples["sequence"], rotation=45) 126 plt.title(f"Example Sequences for {task_name}") 127 plt.xlabel("Sequence") 128 plt.ylabel("Function Value") 129 130 plt.tight_layout() 131 plt.savefig(f"{task_name}_visualization.png") 132 logger.info(f"Visualization saved to {task_name}_visualization.png") 133 plt.close() 134 135 136 def train_and_evaluate(df, task_name): 137 """Train a model on the generated data and evaluate it. 138 139 :param df: DataFrame with sequences and functions 140 :param task_name: Name of the task 141 """ 142 # Split data into train/test 143 np.random.seed(42) 144 msk = np.random.rand(len(df)) < 0.8 145 train_df = df[msk].reset_index(drop=True) 146 test_df = df[~msk].reset_index(drop=True) 147 148 # Save train/test data to temp files 149 with tempfile.TemporaryDirectory() as tmp_dir: 150 tmp_dir = Path(tmp_dir) 151 train_path = tmp_dir / "train_data.csv" 152 test_path = tmp_dir / "test_data.csv" 153 154 train_df.to_csv(train_path, index=False) 155 test_df.to_csv(test_path, index=False) 156 157 # Determine model type based on task 158 if task_name == "classification": 159 model_type = "classification" 160 elif task_name == "multiclass": 161 model_type = "multi-class" 162 else: 163 model_type = "regression" 164 165 logger.info(f"Training {model_type} model for {task_name} task") 166 167 # Train model 168 model = train_model( 169 train_data=train_path, 170 test_data=test_path, 171 sequence_col="sequence", 172 target_col="function", 173 embedding_method="one-hot", 174 model_type=model_type, 175 ) 176 177 # Make predictions on test data 178 predictions = predict(model, test_df["sequence"]) 179 180 # Calculate and print metrics 181 if model_type == "regression": 182 from sklearn.metrics import ( 183 mean_absolute_error, 184 mean_squared_error, 185 r2_score, 186 ) 187 188 mae = mean_absolute_error(test_df["function"], predictions) 189 rmse = np.sqrt(mean_squared_error(test_df["function"], predictions)) 190 r2 = r2_score(test_df["function"], predictions) 191 192 logger.info(f"Test MAE: {mae:.4f}") 193 logger.info(f"Test RMSE: {rmse:.4f}") 194 logger.info(f"Test R²: {r2:.4f}") 195 196 # Scatter plot of actual vs predicted values 197 plt.figure(figsize=(8, 8)) 198 plt.scatter(test_df["function"], predictions, alpha=0.5) 199 plt.plot( 200 [test_df["function"].min(), test_df["function"].max()], 201 [test_df["function"].min(), test_df["function"].max()], 202 "k--", 203 lw=2, 204 ) 205 plt.xlabel("Actual Values") 206 plt.ylabel("Predicted Values") 207 plt.title(f"Actual vs Predicted for {task_name}") 208 plt.savefig(f"{task_name}_predictions.png") 209 plt.close() 210 211 else: # Classification 212 from sklearn.metrics import accuracy_score, classification_report 213 214 accuracy = accuracy_score(test_df["function"], predictions.round()) 215 logger.info(f"Test Accuracy: {accuracy:.4f}") 216 logger.info("\nClassification Report:") 217 report = classification_report(test_df["function"], predictions.round()) 218 logger.info(report) 219 220 # Confusion matrix 221 import seaborn as sns 222 from sklearn.metrics import confusion_matrix 223 224 cm = confusion_matrix(test_df["function"], predictions.round()) 225 plt.figure(figsize=(8, 8)) 226 sns.heatmap(cm, annot=True, fmt="d", cmap="Blues") 227 plt.xlabel("Predicted") 228 plt.ylabel("Actual") 229 plt.title(f"Confusion Matrix for {task_name}") 230 plt.savefig(f"{task_name}_confusion_matrix.png") 231 plt.close() 232 233 234 def main(): 235 """Run the demo.""" 236 args = parse_args() 237 238 logger.info(f"Generating {args.count} sequences for {args.task} task") 239 df = generate_dataset_by_task( 240 task=args.task, 241 count=args.count, 242 noise_level=args.noise, 243 ) 244 245 # Save data to CSV 246 df.to_csv(args.output, index=False) 247 logger.info(f"Data saved to {args.output}") 248 249 # Generate plots if requested 250 if args.plot: 251 logger.info("Generating visualizations") 252 visualize_data(df, args.task) 253 254 # Train model if requested 255 if args.train: 256 logger.info("Training model on generated data") 257 train_and_evaluate(df, args.task) 258 259 260 if __name__ == "__main__": 261 main()