ARTICLE

From Deep Learning with PyTorch by Eli Stevens and Luca Antiga

__________________________________________________________________

Take 37% off Deep Learning with PyTorch. Just enter code fccstevens into the promotional discount code box at checkout at manning.com.
___________________________________________________________________

In this article, we explore some of PyTorch’s capabilities by playing generative adversarial networks.

In part two we saw how to use a pre-trained model for image classification. Now, let’s take a look at GANs.

A pre-trained network that makes stuff up

Models falling under the name of Generative Adversarial Networks (GANs) are one of the most original outcomes of recent deep learning research. We’ll look into these later in our journey. For now, we’ll say that, whereas in standard neural network architectures we’ve one big network optimizing its weights in order to minimize a loss function related to, say, a classification task, in GANs we’ve a couple of networks, named the generator and the discriminator.

Image for post
Figure 1. Concept of a GAN game.

The generator has the task of producing realistic-looking images starting from an input, while the discriminator has to tell whether a given image was fabricated by the generator or it belonged in a set of real images. The end-goal for the generator’s to fool the discriminator into mixing up real and fake images. The end-goal for the discriminator is to find out when it’s being tricked. It’s called the GAN game.

Note that “Discriminator wins” or “Generator wins” shouldn’t be taken literally, as there’s no explicit match between the two. Both networks are associated with cost functions which depend on the outcome of the other network, and which are minimized in turn during training.

This technique lead to generators that produce realistic images out of noise and a conditioning signal, like an attribute, or another image. A well-trained generator learns a plausible model for generating real-world images.

CYCLEGAN

An interesting evolution of this concept’s CycleGAN. A CycleGAN can turn images of one domain into images of another domain (and back), without the need for explicitly providing matching pairs in the training set.

Image for post
Figure 2. A CycleGAN trained to the point where it can fool both discriminator networks.

In CycleGAN, the generator learns to produce an image conforming to a target distribution — Monet paintings, for instance — starting from an image belonging to a different distribution — landscape photos, for instance — to ensure that the discriminator can’t tell if the image produced from a landscape photo is a genuine Monet painting. At the same time, and here’s where the Cycle prefix in the acronym comes in, the resulting painting is sent through a different generator going the other way, Monet to photo in our case (!), to be judged by another discriminator on the other side. Creating such cycle stabilizes the training process considerably, which is one of the original issues with GANs.

The fun part is that, at this point, we don’t need pairs of Monet/photos as ground truths: it’s enough to start from a collection of unrelated Monet works and landscape photos for the generators to learn their task, going beyond a purely supervised setting. The implications of this model go even further than this: the generator learns how to selectively change the appearance of objects in the scene without supervision on what’s what. No signals indicate that water is water and a tree is a tree, but they get translated to something which is the way water and trees are represented in the Monet domain and vice versa.

A NETWORK THAT TURNS HORSES INTO ZEBRAS

An even clearer example’s the Horse2Zebra CycleGAN, which is what we’ll play with right now. In this case, the CycleGAN network was trained on a dataset of (unrelated) horse images and zebra images extracted from the ImageNet dataset. The network learns to take an image of one or more horses and turn them all into zebras, leaving the rest as unmodified as possible. While humankind hasn’t held its breath over the last few million years for a tool that turn horses into zebras, this task showcases the ability of these architectures to model complex real-world processes with distant supervision. While they have their limits, there are hints that in the future we won’t be able to tell real from fake from a live video feed, which opens a can of worms that we’ll duly close right now.

It’s time to play with a pre-trained CycleGAN. This gives us the opportunity to take a step closer and look at how a network, a generator in this case, is implemented. Let’s do it right away: this is what a possible generator architecture for the horse to zebra task looks like. In our case it’s our old friend ResNet. We’ll how show the full source code for the ResnetGenerator class, with the aim of demonstrating that it’s condensed for doing what it does. It takes an image, recognizes one or more horses in it by looking at the pixels and individually modifies the values of those pixels resulting in something that looks like a credible zebra. We won’t recognize anything like that in the source code, because it’s not in there; the network is a scaffold, the juice is in the weights.

Throughout the article we’ll walk ourselves through code piece by piece, trying to provide all the explanations for why things are a certain way. We’ll start off by breaking this rule. We don’t have the tools yet to understand the code in detail, but we can get a feel for what it’s like and what we’ll be able to create at the end of this journey.

