Torch-atom

[TOC]

Introduction

A basic and simple training framework for pytorch, easy for extension.

architecture figure

Dependence

  • torch==1.7.0+cu110
  • torchvision==0.8.0
  • easydict==1.9
  • tensorboard==2.7
  • tensorboardX==2.4
  • PyYAML==5.3.1

Features

  • Any module can be easily customized
  • Not abstract, easy to learn, develop and debug
  • With a lot of repetitive work reduction, it can be more easily control the training process in development and research
  • Friendly to multi-network interactive training, such as GAN, transfer learning, knowledge distillation, etc.

Train

python main.py --model resnet18 --save_dir cifar100_resnet18 --config_path ./configs/cifar100.yml

Results

CIFAR100

All CIFAR100 log files can be downloaded in package here: pan.baidu
code:3lp2

Network Accuracy log
resnet18 76.46 pan.baidu code: ewnd
resnet34 77.23 pan.baidu code: dq4r
resnet50 76.82 pan.baidu code: 1e62
resnet101 77.32 pan.baidu code: myfv
vgg11_bn 70.52 pan.baidu code: 2pun
vgg13_bn 73.71 pan.baidu code: 4vmm
mobilenetV2 68.99 pan.baidu code: e93w
shufflenet 71.17 pan.baidu code: lnvy
shufflenetV2 71.16 pan.baidu code: vmi6

CIFAR10

All CIFAR10 log files can be downloaded in package here: pan.baidu
code:3iqz

Network Accuracy log
resnet18 94.92 pan.baidu code: a20j
resnet34 94.80 pan.baidu code: q8h1
resnet50 94.81 pan.baidu code: f3wr
resnet101 95.45 pan.baidu code: d3i8
vgg11_bn 92.21 pan.baidu code: di45
vgg13_bn 93.74 pan.baidu code: su1z
mobilenetV2 90.92 pan.baidu code: todf
shufflenet 92.06 pan.baidu code: 1xr2
shufflenetV2 91.61 pan.baidu code: 8swu

Demo Train

  • import modules

    from src import DatasetBuilder, TransformBuilder, ModelBuilder, LossBuilder, LossWrapper, OptimizerBuilder, SchedulerBuilder, MetricBuilder
    from torch.utils.data import DataLoader
    from src import Controller
    from src.utils import AverageMeter
  • Load your dataloader and transform

    transform_name = 'cifar100_transform'  # your transform function name
    dataset_name = 'CIFAR100'  # your dataset class name
    train_transform, val_transform = TransformBuilder.load(transform_name)
    trainset, trainset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=train_transform, train=True)
    valset, valset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=val_transform, train=False)
    train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
  • Load model, loss wrapper, metrics, optimizer, schemes and controller

    epochs = 20
    
    model = ModelBuilder.load("resnet18", num_classes=100)
    
    # if categorical weights: LossBuilder.load("CrossEntropyLoss", weight=torch.tensor([1.0, 2.0]).float())
    loss_func1 = LossBuilder.load("CrossEntropyLoss")
    
    # Only one loss function, cross entropy, set its weight to 1.0
    loss_wrapper = LossWrapper([loss_func1], [1.0])
    
    model = model.cuda()
    loss_wrapper = loss_wrapper.cuda()
    metric_functions = [MetricBuilder.load('Accuracy')]
    
    optimizer, optimizer_param = OptimizerBuilder.load('Adam', model.parameters(), lr=0.1)
    scheduler, scheduler_param = SchedulerBuilder.load("cosine_annealing_lr", optimizer, max_epoch=epochs)
    controller = Controller(loss_wrapper=loss_wrapper, model=model, optimizer=optimizer)
  • Train !

    for epoch in range(epochs):
        # train
        model.train()
        loss_recorder = AverageMeter(type='scalar', name='total loss')
        loss_list_recorder = AverageMeter(type='tuple', num_scalar=1, names=["CrossEntropyLoss"])
        metric_list_recorder = AverageMeter(type='tuple', num_scalar=1, names=["Accuracy"])
        for (img, label) in train_loader:
            img = img.cuda()
            label = label.cuda()
            loss, loss_tuple, output_no_grad = controller.train_step(img, label)
            
            loss_recorder.update(loss.item(), img.size(0))
            loss_list_recorder.update(loss_tuple, img.size(0))
    
            metrics = tuple([func(output_no_grad, label) for func in metric_functions])
            metric_list_recorder.update(metrics, img.size(0))
    
        print(f"total loss:{loss_recorder.get_value()} loss_tuple:{loss_list_recorder} metrics:{metric_list_recorder}")
    
        # eval
        model.eval()
        # ...

