The vanilla VAE shows distinct groups, while the CVAE has a more homogeneous distribution. Vanilla VAE encodes the class and class variation in the latent space since no conditional signal is provided. However, CVAE does not need to learn to distinguish classes and the latent space can focus on variation within classes. Therefore, a CVAE can potentially learn more information, since it is not dependent on having to learn basic class conditioning.
Two model architectures were created to test image generation. The first architecture was a convolutional CVAE with a conditional concatenation approach. All networks were created for Fashion-MNIST images of size 28×28 (784 pixels in total).
class ConcatConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=10):
super().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 128 * 4 * 4
# Conditional embedding
self.label_embedding = nn.Embedding(num_classes, 32)
# Latent space (with concatenated condition)
self.fc_mu = nn.Linear(self.flatten_size + 32, latent_dim)
self.fc_var = nn.Linear(self.flatten_size + 32, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim + 32, 4 * 4 * 128)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x, c):
x = self.encoder(x)
c = self.label_embedding(c)
# Concatenate condition with encoded input
x = torch.cat((x, c), dim=1)
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
c = self.label_embedding(c)
# Concatenate condition with latent vector
z = torch.cat((z, c), dim=1)
z = self.decoder_input(z)
z = z.view(-1, 128, 4, 4)
return self.decoder(z)
def forward(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
The CVAE encoder consists of 3 convolutional layers, each followed by a ReLU nonlinearity. The output of the encoder is then flattened. The class number is then passed through an embedding layer and added to the output of the encoder. Then the reparametrization trick with 2 linear layers is used to obtain μ and σ in the latent space. Once sampled, the output of the remametrized latent space is passed to the decoder now concatenated with the output of the class number embedding layer. The decoder consists of 3 transposed convolutional layers. The first two contain a ReLU nonlinearity and the last layer contains a sigmoid nonlinearity. The output of the decoder is a 28×28 generated image.
The other model architecture follows the same approach but adding conditional input instead of concatenating. An important question was whether appending or concatenating will lead to better reconstruction or generation results.
class AdditiveConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=10):
super().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 128 * 4 * 4
# Conditional embedding
self.label_embedding = nn.Embedding(num_classes, self.flatten_size)
# Latent space (without concatenation)
self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
self.fc_var = nn.Linear(self.flatten_size, latent_dim)
# Decoder condition embedding
self.decoder_label_embedding = nn.Embedding(num_classes, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim, 4 * 4 * 128)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x, c):
x = self.encoder(x)
c = self.label_embedding(c)
# Add condition to encoded input
x = x + c
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
# Add condition to latent vector
c = self.decoder_label_embedding(c)
z = z + c
z = self.decoder_input(z)
z = z.view(-1, 128, 4, 4)
return self.decoder(z)
def forward(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
The same loss function is used for all CVAEs in the equation shown above.
def loss_function(recon_x, x, mu, logvar):
"""Computes the loss = -ELBO = Negative Log-Likelihood + KL Divergence.
Args:
recon_x: Decoder output.
x: Ground truth.
mu: Mean of Z
logvar: Log-Variance of Z
"""
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
To evaluate model-generated images, three quantitative metrics are commonly used. The mean square error (MSE) was calculated by summing the squares of the difference between the generated image and a ground truth image in pixels. The structural similarity index measure (SSIM) is a metric that evaluates image quality by comparing two images based on structural information, luminance, and contrast (3). SSIM can be used to compare images of any size, while MSE is relative to pixel size. The SSIM score ranges from -1 to 1, where 1 indicates identical images. Frechet onset distance (FID) is a metric to quantify the realism and diversity of generated images. As FID is a distance measure, lower scores are indicative of better reconstruction of a set of images.
Before expanding to full text to image, reconstruction and generation of CVAE images in Fashion-MNIST. Fashion-MNIST is an MNIST-like dataset that consists of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28×28 grayscale image, associated with a 10-class label (4).
Preprocessing functions were created to extract the relevant keyword containing the class name from the input short text regular expression match. Additional descriptors (synonyms) were used for most classes to account for similar fashion items included in each class (e.g., coat and jacket).
classes = {
'Shirt':0,
'Top':0,
'Trouser':1,
'Pants':1,
'Pullover':2,
'Sweater':2,
'Hoodie':2,
'Dress':3,
'Coat':4,
'Jacket':4,
'Sandal':5,
'Shirt':6,
'Sneaker':7,
'Shoe':7,
'Bag':8,
'Ankle boot':9,
'Boot':9
}def word_to_text(input_str, classes, model, device):
label = class_embedding(input_str, classes)
if label == -1: return Exception("No valid label")
samples = sample_images(model, num_samples=4, label=label, device=device)
plot_samples(samples, input_str, torch.tensor((label)))
return
def class_embedding(input_str, classes):
for key in list(classes.keys()):
template = f'(?i)\\b{key}\\b'
output = re.search(template, input_str)
if output: return classes(key)
return -1
The class name was then converted to its class number and used as conditional input to the CVAE. To generate an image, the class label extracted from the short text description is passed to the decoder with random samples from a Gaussian distribution to input the variable from the latent space.
Before testing the generation, image reconstruction is tested to ensure the functionality of the CVAE. Due to creating a convolutional network with 28×28 images, the network can be trained in less than an hour with less than 100 epochs.
The reconstructions contain the general shape of the ground truth images, but the image is missing sharp, high-frequency features. Any text or complex design patterns appear blurred in the model output. Entering any short text containing a Fashion-MNIST class produces results that resemble reconstructed images.
The generated images have an MSE of 11 and an SSIM of 0.76. These constitute good generations, meaning that on simple, small images, CVAEs can generate quality images. GANs and DDPMs will produce higher quality images with complex features, but CVAEs can handle simple cases.
Extending image generation to text of any length would require more robust methods beyond regular expression matching. To do this, Open ai's CLIP is used to convert text into a high-dimensional embedding vector. The embedding model is used in its ViT-B/32 configuration, which generates embeddings of length 512. A limitation of the CLIP model is that it has a maximum token length of 77, and studies show an even smaller effective length of 20 ( 5). Therefore, in cases where the input text contains multiple sentences, the text is split by sentence and passed through the CLIP encoder. The resulting embeddings are averaged to create the final output embedding.
A long text model requires much more complicated training data than Fashion-MNIST, so the COCO dataset was used. The COCO dataset has annotations (which are not completely robust but will be discussed later) that can be passed to CLIP to obtain embeddings. However, COCO images are 640×480 in size, which means that even with cropping transformations, a larger network is needed. Conditional input addition and concatenation architectures are tested for long text-to-image generation, but the concatenation approach is shown here:
class cVAE(nn.Module):
def __init__(self, latent_dim=128):
super().__init__()device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.clip_model, _ = clip.load("ViT-B/32", device=device)
self.clip_model.eval()
for param in self.clip_model.parameters():
param.requires_grad = False
self.latent_dim = latent_dim
# Modified encoder for 128x128 input
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, 4, stride=2, padding=1), # 64x64
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, 4, stride=2, padding=1), # 4x4
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 512 * 4 * 4 # Flattened size from encoder
# Process CLIP embeddings for encoder
self.condition_processor_encoder = nn.Sequential(
nn.Linear(512, 1024)
)
self.fc_mu = nn.Linear(self.flatten_size + 1024, latent_dim)
self.fc_var = nn.Linear(self.flatten_size + 1024, latent_dim)
self.decoder_input = nn.Linear(latent_dim + 512, 512 * 4 * 4)
# Modified decoder for 128x128 output
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 64x64
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), # 128x128
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 3, 3, stride=1, padding=1), # 128x128
nn.Sigmoid()
)
def encode_condition(self, text):
with torch.no_grad():
embeddings = ()
for sentence in text:
embeddings.append(self.clip_model.encode_text(clip.tokenize(sentence).to('cuda')).type(torch.float32))
return torch.mean(torch.stack(embeddings), dim=0)
def encode(self, x, c):
x = self.encoder(x)
c = self.condition_processor_encoder(c)
x = torch.cat((x, c), dim=1)
return self.fc_mu(x), self.fc_var(x)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
z = torch.cat((z, c), dim=1)
z = self.decoder_input(z)
z = z.view(-1, 512, 4, 4)
return self.decoder(z)
def forward(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
Another important point of research was the generation and reconstruction of images of different sizes. Specifically, modify the COCO images to be sized 64×64, 128×128, and 256×256. After training the network, the reconstruction results must first be tested.
All image sizes lead to a reconstructed background with some characteristic contours and correct colors. However, as the image size increases, more features can be recovered. This makes sense, since although it will take much longer to train a model with a larger image size, the model can capture and learn more information.
With imaging, it is extremely difficult to generate high-quality images. Most images have backgrounds to some extent and blurred features in the image. This would be expected for generating images from a CVAE. This occurs in both concatenation and addition of conditional input, but the concatenated approach works better. This is likely because concatenated conditional inputs will not interfere with important functions and will ensure that information is preserved distinctively. Conditions can be ignored if they are irrelevant. However, additive conditional inputs can interfere with existing functions and completely mess up the network by updating weights during backpropagation.
All images generated by COCO have a much lower SSIM of about 0.4 compared to Fashion-MNIST's SSIM. MSE is proportional to image size, so it is difficult to quantify the differences. The FID for the COCO image generations is in the 200s as further evidence that the images generated by COCO CVAE are not robust.
The biggest limitation when trying to use CVAE for imaging is, well, the CVAE. The amount of information that can be contained and reconstructed/generated depends largely on the size of the latent space. A latent space that is too small will not capture any meaningful information and is proportional to the size of the output image. A 28×28 image needs a much smaller latent space than a 64×64 image (since it proportionally squares the image size). However, a latent space larger than the actual image adds unnecessary information and at that point simply creates a 1 to 1 mapping. For the COCO dataset, a latent space of at least 512 is needed to capture some features. And while CVAEs are generative models, a convolutional encoder and decoder is a fairly rudimentary network. The training style of a GAN or the complex denoising process of a DDPM allow for much more complicated image generation.
Another important limitation in image generation is the data set it is trained on. Although the COCO dataset has annotations, these are not very detailed. To train complex generative models, a different data set must be used for training. COCO does not provide locations or excessive information for background details. A complex feature vector from the CLIP encoder cannot be used effectively for a CVAE in COCO.
Although CVAEs and imaging in COCO have their limitations, it creates a viable imaging model. More codes and details can be provided, just get in touch!
(1) Kingma, Diederik P, et. Alabama. “Self-encoding variational bayes”. arXiv:1312.6114 (2013).
(2) Sohn, Kihyuk et al. Alabama. “Learning the representation of structured results using deep conditional generative models.” NeurIPS Procedures (2015).
(3) Nilsson, J., et. Alabama. “Understanding ssim.” arXiv:2102.12037 (2020).
(4) Xiao, Han, et. Alabama. “Fashion-mnist: A new image dataset for comparing machine learning algorithms.” arXiv:2403.15378 (2024) (MIT license).
(5) Zhang, B., et. Alabama. “Long Clip: Unlock Long Clip Text Capability”. arXiv:2403.15378 (2024).
A shout out to my group's project partners Jake Hession (Deloitte Consultant), Ashley Hong (Google SWE), and Julian Kuppel (Quant)!