/ examples / synthetic_data_demo.py
synthetic_data_demo.py
  1  # /// script
  2  # requires-python = ">=3.11"
  3  # dependencies = [
  4  #   "fast-seqfunc",
  5  #   "matplotlib",
  6  #   "seaborn",
  7  #   "pandas",
  8  #   "numpy",
  9  # ]
 10  # ///
 11  
 12  """Demo script for generating and visualizing synthetic sequence-function data.
 13  
 14  This script demonstrates how to generate various synthetic datasets using
 15  the fast-seqfunc.synthetic module and train models on them.
 16  """
 17  
 18  import argparse
 19  import tempfile
 20  from pathlib import Path
 21  
 22  import matplotlib.pyplot as plt
 23  import numpy as np
 24  import seaborn as sns
 25  from loguru import logger
 26  
 27  from fast_seqfunc.core import predict, train_model
 28  from fast_seqfunc.synthetic import generate_dataset_by_task
 29  
 30  
 31  def parse_args():
 32      """Parse command line arguments."""
 33      parser = argparse.ArgumentParser(
 34          description="Generate and visualize synthetic sequence-function data"
 35      )
 36      parser.add_argument(
 37          "--task",
 38          type=str,
 39          default="g_count",
 40          choices=[
 41              "g_count",
 42              "gc_content",
 43              "motif_position",
 44              "motif_count",
 45              "length_dependent",
 46              "nonlinear_composition",
 47              "interaction",
 48              "classification",
 49              "multiclass",
 50          ],
 51          help="Sequence-function task to generate",
 52      )
 53      parser.add_argument(
 54          "--count", type=int, default=500, help="Number of sequences to generate"
 55      )
 56      parser.add_argument(
 57          "--noise", type=float, default=0.1, help="Noise level to add to the data"
 58      )
 59      parser.add_argument(
 60          "--output", type=str, default="synthetic_data.csv", help="Output file path"
 61      )
 62      parser.add_argument(
 63          "--plot", action="store_true", help="Generate plots of the data"
 64      )
 65      parser.add_argument(
 66          "--train", action="store_true", help="Train a model on the generated data"
 67      )
 68  
 69      return parser.parse_args()
 70  
 71  
 72  def visualize_data(df, task_name):
 73      """Create visualizations for the generated data.
 74  
 75      :param df: DataFrame with sequences and functions
 76      :param task_name: Name of the task for plot title
 77      """
 78      plt.figure(figsize=(14, 6))
 79  
 80      # For classification tasks, show class distribution
 81      if task_name in ["classification", "multiclass"]:
 82          plt.subplot(1, 2, 1)
 83          df["function"].value_counts().plot(kind="bar")
 84          plt.title(f"Class Distribution for {task_name}")
 85          plt.xlabel("Class")
 86          plt.ylabel("Count")
 87  
 88          plt.subplot(1, 2, 2)
 89          # Show sequence length distribution
 90          df["seq_length"] = df["sequence"].apply(len)
 91          sns.histplot(df["seq_length"], kde=True)
 92          plt.title("Sequence Length Distribution")
 93          plt.xlabel("Sequence Length")
 94      else:
 95          # For regression tasks, show function distribution
 96          plt.subplot(1, 2, 1)
 97          sns.histplot(df["function"], kde=True)
 98          plt.title(f"Function Distribution for {task_name}")
 99          plt.xlabel("Function Value")
