Energy-Based models and structured prediction¶

Energy-Based Models (EBMs) discover data dependencies by applying a measure of compatibility (scalar energy) to each configuration of the variables. For a model to make a prediction or decision (inference) it needs to set the value of observed variables to 1 and finding values of the remaining variables that minimize that ā€œenergyā€ level.

In this notebook we're going to work with structured prediction. Structured prediction broadly refers to any problem involving predicting structured values, as opposed to plain scalars. Examples of structured outputs include graphs and text.

We're going to work with text. The task is to transcribe a word from an image. The difficulty here is that different words have different lengths, so we can't just have fixed number of outputs.

We will implement a dynamic programming algorithm to align the text image with the predictions. Optionally, we will also compare our solution to GTN framework.

Resources¶

GTN framework has support for finding Viterbi paths and training the prediction. The below links can be helpful in solving this homework:

  • CTC
  • Weighted Automata in ML

Additional Links:

  • GTN
  • GTN Documentation
  • GTN Applications
InĀ [1]:
# !pip install torchvision==0.16.0
InĀ [2]:
# !pip install gtn==0.0.0

Dataset¶

The first thing to do is implementing the dataset. We're going to create a dataset that creates images of random words. We'll also include some augmentations, such as jitter (moving the character horizontally).

InĀ [3]:
! mkdir fonts
! curl -L https://drive.google.com/uc\?id\=\{12c-EkGHJlYA9dE7nXEEJkXxwmXarMSAk\} -o fonts/Anonymous.ttf
mkdir: fonts: File exists
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 60060  100 60060    0     0  37762      0  0:00:01  0:00:01 --:--:--  124k
InĀ [4]:
from PIL import ImageDraw, ImageFont
import string
import random
import torch
from torch.optim import Adam
from tqdm.notebook import tqdm
import torchvision
from torchvision import transforms
from PIL import Image # PIL is a library to process images
from matplotlib import pyplot as plt
import gtn
torch.manual_seed(0)

simple_transforms = transforms.Compose([
                                    transforms.ToTensor(), 
                                ])

class SimpleWordsDataset(torch.utils.data.IterableDataset):

  def __init__(self, max_length, len=100, jitter=False, noise=False):
    self.max_length = max_length
    self.transforms = transforms.ToTensor()
    self.len = len
    self.jitter = jitter
    self.noise = noise
  
  def __len__(self):
    return self.len

  def __iter__(self):
    for _ in range(self.len):
        text = ''.join([random.choice(string.ascii_lowercase) for i in range(self.max_length)])
        img = self.draw_text(text, jitter=self.jitter, noise=self.noise)
        yield img, text
  
  def draw_text(self, text, length=None, jitter=False, noise=False):
    if length == None:
        length = 18 * len(text)
    img = Image.new('L', (length, 32))
    fnt = ImageFont.truetype("fonts/Anonymous.ttf", 20)

    d = ImageDraw.Draw(img)
    pos = (0, 5)
    if jitter:
        pos = (random.randint(0, 7), 5)
    else:
        pos = (0, 5)
    d.text(pos, text, fill=1, font=fnt)

    img = self.transforms(img)
    img[img > 0] = 1 
    
    if noise:
        img += torch.bernoulli(torch.ones_like(img) * 0.1)
        img = img.clamp(0, 1)
        

    return img[0]

sds = SimpleWordsDataset(1, jitter=True, noise=False)
img = next(iter(sds))[0]
plt.imshow(img)
Out[4]:
<matplotlib.image.AxesImage at 0x1593c4ac0>
No description has been provided for this image

We can look at what the entire alphabet looks like in this dataset.

InĀ [5]:
fig, ax = plt.subplots(3, 9, figsize=(12, 6), dpi=200)

for i, c in enumerate(string.ascii_lowercase):
    row = i // 9
    col = i % 9
    ax[row][col].imshow(sds.draw_text(c))
    ax[row][col].axis('off')
ax[2][8].axis('off')
    
plt.show()
No description has been provided for this image

We can also put the entire alphabet in one image.

InĀ [6]:
alphabet = sds.draw_text(string.ascii_lowercase, 340)
plt.figure(dpi=200)
plt.imshow(alphabet)
plt.axis('off')
Out[6]:
(-0.5, 339.5, 31.5, -0.5)
No description has been provided for this image

Model definition¶

Before we define the model, we define the size of our alphabet. Our alphabet consists of lowercase English letters, and additionally a special character used for space between symbols or before and after the word. For the first part of this assignment, we don't need that extra character.

Our end goal is to learn to transcribe words of arbitrary length. However, first, we pre-train our simple convolutional neural net to recognize single characters. In order to be able to use the same model for one character and for entire words, we are going to design the model in a way that makes sure that the output size for one character (or when input image size is 32x18) is 1x27, and Kx27 whenever the input image is wider. K here will depend on particular architecture of the network, and is affected by strides, poolings, among other things. A little bit more formally, our model $f_\theta$, for an input image $x$ gives output energies $l = f_\theta(x)$. If $x \in \mathbb{R}^{32 \times 18}$, then $l \in \mathbb{R}^{1 \times 27}$. If $x \in \mathbb{R}^{32 \times 100}$ for example, our model may output $l \in \mathbb{R}^{10 \times 27}$, where $l_i$ corresponds to a particular window in $x$, for example from $x_{0, 9i}$ to $x_{32, 9i + 18}$ (again, this will depend on the particular architecture).