# In[1]:
import torch
import torch.nn as nn

class ResNetBlock(nn.Module):

def __init__(self, dim):
super(ResNetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim)

def build_conv_block(self, dim):
conv_block = []

conv_block += [nn.ReflectionPad2d(1)]

conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim),
nn.ReLU(True)]

conv_block += [nn.ReflectionPad2d(1)]

conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
nn.InstanceNorm2d(dim)]

return nn.Sequential(*conv_block)

def forward(self, x):
out = x + self.conv_block(x)
return out


class ResNetGenerator(nn.Module):

def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):

assert(n_blocks >= 0)
super(ResNetGenerator, self).__init__()

self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf

model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
nn.InstanceNorm2d(ngf),
nn.ReLU(True)]

n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=True),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True)]

mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResNetBlock(ngf * mult)]

for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=True),
nn.InstanceNorm2d(int(ngf * mult / 2)),
nn.ReLU(True)]

model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]

self.model = nn.Sequential(*model)

def forward(self, input):
return self.model(input)

Here we’ve declared two classes, a ResNetGenerator and a ResNetBlock. The latter is used from the former, and both of them derive from nn.Module, which is the PyTorch way to specify a portion of a neural network, or in a more elegant way, a piece of differentiable computation. Every instance of nn.Module can be called as a function, with the same arguments as specified in the forward function – input in this case.

Without delving into the details, we can recognize the building blocks, also named modules in PyTorch and commonly named layers in other frameworks, that make up the computation. We can spot linear functions, such as Conv2d, whereby an input image is convolved with learned filters to produce an output, and non-linear functions, such as Tanh and ReLU. All these are instantiated, accumulated in a list, model, and fed to a nn.Sequential container. When called with an input, the latter invokes each contained module with the output of the preceding module as input. This is one of the ways in which models can be defined in PyTorch.

At this point we can instantiate the ResNetGenerator class with the default parameters:

# In[2]:  
netG = ResNetGenerator()

At this point, the model is created, but it contains garbage as weights. We mentioned earlier that we’d run a generator model which was pre-trained on the horse2zebra dataset. The weights of the model are saved in a pth file, which is nothing but a pickle file of the tensor parameters of the model. We can load those into our ResNetGenerator using the load_state_dict method of nn.Module:

# In[3]:
model_path = 'horse2zebra_0.4.0.pth'
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

At this point netG acquired all the knowledge it achieved during training. Note that this is fully equivalent to what happened when we loaded the ResNet101 from torchvision, only that the torchvision.resnet101 function hid it from us.

Let’s put the network in eval mode, as we did for ResNet101:

# In[4]:
netG.eval()

# Out[4]:
ResNetGenerator(
(model): Sequential(
...
)
)

We’re ready to load some random image of a horse and see what our generator produces. First of all, we need to import PIL and torchvision

# In[5]:
from PIL import Image
from torchvision import transforms

Then we define a few input transformations to make sure data enters the network with the right shape and size:

# In[6]:
preprocess = transforms.Compose([transforms.Resize(256),
transforms.ToTensor()])</programlisting>
<simpara>Let&#8217;s open a horse file

Let’s open a horse file:

# In[7]:
img = Image.open("horse.jpg") img
Image for post
Figure 3. A man riding a horse. A horse not having it.

Oh, there’s a dude on the horse. Not for long, judging by the picture. Anyhow, let’s pass it through preprocessing and turn it into a properly shaped variable:

# In[8]:
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)

We shouldn’t worry about the details right now. The important thing is that we follow from a distance. At this point, img_v can be sent to our model

# In[9]:
batch_out = netG(batch_t)

batch_out is now the output of the generator, which we can convert back to an image

# In[10]:
out_t = (batch_out.data.squeeze() + 1.0) / 2.0
out_img = transforms.ToPILImage()(out_t)
# out_img.save('zebra.jpg')
out_img

# Out[10]:
<PIL.Image.Image image mode=RGB size=316x256 at 0x1C0C8E4C550>
Image for post
Figure 4. A man riding a zebra. A zebra not having it.

