Large language models (LLMs) have revolutionized artificial intelligence and have impacted various scientific and engineering disciplines. The Transformer architecture, initially designed for machine translation, has become the basis for GPT models, significantly advancing the field. However, current LLMs face challenges in their training approach, which mainly focuses on predicting the next token based on the previous context while maintaining causality. This simple method has been applied in a variety of areas, including robotics, protein sequencing, audio processing, and video analysis. As LLMs continue to grow in scale, reaching hundreds of billions and even trillions of parameters, concerns are arising about the accessibility of ai research, with some fearing that it may become confined to industry researchers. The central problem researchers are addressing is how to improve the model's capabilities to match those of much larger architectures or achieve comparable performance with fewer training steps, ultimately addressing scale and efficiency challenges in the development of LLM.
Researchers have explored several approaches to improve LLM performance by manipulating intermediate embeddings. One method involved applying hand-tuned filters to the Discrete Cosine Transform of the latent space for tasks such as named entity recognition and topic modeling on non-causal architectures such as BERT. However, this approach, which transforms the entire length of the context, is not suitable for causal language modeling tasks.
Two notable techniques, FNet and WavSPA, attempted to improve attention blocks in BERT-like architectures. FNet replaced the attention mechanism with a 2-D FFT block, but this operation was not causal, considering future tokens. WavSPA computed attention in wavelet space, using multiresolution transformations to capture long-term dependencies. However, it also relied on non-causal operations, examining the entire length of the sequence.
These existing methods, while innovative, face limitations in their applicability to causal decoder-only architectures such as GPT. They often violate the causality assumption crucial for next token prediction tasks, making them unsuitable for direct adaptation to GPT-like models. The challenge remains to develop techniques that can improve model performance while maintaining the causal nature of decoder-only architectures.
Stanford researchers propose the first instance of incorporating wavelets in LLM, WaveletGPTto improve LLMs by incorporating wavelets into their architecture. This technique, believed to be the first of its kind, adds multiscale filters to the intermediate embeddings of the Transformer decoder layers using Haar waves. The innovation allows each prediction of the next token to access multi-scale representations at each layer, rather than relying on fixed-resolution representations.
Surprisingly, this method speeds up the pre-training of transformer-based LLMs by 40% to 60% without adding additional parameters, a significant advance given the widespread use of Transformer Decoder-based architectures in various modalities. The approach also demonstrates substantial performance improvements with the same number of training steps, comparable to adding multiple layers or parameters.
The wavelet-based operation shows performance improvements in three different modalities: language (text-8), raw audio (YoutubeMix), and symbolic music (MAESTRO), highlighting its versatility for structured data sets. Additionally, by making the wavelet kernels learnable, which adds only a small fraction of parameters, the model achieves even greater performance increases, allowing it to learn multiscale filters on intermediate embeddings from scratch.
The proposed method incorporates wavelets into transformer-based large language models while maintaining the assumption of causality. This approach can be applied to various architectures, including transformerless configurations. The technique focuses on manipulating intermediate embeddings of each decoder layer.
For a given signal xl(i), which represents the output of the lth decoder layer along the ith coordinate, the method applies a discrete wavelet transform. With N+1 layers and an embedding dimension E, this process generates N*E signals of length L (context length) from intermediate embeddings between decoder blocks.
The wavelet transform, specifically using Haar wavelets, involves passing the signal through filters with different resolutions. Haar wavelets are square-shaped functions derived from a parent wavelet using scale and shift operations. This process creates secondary waves that capture signal information on various time scales.
The discrete wavelet transform is implemented by passing the signal through low-pass and high-pass filters, followed by downsampling. For Haar wavelets, this is equivalent to averaging and differentiation operations. The process generates approximation coefficients (yapprox) and detail coefficients (ydetail) through convolution and downsampling. This operation is performed recursively on the approximation coefficients to obtain multi-scale representations, allowing each prediction of the next token to access these multi-resolution representations of intermediate embeddings.
This method connects wavelets and LLM embeddings by focusing on approximation coefficients, which capture data structured at multiple levels. For text, this structure ranges from lyrics to thematic patterns, while for symbolic music it ranges from notes to entire pieces. The approach uses Haar wavelets, simplifying the process to a moving average operation. To maintain causality and the length of the original sequence, the method calculates moving averages of previous samples within a specific kernel length for each dimension of the token. This creates multi-scale representations of the input signal, allowing the model to capture information at different resolutions across the embedding dimensions without altering the structure of the intermediate Transformer embeddings.
The method introduces a unique approach to incorporating multi-scale representations without increasing architectural complexity. Instead of computing all approximate signal levels for each embedding dimension, we parameterized the level using the index of the embedding dimension itself. This approach retains half of the intermediate embedding signals unchanged, while processing the other half based on their index. For the processed half, a simple mapping function f determines the kernel size for each coordinate, from level I to IX approximations. The modified signal xnl(i) is calculated using a causal moving average filter with a kernel size determined by f(i). This operation maintains the critical causality assumption in LLMs and prevents information leakage of future tokens. The technique creates a structure where different embedding dimensions move at different speeds, allowing the model to capture information at multiple scales. This multi-rate structure allows the attention mechanism to use multi-scale features at each layer and token, potentially improving the model's ability to capture complex patterns in the data.
Results across three modalities (text, symbolic music, and audio waveforms) demonstrate substantial performance improvements with wavelet-based intermediate operation. For natural language, the decrease in validation loss is equivalent to expanding from a 16-layer model to a 64-layer model on text dataset 8. The modified architecture achieves the same loss almost twice as fast as the original in terms of training steps. This speedup is even more pronounced for raw audio, possibly due to the quasi-stationary nature of audio signals on short time scales. Convergence of raw waveform LLM configurations occurs almost twice as fast compared to text-8 and symbolic music.
By comparing absolute clock execution times, the modified architecture shows computational efficiency in both learnable and unlearnable configurations. The time required to complete an epoch is reported relative to the reference architecture. The method proves to be computationally economical, as the main operation involves simple averaging for Haar wavelets or learning a single-filter convolutional kernel with variable context lengths in the embedding dimensions. This efficiency, combined with performance improvements, underscores the effectiveness of the wavelet-based approach in improving LLM training in various modalities without significant computational overhead.
This study presents WaveletGPT, which introduces the integration of wavelets, a core signal processing technique, in the pre-training of large language models. By introducing a multi-scale structure in the intermediate embeddings, the performance speed is improved by 40-60% without adding any additional parameters. This technique is effective in three different modalities: plain text, symbolic music, and plain audio. When trained for the same duration, it demonstrates substantial improvements in performance. Possible future directions include incorporating advanced wavelet concepts and multi-resolution signal processing to further optimize large language models.
look at the Paper. All credit for this research goes to the researchers of this project. Also, don't forget to follow us on twitter.com/Marktechpost”>twitter and join our Telegram channel and LinkedIn Grabove. If you like our work, you will love our information sheet..
Don't forget to join our SubReddit over 50,000ml.
We are inviting startups, companies and research institutions that are working on small language models to participate in this next Magazine/Report 'Small Language Models' by Marketchpost.com. This magazine/report will be published in late October/early November 2024. Click here to schedule a call!
Asjad is an internal consultant at Marktechpost. He is pursuing B.tech in Mechanical Engineering from Indian Institute of technology, Kharagpur. Asjad is a machine learning and deep learning enthusiast who is always researching applications of machine learning in healthcare.
<script async src="//platform.twitter.com/widgets.js” charset=”utf-8″>