Below is a drawing that explains the sliding window concept. We use the same neural net with the same weights to get $l_1, l_2, l_3$, the only difference is receptive field. $l_1$ is looks at the leftmost part, at character 'c', $l_2$ looks at 'a', and $l_3$ looks at 't'. The receptive field may or may not overlap, depending on how you design your convolutions.

cat.png

InĀ [7]:
# constants for number of classes in total, and for the special extra character for empty space
ALPHABET_SIZE = 27
BETWEEN = 26
InĀ [8]:
from torch import nn

class SimpleNet(torch.nn.Module):

    def __init__(self):
        super().__init__()
        # TODO
        self.conv=nn.Conv2d(1, 512, kernel_size=(32, 18), stride=(1, 4), padding="valid")
        self.linear=nn.Linear(512, 27)
        
    def forward(self, x):
        # TODO
        return self.linear(self.conv(x).squeeze(axis=-2).permute(0, 2, 1))
        

Let's initalize the model and apply it to the alphabet image:

InĀ [9]:
model = SimpleNet()
alphabet_energies = model(alphabet.view(1, 1, *alphabet.shape))

def plot_energies(ce):
    fig=plt.figure(dpi=200)
    ax = plt.axes()
    im = ax.imshow(ce.cpu().T)
    
    ax.set_xlabel('window locations →')
    ax.set_ylabel('← classes')
    ax.xaxis.set_label_position('top') 
    ax.set_xticks([])
    ax.set_yticks([])
    
    cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.02,ax.get_position().height])
    plt.colorbar(im, cax=cax) 
    
plot_energies(alphabet_energies[0].detach())
No description has been provided for this image

So far we only see random outputs, because the classifier is untrained.

Train with one character¶

Now we train the model we've created on a dataset where images contain only single characters. Note the changed cross_entropy function.

InĀ [10]:
LR=1e-2
EPOCHS=15
InĀ [11]:
def train_model(model, epochs, dataloader, criterion, optimizer):
    # TODO
    model.train()
    pbar=tqdm(range(epochs))
    for epoch in pbar:
        train_loss=0.0
        for images, target in dataloader:
            # images=images.unsqueeze(1).cuda()
            images=images.unsqueeze(1)
            # target=target.cuda()
            
            optimizer.zero_grad()
            
            out=model(images)
            loss=criterion(out.squeeze(), target=target)
            loss.backward()
            optimizer.step()
            train_loss+=loss.item()
        train_loss/=len(dataloader)
        pbar.set_postfix({'Train Loss': train_loss})
InĀ [12]:
from tqdm.notebook import tqdm
import torch.optim as optim

def cross_entropy(energies, *args, **kwargs):
    """ We use energies, and therefore we need to use log soft arg min instead
        of log soft arg max. To do that we just multiply energies by -1. """
    return nn.functional.cross_entropy(-1 * energies, *args, **kwargs)

def simple_collate_fn(samples):
    images, annotations = zip(*samples)
    images = list(images)
    annotations = list(annotations)
    annotations = list(map(lambda c : torch.tensor(ord(c) - ord('a')), annotations))
    m_width = max(18, max([i.shape[1] for i in images]))
    for i in range(len(images)):
        images[i] = torch.nn.functional.pad(images[i], (0, m_width - images[i].shape[-1]))
        
    if len(images) == 1:
        return images[0].unsqueeze(0), torch.stack(annotations)
    else:
        return torch.stack(images), torch.stack(annotations)

sds = SimpleWordsDataset(1, len=1000, jitter=True, noise=False)
dataloader = torch.utils.data.DataLoader(sds, batch_size=16, num_workers=0, collate_fn=simple_collate_fn)

# model.cuda()
# TODO: initialize optimizer
optimizer=Adam(model.parameters(), lr=LR)
# TODO: train the model on the one-character dataset
train_model(model, EPOCHS, dataloader, cross_entropy, optimizer)
  0%|          | 0/15 [00:00<?, ?it/s]
InĀ [13]:
def get_accuracy(model, dataset):
    cnt = 0
    for i, l in dataset:
        # energies = model(i.unsqueeze(0).unsqueeze(0).cuda())[0, 0]
        energies = model(i.unsqueeze(0).unsqueeze(0))[0, 0]
        x = energies.argmin(dim=-1)
        cnt += int(x == (ord(l[0]) - ord('a')))
    return cnt / len(dataset)
        
tds = SimpleWordsDataset(1, len=100)
assert get_accuracy(model, tds) == 1.0, 'Your model doesn\'t achieve 100% accuracy for 1 character'

Now, to see how our model would work with more than one character, we apply the model to a bigger input - the image of the alphabet we saw earlier. We extract the energies for each window and show them.

InĀ [14]:
# alphabet_energies_post_train = model(alphabet.cuda().view(1, 1, *alphabet.shape))
alphabet_energies_post_train = model(alphabet.view(1, 1, *alphabet.shape))
plot_energies(alphabet_energies_post_train[0].detach())
No description has been provided for this image

Training with multiple characters¶

Now, we want to train our model to not only recognize the letters, but also to recognize space in-between so that we can use it for transcription later.

This is where complications begin. When transcribing a word from an image, we don't know beforehand how long the word is going to be. We can use our convolutional neural network we've pretrained on single characters to get prediction of character probabilities for all the positions of an input window in the new input image, but we don't know beforehand how to match those predictions with the target label. Training with incorrect matching can lead to wrong model, so in order to be able to train a network to transcribe words, we need a way to find these pairings.

dl.png

