On Hugging Face, there are 20 models tagged “time series” at the time of writing. While certainly not a lot (the “text-generation-inference” tag yields 125,950 results), time series forecasting with foundation models is an interesting enough niche for big companies like amazon, IBM and Salesforce to have developed their own models: Chronos, TinyTimeMixer and Moirai, respectively. At the time of writing, one of the most popular on Hugging Face by number of likes is Lag-Llama, a univariate probabilistic model. Developed by Kashif Rasul, Arjun Ashok and co-authors (1), Lag-Llama was open sourced in February 2024. The authors of the model claim “strong zero-shot generalization capabilities” on a variety of datasets across different domains. Once fine-tuned for specific tasks, they also claim it to be the best general-purpose model of its kind. Big words!
In this blog, I showcase my experience fine-tuning Lag-Llama, and test its capabilities against a more classical machine learning approach. In particular, I benchmark it against an XGBoost model designed to handle univariate time series data. Gradient boosting algorithms such as XGBoost are widely considered the epitome of “classical” machine learning (as opposed to deep-learning), and have been shown to perform extremely well with tabular data (2). Therefore, it seems fitting to use XGBoost to test if Lag-Llama lives up to its promises. Will the foundation model do better? Spoiler alert: it is not that simple.
By the way, I will not go into the details of the model architecture, but the paper is worth a read, as is this nice walk-through by Marco Peixeiro.
The data that I use for this exercise is a 4-year-long series of hourly wave heights off the coast of Ribadesella, a town in the Spanish region of Asturias. The series is available at the Spanish ports authority data portal. The measurements were taken at a station located in the coordinates (43.5, -5.083), from 18/06/2020 00:00 to 18/06/2024 23:00 (3). I have decided to aggregate the series to a daily level, taking the max over the 24 observations in each day. The reason is that the concepts that we go through in this post are better illustrated from a slightly less granular point of view. Otherwise, the results become very volatile very quickly. Therefore, our target variable is the maximum height of the waves recorded in a day, measured in meters.
There are several reasons why I chose this series: the first one is that the Lag-Llama model was trained on some weather-related data, although not a lot, relatively. I would expect the model to find this type of data slightly challenging, but still manageable. The second one is that, while meteorological forecasts are typically produced using numerical weather models, statistical models can still complement these forecasts, specially for long-range predictions. At the very least, in the era of climate change, I think statistical models can tell us what we would typically expect, and how far off it is from what is actually happening.
The dataset is pretty standard and does not require much preprocessing other than imputing a few missing values. The plot below shows what it looks like after we split it into train, validation and test sets. The last two sets have a length of 5 months. To know more about how we preprocess the data, have a look at this notebook.
We are going to benchmark Lag-Llama against XGBoost on two univariate forecasting tasks: point forecasting and probabilistic forecasting. The two tasks complement each other: point forecasting gives us a specific, single-number prediction, whereas probabilistic forecasting gives us a confidence region around it. One could say that Lag-Llama was only trained for the latter, so we should focus on that one. While that is true, I believe that humans find it easier to understand a single number than a confidence interval, so I think the point forecast is still useful, even if just for illustrative purposes.
There are many factors that we need to consider when producing a forecast. Some of the most important include the forecast horizon, the last observation(s) that we feed the model, or how often we update the model (if at all). Different combinations of factors yield their own types of forecast with their own interpretations. In our case, we are going to do a recursive multi-step forecast without updating the model, with a step size of 7 days. This means that we are going to use one single model to produce batches of 7 forecasts at a time. After producing one batch, the model sees 7 more data points, corresponding to the dates that it just predicted, and it produces 7 more forecasts. The model, however, is not retrained as new data is available. In terms of our dataset, this means that we will produce a forecast of maximum wave heights for each day of the next week.
For point forecasting, we are going to use the Mean Absolute Error (MAE) as performance metric. In the case of probabilistic forecasting, we will aim for empirical coverage or coverage probability of 80%.
The scene is set. Let’s get our hands dirty with the experiments!
While originally not designed for time series forecasting, gradient boosting algorithms in general, and XGBoost in particular, can be great predictors. We just need to feed the algorithm the data in the right format. For instance, if we want to use three lags of our target series, we can simply create three columns (say, in a pandas dataframe) with the lagged values and voilà! An XGBoost forecaster. However, this process can quickly become onerous, especially if we intend to use many lags. Luckily for us, the library Skforecast (4) can do this. In fact, Skforecast is the one-stop shop for developing and testing all sorts of forecasters. I honestly can’t recommend it enough!
Creating a forecaster with Skforecast is pretty straightforward. We just need to create a ForecasterAutoreg
object with an XGBoost regressor, which we can then fine-tune. On top of the XGBoost hyperparamters that we would typically optimise for, we also need to search for the best number of lags to include in our model. To do that, Skforecast provides a Bayesian optimisation method that runs Optuna on the background, bayesian_search_forecaster
.
The search yields an optimised XGBoost forecaster
which, among other hyperparameters, uses 21 lags of the target variable, i.e. 21 days of maximum wave heights to predict the next:
Lags: ( 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21)
Parameters: {'n_estimators': 900,
'max_depth': 12,
'learning_rate': 0.30394338985367425,
'reg_alpha': 0.5,
'reg_lambda': 0.0,
'subsample': 1.0,
'colsample_bytree': 0.2}
But is the model any good? Let’s find out!
Point forecasting
First, let’s look at how well the XGBoost forecaster does at predicting the next 7 days of maximum wave heights. The chart below plots the predictions against the actual values of our test set. We can see that the prediction tends to follow the general trend of the actual data, but it is far from perfect.
To create the predictions depicted above, we have used Skforecast’s backtesting_forecaster
function, which allows us to evaluate the model on a test set, as shown in the following code snippet. On top of the predictions, we also get a performance metric, which in our case is the MAE.
Our model’s MAE is 0.64. This means that, on average, our predictions are 64cm off the actual measurement. To put this value in context, the standard deviation of the target variable is 0.86. Therefore, our model’s average error is about 0.74 units of the standard deviation. Furthermore, if we were to simply use the previous equivalent observation as a dummy best guess for our forecast, we would get a MAE of 0.84 (see point 1 of this notebook). All things considered, it seems that, so far, our model is better than a simple logical rule, which is a relief!
Probabilistic forecasting
Skforecast allows us to calculate distribution intervals where the future outcome is likely to fall. The library provides two methods: using either bootstrapped residuals or quantile regression. The results are not very different, so I am going to focus here on the bootstrapped residuals method. You can see more results in part 3 of this notebook.
The idea of constructing prediction intervals using bootstrapped residuals is that we can randomly take a model’s forecast errors (residuals) an add them to the same model’s forecasts. By repeating the process a number of times, we can construct an equal number of alternative forecasts. These predictions follow a distribution that we can get prediction intervals from. In other words, if we assume that the forecast errors are random and identically distributed in time, adding these errors creates a universe of equally possible forecasts. In this universe, we would expect to see at least a percentage of the actual values of the forecasted series. In our case, we will aim for 80% of the values (that is, a coverage of 80%).
To construct the prediction intervals with Skforecast, we follow a 3-step process: first, we generate forecasts for our validation set; second, we compute the residuals from those forecasts and store them in our forecaster class; third, we get the probabilistic forecasts for our test set. The second and third steps are illustrated in the snippet below (the first one corresponds to the code snippet in the previous section). Lines 14-17 are the parameters that govern our bootstrap calculation.
The resulting prediction intervals are depicted in the chart below.
An 84.67% of values in the test set fall within our prediction intervals, which is just above our target of 80%. While this is not bad, it may also mean that we are overshooting and our intervals are too big. Think of it this way: if we said that tomorrow’s waves would be between 0 and infinity meters high, we would always be right, but the forecast would be useless! To get a idea of how big our intervals are, Skforecast’s docs suggest that we compute the area of our intervals by thaking the sum of the differences between the upper and lower boundaries of the intervals. This is not an absolute measure, but it can help us compare across forecasters. In our case, the area is 348.28.
These are our XGBoost results. How about Lag-Llama?
The authors of Lag-Llama provide a demo notebook to start forecasting with the model without fine-tuning it. The code is ready to produce probabilistic forecasts given a set horizon, or prediction length, and a context length, or the amount of previous data points to consider in the forecast. We just need to call the get_llama_predictions
function below:
The core of the funtion is a LagLlamaEstimator
class (lines 19–47), which is a ai/stable/api/gluonts/gluonts.torch.model.estimator.html#gluonts.torch.model.estimator.PyTorchLightningEstimator” rel=”noopener ugc nofollow” target=”_blank”>Pytorch Lightning Estimator based on the ai/stable/index.html” rel=”noopener ugc nofollow” target=”_blank”>GluonTS (5) package for probabilistic forecasting. I suggest you go through the GluonTS docs to get familiar with the package.
We can leverage the get_llama_predictions
function to produce recursive multistep forecasts. We simply need to produce batches of predictions over consecutive batches. This is what we do in the function below, recursive_forecast
:
In lines 37 to 39 of the code snippet above, we extract the percentiles 10 and 90 to produce an 80% probabilistic forecast (90–10), as well as the median of the probabilistic prediction to get a point forecast. If you need to learn more about the output of the model, I suggest you have a look at the author’s tutorial mentioned above.
The authors of the model advise that different datasets and forecasting tasks may require differen context lenghts. In our case, we try context lenghts of 32, 64 and 128 tokens (lags). The chart below shows the results of the 64-token model.
Point forecasting
As we said above, Lag-Llama is not meant to calculate point forecasts, but we can get one by taking the median of the probabilistic interval that it returns. Another potential point forecast would be the mean, although it would be subject to outliers in the interval. In any case, for our particular dataset, both options yield similar results.
The MAE of the 32-token model was 0.75. That of the 64-token model was 0.77, while the MAE of the 128-token model was 0.77 as well. These are all higher than the XGBoost forecaster’s, which went down to 0.64. In fact, they are very close to the baseline, dummy model that used the previous week’s value as today’s forecast (MAE 0.84).
Probabilistic forecasting
With a predicted interval coverage of 68.67% and an interval area of 280.05, the 32-token forecast does not perform up to our required standard. The 64-token one, reaches an 74.0% coverage, which gets closer to the 80% region that we are looking for. To do so, it takes an interval area of 343.74. The 128-token model overshoots but is closer to the mark, with an 84.67% coverage and an area of 399.25. We can grasp an interesting trend here: more coverage implies a larger interval area. This should not always be the case — a very narrow interval could always be right. However, in practice this trade-off is very much present in all the models I have trained.
Notice the periodic bulges in the chart (around March 10 or April 7, for instance). Since we are producing a 7-day forecast, the bulges represent the increased uncertainty as we move away from the last observation that the model saw. In other words, a forecast for the next day will be less uncertain than a forecast for the day after next, and so on.
The 128-token model yields very similar results to the XGBoost forecaster, which had an area 348.28 and a coverage of 84.67%. Based on these results, we can say that, with no training, Lag-Llama’s performance is rather solid and up to par with an optimised traditional forecaster.
Lag-Llama’s Github repo comes with a “best practices” section with tips to use and fine-tune the model. The authors especially recommend tuning the context length and the learning rate. We are going to explore some of the suggested values for these hyperparameters. The code snippet below, which I have taken and modified from the authors’ fine-tuning tutorial notebook, shows how we can conduct a small grid search:
In the code above, we loop over context lengths of 32, 64, and 128 tokens, as well as learning rates of 0.001, 0.001, and 0.005. Within the loop, we also calculate some test metrics: Coverage(0.8), Coverage(0.9) and Mean Absolute Error of (MAE) Coverage. Coverage(0.x) measures how many predictions fall within their prediction interval. For instance, a good model should have a Coverage(0.8) of around 80%. MAE Coverage, on the other hand, measures the deviation of the actual coverage probabilities from the nominal coverage levels. Therefore, a good model in our case should be one with a small MAE and coverages of around 80% and 90%, respectively.
One of the main differences with respect to the original fine-tuning code from the authors is line 46. In that line, the original code does not include a validation set. In my experience, not including it meant that all models that I trained ended up overfitting the training data. On the other hand, with a validation set most models were optimised in Epoch 0 and did not improve the validation loss thereafter. With more data, we may see less extreme outcomes.
Once trained, most of the models in the loop yield a MAE of 0.5 and coverages of 1 on the test set. This means that the models have very broad prediction intervals, but the prediction is not very precise. The model that strikes a better balance is model 6 (counting from 0 to 8 in the loop), with the following hyperparameters and metrics:
{'context_length': 128,
'lr': 0.001,
'Coverage(0.8)': 0.7142857142857143,
'Coverage(0.9)': 0.8571428571428571,
'MAE_Coverage': 0.36666666666666664}
Since this is the most promising model, we are going to run it through the tests that we have with the other forecasters.
The chart below shows the predictions from the fine-tuned model.
Something that catches the eye very quickly is that prediction intervals are substantially smaller than those from the zero-shot version. In fact, the interval area is 188.69. With these prediction intervals, the model reaches a coverage of 56.67% over the 7-day recursive forecast. Remember that our best zero-shot predictions, with a 128-token context, had an area of 399.25, reaching a coverage of 84.67%. This means a 55% reduction in the interval area, with only a 33% decrease in coverage. However, the fine-tuned model is too far from the 80% coverage that we are aiming for, whereas the zero-shot model with 128 tokens wasn’t.
When it comes to point forecasting, the MAE of the model is 0.77, which is not an improvement over the zero-shot forecasts and worse than the XGBoost forecaster.
Overall, the fine-tuned model leaves doesn’t leave us a good picture: it doesn’t do better than a zero-shot better at either point of probabilistic forecasting. The authors do suggest that the model can improve if fine-tuned with more data, so it may be that our training set was not large enough.
To recap, let’s ask again the question that we set out at the beginning of this blog: Is Lag-Llama better at forecasting than XGBoost? For our dataset, the short answer is no, they are similar. The long answer is more complicated, though. Zero-shot forecasts with a 128-token context length were at the same level as XGBoost in terms of probabilistic forecasting. Fine-tuning Lag-Llama further reduced the prediction area, making the model’s correct forecasts more precise, albeit at a substantial cost in terms of probabilistc coverage. This raises the question of where the model could get with more training data. But more data we did not have, so we can’t say that Lag-Llama beat XGBoost.
These results inevitably open a broader debate: since one is not better than the other in terms of performance, which one should we use? In this case, we’d need to consider other variables such as ease of use, deployment and maintenance and inference costs. While I haven’t formally tested the two options in any of those aspects, I suspect the XGBoost would come out better. Less data- and resource-hungry, pretty robust to overfitting and time-tested are hard-to-beat characteristics, and XGBoost has them all.
But do not believe me! The code that I used is publicly available on this Github repo, so go have a look and run it yourself.