In this post I will explain and demonstrate the concept of “structured generative ai“: generative ai limited to defined formats. By the end of the post, you will understand where and when it can be used and how to implement it, whether you are creating a transformer model from scratch or using Hugging Face models. Additionally, we'll cover an important tip for tokenization that's especially relevant for structured languages.
One of the many uses of generative ai is as a translation tool. This often involves translating between two human languages, but may also include computer languages or formats. For example, your application may need to translate natural (human) language to SQL:
Natural language: “Get customer names and emails of customers from the US”SQL: "SELECT name, email FROM customers WHERE country = 'USA'"
Or to convert text data to JSON format:
Natural language: “I am John Doe, phone number is 555–123–4567,
my friends are Anna and Sara”JSON: {name: "John Doe",
phone_number: "555–123–5678",
friends: {
name: (("Anna", "Sara"))}
}
Naturally, many more applications are possible for other structured languages. The training process for such tasks involves feeding natural language examples along with structured formats to an encoder-decoder model. Alternatively, it may be sufficient to leverage a pre-trained language model (LLM).
While achieving 100% accuracy is unattainable, there is one class of errors we can eliminate: syntax errors. These are violations of language formatting, such as replacing commas with periods, using table names that are not present in the SQL schema, or omitting closing brackets, which render SQL or JSON unexecutable.
The fact that we are translating to a structured language means that the list of legitimate tokens at each generation step is limited and predetermined. If we could insert this knowledge into the generative ai process, we could avoid a wide range of incorrect outcomes. This is the idea behind structured generative ai: limit it to a list of legitimate tokens.
A quick reminder on how tokens are generated
Whether an encoder-decoder or GPT architecture is used, token generation operates sequentially. The selection of each token is based on both input and previously generated tokens, and continues until a token is generated , which means the completion of the sequence. At each step, a classifier assigns logit values to all tokens in the vocabulary, representing the probability of each token as the next selection. The next token is sampled based on those logits.
Limit token generation
To limit token generation, we incorporate knowledge of the structure of the output language. Illegitimate tokens have their logits set to -inf, ensuring their exclusion from selection. For example, if only a comma or “FROM” is valid after “Select Name”, all other token logits are set to -inf.
If you are using Hugging Face, this can be implemented using a “logits processor”. To use it you need to implement a class with a __call__ method, which will be called after the logits are calculated, but before sampling. This method receives all generated token logits and input IDs, and returns modified logits for all tokens.
I will demonstrate the code with a simplified example. First, we initialize the model, we will use Bart in this case, but this can work with any model.
from transformers import BartForConditionalGeneration, BartTokenizerFast, PreTrainedTokenizer
from transformers.generation.logits_process import LogitsProcessorList, LogitsProcessor
import torchname = 'facebook/bart-large'
tokenizer = BartTokenizerFast.from_pretrained(name, add_prefix_space=True)
pretrained_model = BartForConditionalGeneration.from_pretrained(name)
If we want to generate a translation from natural language to SQL, we can execute:
to_translate = 'customers emails from the us'
words = to_translate.split()
tokenized_text = tokenizer((words), is_split_into_words=True)out = pretrained_model.generate(
torch.tensor(tokenized_text("input_ids")),
max_new_tokens=20,
)
print(tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(
out(0), skip_special_tokens=True)))
Coming back
'More emails from the us'
Since we did not adjust the model for text tasks to SQL, the result does not look like SQL. We won't train the model in this tutorial, but we will guide you through generating an SQL query. We will achieve this by employing a function that maps each generated token to a list of the next allowed tokens. For simplicity, we will focus only on the token immediately above, but more complicated mechanisms are easy to implement. We will use a dictionary that defines for each token which tokens can follow it. For example, the query should start with “SELECT” or “DELETE”, and after “SELECT” only “name”, “email” or “id” are allowed, since those are the columns in our schema.
rules = {'<s>': ('SELECT', 'DELETE'), # beginning of the generation
'SELECT': ('name', 'email', 'id'), # names of columns in our schema
'DELETE': ('name', 'email', 'id'),
'name': (',', 'FROM'),
'email': (',', 'FROM'),
'id': (',', 'FROM'),
',': ('name', 'email', 'id'),
'FROM': ('customers', 'vendors'), # names of tables in our schema
'customers': ('</s>'),
'vendors': ('</s>'), # end of the generation
}
Now we need to convert these tokens to the IDs used by the model. This will happen inside a class inherited from LogitsProcessor.
def convert_token_to_id(token):
return tokenizer(token, add_special_tokens=False)('input_ids')(0)class SQLLogitsProcessor(LogitsProcessor):
def __init__(self, tokenizer: PreTrainedTokenizer):
self.tokenizer = tokenizer
self.rules = {convert_token_to_id(k): (convert_token_to_id(v0) for v0 in v) for k,v in rules.items()}
Finally, we will implement the __call__ function, which is called after calculating the logits. The function creates a new -infs tensor, checks which IDs are legitimate according to the rules (the dictionary), and puts their scores into the new tensor. The result is a tensor that only has valid values for valid tokens.
class SQLLogitsProcessor(LogitsProcessor):
def __init__(self, tokenizer: PreTrainedTokenizer):
self.tokenizer = tokenizer
self.rules = {convert_token_to_id(k): (convert_token_to_id(v0) for v0 in v) for k,v in rules.items()}def __call__(self, input_ids: torch.LongTensor, scores: torch.LongTensor):
if not (input_ids == self.tokenizer.bos_token_id).any():
# we must allow the start token to appear before we start processing
return scores
# create a new tensor of -inf
new_scores = torch.full((1, self.tokenizer.vocab_size), float('-inf'))
# ids of legitimate tokens
legit_ids = self.rules(int(input_ids(0, -1)))
# place their values in the new tensor
new_scores(:, legit_ids) = scores(0, legit_ids)
return new_scores
And that is! Now we can run a generation with the logits processor:
to_translate = 'customers emails from the us'
words = to_translate.split()
tokenized_text = tokenizer((words), is_split_into_words=True, return_offsets_mapping=True)logits_processor = LogitsProcessorList((SQLLogitsProcessor(tokenizer)))
out = pretrained_model.generate(
torch.tensor(tokenized_text("input_ids")),
max_new_tokens=20,
logits_processor=logits_processor
)
print(tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(
out(0), skip_special_tokens=True)))
Coming back
SELECT email , email , id , email FROM customers
The result is a bit strange, but remember: We don't even train the model! We only apply token generation according to specific rules. In particular, restricting generation does not interfere with training; The restrictions only apply during the post-training generation. Therefore, when implemented properly, these constraints can only improve the generation accuracy.
Our simplistic implementation falls short of covering all SQL syntax. A real implementation should support more syntax, potentially considering not just the last token but several, and allow batch generation. Once these improvements are implemented, our trained model can reliably generate executable SQL queries, restricted to valid table and column names in the schema. A similar approach can impose restrictions on JSON generation, ensuring the presence of keys and the closing of square brackets.
Be careful with tokenization
Tokenization is often overlooked, but correct tokenization is crucial when using generative ai for structured results. However, deep down, tokenization can have an impact on your model training. For example, you can tune a model to translate text to JSON. As part of the tuning process, you provide the model with examples of text-JSON pairs, which it tokenizes. What will this tokenization be like?
While reading “((” as two square brackets, the tokenizer converts them into a single ID, which the token classifier will treat as a completely different class from the single square bracket. This makes all the logic the model must learn: more complicated (e.g. (for example, remembering how many square brackets to close.) Similarly, adding a space before words can change their tokenization and their class ID.
Again, this complicates the logic that the model will have to learn, since the weights connected to each of these IDs will have to be learned separately, for slightly different cases.
For easier learning, make sure each concept and punctuation is consistently converted to the same token by adding spaces before words and characters.
Introducing spaced examples during fine-tuning simplifies the patterns the model must learn, which improves the model's accuracy. During prediction, the model will generate the JSON with spaces, which you can then remove before parsing.
Summary
Generative ai offers a valuable approach to translating into a formatted language. By leveraging knowledge of the output structure, we can constrain the generative process, eliminating a class of errors and ensuring the executability of queries and the parsability of data structures.
Additionally, these formats may use punctuation and keywords to indicate certain meanings. Ensuring that the tokenization of these keywords is consistent can dramatically reduce the complexity of the patterns the model has to learn, thereby reducing the required size of the model and its training time, while increasing its accuracy.
Structured generative ai can effectively translate natural language into any structured format. These translations allow the extraction of information from the text or the generation of queries, which is a powerful tool for many applications.