Accelerate training of ai/ML models with custom operators: Part 3.A
This is a direct sequel to a previous post on the topic of implementing custom TPU operations with Pallas. Of particular interest are custom kernels that take advantage of the unique properties of the TPU architecture in a way that optimizes runtime performance. In this post, we will attempt to demonstrate this opportunity by applying the power of Pallas to the challenge of running sequential algorithms that are interleaved within a predominantly parallelizable deep learning (DL) workload.
We will focus on Non-maximum suppression (NMS) of bounding box proposals as a representative algorithm and explore ways to optimize its implementation. An important component of computer vision (CV) object detection solutions (e.g. RCNN Mask), NMS is commonly used to filter out overlapping bounding boxes and keep only the “best” ones. NMS receives a list of bounding box proposals, an associated list of scores, and a promissory note threshold, and proceeds to greedily and iteratively Choose the remaining box with the highest score and disqualify all other boxes with which you have an IOU that exceeds the set threshold. The fact that the square chosen in the umpteenth the iteration depends on the above n-1 The steps of the algorithm dictate the sequential nature of its implementation. please see here I here to learn more about the reasons behind NMS and its implementation. Although we have chosen to focus on one specific algorithm, most of our discussion should move to other sequential algorithms.
Downloading sequential algorithms to the CPU
The presence of a sequential algorithm within a predominantly parallelizable ML model (e.g. Mask R-CNN) presents an interesting challenge. While GPUs, commonly used for these types of workloads, excel at executing parallel operations such as matrix multiplication, they can significantly underperform compared to CPUs when handling sequential algorithms. This often leads to compute graphs that include crossovers between the GPU and the CPU, where the GPU handles parallel operations and the CPU handles sequential ones. NMS is an excellent example of a sequential algorithm that is commonly offloaded to the CPU. In fact, a detailed analysis of torch visionThe “CUDA” implementation of NMSreveals that it even runs a significant part of the algorithm in UPC.
Although offloading sequential operations to the CPU can improve runtime performance, there are several potential drawbacks to consider:
- Cross-device execution between the CPU and GPU typically requires multiple synchronization points between the devices, which commonly results in downtime on the GPU while waiting for the CPU to complete its tasks. Since the GPU is typically the most expensive component of the training platform, our goal is to minimize that downtime.
- In standard ML workflows, the CPU is responsible for preparing and feeding data to the model, which resides on the GPU. If the data input pipeline involves compute-intensive processing, this can overload the CPU and cause a “miss-in” to the GPU. In such scenarios, offloading parts of the model calculation to the CPU could further exacerbate this problem.
To avoid these drawbacks, alternative approaches could be considered, such as replacing the sequential algorithm with a comparable alternative (e.g. the one suggested here), settle for a slow or suboptimal GPU implementation of the sequential algorithm, or run the workload on the CPU, each of which has its own potential trade-offs.
Sequential algorithms in TPU
This is where the unique architecture of the TPU could present an opportunity. Unlike GPUs, TPUs are sequential processors. While their ability to run highly vectorized operations makes them competitive with GPUs when running parallelizable operations like matrix multiplication, their sequential nature could make them especially well-suited for running machine learning workloads that include a mix of sequential and parallel components. Armed with the Shovel Expansion to JAX, our new TPU kernel creation tool, we will evaluate this opportunity by implementing and evaluating a custom NMS implementation for TPU.
Disclaimers
The NMS implementations that we will share below are intended for demonstration purposes only. We have not made any significant effort to optimize them or verify their robustness, durability or accuracy. Please note that, at the time of writing, Pallas is a experimental Feature: Still in active development. The code we share (based on JAX version 0.4.32) may be out of date by the time you read this. Be sure to check out the most up-to-date APIs and resources available for your Pallas development. Please do not consider our mention of any algorithm, library or API as an endorsement of your use.
We start with a simple NMS implementation in fattened which will serve as a reference point for performance comparison:
import numpy as npdef nms_cpu(boxes, scores, max_output_size, threshold=0.1):
epsilon = 1e-5
# Convert bounding boxes and scores to numpy
boxes = np.array(boxes)
scores = np.array(scores)
# coordinates of bounding boxes
start_x = boxes(:, 0)
start_y = boxes(:, 1)
end_x = boxes(:, 2)
end_y = boxes(:, 3)
# Compute areas of bounding boxes
areas = (end_x - start_x) * (end_y - start_y)
# Sort by confidence score of bounding boxes
order = np.argsort(scores)
# Picked bounding boxes
picked_boxes = ()
# Iterate over bounding boxes
while order.size > 0 and len(picked_boxes) < max_output_size:
# The index of the remaining box with the highest score
index = order(-1)
# Pick the bounding box with largest confidence score
picked_boxes.append(index.item())
# Compute coordinates of intersection
x1 = np.maximum(start_x(index), start_x(order(:-1)))
x2 = np.minimum(end_x(index), end_x(order(:-1)))
y1 = np.maximum(start_y(index), start_y(order(:-1)))
y2 = np.minimum(end_y(index), end_y(order(:-1)))
# Compute areas of intersection and union
w = np.maximum(x2 - x1, 0.0)
h = np.maximum(y2 - y1, 0.0)
intersection = w * h
union = areas(index) + areas(order(:-1)) - intersection
# Compute the ratio between intersection and union
ratio = intersection / np.clip(union, min=epsilon)
# discard boxes above overlap threshold
keep = np.where(ratio < threshold)
order = order(keep)
return picked_boxes
To evaluate the performance of our NMS function, we generate a batch of random frames and scores (as JAX tensors) and run the script in a Google Cloud TPU v5e system that uses the same environment and benchmarking utility as in our previous post. For this experiment, we specify the CPU as JAX default device:
import jax
from jax import random
import jax.numpy as jnpdef generate_random_boxes(run_on_cpu = False):
if run_on_cpu:
jax.config.update('jax_default_device', jax.devices('cpu')(0))
else:
jax.config.update('jax_default_device', jax.devices('tpu')(0))
n_boxes = 1024
img_size = 1024
k1, k2, k3 = random.split(random.key(0), 3)
# Randomly generate box sizes and positions
box_sizes = random.randint(k1,
shape=(n_boxes, 2),
minval=1,
maxval=img_size)
top_left = random.randint(k2,
shape=(n_boxes, 2),
minval=0,
maxval=img_size - 1)
bottom_right = jnp.clip(top_left + box_sizes, 0, img_size - 1)
# Concatenate top-left and bottom-right coordinates
rand_boxes = jnp.concatenate((top_left, bottom_right),
axis=1).astype(jnp.bfloat16)
rand_scores = jax.random.uniform(k3,
shape=(n_boxes,),
minval=0.0,
maxval=1.0)
return rand_boxes, rand_scores
rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=True)
time = benchmark(nms_cpu)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_cpu: {time}')
The resulting average execution time is 2.99 milliseconds. Note the assumption that the input and output tensors reside in the CPU. If they are on the TPU, the time to copy them between devices must also be taken into account.
If our NMS function is a component within a larger compute graph running on the TPU, we might prefer a TPU-compatible implementation to avoid the drawbacks of cross-device execution. The following code block contains a JAX implementation of NMS designed specifically to enable acceleration using JIT compilation. Denoting the number of boxes by northWe begin by calculating the promissory note between each of the n(n-1) pairs of boxes and preparing a northunknownnorth boolean tensor (mask_threshold) where the (I, J.)-th entry indicates whether the promissory note between cells Yo and j exceed the predefined threshold.
To simplify iterative frame selection, we create a copy of the mask tensor (threshold_mask2) where diagonal elements are set to zero to prevent a box from deleting itself. Additionally, we define two score tracking tensors: out_scoreswhich retains the scores of the chosen squares (and resets the scores of the eliminated ones), and remaining_scoreswhich keeps the box scores still in consideration. Then we use the jax.lax.while_loop Function to iteratively choose frames while updating the out_scores and remaining_scores tensioners. Note that the output format of this function differs from the previous function and may need to be adjusted to fit the subsequent steps of the calculation graph.
import functools# Given N boxes, calculates mask_threshold an NxN boolean mask
# where the (i,j) entry indicates whether the IOU of boxes i and j
# exceed the threshold. Returns mask_threshold, mask_threshold2
# which is equivalent to mask_threshold with zero diagonal and
# the scores modified so that all values are greater than 0
def init_tensors(boxes, scores, threshold=0.1):
epsilon = 1e-5
# Extract left, top, right, bottom coordinates
left = boxes(:, 0)
top = boxes(:, 1)
right = boxes(:, 2)
bottom = boxes(:, 3)
# Compute areas of boxes
areas = (right - left) * (bottom - top)
# Calculate intersection points
inter_l = jnp.maximum(left(None, :), left(:, None))
inter_t = jnp.maximum(top(None, :), top(:, None))
inter_r = jnp.minimum(right(None, :), right(:, None))
inter_b = jnp.minimum(bottom(None, :), bottom(:, None))
# Width, height, and area of the intersection
inter_w = jnp.clip(inter_r - inter_l, 0)
inter_h = jnp.clip(inter_b - inter_t, 0)
inter_area = inter_w * inter_h
# Union of the areas
union = areas(None, :) + areas(:, None) - inter_area
# IoU calculation
iou = inter_area / jnp.clip(union, epsilon)
# Shift scores to be greater than zero
out_scores = scores - jnp.min(scores) + epsilon
# Create mask based on IoU threshold
mask_threshold = iou > threshold
# Create mask excluding diagonal (i.e., self IoU is ignored)
mask_threshold2 = mask_threshold * (1-jnp.eye(mask_threshold.shape(0),
dtype=mask_threshold.dtype))
return mask_threshold, mask_threshold2, out_scores
@functools.partial(jax.jit, static_argnames=('max_output_size', 'threshold'))
def nms_jax(boxes, scores, max_output_size, threshold=0.1):
# initialize mask and score tensors
mask_threshold, mask_threshold2, out_scores = init_tensors(boxes,
scores,
threshold)
# The out_scores tensor will retain the scores of the chosen boxes
# and zero the scores of the eliminated ones
# remaining_scores will maintain non-zero scores for boxes that
# have not been chosen or eliminated
remaining_scores = out_scores.copy()
def choose_box(state):
i, remaining_scores, out_scores = state
# choose index of box with highest score from remaining scores
index = jnp.argmax(remaining_scores)
# check validity of chosen box
valid = remaining_scores(index) > 0
# If valid, zero all scores with IOU greater than threshold
# (including the chosen index)
remaining_scores = jnp.where(mask_threshold(index) *valid,
0,
remaining_scores)
# zero the scores of the eliminated tensors (not including
# the chosen index)
out_scores = jnp.where(mask_threshold2(index)*valid,
0,
out_scores)
i = i + 1
return i, remaining_scores, out_scores
def cond_fun(state):
i, _, _ = state
return (i < max_output_size)
i = 0
state = (i, remaining_scores, out_scores)
_, _, out_scores = jax.lax.while_loop(cond_fun, choose_box, state)
# Output the resultant scores. To extract the chosen boxes,
# Take the max_output_size highest scores:
# min = jnp.minimum(jnp.count_nonzero(scores), max_output_size)
# indexes = jnp.argsort(out_scores, descending=True)(:min)
return out_scores
# nms_jax can be run on either the CPU the TPU
rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=True)
time = benchmark(nms_jax)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_jax on CPU: {time}')
rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=False)
time = benchmark(nms_jax)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_jax on TPU: {time}')
The execution times of this NMS implementation are 1.231 and 0.416 milliseconds on CPU and TPU, respectively.
We now present a custom implementation of NMS in which we explicitly take advantage of the fact that in TPUs the Pallas cores run on a sequential manner. Our implementation uses two Boolean matrix masks and two scoring tensors, similar to the approach of our previous function.
We define a kernel function, choose_boxresponsible for selecting the next box and updating the scoring tensors, which are kept in temporary memory. We invoke the kernel through a one-dimensional grid where the number of steps (i.e., the grid size) is determined by the max_output_size parameter.
Please note that due to some limitations (at the time of writing) in Pallas-supported operations, some acrobatics are required to implement both the “argmax” function and the validity check of selected boxes. For the sake of brevity, we omit technical details and refer the interested reader to the comments in the code below.
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu# argmax helper function
def pallas_argmax(scores, n_boxes):
# we assume that the index of each box is stored in the
# least significant bits of the score (see below)
idx = jnp.max(scores.astype(float)).astype(int) % n_boxes
return idx
# Pallas kernel definition
def choose_box(scores, thresh_mask1, thresh_mask2, ret_scores,
scores_scratch, remaining_scores_scratch, *, nsteps, n_boxes):
# initialize scratch memory on first step
@pl.when(pl.program_id(0) == 0)
def _():
scores_scratch(...) = scores(...)
remaining_scores_scratch(...) = scores(...)
remaining_scores = remaining_scores_scratch(...)
# choose box
idx = pallas_argmax(remaining_scores, n_boxes)
# we use any to verfiy validity of the chosen box due
# to limitations on indexing in pallas
valid = (remaining_scores>0).any()
# updating score tensors
remaining_scores_scratch(...) = jnp.where(thresh_mask1(idx,...)*valid,
0,
remaining_scores)
scores_scratch(...) = jnp.where(thresh_mask2(idx,...)*valid,
0,
scores_scratch(...))
# set return value on final step
@pl.when(pl.program_id(0) == nsteps - 1)
def _():
ret_scores(...) = scores_scratch(...)
@functools.partial(jax.jit, static_argnames=('max_output_size', 'threshold'))
def nms_pallas(boxes, scores, max_output_size, threshold=0.1):
n_boxes = scores.size
mask_threshold, mask_threshold2, scores = init_tensors(boxes,
scores,
threshold)
# In order to work around the Pallas argsort limitation
# we create a new scores tensor with the same ordering of
# the input scores tensor in which the index of each score
# in the ordering is encoded in the least significant bits
sorted = jnp.argsort(scores, descending=True)
# descending integers: n_boxes-1, ..., 2, 1, 0
descending = jnp.flip(jnp.arange(n_boxes))
# new scores in descending with the least significant
# bits carrying the argsort of the input scores
ordered_scores = n_boxes * descending + sorted
# new scores with same ordering as input scores
scores = jnp.empty_like(ordered_scores
).at(sorted).set(ordered_scores)
grid = (max_output_size,)
return pl.pallas_call(
functools.partial(choose_box,
nsteps=max_output_size,
n_boxes=n_boxes),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=(
pl.BlockSpec(block_shape=(n_boxes,)),
pl.BlockSpec(block_shape=(n_boxes, n_boxes)),
pl.BlockSpec(block_shape=(n_boxes, n_boxes)),
),
out_specs=pl.BlockSpec(block_shape=(n_boxes,)),
scratch_shapes=(pltpu.VMEM((n_boxes,), scores.dtype),
pltpu.VMEM((n_boxes,), scores.dtype)),
grid=grid,
),
out_shape=jax.ShapeDtypeStruct((n_boxes,), scores.dtype),
compiler_params=dict(mosaic=dict(
dimension_semantics=("arbitrary",)))
)(scores, mask_threshold, mask_threshold2)
rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=False)
time = benchmark(nms_pallas)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_pallas: {time}')
The average execution time of our custom NMS operator is 0.139 milliseconds, making it about three times faster than our native JAX implementation. This result highlights the potential of tailoring the implementation of sequential algorithms to the unique properties of the TPU architecture.
Note that in our implementation of the Pallas kernel, we load the full input tensors into TPU VMEM memory. Given the limited capacity of VMEM, increasing the input size (i.e. increasing the number of bounding boxes) will likely cause memory problems. Typically, these limitations can be addressed by fragmenting the entries with block specifications. Unfortunately, applying this approach would disrupt the current implementation of the NMS. Implementing NMS in input fragments would require a different design, which is beyond the scope of this post.
The results of our experiments are summarized in the following table:
These results demonstrate the potential for running full ML compute graphs on TPU, even when they include sequential components. The performance improvement demonstrated by our Pallas NMS operator, in particular, highlights the opportunity to customize kernels in a way that takes advantage of the strengths of TPUs.
in our previous post We learned about the opportunity to create custom TPU operators using the Pallas extension for JAX. Maximizing this opportunity requires tailoring kernel implementations to the specific properties of the TPU architecture. In this post, we focus on the sequential nature of the TPU processor and its use to optimize a custom NMS kernel. While expanding the solution to support an unlimited number of bounding boxes would require more work, the basic principles we've discussed are still applicable.
Even in the experimental phase of its development, some limitations remain in Pallas that may require creative solutions. But the strength and potential is clearly evident and we anticipate it will only increase as the framework matures.