few-shot-gan-adaptation

Official repository for Few-shot Image Generation via Cross-domain Correspondence (CVPR '21)

Utkarsh Ojha, Yijun Li, Jingwan Lu, Alexei A. Efros, Yong Jae Lee, Eli Shechtman, Richard Zhang

Adobe Research, UC Davis, UC Berkeley

concept
Repository for downloading the datasets and generated images used for performing the evaluations shown in Tables 1 and 2.

Overview

method_diagram

Our method helps adapt the source GAN where one-to-one correspondence is preserved between the source Gs(z) and target Gt(z) images.

Sample images from a model

To generate images from a pre-trained GAN, run the following command:

CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_target model_name

Here, model_name follows the notation of source_target, e.g. ffhq_sketches. Use the --load_noise option to use the noise vectors used for some figures in the paper (Figures 1-4). For example:

CUDA_VISIBLE_DEVICES=0 python generate.py --ckpt_target ffhq_sketches --load_noise noise.pt

Visualizing correspondence results

To visualize the same noise in the source and adapted models, i.e. Gs(z) and Gs→t(z), run the following command(s):

# generate two image grids of 5x5 for source and target
CUDA_VISIBLE_DEVICES=0 python3 generate.py --ckpt_source source_ffhq --ckpt_target ffhq_caricatures --load_noise noise.pt

# visualize the interpolations of source and target
CUDA_VISIBLE_DEVICES=0 python3 generate.py --ckpt_source source_ffhq --ckpt_target ffhq_caricatures --load_noise noise.pt --mode interpolate

Hand gesture experiments

We collected images of random hand gestures being performed on a plain surface (~ 18k images), and used that as the data to train a source model (from scratch). We then adapted it to two different target domains; Landscape images and Google maps. The goal was to see if, during inference, interpolating the hand genstures can result in meaningful variations in the target images. Run the following commands to see the results:

CUDA_VISIBLE_DEVICES=0 python3 generate.py --ckpt_source source_hand --ckpt_target hand_maps --load_noise noise.pt --mode interpolate
CUDA_VISIBLE_DEVICES=0 python3 generate.py --ckpt_source source_hand --ckpt_target hand_landscapes --load_noise noise.pt --mode interpolate

Evaluating FID

There are three sets of images which are used to get the results in Table 1:

  • A set of real images from a target domain -- Rtest
  • 10 images from the above set (Rtest) used to train the algorithm -- Rtrain
  • 5000 generated images using the GAN-based method -- F

The following table provides a link to each of these images:

Rtrain Rtest F
Babies link link link
Sunglasses link link link
Sketches link link link

Rtrain is given just to illustate what the algorithm sees, and won't be used for computing the FID score.

Download, and unzip the set of images into your desired directory, and compute the FID score (taken from pytorch-fid) between the real (Rtest) and fake (F) images, by running the following command

python -m pytorch_fid /path/to/real/images /path/to/fake/images

Evaluating intra-cluster distance

Download the entire set of images from this link (1.1 GB), which are used for the results in Table 2. The organization of this collection is as follows:

cluster_centers
└── amedeo			# target domain -- will be from [amedeo, sketches]
    └── ours			# baseline -- will be from [tgan, tgan_ada, freezeD, ewc, ours]
        └── c0			# center id -- there will be 10 clusters [c0, c1 ... c9]
            ├── center.png	# cluster center -- this is one of the 10 training images used. Each cluster will have its own center
            │── img0.png   	# generated images which matched with this cluster's center, according to LPIPS metric.
            │── img1.png
            │      .
	    │      .
                   

Unzip the file, and then run the following command to compute the results for a baseline on a dataset:

CUDA_VISIBLE_DEVICES=0 python3 feat_cluster.py --baseline <baseline> --dataset <target_domain> --mode intra_cluster_dist

CUDA_VISIBLE_DEVICES=0 python3 feat_cluster.py --baseline tgan --dataset sketches --mode intra_cluster_dist

We also provide the utility to visualize the closest and farthest members of a cluster, as shown in Figure 14 (shown below), using the following command:

CUDA_VISIBLE_DEVICES=0 python3 feat_cluster.py --baseline tgan --dataset sketches --mode visualize_members

The command will save the generated image which is closest/farthest to/from a center as closest.png/farthest.png respectively.

cluster_members

Note We cannot share the images for the caricature domain due to license issues.

Training (adapting) your own GAN

Choose the source domain

  • Only the pre-trained model is needed, i.e. no need for access to the source data.
  • Refer to the first column of the pre-trained models table above.
  • If you wish to use some other source model, make sure that it follows the generator architecture defined in this pytorch implementation of StyleGAN2

Bibtex

If you find our code useful, please cite our paper:

@inproceedings{ojha2021few-shot-gan,
  title={Few-shot Image Generation via Cross-domain Correspondence},
  author={Ojha, Utkarsh and Li, Yijun and Lu, Cynthia and Efros, Alexei A. and Lee, Yong Jae and Shechtman, Eli and Zhang, Richard},
  booktitle={CVPR},
  year={2021}
}

GitHub

https://github.com/utkarshojha/few-shot-gan-adaptation