The lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate.

Simple installation from PyPI

pip install pytorch-lightning  


View the docs here
** DOCS TEMPORARILY have broken links because we recently switched orgs from williamfalcon/pytorch-lightning to pytorchlightning/pytorch-lightning [jan 15, 2020].

As a temporary hack, when you get the 404, replace with


Copy and run this COLAB!

What is it?

Lightning is a very lightweight wrapper on PyTorch that decouples the science code from the engineering code. It's more of a style-guide than a framework. By refactoring your code, we can automate most of the non-research code.

To use Lightning, simply refactor your research code into the LightningModule format (the science) and Lightning will automate the rest (the engineering). Lightning guarantees tested, correct, modern best practices for the automated parts.

  • If you are a researcher, Lightning is infinitely flexible, you can modify everything down to the way .backward is called or distributed is set up.
  • If you are a scientist or production team, lightning is very simple to use with best practice defaults.

What does lightning control for me?

Everything in Blue!
This is how lightning separates the science (red) from the engineering (blue).


How much effort is it to convert?

You're probably tired of switching frameworks at this point. But it is a very quick process to refactor into the Lightning format (ie: hours). Check out this tutorial

Starting a new project?

Use our seed-project aimed at reproducibility!

Why do I want to use lightning?

Every research project starts the same, a model, a training loop, validation loop, etc. As your research advances, you're likely to need distributed training, 16-bit precision, checkpointing, gradient accumulation, etc.

Lightning sets up all the boilerplate state-of-the-art training for you so you can focus on the research.

How do I do use it?

Think about Lightning as refactoring your research code instead of using a new framework. The research code goes into a LightningModule which you fit using a Trainer.

The LightningModule defines a system such as seq-2-seq, GAN, etc... It can ALSO define a simple classifier such as the example below.

To use lightning do 2 things:

  1. Define a LightningModule
    WARNING: This syntax is for version 0.5.0+ where abbreviations were removed.
    import os
    import torch
    from torch.nn import functional as F
    from import DataLoader
    from torchvision.datasets import MNIST
    from torchvision import transforms
    import pytorch_lightning as pl
    class CoolSystem(pl.LightningModule):
        def __init__(self):
            super(CoolSystem, self).__init__()
            # not the best model...
            self.l1 = torch.nn.Linear(28 * 28, 10)
        def forward(self, x):
            return torch.relu(self.l1(x.view(x.size(0), -1)))
        def training_step(self, batch, batch_idx):
            # REQUIRED
            x, y = batch
            y_hat = self.forward(x)
            loss = F.cross_entropy(y_hat, y)
            tensorboard_logs = {'train_loss': loss}
            return {'loss': loss, 'log': tensorboard_logs}
        def validation_step(self, batch, batch_idx):
            # OPTIONAL
            x, y = batch
            y_hat = self.forward(x)
            return {'val_loss': F.cross_entropy(y_hat, y)}
        def validation_end(self, outputs):
            # OPTIONAL
            avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
            tensorboard_logs = {'val_loss': avg_loss}
            return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
        def test_step(self, batch, batch_idx):
            # OPTIONAL
            x, y = batch
            y_hat = self.forward(x)
            return {'test_loss': F.cross_entropy(y_hat, y)}
        def test_end(self, outputs):
            # OPTIONAL
            avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
            tensorboard_logs = {'test_loss': avg_loss}
            return {'avg_test_loss': avg_loss, 'log': tensorboard_logs}
        def configure_optimizers(self):
            # REQUIRED
            # can return multiple optimizers and learning_rate schedulers
            # (LBFGS it is automatically supported, no need for closure function)
            return torch.optim.Adam(self.parameters(), lr=0.02)
        def train_dataloader(self):
            # REQUIRED
            return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
        def val_dataloader(self):
            # OPTIONAL
            return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
        def test_dataloader(self):
            # OPTIONAL
            return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
  2. Fit with a trainer
    from pytorch_lightning import Trainer
    model = CoolSystem()
    # most basic trainer, uses good defaults
    trainer = Trainer()   

Trainer sets up a tensorboard logger, early stopping and checkpointing by default (you can modify all of them or
use something other than tensorboard).

Here are more advanced examples

# train on cpu using only 10% of the data (for demo purposes)
trainer = Trainer(max_epochs=1, train_percent_check=0.1)

# train on 4 gpus (lightning chooses GPUs for you)
# trainer = Trainer(max_epochs=1, gpus=4, distributed_backend='ddp')  

# train on 4 gpus (you choose GPUs)
# trainer = Trainer(max_epochs=1, gpus=[0, 1, 3, 7], distributed_backend='ddp')   

# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)
# trainer = Trainer(max_epochs=1, gpus=8, num_gpu_nodes=4, distributed_backend='ddp')

# train (1 epoch only here for demo)

# view tensorboard logs'View tensorboard logs by running\ntensorboard --logdir {os.getcwd()}')'and going to http://localhost:6006 on your browser')

When you're all done you can even run the test set separately.


Could be as complex as seq-2-seq + attention

# define what happens for training here
def training_step(self, batch, batch_idx):
    x, y = batch
    # define your own forward and loss calculation
    hidden_states = self.encoder(x)
    # even as complex as a seq-2-seq + attn model
    # (this is just a toy, non-working example to illustrate)
    start_token = '<SOS>'
    last_hidden = torch.zeros(...)
    loss = 0
    for step in range(max_seq_len):
        attn_context = self.attention_nn(hidden_states, start_token)
        pred = self.decoder(start_token, attn_context, last_hidden) 
        last_hidden = pred
        pred = self.predict_nn(pred)
        loss += self.loss(last_hidden, y[step])
    #toy example as well
    loss = loss / max_seq_len
    return {'loss': loss} 

Or as basic as CNN image classification

# define what happens for validation here
def validation_step(self, batch, batch_idx):    
    x, y = batch
    # or as basic as a CNN classification
    out = self.forward(x)
    loss = my_loss(out, y)
    return {'loss': loss} 

And you also decide how to collate the output of all validation steps

def validation_end(self, outputs):
    Called at the end of validation to aggregate outputs
    :param outputs: list of individual outputs of each validation step
    val_loss_mean = 0
    val_acc_mean = 0
    for output in outputs:
        val_loss_mean += output['val_loss']
        val_acc_mean += output['val_acc']

    val_loss_mean /= len(outputs)
    val_acc_mean /= len(outputs)
    logs = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
    result = {'log': logs}
    return result