The importance of pairings can be demonstrated by the drawing above. If we map $l_1, l_2, l_3, l_4$ to 'c', 'a', 't', '_' respectively, we'll correctly train the system, but if we put $l_1, l_2, l_3, l_4$ with 'a', 'a', 't', 't', we'd have a very wrong classifier.

To formalize this, we use energy-based models' framework. Let's define the energy $E(x, y, z)$ as the sum of cross-entropies for a particular pairing between probabilities our model gives for input image $x$ and text transcription $y$, and pairing $z$. $z$ is a function $z : \{1, 2, \dots, \vert l \vert \} \to \{1, 2, \dots, \vert y \vert)$, $l$ here is the energies output of our convolutional neural net $l = f_\theta(x)$. $z$ maps each energy vector in $l$ to an element in the output sequence $y$. We want the mappings to make sense, so $z$ should be a non-decreasing function $z(i) \leq z(i+1)$, and it shouldn't skip characters, i.e. $\forall_i \exists_j z(j)=i$.

Energy is then $E(x, y, z) = C(z) + \sum_{i=1}^{\vert l \vert} l_i[z(i)]$ , $C(z)$ is some extra term that allows us to penalize certain pairings, and $l_i[z(i)]$ is the energy of $z(i)$-th symbol on position $i$.

In this particular context, we define $C(z)$ to be infinity for impossible pairings: $$C(z) = \begin{cases} \infty \; \text{if} \; z(1) \neq 1 \vee z(\vert l \vert) \neq \vert y \vert \vee \exists_{i, 1\leq 1 \leq \vert l \vert - 1} z(i) > z(i+1) \vee z(i) < z(i+1) - 1\\ 0 \; \text{otherwise} \end{cases} $$

Then, the free energy $F(x, y) = \arg \min_z E(x, y, z)$. In other words, the free energy is the energy of the best pairing between the probabilities provided by our model, and the target labels.

When training, we are going to use cross-entropies along the best path: $\ell(x, y, z) = \sum_{i=1}^{\vert l \vert}H(y_{z(i)}, \sigma(l_i))$, where $H$ is cross-entropy, $\sigma$ is soft-argmin needed to convert energies to a distribution.

First, let's write functions that would calculate the needed cross entropies $H(y_{z(i)}, \sigma(l_i))$, and energies for us.

InĀ [15]:
def build_path_matrix(energies, targets):
    # inputs: 
    #    energies, shape is BATCH_SIZE x L x 27
    #    targets, shape is BATCH_SIZE x T
    # L is \vert l \vert
    # T is \vert y \vert
    # 
    # outputs:
    #    a matrix of shape BATCH_SIZE x L x T
    #    where output[i, j, k] = energies[i, j, targets[i, k]]
    #
    # Note: you're not allowed to use for loops. The calculation has to be vectorized.
    # you may want to use repeat and repeat_interleave.
    # TODO
    batch_size=energies.shape[0]
    L=energies.shape[1]
    T=targets.shape[-1]
    targets=targets.unsqueeze(1).repeat(1, L, 1)
    # output=torch.gather(energies, 2, targets.cuda())
    output=torch.gather(energies, 2, targets)
    return output

def build_ce_matrix(energies, targets):
    # inputs: 
    #    energies, shape is BATCH_SIZE x L x 27
    #    targets, shape is BATCH_SIZE x T
    # L is \ververtt = targets.unsqueeze(1).repeat(1,energies.shape[1],1)t l \vert
    # T is \vert y \vert
    # 
    # outputs:
    #    a matrix ce of shape BATCH_SIZE x L x T
    #    where ce[i, j, k] = cross_entropy(energies[i, j], targets[i, k])
    #
    # Note: you're not allowed to use for loops. The calculation has to be vectorized.
    # you may want to use repeat and repeat_interleave.
    # TODO
    batch_size=energies.shape[0]
    L=energies.shape[1]
    T=targets.shape[-1]
    energies= energies.permute(0, 2, 1).unsqueeze(-1).repeat(1,1,1, T)
    targets = targets.unsqueeze(1).repeat(1, L, 1)
    return cross_entropy(energies, targets, reduction='none')

Another thing we will need is a transformation for our label $y$. We don't want to use it as is, we want to insert some special label after each character, so, for example 'cat' becomes 'c_a_t_'. This extra '_' models the separation between words, allowing our model to distinguish between strings 'aa' and 'a' in its output. This is then used in inference - we can just get the most likely character for each position from $l = f_\theta(x)$ (for example 'aa_bb_ccc_'), and then remove duplicate characters ('a_b_c_'), and then remove _ (abc). Let's implement a function that would change the string in this manner, and then map all characters to values from 0 to 26, with 0 to 25 corresponding to a-z, and 26 corresponding to _:

InĀ [16]:
def transform_word(s):
    # input: a string
    # output: a tensor of shape 2*len(s)
    # TODO
    encoded_str=[]
    for c in s:
        encoded_str.append(ord(c)-ord('a'))
        encoded_str.append(26)
    return torch.tensor(encoded_str)

Now, let's plot energy table built on our model's prediction for alphabet image.

InĀ [17]:
def plot_pm(pm, path=None):
    fig=plt.figure(dpi=200)
    ax = plt.axes()
    im = ax.imshow(pm.cpu().T)
    
    ax.set_xlabel('window locations →')
    ax.set_ylabel('← label characters')
    ax.xaxis.set_label_position('top') 
    ax.set_xticks([])
    ax.set_yticks([])
    
    if path is not None:
        for i in range(len(path) - 1):
            ax.plot(*path[i], *path[i+1], marker = 'o', markersize=0.5, linewidth=10, color='r', alpha=1)

    cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.02,ax.get_position().height])
    plt.colorbar(im, cax=cax) 

