A basic and simple training framework for pytorch, easy for extension.

architecture figure


  • torch==1.7.0+cu110
  • torchvision==0.8.0
  • easydict==1.9
  • tensorboard==2.7
  • tensorboardX==2.4
  • PyYAML==5.3.1


  • Any module can be easily customized
  • Not abstract, easy to learn, develop and debug
  • With a lot of repetitive work reduction, it can be more easily control the training process in development and research
  • Friendly to multi-network interactive training, such as GAN, transfer learning, knowledge distillation, etc.


python --model resnet18 --save_dir cifar100_resnet18 --config_path ./configs/cifar100.yml



All CIFAR100 log files can be downloaded in package here:

Network Accuracy log
resnet18 76.46 code: ewnd
resnet34 77.23 code: dq4r
resnet50 76.82 code: 1e62
resnet101 77.32 code: myfv
vgg11_bn 70.52 code: 2pun
vgg13_bn 73.71 code: 4vmm
mobilenetV2 68.99 code: e93w
shufflenet 71.17 code: lnvy
shufflenetV2 71.16 code: vmi6


All CIFAR10 log files can be downloaded in package here:

Network Accuracy log
resnet18 94.92 code: a20j
resnet34 94.80 code: q8h1
resnet50 94.81 code: f3wr
resnet101 95.45 code: d3i8
vgg11_bn 92.21 code: di45
vgg13_bn 93.74 code: su1z
mobilenetV2 90.92 code: todf
shufflenet 92.06 code: 1xr2
shufflenetV2 91.61 code: 8swu

Demo Train

  • import modules

    from src import DatasetBuilder, TransformBuilder, ModelBuilder, LossBuilder, LossWrapper, OptimizerBuilder, SchedulerBuilder, MetricBuilder
    from import DataLoader
    from src import Controller
    from src.utils import AverageMeter
  • Load your dataloader and transform

    transform_name = 'cifar100_transform'  # your transform function name
    dataset_name = 'CIFAR100'  # your dataset class name
    train_transform, val_transform = TransformBuilder.load(transform_name)
    trainset, trainset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=train_transform, train=True)
    valset, valset_config = DatasetBuilder.load(dataset_name=dataset_name, transform=val_transform, train=False)
    train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)
  • Load model, loss wrapper, metrics, optimizer, schemes and controller

    epochs = 20
    model = ModelBuilder.load("resnet18", num_classes=100)
    # if categorical weights: LossBuilder.load("CrossEntropyLoss", weight=torch.tensor([1.0, 2.0]).float())
    loss_func1 = LossBuilder.load("CrossEntropyLoss")
    # Only one loss function, cross entropy, set its weight to 1.0
    loss_wrapper = LossWrapper([loss_func1], [1.0])
    model = model.cuda()
    loss_wrapper = loss_wrapper.cuda()
    metric_functions = [MetricBuilder.load('Accuracy')]
    optimizer, optimizer_param = OptimizerBuilder.load('Adam', model.parameters(), lr=0.1)
    scheduler, scheduler_param = SchedulerBuilder.load("cosine_annealing_lr", optimizer, max_epoch=epochs)
    controller = Controller(loss_wrapper=loss_wrapper, model=model, optimizer=optimizer)
  • Train !

    for epoch in range(epochs):
        # train
        loss_recorder = AverageMeter(type='scalar', name='total loss')
        loss_list_recorder = AverageMeter(type='tuple', num_scalar=1, names=["CrossEntropyLoss"])
        metric_list_recorder = AverageMeter(type='tuple', num_scalar=1, names=["Accuracy"])
        for (img, label) in train_loader:
            img = img.cuda()
            label = label.cuda()
            loss, loss_tuple, output_no_grad = controller.train_step(img, label)
            loss_recorder.update(loss.item(), img.size(0))
            loss_list_recorder.update(loss_tuple, img.size(0))
            metrics = tuple([func(output_no_grad, label) for func in metric_functions])
            metric_list_recorder.update(metrics, img.size(0))
        print(f"total loss:{loss_recorder.get_value()} loss_tuple:{loss_list_recorder} metrics:{metric_list_recorder}")
        # eval
        # ...