Oh, man. Who rides a zebra that way? The resulting image isn’t perfect but considering how unusual for the network to find someone riding on top (sort of). It bears repeating that the learning process hasn’t passed through direct supervision, where humans have delineated tens of thousands of horses. The generator learned to produce an image that’d fool the discriminator into thinking that it’s a zebra and there’s nothing fishy with the image (clearly the discriminator has never been to a rodeo).

It’s hard to overstate the implications of this kind of work. Chances are we’ll see a lot of this technology in our future, probably in disparate aspects of our lives.

Numerous other fun generators were developed using adversarial training or with other approaches. Some of them are capable of creating credible human faces of non-existing individuals, while others can translate sketches into real looking pictures of imaginary landscapes. Generative models are also being explored for producing real sounding audio, credible text or enjoyable music. It’s likely that these models will be at the basis of future tools that support the creative process.

This far we’ve had a chance to play with a model that sees into images and a model that generates new images. We’ll end our tour with a model that involves one more, fundamental ingredient: natural language.

A pre-trained network that describes scenes

In order to get first-hand experience with a model involving natural language, we’ll use a pre-trained image captioning model, generously provided by Ruotian Luo and implemented after the work on NeuralTalk2 by Andrej Karpathy. We maintain a clone of the code at [REF]. This kind of models generates a caption in current English describing a scene when presented with a natural image. Again, the interesting part is that the model is trained on a large dataset of images with their sentence description, e.g. “A Tabby cat’s leaning on a wooden table, with one paw on a laser mouse and the other on a black laptop” [REF paper].

Image for post
Figure 5. Concept of a captioning model.

This captioning model has two connected halves. The first half of the network learns to generate “descriptive” numerical representations of the scene (Tabby cat, laser mouse, paw), which are then taken as input to the second half. That half is a recurrent neural network which generates a coherent sentence by putting those descriptions together. The whole architecture is trained end-to-end on image-caption pairs.

A few other proposed captioning models, specifically img2seq, of the seq2seq family, which are versatile kind of models specialized on encoding an input sequence (in this case a sequence of pixels) into a vector, which is then decoded into another sequence (a sequence of characters or words).

NEURALTALK2

Back to the NeuralTalk2 model, we can find it at github.com/deep-learning-with-pytorch/ImageCaptioning.pytorch. We can just place a set of images in the data directory and run the following script

python eval.py --model ./data/FC/fc-model.pth --infos_path ./data/FC/fc-infos.pkl --image_folder ./data

Let’s try with our horse.jpg image. It says “A person riding a horse on a beach”. Quite appropriate.

Now, for fun, let’s see if our CycleGAN can also fool this NeuralTalk2 model. Let’s add the zebra.jpg image in the data folder and rerun the model: “A group of zebras are standing in a field.” Well, it got the animal right, but it saw more than one of them in the image. For sure this isn’t a pose that the network has ever seen in a zebra, nor did it ever see a rider on top of a zebra (with some spurious zebra patterns). In addition, it’s likely that zebras are depicted in groups in the training dataset, and there might be some bias that one could investigate. The captioning network hasn’t seen the rider either. Again, it’s probably for the same reason: the network hasn’t seen a rider on a zebra ever in the training dataset.

In any case, this is an impressive feat: we generated a fake image with an impossible situation and the captioning network was flexible enough to get the subject right.

We’d like to stress that something like this, which was extremely hard to achieve before the advent of deep learning, can be obtained with under a thousand lines of code, with a general-purpose architecture that knows nothing about horses or zebras, and a corpus of images and their descriptions (the MS COCO dataset, in this case). No hard-coded criterion or grammar — everything, including the sentence, is emerging from patterns in the data.

The network architecture in this last case was more complex than the ones we’ve seen earlier — it has a convolutional part and a recurrent part.

That’s where we will stop for now. And remember, this is just a taste of what PyTorch can do.

For more information about the book, check it out on liveBook for free here.

About the author:
Eli Stevens
has worked in Silicon Valley for the past 15 years as a software engineer, and the past 7 years as Chief Technical Officer of a startup making medical device software. Luca Antiga is co-founder and CEO of an AI engineering company located in Bergamo, Italy, and a regular contributor to PyTorch.

Originally published at freecontent.manning.com.

Written by

Follow Manning Publications on Medium for free content and exclusive discounts.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store