# energies = model(alphabet.cuda().view(1, 1, *alphabet.shape))
energies = model(alphabet.view(1, 1, *alphabet.shape))
targets = transform_word(string.ascii_lowercase).unsqueeze(0)


pm = build_path_matrix(energies, targets)
plot_pm(pm[0].detach())
No description has been provided for this image

Now let's implement a function that would tell us the energy of a particular path (i.e. pairing).

InĀ [18]:
def checkValidMapping(path, T):
    for i in range(1, len(path)):
        if path[i]<path[i-1]:
            return False
    return True

def path_energy(pm, path):
    # inputs:
    #   pm - a matrix of energies 
    #    L - energies length
    #    T - targets length
    #   path - list of length L that maps each energy vector to an element in T
    # returns:
    #   energy - sum of energies on the path, or 2**30 if the mapping is invalid
    # TODO
    T=pm.shape[1]
    if checkValidMapping(path, T):
        energy=0.0
        for i, c in enumerate(path):
            energy+=pm[i,c]
        return energy
    else:
        return torch.tensor(2**30)

Now we can check some randomly generated paths and see the associated energies for our alphabet image:

InĀ [19]:
path = torch.zeros(energies.shape[1] - 1)
path[:targets.shape[1] - 1] = 1
path = [0] + list(map(lambda x : x.int().item(), path[torch.randperm(path.shape[0])].cumsum(dim=-1)))
points = list(zip(range(energies.shape[1]), path))

plot_pm(pm[0].detach(), points)
print('energy is', path_energy(pm[0], path).item())
energy is -1444.7196044921875
No description has been provided for this image

Now, generate two paths with bad energies, print their energies and plot them.

InĀ [20]:
# TODO
def getBadPath(pm, topk=1):
    path=[]
    for i in range(pm.shape[1]):
        path.append(torch.topk(pm.squeeze()[i], topk)[1][topk-1].item())
    return path
InĀ [21]:
bad_path1 = getBadPath(pm)
bad_points1 = list(zip(range(energies.shape[1]), bad_path1))
plot_pm(pm[0].detach(), bad_points1)
print('energy is', path_energy(pm[0], bad_path1).item())
energy is 1073741824
No description has been provided for this image
InĀ [22]:
bad_path2 = getBadPath(pm,2)
bad_points2 = list(zip(range(energies.shape[1]), bad_path2))
plot_pm(pm[0].detach(), bad_points2)
print('energy is', path_energy(pm[0], bad_path2).item())
energy is 1073741824
No description has been provided for this image

Part - 1: Viterbi¶

Optimal path finding¶

Now, we're going to implement the finding of the optimal path. To do that, we're going to use Viterbi algorithm, which in this case is a simple dynamic programming problem. In this context, it's a simple dynamic programming algorithm that for each pair i, j, calculates the minimum cost of the path that goes from 0-th index in the energies and 0-th index in the target, to i-th index in the energies, and j-th index in the target. We can memorize the values in a 2-dimensional array, let's call it dp. Then we have the following transitions:

dp[0, 0] = pm[0, 0]
dp[i, j] = min(dp[i - 1, j], dp[i - 1, j - 1]) + pm[i, j]

The optimal path can be recovered if we memorize which cell we came from for each dp[i, j].

Below, you'll need to implement this algorithm:

InĀ [23]:
def find_path(pm):
    # inputs:
    #   pm - a tensor of shape LxT with energies
    #     L is length of energies array
    #     T is target sequence length
    # NOTE: this is slow because it's not vectorized to work with batches.
    #  output:
    #     a tuple of three elements:
    #         1. sum of energies on the best path,
    #         2. list of tuples - points of the best path in the pm matrix 
    #         3. the dp array

    # TODO
    L=pm.shape[0]
    T=pm.shape[1]
    dp=torch.tensor([[0.0]*T for _ in range(L)], device=pm.device)
    direction_array=[[None]*T for _ in range(L)]
    dp[0][0]=pm[0][0]
    direction_array[0][0]=(0,0)
    
    for j in range(1, T):
        dp[0][j]=2**30
        direction_array[0][j]=(0,j)
        
    for i in range(1, L):
        dp[i][0] = dp[i-1][0] + pm[i][0]
        direction_array[i][0]=(i-1,0)
        
    for i in range(1, L):
        for j in range(1, T):
            dp[i][j] = min(dp[i-1][j], dp[i-1][j-1]) + pm[i][j]
            if dp[i-1][j] < dp[i-1][j-1]:
                direction_array[i][j]=(i-1, j)
            else:
                direction_array[i][j]=(i-1, j-1)
    

    path=[]
    j=dp[L-1].min(-1)[1].item()
    for i in range(L-1, -1, -1):
        path.append(j)
        direction = direction_array[i][j]
        j=direction[1]
    
    path.reverse()
        
    points = list(zip(range(pm.shape[0]), path))
    return tuple([path_energy(pm, path), points, dp])

Let's take a look at the best path:

InĀ [24]:
free_energy, path, d = find_path(pm[0])
plot_pm(pm[0].cpu().detach(), path)
print('free energy is', free_energy.item())
free energy is -10255.6416015625
No description has been provided for this image

We can also visualize the dp array. You may need to tune clamping to see what it looks like.

InĀ [25]:
plt.figure(dpi=200)
# print(d)
plt.imshow(d.cpu().detach().T.clamp(torch.min(d).item(), 200))
plt.axis('off')
Out[25]:
(-0.5, 80.5, 51.5, -0.5)
No description has been provided for this image

