/ tests / test_cli.py
test_cli.py
  1  """Tests for fast_seqfunc.cli."""
  2  
  3  import shutil
  4  import tempfile
  5  from pathlib import Path
  6  
  7  import pandas as pd
  8  import pytest
  9  from typer.testing import CliRunner
 10  
 11  from fast_seqfunc.cli import app
 12  from fast_seqfunc.synthetic import (
 13      create_classification_task,
 14      create_g_count_task,
 15      create_multiclass_task,
 16      generate_dataset_by_task,
 17  )
 18  
 19  
 20  @pytest.fixture
 21  def temp_dir():
 22      """Create a temporary directory for test files."""
 23      tmp_dir = tempfile.mkdtemp()
 24      yield tmp_dir
 25      # Clean up after tests
 26      shutil.rmtree(tmp_dir)
 27  
 28  
 29  @pytest.fixture
 30  def g_count_data(temp_dir):
 31      """Generate a G-count dataset and save to CSV."""
 32      # Generate a simple dataset where the function is the count of G's
 33      df = create_g_count_task(count=500, length=20, noise_level=0.1)
 34  
 35      # Save to CSV in the temp directory
 36      data_path = Path(temp_dir) / "g_count_data.csv"
 37      df.to_csv(data_path, index=False)
 38  
 39      return data_path
 40  
 41  
 42  @pytest.fixture
 43  def binary_classification_data(temp_dir):
 44      """Generate a classification dataset and save to CSV."""
 45      df = create_classification_task(count=500, length=20, noise_level=0.1)
 46  
 47      # Save to CSV in the temp directory
 48      data_path = Path(temp_dir) / "classification_data.csv"
 49      df.to_csv(data_path, index=False)
 50  
 51      return data_path
 52  
 53  
 54  @pytest.fixture
 55  def multiclass_data(temp_dir):
 56      """Generate a multi-class dataset and save to CSV."""
 57      df = create_multiclass_task(count=500, length=20, noise_level=0.1)
 58  
 59      # Save to CSV in the temp directory
 60      data_path = Path(temp_dir) / "multiclass_data.csv"
 61      df.to_csv(data_path, index=False)
 62  
 63      return data_path
 64  
 65  
 66  @pytest.fixture
 67  def test_tasks():
 68      """Define a list of test tasks."""
 69      return [
 70          "g_count",
 71          "gc_content",
 72          "motif_position",
 73          "motif_count",
 74          "nonlinear_composition",
 75          "interaction",
 76      ]
 77  
 78  
 79  def test_cli_hello():
 80      """Test the hello command."""
 81      runner = CliRunner()
 82      result = runner.invoke(app, ["hello"])
 83      assert result.exit_code == 0
 84      assert "fast-seqfunc" in result.stdout
 85  
 86  
 87  def test_cli_describe():
 88      """Test the describe command."""
 89      runner = CliRunner()
 90      result = runner.invoke(app, ["describe"])
 91      assert result.exit_code == 0
 92      assert "sequence-function" in result.stdout
 93  
 94  
 95  def test_cli_g_count_regression(g_count_data, temp_dir):
 96      """Test CLI with G-count regression task."""
 97      runner = CliRunner()
 98      model_path = Path(temp_dir) / "model.pkl"
 99  