Customize

Customize Dataset

  • In src/datasets directory, define your customized mydataset.py like cifar.py,.
  • CIFAR class needs some parameters for initialization, such as root, train, download, which can be specified in src/datasets/dataset_config.yml. Something should be noticed that transform needs to be set in transorms.py, details can be found at Customize Transform.
  • In src/datasets/dataset_builder.py, please import your dataset class. For example, MyDataset class is defined in mydataset.py, thus from .mydataset import * in dataset_builder.py
  • In configs/xxx.yml, set dataset.name to MyDataset

Customize Transform

  • In src/datasets/transforms.py, define your transform function, named my_transform_func, which returns train_transform and val_transform
  • In configs/xxx.yml, please set dataset.transform_name to my_transform_func

Customize Model

  • In src/models directory, define your customized model, such as my_model.py , and define the module class MyModel. Please refer to resnet.py
  • In src/models/model_builder.py, import your model. from .my_model import * under the try process, and from my_model import * under the except process. It is just convenient for debugging.
  • In configs/xxx.yml, set the model['name'] to MyModel

Customize Loss Function

  • In src/losses directory, you can define customized loss Module, such as CrossEntropyLoss in classification.py.
  • Then import your loss Module in loss_builder.py
  • Maybe your model is supervised by multiple loss functions, which have different weights, so LossWrapper Module in src/losses/loss_wrapper.py may meet the requirement.
  • In configs/xxx.yml, please add your loss names and weights into train.criterion.names and train/criterion/loss_weights respectively.

Customize Optimizer

  • In src/optimizer directory, optimizers.py can be found, please define your customized optimizer here. For example, SGD and Adam have already defined, parameters and lr should be specified, and other params need to be specifed by *args, **kwargs. Please refer to

    # src/optimizer/optimizers.py
    def SGD(parameters, lr, *args, **kwargs) -> optim.Optimizer:
        optimizer = optim.SGD(parameters, lr=lr, *args, **kwargs)
        return optimizer
  • For other parameters, such as weight_decay, can be set in src/optimizer/optimizer_config.yml. Please refer to the below yaml, and it is ok for 5e-4 format, we transform it in src/optimizer/optimizer_builder.py.

    # src/optimizer/optimizer_config.yml
    SGD:
      momentum: 0.9
      weight_decay: 5e-4
      dampening: 0
      nesterov: False
  • In configs/xxx.yml, set the train['lr'], and set the train['optimizer'] to SGD

Customize Schemes

  • In src/schemes/lr_schemes.py, define your learning rate scheme function, named my_scheduler, which requires some params, such as optimizer, epochs and so on.
  • Other params can be specified easily in src/schemes/scheme_config.yml
  • In configs/xxx.yml, set the train.schedule to my_scheduler

Customize Metrics

  • In src/metrics/ directory, define your metric, such as Accuracy in accuracy.py, which computes the metric of predictions and target and returns an metric scalar
  • Import your metric in metric_builder.py, for example, from .accuracy import *
  • Multiple metrics are supported, in configs/xxx.yml, add your metrics into train.metric.names. While training model, the strategy of saving checkpoint refers to the train.metrics.key_metric_name in configs.xxx.yml, more details can be found at Customize Checkpoint Saving Strategy

Customize Training and Validation Procedure for One Batch

  • In src/controller.py, please feel free to build your training end validation step
  • Training step returns loss, loss_tuple and output_no_grad, where loss_tuple and output_no_grad only involve in logging, whether loss has a gradient or not depends on you.

Customize Checkpoint Saving Strategy

  • After training epoch, validation epoch will be performed in general. Torch-atom’s NetIO in src/utils/netio.py will save the best state dict according to key_metric_name and strategy in configs/xxx.yml
  • Of course, checkpoint can be saved each save_freq epoch, which can be set in configs/xxx.yml as well

Todo

  • DDP training
  • More experiment results
  • More widely-used datasets and models
  • Some visualization code for analysis
    • bad case analysis
    • data augmentation visualization

Acknowledgement

Torch-atom got ideas and developed based on the following projects:

open-mmlab/mmclassification

weiaicunzai/pytorch-cifar100

Citation

If you find this project useful in your research, please consider cite:

@misc{2022torchatom,
    title={Torch-atom: A basic and simple training framework for pytorch},
    author={Baitan Shao},
    howpublished = {\url{https://github.com/shaoeric/torch-atom}},
    year={2022}
}

License

The MIT License | Open Source Initiative

Finally

Please feel free to submit issues, ?

GitHub

View Github