test_forecasting_model.py
1 import os 2 3 import numpy as np 4 import pytest 5 import torch 6 from lightning.pytorch import Trainer 7 from pytorch_forecasting import DeepAR, TimeSeriesDataSet 8 from pytorch_forecasting.data.examples import generate_ar_data 9 10 import mlflow 11 12 13 @pytest.fixture 14 def model_path(tmp_path): 15 return os.path.join(tmp_path, "model") 16 17 18 def _gen_forecasting_model_and_data(n_series, timesteps, max_prediction_length): 19 data = generate_ar_data(seasonality=10.0, timesteps=timesteps, n_series=n_series) 20 max_encoder_length = 30 21 22 time_series_dataset = TimeSeriesDataSet( 23 data[lambda x: x.time_idx <= timesteps - max_prediction_length], 24 time_idx="time_idx", 25 target="value", 26 group_ids=["series"], 27 max_encoder_length=max_encoder_length, 28 max_prediction_length=max_prediction_length, 29 time_varying_unknown_reals=["value"], 30 ) 31 deepar = DeepAR.from_dataset( 32 time_series_dataset, 33 learning_rate=1e-3, 34 hidden_size=16, 35 rnn_layers=2, 36 ) 37 dataloader = time_series_dataset.to_dataloader(train=True, batch_size=32) 38 trainer = Trainer(max_epochs=2, gradient_clip_val=0.1, accelerator="auto") 39 trainer.fit(deepar, train_dataloaders=dataloader) 40 41 return deepar, data 42 43 44 def test_forecasting_model_pyfunc_loader(model_path: str): 45 n_series = 10 46 max_prediction_length = 20 47 deepar, data = _gen_forecasting_model_and_data( 48 n_series=n_series, 49 timesteps=100, 50 max_prediction_length=max_prediction_length, 51 ) 52 53 torch.manual_seed(42) 54 predicted = deepar.predict(data).numpy() 55 assert predicted.shape == (n_series, max_prediction_length) 56 57 mlflow.pytorch.save_model(deepar, model_path) 58 59 pyfunc_loaded = mlflow.pyfunc.load_model(model_path) 60 torch.manual_seed(42) 61 np.testing.assert_array_almost_equal(pyfunc_loaded.predict(data), predicted, decimal=4) 62 63 with pytest.raises( 64 TypeError, 65 match="The pytorch forecasting model does not support numpy.ndarray", 66 ): 67 pyfunc_loaded.predict(np.array([1.0, 2.0]))