5. Initialise TransformerModel instance

In [36]:
ntokens = len(TEXT.vocab.stoi) # vocab size
emsize = 200 # embedding dimension
nhid = 200 # hidden state dimension
nlayers = 2 # no. of TransformerEncoderLayer in TransformerEncoder
nhead  = 2 # no. of heads in multiheadattention models
dropout = 0.2

model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device) # Run it on cpu or gpu device?

6. Train and Evaluate the Model

  1. Define the loss function – CrossEntropyLoss
  2. Define the optimiser – SGD (stochastic gradient descent)
  3. Define the learning rate – Initial lr of 5.0 but we are using the StepLR which adjust the learning rate through epochs
  4. Use clip_gradnorm function to scall all the gradient to prevent exploding gradient
In [38]:
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95) # gamma - rate of decay

Training

  1. Turn on training mode
  2. Initialise variables such as total_loss, start_time, and ntokens
  3. For each batch:
    • Get the batch for training
    • Set gradients to zero
    • Forward pass data through to our Transformer
    • Compute loss
    • Backprop (autograd)
    • Update weights
    • Print total_loss every 200 data samples
In [42]:
import time
def train():
    model.train()
    
    total_loss = 0.
    start_time = time.time()
    ntokens = len(TEXT.vocab.stoi)
    
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad() # set all gradients to zero
        
        output = model(data) # forward pass
        
        loss = criterion(output.view(-1, ntokens), targets) # compute loss
        loss.backward() # backward prop through autograd
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # clipping to prevent exploding gradient
        
        optimizer.step() # update weights
        
        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            
            total_loss = 0
            start_time = time.time()

Evaluation

  1. Turn eval mode on
  2. Initialise variables such as total_loss (eval) and ntokens
  3. Within no_grad (autograd turned off), for each batch:
    • Get the batch
    • Forward pass the eval data
    • Compute the total loss
    • Return val loss
In [43]:
def evaluate(eval_model, data_source):
    
    eval_model.eval()
    total_loss = 0.
    ntokens = len(TEXT.vocab.stoi)
    
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            output = eval_model(data)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
        
    return total_loss / (len(data_source) - 1)

The training process

For each epoch:

  1. Train through the whole dataset in batches
  2. Evaluate the model at the end of the epoch
  3. Adjust the learning rate through the scheduler
In [44]:
best_val_loss = float("inf")
epochs = 3
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model
    
    scheduler.step()