Training loop¶

Now is time to train the network using our best path finder. We're going to use the energy loss function: $$\ell(x, y) = \sum_i H(y_{z(i)}, l_i)$$ Where $z$ is the best path we've found. This is akin to pushing down on the free energy $F(x, y)$, while pushing up everywhere else by nature of cross-entropy.

InĀ [26]:
def train_ebm_model(model, num_epochs, train_loader, criterion, optimizer):
    ''' Train EBM Model using find_path()'''
    pbar = tqdm(range(num_epochs))
    size = 0
    free_energies = []
    paths = []
    model.train()
    for epoch in pbar:
        total_train_loss = 0.0
        start_time = time.time()
        # TODO: implement the training loop
        for samples, targets in train_loader:
            optimizer.zero_grad()
            
            # samples=samples.cuda()
            # targets=targets.cuda()
            energies=model(samples.unsqueeze(1))
            pm=build_path_matrix(energies, targets)
            batch_loss=[]
            for b_index in range(pm.shape[0]):
                free_energy, best_path, _ = find_path(pm[b_index])
                target_indices=[ind[1] for ind in best_path]
                batch_loss.append(criterion(energies[b_index], targets[b_index, target_indices]))
            
            loss=sum(batch_loss)
            total_train_loss+=loss.item()
            loss.backward()
            optimizer.step()

        epoch_time = time.time() - start_time
        pbar.set_postfix({'train_loss': total_train_loss / len(sds), 'Epoch Time': epoch_time})

    return
InĀ [27]:
LR=1e-3
EPOCHS=15
InĀ [28]:
import copy
import time

def collate_fn(samples):
    """ A function to collate samples into batches for multi-character case"""
    images, annotations = zip(*samples)
    images = list(images)
    annotations = list(annotations)
    annotations = list(map(transform_word, annotations))
    m_width = max(18, max([i.shape[1] for i in images]))
    m_length = max(3, max([s.shape[0] for s in annotations]))
    for i in range(len(images)):
        images[i] = torch.nn.functional.pad(images[i], (0, m_width - images[i].shape[-1]))
        annotations[i] = torch.nn.functional.pad(annotations[i], (0, m_length - annotations[i].shape[0]), value=BETWEEN)
    if len(images) == 1:
        return images[0].unsqueeze(0), torch.stack(annotations)
    else:
        return torch.stack(images), torch.stack(annotations)
    
sds = SimpleWordsDataset(2, 2500) # for simplicity, we're training only on words of length two

BATCH_SIZE = 32
dataloader = torch.utils.data.DataLoader(sds, batch_size=BATCH_SIZE, num_workers=0, collate_fn=collate_fn)

# TODO: Make a copy of your model and re-initialize optimizer

# TODO: train the model using the train_ebm_model()
# note: remember that our best path finding algorithm is not batched, so you'll
# need a for loop to do loss calculation. 
# This is not ideal, as for loops are very slow, but for 
# demonstration purposes it will suffice. In practice, this will be
# unusable for any real problem unless it handles batching.
# also: remember that the loss is the sum of cross_entropies along the path, not 
# energies!
ebm_model=copy.deepcopy(model)
# ebm_model.cuda()
ebm_model.train()
optimizer=Adam(ebm_model.parameters(), lr=LR)
train_ebm_model(ebm_model, EPOCHS, dataloader, cross_entropy, optimizer)
  0%|          | 0/15 [00:00<?, ?it/s]

Let's check what the energy matrix looks like for the alphabet image now.

InĀ [29]:
# energies = ebm_model(alphabet.unsqueeze(0).unsqueeze(0).cuda())
energies = ebm_model(alphabet.unsqueeze(0).unsqueeze(0))
targets = transform_word(string.ascii_lowercase)
pm = build_path_matrix(energies, targets.unsqueeze(0))

free_energy, path, _ = find_path(pm[0])
plot_pm(pm[0].detach(), path)
print('free energy is', free_energy.item())
free energy is -13971.6572265625
No description has been provided for this image

We can also look at raw energies output:

InĀ [30]:
# alphabet_energy_post_train_viterbi = ebm_model(alphabet.cuda().view(1, 1, *alphabet.shape))
alphabet_energy_post_train_viterbi = ebm_model(alphabet.view(1, 1, *alphabet.shape))


plt.figure(dpi=200, figsize=(40, 10))
plt.imshow(alphabet_energy_post_train_viterbi.cpu().data[0].T)
plt.axis('off')
Out[30]:
(-0.5, 80.5, 26.5, -0.5)
No description has been provided for this image

Decoding¶

Now we can use the model for decoding a word from an image. Let's pick some word, apply the model to it, and see energies.

InĀ [31]:
img = sds.draw_text('hello')
# energies = ebm_model(img.cuda().unsqueeze(0).unsqueeze(0))
energies = ebm_model(img.unsqueeze(0).unsqueeze(0))
plt.imshow(img)
plot_energies(energies[0].detach().cpu())
No description has been provided for this image
No description has been provided for this image

You should see some characters light up. Now, let's implement a simple decoding algorithm. To decode, first we want to get most likely classes for all energies, and then do two things:

  1. segment strings using the divisors (our special character with index 26), and for each segment replace it with the most common character in that segment. Example: aaab_bab_ -> a_b. If some characters are equally common, you can pick random.
  2. remove all special divisor characters: a_b -> ab
