/ examples / variable_length_sequences.py
variable_length_sequences.py
  1  #!/usr/bin/env python
  2  # /// script
  3  # requires-python = ">=3.11"
  4  # dependencies = [
  5  #   "fast-seqfunc",
  6  #   "pandas",
  7  #   "numpy",
  8  #   "matplotlib",
  9  #   "seaborn",
 10  #   "scikit-learn>=1.0.0",
 11  #   "fast-seqfunc @ git+https://github.com/ericmjl/fast-seqfunc.git",
 12  # ]
 13  # ///
 14  
 15  """
 16  Variable Length Sequences Example for fast-seqfunc.
 17  
 18  This script demonstrates how to:
 19  1. Generate synthetic DNA sequences of variable lengths
 20  2. Use padding options to train a sequence-function model
 21  3. Compare different padding strategies
 22  4. Make predictions on new sequences of different lengths
 23  """
 24  
 25  import random
 26  from pathlib import Path
 27  from typing import List, Tuple
 28  
 29  import matplotlib.pyplot as plt
 30  import numpy as np
 31  import pandas as pd
 32  import seaborn as sns
 33  from loguru import logger
 34  
 35  from fast_seqfunc import load_model, predict, save_model, train_model
 36  from fast_seqfunc.embedders import OneHotEmbedder
 37  
 38  # Set random seed for reproducibility
 39  np.random.seed(42)
 40  random.seed(42)
 41  
 42  
 43  def generate_variable_length_sequence(
 44      min_length: int = 50, max_length: int = 150
 45  ) -> str:
 46      """Generate a random DNA sequence with variable length.
 47  
 48      :param min_length: Minimum sequence length
 49      :param max_length: Maximum sequence length
 50      :return: Random DNA sequence
 51      """
 52      length = random.randint(min_length, max_length)
 53      nucleotides = "ACGT"
 54      return "".join(random.choice(nucleotides) for _ in range(length))
 55  
 56  
 57  def generate_variable_length_data(
 58      n_samples: int = 1000, min_length: int = 50, max_length: int = 150
 59  ) -> pd.DataFrame:
 60      """Generate synthetic variable-length sequence-function data.
 61  
 62      The function value depends on:
 63      1. The GC content (proportion of G and C nucleotides)
 64      2. The length of the sequence
 65  
 66      :param n_samples: Number of samples to generate
 67      :param min_length: Minimum sequence length
 68      :param max_length: Maximum sequence length
 69      :return: DataFrame with sequences and function values
 70      """
 71      sequences = []
 72      lengths = []
 73  
 74      for _ in range(n_samples):
 75          seq = generate_variable_length_sequence(min_length, max_length)
 76          sequences.append(seq)
 77          lengths.append(len(seq))
 78  
 79      # Calculate function values based on GC content and length
 80      gc_contents = [(seq.count("G") + seq.count("C")) / len(seq) for seq in sequences]
 81  
 82      # Function value = normalized GC content + normalized length + noise
 83      normalized_gc = [(gc - 0.5) * 2 for gc in gc_contents]  # -1 to 1
 84      normalized_length = [
 85          (length - min_length) / (max_length - min_length) for length in lengths
 86      ]  # 0 to 1
 87  
 88      functions = [
 89          0.6 * gc + 0.4 * length + np.random.normal(0, 0.05)
 90          for gc, length in zip(normalized_gc, normalized_length)
 91      ]
 92  
 93      # Create DataFrame
 94      df = pd.DataFrame(
 95          {
 96              "sequence": sequences,
 97              "function": functions,
 98              "length": lengths,
 99              "gc_content": gc_contents,
100          }
101      )
102  
103      return df
104  
105  
106  def compare_padding_strategies(
107      train_data: pd.DataFrame, test_data: pd.DataFrame
108  ) -> Tuple[dict, dict, dict]:
109      """Compare different padding strategies for variable-length sequences.
110  
111      :param train_data: Training data
112      :param test_data: Test data
113      :return: Tuple of model info for each strategy
114          (no padding, default padding, custom padding)
115      """
116      logger.info("Training model with padding disabled...")
117      model_no_padding = train_model(
118          train_data=train_data,
119          test_data=test_data,
120          sequence_col="sequence",
121          target_col="function",
122          embedding_method="one-hot",
123          model_type="regression",
124          optimization_metric="r2",
125          embedder_kwargs={"pad_sequences": False},
126      )
127  
128      logger.info("Training model with default padding (gap character '-')...")
129      model_default_padding = train_model(
130          train_data=train_data,
131          test_data=test_data,
132          sequence_col="sequence",
133          target_col="function",
134          embedding_method="one-hot",
135          model_type="regression",
136          optimization_metric="r2",
137          embedder_kwargs={"pad_sequences": True, "gap_character": "-"},
138      )
139  
140      logger.info("Training model with custom padding (gap character 'X')...")
141      model_custom_padding = train_model(
142          train_data=train_data,
143          test_data=test_data,
144          sequence_col="sequence",
145          target_col="function",
146          embedding_method="one-hot",
147          model_type="regression",
148          optimization_metric="r2",
149          embedder_kwargs={"pad_sequences": True, "gap_character": "X"},
150      )
151  
152      return model_no_padding, model_default_padding, model_custom_padding
153  
154  
155  def demonstrate_embedder_usage() -> None:
156      """Demonstrate direct usage of the OneHotEmbedder with padding options."""
157      logger.info("Demonstrating direct usage of OneHotEmbedder...")
158  
159      # Create some example sequences of different lengths
160      sequences = ["ACGT", "AATT", "GCGCGCGC", "A"]
161      logger.info(f"Example sequences: {sequences}")
162  
163      # Default embedder (pads with '-')
164      embedder = OneHotEmbedder(sequence_type="dna")
165      embeddings = embedder.fit_transform(sequences)
166      logger.info("Default embedder (padding enabled):")
167      logger.info(f"  - Embeddings shape: {embeddings.shape}")
168      logger.info(f"  - Max length detected: {embedder.max_length}")
169      logger.info(f"  - Alphabet: {embedder.alphabet}")
170  
171      # Embedder with explicit max_length
172      embedder_max = OneHotEmbedder(sequence_type="dna", max_length=10)
173      embeddings_max = embedder_max.fit_transform(sequences)
174      logger.info("Embedder with explicit max_length=10:")
175      logger.info(f"  - Embeddings shape: {embeddings_max.shape}")
176  
177      # Embedder with custom gap character
178      embedder_custom = OneHotEmbedder(sequence_type="dna", gap_character="X")
179      _ = embedder_custom.fit_transform(sequences)
180      logger.info("Embedder with custom gap character 'X':")
181      logger.info(f"  - Alphabet: {embedder_custom.alphabet}")
182  
183      # Embedder with padding disabled
184      embedder_no_pad = OneHotEmbedder(sequence_type="dna", pad_sequences=False)
185      embeddings_no_pad = embedder_no_pad.fit_transform(sequences)
186      logger.info("Embedder with padding disabled:")
187      logger.info(f"  - Number of embeddings: {len(embeddings_no_pad)}")
188      logger.info("  - Shapes of individual embeddings:")
189      for i, emb in enumerate(embeddings_no_pad):
190          logger.info(
191              f"    - Sequence {i} ({len(sequences[i])} nucleotides): {emb.shape}"
192          )
193  
194  
195  def plot_results(
196      test_data: pd.DataFrame,
197      models: List[dict],
198      model_names: List[str],
199      output_dir: Path,
200  ) -> None:
201      """Plot comparison of different padding strategies.
202  
203      :param test_data: Test data
204      :param models: List of trained models
205      :param model_names: Names of the models
206      :param output_dir: Output directory for plots
207      """
208      # Plot test predictions for each model
209      plt.figure(figsize=(10, 8))
210  
211      true_values = test_data["function"]
212  
213      for model, name in zip(models, model_names):
214          predictions = predict(model, test_data["sequence"])
215  
216          # Calculate R²
217          r2 = (
218              1
219              - ((predictions - true_values) ** 2).sum()
220              / ((true_values - true_values.mean()) ** 2).sum()
221          )
222  
223          # Plot
224          plt.scatter(
225              true_values, predictions, alpha=0.5, label=f"{name} (R² = {r2:.4f})"
226          )
227  
228      # Plot identity line
229      plt.plot(
230          [min(true_values), max(true_values)],
231          [min(true_values), max(true_values)],
232          "r--",
233      )
234  
235      plt.xlabel("True Function Value")
236      plt.ylabel("Predicted Function Value")
237      plt.title("Comparison of Padding Strategies for Variable-Length Sequences")
238      plt.legend()
239      plt.tight_layout()
240      plt.savefig(output_dir / "padding_comparison.png", dpi=300)
241      logger.info(f"Plot saved to {output_dir / 'padding_comparison.png'}")
242  
243      # Plot function vs length
244      plt.figure(figsize=(10, 6))
245      sns.scatterplot(x="length", y="function", data=test_data, alpha=0.6)
246      plt.xlabel("Sequence Length")
247      plt.ylabel("Function Value")
248      plt.title("Function Value vs Sequence Length")
249      plt.tight_layout()
250      plt.savefig(output_dir / "function_vs_length.png", dpi=300)
251      logger.info(f"Plot saved to {output_dir / 'function_vs_length.png'}")
252  
253      # Plot function vs GC content
254      plt.figure(figsize=(10, 6))
255      sns.scatterplot(x="gc_content", y="function", data=test_data, alpha=0.6)
256      plt.xlabel("GC Content")
257      plt.ylabel("Function Value")
258      plt.title("Function Value vs GC Content")
259      plt.tight_layout()
260      plt.savefig(output_dir / "function_vs_gc_content.png", dpi=300)
261      logger.info(f"Plot saved to {output_dir / 'function_vs_gc_content.png'}")
262  
263  
264  def main() -> None:
265      """Run the example pipeline."""
266      logger.info("Fast-SeqFunc Variable Length Sequences Example")
267      logger.info("============================================")
268  
269      # Create directory for outputs
270      output_dir = Path("examples/output/variable_length")
271      output_dir.mkdir(parents=True, exist_ok=True)
272  
273      # Generate synthetic data
274      logger.info("Generating synthetic data with variable-length sequences...")
275      n_samples = 2000
276      min_length = 50
277      max_length = 150
278      all_data = generate_variable_length_data(
279          n_samples=n_samples, min_length=min_length, max_length=max_length
280      )
281  
282      # Display statistics
283      logger.info(
284          f"Generated {n_samples} sequences "
285          f"with lengths from {min_length} to {max_length}"
286      )
287      logger.info("Sequence length statistics:")
288      logger.info(f"  - Mean: {all_data['length'].mean():.1f}")
289      logger.info(f"  - Min: {all_data['length'].min()}")
290      logger.info(f"  - Max: {all_data['length'].max()}")
291  
292      # Split into train and test sets
293      train_size = int(0.8 * n_samples)
294      train_data = all_data[:train_size].copy()
295      test_data = all_data[train_size:].copy()
296  
297      logger.info(
298          f"Data split: {train_size} train, {n_samples - train_size} test samples"
299      )
300  
301      # Save data files
302      train_data.to_csv(output_dir / "train_data.csv", index=False)
303      test_data.to_csv(output_dir / "test_data.csv", index=False)
304  
305      # Demonstrate direct usage of the OneHotEmbedder
306      demonstrate_embedder_usage()
307  
308      # Compare different padding strategies
309      logger.info("\nComparing different padding strategies...")
310      model_no_padding, model_default_padding, model_custom_padding = (
311          compare_padding_strategies(train_data, test_data)
312      )
313  
314      # Display test results for each model
315      for name, model in [
316          ("No Padding", model_no_padding),
317          ("Default Padding", model_default_padding),
318          ("Custom Padding", model_custom_padding),
319      ]:
320          if model.get("test_results"):
321              logger.info(f"\nTest metrics for {name}:")
322              for metric, value in model["test_results"].items():
323                  logger.info(f"  {metric}: {value:.4f}")
324  
325      # Save models
326      save_model(model_default_padding, output_dir / "model_default_padding.pkl")
327      logger.info(
328          f"Default padding model saved to {output_dir / 'model_default_padding.pkl'}"
329      )
330  
331      # Plot results
332      logger.info("\nCreating comparison plots...")
333      plot_results(
334          test_data,
335          [model_no_padding, model_default_padding, model_custom_padding],
336          ["No Padding", "Default Padding (-)", "Custom Padding (X)"],
337          output_dir,
338      )
339  
340      # Generate new test sequences with different lengths
341      logger.info("\nTesting prediction on new sequences of different lengths...")
342      new_sequences = [generate_variable_length_sequence(30, 200) for _ in range(5)]
343  
344      # Show predictions using the default padding model
345      loaded_model = load_model(output_dir / "model_default_padding.pkl")
346      predictions = predict(loaded_model, new_sequences)
347  
348      # Display results
349      for seq, pred in zip(new_sequences, predictions):
350          gc_content = (seq.count("G") + seq.count("C")) / len(seq)
351          logger.info(
352              f"Sequence length: {len(seq)}, GC content: {gc_content:.2f}, "
353              f"Predicted function: {pred:.4f}"
354          )
355  
356  
357  if __name__ == "__main__":
358      main()