Extracción de relaciones mejorada mediante el ajuste de Llama3–8B con un conjunto de datos sintéticos creado con Llama3–70B
La extracción de relaciones (RE) es la tarea de extraer relaciones de texto no estructurado para identificar conexiones entre varias entidades nombradas. Se realiza junto con el reconocimiento de entidades nombradas (NER) y es un paso esencial en el proceso de procesamiento del lenguaje natural. Con el auge de los modelos de lenguajes grandes (LLM), los enfoques supervisados tradicionales que implican etiquetar intervalos de entidades y clasificar relaciones (si las hay) entre ellas se mejoran o se reemplazan por completo por enfoques basados en LLM (1).
Llama3 es el lanzamiento importante más reciente en el dominio de GenerativeAI (ai.meta.com/blog/meta-llama-3/” rel=”noopener ugc nofollow” target=”_blank”>2). El modelo base está disponible en dos tamaños, 8B y 70B, y se espera que pronto se lance un modelo 400B. Estos modelos están disponibles en la plataforma HuggingFace; ver (3) para detalles. La variante 70B impulsa el nuevo sitio web de chat de Meta ai” rel=”noopener ugc nofollow” target=”_blank”>Meta.ai y muestra un rendimiento comparable al de ChatGPT. El modelo 8B se encuentra entre los de mayor rendimiento de su clase. La arquitectura de Llama3 es similar a la de Llama2, y el aumento del rendimiento se debe principalmente a la actualización de datos. El modelo viene con un tokenizador actualizado y una ventana de contexto ampliada. Está etiquetado como de código abierto, aunque sólo se publica un pequeño porcentaje de los datos. En general, es un modelo excelente y no puedo esperar para probarlo.
Llama3–70B puede producir resultados sorprendentes, pero debido a su tamaño no es práctico, es prohibitivamente caro y difícil de usar en sistemas locales. Por lo tanto, para aprovechar sus capacidades, hacemos que Llama3–70B le enseñe al Llama3–8B más pequeño la tarea de extracción de relaciones de texto no estructurado.
Específicamente, con la ayuda de Llama3–70B, construimos un conjunto de datos de ajuste supervisado destinado a la extracción de relaciones. Luego utilizamos este conjunto de datos para ajustar Llama3–8B para mejorar sus capacidades de extracción de relaciones.
Para reproducir el código en el Cuaderno de Google Colab asociado a este blog, necesitarás:
- Credenciales de HuggingFace (para guardar el modelo afinado, opcional) y acceso a Llama3, que se puede obtener siguiendo las instrucciones de una de las tarjetas de los modelos;
- un gratis GroqCloud cuenta (puede iniciar sesión con una cuenta de Google) y una clave API correspondiente.
Para este proyecto utilicé un Google Colab Pro equipado con una GPU A100 y una configuración de alta RAM.
Comenzamos instalando todas las bibliotecas necesarias:
!pip install -q groq
!pip install -U accelerate bitsandbytes datasets evaluate
!pip install -U peft transformers trl
Me alegró mucho notar que toda la configuración funcionó desde el principio sin problemas de dependencias ni necesidad de instalación. transformers
de la fuente, a pesar de la novedad del modelo.
También debemos dar acceso a Goggle Colab a la unidad y a los archivos y configurar el directorio de trabajo:
# For Google Colab settings
from google.colab import userdata, drive# This will prompt for authorization
drive.mount('/content/drive')
# Set the working directory
%cd '/content/drive/MyDrive/postedBlogs/llama3RE'
Para aquellos que deseen cargar el modelo en HuggingFace Hub, debemos cargar las credenciales del Hub. En mi caso, estos se almacenan en los secretos de Google Colab, a los que se puede acceder mediante el botón de la izquierda. Este paso es opcional.
# For Hugging Face Hub setting
from huggingface_hub import login# Upload the HuggingFace token (should have WRITE access) from Colab secrets
HF = userdata.get('HF')
# This is needed to upload the model to HuggingFace
login(token=HF,add_to_git_credential=True)
También agregué algunas variables de ruta para simplificar el acceso a los archivos:
# Create a path variable for the data folder
data_path = '/content/drive/MyDrive/postedBlogs/llama3RE/datas/'# Full fine-tuning dataset
sft_dataset_file = f'{data_path}sft_train_data.json'
# Data collected from the the mini-test
mini_data_path = f'{data_path}mini_data.json'
# Test data containing all three outputs
all_tests_data = f'{data_path}all_tests.json'
# The adjusted training dataset
train_data_path = f'{data_path}sft_train_data.json'
# Create a path variable for the SFT model to be saved locally
sft_model_path = '/content/drive/MyDrive/llama3RE/Llama3_RE/'
Ahora que nuestro espacio de trabajo está configurado, podemos pasar al primer paso, que es construir un conjunto de datos sintéticos para la tarea de extracción de relaciones.
Hay varios conjuntos de datos de extracción de relaciones disponibles, siendo el más conocido el CONLL04 conjunto de datos. Además, existen excelentes conjuntos de datos como web_nlgdisponible en HuggingFace, y SciREX desarrollado por AllenAI. Sin embargo, la mayoría de estos conjuntos de datos vienen con licencias restrictivas.
Inspirado en el formato del web_nlg
conjunto de datos construiremos nuestro propio conjunto de datos. Este enfoque será particularmente útil si planeamos ajustar un modelo entrenado en nuestro conjunto de datos. Para empezar, necesitamos una colección de oraciones cortas para nuestra tarea de extracción de relaciones. Podemos compilar este corpus de varias maneras.
Reúna una colección de oraciones
Usaremos databricks-dolly-15k, un conjunto de datos de código abierto generado por empleados de Databricks en 2023. Este conjunto de datos está diseñado para un ajuste supervisado e incluye cuatro características: instrucción, contexto, respuesta y categoría. Después de analizar las ocho categorías, decidí conservar la primera frase del contexto del information_extraction
categoría. Los pasos de análisis de datos se describen a continuación:
from datasets import load_dataset# Load the dataset
dataset = load_dataset("databricks/databricks-dolly-15k")
# Choose the desired category from the dataset
ie_category = (e for e in dataset("train") if e("category")=="information_extraction")
# Retain only the context from each instance
ie_context = (e("context") for e in ie_category)
# Split the text into sentences (at the period) and keep the first sentence
reduced_context = (text.split('.')(0) + '.' for text in ie_context)
# Retain sequences of specified lengths only (use character length)
sampler = (e for e in reduced_context if 30 < len(e) < 170)
El proceso de selección arroja un conjunto de datos que comprende 1.041 frases. Dado que se trata de un miniproyecto, no seleccioné las oraciones y, como resultado, es posible que algunas muestras no sean ideales para nuestra tarea. En un proyecto designado para producción, seleccionaría cuidadosamente solo las oraciones más apropiadas. Sin embargo, para los fines de este proyecto, este conjunto de datos será suficiente.
Formatear los datos
Primero necesitamos crear un mensaje del sistema que definirá el mensaje de entrada e indicará al modelo cómo generar las respuestas:
system_message = """You are an experienced annontator.
Extract all entities and the relations between them from the following text.
Write the answer as a triple entity1|relationship|entitity2.
Do not add anything else.
Example Text: Alice is from France.
Answer: Alice|is from|France.
"""
Dado que se trata de una fase experimental, mantengo las exigencias del modelo al mínimo. Probé varios otros mensajes, incluidos algunos que solicitaban resultados en formato CoNLL donde se clasifican las entidades, y el modelo funcionó bastante bien. Sin embargo, en aras de la simplicidad, por ahora nos ceñiremos a lo básico.
También necesitamos convertir los datos a un formato conversacional:
messages = ((
{"role": "system","content": f"{system_message}"},
{"role": "user", "content": e}) for e in sampler)
El cliente y la API de Groq
Llama3 se lanzó hace apenas unos días y la disponibilidad de opciones de API aún es limitada. Si bien hay una interfaz de chat disponible para Llama3–70B, este proyecto requiere una API que pueda procesar mis 1000 oraciones con un par de líneas de código. Encontré esto excelente Video de Youtube que explica cómo utilizar la API de GroqCloud de forma gratuita. Para obtener más detalles, consulte el vídeo.
Solo un recordatorio: deberá iniciar sesión y recuperar una clave API gratuita del GroqCloud sitio web. Mi clave API ya está guardada en los secretos de Google Colab. Empezamos inicializando el cliente Groq:
import os
from groq import Groqgclient = Groq(
api_key=userdata.get("GROQ"),
)
A continuación necesitamos definir un par de funciones auxiliares que nos permitirán interactuar con el ai/” rel=”noopener ugc nofollow” target=”_blank”>Meta.ai interfaz de chat de manera efectiva (estos están adaptados del Video de Youtube):
import time
from tqdm import tqdmdef process_data(prompt):
"""Send one request and retrieve model's generation."""
chat_completion = gclient.chat.completions.create(
messages=prompt, # input prompt to send to the model
model="llama3-70b-8192", # according to GroqCloud labeling
temperature=0.5, # controls diversity
max_tokens=128, # max number tokens to generate
top_p=1, # proportion of likelihood weighted options to consider
stop=None, # string that signals to stop generating
stream=False, # if set partial messages are sent
)
return chat_completion.choices(0).message.content
def send_messages(messages):
"""Process messages in batches with a pause between batches."""
batch_size = 10
answers = ()
for i in tqdm(range(0, len(messages), batch_size)): # batches of size 10
batch = messages(i:i+10) # get the next batch of messages
for message in batch:
output = process_data(message)
answers.append(output)
if i + 10 < len(messages): # check if there are batches left
time.sleep(10) # wait for 10 seconds
return answers
La primera función process_data()
Sirve como contenedor para la función de finalización del chat del cliente Groq. La segunda función send_messages()
, procesa los datos en pequeños lotes. Si sigue el enlace Configuración en la página del área de juegos de Groq, encontrará un enlace para Límites que detalla las condiciones bajo las cuales podemos usar la API gratuita, incluidos límites en la cantidad de solicitudes y tokens generados. Para evitar exceder estos límites, agregué un retraso de 10 segundos después de cada lote de 10 mensajes, aunque en mi caso no era estrictamente necesario. Quizás quieras experimentar con estas configuraciones.
Lo que queda ahora es generar nuestros datos de extracción de relaciones e integrarlos con el conjunto de datos inicial:
# Data generation with Llama3-70B
answers = send_messages(messages)# Combine input data with the generated dataset
combined_dataset = ({'text': user, 'gold_re': output} for user, output in zip(sampler, answers))
Antes de continuar con el ajuste del modelo, es importante evaluar su rendimiento en varias muestras para determinar si realmente es necesario realizar un ajuste.
Construyendo un conjunto de datos de prueba
Seleccionaremos 20 muestras del conjunto de datos que acabamos de construir y las dejaremos a un lado para realizar pruebas. El resto del conjunto de datos se utilizará para realizar ajustes.
import random
random.seed(17)# Select 20 random entries
mini_data = random.sample(combined_dataset, 20)
# Build conversational format
parsed_mini_data = (({'role': 'system', 'content': system_message},
{'role': 'user', 'content': e('text')}) for e in mini_data)
# Create the training set
train_data = (item for item in combined_dataset if item not in mini_data)
Usaremos la API de GroqCloud y las utilidades definidas anteriormente, especificando model=llama3-8b-8192
mientras que el resto de la función permanece sin cambios. En este caso, podemos procesar directamente nuestro pequeño conjunto de datos sin preocuparnos de exceder los límites de la API.
A continuación se muestra un resultado de muestra que proporciona el original. text
la generación Llama3-70B denota gold_re
y la generación Llama3-8B etiquetada test_re
.
{'text': 'Long before any knowledge of electricity existed, people were aware of shocks from electric fish.',
'gold_re': 'people|were aware of|shocks\nshocks|from|electric fish\nelectric fish|had|electricity',
'test_re': 'electric fish|were aware of|shocks'}
Para obtener el conjunto de datos de prueba completo, consulte la Cuaderno de Google Colab.
Solo con este ejemplo, queda claro que Llama3–8B podría beneficiarse de algunas mejoras en sus capacidades de extracción de relaciones. Trabajemos para mejorar eso.
Utilizaremos un arsenal completo de técnicas para ayudarnos, incluidas QLoRA y Flash Attention. No profundizaré en los detalles de la elección de hiperparámetros aquí, pero si está interesado en explorar más a fondo, consulte estas excelentes referencias (4) y (5).
La GPU A100 admite Flash Attention y bfloat16, y posee alrededor de 40 GB de memoria, que es suficiente para nuestras necesidades de ajuste.
Preparación del conjunto de datos SFT
Comenzamos analizando el conjunto de datos en un formato conversacional, que incluye un mensaje del sistema, texto de entrada y la respuesta deseada, que derivamos de la generación Llama3–70B. Luego lo guardamos como un conjunto de datos de HuggingFace:
def create_conversation(sample):
return {
"messages": (
{"role": "system","content": system_message},
{"role": "user", "content": sample("text")},
{"role": "assistant", "content": sample("gold_re")}
)
}from datasets import load_dataset, Dataset
train_dataset = Dataset.from_list(train_data)
# Transform to conversational format
train_dataset = train_dataset.map(create_conversation,
remove_columns=train_dataset.features,
batched=False)
Elige el modelo
model_id = "meta-llama/Meta-Llama-3-8B"
Cargar el tokenizador
from transformers import AutoTokenizer# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id,
use_fast=True,
trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
# Set a maximum length
tokenizer.model_max_length = 512
Elija los parámetros de cuantificación
from transformers import BitsAndBytesConfigbnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
Cargar el modelo
from transformers import AutoModelForCausalLM
from peft import prepare_model_for_kbit_training
from trl import setup_chat_formatdevice_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() else None
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
attn_implementation="flash_attention_2",
quantization_config=bnb_config
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)
Configuración LoRA
from peft import LoraConfig# According to Sebastian Raschka findings
peft_config = LoraConfig(
lora_alpha=128, #32
lora_dropout=0.05,
r=256, #16
bias="none",
target_modules=("q_proj", "o_proj", "gate_proj", "up_proj",
"down_proj", "k_proj", "v_proj"),
task_type="CAUSAL_LM",
)
Los mejores resultados se logran al apuntar a todas las capas lineales. Si le preocupan las limitaciones de memoria, puede resultar beneficioso optar por valores más estándar, como alfa=32 y rango=16, ya que estas configuraciones dan como resultado muchos menos parámetros.
Argumentos de entrenamiento
from transformers import TrainingArguments# Adapted from Phil Schmid blogpost
args = TrainingArguments(
output_dir=sft_model_path, # directory to save the model and repository id
num_train_epochs=2, # number of training epochs
per_device_train_batch_size=4, # batch size per device during training
gradient_accumulation_steps=2, # number of steps before performing a backward/update pass
gradient_checkpointing=True, # use gradient checkpointing to save memory, use in distributed training
optim="adamw_8bit", # choose paged_adamw_8bit if not enough memory
logging_steps=10, # log every 10 steps
save_strategy="epoch", # save checkpoint every epoch
learning_rate=2e-4, # learning rate, based on QLoRA paper
bf16=True, # use bfloat16 precision
tf32=True, # use tf32 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
lr_scheduler_type="constant", # use constant learning rate scheduler
push_to_hub=True, # push model to Hugging Face hub
hub_model_id="llama3-8b-sft-qlora-re",
report_to="tensorboard", # report metrics to tensorboard
)
Si elige guardar el modelo localmente, puede omitir los últimos tres parámetros. También es posible que necesite ajustar el per_device_batch_size
y gradient_accumulation_steps
para evitar errores de falta de memoria (OOM).
Inicializar el entrenador y entrenar el modelo
from trl import SFTTrainertrainer = SFTTrainer(
model=model,
args=args,
train_dataset=sft_dataset,
peft_config=peft_config,
max_seq_length=512,
tokenizer=tokenizer,
packing=False, # True if the dataset is large
dataset_kwargs={
"add_special_tokens": False, # the template adds the special tokens
"append_concat_token": False, # no need to add additional separator token
}
)
trainer.train()
trainer.save_model()
La capacitación, incluido el guardado del modelo, duró unos 10 minutos.
Borremos la memoria para prepararnos para las pruebas de inferencia. Si está utilizando una GPU con menos memoria y encuentra errores de CUDA sin memoria (OOM), es posible que deba reiniciar el tiempo de ejecución.
import torch
import gc
del model
del tokenizer
gc.collect()
torch.cuda.empty_cache()
En este paso final cargaremos el modelo base a media precisión junto con el adaptador Peft. Para esta prueba, he optado por no fusionar el modelo con el adaptador.
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline
import torch# HF model
peft_model_id = "solanaO/llama3-8b-sft-qlora-re"
# Load Model with PEFT adapter
model = AutoPeftModelForCausalLM.from_pretrained(
peft_model_id,
device_map="auto",
torch_dtype=torch.float16,
offload_buffers=True
)
A continuación, cargamos el tokenizador:
okenizer = AutoTokenizer.from_pretrained(peft_model_id)tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
Y construimos el canal de generación de texto:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
Cargamos el conjunto de datos de prueba, que consta de las 20 muestras que reservamos anteriormente, y formateamos los datos en un estilo conversacional. Sin embargo, esta vez omitimos el mensaje del asistente y lo formateamos como un conjunto de datos de Hugging Face:
def create_input_prompt(sample):
return {
"messages": (
{"role": "system","content": system_message},
{"role": "user", "content": sample("text")},
)
}from datasets import Dataset
test_dataset = Dataset.from_list(mini_data)
# Transform to conversational format
test_dataset = test_dataset.map(create_input_prompt,
remove_columns=test_dataset.features,
batched=False)
Una prueba de muestra
Generemos resultados de extracción de relaciones usando SFT Llama3–8B y compárelos con los dos resultados anteriores en una sola instancia:
Generate the input prompt
prompt = pipe.tokenizer.apply_chat_template(test_dataset(2)("messages")(:2),
tokenize=False,
add_generation_prompt=True)
# Generate the output
outputs = pipe(prompt,
max_new_tokens=128,
do_sample=False,
temperature=0.1,
top_k=50,
top_p=0.1,
)
# Display the results
print(f"Question: {test_dataset(2)('messages')(1)('content')}\n")
print(f"Gold-RE: {test_sampler(2)('gold_re')}\n")
print(f"LLama3-8B-RE: {test_sampler(2)('test_re')}\n")
print(f"SFT-Llama3-8B-RE: {outputs(0)('generated_text')(len(prompt):).strip()}")
Obtenemos lo siguiente:
Question: Long before any knowledge of electricity existed, people were aware of shocks from electric fish.Gold-RE: people|were aware of|shocks
shocks|from|electric fish
electric fish|had|electricity
LLama3-8B-RE: electric fish|were aware of|shocks
SFT-Llama3-8B-RE: people|were aware of|shocks
shocks|from|electric fish
En este ejemplo, observamos mejoras significativas en las capacidades de extracción de relaciones de Llama3–8B mediante ajustes. A pesar de que el conjunto de datos de ajuste no es muy limpio ni particularmente grande, los resultados son impresionantes.
Para obtener los resultados completos del conjunto de datos de 20 muestras, consulte el Cuaderno de Google Colab. Tenga en cuenta que la prueba de inferencia lleva más tiempo porque cargamos el modelo con media precisión.
En conclusión, al utilizar Llama3–70B y un conjunto de datos disponible, creamos con éxito un conjunto de datos sintéticos que luego se utilizó para ajustar Llama3–8B para una tarea específica. Este proceso no sólo nos familiarizó con Llama3, sino que también nos permitió aplicar técnicas sencillas de Hugging Face. Observamos que trabajar con Llama3 se parece mucho a la experiencia con Llama2, siendo las mejoras notables una mayor calidad de salida y un tokenizador más efectivo.
Para aquellos interesados en ampliar aún más los límites, considere desafiar el modelo con tareas más complejas, como categorizar entidades y relaciones, y utilizar estas clasificaciones para crear un gráfico de conocimiento.
- Somin Wadhwa, Silvio Amir, Byron C. Wallace, Revisando la extracción de relaciones en la era de los grandes modelos de lenguaje, arXiv.2305.05003 (2023).
- Meta, Presentamos Meta Llama 3: el LLM disponible abiertamente más capaz hasta la fecha, 18 de abril de 2024 (ai.meta.com/blog/meta-llama-3/” rel=”noopener ugc nofollow” target=”_blank”>enlace).
- Philipp Schmid, Omar Sanseviero, Pedro Cuenca, Youndes Belkada, Leandro von Werra, Bienvenido Llama 3: el nuevo LLM abierto de Met, 18 de abril de 2024.
- Sebastián Raschka, Consejos prácticos para perfeccionar los LLM utilizando LoRA (adaptación de bajo rango)Ahead of ai, 19 de noviembre de 2023.
- Philipp Schmid, Cómo perfeccionar los LLM en 2024 con Hugging Face, 22 de enero de 2024.
databricks-dolly-15K en la plataforma Hugging Face (CC BY-SA 3.0)