| epoch   1 |   200/ 2981 batches | lr 5.00 | ms/batch 678.52 | loss  8.05 | ppl  3141.77
| epoch   1 |   400/ 2981 batches | lr 5.00 | ms/batch 670.57 | loss  6.81 | ppl   907.39
| epoch   1 |   600/ 2981 batches | lr 5.00 | ms/batch 633.41 | loss  6.39 | ppl   596.21
| epoch   1 |   800/ 2981 batches | lr 5.00 | ms/batch 624.25 | loss  6.25 | ppl   517.00
| epoch   1 |  1000/ 2981 batches | lr 5.00 | ms/batch 634.62 | loss  6.12 | ppl   455.91
| epoch   1 |  1200/ 2981 batches | lr 5.00 | ms/batch 625.54 | loss  6.10 | ppl   446.01
| epoch   1 |  1400/ 2981 batches | lr 5.00 | ms/batch 673.19 | loss  6.05 | ppl   425.65
| epoch   1 |  1600/ 2981 batches | lr 5.00 | ms/batch 626.95 | loss  6.05 | ppl   425.57
| epoch   1 |  1800/ 2981 batches | lr 5.00 | ms/batch 650.32 | loss  5.96 | ppl   387.49
| epoch   1 |  2000/ 2981 batches | lr 5.00 | ms/batch 683.12 | loss  5.96 | ppl   386.46
| epoch   1 |  2200/ 2981 batches | lr 5.00 | ms/batch 693.66 | loss  5.85 | ppl   347.61
| epoch   1 |  2400/ 2981 batches | lr 5.00 | ms/batch 701.07 | loss  5.89 | ppl   363.04
| epoch   1 |  2600/ 2981 batches | lr 5.00 | ms/batch 694.86 | loss  5.91 | ppl   367.34
| epoch   1 |  2800/ 2981 batches | lr 5.00 | ms/batch 665.96 | loss  5.80 | ppl   331.24
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 2019.01s | valid loss  5.80 | valid ppl   330.54
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 2981 batches | lr 4.51 | ms/batch 657.31 | loss  5.81 | ppl   334.17
| epoch   2 |   400/ 2981 batches | lr 4.51 | ms/batch 644.85 | loss  5.78 | ppl   323.31
| epoch   2 |   600/ 2981 batches | lr 4.51 | ms/batch 649.17 | loss  5.60 | ppl   271.25
| epoch   2 |   800/ 2981 batches | lr 4.51 | ms/batch 649.78 | loss  5.64 | ppl   281.39
| epoch   2 |  1000/ 2981 batches | lr 4.51 | ms/batch 687.45 | loss  5.59 | ppl   268.69
| epoch   2 |  1200/ 2981 batches | lr 4.51 | ms/batch 675.11 | loss  5.62 | ppl   276.31
| epoch   2 |  1400/ 2981 batches | lr 4.51 | ms/batch 694.79 | loss  5.63 | ppl   279.10
| epoch   2 |  1600/ 2981 batches | lr 4.51 | ms/batch 650.40 | loss  5.66 | ppl   285.76
| epoch   2 |  1800/ 2981 batches | lr 4.51 | ms/batch 652.19 | loss  5.59 | ppl   267.81
| epoch   2 |  2000/ 2981 batches | lr 4.51 | ms/batch 647.80 | loss  5.62 | ppl   276.41
| epoch   2 |  2200/ 2981 batches | lr 4.51 | ms/batch 647.48 | loss  5.51 | ppl   248.20
| epoch   2 |  2400/ 2981 batches | lr 4.51 | ms/batch 648.24 | loss  5.58 | ppl   265.53
| epoch   2 |  2600/ 2981 batches | lr 4.51 | ms/batch 647.55 | loss  5.60 | ppl   270.04
| epoch   2 |  2800/ 2981 batches | lr 4.51 | ms/batch 677.35 | loss  5.52 | ppl   249.35
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 2007.69s | valid loss  5.60 | valid ppl   270.49
-----------------------------------------------------------------------------------------
| epoch   3 |   200/ 2981 batches | lr 4.29 | ms/batch 647.95 | loss  5.55 | ppl   256.18
| epoch   3 |   400/ 2981 batches | lr 4.29 | ms/batch 646.51 | loss  5.56 | ppl   259.75
| epoch   3 |   600/ 2981 batches | lr 4.29 | ms/batch 653.71 | loss  5.38 | ppl   215.98
| epoch   3 |   800/ 2981 batches | lr 4.29 | ms/batch 655.16 | loss  5.42 | ppl   226.74
| epoch   3 |  1000/ 2981 batches | lr 4.29 | ms/batch 681.45 | loss  5.39 | ppl   219.43
| epoch   3 |  1200/ 2981 batches | lr 4.29 | ms/batch 662.13 | loss  5.42 | ppl   225.75
| epoch   3 |  1400/ 2981 batches | lr 4.29 | ms/batch 643.32 | loss  5.46 | ppl   235.62
| epoch   3 |  1600/ 2981 batches | lr 4.29 | ms/batch 642.55 | loss  5.49 | ppl   241.67
| epoch   3 |  1800/ 2981 batches | lr 4.29 | ms/batch 639.31 | loss  5.42 | ppl   225.96
| epoch   3 |  2000/ 2981 batches | lr 4.29 | ms/batch 646.99 | loss  5.44 | ppl   231.19
| epoch   3 |  2200/ 2981 batches | lr 4.29 | ms/batch 659.40 | loss  5.33 | ppl   206.51
| epoch   3 |  2400/ 2981 batches | lr 4.29 | ms/batch 655.10 | loss  5.41 | ppl   222.67
| epoch   3 |  2600/ 2981 batches | lr 4.29 | ms/batch 656.79 | loss  5.42 | ppl   226.89
| epoch   3 |  2800/ 2981 batches | lr 4.29 | ms/batch 641.26 | loss  5.34 | ppl   209.49
-----------------------------------------------------------------------------------------
| end of epoch   3 | time: 1984.65s | valid loss  5.57 | valid ppl   261.78
-----------------------------------------------------------------------------------------

7. Final eval on Test Data

In [46]:
test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss)))
print('=' * 89)
=========================================================================================
| End of training | test loss  5.47 | test ppl   238.64
=========================================================================================
Ryan

Ryan

Data Scientist

Leave a Reply