Building a GAN network#
Ran on Colab.
1raise SystemExit("Stop right there!");
An exception has occurred, use %tb to see the full traceback.
SystemExit: Stop right there!
1# Import the required library functions
2from keras.models import Sequential
3from numpy import hstack, zeros, ones
4from numpy.random import rand, randn
5from keras.layers import Dense
6import matplotlib.pyplot as plt
1# Use Keras' linear activation function for the last layer of the
2# generator network because the output vector should consist of
3# continuous real values as a normal distribution does.
4def define_gen(latent_dim, n_outputs=2):
5 model = Sequential()
6 model.add(
7 Dense(
8 15,
9 activation="relu",
10 kernel_initializer="he_uniform",
11 input_dim=latent_dim,
12 )
13 )
14 model.add(Dense(n_outputs, activation="linear"))
15 return model
1# The discriminator network has a binary output that identifies whether
2# the input is real or fake: Use sigmoid as the activation function and
3# binary cross-entropy as loss
4def define_disc(n_inputs=2):
5 model = Sequential()
6 model.add(
7 Dense(
8 25,
9 activation="relu",
10 kernel_initializer="he_uniform",
11 input_dim=n_inputs,
12 )
13 )
14 model.add(Dense(1, activation="sigmoid"))
15 model.compile(
16 loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"]
17 )
18 return model
1# Extract real samples from the dataset to inspect fake samples against
2def define_gan(generator, discriminator):
3 discriminator.trainable = False
4 model = Sequential()
5 model.add(generator)
6 model.add(discriminator)
7 model.compile(loss="binary_crossentropy", optimizer="adam")
8 return model
1# Setting the generator model to create fake samples. Generate the
2# same number of points in the latent space, passing to the generator
3# and creating samples.
4def generate_real(n):
5 x1 = rand(n) - 0.5
6 x2 = x1 * x1
7 x1 = x1.reshape(n, 1)
8 x2 = x2.reshape(n, 1)
9 x = hstack((x1, x2))
10 y = ones((n, 1))
11 return x, y
1# Define the arguments like batch size,input feature size and output
2# feature size
3def gen_latent_points(latent_dim, n):
4 x_input = randn(latent_dim * n)
5 x_input = x_input.reshape(n, latent_dim)
6 return x_input
1# Using the generator to generate fake samples with class labels
2def gen_fake(generator, latent_dim, n):
3 x_input = gen_latent_points(latent_dim, n)
4 x = generator.predict(x_input)
5 y = zeros((n, 1))
6 return x, y
1# Evaluating the discriminator model
2def performance_summary(epoch, generator, discriminator, latent_dim, n=100):
3 x_real, y_real = generate_real(n)
4 _, acc_real = discriminator.evaluate(x_real, y_real, verbose=0)
5 x_fake, y_fake = gen_fake(generator, latent_dim, n)
6 _, acc_fake = discriminator.evaluate(x_fake, y_fake, verbose=0)
7 print(epoch, acc_real, acc_fake)
8 plt.scatter(x_real[:, 0], x_real[:, 1], color="green")
9 plt.scatter(x_fake[:, 0], x_fake[:, 1], color="red")
10 plt.show()
1# Training the model
2def train(
3 g_model,
4 d_model,
5 gan_model,
6 latent_dim,
7 n_epochs=2000,
8 n_batch=128,
9 n_eval=100,
10):
11 half_batch = int(n_batch / 2)
12 for i in range(n_epochs):
13 x_real, y_real = generate_real(half_batch)
14 x_fake, y_fake = gen_fake(g_model, latent_dim, half_batch)
15 d_model.train_on_batch(x_real, y_real)
16 d_model.train_on_batch(x_fake, y_fake)
17 x_gan = gen_latent_points(latent_dim, n_batch)
18 y_gan = ones((n_batch, 1))
19 gan_model.train_on_batch(x_gan, y_gan)
20 if (i + 1) % n_eval == 0:
21 performance_summary(i, g_model, d_model, latent_dim)
1dim_latent = 5
2gen = define_gen(dim_latent)
3disc = define_disc()
4model_gan = define_gan(gen, disc)
5train(gen, disc, model_gan, dim_latent)
99 0.6600000262260437 0.15000000596046448
199 0.7200000286102295 0.8700000047683716
299 0.7799999713897705 0.6200000047683716
399 0.8100000023841858 0.4300000071525574
499 0.7400000095367432 0.3700000047683716
599 0.800000011920929 0.3100000023841858
699 0.9100000262260437 0.49000000953674316
799 0.8199999928474426 0.6200000047683716
899 0.7699999809265137 0.5899999737739563
999 0.7099999785423279 0.5799999833106995
1099 0.6499999761581421 0.5899999737739563
1199 0.6499999761581421 0.6600000262260437
1299 0.6299999952316284 0.5699999928474426
1399 0.7699999809265137 0.6700000166893005
1499 0.6899999976158142 0.699999988079071
1599 0.6399999856948853 0.6899999976158142
1699 0.6600000262260437 0.6200000047683716
1799 0.5299999713897705 0.6800000071525574
1899 0.41999998688697815 0.6200000047683716
1999 0.3700000047683716 0.7799999713897705