InĀ [32]:
from collections import Counter
InĀ [33]:
def indices_to_str(indices):
    # inputs: indices - a tensor of most likely class indices
    # outputs: decoded string
    
    # TODO
    out_str=[]
    for ind in indices:
        if ind==26:
            out_str.append("_")
        else:
            out_str.append(chr(ind+ord('a')))
    seperated_words="".join(out_str).split("_")
    filtered_chars=[Counter(s).most_common(1)[0][0] for s in seperated_words if s]
    return "_".join(filtered_chars)
    
min_indices = energies[0].argmin(dim=-1)
print(indices_to_str(min_indices))
h_e_l_o_i

Part - 2: GTN¶

CTC Loss Function¶

The below cell has the CTC Loss implementation which will be needed for the training. It also has the viterbi implemented, so feel free to use it.

  • For calculating loss, CTCLoss() can be used.
  • For viterbi, CTCLossFunction.viterbi() can be used.

Brief Explanation:¶

GTN constructs 2 graphs-

  1. Alignment graph $A_y$ that can produce a sequence of any length which when collapsed gives the output.
  2. Emission graph $\mathcal{E}$ that can produce any sequence of length L, input's length. This can be considered as some sort of restriction over $A_y$.

Loss function is then calculated as $-\log P(y \mid X) = -(\log (A_y \circ \mathcal{E}) - \log \mathcal{E})$ where $\circ$ is the composition function.

InĀ [34]:
# Credits: Adapted from https://github.com/facebookresearch/gtn_applications
import torch.nn.functional as F
import torch.nn as nn

class CTCLossFunction(torch.autograd.Function):
    @staticmethod
    def create_ctc_graph(target, blank_idx):
        g_criterion = gtn.Graph(False)
        L = len(target)
        S = 2 * L + 1
        for l in range(S):
            idx = (l - 1) // 2
            g_criterion.add_node(l == 0, l == S - 1 or l == S - 2)
            label = target[idx] if l % 2 else blank_idx
            g_criterion.add_arc(l, l, label)
            if l > 0:
                g_criterion.add_arc(l - 1, l, label)
            if l % 2 and l > 1 and label != target[idx - 1]:
                g_criterion.add_arc(l - 2, l, label)
        g_criterion.arc_sort(False)
        return g_criterion

    @staticmethod
    def forward(ctx, log_probs, targets, blank_idx=0, reduction="none"):
        B, T, C = log_probs.shape
        losses = [None] * B
        scales = [None] * B
        emissions_graphs = [None] * B

        def process(b):
            # create emission graph
            g_emissions = gtn.linear_graph(T, C, log_probs.requires_grad)
            cpu_data = log_probs[b].cpu().contiguous()
            g_emissions.set_weights(cpu_data.data_ptr())

            # create criterion graph
            g_criterion = CTCLossFunction.create_ctc_graph(targets[b], blank_idx)
            # compose the graphs
            g_loss = gtn.negate(
                gtn.forward_score(gtn.intersect(g_emissions, g_criterion))
            )

            scale = 1.0
            if reduction == "mean":
                L = len(targets[b])
                scale = 1.0 / L if L > 0 else scale
            elif reduction != "none":
                raise ValueError("invalid value for reduction '" + str(reduction) + "'")

            # Save for backward:
            losses[b] = g_loss
            scales[b] = scale
            emissions_graphs[b] = g_emissions

        gtn.parallel_for(process, range(B))

        ctx.auxiliary_data = (losses, scales, emissions_graphs, log_probs.shape)
        loss = torch.tensor([losses[b].item() * scales[b] for b in range(B)])
        return torch.mean(loss.cuda() if log_probs.is_cuda else loss)

    @staticmethod
    def backward(ctx, grad_output):
        losses, scales, emissions_graphs, in_shape = ctx.auxiliary_data
        B, T, C = in_shape
        input_grad = torch.empty((B, T, C))

        def process(b):
            gtn.backward(losses[b], False)
            emissions = emissions_graphs[b]
            grad = emissions.grad().weights_to_numpy()
            input_grad[b] = torch.from_numpy(grad).view(1, T, C) * scales[b]

        gtn.parallel_for(process, range(B))

        if grad_output.is_cuda:
            input_grad = input_grad.cuda()
        input_grad *= grad_output / B

        return (
            input_grad,
            None,  # targets
            None,  # blank_idx
            None,  # reduction
        )
    
def viterbi(energies, targets, blank_idx=0):
    outputs = -1 * energies
    B, T, C = outputs.shape
    paths = [None] * B
    scores = [None] * B
    emissions_graphs = [None] * B
    def process(b):
        L = len(targets[b])
        # create emission graph
        g_emissions = gtn.linear_graph(T, C, outputs.requires_grad)
        cpu_data = outputs[b].cpu().contiguous()
        g_emissions.set_weights(cpu_data.data_ptr())

        # create criterion graph
        g_criterion = CTCLossFunction.create_ctc_graph(targets[b], blank_idx)
        g_score = gtn.viterbi_score(gtn.intersect(g_emissions, g_criterion))
        g_path = gtn.viterbi_path(gtn.intersect(g_emissions, g_criterion))
        paths[b] = [2*p if 2*p < L else i for i, p in enumerate(g_path.labels_to_list())]
        l = 0
        paths[b] = []
        for p in g_path.labels_to_list():
            if 2*p < L:
                l = p
                paths[b].append(2*p)
            else:
                paths[b].append(2*l + 1)
        scores[b] = -1 * g_score.item()
        emissions_graphs[b] = g_emissions

    gtn.parallel_for(process, range(B))

    return (scores, paths)

