The Mamba in the Llama: Distilling and Accelerating Hybrid Models
Introduction
The evolution of large language models (LLMs) has been largely driven by the success of Transformer architectures. However, despite their impressive capabilities, Transformers suffer from significant inefficiencies, particularly in scenarios involving long sequences due to their quadratic complexity and heavy memory requirements. These challenges have spurred interest in exploring alternative architectures that can offer similar or even better performance with greater efficiency.
One such promising direction is the use of linear Recurrent Neural Networks (linear RNNs), specifically the Mamba and Mamba2 architecture. Mamba and its variants have demonstrated competitive performance to Transformers while offering significant advantages in inference speed. Mamba enjoys parallel training, as well as constant memory requirements during inference.
But can we bridge the gap between these architectures and harness the strengths of both? The answer lies in distilling large-scale Transformer models into hybrid linear RNNs and accelerating inference, combining the best of both worlds.
From Transformer to Mamba
The self-attention mechanism is vital to transformers, enabling models to weigh the importance of different tokens in a sequence. However, this comes at the cost of computational and memory inefficiencies, particularly for long sequences. For example, during inference, transformers store the key and value vectors for every token they encounter in the KV-cache. For big models and long sequences, dealing with the KV-cache causes a big memory overhead. It slows inference down, and it occupied a lot of GPU memory.
In contrast, linear RNNs like Mamba enjoy linear-time scaling during training, and constant memory cost during inference as the entire state is summarized in a fixed size tensor. As a consequence, Mamba offers up to 5× higher throughput in inference tasks.
Therefore, it would make sense to take a pretrained transformer and distill its capabilities in a Mamba model, so as to explot the inference capabilities of linear RNNs while preserving the generation quality of transformer LLMs.
Despite their differences, a natural relationship exists between attention mechanisms in Transformers and the operations in linear RNNs. By linearizing the attention mechanism—essentially removing the softmax non-linearity—we can approximate the behavior of attention using linear RNNs. Therefore, we begin our distillation process by initializing the parameters in Mamba so as to mimic a linearized version of the transformer attention.
Linear RNNs take the following form
By linearizing transformer attention, we get
Thus the relationship:
We know have a clear relationship between the matrices of a Mamba block and those of a (linear) attention layer! This relationship, and the way we use it for distillation, is shown in Figure 1.
Distilling to an Expanded Linear RNN
Initialization
The initialization procesdure is shown in Figure 2.
Figure 2: Transferring Transformer to Mamba. Weights, represented by layers, in the same color are initialized from the transformer (Linear projections for Q, K, and V are initialized using linear projection for C, B, and X respectively). We replace individual attention blocks with Mamba blocks, and then finetune Mamba blocks while freezing the MLP blocks. Shapes are kept mainly the same. New parameters are introduced for the learned A and ∆ parameters.
Distillation
The goal of our distillation is not to exactly replicate the performance of the original model but rather to provide a good starting point for the next distillation steps.
Figure 2 shows the resulting architecture. Our version directly replaces Transformer attention heads with fine-tuned linear RNN layers. During this phase, the Mamba layers are trained while the original Transformer's MLP layers are kept frozen to preserve their learned knowledge. This approach also requires processing additional components, such as grouped query attention that shares keys and values across heads. We note that this architecture differs from the architecture used in many Mamba systems, which combines MLP-SSM layers and uses a single head.
This initialization allows us to replace any attention block with a linear RNN block. We experiment with hybrid models where we keep every
𝑛 attention layers. Empirically, we found that replacing layers in a stepwise manner was the most effective strategy; i.e., we first keep every 2 layers, distill, then every 4, and continue distillation.
We use the pseudo-label distillation approach with the following loss functions.
1. Word-level KL-Divergence: The student model's probability distribution is trained to match the teacher model's distribution by minimizing KL divergence over all possible next tokens.
2. Sequence-level Knowledge Distillation (SeqKD): This method involves replacing the ground truth text with the teacher's generated output, known as pseudo-labels.
The overall loss function for SFT combines sequence and word-level loss:
where 𝜃 represents the trainable parameters of the student model, and 𝛼 and 𝛽 control the weights of the sequence and word loss terms, respectively.
Supervised Fine-Tuning (SFT)
Then, we perform supervised fine-tuning with the distilled hybrid model on datasets generated by GPT-4, such as OpenHermes 2.5, for one epoch.
Direct Preference Optimization (DPO)
The third stage of instruction-tuning aligns the LLM with user preferences, aiming to maximize a reward model 𝑟 while staying close to a reference model, typically the supervised fine-tuned model. The optimization objective is:
Recent methods, such as Direct Preference Optimization (DPO), have proven effective for this purpose. If preferred (𝑦𝑤) and dispreferred (𝑦𝑙) outputs are available for a given prompt
𝑥, the optimization problem can be reformulated as:
This optimization is performed at the sequence level by scoring the preferred and dispreferred outputs of the model with the teacher and then backpropagating to the student. This work introduces DPO as a novel distillation objective at the sequence level.
Speculative decoding for Mamba and Hybrid Models
Figure 3: Multi-Step RNN Speculative Decoding. Left (top): The draft model generates the set of blue draft tokens sequentially. The draft tokens are then verified. Right (top): Verification uses the multistep kernel, without materializing the intermediate states. The last token is rejected and replaced with the true best token. Note, that even though more tokens are generated we cannot advance the hidden state cache. Left (bottom): The draft model can now generate more blue draft tokens from the current tokens, resulting in six total. Right (bottom): When the new draft is verified, the multi-step kernel returns both the hidden state after the yellow token and the final hidden state, since verification will fall between those positions.
One of the most significant challenges in deploying large language models (LLMs) like Transformers is their slow inference speed, particularly when generating long sequences. This limitation arises because these models generate text in an autoregressive manner—each token is generated sequentially, depending on the previous one. This inherently serial process becomes a bottleneck, preventing efficient utilization of computational resources.
To address this, speculative decoding has been introduced as a means to accelerate the inference process. Speculative decoding leverages the idea of using a smaller, faster draft model to predict several tokens ahead, which are then validated by a more accurate but slower verifier model.
The Basics of Speculative Decoding
In speculative decoding, the process begins with a draft model that generates a sequence of candidate tokens (let's say 𝐾 tokens). This draft model is designed to be lightweight and fast, allowing it to produce these tokens quickly. However, because the draft model is not as accurate as the main model, these tokens need to be verified.
The verifier model, which is typically the main model or in our case a hybrid model checks, in parallel, the sequence of tokens generated by the draft model. If the tokens pass the verification check, they are accepted and included in the final output. If a token fails the verification, the process stops at that point, discards the remaining speculative tokens, and the main model takes over to generate the correct token.
This process then repeats, with the draft model generating another sequence of 𝐾 tokens from the new point, followed by verification. The key advantage here is that when speculations are correct, multiple tokens are generated and verified in a single step, significantly reducing the number of sequential operations required.
Speculative decoding for Mamba and Hybrid Models
While speculative decoding has been effectively used in Transformer models, applying it to Mamba and hybrid models presents unique challenges.
Attention-based models are particularly amenable to speculation, as they are slow at generation due to their sequential nature, but fast at verification due to their ability to check multiple tokens in parallel. Linear RNN models like Mamba have significantly different performance characteristics that make them less amenable to speculative decoding. Sequential decoding using recurrent-style sampling is already significantly faster than attention. Like attention, there are parallel modes for models like Mamba which are used at training.
These are efficient, but are tuned for extremely long sequences. In addition, they rely on hardware-aware optimizations, such as avoiding materializing intermediate states. These properties make it difficult to use for speculation for relatively short chains when it is unknown when a conflict will occur.
An additional challenge arises from caching states in RNN models. The state of an attention model is represented by the key-value cache, 𝐾1:𝑡, 𝑉1:𝑡; whereas the state of an RNN model is simply ℎ𝑡. To be competitive with attention this single RNN state needs to be very large. During speculation, we need to rewind to a previous state at time step 𝑡′. For attention, this is simply 𝐾1:𝑡′ , 𝑉1:𝑡′ ; however, for RNNs this would require caching all ℎ1:𝑡 which would require a large memory overhead.
We propose a new algorithm for linear RNN speculative decoding using hardware-aware multi-step generation. The core to the approach generation kernel that computes,
Where 𝑖 is the starting hidden state, 𝑖≤𝑗≤𝑘, and 𝑗…𝑘 is the range of y outputs needed. The kernel is hardware-aware because it avoids materializing key terms off of the fast GPU memory. Specifically, it avoids instantiating most h1:𝑛 as well as the discrete-time linear RNN parameters. This kernel is aimed to target the issues presented above.It can save a snapshot of the state h𝑗 before evaluating the draft tokens. This allows recomputing the correct state on the fly after a token is rejected. The assumption is that decoding is bottlenecked by memory and not by compute, as we can compute multiple steps of decoding with very little overhead over single-step decoding.
Figure 3 shows the algorithm. The approach maintains only one RNN hidden state in cache for verification and advances it lazily based on the success of the multi-step kernel. Since the distilled models contain transformer layers, we also extend speculative decoding to Attention/RNN hybrid architectures. In this setting, the RNN layers perform verification according to the aforementioned method, while the transformer layers simply perform parallel verification.
Experimental Results
Distillation Experiments
- Target Models: We conducted experiments using two LLM chat models: Zephyr-7B and Llama-3 Instruct 8B. Our goal is to distill transformer models into hybrid Mamba and Mamba2 models with varying attention layers (50%, 25%, 12.5%, and 0%). Mamba2 is optimized for recent GPU architectures. The hybrid models, such as Zephyr-Mamba and Llama3-Mamba, represent distillations from their respective teacher models.
- Distillation Process: We employed a three-stage process:
- Stage 1: Using UltraChat and UltraFeedback as seed prompts, the teacher model generates pseudo-labels. The student model is then trained with a loss function combining KL loss and cross-entropy loss.
- Stage 2: Supervised fine-tuning is applied to instruction tuning datasets like OpenHermes 2.5.
- Stage 3: For models distilled from Zephyr, distilled alignment is performed using DPO on UltraFeedback. Models distilled from Llama-3 Instruct 8B use datasets from SimPO and Zephyr.
- The entire distillation process for each hybrid model takes less than five days on 8x80G A100 GPUs.
- Baselines. We compared our models with other large-scale linear RNN models, including pure SSM architectures like TRI Mamba 7B, Falcon Mamba 7B, hybrid architectures like Nvidia Hybrid Mamba 2, and other models like Recurrent Gemma-9B Instruct.
Evaluation on the Chat Benchmark
Our models were evaluated using single-turn (AlpacaEval) and multi-turn chat benchmarks (MT-Bench).
The distilled hybrid Mamba model (50%) achieves a score in the MT-benchmark similar to the teacher model and slightly better on the AlpacaEval benchmark in both LC win rate and overall win rate. The performance of the distilled hybrid Mamba models (25% and 12.5%) is slightly worse than that of the teacher model in the MT benchmark but still surpasses some large transformers with more parameters in AlpacaEval. Notably, the distilled hybrid model performs better than Falcon Mamba, which was trained from scratch with over 5 trillion tokens.
Evaluation on the General Benchmark
Zero Shot Evaluation
Using the LM Evaluation Harness library, we evaluate our models on 10 tasks.
Both hybrid Mamba-Llama3 and Mamba2-Llama3 models, distilled from Llama-3 Instruct 8B, perform better compared to the open-source TRI Mamba and Nvidia Mamba models trained from scratch.
Benchmark Evaluation
Few-shot evaluations are conducted on the OpenLLM Leaderboard across multiple benchmarks such as ARC-Challenge, HellaSwag, MMLU, and Winogrande. We also evaluated GSM8K and CRUX using ZeroEval.
The results indicate that the performance of our distilled hybrid models matches that of the best open-source linear RNN models on the OpenLLM Leaderboard, while outperforming the corresponding open-source instruct models in GSM8K and CRUX.
Speculative Decoding
We test our novel speculative decoding algorithm for Mamba on Mamba 2.8b and Mamba 7B models. For Mamba 7B, we train a draft model on the RefinedWeb dataset.
Extending speculative decoding to Mamba/RNNs means that we can now mix transformer, hybrid and full Mamba models when doing speculation. As an example, the Mamba 7B model uses a transformer speculator.
# Gen. Tokens refers to the average number of tokens generated at each iteration. It includes an additional token generated directly from the Verifier logits.
We additionally run speculation with our newly distilled hybrid models. For both models, we train transformer speculators.
We test on both the Zephyr and Llama hybrid models with different configurations. For both the 50% and 25% distilled models, we achieve speedups
of over 1.8x on the Zephyr-Hybrid compared to the non-speculative baseline. We also show that the 4-layer draft model we trained achieves a higher acceptance rate, but it adds some additional overhead due to the increased draft model size. For the Llama-hybrid models, the speedups are more modest since the draft
model is larger due to the large embedding table of Llama 3. In subsequent work, we will focus on making these draft models smaller.
Conclusion
The research outlined in this post presents a compelling case for distilling Transformer models into hybrid linear RNNs like Mamba. This approach not only preserves the impressive generative capabilities of Transformers but also significantly enhances their efficiency, making them more suitable for deployment.
As the demand for more efficient and scalable NLP models grows, techniques like the one discussed here will be crucial in pushing the boundaries of what is possible with large language models. The hybrid approach offers a promising path forward, combining the strengths of both Transformers and linear RNNs to create models that are not only powerful but also highly efficient.
- Lower
Cost20% - faster
training4x - network
compression117x
Q: Should I use the RedPajama-V2 Dataset out of the box?
RedPajama-V2 is conceptualized as a pool of data that serves as a foundation for creating high quality datasets. The dataset is thus not intended to be used out of the box and, depending on the application, data should be filtered out using the quality signals that accompany the data. With this dataset, we take the view that the optimal filtering of data is dependent on the intended use. Our goal is to provide all the signals and tooling that enables this.
article