nn.Transformer for Language Modelling

The structure is as follows:

  1. Embedding layer
  2. Position encoding layer
  3. nn.TransformerEncoder consists of multiple nn.TransformerEncoderLayer (with attention mask to avoid attending future tokens)
  4. Final linear layer with softmax function to output words

1. Import dependencies

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import TransformerEncoder, TransformerEncoderLayer

2. PositionalEncoding Class

In [6]:
class PositionalEncoding(nn.Module):
    
    def __init__(self, d_model, dropout = 0.1, max_len = 5000):
        super(PositionalEncoding, self).__init__()
        
        self.dropout = nn.Dropout(p = dropout)
        
        pe = torch.zeros(max_len, d_model)
        # torch.arange returns a 1-D tensor with values from start (0) to end (max_len)
        position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1) # unsqueeze add 1 to the axis 1 (from (4,) to (4, 1))
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        # register buffer means the variable pe is not learnable (pe is not a model parameter)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

3. TransformerModel Class

In [8]:
class TransformerModel(nn.Module):
    
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, token)
        
        self.init_weights()
    
    def _generate_square_subsequent_mask(self, sz):
        # torch.triu returns the upper triangle part of a matrix
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, src):
        
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask
        
        src = self.encoder(src) * math.sqrt(self.ninp) # convert input sequence to embeddings
        src = self.pos_encoder(src) # add positional encoding
        
        output = self.transformer_encoder(src, self.src_mask)
        output = self.decoder(output)
        
        return output

4. Load and Batch data

We will be using torchtext.datasets to load our Wikitext-2 dataset and batchify() function to arrange dataset into columns (batches), trimming off any tokens that doesn’t perfectly divided into batches.

In [11]:
import torchtext
from torchtext.data.utils import get_tokenizer
In [12]:
TEXT = torchtext.data.Field(tokenize = get_tokenizer('basic_english'), 
                            init_token = '<sos>', 
                            eos_token = '<eos>', 
                            lower = True)

train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
downloading wikitext-2-v1.zip
wikitext-2-v1.zip: 100%|██████████| 4.48M/4.48M [00:00<00:00, 6.31MB/s]
extracting
In [24]:
print(train_txt.examples[0].text[:100])
['<eos>', '=', 'valkyria', 'chronicles', 'iii', '=', '<eos>', '<eos>', 'senjō', 'no', 'valkyria', '3', '<unk>', 'chronicles', '(', 'japanese', '戦場のヴァルキュリア3', ',', 'lit', '.', 'valkyria', 'of', 'the', 'battlefield', '3', ')', ',', 'commonly', 'referred', 'to', 'as', 'valkyria', 'chronicles', 'iii', 'outside', 'japan', ',', 'is', 'a', 'tactical', 'role', '@-@', 'playing', 'video', 'game', 'developed', 'by', 'sega', 'and', 'media', '.', 'vision', 'for', 'the', 'playstation', 'portable', '.', 'released', 'in', 'january', '2011', 'in', 'japan', ',', 'it', 'is', 'the', 'third', 'game', 'in', 'the', 'valkyria', 'series', '.', '<unk>', 'the', 'same', 'fusion', 'of', 'tactical', 'and', 'real', '@-@', 'time', 'gameplay', 'as', 'its', 'predecessors', ',', 'the', 'story', 'runs', 'parallel', 'to', 'the', 'first', 'game', 'and', 'follows', 'the']
In [25]:
def batchify(data, bsz):
    data = TEXT.numericalize([data.examples[0].text])
    
    # Divide the dataset into bsz parts
    nbatch = data.size(0) // bsz
    
    # Trim any additional elements that doesn't fall into a batch
    data = data.narrow(0, 0, nbatch * bsz)
    
    # Evenly divide data across batches
    data = data.view(bsz, -1).t().contiguous()
    
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size)
val_data = batchify(val_txt, eval_batch_size)
test_data = batchify(test_txt, eval_batch_size)
In [28]:
train_data.shape
Out[28]:
torch.Size([104335, 20])

Here, we want a function to generate input and target sequence based on specified backprop through time (bptt).

In [33]:
bptt = 35

def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i: i + seq_len]
    target = source[i + 1: i + 1 + seq_len].view(-1)
    
    return data, target
Ryan

Ryan

Data Scientist

Leave a Reply