Inspired by Andrej Kapathy's recent YouTube video on Let's play GPT-2 (124M)I would like to reconstruct it with most of the training optimizations in Jax. Jax is designed for highly efficient computation speed, and it is quite interesting to compare Pytorch with its recent training optimization, and Jax with its related libraries like Flax (layer API for neural network training for Jax) and Optax (a gradient processing and optimization library for JAX). We will quickly learn what Jax is and reconstruct the GPT with Jax. At the end, we will compare the token/sec with multi-GPU training between Pytorch and Jax.
What is Jax?
Based on your read the documentJAX is a Python library for accelerator-oriented matrix computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. I would like to introduce JAX with its name. While some call it Just Another SG (Accelerated Linear Algibra), I prefer to call it J(it) A(autograd) x(LA) to demonstrate its high efficiency capability.
J — Just-in-time (JIT) compilation. When you run your Python function, Jax converts it into a primitive set of operations called Jaxpr. The Jaxpr expression will then become an input for XLA, which compiles the lower-level scripts to produce an executable optimized for the target device (CPU, GPU, or TPU).
A — Autograd. Computing gradients is a fundamental part of modern machine learning methods, and you can simply call them jax.grad()
to obtain gradients that allow optimizing the models.
x — XLA. This is an open-source machine learning compiler for CPU, GPU, and ML accelerators. In general, XLA performs several built-in optimizations and analysis steps on the HLO stable It then sends the HLO calculation to a backend for further HLO-level optimizations. The backend then performs target-specific code generation.
Those are just a few key features of JAX, but it also has many easy-to-use APIs similar to numpy. jax.numpy
and automatic vectorization with jax.vmap
and parallelize your codes across multiple devices via jax.pmap
We'll cover more Jax concepts and applications in upcoming blogs, but for now let's play NanoGPT with Jax!
From care to the transformer
GPT is a decoder-only transformer model, and the key building block is the Attention module. We can first define a model configuration data class to save the model hyperparameters, so that the model module can efficiently consume them to initialize the model architecture. Similar to the 124M GPT model, here we initialize a 12-layer transformer decoder with 12 heads and a vocabulary size of 50257 tokens, each of which has an embedding dimension of 768. The block size for attention calculation is 1024.
from dataclasses import dataclass@dataclass
class ModelConfig:
vocab_size: int = 50257
n_head: int = 12
n_embd: int = 768
block_size: int = 1024
n_layer: int = 12
dropout_rate: float = 0.1
Next, we come to the key component of the transformer model: attention. The idea is to process the inputs into three weight matrices: key, query, and value. Here we rely on the flax
a Jax layer and training API library to initialize the 3 weight matrix, by simply calling the flax.linen.Dense
As mentioned, Jax has many similar APIs to numpy, so we reword the outputs after the weight matrix with jax.numpy.reshape
from (batch_size, sequence_length, embedding_dim) to (batch_size, sequence_length, num_head, embedding_dim / num_head). Since we need to do matrix multiplication on the key and value arrays, jax also has jax.numpy.matmul
API and jax.numpy.transpose
(transpose the key matrix for multiplication).
Note that we need to put a mask on the attention matrix to prevent information leakage (prevent earlier tokens from accessing later tokens). jax.numpy.tril
helps to build a matrix of lower triangles and jax.numpy.where
Can you fill in the infinite number so that we get 0 after softmax? jax.nn.softmax
Full multi-head care codes can be found below.
from flax import linen as nn
import jax.numpy as jnpclass CausalSelfAttention(nn.Module):
config: ModelConfig
@nn.compact
def __call__(self, x, deterministic=True):
assert len(x.shape) == 3
b, l, d = x.shape
q = nn.Dense(self.config.n_embd)(x)
k = nn.Dense(self.config.n_embd)(x)
v = nn.Dense(self.config.n_embd)(x)
# q*k / sqrt(dim) -> softmax -> @v
q = jnp.reshape(q, (b, l, d//self.config.n_head , self.config.n_head))
k = jnp.reshape(k, (b, l, d//self.config.n_head , self.config.n_head))
v = jnp.reshape(v, (b, l, d//self.config.n_head , self.config.n_head))
norm = jnp.sqrt(list(jnp.shape(k))(-1))
attn = jnp.matmul(q,jnp.transpose(k, (0,1,3,2))) / norm
mask = jnp.tril(attn)
attn = jnp.where(mask(:,:,:l,:l), attn, float("-inf"))
probs = jax.nn.softmax(attn, axis=-1)
y = jnp.matmul(probs, v)
y = jnp.reshape(y, (b,l,d))
y = nn.Dense(self.config.n_embd)(y)
return y
You may notice that there is no __init__
either forward
methods as you can see in pytorch. This is what is special about jax, where you can explicitly define layers with setup
methods, or define them implicitly within the forward step by adding nn.compact
Above __call__
method. (referee)
Next, let's build the MLP layer and the block layer, which includes the dense layer, Gelu activation function, LayerNorm, and Dropout. Again, flax.linen has the layer APIs to help us build the module. Note that we will pass a deterministic
Boolean variable to control different behaviors during training or evaluation for some layers like Dropout.
class MLP(nn.Module):config: ModelConfig
@nn.compact
def __call__(self, x, deterministic=True):
x = nn.Dense(self.config.n_embd*4)(x)
x = nn.gelu(x, approximate=True)
x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=deterministic)
x = nn.Dense(self.config.n_embd)(x)
x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=deterministic)
return x
class Block(nn.Module):
config: ModelConfig
@nn.compact
def __call__(self, x):
x = nn.LayerNorm()(x)
x = x + CausalSelfAttention(self.config)(x)
x = nn.LayerNorm()(x)
x = x + MLP(self.config)(x)
return x
Now let's use the above blocks to build the NanoGPT:
Given the inputs of a sequence of token ids, we use the flax.linen.Embed
layer to obtain position embeddings and token embeddings. We then pass them to the Block module N times, where N is the number of layers defined in the model setup. At the end, we map the outputs of the last Block to the probabilities of each token in the vocabulary to predict the next token. In addition to the forward __call__
method, we are also going to create a init
methods to obtain the dummy inputs to obtain the model parameters.
class GPT(nn.Module):config: ModelConfig
@nn.compact
def __call__(self, x, deterministic=False):
B, T = x.shape
assert T <= self.config.block_size
pos = jnp.arange(0, T)(None)
pos_emb = nn.Embed(self.config.block_size, self.config.n_embd)(pos)
wte = nn.Embed(self.config.vocab_size, self.config.n_embd)
tok_emb = wte(x)
x = tok_emb + pos_emb
for _ in range(self.config.n_layer):
x = Block(self.config)(x)
x = nn.LayerNorm()(x)
logits = nn.Dense(config.n_embd, config.vocab_size)
# logits = wte.attend(x) # parameter sharing
return logits
def init(self, rng):
tokens = jnp.zeros((1, self.config.block_size), dtype=jnp.uint16)
params = jax.jit(super().init, static_argnums=(2,))(rng, tokens, True)
return params
Now, let's vary the number of parameters: first we initialize the model configuration data class and the random key, then we create dummy inputs and feed them into the GPT model. Then, we use jax.util.treemap
API to create a counting parameter function. We got 124439808 (124M) parameters, the same amount as Huggingface's GPT2, BOOM!
Data loader and training loop
Now we are going to overfit a small dataset. To make it comparable in Andrej's video on Pytorch NanoGPT, let's use the toy data set which he shared in his video. We use the GPT2 tokenizer from tiktoken
library to tokenize all texts in the input file and convert the tokens into jax.numpy.array
for Jax's model training.
class DataLoader:
def __init__(self, B, T):
self.current_position = 0
self.B = B
self.T = Twith open("input.txt","r") as f:
text = f.read()
enc = tiktoken.get_encoding("gpt2")
self.tokens = jnp.array(enc.encode(text))
print(f"loaded {len(self.tokens)} tokens in the datasets" )
print(f" 1 epoch = {len(self.tokens)//(B*T)} batches")
def next_batch(self):
B,T = self.B, self.T
buf = self.tokens(self.current_position:self.current_position+B*T+1)
x,y = jnp.reshape(buf(:-1),(B,T)), jnp.reshape(buf(1:),(B,T))
self.current_position += B*T
if self.current_position + B*T+1 > len(self.tokens):
self.current_position = 0
return x,y
Next, let's forget about distributed training and optimization first and just create a simple training loop for a consistency check. The first thing to do after initializing the model is to create a State Traina model state where we can update parameters and gradients. TrainState takes three important inputs: apply_fn (model feedforward function), params (model parameters from the init method), and tx (an Optax gradient transform).
We then use the train_step function to update the model state (gradients and parameters) to continue training the model. Optax
Provide the softmax cross entropy as the loss function for the following token prediction task, and jax.value_and_grad
Compute the gradients and loss value for the loss function. Finally, we update the model state with the new parameters using the apply_gradients
(API.referee) Don't forget to modify the train_step function to reduce the computational overhead!
def init_train_state(key, config) -> TrainState:
model = GPT(config)
params = model.init(key)
optimizer = optax.adamw(3e-4, b1=0.9, b2=0.98, eps=1e-9, weight_decay=1e-1)
train_state = TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer)
return train_state@jax.jit
def train_step(state: TrainState, x: jnp.ndarray, y: jnp.ndarray) -> Tuple(jnp.ndarray, TrainState):
def loss_fn(params: FrozenDict) -> jnp.ndarray:
logits = state.apply_fn(params, x, False)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
return loss
loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(state.params)
new_state = state.apply_gradients(grads=grads)
return loss, new_state
Now everything is ready for the poor man's training loop. Let's check the loss value. The model prediction should be better than the random guess, so the loss should be less than -ln(1/50257)≈10.825. What we expect from single batch overfitting is that: at first the loss is close to 10.825, then it goes down to close to 0. Let's take a batch of (x, y) and run the training loop 50 times. I also add a similar logarithm to calculate the training speed.
As we can see, the loss value is exactly as we expected and the training throughput is around 400–500 k token/sec. Which is already 40x faster than the initial Pytorch version without any optimization in Andrej's video. Note that we run the Jax scripts on 1 A100 GPU, which should eliminate the hardware difference for the speed comparison. There is no .to(device)
things to move your model or data from the host CPU to the device GPU, which is one of the benefits of Jax!
So that's it and we're done. We'll make training 10x faster in Part 2 with further optimizations…
Part 2:The journey of training optimization to 1350k tokens/sec on a single GPU!
“Unless otherwise stated, all images are the property of the author”