100      # Train model
101      result = runner.invoke(
102          app,
103          [
104              "train",
105              str(g_count_data),
106              "--sequence-col",
107              "sequence",
108              "--target-col",
109              "function",
110              "--embedding-method",
111              "one-hot",
112              "--model-type",
113              "regression",
114              "--output-path",
115              str(model_path),
116          ],
117      )
118  
119      assert result.exit_code == 0
120      assert model_path.exists()
121  
122      # Make predictions
123      predictions_path = Path(temp_dir) / "predictions.csv"
124      result = runner.invoke(
125          app,
126          [
127              "predict-cmd",
128              str(model_path),
129              str(g_count_data),
130              "--sequence-col",
131              "sequence",
132              "--output-path",
133              str(predictions_path),
134          ],
135      )
136  
137      assert result.exit_code == 0
138      assert predictions_path.exists()
139  
140      # Verify predictions file has expected columns
141      predictions_df = pd.read_csv(predictions_path)
142      assert "sequence" in predictions_df.columns
143      assert "prediction" in predictions_df.columns
144  
145  
146  def test_cli_classification(binary_classification_data, temp_dir):
147      """Test CLI with binary classification task."""
148      runner = CliRunner()
149      model_path = Path(temp_dir) / "model_classification.pkl"
150  
151      # Train model
152      result = runner.invoke(
153          app,
154          [
155              "train",
156              str(binary_classification_data),
157              "--sequence-col",
158              "sequence",
159              "--target-col",
160              "function",
161              "--embedding-method",
162              "one-hot",
163              "--model-type",
164              "classification",
165              "--output-path",
166              str(model_path),
167          ],
168      )
169  
170      assert result.exit_code == 0
171      assert model_path.exists()
172  
173      # Make predictions
174      predictions_path = Path(temp_dir) / "predictions_classification.csv"
175      result = runner.invoke(
176          app,
177          [
178              "predict-cmd",
179              str(model_path),
180              str(binary_classification_data),
181              "--sequence-col",
182              "sequence",
183              "--output-path",
184              str(predictions_path),
185          ],
186      )
187  
188      assert result.exit_code == 0
189      assert predictions_path.exists()
190  
191      # Verify predictions file has expected columns
192      predictions_df = pd.read_csv(predictions_path)
193      assert "sequence" in predictions_df.columns
194      assert "prediction" in predictions_df.columns
195  
196  
197  def test_cli_multiclass(multiclass_data, temp_dir):
198      """Test CLI with multi-class classification task."""
199      runner = CliRunner()
200      model_path = Path(temp_dir) / "model_multiclass.pkl"
201  
202      # Train model
203      result = runner.invoke(
204          app,
205          [
206              "train",
207              str(multiclass_data),
208              "--sequence-col",
209              "sequence",
210              "--target-col",
211              "function",
212              "--embedding-method",
213              "one-hot",
214              "--model-type",
215              "classification",
216              "--output-path",
217              str(model_path),
218          ],
219      )
220  
221      assert result.exit_code == 0
222      assert model_path.exists()
223  
224      # Make predictions
225      predictions_path = Path(temp_dir) / "predictions_multiclass.csv"
226      result = runner.invoke(
227          app,
228          [
229              "predict-cmd",
230              str(model_path),
231              str(multiclass_data),
232              "--sequence-col",
233              "sequence",
234              "--output-path",
235              str(predictions_path),
236          ],
237      )
238  
239      assert result.exit_code == 0
240      assert predictions_path.exists()
241  
242      # Verify predictions file has expected columns
243      predictions_df = pd.read_csv(predictions_path)
244      assert "sequence" in predictions_df.columns
245      assert "prediction" in predictions_df.columns
246  
247  
248  def test_cli_compare_embeddings(g_count_data, temp_dir):
249      """Test CLI for comparing embedding methods."""
250      runner = CliRunner()
251      comparison_path = Path(temp_dir) / "embedding_comparison.csv"
252  
253      # Run comparison
254      result = runner.invoke(
255          app,
256          [
257              "compare-embeddings",
258              str(g_count_data),
259              "--output-path",
260              str(comparison_path),
261          ],
262      )
263  
264      # NOTE: This test might take longer as it compares multiple embedding methods
265      # We just check that the command runs without error
266      assert result.exit_code == 0
267  
268      # The comparison might not complete if some embedding methods aren't available,
269      # but the file should at least be created
270      assert comparison_path.exists()
271  
272  
273  @pytest.mark.parametrize(
274      "task",
275      [
276          "g_count",
277          "gc_content",
278          "motif_position",
279      ],
280  )
281  def test_cli_with_different_tasks(task, temp_dir):
282      """Test CLI with different sequence-function tasks."""
283      runner = CliRunner()
284  
285      # Generate dataset
286      df = generate_dataset_by_task(task=task, count=500, noise_level=0.1)
287      data_path = Path(temp_dir) / f"{task}_data.csv"
288      df.to_csv(data_path, index=False)
289  
290      # Train model
291      model_path = Path(temp_dir) / f"{task}_model.pkl"
292      result = runner.invoke(
293          app, ["train", str(data_path), "--output-path", str(model_path)]
294      )
295  
296      assert result.exit_code == 0
297      assert model_path.exists()
298  
299      # Make predictions
300      predictions_path = Path(temp_dir) / f"{task}_predictions.csv"
301      result = runner.invoke(
302          app,
303          [
304              "predict-cmd",
305              str(model_path),
306              str(data_path),
307              "--output-path",
308              str(predictions_path),
309          ],
310      )
311  
312      assert result.exit_code == 0
313      assert predictions_path.exists()