A step-by-step guide to building a Thai multilingual subword tokenizer based on a BPE algorithm trained on Thai and English datasets using only Python
The main task of the Tokenizer is to translate the raw input texts (Thai in our case, but it can be in any foreign language) into numbers and pass them to the model transformers. The model transformer then generates the output as numbers. Again, Tokenizer translates these numbers into texts that end users can understand. The high-level diagram below describes the flow explained above.
Generally, many of us are only interested in learning how the model transformer architecture works in depth. We often overlook the detailed learning of some important components such as tokenizers. Understanding how the tokenizer works in depth and having a good handle on its functionalities gives us a good head start in improving the accuracy and performance of our model.
Similar to Tokenizer, some of the most important components of LLM implementation processes are data preprocessing, evaluation, guardrails, and testing and monitoring. I highly recommend that you study more details about these topics. I realized the importance of these components only after working on the actual implementation of my basic multilingual ThaiLLM model in production.
Why do you need a Thai or any other foreign language tokenizer?
- Let’s say you are using generic English-based tokenizers to pre-train a large, multilingual language model like Thai, Hindi, Indonesian, Arabic, Chinese, etc. In that case, your model is not likely to provide a suitable output that makes sense for your specific domain or use cases. So, creating your own tokenizer in the language of your choice certainly helps make the output of your model much more coherent and understandable.
- Creating your own tokenizer also gives you complete control over the comprehensive and inclusive vocabulary you want to create. During the attention mechanism, due to the comprehensive vocabulary, the tokenizer can pay attention and learn from more tokens within the limited context length of the sequence. Thus, it makes the learning more coherent, which ultimately helps in better inference of the model.
The good news is that once you've finished creating Thai Tokenizer, you can easily create a tokenizer in any other language. All the creation steps are the same, except that you'll need to train on the dataset of the language of your choice.
Now that we have good reasons to create our own tokenizer, below are the steps to create our Thai language tokenizer.
- Let's build our own BPE algorithm
- Training the tokenizer
- Tokenizer encoding and decoding function
- Load and test the tokenizer
Step 1: Build our own BPE (byte pair encoding) algorithm:
The BPE algorithm is used in many popular LLMs such as Llama, GPT, and others to create their tokenizer. We can choose one of these LLM tokenizers if our model is based on the English language. Since we are creating the Thai tokenizer, the best option is to create our own BPE algorithm from scratch and use it to create our tokenizer. Let us first understand how the BPE algorithm works with the help of the simple flowchart given below and then we will start creating it accordingly.
The examples in the flowchart are shown in English for easy understanding.
Let's write code to implement the BPE algorithm for our Thai tokenizer.
# A simple practice example to get familiarization with utf-8 encoding to convert strings to bytes.
text = "How are you คุณเป็นอย่างไร" # Text string in both English and Thai
text_bytes = text.encode("utf-8")
print(f"Text in byte: {text_bytes}")text_list = list(text_bytes) # Converts text bytes to a list of integer
print(f"Text list in integer: {text_list}")
# As I don't want to reinvent the wheel, I will be referencing most of the code block from Andrej Karpathy's GitHub (https://github.com/karpathy/minbpe?tab=readme-ov-file).
# However, I'll be modifying code blocks specific to building our Thai language tokenizer and also explaining the codes so that you can understand how each code block works and make it easy when you implement code for your use case later.# This module provides access to the Unicode Character Database (UCD) which defines character properties for all Unicode characters.
import unicodedata
# This function returns a dictionary with consecutive pairs of integers and their counts in the given list of integers.
def get_stats(ids, stats=None):
stats = {} if stats is None else stats
# zip function allows to iterate consecutive items from given two list
for pair in zip(ids, ids(1:)):
# If a pair already exists in the stats dictionary, add 1 to its value else assign the value as 0.
stats(pair) = stats.get(pair, 0) + 1
return stats
# Once we find out the list of consecutive pairs of integers, we'll then replace those pairs with new integer tokens.
def merge(ids, pair, idx):
newids = ()
i = 0
# As we'll be merging a pair of ids, hence the minimum id in the list should be 2 or more.
while i < len(ids):
# If the current id and next id(id+1) exist in the given pair, and the position of id is not the last, then replace the 2 consecutive id with the given index value.
if ids(i) == pair(0) and i < len(ids) - 1 and ids(i+1) == pair(1):
newids.append(idx)
i += 2 # If the pair is matched, the next iteration starts after 2 positions in the list.
else:
newids.append(ids(i))
i += 1 # Since the current id pair didn't match, so start iteration from the 1 position next in the list.
# Returns the Merged Ids list
return newids
# This function checks that using 'unicodedata.category' which returns "C" as the first letter if it is a control character and we'll have to replace it readable character.
def replace_control_characters(s: str) -> str:
chars = ()
for ch in s:
# If the character is not distorted (meaning the first letter doesn't start with "C"), then append the character to chars list.
if unicodedata.category(ch)(0) != "C":
chars.append(ch)
# If the character is distorted (meaning the first letter has the letter "C"), then replace it with readable bytes and append to chars list.
else:
chars.append(f"\\u{ord(ch):04x}")
return "".join(chars)
# Some of the tokens such as control characters like Escape Characters can't be decoded into valid strings.
# Hence those need to be replace with readable character such as �
def render_token(t: bytes) -> str:
s = t.decode('utf-8', errors='replace')
s = replace_control_characters(s)
return s
The two functions get_statistics and bind What was defined above in the code block is the implementation of the BPE algorithm for our Thai tokenizer. Now that the algorithm is ready, let’s write the code to train our tokenizer.
Step 2: Train the tokenizer:
Training the tokenizer involves generating a vocabulary which is a database of unique tokens (words and subwords) along with a unique index number assigned to each token. We will use The Thai Wiki dataset Hugging Face to train our Thai Tokenizer. Just like training an LLM requires a lot of data, you will also need a good amount of data to train a tokenizer. You could also use the same dataset to train both the LLM and the tokenizer, although it is not mandatory. For a multilingual LLM, it is advisable to use the English and Thai datasets in a 2:1 ratio, which is a standard approach followed by many practitioners.
Let's start writing the training code.
# Import Regular Expression
import regex as re # Create a Thai Tokenizer class.
class ThaiTokenizer():
def __init__(self):
# The byte pair should be done within the related words or sentences that give a proper context. Pairing between unrelated words or sentences may give undesirable output.
# To prevent this behavior, we'll implement the LLama 3 regular expression pattern to make meaningful chunks of our text before implementing the byte pair algorithm.
self.pattern = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|(^\r\n\p{L}\p{N})?\p{L}+|\p{N}{1,3}| ?(^\s\p{L}\p{N})+(\r\n)*|\s*(\r\n)+|\s+(?!\S)|\s+"
self.compiled_pattern = re.compile(self.pattern)
# Special tokens are used to provide coherence in the sequence while training.
# Special tokens are assigned a unique index number and stored in vocabulary.
self.special_tokens = {
'<|begin_of_text|>': 1101,
'<|end_of_text|>': 1102,
'<|start_header_id|>': 1103,
'<|end_header_id|>': 1104,
'<|eot_id|>': 1105
}
# Initialize merges with empty dictionary
self.merges = {}
# Initialize the vocab dictionary by calling the function _build_vocab which is defined later in this class.
self.vocab = self._build_vocab()
# Tokenizer training function
def train(self, text, vocab_size):
# Make sure the vocab size must be at least 256 as the utf-8 encoding for the range 0-255 are same as the Ascii character.
assert vocab_size >= 256
# Total number of merges into the vocabulary.
num_merges = vocab_size - 256
# The first step is to make sure to split the text up into text chunks using the pattern defined above.
text_chunks = re.findall(self.compiled_pattern, text)
# Each text_chunks will be utf-8 encoded to bytes and then converted into an integer list.
ids = (list(ch.encode("utf-8")) for ch in text_chunks)
# Iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
vocab = {idx: bytes((idx)) for idx in range(256)} # idx -> bytes
# Until the total num_merges is reached, find the common pair of consecutive id in the ids list and start merging them to create a new token
for i in range(num_merges):
# Count the number of times every consecutive pair appears
stats = {}
for chunk_ids in ids:
# Passing in stats will update it in place, adding up counts
get_stats(chunk_ids, stats)
# Find the pair with the highest count
pair = max(stats, key=stats.get)
# Mint a new token: assign it the next available id
idx = 256 + i
# Replace all occurrences of pair in ids with idx
ids = (merge(chunk_ids, pair, idx) for chunk_ids in ids)
# Save the merge
merges(pair) = idx
vocab(idx) = vocab(pair(0)) + vocab(pair(1))
# Save class variables to be used later during tokenizer encode and decode
self.merges = merges
self.vocab = vocab
# Function to return a vocab dictionary combines with merges and special tokens
def _build_vocab(self):
# The utf-8 encoding for the range 0-255 are same as the Ascii character.
vocab = {idx: bytes((idx)) for idx in range(256)}
# Iterate through merge dictionary and add into vocab dictionary
for (p0, p1), idx in self.merges.items():
vocab(idx) = vocab(p0) + vocab(p1)
# Iterate through special token dictionary and add into vocab dictionary
for special, idx in self.special_tokens.items():
vocab(idx) = special.encode("utf-8")
return vocab
# After training is complete, use the save function to save the model file and vocab file.
# Model file will be used to load the tokenizer model for further use in llm
# Vocab file is just for the purpose of human verification
def save(self, file_prefix):
# Writing to model file
model_file = file_prefix + ".model" # model file name
# Model write begins
with open(model_file, 'w') as f:
f.write("thai tokenizer v1.0\n") # write the tokenizer version
f.write(f"{self.pattern}\n") # write the pattern used in tokenizer
f.write(f"{len(self.special_tokens)}\n") # write the length of special tokens
# Write each special token in the specific format like below
for tokens, idx in self.special_tokens.items():
f.write(f"{tokens} {idx}\n")
# Write only the keys part from the merges dict
for idx1, idx2 in self.merges:
f.write(f"{idx1} {idx2}\n")
# Writing to the vocab file
vocab_file = file_prefix + ".vocab" # vocab file name
# Change the position of keys and values of merge dict and store into inverted_merges
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
# Vocab write begins
with open(vocab_file, "w", encoding="utf-8") as f:
for idx, token in self.vocab.items():
# render_token function processes tokens and prevents distorted bytes by replacing them with readable character
s = render_token(token)
# If the index of vocab is present in merge dict, then find its child index, convert their corresponding bytes in vocab dict and write the characters
if idx in inverted_merges:
idx0, idx1 = inverted_merges(idx)
s0 = render_token(self.vocab(idx0))
s1 = render_token(self.vocab(idx1))
f.write(f"({s0})({s1}) -> ({s}) {idx}\n")
# If index of vocab is not present in merge dict, just write it's index and the corresponding string
else:
f.write(f"({s}) {idx}\n")
# Function to load tokenizer model.
# This function is invoked only after the training is complete and the tokenizer model file is saved.
def load(self, model_file):
merges = {} # Initialize merge and special_tokens with empty dict
special_tokens = {} # Initialize special_tokens with empty dict
idx = 256 # As the range (0, 255) is already reserved in vocab. So the next index only starts from 256 and onwards.
# Read model file
with open(model_file, 'r', encoding="utf-8") as f:
version = f.readline().strip() # Read the tokenizer version as defined during model file writing
self.pattern = f.readline().strip() # Read the pattern used in tokenizer
num_special = int(f.readline().strip()) # Read the length of special tokens
# Read all the special tokens and store in special_tokens dict defined earlier
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens(special) = int(special_idx)
# Read all the merge indexes from the file. Make it a key pair and store it in merge dictionary defined earlier.
# The value of this key pair would be idx(256) as defined above and keep on increase by 1.
for line in f:
idx1, idx2 = map(int, line.split())
merges((idx1, idx2)) = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
# Create a final vocabulary dictionary by combining merge, special_token and vocab (0-255). _build_vocab function helps to do just that.
self.vocab = self._build_vocab()
Step 3: Tokenizer encoding and decoding function:
- Tokenizer encoding: The tokenizer's encoding function analyzes the vocabulary and translates the input texts or prompts into a list of integer identifiers. These identifiers are then fed into the transformer blocks.
- Tokenizer Decoding: The tokenizer's decoding function analyzes the vocabulary and translates the list of IDs generated from the transformer's classifier block into output texts.
Let's take a look at the diagram below for further clarity.
Let's write code to implement the tokenizer's encoding and decoding function.
# Tokenizer encode function takes text as a string and returns integer ids list
def encode(self, text): # Define a pattern to identify special token present in the text
special_pattern = "(" + "|".join(re.escape(k) for k in self.special_tokens) + ")"
# Split special token (if present) from the rest of the text
special_chunks = re.split(special_pattern, text)
# Initialize empty ids list
ids = ()
# Loop through each of parts in the special chunks list.
for part in special_chunks:
# If the part of the text is the special token, get the idx of the part from the special token dictionary and append it to the ids list.
if part in self.special_tokens:
ids.append(self.special_tokens(part))
# If the part of text is not a special token
else:
# Split the text into multiple chunks using the pattern we've defined earlier.
text_chunks = re.findall(self.compiled_pattern, text)
# All text chunks are encoded separately, then the results are joined
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") # Encode text to bytes
chunk_ids = list(chunk_bytes) # Convert bytes to list of integer
while len(chunk_ids) >= 2: # chunks ids list must be at least 2 id to form a byte-pair
# Count the number of times every consecutive pair appears
stats = get_stats(chunk_ids)
# Some idx pair might be created with another idx in the merge dictionary. Hence we'll find the pair with the lowest merge index to ensure we cover all byte pairs in the merge dict.
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# Break the loop and return if the pair is not present in the merges dictionary
if pair not in self.merges:
break
# Find the idx of the pair present in the merges dictionary
idx = self.merges(pair)
# Replace the occurrences of pair in ids list with this idx and continue
chunk_ids = merge(chunk_ids, pair, idx)
ids.extend(chunk_ids)
return ids
# Tokenizer decode function takes a list of integer ids and return strings
def decode(self, ids):
# Initialize empty byte list
part_bytes = ()
# Change the position of keys and values of special_tokens dict and store into inverse_special_tokens
inverse_special_tokens = {v: k for k, v in self.special_tokens.items()}
# Loop through idx in the ids list
for idx in ids:
# If the idx is found in vocab dict, get the bytes of idx and append them into part_bytes list
if idx in self.vocab:
part_bytes.append(self.vocab(idx))
# If the idx is found in inverse_special_tokens dict, get the token string of the corresponding idx, convert it to bytes using utf-8 encode and then append it into part_bytes list
elif idx in inverse_special_tokens:
part_bytes.append(inverse_special_tokens(idx).encode("utf-8"))
# If the idx is not found in both vocab and special token dict, throw an invalid error
else:
raise ValueError(f"invalid token id: {idx}")
# Join all the individual bytes from the part_byte list
text_bytes = b"".join(part_bytes)
# Convert the bytes to text string using utf-8 decode function. Make sure to use "errors=replace" to replace distorted characters with readable characters such as �.
text = text_bytes.decode("utf-8", errors="replace")
return text
Step 4: Load and test the tokenizer:
Finally, here comes the best part of this article. In this section we will perform two interesting tasks.
- First, train our tokenizer on the Hugging Face Thai Wiki dataset. We have chosen a small dataset size (2.2 MB) to make training faster. However, for a real-world deployment, you should choose a much larger dataset to get better results. Once the training is complete, we will save the model.
- Second, we will load the saved tokenizer model and test the tokenizer's encoding and decoding functionality.
Let's dive into it.
# Train the tokenizerimport time # To caculate the duration of training completion
# Load training raw text data (thai_wiki dataset) from huggingface. thai_wiki_small.text: https://github.com/tamangmilan/thai_tokenizer
texts = open("/content/thai_wiki_small.txt", "r", encoding="utf-8").read()
texts = texts.strip()
# Define vocab size
vocab_size = 512
# Initialize a tokenizer model class
tokenizer = ThaiTokenizer()
# Start train a tokenizer
start_time = time.time()
tokenizer.train(texts, vocab_size)
end_time = time.time()
# Save tokenizer: you can change path and filename.
tokenizer.save("./models/thaitokenizer")
print(f"Total time to complete tokenizer training: {end_time-start_time:.2f} seconds")
# Output: Total time to complete tokenizer training: 186.11 seconds (3m 6s) (Note: Training duration will be longer if vocab_size is bigger and lesser for smaller vocab_size)
# Test the tokenizer# Initialize a tokenizer model class
tokenizer = ThaiTokenizer()
# Load tokenizer model. This model was saved during training.
tokenizer.load("./models/thaitokenizer.model")
# Invoke and verify the tokenizer encode and decode function for English Language
eng_texts = "When society evolved in different lands"
print(f"English Text: {eng_texts}")
encoded_ids = tokenizer.encode(eng_texts)
print(f"Encoded Ids: {encoded_ids}")
decoded_texts = tokenizer.decode(encoded_ids)
print(f"Decoded Texts: {decoded_texts}\n")
# Invoke and verify the tokenizer encode and decode function for Thai Language
thai_texts = "เมื่อสังคมมีวิวัฒนาการขึ้นในดินแดนต่าง"
print(f"Thai Text: {thai_texts}")
thai_encoded_ids = tokenizer.encode(thai_texts)
print(f"Encoded Ids: {thai_encoded_ids}")
thai_decoded_texts = tokenizer.decode(thai_encoded_ids)
print(f"Decoded Texts: {thai_decoded_texts}")
Perfect! Our Thai Tokenizer can now encode and decode Thai and English text correctly and accurately.
Have you noticed that the encoded identifiers for English texts are longer than the encoded identifiers for Thai texts? This is because we have only trained our tokenizer with the Thai dataset. Therefore, the tokenizer can only create a complete vocabulary for the Thai language. Since we have not trained it with an English dataset, the tokenizer has to encode directly from the character level, which results in longer encoded identifiers. As I mentioned before, for multilingual LLM, you should train the English and Thai datasets with a 2:1 ratio. This will give you balanced and quality results.
And that's all! We have successfully created our own Thai tokenizer from scratch using only Python. And I think it was great. With this, you can easily create a tokenizer for any foreign language. This will give you a huge advantage when implementing your multilingual LLM.
Thank you very much for reading!
References
(1) Andrej Karpathy, Git Hub: Karpthy/minbpe