PyTorch implementation of VAGAN: Visual Feature Attribution Using Wasserstein GANs

This code aims to reproduce results obtained in the paper "Visual Feature Attribution using Wasserstein GANs" (official repo, TensorFlow code)

This repository contains the code to reproduce results for the paper cited above, where the authors presents a novel feature attribution technique based on Wasserstein Generative Adversarial Networks (WGAN). The code works for both synthetic (2D) and real 3D neuroimaging data, you can check below for a brief description of the two datasets.

anomaly maps examples

Here is an example of what the generator/mapper network should produce: ctrl-click on the below image to open the gifv in a new tab (one frame every 50 iterations, left: input, right: anomaly map for synthetic data at iteration 50 * (its + 1)).


Synthetic Dataset

In order to quantitatively evaluate the performance
of the examined visual attribution methods, we generated
a synthetic dataset of 10000 112x112 images with two
classes, which model a healthy control group (label 0) and a
patient group (label 1). The images were split evenly across
the two categories. We closely followed the synthetic data
generation process described in [31][SubCMap: Subject and Condition Specific Effect Maps]
where disease effects were studied in smaller cohorts of registered images.
The control group (label 0) contained images with ran-
dom iid Gaussian noise convolved with a Gaussian blurring
filter. Examples are shown in Fig. 3. The patient images
(label 1) also contained the noise, but additionally exhib-
ited one of two disease effects which was generated from a
ground-truth effect map: a square in the centre and a square
in the lower right (subtype A), or a square in the centre and a
square in the upper left (subtype B). Importantly, both dis-
ease subtypes shared the same label. The location of the
off-centre squares was randomly offset in each direction by
a maximum of 5 pixels. This moving effect was added to
make the problem harder, but had no notable effect on the


ADNI Dataset

Currently we only implemented training on synthetic dataset, we will work on implement training on ADNI dataset asap (but pull requests are welcome as always), we put below ADNI dataset details for sake of completeness.

"We selected 5778 3D T1-weighted MR images from
1288 subjects with either an MCI (label 0) or AD (label 1) diagnosis from the ADNI cohort. 2839 of the images
were acquired using a 1.5T magnet, the remainder using a
3T magnet. The subjects are scanned at regular intervals as
part of the ADNI study and a number of subjects converted
from MCI to AD over the years. We did not use these cor-
respondences for training, however, we took advantage of it
for evaluation as will be described later.
All images were processed using standard operations
available in the FSL toolbox [52][Advances in functional and structural MR
image analysis and implementation as FSL.
] in order to reorient and
rigidly register the images to MNI space, crop them and
correct for field inhomogeneities. We then skull-stripped
the images using the ROBEX algorithm [24][Robust brain extraction across datasets and comparison with
publicly available methods
]. Lastly, we
resampled all images to a resolution of 1.3 mm 3 and nor-
malised them to a range from -1 to 1. The final volumes
had a size of 128x160x112 voxels."

"Data used in preparation of this article were obtained from
the Alzheimers disease Neuroimaging Initiative (ADNI) database
As such, the investigators within the ADNI
contributed to the design and implementation of ADNI and/or provided data but
did not participate in analysis or writing of this
report. A complete listing of ADNI investigators can be found at:"



To train the WGAN on this task, cd into this repo's src root folder and execute:

$ python

This script takes the following command line options:

  • dataset_root: the root directory where tha dataset is stored, default to '../dataset'

  • experiment: directory in where samples and models will be saved, default to '../samples'

  • batch_size: input batch size, default to 32

  • image_size: the height / width of the input image to network, default to 112

  • channels_number: input image channels, default to 1

  • num_filters_g: number of filters for the first layer of the generator, default to 16

  • num_filters_d: number of filters for the first layer of the discriminator, default to 16

  • nepochs: number of epochs to train for, default to 1000

  • d_iters: number of discriminator iterations per each generator iter, default to 5

  • learning_rate_g: learning rate for generator, default to 1e-3

  • learning_rate_d: learning rate for discriminator, default to 1e-3

  • beta1: beta1 for adam. default to 0.0

  • cuda: enables cuda (store True)

  • manual_seed: input for the manual seeds initializations, default to 7

Running the command without arguments will train the models with the default hyperparamters values (producing results shown above).


We ported all models found in the original repository in PyTorch, you can find all implemented models here:

Useful repositories and code

  • vagan-code: Reposiory for the reference paper from its authors

  • ganhacks: Starter from "How to Train a GAN?" at NIPS2016

  • WassersteinGAN: Code accompanying the paper "Wasserstein GAN"

  • wgan-gp: Pytorch implementation of Paper "Improved Training of Wasserstein GANs".

  • c3d-pytorch: Model used as discriminator in the reference paper

  • Pytorch-UNet: Model used as genertator in this repository

  • dcgan: Model used as discriminator in this repository

.bib citation

cite the paper as follows (copied-pasted it from arxiv for you):

  author    = {Christian F. Baumgartner and
               Lisa M. Koch and
               Kerem Can Tezcan and
               Jia Xi Ang and
               Ender Konukoglu},
  title     = {Visual Feature Attribution using Wasserstein GANs},
  journal   = {CoRR},
  volume    = {abs/1711.08998},
  year      = {2017},
  url       = {},
  archivePrefix = {arXiv},
  eprint    = {1711.08998},
  timestamp = {Sun, 03 Dec 2017 12:38:15 +0100},
  biburl    = {},
  bibsource = {dblp computer science bibliography,}