CTCLoss = CTCLossFunction.apply
InĀ [35]:
def train_gtn_model(model, num_epochs, train_loader, criterion, optimizer):
    ''' Train CTC Model using GTN'''
    pbar = tqdm(range(num_epochs))
    train_losses = []
    size = 0
    free_energies = []
    paths = []
    max_grad_norm = None
    if torch.cuda.is_available():
        model = model.cuda()
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    for epoch in pbar:
        total_train_loss = 0.0
        start_time = time.time()
        # TODO: implement the training loop
        for samples, targets in train_loader:
            optimizer.zero_grad()
            samples=samples.to(device)
            targets=targets.to(device)
            
            outputs = model(samples.unsqueeze(1))

            log_probs = F.log_softmax(-1.0*outputs, dim=-1)
            loss=criterion(log_probs, targets)
            total_train_loss+=loss.item()
            loss.backward()
            optimizer.step()
        epoch_time = time.time() - start_time
        train_losses.append(total_train_loss)
        pbar.set_postfix({'train_loss': total_train_loss/len(sds), 'Epoch Time': epoch_time})

    return
InĀ [36]:
LR=1e-3
EPOCHS=15
InĀ [37]:
# Similar to what we have done earlier but instead of find_path(), we will use
# GTN's framework.
sds = SimpleWordsDataset(3, 2500) # for simplicity, we're training only on words of length two

BATCH_SIZE = 32
dataloader = torch.utils.data.DataLoader(sds, batch_size=BATCH_SIZE, num_workers=0, collate_fn=collate_fn)

# TODO: Make another copy of the single character model and re-initialize optimizer
gtn_model=copy.deepcopy(model)
optimizer=Adam(gtn_model.parameters(), lr=LR)
# TODO: train the model
# note: remember that our best path finding algorithm is not batched, so you'll
# need a for loop to do loss calculation. 
# This is not ideal, as for loops are very slow, but for 
# demonstration purposes it will suffice. In practice, this will be
# unusable for any real problem unless it handles batching.

# also: remember that the loss is the sum of cross_entropies along the path, not 
# energies!
# gtn_model.cuda()
gtn_model.train()
train_gtn_model(gtn_model, EPOCHS, dataloader, CTCLoss, optimizer)
  0%|          | 0/15 [00:00<?, ?it/s]
InĀ [38]:
# energies = gtn_model(alphabet.unsqueeze(0).unsqueeze(0).cuda())
energies = gtn_model(alphabet.unsqueeze(0).unsqueeze(0))
targets = transform_word(string.ascii_lowercase)
pm = build_path_matrix(energies, targets.unsqueeze(0))

# TODO: Use the provided viterbi function to get score and path
# print(targets.shape, energies.shape)
score, path = viterbi(energies, targets.unsqueeze(0))
# print(path)
# path is obtained from the above TODO
points = list(zip(range(energies.shape[1]), path[0]))
plot_pm(pm[0].cpu().detach(), points)
print('energy is', score[0])
energy is -13924.0771484375
No description has been provided for this image
InĀ [39]:
img = sds.draw_text('hello')
# energies = gtn_model(img.cuda().unsqueeze(0).unsqueeze(0))
energies = gtn_model(img.unsqueeze(0).unsqueeze(0))
plt.imshow(img)
plot_energies(energies[0].detach().cpu())
No description has been provided for this image
No description has been provided for this image
InĀ [40]:
min_indices = energies[0].argmin(dim=-1)
print(indices_to_str(min_indices))
h_e_l_o_a

Part - 3: Train Model with no pretraining¶

In part - 1 and part - 2, we have trained a model on single character first and then trained on multiple character sequence. Here, we will train a model using GTN directly on multiple character sequence.

Additionally, we will experiment this on custom handwritten data.

InĀ [41]:
# TODO
# Training the model on Anonymous data without pretraining
LR=1e-3
EPOCHS=20
sds = SimpleWordsDataset(3, 2500)
BATCH_SIZE = 32
dataloader = torch.utils.data.DataLoader(sds, batch_size=BATCH_SIZE, num_workers=0, collate_fn=collate_fn)

gtn_no_pretrained_model=SimpleNet()
optimizer=Adam(gtn_no_pretrained_model.parameters(), lr=LR)
# gtn_no_pretrained_model.cuda()
gtn_no_pretrained_model.train()
train_gtn_model(gtn_no_pretrained_model, EPOCHS, dataloader, CTCLoss, optimizer)
  0%|          | 0/20 [00:00<?, ?it/s]
InĀ [42]:
# energies = gtn_no_pretrained_model(alphabet.unsqueeze(0).unsqueeze(0).cuda())
energies = gtn_no_pretrained_model(alphabet.unsqueeze(0).unsqueeze(0))
targets = transform_word(string.ascii_lowercase)
pm = build_path_matrix(energies, targets.unsqueeze(0))

score, path = viterbi(energies, targets.unsqueeze(0))

points = list(zip(range(energies.shape[1]), path[0]))
plot_pm(pm[0].cpu().detach(), points)
print('energy is', score[0])
energy is -1631.3629150390625
No description has been provided for this image
InĀ [43]:
img = sds.draw_text('hello')
# energies = gtn_no_pretrained_model(img.cuda().unsqueeze(0).unsqueeze(0))
energies = gtn_no_pretrained_model(img.unsqueeze(0).unsqueeze(0))
plt.imshow(img)
plot_energies(energies[0].detach().cpu())
No description has been provided for this image
No description has been provided for this image
InĀ [44]:
min_indices = energies[0].argmin(dim=-1)
print(indices_to_str(min_indices))
h_e_l_d_r

