PyTorch Wrapper

PyTorch Wrapper is a library that provides a systematic and extensible way to build, train, evaluate, and tune deep learning models using PyTorch. It also provides several ready to use modules and functions for fast model development.

Installation

From PyPI

pip install pytorch-wrapper

From Source

git clone https://github.com/jkoutsikakis/pytorch-wrapper.git
cd pytorch-wrapper
pip install .

Basic abstract usage pattern

import torch
import pytorch_wrapper as pw

train_dataloader = ...
val_dataloader = ...
dev_dataloader = ...

evaluators = { 'acc': pw.evaluators.AccuracyEvaluator(), ... }
loss_wrapper = pw.loss_wrappers.GenericPointWiseLossWrapper(torch.nn.BCEWithLogitsLoss())

model = ...

system = pw.System(model=model, device=torch.device('cuda'))

optimizer = torch.optim.Adam(system.model.parameters())

system.train(
    loss_wrapper,
    optimizer,
    train_data_loader=train_dataloader,
    evaluators=evaluators,
    evaluation_data_loaders={'val': val_dataloader},
    callbacks=[
        pw.training_callbacks.EarlyStoppingCriterionCallback(
            patience=3,
            evaluation_data_loader_key='val',
            evaluator_key='acc',
            tmp_best_state_filepath='current_best.weights'
        )
    ]
)

results = system.evaluate(dev_dataloader, evaluators)

predictions = system.predict(dev_dataloader)

system.save_model_state('model.weights')
system.load_model_state('model.weights')

GitHub