Large language models are composed of billions of parameters (weights). For each word it generates, the model has to perform computationally expensive calculations on all of these parameters.
Large language models accept a sentence or sequence of tokens and generate a probability distribution of the next most likely token.
Therefore, usually the decoding north tokens (or generate north model words) requires running the model north number of times. At each iteration, the new token is added to the input sentence and passed back to the model. This can be costly.
Furthermore, the decoding strategy can influence the quality of the words generated. Generating tokens simply by simply taking the token with the highest probability in the output distribution can result in repetitive text. Random sampling from the distribution can cause unintentional drift.
Therefore, a robust decoding strategy is required to ensure both:
- High quality results
- Fast inference time
Both requirements can be addressed using a combination of a large and small language model, as long as the amateur and expert models are similar (e.g., the same architecture but different sizes).
- Objective/large model: Main LM with higher number of parameters (e.g. OPT-13B)
- Hobbyist/small model: Smaller version of Main LM with fewer parameters (e.g. OPT-125M)
Speculative and contrastive Decoding takes advantage of large and small LLMs for reliable and efficient text generation.
Contrastive decoding it is a strategy that exploits the fact that flaws in large LLMs (such as repetition, inconsistency) are even more pronounced in small LLMs. Therefore, this strategy optimizes the tokens with the largest probability difference between the small and large model.
For a single prediction, contrastive decoding generates two probability distributions:
- q = logit probabilities for the amateur model
- p = logit probabilities for the expert model
The next token is chosen based on the following criteria:
- Discard all tokens that do not have a high enough probability according to the expert model (discard p(x) < alpha * max(p))
- From the remaining tokens, select the one with the largest difference between the log probabilities of the large model and the small model, max(p(x) – q(x)).
Contrastive decoding implementation
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
amateur_lm = AutoModelForCausalLM.from_pretrained('gpt2')
expert_lm = AutoModelForCausalLM.from_pretrained('gpt2-large')
def contrastive_decoding(prompt, max_length=50):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
while input_ids.shape(1) < max_length:
# Generate amateur model output
amateur_outputs = amateur_lm(input_ids, return_dict=True)
amateur_logits = torch.softmax(amateur_outputs.logits(:, -1, :), dim=-1)
log_probs_amateur = torch.log(amateur_logits)
# Generate expert model output
expert_outputs = expert_lm(input_ids, return_dict=True)
expert_logits = torch.softmax(expert_outputs.logits(:, -1, :), dim=-1)
log_probs_exp = torch.log(expert_logits)
log_probs_diff = log_probs_exp - log_probs_amateur
# Set an alpha threshold to eliminate less confident tokens in expert
alpha = 0.1
candidate_exp_prob = torch.max(expert_logits)
# Mask tokens below threshold for expert model
V_head = expert_logits < alpha * candidate_exp_prob
# Select the next token from the log-probabilities difference, ignoring masked values
token = torch.argmax(log_probs_diff.masked_fill(V_head, -torch.inf)).unsqueeze(0)
# Append token and accumulate generated text
input_ids = torch.cat((input_ids, token.unsqueeze(1)), dim=-1)
return tokenizer.batch_decode(input_ids)
prompt = "Large Language Models are"
generated_text = contrastive_decoding(prompt, max_length=25)
print(generated_text)
Speculative decoding It is based on the principle that the smaller model should sample from the same distribution as the larger model. Therefore, this strategy aims to accept as many predictions from the smaller model as possible, as long as they align with the distribution of the larger model.
The smallest model generates north tiles in sequence, as possible guesses. However, all north Sequences are fed into the larger expert model as a single batch, which is faster than sequential generation.
This results in a cache for each model, with north probability distributions in each cache.
- q = logit probabilities for the amateur model
- p = logit probabilities for the expert model
Tokens sampled from the amateur model are then accepted or rejected based on the following conditions:
- If the probability of the token is higher in the expert distribution (p) than in the amateur distribution (q), or p(x) > q(x), accept token
- If the probability of the token is lower in the expert distribution (p) than in the amateur distribution (q), or p(x) < q(x)reject token with probability 1 – p(x) / q(x)
If a token is rejected, the next token is sampled from the expert distribution or the fitted distribution. Additionally, the amateur and expert model resets the cache and rebuilds north guesses and probability distributions p and q.
Speculative decoding implementation
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch# Load models and tokenizer
tokenizer = AutoTokenizer.from_pretrained('gpt2')
amateur_lm = AutoModelForCausalLM.from_pretrained('gpt2')
expert_lm = AutoModelForCausalLM.from_pretrained('gpt2-large')
# Sample next token from output distribution
def sample_from_distribution(logits):
sampled_index = torch.multinomial(logits, 1)
return sampled_index
def generate_cache(input_ids, n_tokens):
# Store logits at each step for amateur and expert models
amateur_logits_per_step = ()
generated_tokens = ()
batch_input_ids = ()
with torch.no_grad():
for _ in range(n_tokens):
# Generate amateur model output
amateur_outputs = amateur_lm(input_ids, return_dict=True)
amateur_logits = torch.softmax(amateur_outputs.logits(:, -1, :), dim=-1)
amateur_logits_per_step.append(amateur_logits)
# Sampling from amateur logits
next_token = sample_from_distribution(amateur_logits)
generated_tokens.append(next_token)
# Append to input_ids for next generation step
input_ids = torch.cat((input_ids, next_token), dim=-1)
batch_input_ids.append(input_ids.squeeze(0))
# Feed IDs to expert model as batch
batched_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=0 )
expert_outputs = expert_lm(batched_input_ids, return_dict=True)
expert_logits = torch.softmax(expert_outputs.logits(:, -1, :), dim=-1)
return amateur_logits_per_step, expert_logits, torch.cat(generated_tokens, dim=-1)
def speculative_decoding(prompt, n_tokens=5, max_length=50):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
while input_ids.shape(1) < max_length:
amateur_logits_per_step, expert_logits, generated_ids = generate_cache(
input_ids, n_tokens
)
accepted = 0
for n in range(n_tokens):
token = generated_ids(:, n)(0)
r = torch.rand(1).item()
# Extract probabilities
p_x = expert_logits(n)(token).item()
q_x = amateur_logits_per_step(n)(0)(token).item()
# Speculative decoding acceptance criterion
if ((q_x > p_x) and (r > (1 - p_x / q_x))):
break # Reject token and restart the loop
else:
accepted += 1
# Check length
if (input_ids.shape(1) + accepted) >= max_length:
return tokenizer.batch_decode(input_ids)
input_ids = torch.cat((input_ids, generated_ids(:, :accepted)), dim=-1)
if accepted < n_tokens:
diff = expert_logits(accepted) - amateur_logits_per_step(accepted)(0)
clipped_diff = torch.clamp(diff, min=0)
# Sample a token from the adjusted expert distribution
normalized_result = clipped_diff / torch.sum(clipped_diff, dim=0, keepdim=True)
next_token = sample_from_distribution(normalized_result)
input_ids = torch.cat((input_ids, next_token.unsqueeze(1)), dim=-1)
else:
# Sample directly from the expert logits for the last accepted token
next_token = sample_from_distribution(expert_logits(-1))
input_ids = torch.cat((input_ids, next_token.unsqueeze(1)), dim=-1)
return tokenizer.batch_decode(input_ids)
# Example usage
prompt = "Large Language models are"
generated_text = speculative_decoding(prompt, n_tokens=3, max_length=25)
print(generated_text)
Assessment
We can evaluate both decoding approaches by comparing them with a naive decoding method, where we randomly choose the next token from the probability distribution.
def sequential_sampling(prompt, max_length=50):
"""
Perform sequential sampling with the given model.
"""
# Tokenize the input prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_idswith torch.no_grad():
while input_ids.shape(1) < max_length:
# Sample from the model output logits for the last token
outputs = expert_lm(input_ids, return_dict=True)
logits = outputs.logits(:, -1, :)
probabilities = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probabilities, num_samples=1)
input_ids = torch.cat((input_ids, next_token), dim=-1)
return tokenizer.batch_decode(input_ids)
To evaluate contrastive decoding, we can use the following metrics of lexical richness.
- Entropy of n-grams: Measures the unpredictability or diversity of n-grams in the generated text. High entropy indicates more diverse text, while low entropy suggests repetition or predictability.
- different-n: Measures the proportion of unique n-grams in the generated text. Higher distinct n values indicate greater lexical diversity.
from collections import Counter
import mathdef ngram_entropy(text, n):
"""
Compute n-gram entropy for a given text.
"""
# Tokenize the text
tokens = text.split()
if len(tokens) < n:
return 0.0 # Not enough tokens to form n-grams
# Create n-grams
ngrams = (tuple(tokens(i:i + n)) for i in range(len(tokens) - n + 1))
# Count frequencies of n-grams
ngram_counts = Counter(ngrams)
total_ngrams = sum(ngram_counts.values())
# Compute entropy
entropy = -sum((count / total_ngrams) * math.log2(count / total_ngrams)
for count in ngram_counts.values())
return entropy
def distinct_n(text, n):
"""
Compute distinct-n metric for a given text.
"""
# Tokenize the text
tokens = text.split()
if len(tokens) < n:
return 0.0 # Not enough tokens to form n-grams
# Create n-grams
ngrams = (tuple(tokens(i:i + n)) for i in range(len(tokens) - n + 1))
# Count unique and total n-grams
unique_ngrams = set(ngrams)
total_ngrams = len(ngrams)
return len(unique_ngrams) / total_ngrams if total_ngrams > 0 else 0.0
prompts = (
"Large Language models are",
"Barack Obama was",
"Decoding strategy is important because",
"A good recipe for Halloween is",
"Stanford is known for"
)
# Initialize accumulators for metrics
naive_entropy_totals = (0, 0, 0) # For n=1, 2, 3
naive_distinct_totals = (0, 0) # For n=1, 2
contrastive_entropy_totals = (0, 0, 0)
contrastive_distinct_totals = (0, 0)
for prompt in prompts:
naive_generated_text = sequential_sampling(prompt, max_length=50)(0)
for n in range(1, 4):
naive_entropy_totals(n - 1) += ngram_entropy(naive_generated_text, n)
for n in range(1, 3):
naive_distinct_totals(n - 1) += distinct_n(naive_generated_text, n)
contrastive_generated_text = contrastive_decoding(prompt, max_length=50)(0)
for n in range(1, 4):
contrastive_entropy_totals(n - 1) += ngram_entropy(contrastive_generated_text, n)
for n in range(1, 3):
contrastive_distinct_totals(n - 1) += distinct_n(contrastive_generated_text, n)
# Compute averages
naive_entropy_averages = (total / len(prompts) for total in naive_entropy_totals)
naive_distinct_averages = (total / len(prompts) for total in naive_distinct_totals)
contrastive_entropy_averages = (total / len(prompts) for total in contrastive_entropy_totals)
contrastive_distinct_averages = (total / len(prompts) for total in contrastive_distinct_totals)
# Display results
print("Naive Sampling:")
for n in range(1, 4):
print(f"Average Entropy (n={n}): {naive_entropy_averages(n - 1)}")
for n in range(1, 3):
print(f"Average Distinct-{n}: {naive_distinct_averages(n - 1)}")
print("\nContrastive Decoding:")
for n in range(1, 4):
print(f"Average Entropy (n={n}): {contrastive_entropy_averages(n - 1)}")
for n in range(1, 3):
print(f"Average Distinct-{n}: {contrastive_distinct_averages(n - 1)}")
The following results show us that contrastive decoding outperforms naive sampling for these metrics.
Naive sampling:
Average entropy (n=1): 4.990499826537679
Average entropy (n=2): 5.174765791328267
Average entropy (n=3): 5.14373124004409
Distinct Average-1: 0.8949694135740648
Distinct Average-2: 0.9951219512195122Contrastive decoding:
Average entropy (n=1): 5.182773920916605
Average entropy (n=2): 5.3495681172235665
Average entropy (n=3): 5.313720275712986
Distinct Average-1: 0.9028425204970866
Different Average-2: 1.0
To evaluate speculative decoding, we can observe the average execution time of a set of cues for different north values.
import time
import matplotlib.pyplot as plt# Parameters
n_tokens = range(1, 11)
speculative_decoding_times = ()
naive_decoding_times = ()
prompts = (
"Large Language models are",
"Barack Obama was",
"Decoding strategy is important because",
"A good recipe for Halloween is",
"Stanford is known for"
)
# Loop through n_tokens values
for n in n_tokens:
avg_time_naive, avg_time_speculative = 0, 0
for prompt in prompts:
start_time = time.time()
_ = sequential_sampling(prompt, max_length=25)
avg_time_naive += (time.time() - start_time)
start_time = time.time()
_ = speculative_decoding(prompt, n_tokens=n, max_length=25)
avg_time_speculative += (time.time() - start_time)
naive_decoding_times.append(avg_time_naive / len(prompts))
speculative_decoding_times.append(avg_time_speculative / len(prompts))
avg_time_naive = sum(naive_decoding_times) / len(naive_decoding_times)
# Plotting the results
plt.figure(figsize=(8, 6))
plt.bar(n_tokens, speculative_decoding_times, width=0.6, label='Speculative Decoding Time', alpha=0.7)
plt.axhline(y=avg_time_naive, color='red', linestyle='--', label='Naive Decoding Time')
# Labels and title
plt.xlabel('n_tokens', fontsize=12)
plt.ylabel('Average Time (s)', fontsize=12)
plt.title('Speculative Decoding Runtime vs n_tokens', fontsize=14)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
# Show the plot
plt.show()
plt.savefig("plot.png")
We can see that the average execution time for naive decoding is much longer than that for speculative decoding. north values.
Combining large and small language models for decoding strikes a balance between quality and efficiency. While these approaches introduce additional complexity to system design and resource management, their benefits apply to conversational ai, real-time translation, and content creation.
These approaches require careful consideration of implementation limitations. For example, the additional memory and compute demands when running dual models can limit feasibility on edge devices, although this can be mitigated by techniques such as model quantization.
Unless otherwise noted, all images are the author's.