DatasetGAN
This is the official code and data release for:
DatasetGAN: Efficient Labeled Data Factory with Minimal Human Effort
Yuxuan Zhang*, Huan Ling*, Jun Gao, Kangxue Yin, Jean-Francois Lafleche, Adela Barriuso, Antonio Torralba, Sanja Fidler
Requirements
-
Python 3.6 or 3.7 are supported.
-
Pytorch 1.4.0 + is recommended.
-
This code is tested with CUDA 10.2 toolkit and CuDNN 7.5.
-
Please check the python package requirement from
requirements.txt
, and install usingpip install -r requirements.txt
Download Dataset from google drive and put it in the folder of ./datasetGAN/dataset_release. Please be aware that the dataset of DatasetGAN is released under the Creative Commons BY-NC 4.0 license by NVIDIA Corporation.
Download pretrained checkpoint from Stylegan and convert the tensorflow checkpoint to pytorch. Put checkpoints in the folder of ./datasetGAN/dataset_release/stylegan_pretrain. Please be aware that the any code dependency and checkpoint related to Stylegan, the license is under the Creative Commons BY-NC 4.0 license by NVIDIA Corporation.
Note: a good example of converting stylegan tensorlow checkpoint to pytorch is available this Link.
Training
To reproduce paper DatasetGAN: Efficient Labeled Data Factory with Minimal Human Effort:
cd datasetGAN
- Run Step1: Interpreter training.
- Run Step2: Sampling to generate massive annotation-image dataset.
- Run Step3: Train Downstream Task.
1. Interpreter Training
python train_interpreter.py --exp experiments/<exp_name>.json
Note: Training time for 16 images is around one hour. 160G RAM is required to run 16 images training. One can cache the data returned from prepare_data function to disk but it will increase trianing time due to I/O burden.
Example of annotation schema for Face class. Please refer to paper for other classes.
2. Run GAN Sampling
python train_interpreter.py \
--generate_data True --exp experiments/<exp_name>.json \
--resume [path-to-trained-interpreter in step3] \
--num_sample [num-samples]
To run sampling processes in parallel
sh datasetGAN/script/generate_face_dataset.sh
Example of sampling images and annotation:
3. Train Downstream Task
python train_deeplab.py \
--data_path [path-to-generated-dataset in step4] \
--exp experiments/<exp_name>.json
Inference
python test_deeplab_cross_validation.py --exp experiments/face_34.json\
--resume [path-to-downstream task checkpoint] --cross_validate True
June 21st Update:
For training interpreter, we change the upsampling method from nearnest upsampling to bilinar upsampling in line and update results in Table 1. The table reports mIOU.
![](https://github.com/nv-tlabs/datasetGAN_release/raw/master/figs/new_table.png =80%x)
Citations
Please ue the following citation if you use our data or code:
@inproceedings{zhang2021datasetgan,
title={Datasetgan: Efficient labeled data factory with minimal human effort},
author={Zhang, Yuxuan and Ling, Huan and Gao, Jun and Yin, Kangxue and Lafleche, Jean-Francois and Barriuso, Adela and Torralba, Antonio and Fidler, Sanja},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={10145--10155},
year={2021}
}