The objective is to build a simple Wasserstein Generative Adversarial Network using elementary machine learning to generate images of the same kind from the CIFAR-10 dataset. It will also enable us to understand more about GANs with regard to how Wasserstein loss affects training stability.
Concepts: Introduction to Generative Adversarial Networks (GANs) and Components Generator and Discriminator.
- Wasserstein loss: Understand the concept and how it stabilizes GAN training.
- CIFAR-10 dataset: Understand the dataset and pre-process images for training.
**Implementation Steps:**
1. Generator Architecture:
- Input: Random noise vector (latent vector)
- Output: Image similar to CIFAR-10 images (32x32 RGB)
- Basic layers such as **Dense**, **BatchNormalization**, and **UpSampling** can be used to progressively increase the image size.
2. **Discriminator Architecture:**
- Input: Real or generated image (32x32 RGB)
- Output: Scalar value indicating the "realness" of the image
- Use **Conv2D** layers for feature extraction from the image with **LeakyReLU** activations and **BatchNormalization**
3. **WGAN Loss Function:**
- Implement **Wasserstein loss** for both the generator and the discriminator.
- The discriminator tries to distinguish between real and fake images, and the generator tries to refine the fake images.
4. **Training:**
- Train the discriminator several times for each generator update (as needed for WGAN).
- Use **RMSprop** as the optimizer, which is standard to use with WGANs in order to keep training stable.
5. **Monitoring and Saving Results:**
Save generated images at some intervals during training to observe the progress.
**Dataset:**
Use the **CIFAR-10** dataset, one of the most popular datasets, with 60,000 32x32 color images in 10 classes; each class has 6,000 images.
**Requirements:**
Knowing the basics of **Neural Networks** and **Generative Models**.
Experience in using **TensorFlow/Keras** or any other deep learning framework.
**Deliverables:**
Python code for WGAN architecture.
- A short report of the results, including generated images and loss curves.
**Resources:**
- **WGAN Paper:** "Wasserstein GAN" by Arjovsky et al.
- **TensorFlow/Keras Documentation** for implementation details.
Implementing a Wasserstein GAN (WGAN) for Image Generation using CIFAR-10 Dataset
- paypal56_ab6mk6y7
- Site Admin
- Posts: 72
- Joined: Sat Oct 26, 2024 3:05 pm
- paypal56_ab6mk6y7
- Site Admin
- Posts: 72
- Joined: Sat Oct 26, 2024 3:05 pm
Re: Implementing a Wasserstein GAN (WGAN) for Image Generation using CIFAR-10 Dataset
Here’s a complete solution to implement a **Wasserstein GAN (WGAN)** for image generation using the **CIFAR-10 dataset**, written in English. The solution includes all essential steps such as data preprocessing, generator and discriminator models, loss functions, and training logic.
### Full Code for WGAN Image Generation
```python
### Code Breakdown:
1. **Data Preprocessing:**
- **CIFAR-10 dataset** is loaded and normalized to the range `[-1, 1]` for better stability during training.
2. **Generator Model:**
- A deep neural network that takes a random noise vector (`z_dim = 100`) as input and generates an image resembling a CIFAR-10 image. The generator uses dense layers, leaky ReLU activations, and batch normalization.
3. **Discriminator Model:**
- A convolutional neural network (CNN) that discriminates between real and fake images. It uses 2D convolutions, leaky ReLU, and a fully connected layer to produce a single output indicating whether the image is real or fake.
4. **WGAN Model:**
- The **Wasserstein GAN** model is built by combining the generator and discriminator. The discriminator is compiled with the **mean squared error** loss, and the optimizer used is **RMSprop**.
5. **Training Function:**
- The model is trained in alternating steps:
- The discriminator is trained on both real and fake images.
- The generator is trained to generate images that can "fool" the discriminator into classifying them as real.
- The loss is calculated using **Wasserstein loss** (mean squared error), which is more stable for training than traditional GAN loss.
6. **Image Generation:**
- At each `sample_interval`, the generator creates and saves images to track progress.
7. **Save and Display Generated Images:**
- Generated images are saved to the `images/` directory every `sample_interval` epochs.
### Full Code for WGAN Image Generation
```python
Code: Select all
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, BatchNormalization, LeakyReLU, Conv2D, UpSampling2D
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import RMSprop
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import cifar10
import os
# 1. Data Loading and Preprocessing
(x_train, _), (_, _) = cifar10.load_data()
# Normalize the images to the range [-1, 1]
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
# Image shape and latent space dimension
img_shape = (32, 32, 3)
z_dim = 100 # Latent vector size
# 2. Building the Generator Model
def build_generator(z_dim):
model = Sequential()
model.add(Dense(256, input_dim=z_dim))
model.add(LeakyReLU(0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(img_shape), activation='tanh'))
model.add(Reshape(img_shape))
return model
# 3. Building the Discriminator Model
def build_discriminator(img_shape):
model = Sequential()
model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=img_shape, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2D(256, kernel_size=3, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dense(1))
return model
# 4. Building the WGAN Model
def build_wgan(generator, discriminator):
# Compile discriminator with Wasserstein loss
discriminator.compile(loss='mean_squared_error', optimizer=RMSprop(lr=0.00005), metrics=['accuracy'])
# Set discriminator to non-trainable when training the WGAN
discriminator.trainable = False
# WGAN model
z = tf.keras.Input(shape=(z_dim,))
img = generator(z)
validity = discriminator(img)
# Compile WGAN model
wgan = tf.keras.Model(z, validity)
wgan.compile(loss='mean_squared_error', optimizer=RMSprop(lr=0.00005))
return wgan
# 5. Training Function for WGAN
def train_wgan(generator, discriminator, wgan, epochs, batch_size, sample_interval):
half_batch = batch_size // 2
for epoch in range(epochs):
# Train the Discriminator
idx = np.random.randint(0, x_train.shape[0], half_batch)
imgs = x_train[idx]
noise = np.random.normal(0, 1, (half_batch, z_dim))
gen_imgs = generator.predict(noise)
# Train on real and fake images
d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, -np.ones((half_batch, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Train the Generator
noise = np.random.normal(0, 1, (batch_size, z_dim))
g_loss = wgan.train_on_batch(noise, np.ones((batch_size, 1)))
# Output progress
if epoch % sample_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}] [G loss: {g_loss}]")
sample_images(epoch, generator)
# 6. Function to Save and Display Generated Images
def sample_images(epoch, generator):
noise = np.random.normal(0, 1, (9, z_dim))
gen_imgs = generator.predict(noise)
# Rescale images to [0, 1] for display
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(3, 3)
count = 0
for i in range(3):
for j in range(3):
axs[i, j].imshow(gen_imgs[count])
axs[i, j].axis('off')
count += 1
fig.savefig(f"images/epoch_{epoch}.png")
plt.close()
# 7. Create Directory for Image Saving
os.makedirs('images', exist_ok=True)
# 8. Initialize Models
generator = build_generator(z_dim)
discriminator = build_discriminator(img_shape)
wgan = build_wgan(generator, discriminator)
# 9. Training Parameters
epochs = 10000
batch_size = 64
sample_interval = 1000
# 10. Train the WGAN Model
train_wgan(generator, discriminator, wgan, epochs, batch_size, sample_interval)
```
1. **Data Preprocessing:**
- **CIFAR-10 dataset** is loaded and normalized to the range `[-1, 1]` for better stability during training.
2. **Generator Model:**
- A deep neural network that takes a random noise vector (`z_dim = 100`) as input and generates an image resembling a CIFAR-10 image. The generator uses dense layers, leaky ReLU activations, and batch normalization.
3. **Discriminator Model:**
- A convolutional neural network (CNN) that discriminates between real and fake images. It uses 2D convolutions, leaky ReLU, and a fully connected layer to produce a single output indicating whether the image is real or fake.
4. **WGAN Model:**
- The **Wasserstein GAN** model is built by combining the generator and discriminator. The discriminator is compiled with the **mean squared error** loss, and the optimizer used is **RMSprop**.
5. **Training Function:**
- The model is trained in alternating steps:
- The discriminator is trained on both real and fake images.
- The generator is trained to generate images that can "fool" the discriminator into classifying them as real.
- The loss is calculated using **Wasserstein loss** (mean squared error), which is more stable for training than traditional GAN loss.
6. **Image Generation:**
- At each `sample_interval`, the generator creates and saves images to track progress.
7. **Save and Display Generated Images:**
- Generated images are saved to the `images/` directory every `sample_interval` epochs.
- paypal56_ab6mk6y7
- Site Admin
- Posts: 72
- Joined: Sat Oct 26, 2024 3:05 pm
Re: Implementing a Wasserstein GAN (WGAN) for Image Generation using CIFAR-10 Dataset
### Enhancements for WGAN
1. **Gradient Penalty (WGAN-GP):**
- The **gradient penalty** is added to the loss function of the discriminator to improve training stability. It ensures that the discriminator's gradients are not too large, which can lead to unstable training in GANs.
2. **Model Evaluation:**
- We can implement a function to evaluate the quality of the generated images quantitatively using metrics like **Inception Score (IS)** or **Fréchet Inception Distance (FID)**.
3. **Model Checkpointing:**
- To save the best performing model during training, we can implement checkpointing. This saves the model weights at certain intervals, which allows for recovery and evaluation.
### Adding Gradient Penalty to WGAN
Let's modify the code to incorporate **WGAN-GP**. We will add the **gradient penalty** to the discriminator's loss.
#### Step 1: Define the Gradient Penalty
```python
```
#### Step 2: Modify the Discriminator Loss to Include the Gradient Penalty
Now, integrate the gradient penalty into the discriminator's loss function during training. The total loss of the discriminator is a combination of the Wasserstein loss (mean squared error) and the gradient penalty.
```python
```
### Step 3: Implement Model Evaluation (Inception Score)
The **Inception Score (IS)** is a commonly used metric for evaluating the quality of images generated by GANs. It measures both the diversity and quality of the images. The higher the Inception Score, the better the images.
Here is how you can implement it:
```python
```
### Step 4: Checkpointing and Saving the Model
To save the best-performing model based on the validation or evaluation metric, we can use **ModelCheckpoint**.
```python``
### Full Updated Code with Gradient Penalty and Model Evaluation
```python
1. **Gradient Penalty (WGAN-GP):**
- The **gradient penalty** is added to the loss function of the discriminator to improve training stability. It ensures that the discriminator's gradients are not too large, which can lead to unstable training in GANs.
2. **Model Evaluation:**
- We can implement a function to evaluate the quality of the generated images quantitatively using metrics like **Inception Score (IS)** or **Fréchet Inception Distance (FID)**.
3. **Model Checkpointing:**
- To save the best performing model during training, we can implement checkpointing. This saves the model weights at certain intervals, which allows for recovery and evaluation.
### Adding Gradient Penalty to WGAN
Let's modify the code to incorporate **WGAN-GP**. We will add the **gradient penalty** to the discriminator's loss.
#### Step 1: Define the Gradient Penalty
```python
Code: Select all
import tensorflow as tf
def compute_gradient_penalty(batch_size, real_images, fake_images, discriminator):
# Random weight term for interpolation
alpha = tf.random.normal([batch_size, 1, 1, 1], mean=0.0, stddev=1.0)
interpolated_images = alpha * real_images + (1 - alpha) * fake_images
with tf.GradientTape() as tape:
tape.watch(interpolated_images)
validity = discriminator(interpolated_images)
gradients = tape.gradient(validity, interpolated_images)
# Compute the norm of the gradients
grad_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
penalty = tf.reduce_mean(tf.square(grad_norm - 1.0))
return penalty
#### Step 2: Modify the Discriminator Loss to Include the Gradient Penalty
Now, integrate the gradient penalty into the discriminator's loss function during training. The total loss of the discriminator is a combination of the Wasserstein loss (mean squared error) and the gradient penalty.
```python
Code: Select all
def train_wgan_gp(generator, discriminator, wgan, epochs, batch_size, sample_interval, lambda_gp=10.0):
half_batch = batch_size // 2
for epoch in range(epochs):
# Train the Discriminator
idx = np.random.randint(0, x_train.shape[0], half_batch)
imgs = x_train[idx]
noise = np.random.normal(0, 1, (half_batch, z_dim))
gen_imgs = generator.predict(noise)
# Train on real and fake images
d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, -np.ones((half_batch, 1)))
# Compute gradient penalty
gp = compute_gradient_penalty(half_batch, imgs, gen_imgs, discriminator)
# Total discriminator loss
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) + lambda_gp * gp
# Train the Generator
noise = np.random.normal(0, 1, (batch_size, z_dim))
g_loss = wgan.train_on_batch(noise, np.ones((batch_size, 1)))
# Output progress
if epoch % sample_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}] [G loss: {g_loss}] [GP: {gp}]")
sample_images(epoch, generator)
### Step 3: Implement Model Evaluation (Inception Score)
The **Inception Score (IS)** is a commonly used metric for evaluating the quality of images generated by GANs. It measures both the diversity and quality of the images. The higher the Inception Score, the better the images.
Here is how you can implement it:
```python
Code: Select all
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.preprocessing import image
import numpy as np
def calculate_inception_score(generator, n_samples=1000, batch_size=32):
inception_model = InceptionV3(include_top=True, weights='imagenet')
def predict_inception_score(images):
images_resized = tf.image.resize(images, (299, 299)) # Resize to InceptionV3 input size
images_resized = tf.keras.applications.inception_v3.preprocess_input(images_resized)
return inception_model.predict(images_resized)
# Generate images
noise = np.random.normal(0, 1, (n_samples, z_dim))
gen_imgs = generator.predict(noise)
# Get the Inception model predictions
predictions = predict_inception_score(gen_imgs)
# Calculate the Inception Score
kl_divergence = tf.reduce_sum(predictions * tf.math.log(predictions + 1e-8), axis=1)
is_score = np.exp(np.mean(kl_divergence) - np.mean(tf.math.log(np.mean(predictions, axis=0))))
return is_score
### Step 4: Checkpointing and Saving the Model
To save the best-performing model based on the validation or evaluation metric, we can use **ModelCheckpoint**.
```python
Code: Select all
from tensorflow.keras.callbacks import ModelCheckpoint
# Callback for saving the model with the best validation performance
checkpoint = ModelCheckpoint('wgan_best_model.h5', save_best_only=True, save_weights_only=True)
# Add the callback to your training function or modify the training loop to save the model at intervals
### Full Updated Code with Gradient Penalty and Model Evaluation
```python
Code: Select all
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, BatchNormalization, LeakyReLU, Conv2D, UpSampling2D
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import RMSprop
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import cifar10
import os
# 1. Data Loading and Preprocessing
(x_train, _), (_, _) = cifar10.load_data()
# Normalize the images to the range [-1, 1]
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
# Image shape and latent space dimension
img_shape = (32, 32, 3)
z_dim = 100 # Latent vector size
# 2. Building the Generator Model
def build_generator(z_dim):
model = Sequential()
model.add(Dense(256, input_dim=z_dim))
model.add(LeakyReLU(0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(img_shape), activation='tanh'))
model.add(Reshape(img_shape))
return model
# 3. Building the Discriminator Model
def build_discriminator(img_shape):
model = Sequential()
model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=img_shape, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Conv2D(256, kernel_size=3, strides=2, padding='same'))
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dense(1))
return model
# 4. Building the WGAN Model
def build_wgan(generator, discriminator):
# Compile discriminator with Wasserstein loss
discriminator.compile(loss='mean_squared_error', optimizer=RMSprop(lr=0.00005), metrics=['accuracy'])
discriminator.trainable = False
z = tf.keras.Input(shape=(z_dim,))
img = generator(z)
validity = discriminator(img)
wgan = tf.keras.Model(z, validity)
wgan.compile(loss='mean_squared_error', optimizer=RMSprop(lr=0.00005))
return wgan
# 5. Compute Gradient Penalty
def compute_gradient_penalty(batch_size, real_images, fake_images, discriminator):
alpha = tf.random.normal([batch_size, 1, 1, 1], mean=0.0, stddev=1.0)
interpolated_images = alpha * real_images + (1 - alpha) * fake_images
with tf.GradientTape() as tape:
tape.watch(interpolated_images)
validity = discriminator(interpolated_images)
gradients = tape.gradient(validity, interpolated_images)
grad_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
penalty = tf.reduce_mean(tf.square(grad_norm - 1.0))
return penalty
# 6. Train WGAN-GP
def train_wgan_gp(generator, discriminator, wgan, epochs, batch_size, sample_interval, lambda_gp=10.0):
half_batch = batch_size // 2
for epoch in range(epochs):
idx = np.random.randint(0, x_train.shape[0], half_batch)
imgs = x_train[idx]
noise = np.random.normal(0, 1, (half_batch, z_dim))
gen_imgs = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, -np.ones((half_batch, 1)))
gp = compute_gradient_penalty(half_batch, imgs, gen_imgs, discriminator)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) + lambda_gp * gp
noise = np.random.normal(0, 1, (batch_size, z_dim))
g_loss = wgan.train_on_batch(noise, np.ones((batch_size, 1)))
if epoch % sample_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}] [G loss: {g_loss}] [GP: {gp}]")
sample_images(epoch, generator)
# 7. Function to Save and Display Generated Images
def sample_images(epoch, generator):
noise = np.random.normal(0, 1, (9, z_dim))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(3, 3)
count = 0
for i in range(3):
for j in range(3):
axs[i, j].imshow(gen_imgs[count])
axs[i, j].axis('off')
count += 1
fig.savefig(f"images/epoch_{epoch}.png")
plt.close()
# 8. Initialize Models and Train
os.makedirs('images', exist_ok=True)
generator = build_generator(z_dim)
discriminator = build_discriminator(img_shape)
wgan = build_wgan(generator, discriminator)
epochs = 10000
batch_size = 64
sample_interval = 1000
train_wgan_gp(generator, discriminator, wgan, epochs, batch_size, sample_interval)