High-Level Training framework for Pytorch
Pywick is a high-level Pytorch training framework that aims to get you
up and running quickly with state of the art neural networks. Does the
world need another Pytorch framework? Probably not. But we started this
project when no good frameworks were available and it just kept growing.
So here we are.
Pywick tries to stay on the bleeding edge of research into neural networks. If you just wish to run a vanilla CNN, this is probably
going to be overkill. However, if you want to get lost in the world of neural networks, fine-tuning and hyperparameter optimization
for months on end then this is probably the right place for you :)
Among other things Pywick includes:
- State of the art normalization, activation, loss functions and
optimizers not included in the standard Pytorch library.
- A high-level module for training with callbacks, constraints, metrics,
conditions and regularizers.
- Dozens of popular object classification and semantic segmentation models.
- Comprehensive data loading, augmentation, transforms, and sampling capability.
- Utility tensor functions.
- Useful meters.
- Basic GridSearch (exhaustive and random).
pip install pywick
or specific version from git:
pip install git+https://github.com/achaiah/[email protected]
ModuleTrainer class provides a high-level training interface which abstracts
away the training loop while providing callbacks, constraints, initializers, regularizers,
from pywick.modules import ModuleTrainer from pywick.initializers import XavierUniform from pywick.metrics import CategoricalAccuracySingleInput import torch.nn as nn import torch.functional as F # Define your model EXACTLY as normal class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3) self.conv2 = nn.Conv2d(32, 64, kernel_size=3) self.fc1 = nn.Linear(1600, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2(x), 2)) x = x.view(-1, 1600) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x) model = Network() trainer = ModuleTrainer(model) # optionally supply cuda_devices as a parameter initializers = [XavierUniform(bias=False, module_filter='fc*')] # initialize metrics with top1 and top5 metrics = [CategoricalAccuracySingleInput(top_k=1), CategoricalAccuracySingleInput(top_k=5)] trainer.compile(loss='cross_entropy', # callbacks=callbacks, # define your callbacks here (e.g. model saver, LR scheduler) # regularizers=regularizers, # define regularizers # constraints=constraints, # define constraints optimizer='sgd', initializers=initializers, metrics=metrics) trainer.fit_loader(train_dataset_loader, val_loader=val_dataset_loader, num_epoch=20, verbose=1)
You also have access to the standard evaluation and prediction functions:
loss = trainer.evaluate(x_train, y_train) y_pred = trainer.predict(x_train)
PyWick provides a wide range of callbacks, generally mimicking the interface
CSVLogger- Logs epoch-level metrics to a CSV file
CyclicLRScheduler- Cycles through min-max learning rate
EarlyStopping- Provides ability to stop training early based on supplied criteria
History- Keeps history of metrics etc. during the learning process
LambdaCallback- Allows you to implement your own callbacks on the fly
LRScheduler- Simple learning rate scheduler based on function or supplied schedule
ModelCheckpoint- Comprehensive model saver
ReduceLROnPlateau- Reduces learning rate (LR) when a plateau has been reached
SimpleModelCheckpoint- Simple model saver
- Additionally, a
TensorboardLoggeris incredibly easy to implement
via the TensorboardX (now
part of pytorch 1.1 release!)
from pywick.callbacks import EarlyStopping callbacks = [EarlyStopping(monitor='val_loss', patience=5)] trainer.set_callbacks(callbacks)
PyWick also provides regularizers:
Both regularizers and constraints can be selectively applied on layers using regular expressions and the
argument. Constraints can be explicit (hard) constraints applied at an arbitrary batch or
epoch frequency, or they can be implicit (soft) constraints similar to regularizers
where the the constraint deviation is added as a penalty to the total model loss.
from pywick.constraints import MaxNorm, NonNeg from pywick.regularizers import L1Regularizer # hard constraint applied every 5 batches hard_constraint = MaxNorm(value=2., frequency=5, unit='batch', module_filter='*fc*') # implicit constraint added as a penalty term to model loss soft_constraint = NonNeg(lagrangian=True, scale=1e-3, module_filter='*fc*') constraints = [hard_constraint, soft_constraint] trainer.set_constraints(constraints) regularizers = [L1Regularizer(scale=1e-4, module_filter='*conv*')] trainer.set_regularizers(regularizers)
You can also fit directly on a
torch.utils.data.DataLoader and can have
a validation set as well :
from pywick import TensorDataset from torch.utils.data import DataLoader train_dataset = TensorDataset(x_train, y_train) train_loader = DataLoader(train_dataset, batch_size=32) val_dataset = TensorDataset(x_val, y_val) val_loader = DataLoader(val_dataset, batch_size=32) trainer.fit_loader(loader, val_loader=val_loader, num_epoch=100)
Data Augmentation and Datasets
The PyWick package provides a ton of good data augmentation and transformation
tools which can be applied during data loading. The package also provides the flexible
FolderDataset and 'MultiFolderDataset' classes to handle most dataset needs.
These transforms work directly on torch tensors
Additionally, we provide image-specific manipulations directly on tensors:
Affine Transforms (perform affine or affine-like transforms on torch tensors)
We also provide a class for stringing multiple affine transformations together so that only one interpolation takes place:
Blur and Scramble transforms (for tensors)
Datasets and Sampling
We provide the following datasets which provide general structure and iterators for sampling from and using transforms on in-memory or out-of-memory data. In particular,
the FolderDataset has been designed to fit most of your dataset needs. It has extensive options for data filtering and manipulation.
It supports loading images for classification, segmentation and even arbitrary source/target mapping. Take a good look at its documentation for more info.
In many scenarios it is important to ensure that your traing set is properly balanced,
however, it may not be practical in real life to obtain such a perfect dataset. In these cases
you can use the
ImbalancedDatasetSampler as a drop-in replacement for the basic sampler provided
by the DataLoader. More information can be found here
from pywick.samplers import ImbalancedDatasetSampler train_loader = torch.utils.data.DataLoader(train_dataset, sampler=ImbalancedDatasetSampler(train_dataset), batch_size=args.batch_size, **kwargs)
PyWick provides a few utility functions not commonly found:
th_gather_nd(N-dimensional version of torch.gather)
th_affine3d(affine transforms on torch.Tensors)
Acknowledgements and References
We stand on the shoulders of (github?) giants and couldn't have done
this without the rich github ecosystem and community. This framework is
based in part on the excellent
originally published by @ncullen93. Additionally, many models have been
gently borrowed/modified from @Cadene pretrained models
Subscribe to Python Awesome
Get the latest posts delivered right to your inbox