GAN Ensembling
Ensembling with Deep Generative Views.
Lucy Chai, Jun-Yan Zhu, Eli Shechtman, Phillip Isola, Richard Zhang
CVPR 2021
Prerequisites
- Linux
- Python 3
- NVIDIA GPU + CUDA CuDNN
Table of Contents:
- Colab - run a limited demo version without local installation
- Setup - download required resources
- Quickstart - short demonstration code snippet
- Notebooks - jupyter notebooks for visualization
- Pipeline - details on full pipeline
We project an input image into the latent space of a pre-trained GAN and perturb it slightly to obtain modifications of the input image. These alternative views from the GAN are ensembled at test-time, together with the original image, in a downstream classification task.
To synthesize deep generative views, we first align (Aligned Input) and reconstruct an image by finding the corresponding latent code in StyleGAN2 (GAN Reconstruction). We then investigate different approaches to produce image variations using the GAN, such as style-mixing on fine layers (Style-mix Fine), which predominantly changes color, or coarse layers (Style-mix Coarse), which changes pose.
Colab
This Colab Notebook demonstrates the basic latent code perturbation and classification procedure in a simplified setting on the aligned cat dataset.
Setup
- Clone this repo:
git clone https://github.com/chail/gan-ensembling.git
cd gan-ensembling
-
Install dependencies:
- we provide a Conda
environment.yml
file listing the dependencies. You can create the Conda environment using:
conda env create -f environment.yml
- we provide a Conda
-
Download resources:
- We provide a script for downloading associated resources.
- It will download precomputed latent codes (cat: 291M, car: 121M, celebahq: 1.8G, cifar10: 883M), a subset of trained models (592M), precomputed results (1.3G), and associated libraries.
- Fetch the resources by running
bash resources/download_resources.sh
- Note, Optional: to run the StyleGAN ID-invert models, the models need to be downloaded separately. Follow the directions here to obtain
styleganinv_ffhq256_encoder.pth
andstyleganinv_ffhq256_encoder.pth
, and place them inmodels/pretrain
- Note, Optional: the download script downloads a subset of the pretrained models for the demo notebook. For further experiments, the additional pretrained models (total 7.0G) can be downloaded here; it includes 40 binary face attribute classifiers, and classifiers trained on the different perturbation methods for the remaining datasets.
-
Download external datasets:
- CelebA-HQ: Follow the instructions here to create the CelebA-HQ dataset and place CelebA-HQ images in directory
dataset/celebahq/images/images
. - Cars: This dataset is a subset of Cars196. Download the images from here and the devkit from here. (We are subsetting their training images into train/test/val partitions). Place the images in directory
dataset/cars/images/images
and the devkit indataset/cars/devkit
. - The processed and aligned cat images are downloaded with the above resources, and cifar10 dataset is downloaded via the PyTorch wrapper.
- CelebA-HQ: Follow the instructions here to create the CelebA-HQ dataset and place CelebA-HQ images in directory
An example of the directory organization is below:
dataset/celebahq/
images/images/
000004.png
000009.png
000014.png
...
latents/
latents_idinvert/
dataset/cars/
devkit/
cars_meta.mat
cars_test_annos.mat
cars_train_annos.mat
...
images/images/
00001.jpg
00002.jpg
00003.jpg
...
latents/
dataset/catface/
images/
latents/
dataset/cifar10/
cifar-10-batches-py/
latents/
Quickstart
Once the datasets and precomputed resources are downloaded, the following code snippet demonstrates how to perturb GAN images. Additional examples are contained in notebooks/demo.ipynb
.
import data
from networks import domain_generator
dataset_name = 'celebahq'
generator_name = 'stylegan2'
attribute_name = 'Smiling'
val_transform = data.get_transform(dataset_name, 'imval')
dset = data.get_dataset(dataset_name, 'val', attribute_name, load_w=True, transform=val_transform)
generator = domain_generator.define_generator(generator_name, dataset_name)
index = 100
original_image = dset[index][0][None].cuda()
latent = dset[index][1][None].cuda()
gan_reconstruction = generator.decode(latent)
mix_latent = generator.seed2w(n=4, seed=0)
perturbed_im = generator.perturb_stylemix(latent, 'fine', mix_latent, n=4)
Notebooks
Important: First, set up symlinks required for notebooks: bash notebooks/setup_notebooks.sh
, and add the conda environment to jupyter kernels: python -m ipykernel install --user --name gan-ensembling
.
The provided notebooks are:
notebooks/demo.ipynb
: basic usage examplenotebooks/evaluate_ensemble.ipynb
: plot classification test accuracy as a function of ensemble weightnotebooks/plot_precomputed_evaluations.ipynb
: notebook to generate figures in paper
Full Pipeline
The full pipeline contains three main parts:
- optimize latent codes
- train classifiers
- evaluate the ensemble of GAN-generated images.
Examples for each step of the pipeline are contained in the following scripts:
bash scripts/optimize_latent/examples.sh
bash scripts/train_classifier/examples.sh
bash scripts/eval_ensemble/examples.sh
To add to the pipeline:
- Data: in the
data/
directory, add the dataset indata/__init__.py
and create the dataset class and transformation functions. Seedata/data_*.py
for examples. - Generator: modify
networks/domain_generators.py
to add the generator indomain_generators.define_generator
. The perturbation ranges for each dataset and generator are specified innetworks/perturb_settings.py
. - Classifier: modify
networks/domain_classifiers.py
to add the classifier indomain_classifiers.define_classifier
Acknowledgements
We thank the authors of these repositories:
- Gan Seeing for GAN and visualization utilities
- StyleGAN 2 Pytorch for pytorch implementation of StyleGAN 2 and pretrained models (license)
- Stylegan 2 ADA Pytorch for the class-conditional StyleGAN 2 CIFAR10 generator (license)
- StyleGAN In-domain inversion for the in-domain stylegan generator and encoder (license)
- Pytorch CIFAR for CIFAR10 classification (license)
- Latent Composition for some code and remaining encoders (license)
- Cat dataset images are from the Oxford-IIIT Pet Dataset (license), aligned using the Frederic landmark detector (license).
Citation
If you use this code for your research, please cite our paper:
@inproceedings{chai2021ensembling,
title={Ensembling with Deep Generative Views.},
author={Chai, Lucy and Zhu, Jun-Yan and Shechtman, Eli and Isola, Phillip and Zhang, Richard},
booktitle={CVPR},
year={2021}
}