metrics
Machine learning metrics for distributed, scalable PyTorch applications.
Installation
Simple installation from PyPI
pip install torchmetrics
Other installations
Install using conda
conda install torchmetrics
Pip from source
# with git
pip install git+https://github.com/PytorchLightning/metrics.git@master
Pip from archive
pip install https://github.com/PyTorchLightning/metrics/archive/master.zip
What is Torchmetrics
TorchMetrics is a collection of 25+ PyTorch metrics implementations and an easy-to-use API to create custom metrics. It offers:
- A standardized interface to increase reproducibility
- Reduces boilerplate
- Automatic accumulation over batches
- Metrics optimized for distributed-training
- Automatic synchronization between multiple devices
You can use TorchMetrics with any PyTorch model or with PyTorch Lightning to enjoy additional features such as:
- Module metrics are automatically placed on the correct device.
- Native support for logging metrics in Lightning to reduce even more boilerplate.
Using TorchMetrics
Module metrics
The module-based metrics contain internal metric states (similar to the parameters of the PyTorch module) that automate accumulation and synchronization across devices!
- Automatic accumulation over multiple batches
- Automatic synchronization between multiple devices
- Metric arithmetic
This can be run on CPU, single GPU or multi-GPUs!
For the single GPU/CPU case:
import torch
# import our library
import torchmetrics
# initialize metric
metric = torchmetrics.Accuracy()
n_batches = 10
for i in range(n_batches):
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
# metric on current batch
acc = metric(preds, target)
print(f"Accuracy on batch {i}: {acc}")
# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")
Module metric usage remains the same when using multiple GPUs or multiple nodes.
Example using DDP
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# initialize model
metric = torchmetrics.Accuracy()
# define a model and append your metric to it
# this allows metric states to be placed on correct accelerators when
# .to(device) is called on the model
model = nn.Linear(10, 10)
model.metric = metric
model = model.to(rank)
# initialize DDP
model = DDP(model, device_ids=[rank])
n_epochs = 5
# this shows iteration over multiple training epochs
for n in range(n_epochs):
# this will be replaced by a DataLoader with a DistributedSampler
n_batches = 10
for i in range(n_batches):
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
# metric on current batch
acc = metric(preds, target)
if rank == 0: # print only for rank 0
print(f"Accuracy on batch {i}: {acc}")
# metric on all batches and all accelerators using custom accumulation
# accuracy is same across both accelerators
acc = metric.compute()
print(f"Accuracy on all data: {acc}, accelerator rank: {rank}")
# Reseting internal state such that metric ready for new data
metric.reset()
Implementing your own Module metric
Implementing your own metric is as easy as subclassing an torch.nn.Module
. Simply, subclass torchmetrics.Metric
and implement the following methods:
class MyAccuracy(Metric):
def __init__(self, dist_sync_on_step=False):
# call `self.add_state`for every internal state that is needed for the metrics computations
# dist_reduce_fx indicates the function that should be used to reduce
# state from multiple processes
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
# update metric states
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
# compute final result
return self.correct.float() / self.total
Functional metrics
Similar to torch.nn
, most metrics have both a module-based and a functional version.
The functional versions are simple python functions that as input take torch.tensors and return the corresponding metric as a torch.tensor.
import torch
# import our library
import torchmetrics
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
acc = torchmetrics.functional.accuracy(preds, target)