Language models (LM) are the driving force behind many recent advances in natural language processing. Models like T5, LaMDA, GPT-3, and PaLM have shown impressive performance on various language tasks. While multiple factors can contribute to improving LM performance, some recent studies suggest that increasing model size is crucial to reveal emerging capabilities. In other words, some cases can be solved with small models, while others seem to benefit from a larger scale.
Despite recent efforts that have enabled efficient training of LMs on large amounts of data, the trained models can still be slow and expensive for practical use. By generating text at the time of inference, most autoregressive LMs generate content similar to how we speak and write (word after word), predicting each new word based on previous words. This process cannot be parallelized as LMs need to complete the prediction of one word before starting to compute the next. Furthermore, predicting each word requires significant computation given the billions of model parameters.
In “Safe Adaptive Language Modeling“, featured in NeurIPS 2022, we present a new method to speed up LM text generation by improving efficiency at inference time. Our method, called CALM, is motivated by the intuition that some next word predictions are easier than others. When writing a sentence, some continuations are trivial, while others may require more effort. Current LMs dedicate the same amount of computing power to all predictions. Instead, CALM dynamically distributes the computational effort across the generation intervals. By selectively allocating more computational resources to only more difficult predictions, CALM generates text faster while preserving output quality.
Safe adaptive language modeling
When possible, CALM skips some computational effort for certain predictions. To demonstrate this, we use the popular T5 encoder-decoder architecture. The encoder reads the input text (for example, a news article to summarize) and converts the text into dense representations. The decoder then outputs the digest by predicting it word for word. Both the encoder and the decoder include a long sequence of Transformer layers. Each layer includes attention Y feedback modules with many matrix multiplications. These layers gradually modify the hidden representation that is ultimately used to predict the next word.
Instead of waiting for all decoder layers to complete, CALM tries to predict the next word before, after some layer in between. To decide whether to commit to a certain prediction or postpone the prediction to a later layer, we measure the confidence of the model in its intermediate prediction. The rest of the calculation is skipped only when the model is sufficiently sure that the prediction will not change. To quantify what is “secure enough”, we calibrate a threshold that statistically satisfies arbitrary quality guarantees over the entire output stream.
Text generation with a regular language model (top) and with CALM (bottom). CALM tries to make early predictions. Once confident enough (darker shades of blue), jump in and save time. |
Language models with early exits
Enabling this early exit strategy for LMs requires minimal modifications to the training and inference processes. During training, we encourage the model to produce meaningful representations in intermediate layers. Instead of predicting using only the top layer, our learning loss function is a weighted average of the predictions from all layers, assigning more weight to the top layers. Our experiments show that this significantly improves the predictions of the middle layer while preserving the performance of the full model. In a model variant, we also include a small early exit classifier trained to classify whether the prediction of the local intermediate layer is consistent with the upper layer. We train this classifier in a quick second step where we freeze the rest of the model.
Once the model is trained, we need a method to allow early exit. First, we define a local confidence measure to capture the confidence of the model in its intermediate prediction. We explore three measures of confidence (described in the results section below): (1) soft max answer, taking the maximum predicted probability from the softmax distribution; (2) state spread, the cosine distance between the current hidden representation and that of the previous layer; and (3) early Exit Classifier, the output of a classifier specifically trained to predict local consistency. We find the softmax answer to be statistically sound and at the same time simple and quick to compute. The other two alternatives are lighter in floating point operations (flops).
Another challenge is that the self attention of each layer depends on the hidden states of the previous words. If we go out early for some word predictions, these hidden states may be missing. Instead, we again pay attention to the hidden state of the last computed layer.
Finally, we set the local trust threshold for leaving early. In the next section, we describe our controlled process for finding good threshold values. As a first step, we simplify this infinite search space based on a helpful observation: mistakes made early in the generation process are more detrimental, since they can affect all subsequent results. So we start with a higher (more conservative) threshold and gradually lower it over time. We use a negative exponent with user-defined temperature to control this decay rate. We found that this allows for better control over the performance-efficiency trade-off (the acceleration gained per quality level).
Reliable quality control of the accelerated model
Early release decisions have to be local; they need to happen by predicting each word. In practice, however, the final product should be globally consistent or comparable to the original model. For example, if the original full model output “the concert was wonderful and long”, one would accept CALM by changing the order of the adjectives and output “the concert was long and wonderful”. However, locally, the word “wonderful” was replaced by “long”. Therefore, the two outputs are globally consistent, but include some local inconsistencies. We build on the learn then test (LTT) to connect trust-based local decisions with globally consistent results.
First, we define and formulate two types of consistency constraints to choose from:
- Text consistency: We constrain the expected textual distance between the outputs of CALM and the outputs of the full model. This does not require any labeled data.
- Risk Consistency: We cap the expected increase in loss that we allow for CALM compared to the full model. This requires benchmark results against which to compare.
For each of these constraints, we can set the tolerance we allow and calibrate the confidence threshold to allow early exits while reliably satisfying our defined constraint with an arbitrarily high probability.
CALM saves inference time
We run experiments on three popular generation data sets: CNN/DM to sum up, WMT for machine translation, and Team to answer questions. We evaluate each of the three confidence measures (softmax response, state propagation, and early exit classifier) using an 8-layer encoder-decoder model. To assess performance at the global sequence level, we use the standard Red-L, BLUEY Token-F1 scores that measure distances against references written by humans. We show that one can maintain the full performance of the model by using only a third or a half of the layers on average. CALM achieves this by dynamically distributing the computational effort over the prediction time intervals.
As a rough upper bound, we also calculate the predictions using a local oracle confidence measure, which allows output in the first layer that leads to the same prediction as the upper one. Across all three tasks, Oracle’s measure can preserve full model performance when using only 1.5 decoder layers on average. Unlike CALM, a static baseline uses the same number of layers for all predictions, requiring 3-7 layers (depending on the data set) to preserve its performance. This demonstrates why dynamic allocation of computing effort is important. Only a small fraction of the predictions require most of the model complexity, while for others much less should suffice.
Throughput per task vs. the average number of decoder layers used. |
Finally, we also find that CALM allows for practical speedups. By benchmarking TPUs, we saved almost half the computation time and maintained the quality of the results.
Conclution
CALM enables faster text generation with LM, without reducing the quality of the output text. This is achieved by dynamically changing the amount of computation per generation time step, allowing the model to exit the computation sequence earlier when it is sufficiently confident.
As language models continue to grow in size, studying how to use them efficiently becomes crucial. CALM is orthogonal and can be combined with many efficiency-related efforts, including the model quantization, distillationshortage, effective partitionY distributed control flows.
Thanks
It was an honor and a privilege to work on this with Adam Fisch, Ionel Gog, Seungyeon Kim, Jai Gupta, Mostafa Dehghani, Dara Bahri, Vinh Q. Tran, Yi Tay, and Donald Metzler. We also thank Anselm Levskaya, Hyung Won Chung, Tao Wang, Paul Barham, Michael Isard, Orhan Firat, Carlos Riquelme, Aditya Menon, Zhifeng Chen, Sanjiv Kumar, and Jeff Dean for helpful discussions and comments. Finally, thanks to Tom Small for preparing the animation in this blog post.