MIRO: Mutual Information Regularization with Oracle

Official PyTorch implementation of Domain Generalization by Mutual-Information Regularization with Pre-trained Models.

Junbum Cha, Kyungjae Lee, Sungrae Park, Sanghyuk Chun.

Preparation

Dependencies

pip install -r requirements.txt

Datasets

python -m domainbed.scripts.download --data_dir=/my/datasets/path

Environments

Environment details used for the main experiments. Every main experiment is conducted on a single NVIDIA V100 GPU.

Environment:
	Python: 3.7.7
	PyTorch: 1.7.1
	Torchvision: 0.8.2
	CUDA: 10.1
	CUDNN: 7603
	NumPy: 1.21.4
	PIL: 7.2.0

How to Run

train_all.py script conducts multiple leave-one-out cross-validations for all target domain.

python train_all.py exp_name --dataset PACS --data_dir /my/dataset/path --algorithm MIRO

Main experiments

Run command with hyperparameters (HPs):

python train_all.py exp_name --data_dir /my/dataset/path --algorithm MIRO \
    --dataset PACS \
    --lr 3e-5 \
    --resnet_dropout 0.0 \
    --weight_decay 0.0 \
    --ld 0.01 \
    --trial_seed 0

Our searched HPs:

PACS VLCS OfficeHome TerraIncognita DomainNet
Learning rate 3e-5 1e-5 3e-5 3e-5 3e-5
Dropout 0.0 0.5 0.1 0.0 0.1
Weight decay 0.0 1e-6 1e-6 1e-4 0.0
Lambda 0.01 0.01 0.1 0.1 0.1

Combination with SWAD

Set --swad True to combine with SWAD.

python train_all.py exp_name --data_dir /my/dataset/path --algorithm MIRO \
    --dataset PACS \
    --ld 0.01 \
    --swad True \
    --trial_seed 0

Experiments with various backbones

You can run MIRO with different backbones via --model parameter:

# model is one of [resnet50, resnet50_barlowtwins, resnet50_moco, clip_resnet50, clip_vit-b16, swag_regnety_16gf]
python train_all.py exp_name --data_dir /my/dataset/path --algorithm MIRO \
    --dataset PACS --model resnet50

The checkpoint should be prepared before run MoCo v3 (resnet50_moco). You can download ResNet-50 MoCo v3 on 1000 epochs here.

Reproduce the main results of the paper

We provide the commands to reproduce the main results of the paper (Table 1). Note that every result is averaged over three trials; use trial_seed option as 0, 1, 2 and average the results.

Main experiment (65.9% by MIRO)

python train_all.py PACS --data_dir /my/dataset/path --algorithm MIRO --dataset PACS --lr 3e-5 --resnet_dropout 0.0 --weight_decay 0.0 --ld 0.01
python train_all.py VLCS --data_dir /my/dataset/path --algorithm MIRO --dataset VLCS --lr 1e-5 --resnet_dropout 0.5 --weight_decay 1e-6 --ld 0.01
python train_all.py OfficeHome --data_dir /my/dataset/path --algorithm MIRO --dataset OfficeHome --lr 3e-5 --resnet_dropout 0.1 --weight_decay 1e-6 --ld 0.1
python train_all.py TerraIncognita --data_dir /my/dataset/path --algorithm MIRO --dataset TerraIncognita --lr 3e-5 --resnet_dropout 0.0 --weight_decay 1e-4 --ld 0.1
python train_all.py DomainNet --data_dir /my/dataset/path --algorithm MIRO --dataset DomainNet --lr 3e-5 --resnet_dropout 0.1 --weight_decay 0.0 --ld 0.1
Combination with SWAD (68.1% by MIRO + SWAD)

python train_all.py PACS --data_dir /my/dataset/path --algorithm MIRO --dataset PACS --ld 0.01 --swad True
python train_all.py VLCS --data_dir /my/dataset/path --algorithm MIRO --dataset VLCS --ld 0.01 --checkpoint_freq 50 --tolerance_ratio 0.2 --swad True
python train_all.py OfficeHome --data_dir /my/dataset/path --algorithm MIRO --dataset OfficeHome --ld 0.1 --swad True
python train_all.py TerraIncognita --data_dir /my/dataset/path --algorithm MIRO --dataset TerraIncognita --ld 0.1 --swad True
python train_all.py DomainNet --data_dir /my/dataset/path --algorithm MIRO --dataset DomainNet --ld 0.1 --checkpoint_freq 500 --swad True
Pushing the limits (77.3% by MIRO + SWAD + SWAG)

python train_all.py PACS --data_dir /my/dataset/path --algorithm MIRO --dataset PACS --ld 0.01 --model swag_regnety_16gf --batch_size 16 --swad True
python train_all.py VLCS --data_dir /my/dataset/path --algorithm MIRO --dataset VLCS --ld 0.01 --checkpoint_freq 50 --tolerance_ratio 0.2 --model swag_regnety_16gf --batch_size 16 --swad True
python train_all.py OfficeHome --data_dir /my/dataset/path --algorithm MIRO --dataset OfficeHome --ld 0.01 --model swag_regnety_16gf --batch_size 16 --swad True
python train_all.py TerraIncognita --data_dir /my/dataset/path --algorithm MIRO --dataset TerraIncognita --ld 0.01 --model swag_regnety_16gf --batch_size 16 --swad True
python train_all.py DomainNet --data_dir /my/dataset/path --algorithm MIRO --dataset DomainNet --ld 0.1 --checkpoint_freq 500 --model swag_regnety_16gf --batch_size 16 --swad True

Main Results

Citation

@article{cha2022miro,
  title={Domain Generalization by Mutual-Information Regularization with Pre-trained Models},
  author={Junbum Cha and Kyungjae Lee and Sungrae Park and Sanghyuk Chun},
  journal={arXiv preprint arXiv:2203.10789},
  year={2022}
}

License

This project is released under the MIT license, included here.

This project include some code from facebookresearch/DomainBed (MIT license) and khanrc/swad (MIT license).

GitHub

View Github