Una implementación de un extremo a otro de un Pytorch Transformer, en la que cubriremos conceptos clave como autoatención, codificadores, decodificadores y mucho más.
Cuando decidí profundizar en las arquitecturas Transformer, a menudo me sentía frustrado al leer o mirar tutoriales en línea porque sentía que siempre se perdían algo:
- Los tutoriales oficiales de Tensorflow o Pytorch utilizaron sus propias API, por lo que se mantuvieron en un alto nivel y me obligaron a tener que acceder a su código base para ver qué había debajo del capó. Miles de líneas de código requieren mucho tiempo y no siempre son fáciles de leer.
- Otros tutoriales con código personalizado que encontré (enlaces al final del artículo) a menudo simplificaban demasiado los casos de uso y no abordaban conceptos como el enmascaramiento del manejo por lotes de secuencias de longitud variable.
Por lo tanto, decidí escribir mi propio Transformer para asegurarme de entender los conceptos y poder usarlo con cualquier conjunto de datos.
Por lo tanto, durante este artículo seguiremos un enfoque metódico en el que implementaremos un transformador capa por capa y bloque por bloque.
Obviamente, hay muchas implementaciones diferentes, así como API de alto nivel de Pytorch o Tensorflow, que ya están disponibles en el mercado, con, estoy seguro, un mejor rendimiento que el modelo que crearemos.
“Está bien, pero ¿por qué no utilizar las implementaciones TF/Pytorch entonces”?
El propósito de este artículo es educativo y no tengo ninguna intención de superar las implementaciones de Pytorch o Tensorflow. Creo que la teoría y el código detrás de los transformadores no son sencillos, por eso espero que seguir este tutorial paso a paso te permita comprender mejor estos conceptos y sentirte más cómodo al crear tu propio código. más tarde.
Otra razón para construir su propio transformador desde cero es que le permitirá comprender completamente cómo utilizar las API anteriores. Si nos fijamos en la implementación de Pytorch del forward()
método de la clase Transformer, verá muchas palabras clave oscuras como:
Si ya está familiarizado con estas palabras clave, puede omitir este artículo.
De lo contrario, este artículo le guiará a través de cada una de estas palabras clave con los conceptos subyacentes.
Si ya has oído hablar de ChatGPT o Gemini, entonces ya conociste a un transformador. En realidad, la “T” de ChatGPT significa Transformer.
La arquitectura fue acuñada por primera vez en 2017 por investigadores de Google en el artículo “La atención es todo lo que necesitas”. Es bastante revolucionario ya que los modelos anteriores utilizados para realizar aprendizaje de secuencia a secuencia (traducción automática, voz a texto, etc.) dependían de RNN que eran computacionalmente costosos en el sentido de que tenían que procesar secuencias paso a paso, mientras que Transformers Solo es necesario mirar una vez la secuencia completa, moviendo la complejidad temporal de O(n) a O(1).
Las aplicaciones de los transformadores son bastante amplias en el ámbito de la PNL e incluyen traducción de idiomas, respuesta a preguntas, resumen de documentos, generación de texto, etc.
La arquitectura general de un transformador es la siguiente:
El primer bloque que implementaremos es en realidad la parte más importante de un transformador y se llama Atención de cabezales múltiples. Veamos dónde se ubica en la arquitectura general.
La atención es un mecanismo que en realidad no es específico de los transformadores y que ya se utilizó en los modelos secuencia a secuencia RNN.
import torch
import torch.nn as nn
import mathclass MultiHeadAttention(nn.Module):
def __init__(self, hidden_dim=256, num_heads=4):
"""
input_dim: Dimensionality of the input.
num_heads: The number of attention heads to split the input into.
"""
super(MultiHeadAttention, self).__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
assert hidden_dim % num_heads == 0, "Hidden dim must be divisible by num heads"
self.Wv = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Value part
self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Key part
self.Wq = nn.Linear(hidden_dim, hidden_dim, bias=False) # the Query part
self.Wo = nn.Linear(hidden_dim, hidden_dim, bias=False) # the output layer
def check_sdpa_inputs(self, x):
assert x.size(1) == self.num_heads, f"Expected size of x to be ({-1, self.num_heads, -1, self.hidden_dim // self.num_heads}), got {x.size()}"
assert x.size(3) == self.hidden_dim // self.num_heads
def scaled_dot_product_attention(
self,
query,
key,
value,
attention_mask=None,
key_padding_mask=None):
"""
query : tensor of shape (batch_size, num_heads, query_sequence_length, hidden_dim//num_heads)
key : tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
value : tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
attention_mask : tensor of shape (query_sequence_length, key_sequence_length)
key_padding_mask : tensor of shape (sequence_length, key_sequence_length)
"""
self.check_sdpa_inputs(query)
self.check_sdpa_inputs(key)
self.check_sdpa_inputs(value)
d_k = query.size(-1)
tgt_len, src_len = query.size(-2), key.size(-2)
# logits = (B, H, tgt_len, E) * (B, H, E, src_len) = (B, H, tgt_len, src_len)
logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Attention mask here
if attention_mask is not None:
if attention_mask.dim() == 2:
assert attention_mask.size() == (tgt_len, src_len)
attention_mask = attention_mask.unsqueeze(0)
logits = logits + attention_mask
else:
raise ValueError(f"Attention mask size {attention_mask.size()}")
# Key mask here
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2) # Broadcast over batch size, num heads
logits = logits + key_padding_mask
attention = torch.softmax(logits, dim=-1)
output = torch.matmul(attention, value) # (batch_size, num_heads, sequence_length, hidden_dim)
return output, attention
def split_into_heads(self, x, num_heads):
batch_size, seq_length, hidden_dim = x.size()
x = x.view(batch_size, seq_length, num_heads, hidden_dim // num_heads)
return x.transpose(1, 2) # Final dim will be (batch_size, num_heads, seq_length, , hidden_dim // num_heads)
def combine_heads(self, x):
batch_size, num_heads, seq_length, head_hidden_dim = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, num_heads * head_hidden_dim)
def forward(
self,
q,
k,
v,
attention_mask=None,
key_padding_mask=None):
"""
q : tensor of shape (batch_size, query_sequence_length, hidden_dim)
k : tensor of shape (batch_size, key_sequence_length, hidden_dim)
v : tensor of shape (batch_size, key_sequence_length, hidden_dim)
attention_mask : tensor of shape (query_sequence_length, key_sequence_length)
key_padding_mask : tensor of shape (sequence_length, key_sequence_length)
"""
q = self.Wq(q)
k = self.Wk(k)
v = self.Wv(v)
q = self.split_into_heads(q, self.num_heads)
k = self.split_into_heads(k, self.num_heads)
v = self.split_into_heads(v, self.num_heads)
# attn_values, attn_weights = self.multihead_attn(q, k, v, attn_mask=attention_mask)
attn_values, attn_weights = self.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attention_mask=attention_mask,
key_padding_mask=key_padding_mask,
)
grouped = self.combine_heads(attn_values)
output = self.Wo(grouped)
self.attention_weigths = attn_weights
return output
Necesitamos explicar algunos conceptos aquí.
1) Consultas, Claves y Valores.
El consulta es la información que estás tratando de hacer coincidir,
El llave y valores son la información almacenada.
Piense en eso como si estuviera usando un diccionario: siempre que use un diccionario de Python, si su consulta no coincide con las claves del diccionario, no se le devolverá nada. Pero ¿qué pasa si queremos que nuestro diccionario devuelva una combinación de información bastante parecida? Como si tuviéramos:
d = {"panther": 1, "bear": 10, "dog":3}
d("wolf") = 0.2*d("panther") + 0.7*d("dog") + 0.1*d("bear")
Básicamente, de esto se trata la atención: mirar diferentes partes de sus datos y combinarlas para obtener una síntesis como respuesta a su consulta.
La parte relevante del código es esta, donde calculamos los pesos de atención entre la consulta y las claves.
logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # we compute the weights of attention
Y este, donde aplicamos los pesos normalizados a los valores:
attention = torch.softmax(logits, dim=-1)
output = torch.matmul(attention, value) # (batch_size, num_heads, sequence_length, hidden_dim)
2) Atención enmascaramiento y relleno.
Al prestar atención a partes de una entrada secuencial, no queremos incluir información inútil o prohibida.
La información inútil es, por ejemplo, el relleno: nuestro modelo debe ignorar los símbolos de relleno, utilizados para alinear todas las secuencias de un lote con el mismo tamaño de secuencia. Volveremos a eso en la última sección.
La información prohibida es un poco más compleja. Cuando se entrena, un modelo aprende a codificar la secuencia de entrada y alinear objetivos con las entradas. Sin embargo, como el proceso de inferencia implica observar tokens emitidos previamente para predecir el siguiente (piense en la generación de texto en ChatGPT), debemos aplicar las mismas reglas durante el entrenamiento.
Por eso aplicamos un máscara causal para garantizar que los objetivos, en cada paso de tiempo, solo puedan ver información del pasado. Aquí está la sección correspondiente donde se aplica la máscara (el cálculo de la máscara se cubre al final)
if attention_mask is not None:
if attention_mask.dim() == 2:
assert attention_mask.size() == (tgt_len, src_len)
attention_mask = attention_mask.unsqueeze(0)
logits = logits + attention_mask
Corresponde a la siguiente parte del Transformador:
Al recibir y tratar una entrada, un transformador no tiene sentido de orden ya que mira la secuencia como un todo, a diferencia de lo que hacen los RNN. Por lo tanto, necesitamos agregar un indicio de orden temporal para que el transformador pueda aprender las dependencias.
Los detalles específicos de cómo funciona la codificación posicional están fuera del alcance de este artículo, pero no dudes en leer el documento original para comprenderlos.
# Taken from https://pytorch.org/tutorials/beginner/transformer_tutorial.html#define-the-model
class PositionalEncoding(nn.Module):def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe(:, 0::2) = torch.sin(position * div_term)
pe(:, 1::2) = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Arguments:
x: Tensor, shape ``(batch_size, seq_len, embedding_dim)``
"""
x = x + self.pe(:, :x.size(1), :)
return x
¡Estamos a punto de tener un codificador completo funcionando! El codificador es la parte izquierda del transformador.
Agregaremos una pequeña parte a nuestro código, que es la parte Feed Forward:
class PositionWiseFeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int):
super(PositionWiseFeedForward, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
¡Juntando las piezas, obtenemos un módulo codificador!
class EncoderBlock(nn.Module):
def __init__(self, n_dim: int, dropout: float, n_heads: int):
super(EncoderBlock, self).__init__()
self.mha = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
self.norm1 = nn.LayerNorm(n_dim)
self.ff = PositionWiseFeedForward(n_dim, n_dim)
self.norm2 = nn.LayerNorm(n_dim)
self.dropout = nn.Dropout(dropout)def forward(self, x, src_padding_mask=None):
assert x.ndim==3, "Expected input to be 3-dim, got {}".format(x.ndim)
att_output = self.mha(x, x, x, key_padding_mask=src_padding_mask)
x = x + self.dropout(self.norm1(att_output))
ff_output = self.ff(x)
output = x + self.norm2(ff_output)
return output
Como se muestra en el diagrama, el codificador en realidad contiene N bloques o capas de codificador, así como una capa de incrustación para nuestras entradas. Por lo tanto, creemos un codificador agregando los bloques Incrustación, Codificación posicional y Codificador:
class Encoder(nn.Module):
def __init__(
self,
vocab_size: int,
n_dim: int,
dropout: float,
n_encoder_blocks: int,
n_heads: int):super(Encoder, self).__init__()
self.n_dim = n_dim
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=n_dim
)
self.positional_encoding = PositionalEncoding(
d_model=n_dim,
dropout=dropout
)
self.encoder_blocks = nn.ModuleList((
EncoderBlock(n_dim, dropout, n_heads) for _ in range(n_encoder_blocks)
))
def forward(self, x, padding_mask=None):
x = self.embedding(x) * math.sqrt(self.n_dim)
x = self.positional_encoding(x)
for block in self.encoder_blocks:
x = block(x=x, src_padding_mask=padding_mask)
return x
La parte del decodificador es la parte de la izquierda y requiere un poco más de elaboración.
Hay algo llamado Atención de múltiples cabezas enmascaradas. Recuerda lo que dijimos antes sobre máscara causal ? Bueno esto pasa aquí. Usaremos el parámetro atencion_mask de nuestro módulo de atención de cabezales múltiples para representar esto (más detalles sobre cómo calculamos la máscara al final):
# Stuff beforeself.self_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
masked_att_output = self.self_attention(
q=tgt,
k=tgt,
v=tgt,
attention_mask=tgt_mask, <-- HERE IS THE CAUSAL MASK
key_padding_mask=tgt_padding_mask)
# Stuff after
La segunda atención se llama. atención cruzada. ¡Utilizará la consulta del decodificador para hacer coincidir la clave y los valores del codificador! Cuidado: pueden tener diferentes longitudes durante el entrenamiento, por lo que suele ser una buena práctica definir claramente las formas esperadas de las entradas de la siguiente manera:
def scaled_dot_product_attention(
self,
query,
key,
value,
attention_mask=None,
key_padding_mask=None):
"""
query : tensor of shape (batch_size, num_heads, query_sequence_length, hidden_dim//num_heads)
key : tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
value : tensor of shape (batch_size, num_heads, key_sequence_length, hidden_dim//num_heads)
attention_mask : tensor of shape (query_sequence_length, key_sequence_length)
key_padding_mask : tensor of shape (sequence_length, key_sequence_length)"""
Y aquí está la parte donde usamos la salida del codificador, llamada memoriacon nuestra entrada de decodificador:
# Stuff before
self.cross_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
cross_att_output = self.cross_attention(
q=x1,
k=memory,
v=memory,
attention_mask=None, <-- NO CAUSAL MASK HERE
key_padding_mask=memory_padding_mask) <-- WE NEED TO USE THE PADDING OF THE SOURCE
# Stuff after
Juntando las piezas, terminamos con esto para el Decodificador:
class DecoderBlock(nn.Module):
def __init__(self, n_dim: int, dropout: float, n_heads: int):
super(DecoderBlock, self).__init__()# The first Multi-Head Attention has a mask to avoid looking at the future
self.self_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
self.norm1 = nn.LayerNorm(n_dim)
# The second Multi-Head Attention will take inputs from the encoder as key/value inputs
self.cross_attention = MultiHeadAttention(hidden_dim=n_dim, num_heads=n_heads)
self.norm2 = nn.LayerNorm(n_dim)
self.ff = PositionWiseFeedForward(n_dim, n_dim)
self.norm3 = nn.LayerNorm(n_dim)
# self.dropout = nn.Dropout(dropout)
def forward(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, memory_padding_mask=None):
masked_att_output = self.self_attention(
q=tgt, k=tgt, v=tgt, attention_mask=tgt_mask, key_padding_mask=tgt_padding_mask)
x1 = tgt + self.norm1(masked_att_output)
cross_att_output = self.cross_attention(
q=x1, k=memory, v=memory, attention_mask=None, key_padding_mask=memory_padding_mask)
x2 = x1 + self.norm2(cross_att_output)
ff_output = self.ff(x2)
output = x2 + self.norm3(ff_output)
return output
class Decoder(nn.Module):
def __init__(
self,
vocab_size: int,
n_dim: int,
dropout: float,
max_seq_len: int,
n_decoder_blocks: int,
n_heads: int):
super(Decoder, self).__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=n_dim
)
self.positional_encoding = PositionalEncoding(
d_model=n_dim,
dropout=dropout
)
self.decoder_blocks = nn.ModuleList((
DecoderBlock(n_dim, dropout, n_heads) for _ in range(n_decoder_blocks)
))
def forward(self, tgt, memory, tgt_mask=None, tgt_padding_mask=None, memory_padding_mask=None):
x = self.embedding(tgt)
x = self.positional_encoding(x)
for block in self.decoder_blocks:
x = block(x, memory, tgt_mask=tgt_mask, tgt_padding_mask=tgt_padding_mask, memory_padding_mask=memory_padding_mask)
return x
Recuerde la sección Atención de múltiples cabezales donde mencionamos excluir ciertas partes de las entradas al prestar atención.
Durante el entrenamiento, consideramos lotes de entradas y objetivos, donde cada instancia puede tener una longitud variable. Considere el siguiente ejemplo en el que agrupamos 4 palabras: plátano, sandía, pera y arándano. Para procesarlas como un solo lote, debemos alinear todas las palabras con la longitud de la palabra más larga (sandía). Por lo tanto, agregaremos una ficha adicional, PAD, a cada palabra para que todas terminen con la misma longitud que una sandía.
En la siguiente imagen, la tabla superior representa los datos sin procesar, la tabla inferior la versión codificada:
En nuestro caso, queremos excluir los índices de relleno de los pesos de atención que se calculan. Por lo tanto, podemos calcular una máscara de la siguiente manera, tanto para los datos de origen como para los de destino:
padding_mask = (x == PAD_IDX)
¿Qué pasa ahora con las máscaras causales? Bueno, si queremos, en cada paso de tiempo, que el modelo pueda atender solo pasos en el pasado, esto significa que para cada paso de tiempo T, el modelo solo puede atender cada paso t para t en 1…T. Es un bucle for doble, por lo tanto podemos usar una matriz para calcular eso:
def generate_square_subsequent_mask(size: int):
"""Generate a triangular (size, size) mask. From PyTorch docs."""
mask = (1 - torch.triu(torch.ones(size, size), diagonal=1)).bool()
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
¡Construyamos ahora nuestro Transformer juntando piezas!
En nuestro caso de uso, usaremos un conjunto de datos muy simple para mostrar cómo aprenden realmente los Transformers.
“¿Pero por qué utilizar un transformador para invertir palabras? ¡Ya sé cómo hacer eso en Python con word(::-1)!”
El objetivo aquí es ver si el mecanismo de atención del Transformador funciona. Lo que esperamos es ver que los pesos de atención se muevan de derecha a izquierda cuando se les da una secuencia de entrada. Si es así, esto significa que nuestro Transformer ha aprendido una gramática muy simple, que simplemente lee de derecha a izquierda, y podría generalizar a gramáticas más complejas al realizar traducciones de idiomas de la vida real.
Primero comencemos con nuestra clase Transformer personalizada:
import torch
import torch.nn as nn
import mathfrom .encoder import Encoder
from .decoder import Decoder
class Transformer(nn.Module):
def __init__(self, **kwargs):
super(Transformer, self).__init__()
for k, v in kwargs.items():
print(f" * {k}={v}")
self.vocab_size = kwargs.get('vocab_size')
self.model_dim = kwargs.get('model_dim')
self.dropout = kwargs.get('dropout')
self.n_encoder_layers = kwargs.get('n_encoder_layers')
self.n_decoder_layers = kwargs.get('n_decoder_layers')
self.n_heads = kwargs.get('n_heads')
self.batch_size = kwargs.get('batch_size')
self.PAD_IDX = kwargs.get('pad_idx', 0)
self.encoder = Encoder(
self.vocab_size, self.model_dim, self.dropout, self.n_encoder_layers, self.n_heads)
self.decoder = Decoder(
self.vocab_size, self.model_dim, self.dropout, self.n_decoder_layers, self.n_heads)
self.fc = nn.Linear(self.model_dim, self.vocab_size)
@staticmethod
def generate_square_subsequent_mask(size: int):
"""Generate a triangular (size, size) mask. From PyTorch docs."""
mask = (1 - torch.triu(torch.ones(size, size), diagonal=1)).bool()
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def encode(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Input
x: (B, S) with elements in (0, C) where C is num_classes
Output
(B, S, E) embedding
"""
mask = (x == self.PAD_IDX).float()
encoder_padding_mask = mask.masked_fill(mask == 1, float('-inf'))
# (B, S, E)
encoder_output = self.encoder(
x,
padding_mask=encoder_padding_mask
)
return encoder_output, encoder_padding_mask
def decode(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
memory_padding_mask=None
) -> torch.Tensor:
"""
B = Batch size
S = Source sequence length
L = Target sequence length
E = Model dimension
Input
encoded_x: (B, S, E)
y: (B, L) with elements in (0, C) where C is num_classes
Output
(B, L, C) logits
"""
mask = (tgt == self.PAD_IDX).float()
tgt_padding_mask = mask.masked_fill(mask == 1, float('-inf'))
decoder_output = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=self.generate_square_subsequent_mask(tgt.size(1)),
tgt_padding_mask=tgt_padding_mask,
memory_padding_mask=memory_padding_mask,
)
output = self.fc(decoder_output) # shape (B, L, C)
return output
def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
"""
Input
x: (B, Sx) with elements in (0, C) where C is num_classes
y: (B, Sy) with elements in (0, C) where C is num_classes
Output
(B, L, C) logits
"""
# Encoder output shape (B, S, E)
encoder_output, encoder_padding_mask = self.encode(x)
# Decoder output shape (B, L, C)
decoder_output = self.decode(
tgt=y,
memory=encoder_output,
memory_padding_mask=encoder_padding_mask
)
return decoder_output
Realizar inferencia con decodificación codiciosa
Necesitamos agregar un método que actuará como el famoso model.predict
de scikit.learn. El objetivo es pedirle al modelo que genere predicciones dinámicamente dada una entrada. Durante la inferencia, no hay objetivo: el modelo comienza generando un token atendiendo a la salida y usa su propia predicción para continuar emitiendo tokens. Es por eso que esos modelos a menudo se denominan modelos autorregresivos, ya que utilizan predicciones pasadas para predecir la siguiente.
El problema de la decodificación codiciosa es que considera el token con mayor probabilidad en cada paso. Esto puede dar lugar a predicciones muy malas si las primeras fichas son completamente erróneas. Existen otros métodos de decodificación, como la búsqueda Beam, que considera una lista corta de secuencias candidatas (piense en mantener los tokens top-k en cada paso de tiempo en lugar del argmax) y devuelve la secuencia con la probabilidad total más alta.
Por ahora, implementemos la decodificación codiciosa y agréguela a nuestro modelo Transformer:
def predict(
self,
x: torch.Tensor,
sos_idx: int=1,
eos_idx: int=2,
max_length: int=None
) -> torch.Tensor:
"""
Method to use at inference time. Predict y from x one token at a time. This method is greedy
decoding. Beam search can be used instead for a potential accuracy boost.Input
x: str
Output
(B, L, C) logits
"""
# Pad the tokens with beginning and end of sentence tokens
x = torch.cat((
torch.tensor((sos_idx)),
x,
torch.tensor((eos_idx)))
).unsqueeze(0)
encoder_output, mask = self.transformer.encode(x) # (B, S, E)
if not max_length:
max_length = x.size(1)
outputs = torch.ones((x.size()(0), max_length)).type_as(x).long() * sos_idx
for step in range(1, max_length):
y = outputs(:, :step)
probs = self.transformer.decode(y, encoder_output)
output = torch.argmax(probs, dim=-1)
# Uncomment if you want to see step by step predicitons
# print(f"Knowing {y} we output {output(:, -1)}")
if output(:, -1).detach().numpy() in (eos_idx, sos_idx):
break
outputs(:, step) = output(:, -1)
return outputs
Creando datos de juguetes
Definimos un pequeño conjunto de datos que invierte palabras, lo que significa que “helloworld” devolverá “dlrowolleh”:
import numpy as np
import torch
from torch.utils.data import Datasetnp.random.seed(0)
def generate_random_string():
len = np.random.randint(10, 20)
return "".join((chr(x) for x in np.random.randint(97, 97+26, len)))
class ReverseDataset(Dataset):
def __init__(self, n_samples, pad_idx, sos_idx, eos_idx):
super(ReverseDataset, self).__init__()
self.pad_idx = pad_idx
self.sos_idx = sos_idx
self.eos_idx = eos_idx
self.values = (generate_random_string() for _ in range(n_samples))
self.labels = (x(::-1) for x in self.values)
def __len__(self):
return len(self.values) # number of samples in the dataset
def __getitem__(self, index):
return self.text_transform(self.values(index).rstrip("\n")), \
self.text_transform(self.labels(index).rstrip("\n"))
def text_transform(self, x):
return torch.tensor((self.sos_idx) + (ord(z)-97+3 for z in x) + (self.eos_idx)
Ahora definiremos los pasos de formación y evaluación:
PAD_IDX = 0
SOS_IDX = 1
EOS_IDX = 2def train(model, optimizer, loader, loss_fn, epoch):
model.train()
losses = 0
acc = 0
history_loss = ()
history_acc = ()
with tqdm(loader, position=0, leave=True) as tepoch:
for x, y in tepoch:
tepoch.set_description(f"Epoch {epoch}")
optimizer.zero_grad()
logits = model(x, y(:, :-1))
loss = loss_fn(logits.contiguous().view(-1, model.vocab_size), y(:, 1:).contiguous().view(-1))
loss.backward()
optimizer.step()
losses += loss.item()
preds = logits.argmax(dim=-1)
masked_pred = preds * (y(:, 1:)!=PAD_IDX)
accuracy = (masked_pred == y(:, 1:)).float().mean()
acc += accuracy.item()
history_loss.append(loss.item())
history_acc.append(accuracy.item())
tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy.item())
return losses / len(list(loader)), acc / len(list(loader)), history_loss, history_acc
def evaluate(model, loader, loss_fn):
model.eval()
losses = 0
acc = 0
history_loss = ()
history_acc = ()
for x, y in tqdm(loader, position=0, leave=True):
logits = model(x, y(:, :-1))
loss = loss_fn(logits.contiguous().view(-1, model.vocab_size), y(:, 1:).contiguous().view(-1))
losses += loss.item()
preds = logits.argmax(dim=-1)
masked_pred = preds * (y(:, 1:)!=PAD_IDX)
accuracy = (masked_pred == y(:, 1:)).float().mean()
acc += accuracy.item()
history_loss.append(loss.item())
history_acc.append(accuracy.item())
return losses / len(list(loader)), acc / len(list(loader)), history_loss, history_acc
Y entrena el modelo durante un par de épocas:
def collate_fn(batch):
"""
This function pads inputs with PAD_IDX to have batches of equal length
"""
src_batch, tgt_batch = (), ()
for src_sample, tgt_sample in batch:
src_batch.append(src_sample)
tgt_batch.append(tgt_sample)src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=True)
tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=True)
return src_batch, tgt_batch
# Model hyperparameters
args = {
'vocab_size': 128,
'model_dim': 128,
'dropout': 0.1,
'n_encoder_layers': 1,
'n_decoder_layers': 1,
'n_heads': 4
}
# Define model here
model = Transformer(**args)
# Instantiate datasets
train_iter = ReverseDataset(50000, pad_idx=PAD_IDX, sos_idx=SOS_IDX, eos_idx=EOS_IDX)
eval_iter = ReverseDataset(10000, pad_idx=PAD_IDX, sos_idx=SOS_IDX, eos_idx=EOS_IDX)
dataloader_train = DataLoader(train_iter, batch_size=256, collate_fn=collate_fn)
dataloader_val = DataLoader(eval_iter, batch_size=256, collate_fn=collate_fn)
# During debugging, we ensure sources and targets are indeed reversed
# s, t = next(iter(dataloader_train))
# print(s(:4, ...))
# print(t(:4, ...))
# print(s.size())
# Initialize model parameters
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# Define loss function : we ignore logits which are padding tokens
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
# Save history to dictionnary
history = {
'train_loss': (),
'eval_loss': (),
'train_acc': (),
'eval_acc': ()
}
# Main loop
for epoch in range(1, 4):
start_time = time.time()
train_loss, train_acc, hist_loss, hist_acc = train(model, optimizer, dataloader_train, loss_fn, epoch)
history('train_loss') += hist_loss
history('train_acc') += hist_acc
end_time = time.time()
val_loss, val_acc, hist_loss, hist_acc = evaluate(model, dataloader_val, loss_fn)
history('eval_loss') += hist_loss
history('eval_acc') += hist_acc
print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Train acc: {train_acc:.3f}, Val loss: {val_loss:.3f}, Val acc: {val_acc:.3f} "f"Epoch time = {(end_time - start_time):.3f}s"))
Visualiza la atención
Definimos una pequeña función para acceder a los pesos de los cabezales de atención:
fig = plt.figure(figsize=(10., 10.))
images = model.decoder.decoder_blocks(0).cross_attention.attention_weigths(0,...).detach().numpy()
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(2, 2), # creates 2x2 grid of axes
axes_pad=0.1, # pad between axes in inch.
)for ax, im in zip(grid, images):
# Iterating over the grid returns the Axes.
ax.imshow(im)
Podemos ver un bonito patrón de derecha a izquierda cuando leemos los pesos desde arriba. Las partes verticales en la parte inferior del eje y seguramente pueden representar pesos enmascarados debido a la máscara de relleno
Probando nuestro modelo!
Para probar nuestro modelo con nuevos datos, definiremos un poco Translator
clase para ayudarnos con la decodificación:
class Translator(nn.Module):
def __init__(self, transformer):
super(Translator, self).__init__()
self.transformer = transformer@staticmethod
def str_to_tokens(s):
return (ord(z)-97+3 for z in s)
@staticmethod
def tokens_to_str(tokens):
return "".join((chr(x+94) for x in tokens))
def __call__(self, sentence, max_length=None, pad=False):
x = torch.tensor(self.str_to_tokens(sentence))
outputs = self.transformer.predict(sentence)
return self.tokens_to_str(outputs(0))
Deberías poder ver lo siguiente:
Y si imprimimos el cabezal de atención observaremos lo siguiente:
fig = plt.figure()
images = model.decoder.decoder_blocks(0).cross_attention.attention_weigths(0,...).detach().numpy().mean(axis=0)fig, ax = plt.subplots(1,1, figsize=(10., 10.))
# Iterating over the grid returs the Axes.
ax.set_yticks(range(len(out)))
ax.set_xticks(range(len(sentence)))
ax.xaxis.set_label_position('top')
ax.set_xticklabels(iter(sentence))
ax.set_yticklabels((f"step {i}" for i in range(len(out))))
ax.imshow(images)
¡Podemos ver claramente que el modelo atiende de derecha a izquierda al invertir nuestra oración “reversethis”! (El paso 0 en realidad recibe el token de comienzo de oración).
Eso es todo, ahora puede escribir Transformer y usarlo con conjuntos de datos más grandes para realizar la traducción automática de crear su propio BERT, por ejemplo.
Quería que este tutorial le mostrara las advertencias al escribir un Transformer: el relleno y el enmascaramiento son quizás las partes que requieren más atención (juego de palabras no intencionado), ya que definirán el buen desempeño del modelo durante la inferencia.
En los siguientes artículos, veremos cómo crear su propio modelo BERT y cómo usar Equinox, una biblioteca de alto rendimiento además de JAX.
Manténganse al tanto !
(+) “El transformador anotado”
(+) “Transformadores desde cero“
(+) “Traducción automática neuronal con Transformer y Keras”
(+) “El transformador ilustrado”
(+) Tutorial de aprendizaje profundo de la Universidad de Amsterdam
(+) Tutorial de Pytorch sobre transformadores