### Enhancements for WGAN
1. **Gradient Penalty (WGAN-GP):**
- The **gradient penalty** is added to the loss function of the discriminator to improve training stability. It ensures that the discriminator's gradients are not too large, which can lead to unstable training in GANs.
2. **Model Evaluation:**
- We can implement a function to evaluate the quality of the generated images quantitatively using metrics like **Inception Score (IS)** or **Fréchet Inception Distance (FID)**.
3. **Model Checkpointing:**
- To save the best performing model during training, we can implement checkpointing. This saves the model weights at certain intervals, which allows for recovery and evaluation.
### Adding Gradient Penalty to WGAN
Let's modify the code to incorporate **WGAN-GP**. We will add the **gradient penalty** to the discriminator's loss.
#### Step 1: Define the Gradient Penalty
```python
Code: Select all
import tensorflow as tf
def compute_gradient_penalty(batch_size, real_images, fake_images, discriminator):
# Random weight term for interpolation
alpha = tf.random.normal([batch_size, 1, 1, 1], mean=0.0, stddev=1.0)
interpolated_images = alpha * real_images + (1 - alpha) * fake_images
with tf.GradientTape() as tape:
tape.watch(interpolated_images)
validity = discriminator(interpolated_images)
gradients = tape.gradient(validity, interpolated_images)
# Compute the norm of the gradients
grad_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
penalty = tf.reduce_mean(tf.square(grad_norm - 1.0))
return penalty
```
#### Step 2: Modify the Discriminator Loss to Include the Gradient Penalty
Now, integrate the gradient penalty into the discriminator's loss function during training. The total loss of the discriminator is a combination of the Wasserstein loss (mean squared error) and the gradient penalty.
```python
Code: Select all
def train_wgan_gp(generator, discriminator, wgan, epochs, batch_size, sample_interval, lambda_gp=10.0):
half_batch = batch_size // 2
for epoch in range(epochs):
# Train the Discriminator
idx = np.random.randint(0, x_train.shape[0], half_batch)
imgs = x_train[idx]
noise = np.random.normal(0, 1, (half_batch, z_dim))
gen_imgs = generator.predict(noise)
# Train on real and fake images
d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, -np.ones((half_batch, 1)))
# Compute gradient penalty
gp = compute_gradient_penalty(half_batch, imgs, gen_imgs, discriminator)
# Total discriminator loss
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) + lambda_gp * gp
# Train the Generator
noise = np.random.normal(0, 1, (batch_size, z_dim))
g_loss = wgan.train_on_batch(noise, np.ones((batch_size, 1)))
# Output progress
if epoch % sample_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}] [G loss: {g_loss}] [GP: {gp}]")
sample_images(epoch, generator)
```
### Step 3: Implement Model Evaluation (Inception Score)
The **Inception Score (IS)** is a commonly used metric for evaluating the quality of images generated by GANs. It measures both the diversity and quality of the images. The higher the Inception Score, the better the images.
Here is how you can implement it:
```python
Code: Select all
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.preprocessing import image
import numpy as np
def calculate_inception_score(generator, n_samples=1000, batch_size=32):
inception_model = InceptionV3(include_top=True, weights='imagenet')
def predict_inception_score(images):
images_resized = tf.image.resize(images, (299, 299)) # Resize to InceptionV3 input size
images_resized = tf.keras.applications.inception_v3.preprocess_input(images_resized)
return inception_model.predict(images_resized)
# Generate images
noise = np.random.normal(0, 1, (n_samples, z_dim))
gen_imgs = generator.predict(noise)
# Get the Inception model predictions
predictions = predict_inception_score(gen_imgs)
# Calculate the Inception Score
kl_divergence = tf.reduce_sum(predictions * tf.math.log(predictions + 1e-8), axis=1)
is_score = np.exp(np.mean(kl_divergence) - np.mean(tf.math.log(np.mean(predictions, axis=0))))
return is_score
```
### Step 4: Checkpointing and Saving the Model
To save the best-performing model based on the validation or evaluation metric, we can use **ModelCheckpoint**.
```python
Code: Select all
from tensorflow.keras.callbacks import ModelCheckpoint
# Callback for saving the model with the best validation performance
checkpoint = ModelCheckpoint('wgan_best_model.h5', save_best_only=True, save_weights_only=True)
# Add the callback to your training function or modify the training loop to save the model at intervals
``
### Full Updated Code with Gradient Penalty and Model Evaluation
```python
Code: Select all
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, BatchNormalization, LeakyReLU, Conv2D, UpSampling2D
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import RMSprop
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import cifar10
import os
# 1. Data Loading and Preprocessing
(x_train, _), (_, _) = cifar10.load_data()
# Normalize the images to the range [-1, 1]
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
# Image shape and latent space dimension
img_shape = (32, 32, 3)
z_dim = 100 # Latent vector size
# 2. Building the Generator Model
def build_generator(z_dim):
model = Sequential()
model.add(Dense(256, input_dim=z_dim))
model.add(LeakyReLU(0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(img_shape), activation='tanh'))
model.add(Reshape(img_shape))
return model
# 3. Building the Discriminator Model
def build_discriminator(img_shape):
model = Sequential()
model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=img_shape, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2D(256, kernel_size=3, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dense(1))
return model
# 4. Building the WGAN Model
def build_wgan(generator, discriminator):
# Compile discriminator with Wasserstein loss
discriminator.compile(loss='mean_squared_error', optimizer=RMSprop(lr=0.00005), metrics=['accuracy'])
discriminator.trainable = False
z = tf.keras.Input(shape=(z_dim,))
img = generator(z)
validity = discriminator(img)
wgan = tf.keras.Model(z, validity)
wgan.compile(loss='mean_squared_error', optimizer=RMSprop(lr=0.00005))
return wgan
# 5. Compute Gradient Penalty
def compute_gradient_penalty(batch_size, real_images, fake_images, discriminator):
alpha = tf.random.normal([batch_size, 1, 1, 1], mean=0.0, stddev=1.0)
interpolated_images = alpha * real_images + (1 - alpha) * fake_images
with tf.GradientTape() as tape:
tape.watch(interpolated_images)
validity = discriminator(interpolated_images)
gradients = tape.gradient(validity, interpolated_images)
grad_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
penalty = tf.reduce_mean(tf.square(grad_norm - 1.0))
return penalty
# 6. Train WGAN-GP
def train_wgan_gp(generator, discriminator, wgan, epochs, batch_size, sample_interval, lambda_gp=10.0):
half_batch = batch_size // 2
for epoch in range(epochs):
idx = np.random.randint(0, x_train.shape[0], half_batch)
imgs = x_train[idx]
noise = np.random.normal(0, 1, (half_batch, z_dim))
gen_imgs = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, -np.ones((half_batch, 1)))
gp = compute_gradient_penalty(half_batch, imgs, gen_imgs, discriminator)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) + lambda_gp * gp
noise = np.random.normal(0, 1, (batch_size, z_dim))
g_loss = wgan.train_on_batch(noise, np.ones((batch_size, 1)))
if epoch % sample_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}] [G loss: {g_loss}] [GP: {gp}]")
sample_images(epoch, generator)
# 7. Function to Save and Display Generated Images
def sample_images(epoch, generator):
noise = np.random.normal(0, 1, (9, z_dim))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(3, 3)
count = 0
for i in range(3):
for j in range(3):
axs[i, j].imshow(gen_imgs[count])
axs[i, j].axis('off')
count += 1
fig.savefig(f"images/epoch_{epoch}.png")
plt.close()
# 8. Initialize Models and Train
os.makedirs('images', exist_ok=True)
generator = build_generator(z_dim)
discriminator = build_discriminator(img_shape)
wgan = build_wgan(generator, discriminator)
epochs = 10000
batch_size = 64
sample_interval = 1000
train_wgan_gp(generator, discriminator, wgan, epochs, batch_size, sample_interval)