Customize Dataset

  • In src/datasets directory, define your customized like,.
  • CIFAR class needs some parameters for initialization, such as root, train, download, which can be specified in src/datasets/dataset_config.yml. Something should be noticed that transform needs to be set in, details can be found at Customize Transform.
  • In src/datasets/, please import your dataset class. For example, MyDataset class is defined in, thus from .mydataset import * in
  • In configs/xxx.yml, set to MyDataset

Customize Transform

  • In src/datasets/, define your transform function, named my_transform_func, which returns train_transform and val_transform
  • In configs/xxx.yml, please set dataset.transform_name to my_transform_func

Customize Model

  • In src/models directory, define your customized model, such as , and define the module class MyModel. Please refer to
  • In src/models/, import your model. from .my_model import * under the try process, and from my_model import * under the except process. It is just convenient for debugging.
  • In configs/xxx.yml, set the model['name'] to MyModel

Customize Loss Function

  • In src/losses directory, you can define customized loss Module, such as CrossEntropyLoss in
  • Then import your loss Module in
  • Maybe your model is supervised by multiple loss functions, which have different weights, so LossWrapper Module in src/losses/ may meet the requirement.
  • In configs/xxx.yml, please add your loss names and weights into train.criterion.names and train/criterion/loss_weights respectively.

Customize Optimizer

  • In src/optimizer directory, can be found, please define your customized optimizer here. For example, SGD and Adam have already defined, parameters and lr should be specified, and other params need to be specifed by *args, **kwargs. Please refer to

    # src/optimizer/
    def SGD(parameters, lr, *args, **kwargs) -> optim.Optimizer:
        optimizer = optim.SGD(parameters, lr=lr, *args, **kwargs)
        return optimizer
  • For other parameters, such as weight_decay, can be set in src/optimizer/optimizer_config.yml. Please refer to the below yaml, and it is ok for 5e-4 format, we transform it in src/optimizer/

    # src/optimizer/optimizer_config.yml
      momentum: 0.9
      weight_decay: 5e-4
      dampening: 0
      nesterov: False
  • In configs/xxx.yml, set the train['lr'], and set the train['optimizer'] to SGD

Customize Schemes

  • In src/schemes/, define your learning rate scheme function, named my_scheduler, which requires some params, such as optimizer, epochs and so on.
  • Other params can be specified easily in src/schemes/scheme_config.yml
  • In configs/xxx.yml, set the train.schedule to my_scheduler

Customize Metrics

  • In src/metrics/ directory, define your metric, such as Accuracy in, which computes the metric of predictions and target and returns an metric scalar
  • Import your metric in, for example, from .accuracy import *
  • Multiple metrics are supported, in configs/xxx.yml, add your metrics into train.metric.names. While training model, the strategy of saving checkpoint refers to the train.metrics.key_metric_name in, more details can be found at Customize Checkpoint Saving Strategy

Customize Training and Validation Procedure for One Batch

  • In src/, please feel free to build your training end validation step
  • Training step returns loss, loss_tuple and output_no_grad, where loss_tuple and output_no_grad only involve in logging, whether loss has a gradient or not depends on you.

Customize Checkpoint Saving Strategy

  • After training epoch, validation epoch will be performed in general. Torch-atom’s NetIO in src/utils/ will save the best state dict according to key_metric_name and strategy in configs/xxx.yml
  • Of course, checkpoint can be saved each save_freq epoch, which can be set in configs/xxx.yml as well


  • DDP training
  • More experiment results
  • More widely-used datasets and models
  • Some visualization code for analysis
    • bad case analysis
    • data augmentation visualization


Torch-atom got ideas and developed based on the following projects:




If you find this project useful in your research, please consider cite:

    title={Torch-atom: A basic and simple training framework for pytorch},
    author={Baitan Shao},
    howpublished = {\url{}},


The MIT License | Open Source Initiative


Please feel free to submit issues, ?


View Github