Vision-aided GAN
video (3m) | website | paper
Can the collective knowledge from a large bank of pretrained vision models be leveraged to improve GAN training? If so, with so many models to choose from, which one(s) should be selected, and in what manner are they most effective?
We find that pretrained computer vision models can significantly improve performance when used in an ensemble of discriminators. We propose an effective selection mechanism, by probing the linear separability between real and fake samples in pretrained model embeddings, choosing the most accurate model, and progressively adding it to the discriminator ensemble. Our method can improve GAN training in both limited data and large-scale settings.
Ensembling Off-the-shelf Models for GAN Training
Nupur Kumari, Richard Zhang, Eli Shechtman, Jun-Yan Zhu
arXiv 2112.09130, 2021
Quantitative Comparison
Our method outperforms recent GAN training methods by a large margin, especially in limited sample setting. For LSUN Cat, we achieve similar FID as StyleGAN2 trained on the full dataset using only $0.7%$ of the dataset. On the full dataset, our method improves FID by 1.5x to 2x on cat, church, and horse categories of LSUN.
Example Results
Below, we show visual comparisons between the baseline StyleGAN2-ADA and our model (Vision-aided GAN) for the
same randomly sample latent code.
Interpolation Videos
Latent interpolation results of models trained with our method on AnimalFace Cat (160 images), Dog (389 images), and Bridge-of-Sighs (100 photos).
Requirements
- 64-bit Python 3.8 and PyTorch 1.8.0 (or later). See https://pytorch.org/ for PyTorch install instructions.
- Cuda toolkit 11.0 or later.
- python libraries: see requirements.txt
- StyleGAN2 code relies heavily on custom PyTorch extensions. For detail please refer to the repo stylegan2-ada-pytorch
Setting up Off-the-shelf Computer Vision models
CLIP(ViT): we modify the model.py function to return intermediate features of the transformer model. To set up follow these steps.
git clone https://github.com/openai/CLIP.git
mv vision-aided-gan/training/clip_model.py CLIP/clip/model.py
cd CLIP
python setup.py install
DINO(ViT): model is automatically downloaded from torch hub.
VGG-16: model is automatically downloaded.
Swin-T(MoBY): Create a “pretrained-models” directory and save the downloaded model there.
Swin-T(Object Detection): follow the below step for setup. Download the model here and save it in the “pretrained-models” directory.
git clone https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
cd Swin-Transformer-Object-Detection
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
python setup.py install
Swin-T(Segmentation): follow the below step for setup. Download the model here and save it in the “pretrained-models” directory.
git clone https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation.git
cd Swin-Transformer-Semantic-Segmentation
remove assert statement from __init__.py
python setup.py install
Face Parsing:download the model here and save in the “pretrained-models” directory.
Face Normals:download the model here and save in the “pretrained-models” directory.
Datasets
Dataset preparation is same as given in stylegan2-ada-pytorch.
Example setup for LSUN Church
LSUN Church
git clone https://github.com/fyu/lsun.git
cd lsun
python3 download.py -c church_outdoor
unzip church_outdoor_train_lmdb.zip
cd ../vision-aided-gan
python dataset_tool.py --source <path-to>/church_outdoor_train_lmdb/ --dest <path-to-datasets>/church1k.zip --max-images 1000 --transform=center-crop --width=256 --height=256
datasets can be downloaded from their repsective websites:
FFHQ, LSUN Categories, AFHQ, AnimalFace Dog, AnimalFace Cat, 100-shot Bridge-of-Sighs
Training new networks
model selection: returns the computer vision model with highest linear probe accuracy for the best FID model in a folder or the given network file.
python model_selection.py --data mydataset.zip --network <mynetworkfolder or mynetworkpklfile>
example training command for training with a single pretrained network from scratch
python train.py --outdir=training-models/ --data=mydataset.zip --gpus 2 --metrics fid50k_full --kimg 25000 --cfg ffhq1k --cv input-dino-output-conv_multi_level --cv-loss multilevel_s --augcv ada --ada-target-cv 0.3 --augpipecv bgc --batch 16 --mirror 1 --aug ada --augpipe bgc --snap 25 --warmup 1
Training configuration corresponding to training with vision-aided-loss:
--cv=input-dino-output-conv_multi_level
pretrained network and its configuration.--warmup=0
should be enabled when training from scratch. Introduces our loss after training with 500k images.--cv-loss=multilevel
what loss to use on pretrained model based discriminator.--augcv=ada
performs ADA augmentation on pretrained model based discriminator.--augpipecv=bgc
ADA augmentation strategy.--ada-target-cv=0.3
adjusts ADA target value for pretrained model based discriminator.--exact-resume=0
enables exact resume along with optimizer state.
Miscellaneous configurations:
--appendname=''
additional string to append to training directory name.--wandb-log=0
enables wandb logging.--clean=0
enables FID calculation using clean-fid if the real distribution statistics are pre-calculated.
Pretrained Models can be downloaded at this link
Acknowledgments
We thank Muyang Li, Sheng-Yu Wang, Chonghyuk (Andrew) Song for proofreading the draft. We are also grateful to Alexei A. Efros, Sheng-Yu Wang, Taesung Park, and William Peebles for helpful comments and discussion. Our codebase is built on stylegan2-ada-pytorch and DiffAugment.