the idea of Generative Adversarial Networks, or GANs, was introduced by Goodfellow and his colleagues (1) in 2014, and soon after became extremely popular in the field of computer vision and imaging. Despite the last 10 years of rapid development within the ai domain and the growth in the number of new algorithms, the simplicity and brilliance of this concept remains extremely impressive. So today I want to illustrate how powerful these networks can be when trying to remove clouds from RGB (red, green, blue) satellite images.
Preparing a properly balanced, large enough and properly preprocessed CV dataset requires a huge amount of time, so I decided to explore what Kaggle has to offer. The dataset I found most appropriate for this task is EuroSat (2), which has an open license. Understands 27000 64×64 pixel labeled RGB images Sentinel-2 and is designed to solve the multi-class classification problem.
We are not interested in the classification itself, but one of the main characteristics of the EuroSat data set is that all its images have a clear sky. That's exactly what we need. Taking this approach from (3), we will use these Sentinel-2 shots as targets and create inputs by adding noise (clouds) to them.
So let's prepare our data before we actually talk about GANs. First of all, we need to download the data and merge all the classes into a directory.
The complete Python code: GitHub.
import numpy as np
import pandas as pd
import randomfrom os import listdir, mkdir, rename
from os.path import join, exists
import shutil
import datetime
import matplotlib.pyplot as plt
from highlight_text import ax_text, fig_text
from PIL import Image
import warnings
warnings.filterwarnings('ignore')
classes = listdir('./EuroSat')
path_target = './EuroSat/all_targets'
path_input = './EuroSat/all_inputs'"""RUN IT ONLY ONCE TO RENAME THE FILES IN THE UNPACKED ARCHIVE"""
mkdir(path_input)
mkdir(path_target)
k = 1
for kind in classes:
path = join('./EuroSat', str(kind))
for i, f in enumerate(listdir(path)):
shutil.copyfile(join(path, f),
join(path_target, f))
rename(join(path_target, f), join(path_target, f'{k}.jpg'))
k += 1
The second important step is to generate noise. While you can use different approaches, for example, randomly masking some pixels, adding some Gaussian noise, in this article I want to try something new for me: Perlin noise. It was invented in the 80s by Ken Perlin (4) when he was developing cinematic smoke effects. This type of noise has a more organic appearance compared to normal random noise. Just let me prove it.
def generate_perlin_noise(width, height, scale, octaves, persistence, lacunarity):
noise = np.zeros((height, width))
for i in range(height):
for j in range(width):
noise(i)(j) = pnoise2(i / scale,
j / scale,
octaves=octaves,
persistence=persistence,
lacunarity=lacunarity,
repeatx=width,
repeaty=height,
base=0)
return noisedef normalize_noise(noise):
min_val = noise.min()
max_val = noise.max()
return (noise - min_val) / (max_val - min_val)
def generate_clouds(width, height, base_scale, octaves, persistence, lacunarity):
clouds = np.zeros((height, width))
for octave in range(1, octaves + 1):
scale = base_scale / octave
layer = generate_perlin_noise(width, height, scale, 1, persistence, lacunarity)
clouds += layer * (persistence ** octave)
clouds = normalize_noise(clouds)
return clouds
def overlay_clouds(image, clouds, alpha=0.5):
clouds_rgb = np.stack((clouds) * 3, axis=-1)
image = image.astype(float) / 255.0
clouds_rgb = clouds_rgb.astype(float)
blended = image * (1 - alpha) + clouds_rgb * alpha
blended = (blended * 255).astype(np.uint8)
return blended
width, height = 64, 64
octaves = 12 #number of noise layers combined
persistence = 0.5 #lower persistence reduces the amplitude of higher-frequency octaves
lacunarity = 2 #higher lacunarity increases the frequency of higher-frequency octaves
for i in range(len(listdir(path_target))):
base_scale = random.uniform(5,120) #noise frequency
alpha = random.uniform(0,1) #transparencyclouds = generate_clouds(width, height, base_scale, octaves, persistence, lacunarity)
img = np.asarray(Image.open(join(path_target, f'{i+1}.jpg')))
image = Image.fromarray(overlay_clouds(img,clouds, alpha))
image.save(join(path_input,f'{i+1}.jpg'))
print(f'Processed {i+1}/{len(listdir(path_target))}')
idx = np.random.randint(27000)
fig,ax = plt.subplots(1,2)
ax(0).imshow(np.asarray(Image.open(join(path_target, f'{idx}.jpg'))))
ax(1).imshow(np.asarray(Image.open(join(path_input, f'{idx}.jpg'))))
ax(0).set_title("Target")
ax(0).axis('off')
ax(1).set_title("Input")
ax(1).axis('off')
plt.show()
As you can see above, the clouds in the images are very realistic, they have different “density” and texture that resemble the real ones.
If, like me, you are intrigued by Perlin noise, here is a really interesting video on how this noise can be applied in the GameDev industry:
Now that we have a data set ready to use, let's talk about GAN.
To better illustrate this idea, let's imagine that you are traveling through Southeast Asia and you urgently need a hoodie, since it is too cold outside. When you reach the nearest market, you will find a small store with brand name clothing. The seller brings you a nice hoodie to try on saying that it is the famous ExpensiveButNotWorthIt brand. You take a closer look and conclude that it is obviously fake. The salesman says, 'Wait a second, I have the REAL one.' He returns with another hoodie, which looks more like the brand name one, but is still fake. After several iterations like this, the seller brings an indistinguishable copy of the legendary ExpensiveButNotWorthIt and you easily purchase it. This is basically how GANs work!
In the case of GANs, it is called a discriminator (D). The goal of a discriminator is to distinguish between a true and a false object, or to solve the binary classification task. The seller is called a generator (G), since he tries to generate a high-quality fake. The discriminator and generator are independently trained to outperform each other. Therefore, in the end we get a high-quality fake.
The training process originally looks like this:
- Sample input noise (in our case images with clouds).
- Feed the noise to G and collect the prediction.
- Calculate the loss of D by getting 2 predictions, one for the output of G and one for the actual data.
- Update the weights of D.
- Test the input noise again.
- Feed the noise to G and collect the prediction.
- Calculate the loss G by feeding your prediction to D.
- Update the weights of G.
In other words, we can define a value function V(G,D):
where we want to minimize the term Login(1-D(G(z))) to train G and maximize register D(x) to train D (in this notation x – real data sample and z – noise).
Now let's try to implement it in pytorch!
In the original article, the authors talk about the use of multilayer perceptron (MLP); It is also often referred to simply as ANN, but I want to try a slightly more complicated approach: I want to use the UNet(5) architecture as a generator and ResNet(6) as a discriminator. Both are well-known CNN architectures, so I won't explain them here (let me know if I should write a separate article in the comments).
Let's build them. Discriminated:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.data import Subset
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
nn.BatchNorm2d(out_channels),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels))
self.downsample = downsample
self.relu = nn.ReLU()
self.out_channels = out_channelsdef forward(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block=ResidualBlock, all_connections=(3,4,6,3)):
super(ResNet, self).__init__()
self.inputs = 16
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(16),
nn.ReLU()) #16x64x64
self.maxpool = nn.MaxPool2d(kernel_size = 2, stride = 2) #16x32x32
self.layer0 = self.makeLayer(block, 16, all_connections(0), stride = 1) #connections = 3, shape: 16x32x32
self.layer1 = self.makeLayer(block, 32, all_connections(1), stride = 2)#connections = 4, shape: 32x16x16
self.layer2 = self.makeLayer(block, 128, all_connections(2), stride = 2)#connections = 6, shape: 1281x8x8
self.layer3 = self.makeLayer(block, 256, all_connections(3), stride = 2)#connections = 3, shape: 256x4x4
self.avgpool = nn.AvgPool2d(4, stride=1)
self.fc = nn.Linear(256, 1)
def makeLayer(self, block, outputs, connections, stride=1):
downsample = None
if stride != 1 or self.inputs != outputs:
downsample = nn.Sequential(
nn.Conv2d(self.inputs, outputs, kernel_size=1, stride=stride),
nn.BatchNorm2d(outputs),
)
layers = ()
layers.append(block(self.inputs, outputs, stride, downsample))
self.inputs = outputs
for i in range(1, connections):
layers.append(block(self.inputs, outputs))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(-1, 256)
x = self.fc(x).flatten()
return F.sigmoid(x)
Generator:
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self):
super().__init__()
self.conv_1 = DoubleConv(3, 32) # 32x64x64
self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32x32
self.conv_2 = DoubleConv(32, 64) #64x32x32
self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2) #64x16x16
self.conv_3 = DoubleConv(64, 128) #128x16x16
self.pool_3 = nn.MaxPool2d(kernel_size=2, stride=2) #128x8x8
self.conv_4 = DoubleConv(128, 256) #256x8x8
self.pool_4 = nn.MaxPool2d(kernel_size=2, stride=2) #256x4x4
self.conv_5 = DoubleConv(256, 512) #512x2x2
#DECODER
self.upconv_1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) #256x4x4
self.conv_6 = DoubleConv(512, 256) #256x4x4
self.upconv_2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) #128x8x8
self.conv_7 = DoubleConv(256, 128) #128x8x8
self.upconv_3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) #64x16x16
self.conv_8 = DoubleConv(128, 64) #64x16x16
self.upconv_4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) #32x32x32
self.conv_9 = DoubleConv(64, 32) #32x32x32
self.output = nn.Conv2d(32, 3, kernel_size = 3, stride = 1, padding = 1) #3x64x64
def forward(self, batch):
conv_1_out = self.conv_1(batch)
conv_2_out = self.conv_2(self.pool_1(conv_1_out))
conv_3_out = self.conv_3(self.pool_2(conv_2_out))
conv_4_out = self.conv_4(self.pool_3(conv_3_out))
conv_5_out = self.conv_5(self.pool_4(conv_4_out))
conv_6_out = self.conv_6(torch.cat((self.upconv_1(conv_5_out), conv_4_out), dim=1))
conv_7_out = self.conv_7(torch.cat((self.upconv_2(conv_6_out), conv_3_out), dim=1))
conv_8_out = self.conv_8(torch.cat((self.upconv_3(conv_7_out), conv_2_out), dim=1))
conv_9_out = self.conv_9(torch.cat((self.upconv_4(conv_8_out), conv_1_out), dim=1))
output = self.output(conv_9_out)
return F.sigmoid(output)
Now we need to split our data into training/testing and wrap it in a torch data set:
class dataset(Dataset):
def __init__(self, batch_size, images_paths, targets, img_size = 64):
self.batch_size = batch_size
self.img_size = img_size
self.images_paths = images_paths
self.targets = targets
self.len = len(self.images_paths) // batch_sizeself.transform = transforms.Compose((
transforms.ToTensor(),
))
self.batch_im = (self.images_paths(idx * self.batch_size:(idx + 1) * self.batch_size) for idx in range(self.len))
self.batch_t = (self.targets(idx * self.batch_size:(idx + 1) * self.batch_size) for idx in range(self.len))
def __getitem__(self, idx):
pred = torch.stack((
self.transform(Image.open(join(path_input,file_name)))
for file_name in self.batch_im(idx)
))
target = torch.stack((
self.transform(Image.open(join(path_target,file_name)))
for file_name in self.batch_im(idx)
))
return pred, target
def __len__(self):
return self.len
Perfect. It's time to write the training cycle. Before doing so, let's define our loss functions and optimizer:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")batch_size = 64
num_epochs = 15
learning_rate_D = 1e-5
learning_rate_G = 1e-4
discriminator = ResNet()
generator = UNet()
bce = nn.BCEWithLogitsLoss()
l1loss = nn.L1Loss()
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate_D)
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate_G)
scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=10, gamma=0.1)
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=10, gamma=0.1)
As you can see, these losses are different from those in the image with the GAN algorithm. In particular, I added L1Loss. The idea is that we're not just generating a random image from noise, we want to keep most of the information in the input and simply remove the noise. Then the loss G will be:
G_loss = log(1 − D(G(z))) + 𝝀 |G(z)-y|
instead of just
G_loss = log(1 − D(G(z)))
𝝀 is an arbitrary coefficient that balances two loss components.
Finally, let's split the data to start the training process:
test_ratio, train_ratio = 0.3, 0.7
num_test = int(len(listdir(path_target))*test_ratio)
num_train = int((int(len(listdir(path_target)))-num_test))img_size = (64, 64)
print("Number of train samples:", num_train)
print("Number of test samples:", num_test)
random.seed(231)
train_idxs = np.array(random.sample(range(num_test+num_train), num_train))
mask = np.ones(num_train+num_test, dtype=bool)
mask(train_idxs) = False
images = {}
features = random.sample(listdir(path_input),num_test+num_train)
targets = random.sample(listdir(path_target),num_test+num_train)
random.Random(231).shuffle(features)
random.Random(231).shuffle(targets)
train_input_img_paths = np.array(features)(train_idxs)
train_target_img_path = np.array(targets)(train_idxs)
test_input_img_paths = np.array(features)(mask)
test_target_img_path = np.array(targets)(mask)
train_loader = dataset(batch_size=batch_size, img_size=img_size, images_paths=train_input_img_paths, targets=train_target_img_path)
test_loader = dataset(batch_size=batch_size, img_size=img_size, images_paths=test_input_img_paths, targets=test_target_img_path)
Now we can run our training loop:
train_loss_G, train_loss_D, val_loss_G, val_loss_D = (), (), (), ()
all_loss_G, all_loss_D = (), ()
best_generator_epoch_val_loss, best_discriminator_epoch_val_loss = -np.inf, -np.inf
for epoch in range(num_epochs):discriminator.train()
generator.train()
discriminator_epoch_loss, generator_epoch_loss = 0, 0
for inputs, targets in train_loader:
inputs, true = inputs, targets
'''1. Training the Discriminator (ResNet)'''
optimizer_D.zero_grad()
fake = generator(inputs).detach()
pred_fake = discriminator(fake).to(device)
loss_fake = bce(pred_fake, torch.zeros(batch_size, device=device))
pred_real = discriminator(true).to(device)
loss_real = bce(pred_real, torch.ones(batch_size, device=device))
loss_D = (loss_fake+loss_real)/2
loss_D.backward()
optimizer_D.step()
discriminator_epoch_loss += loss_D.item()
all_loss_D.append(loss_D.item())
'''2. Training the Generator (UNet)'''
optimizer_G.zero_grad()
fake = generator(inputs)
pred_fake = discriminator(fake).to(device)
loss_G_bce = bce(pred_fake, torch.ones_like(pred_fake, device=device))
loss_G_l1 = l1loss(fake, targets)*100
loss_G = loss_G_bce + loss_G_l1
loss_G.backward()
optimizer_G.step()
generator_epoch_loss += loss_G.item()
all_loss_G.append(loss_G.item())
discriminator_epoch_loss /= len(train_loader)
generator_epoch_loss /= len(train_loader)
train_loss_D.append(discriminator_epoch_loss)
train_loss_G.append(generator_epoch_loss)
discriminator.eval()
generator.eval()
discriminator_epoch_val_loss, generator_epoch_val_loss = 0, 0
with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs, targets
fake = generator(inputs)
pred = discriminator(fake).to(device)
loss_G_bce = bce(fake, torch.ones_like(fake, device=device))
loss_G_l1 = l1loss(fake, targets)*100
loss_G = loss_G_bce + loss_G_l1
loss_D = bce(pred.to(device), torch.zeros(batch_size, device=device))
discriminator_epoch_val_loss += loss_D.item()
generator_epoch_val_loss += loss_G.item()
discriminator_epoch_val_loss /= len(test_loader)
generator_epoch_val_loss /= len(test_loader)
val_loss_D.append(discriminator_epoch_val_loss)
val_loss_G.append(generator_epoch_val_loss)
print(f"------Epoch ({epoch+1}/{num_epochs})------\nTrain Loss D: {discriminator_epoch_loss:.4f}, Val Loss D: {discriminator_epoch_val_loss:.4f}")
print(f'Train Loss G: {generator_epoch_loss:.4f}, Val Loss G: {generator_epoch_val_loss:.4f}')
if discriminator_epoch_val_loss > best_discriminator_epoch_val_loss:
discriminator_epoch_val_loss = best_discriminator_epoch_val_loss
torch.save(discriminator.state_dict(), "discriminator.pth")
if generator_epoch_val_loss > best_generator_epoch_val_loss:
generator_epoch_val_loss = best_generator_epoch_val_loss
torch.save(generator.state_dict(), "generator.pth")
#scheduler_D.step()
#scheduler_G.step()
fig, ax = plt.subplots(1,3)
ax(0).imshow(np.transpose(inputs.numpy()(7), (1,2,0)))
ax(1).imshow(np.transpose(targets.numpy()(7), (1,2,0)))
ax(2).imshow(np.transpose(fake.detach().numpy()(7), (1,2,0)))
plt.show()
Once the code is finished, we can plot the losses. This code was adopted in part from this cool website:
from matplotlib.font_manager import FontPropertiesbackground_color = '#001219'
font = FontProperties(fname='LexendDeca-VariableFont_wght.ttf')
fig, ax = plt.subplots(1, 2, figsize=(16, 9))
fig.set_facecolor(background_color)
ax(0).set_facecolor(background_color)
ax(1).set_facecolor(background_color)
ax(0).plot(range(len(all_loss_G)), all_loss_G, color='#bc6c25', lw=0.5)
ax(1).plot(range(len(all_loss_D)), all_loss_D, color='#00b4d8', lw=0.5)
ax(0).scatter(
(np.array(all_loss_G).argmax(), np.array(all_loss_G).argmin()),
(np.array(all_loss_G).max(), np.array(all_loss_G).min()),
s=30, color='#bc6c25',
)
ax(1).scatter(
(np.array(all_loss_D).argmax(), np.array(all_loss_D).argmin()),
(np.array(all_loss_D).max(), np.array(all_loss_D).min()),
s=30, color='#00b4d8',
)
ax_text(
np.array(all_loss_G).argmax()+60, np.array(all_loss_G).max()+0.1,
f'{round(np.array(all_loss_G).max(),1)}',
fontsize=13, color='#bc6c25',
font=font,
ax=ax(0)
)
ax_text(
np.array(all_loss_G).argmin()+60, np.array(all_loss_G).min()-0.1,
f'{round(np.array(all_loss_G).min(),1)}',
fontsize=13, color='#bc6c25',
font=font,
ax=ax(0)
)
ax_text(
np.array(all_loss_D).argmax()+60, np.array(all_loss_D).max()+0.01,
f'{round(np.array(all_loss_D).max(),1)}',
fontsize=13, color='#00b4d8',
font=font,
ax=ax(1)
)
ax_text(
np.array(all_loss_D).argmin()+60, np.array(all_loss_D).min()-0.005,
f'{round(np.array(all_loss_D).min(),1)}',
fontsize=13, color='#00b4d8',
font=font,
ax=ax(1)
)
for i in range(2):
ax(i).tick_params(axis='x', colors='white')
ax(i).tick_params(axis='y', colors='white')
ax(i).spines('left').set_color('white')
ax(i).spines('bottom').set_color('white')
ax(i).set_xlabel('Epoch', color='white', fontproperties=font, fontsize=13)
ax(i).set_ylabel('Loss', color='white', fontproperties=font, fontsize=13)
ax(0).set_title('Generator', color='white', fontproperties=font, fontsize=18)
ax(1).set_title('Discriminator', color='white', fontproperties=font, fontsize=18)
plt.savefig('Loss.jpg')
plt.show()
# ax(0).set_axis_off()
# ax(1).set_axis_off()