(Experimental) Collect or get your own handwritten dataset and test the model on that.

InĀ [45]:
#TODO
# Downloading custom data
! curl --output fonts/customFont.zip https://www.fontsquirrel.com/fonts/download/3dumb
! unzip -n fonts/customFont.zip -d fonts/3dumb
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  230k  100  230k    0     0   387k      0 --:--:-- --:--:-- --:--:--  387k
Archive:  fonts/customFont.zip
InĀ [46]:
class CustomWordsDataset(torch.utils.data.IterableDataset):

  def __init__(self, max_length, len=100, jitter=False, noise=False, custom_fonts_path=None):
    self.max_length = max_length
    self.transforms = transforms.ToTensor()
    self.len = len
    self.jitter = jitter
    self.noise = noise
    self.custom_fonts_path=custom_fonts_path
  
  def __len__(self):
    return self.len

  def __iter__(self):
    for _ in range(self.len):
        text = ''.join([random.choice(string.ascii_lowercase) for i in range(self.max_length)])
        img = self.draw_text(text, jitter=self.jitter, noise=self.noise)
        yield img, text
  
  def draw_text(self, text, length=None, jitter=False, noise=False):
    if length == None:
        length = 18 * len(text)
    img = Image.new('L', (length, 32))
    fnt = ImageFont.truetype("fonts/Anonymous.ttf" if not self.custom_fonts_path else self.custom_fonts_path, 20)

    d = ImageDraw.Draw(img)
    pos = (0, 5)
    if jitter:
        pos = (random.randint(0, 7), 5)
    else:
        pos = (0, 5)
    d.text(pos, text, fill=1, font=fnt)

    img = self.transforms(img)
    img[img > 0] = 1 
    
    if noise:
        img += torch.bernoulli(torch.ones_like(img) * 0.1)
        img = img.clamp(0, 1)
        

    return img[0]
InĀ [47]:
# Training model on Custom Handwritten Data from scratch
LR=1e-3
EPOCHS=20
sds = CustomWordsDataset(3, 2500, custom_fonts_path="./fonts/3dumb/2Dumb.ttf")
BATCH_SIZE = 32
dataloader = torch.utils.data.DataLoader(sds, batch_size=BATCH_SIZE, num_workers=0, collate_fn=collate_fn)

gtn_no_pretrained_model=SimpleNet()
optimizer=Adam(gtn_no_pretrained_model.parameters(), lr=LR)
# gtn_no_pretrained_model.cuda()
gtn_no_pretrained_model.train()
train_gtn_model(gtn_no_pretrained_model, EPOCHS, dataloader, CTCLoss, optimizer)
  0%|          | 0/20 [00:00<?, ?it/s]
InĀ [48]:
# energies = gtn_no_pretrained_model(alphabet.unsqueeze(0).unsqueeze(0).cuda())
energies = gtn_no_pretrained_model(alphabet.unsqueeze(0).unsqueeze(0))
targets = transform_word(string.ascii_lowercase)
pm = build_path_matrix(energies, targets.unsqueeze(0))

score, path = viterbi(energies, targets.unsqueeze(0))

points = list(zip(range(energies.shape[1]), path[0]))
plot_pm(pm[0].cpu().detach(), points)
print('energy is', score[0])
energy is -648.4822998046875
No description has been provided for this image
InĀ [49]:
img = sds.draw_text('hello')
# energies = gtn_no_pretrained_model(img.cuda().unsqueeze(0).unsqueeze(0))
energies = gtn_no_pretrained_model(img.unsqueeze(0).unsqueeze(0))
plt.imshow(img)
plot_energies(energies[0].detach().cpu())
No description has been provided for this image
No description has been provided for this image
InĀ [50]:
min_indices = energies[0].argmin(dim=-1)
print(indices_to_str(min_indices))
h_e_l_o

Visualize some images and their predictions given by the model.

InĀ [51]:
#TODO
# Some visualizations and predictions of the model on custome dataset
img = sds.draw_text('energy based models')
# energies = gtn_no_pretrained_model(img.cuda().unsqueeze(0).unsqueeze(0))
energies = gtn_no_pretrained_model(img.unsqueeze(0).unsqueeze(0))
plt.imshow(img)
min_indices = energies[0].argmin(dim=-1)
print(indices_to_str(min_indices))
e_n_e_r_g_y_b_a_s_e_d_m_o_d_e_l_s
No description has been provided for this image
InĀ [52]:
img = sds.draw_text('diffusion modeling')
# energies = gtn_no_pretrained_model(img.cuda().unsqueeze(0).unsqueeze(0))
energies = gtn_no_pretrained_model(img.unsqueeze(0).unsqueeze(0))
plt.imshow(img)
min_indices = energies[0].argmin(dim=-1)
print(indices_to_str(min_indices))
d_i_f_f_u_s_i_o_n_m_m_o_d_e_l_n_g
No description has been provided for this image
InĀ [53]:
img = sds.draw_text('transformers')
# energies = gtn_no_pretrained_model(img.cuda().unsqueeze(0).unsqueeze(0))
energies = gtn_no_pretrained_model(img.unsqueeze(0).unsqueeze(0))
plt.imshow(img)
min_indices = energies[0].argmin(dim=-1)
print(indices_to_str(min_indices))
t_r_a_n_s_f_o_r_m_e_r_s
No description has been provided for this image
InĀ [54]:
# End