PyTorch Lightning Structure

How do we refactor our PyTorch code to fit the PyTorch Lightning framework?

There are 5 main areas of a research project that you need to refactor:

  1. The model
    • Model definition (init method) (REQUIRED)
    • Forward pass (forward method) (REQUIRED)
  2. The data and dataloader
    • Data download and preparation, shouldn’t include this in the dataloader method (prepare_data method)
    • Batching training data (train_dataloader method) (REQUIRED)
    • Batching validation data (val_dataloader method)
    • Batching test data (test_dataloader method)
  3. The optimiser
    • Initialising and setting up your optimiser (configure_optimizers method) (REQUIRED)
  4. The loss
    • Set up the loss function and compute the loss (doesn’t have to be a separate method, can just be included in training_step)
  5. The training / validation / testing computations
    • Training process (training_step method) (REQUIRED)
    • Validation process (validation_step method)
    • Testing process (testing_step method)
How to train our Lightning model?

Using the Trainer class provided by PyTorch Lightning. It covers .fit (training and validation) and .test (testing)

What other benefits come with PyTorch Lightning?
  1. You can easily plot the results by running tensorboard as all the logging (engineering) is handled by PyTorch Lightning.
  2. A bonus feature (amazing!) is that if you keep calling fit, it’ll continue to train the model where it left off!

Practical Implementation of the 5 main areas listed above

The Model
Import Dependencies
In [3]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
The Example MNIST Model
In [6]:
def __init__(self):
    super().__init__()

    # Normal PyTorch layer of CNN
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 256)
    self.layer_3 = torch.nn.Linear(256, 10)

def forward(self, x):

    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) --> (b, 1 * 28 * 28)
    x = x.view(batch_size, -1)

    # Input flowing through all three layers!
    x = self.layer_1(x)
    x = torch.relu(x)
    x = self.layer_2(x)
    x = torch.relu(x)
    x = self.layer_3(x)

    # probability distribution over labels
    x = torch.log_softmax(x, dim=1)

    return x
The Data
Import dependencies
In [5]:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms
MNIST Data

The prepare_data method is used to download or get data ready for our processing pipeline. The train_dataloader takes the downloaded data and process them through predefined transformation pipeline and batch them!

In [7]:
def prepare_data(self):
    MNIST(os.getcwd(), train = True, download = True)

def train_dataloader(self):
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
    
    return DataLoader(mnist_train, batch_size=64)
The Optimiser
In [9]:
from torch.optim import Adam

def configure_optimizers(self):
    return Adam(self.parameters(), lr=1e-3)
The Loss
In [10]:
def cross_entropy_loss(self, logits, labels):
    return F.nll_loss(logits, labels)
The Training / Validation / Test Computations
In [11]:
def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = self.cross_entropy_loss(logits, y)
    
    logs = {'loss': loss}
    return {'loss': loss, 'log': logs}

Putting it all together

In [13]:
class LightningMNIST(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)
        return x
    
    def prepare_data(self):
        MNIST(os.getcwd(), train = True, download = True)

    def train_dataloader(self):
        transform=transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.1307,), (0.3081,))])
        mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
        return DataLoader(mnist_train, batch_size=64)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # add logging
        logs = {'loss': loss}
        return {'loss': loss, 'log': logs}

Training with Trainer

In [14]:
from pytorch_lightning import Trainer

model = LightningMNIST()
trainer = Trainer()
trainer.fit(model)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw/train-images-idx3-ubyte.gz to /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw/train-labels-idx1-ubyte.gz to /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw/t10k-images-idx3-ubyte.gz to /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw/t10k-labels-idx1-ubyte.gz to /Users/rong2/Desktop/pytorch_learning/notebooks/MNIST/raw
Processing...
Done!
  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33 K  
2 | layer_3 | Linear | 2 K
Out[14]:
1
Training continued when you run .fit again
In [17]:
trainer.fit(model)
  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K 
1 | layer_2 | Linear | 33 K  
2 | layer_3 | Linear | 2 K
Out[17]:
1

Logging with Tensorboard

In [19]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
Ryan

Ryan

Data Scientist

Leave a Reply