LLMs like GPT-4 excel at language understanding, but struggle with high GPU memory usage during inference, which limits their scalability for real-time applications like chatbots. Existing methods reduce memory by compressing the KV cache, but ignore cross-layer dependencies and pre-computation memory demands. Inference memory usage mainly comes from model parameters and the KV cache, with the latter consuming much more memory. For example, a 7 billion parameter model uses 14 GB for the parameters but 72 GB for the KV cache. This significant memory requirement restricts the performance of LLM inference on GPUs.
Researchers from Shanghai Jiao Tong University, Xiaohongshu Inc. and South China University of technology developed PyramidInfer, which improves LLM inference by compressing the KV cache. Unlike existing methods that ignore cross-layer dependencies and memory demands of precomputation, PyramidInfer retains only crucial context keys and values layer by layer. Inspired by recent tokens' consistency in attention weights, this approach significantly reduces GPU memory usage. Experiments show that PyramidInfer improves performance by 2.2 times and reduces KV cache by more than 54% compared to existing methods, demonstrating its effectiveness on various tasks and models.
Efficient strategies are essential to handle the increasing demand for chatbot queries, aiming to maximize performance by taking advantage of GPU parallelism. One approach is to increase GPU memory through pipeline parallelism and KV cache flushing, using multiple GPUs or RAM. For limited GPU memory, reducing the KV cache is another option. Techniques such as FlashAttention 2 and PagedAttention minimize memory waste by optimizing CUDA operations. Methods such as StreamingLLM, H2O, and Scissorhands compress the KV cache by focusing on recent context or attention mechanisms, but ignore layer differences and prefetch phase compression. PyramidInfer addresses these gaps by considering layer-specific compression in both phases.
Testing the hypotheses of inference context redundancy (ICR) and recency of attention (RAC) inspired the design of PyramidInfer. ICR postulates that many context cues and values are redundant during inference and are only needed in training to predict the next token. Experiments with a 40-layer LLaMA 2-13B model revealed that deeper layers have higher redundancy, allowing for a significant reduction in KV cache without affecting output quality. RAC confirms that certain keys and values are consistently freshly tokenized, enabling the selection of critical contexts (PVCs) for efficient inference. PyramidInfer leverages these insights to compress the KV cache efficiently in both the prefetch and build phases.
PyramidInfer's performance was evaluated across multiple tasks and models, demonstrating significant reductions in GPU memory usage and increased performance while maintaining build quality. The assessment included language modeling in wikitext-v2, LLM benchmarks like MMLU and BBH, mathematical reasoning with GSM8K, coding through HumanEval, conversation management with MT-Bench, and summarizing long texts using LEval. PyramidInfer was tested on models such as LLaMA 2, LLaMA 2-Chat, Vicuna 1.5-16k and CodeLLaMA in different sizes. The results showed that PyramidInfer effectively maintained build quality with less GPU memory than full cache methods and significantly outperformed local strategies.
In conclusion, PyramidInfer presents an efficient method to compress the KV cache during the prefetch and generation phases, inspired by ICR and RAC. This approach significantly reduces GPU memory usage without compromising model performance, making it ideal for deploying large language models in resource-constrained environments. Despite its effectiveness, PyramidInfer requires additional calculations, which limits speed with small batches. As the first to compress the KV cache in the prefetch phase, PyramidInfer has not yet become a lossless method, indicating potential for future improvements in this area.
Review the Paper. All credit for this research goes to the researchers of this project. Also, don't forget to follow us on twitter.com/Marktechpost”>twitter. Join our Telegram channel, Discord channeland LinkedIn Grabove.
If you like our work, you will love our Newsletter..
Don't forget to join our 42k+ ML SubReddit
Sana Hassan, a consulting intern at Marktechpost and a dual degree student at IIT Madras, is passionate about applying technology and artificial intelligence to address real-world challenges. With a strong interest in solving practical problems, she brings a new perspective to the intersection of ai and real-life solutions.
<script async src="//platform.twitter.com/widgets.js” charset=”utf-8″>