100  
101          plt.subplot(1, 2, 2)
102          # For tasks with variable length, plot function vs length
103          if task_name == "length_dependent":
104              df["seq_length"] = df["sequence"].apply(len)
105              sns.scatterplot(x="seq_length", y="function", data=df)
106              plt.title("Function vs Sequence Length")
107              plt.xlabel("Sequence Length")
108              plt.ylabel("Function Value")
109          # For GC content, show relationship with function
110          elif task_name in ["g_count", "gc_content"]:
111              df["gc_content"] = df["sequence"].apply(
112                  lambda s: (s.count("G") + s.count("C")) / len(s)
113              )
114              sns.scatterplot(x="gc_content", y="function", data=df)
115              plt.title("Function vs GC Content")
116              plt.xlabel("GC Content")
117              plt.ylabel("Function Value")
118          # For other tasks, show example sequences
119          else:
120              # Sample 10 random sequences to display
121              examples = df.sample(min(10, len(df)))
122              plt.clf()
123              plt.figure(figsize=(12, 6))
124              plt.bar(range(len(examples)), examples["function"])
125              plt.xticks(range(len(examples)), examples["sequence"], rotation=45)
126              plt.title(f"Example Sequences for {task_name}")
127              plt.xlabel("Sequence")
128              plt.ylabel("Function Value")
129  
130      plt.tight_layout()
131      plt.savefig(f"{task_name}_visualization.png")
132      logger.info(f"Visualization saved to {task_name}_visualization.png")
133      plt.close()
134  
135  
136  def train_and_evaluate(df, task_name):
137      """Train a model on the generated data and evaluate it.
138  
139      :param df: DataFrame with sequences and functions
140      :param task_name: Name of the task
141      """
142      # Split data into train/test
143      np.random.seed(42)
144      msk = np.random.rand(len(df)) < 0.8
145      train_df = df[msk].reset_index(drop=True)
146      test_df = df[~msk].reset_index(drop=True)
147  
148      # Save train/test data to temp files
149      with tempfile.TemporaryDirectory() as tmp_dir:
150          tmp_dir = Path(tmp_dir)
151          train_path = tmp_dir / "train_data.csv"
152          test_path = tmp_dir / "test_data.csv"
153  
154          train_df.to_csv(train_path, index=False)
155          test_df.to_csv(test_path, index=False)
156  
157          # Determine model type based on task
158          if task_name == "classification":
159              model_type = "classification"
160          elif task_name == "multiclass":
161              model_type = "multi-class"
162          else:
163              model_type = "regression"
164  
165          logger.info(f"Training {model_type} model for {task_name} task")
166  
167          # Train model
168          model = train_model(
169              train_data=train_path,
170              test_data=test_path,
171              sequence_col="sequence",
172              target_col="function",
173              embedding_method="one-hot",
174              model_type=model_type,
175          )
176  
177          # Make predictions on test data
178          predictions = predict(model, test_df["sequence"])
179  
180          # Calculate and print metrics
181          if model_type == "regression":
182              from sklearn.metrics import (
183                  mean_absolute_error,
184                  mean_squared_error,
185                  r2_score,
186              )
187  
188              mae = mean_absolute_error(test_df["function"], predictions)
189              rmse = np.sqrt(mean_squared_error(test_df["function"], predictions))
190              r2 = r2_score(test_df["function"], predictions)
191  
192              logger.info(f"Test MAE: {mae:.4f}")
193              logger.info(f"Test RMSE: {rmse:.4f}")
194              logger.info(f"Test R²: {r2:.4f}")
195  
196              # Scatter plot of actual vs predicted values
197              plt.figure(figsize=(8, 8))
198              plt.scatter(test_df["function"], predictions, alpha=0.5)
199              plt.plot(
200                  [test_df["function"].min(), test_df["function"].max()],
201                  [test_df["function"].min(), test_df["function"].max()],
202                  "k--",
203                  lw=2,
204              )
205              plt.xlabel("Actual Values")
206              plt.ylabel("Predicted Values")
207              plt.title(f"Actual vs Predicted for {task_name}")
208              plt.savefig(f"{task_name}_predictions.png")
209              plt.close()
210  
211          else:  # Classification
212              from sklearn.metrics import accuracy_score, classification_report
213  
214              accuracy = accuracy_score(test_df["function"], predictions.round())
215              logger.info(f"Test Accuracy: {accuracy:.4f}")
216              logger.info("\nClassification Report:")
217              report = classification_report(test_df["function"], predictions.round())
218              logger.info(report)
219  
220              # Confusion matrix
221              import seaborn as sns
222              from sklearn.metrics import confusion_matrix
223  
224              cm = confusion_matrix(test_df["function"], predictions.round())
225              plt.figure(figsize=(8, 8))
226              sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
227              plt.xlabel("Predicted")
228              plt.ylabel("Actual")
229              plt.title(f"Confusion Matrix for {task_name}")
230              plt.savefig(f"{task_name}_confusion_matrix.png")
231              plt.close()
232  
233  
234  def main():
235      """Run the demo."""
236      args = parse_args()
237  
238      logger.info(f"Generating {args.count} sequences for {args.task} task")
239      df = generate_dataset_by_task(
240          task=args.task,
241          count=args.count,
242          noise_level=args.noise,
243      )
244  
245      # Save data to CSV
246      df.to_csv(args.output, index=False)
247      logger.info(f"Data saved to {args.output}")
248  
249      # Generate plots if requested
250      if args.plot:
251          logger.info("Generating visualizations")
252          visualize_data(df, args.task)
253  
254      # Train model if requested
255      if args.train:
256          logger.info("Training model on generated data")
257          train_and_evaluate(df, args.task)
258  
259  
260  if __name__ == "__main__":
261      main()