/ tests / pytorch / test_forecasting_model.py
test_forecasting_model.py
 1  import os
 2  
 3  import numpy as np
 4  import pytest
 5  import torch
 6  from lightning.pytorch import Trainer
 7  from pytorch_forecasting import DeepAR, TimeSeriesDataSet
 8  from pytorch_forecasting.data.examples import generate_ar_data
 9  
10  import mlflow
11  
12  
13  @pytest.fixture
14  def model_path(tmp_path):
15      return os.path.join(tmp_path, "model")
16  
17  
18  def _gen_forecasting_model_and_data(n_series, timesteps, max_prediction_length):
19      data = generate_ar_data(seasonality=10.0, timesteps=timesteps, n_series=n_series)
20      max_encoder_length = 30
21  
22      time_series_dataset = TimeSeriesDataSet(
23          data[lambda x: x.time_idx <= timesteps - max_prediction_length],
24          time_idx="time_idx",
25          target="value",
26          group_ids=["series"],
27          max_encoder_length=max_encoder_length,
28          max_prediction_length=max_prediction_length,
29          time_varying_unknown_reals=["value"],
30      )
31      deepar = DeepAR.from_dataset(
32          time_series_dataset,
33          learning_rate=1e-3,
34          hidden_size=16,
35          rnn_layers=2,
36      )
37      dataloader = time_series_dataset.to_dataloader(train=True, batch_size=32)
38      trainer = Trainer(max_epochs=2, gradient_clip_val=0.1, accelerator="auto")
39      trainer.fit(deepar, train_dataloaders=dataloader)
40  
41      return deepar, data
42  
43  
44  def test_forecasting_model_pyfunc_loader(model_path: str):
45      n_series = 10
46      max_prediction_length = 20
47      deepar, data = _gen_forecasting_model_and_data(
48          n_series=n_series,
49          timesteps=100,
50          max_prediction_length=max_prediction_length,
51      )
52  
53      torch.manual_seed(42)
54      predicted = deepar.predict(data).numpy()
55      assert predicted.shape == (n_series, max_prediction_length)
56  
57      mlflow.pytorch.save_model(deepar, model_path)
58  
59      pyfunc_loaded = mlflow.pyfunc.load_model(model_path)
60      torch.manual_seed(42)
61      np.testing.assert_array_almost_equal(pyfunc_loaded.predict(data), predicted, decimal=4)
62  
63      with pytest.raises(
64          TypeError,
65          match="The pytorch forecasting model does not support numpy.ndarray",
66      ):
67          pyfunc_loaded.predict(np.array([1.0, 2.0]))