The comprehensive overview of continual learning paper states training strategies for continual learning can be divided into 5 sub categories:
- Regularisation-based approach: this approach adds constraints or penalties to the learning process during the training process.
- Optimisation-based approach: this technique focuses on modifying the optimisation algorithm.
- Representation-based approach: this aims to learn a shared feature representation across different tasks, helping the model generalise better to new but related tasks.
- Replay-based approach: this involves storing some data or learned features from previous tasks and replaying them during training on new tasks to maintain performance on earlier learned tasks. In other words, mixing both the old and new datasets when training on new tasks.
- Architecture-based approach: in this approach, the network architecture is dynamically adjusted, often by growing or partitioning, delegating different parts of the network to different tasks.
Soft Masking of Parameters
The following soft-masking techniques mask and adjust the gradients of each parameter during the training process. The optimisation-based approaches coming up also use the gradients for continual learning. Remember the gradients aren’t just temporary numbers that appear and disappear during training; they’re signals that guide the evolution of the weights.
SPG
This paper proposes a technique named SPG (Soft-masking of Parameter-level Gradient flow) which aims to:
- Train the model on each task until convergence.
- After training, calculate the “importance” of each parameter for the task.
- Soft-mask parameters based on their accumulated importance, making important parameters less likely to change during the learning of new tasks.
Let’s break the approach down step by step:
1. Training the First Task
Train the model on the first task’s dataset as normal.
2. Calculate Parameter Importance for the First Task
After the training of the first task is complete, we calculate the importance of each model parameter. The intuition here is simple, we use the gradients of each parameter to compute its importance. A larger gradient implies that a small change in that parameter will result in a larger change in the loss, meaning the model’s performance could vary more significantly, hence that parameter is important.
The gradients are also normalised, because gradients in the first layer could be small, while those in the last layer could be large. If you’re calculating importance based on these raw gradient values, parameters in the last layer would seem more important because of the scale of their gradients, not necessarily because they are genuinely more crucial for the task.
Let’s translate this calculation to PyTorch-like pseudocode:
import torchdef compute_final_importance(model, loss_function, data_loader):
# Get a single batch from the data loader
inputs, labels = next(iter(data_loader))
# Forward and backward pass to calculate the gradients for all parameters
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
importances = ()
# Calculate importance based on the gradients
for param in model.parameters():
if param.grad is not None: # Gradients may be None for some unused parameters
normalized_grad = (param.grad - torch.mean(param.grad)) / torch.std(param.grad)
importance = torch.tanh(normalized_grad)
importances.append(importance)
return torch.stack(importances).mean(dim=0)
3. Accumulating Importance Across Tasks
The accumulated importance of each parameter across task is simply calculated by taking the max value at any stage.
4. Training Subsequent Tasks, combined loss and the soft-masking mechanism:
When training on new tasks, the researchers use a combined loss function consisting of two parts. One is the standard loss function which is used as normal on the new task and data, and the second is an additional loss function which involves putting the new data through the old model (the converged model checkpoint after the previous task) and summing up the logits produced. In classification networks the logits are usually the raw non normalised predictions generated by the model in one of the last layers before going through something like a softmax function. This sum of logits serves as a form of loss. The rationale is that if the summed logits are significantly affected when the model parameters change, those parameters are crucial for the performance of the previously learned task.
The gradients generated from this additional loss serve as a guide during backpropagation, nudging the shared parameters to change in a direction that is less likely to harm performance on the first task. It therefore acts as a sort of penalty term to enforce that any updates made to the model do not lead to a significant loss of information related to previous tasks.
Train the model on the next task. Use a standard training loop, but modify the gradients during backpropagation based on their accumulated importance. This is the soft-masking mechanism:
import torchaccumulated_importance = # calculated at the end of each task
for epoch in range(num_epochs):
for x, y in train_loader:
# Forward Pass: Calculate the loss for the current task using the proper loss function
logits = new_model(x)
loss_current_task = nn.CrossEntropyLoss()(logits, y)
# Forward Pass: Calculate the additional losses for previous tasks (CHI mechanism)
loss_previous_tasks = 0
for prev_task_id in range(task_id):
logits_prev = old_model(x, prev_task_id)
loss_previous_tasks += logits_prev.sum()
# Combine the losses
combined_loss = loss_current_task + loss_previous_tasks
# Backward Pass
optimizer.zero_grad()
combined_loss.backward()
# Update the accumulated importance
for param, acc_imp in zip(model.parameters(), accumulated_importance):
grad = param.grad
acc_imp = torch.max(acc_imp, torch.abs(grad))
# Soft-masking the gradients before taking an optimization step
for param, imp in zip(model.parameters(), accumulated_importance):
param.grad *= (1 - importance)
optimizer.step()
5. Soft-Masking Special Cases
- Feature Extractor: Gradients of parameters in the shared feature extractor are modified based on their specific accumulated importance.
- Classification Head: For the classification head, gradients are modified based on the average importance of the feature extractor.
Applying this to LLMs
Bear in mind, this paper does not experiment this with a language model, but I assume in a language model you could think of the transformer layers as analogous to the “feature extractor,” and the final classification layer (which predicts the next word or token in the sequence) as the “classification head.”
Next we’ll go into a paper which applies similar soft-masking to the pre-training stage in language modelling.
This paper introduces a technique called DAS (Continual DA-pre-training of LMs with Soft-masking) for continual learning in the pre-training stage of a large language model. It applies a soft-masking technique similar to the one just discussed along with a couple other techniques in attempt to continue pre-training of an LLM without running into catastrophic forgetting.
Let’s break it down step by step:
Pre-train the LLM like normal.
Prepare New Domain Data:
A new dataset from a different domain is prepared.
Calculating the importance of each neuron
SPG used gradients to determine the importance of each parameter, and then applied the calculated importance value to mask the gradient adjustments of parameters during training. This paper tries to determine the importance of each unit/neuron, rather than parameter, and then uses this in the same way by masking the gradient during training.
This paper uses two different methods to calculate the importance of neurons, depending on the task at hand. One, a gradient-based importance detection method (originally outlined in this paper), and two, a custom “proxy loss function”.
The first introduced is not used at all in the continual learning of the first new domain. Why? It needs data from the training dataset to work and the authors state that users “don’t have access to the massive original pre-training dataset”, which is a fair assumption.
They propose a Proxy Loss Function:
I found this term confusing at first, but it’s called this because the original gradient-based importance detection method is defined as a loss function itself, which you can then use to run the network’s outputs through to get the gradients of each neuron, which can then be used to derive importance, just like the SPG technique.
According to the paper, the importance is calculated for each “unit” in the network, where a unit could be a neuron or an attention head.
Proxy loss function (“Proxy KL-divergence loss”):
- Take a subset of the new domain we’re wanting to train on and feed it twice through the model to get two different representations. These representations will differ a bit due to the existing dropout masks in the Transformer architecture.
- Compute the KL-divergence between these two representations.
Modified Backpropagation Flow with Proxy and Combined Loss
- Forward Pass: Data goes through a forward pass in the neural network.
- Backpropagation:
Apply Proxy Loss for Gradient Adjustment: The proxy loss function’s unit-level importance is used to soft-mask the original gradients. This is expressed as:
adjusted_grad *= (1 − unit_level_importance)
Calculate Combined Loss (MLM + Contrastive Loss): Compute the combined loss using both MLM and contrastive loss.
Further Pre-training on More Domains
- Direct Importance Calculation: For each new domain, the importance of each unit can now be directly calculated using the data from the new domain via the gradient-based method outlined in equation 3, eliminating the need for the proxy loss function which is only once used after the initial pre-training.
- The importance of neurons is updated incrementally as each new task is learned. This update is done using element-wise max. “Element-wise maximum (EMax) operation” refers to comparing two vectors element by element, and taking the maximum value for each corresponding element to create a new vector. E.g.: if you have two vectors A and B of the same length, the element-wise maximum will result in a new vector C where each element C(i) is the maximum between A(i) and B(i).
We’ll refer to the two techniques outlined in the comprehensive survey paper in section 3.1
Gradient Direction Preservation
The paper talks about manipulating the gradient-based optimisation process to make the gradient directions of new training samples close to those from old training samples. The formula
⟨ ∇θ Lₖ(θ; Dₖ), ∇θ Lₖ(θ; Mₜ) ⟩ ≥ 0
enforces that learning the new task should not increase the loss for the old tasks. Essentially, the gradients of the new task and the old tasks are encouraged to align.
Breaking down the formula, we take the dot product of the gradient of the loss from the new task (∇θ Lₖ(θ; Dₖ)) and the gradient of the loss from the old task (∇θ Lₖ(θ; Mₜ)) should be non-negative. In this context, a positive dot product implies that the gradients for the old task and the new task are generally pointing in the same direction, with the angle between these two vectors is less than or equal to 90 degrees.
Forward/Backward Passes:
Forward Pass:
You would run your input data Dₖ for the new task and Mₜ for the old task through the same model to calculate the loss for each.
Backward Pass:
- Compute the gradients of the loss with respect to the network parameters for both the old and new task.
- Alignment Check: Compute the dot product of the two gradients. You’d then use this information to modify the gradients for the new task in such a way that the dot product is non-negative.
- Update Weights: Update the model parameters using these “aligned” gradients.
import torch# Forward pass for the new task
output_k = model(D_k)
loss_k = criterion(output_k, y_k)
# Forward pass for the old task
output_t = model(M_t)
loss_t = criterion(output_t, y_t)
# Compute gradients for both tasks
loss_k.backward(retain_graph=True) # Compute gradients for new task but keep computation graph
grad_k = torch.cat((p.grad.view(-1) for p in model.parameters()))
optimizer.zero_grad()
loss_t.backward() # Compute gradients for old task
grad_t = torch.cat((p.grad.view(-1) for p in model.parameters()))
# Compute dot product and modify gradients if they don't align
dot_product = torch.dot(grad_k, grad_t)
if dot_product < 0:
# I'm not sure how you modify the gradients here if they don't align, I'm not sure the paper specifies it
# Use the modified gradient to update model parameters
index = 0
for p in model.parameters():
num_params = p.numel()
# Update using modified gradients
p.grad = grad_k(index: index + num_params).view(p.shape)
index += num_params
optimizer.step()
Gradient Direction Preservation without needing old training samples
The text also highlights that gradient projection can be performed even without storing old samples. NCL (Natural continual learning, paper link) is the technique summarised here. Note, this can be categorised as both a regularisation and optimisation based approach.
Training process step by step:
Forward Pass:
You would run your new data through the network and calculate the loss as usual.
Backward Pass:
Objective: The aim is to minimise the task-specific loss ℓk(θ) while adhering to a distance constraint d(θ,θ+δ)≤r.
Algorithm step by step:
- As normal, compute the gradient of the loss with respect to the model parameters ∇θℓk(θ).
- The δ is calculated using the update rule. This gives you the “suggested” changes to the model parameters θ based on the new task’s requirements.
- Then, you plug this δ into the distance constraint formula: d(θ,θ+δ)=squareroot(δ⊤Λ_k-1δ). The constraint acts like a boundary around the current parameters θ, defined by the distance metric d(θ,θ+δ) and the radius r. I struggled to see why they called it a “radius”, and not just “constraint number” or something. I think it’s because the researchers are visualising the gradients and training process in a high-dimensional space. When you apply a constraint based on the distance metric, you’re essentially defining a “sphere” around your current parameter values in that high-dimensional space. The “radius” r of this sphere sets a limit on how much the parameter can move while learning a new task.
- If the proposed δ would move θ too far according to this distance metric, i.e., beyond this boundary, you scale it down so that it stays within the allowable region defined by the radius r.
Let’s look at each bit more in-depth:
Update Rule: The update rule provides a direction in which θ should move.
Breaking it down:
- ∇θ ℓk(θ) represents the gradients for all parameters (θ) calculated by the loss function.
- Parameter importance calculation (Λ^(k-1)_(-1)): This term represents a precision matrix and it is yet another way to calculate the importance of parameters in the network. more details below
- Regularisation Term (θ — μ_(k-1)): This term pulls the updated parameters closer to the optimal parameters μ_(k-1) from the previous task. Like the before techniques, it acts as a regulariser to avoid deviation from what was already learned.
- Learning Rate (λ)
Distance Constraint: Before applying this update, you’d usually check whether this change δ would violate the distance constraint d(θ,θ+δ)≤r. If it does, you’d typically scale down δ so that it satisfies the constraint.
Precision matrix explanation: before in the soft-masking methods we saw the calculation of importance via the output of all neurons or their gradients. In this method a precision matrix is used. This is a bit more complex so I’ll attempt to explain it:
We first calculate the covariance matrix for the networks parameters. In the context of neural networks, the columns in the gradient matrix G correspond to the parameters (weights and biases) of the model. Each row in G represents the gradient vector for a single training example, with respect to all of those parameters.
So, if you have a neural network with P parameters (this includes all the weights and biases from all layers), then each gradient vector will have P elements, one for each parameter. Therefore, G will be a matrix of shape N × P, N representing each batch and therefore each row representing the average gradient vector across all the training examples in a given batch.
When you calculate the covariance matrix Σ from G, the resulting matrix will have dimensions P × P. The diagonal entries Σii will indicate the variance of the gradient with respect to the ith parameter, and the off-diagonal entries Σij will indicate the covariance between the gradients with respect to the ith and jth parameters. This gives you an idea of how these parameters interact or co-vary during the training process. The inverse of this matrix is the precision matrix, which is what we use to determine importance.
Why the precision matrix over the covariance matrix? While the covariance matrix Σ does capture how parameters interact with each other during training, it doesn’t specifically indicate how crucial each parameter is to the task at hand when all other parameters are considered. In contrast, the precision matrix allows us to assess the conditional independence (this is a concept in probability theory, look it up) of parameters. Large values in the precision matrix indicate that knowing one parameter is highly informative about another, given all the other parameters. I’m not going to go into examples of how this works so get ChatGPT to generate some examples using a very small neural network to see how the values can be interpreted.
Previous methods we saw that calculate importance focus on individual neurons or parameters, ignoring the relationships between them. The precision matrix, on the other hand, can capture these relationships. Like everything in deep learning, whether this is a better way to calculate the importance of a network, is going to be empirical and could differ depending on the task and scale of the network.
Algorithm step by step in PyTorch:
import torch# Constraint radius
radius = 0.1
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(data_loader):
optimizer.zero_grad()
# Forward pass
output = model(data)
loss = loss_function(output, target)
# Backward pass to get gradients for params
loss.backward()
model_grad = torch.cat((p.grad.data.view(-1) for p in model.parameters()))
# Compute δ using the NCL method
# δ = Λ^(-1) * grad - (θ - µ)
delta = torch.matmul(torch.inverse(covarianceMatrix), model_grad) - (torch.cat((p.data.view(-1) for p in model.parameters())) - parametersForPrevTask)
# Check constraint
if torch.norm(delta) > radius:
delta = radius * delta / torch.norm(delta)
# Update model parameters (θ) using δ
idx = 0
for p in model.parameters():
length = p.data.numel()
p.data += delta(idx: idx + length).view(p.data.shape)
idx += length
# Update Λ and µ for the next task, probably going to be task-specific and non-trivial