Convex Potential Flows

"Convex Potential Flows: Universal Probability Distributions with Optimal Transport and Convex Optimization" by Chin-Wei Huang, Ricky T. Q. Chen, Christos Tsirigotis, Aaron Courville. In ICLR 2021.

Dependencies:

run pip install -r requirements.txt

Datasets

Experiments

•• Important ••
Unless otherwise specified, the loss (negative log likelihood) printed during training
is not a measure of the log likelihood; instead it's a "surrogate" loss function explained in the paper:
differentiating this surrogate loss will give us an stochastic estimate of the gradient.

To clarify:
When the model is in the .train() mode, the forward_transform_stochastic function is used to give a stochastic estimate of
the logdet "gradient".
When in .eval() mode, a stochastic estimate of the logdet itself (using Lanczos) will be provided.
The forward_transform_bruteforce function computes the logdet exactly.

As an example, we've used the following context wrapper in train_tabular.py to obtain a likelihood estimate
throughout training:

def eval_ctx(flow, bruteforce=False, debug=False, no_grad=True):
    flow.eval()
    for f in flow.flows[1::2]:
        f.no_bruteforce = not bruteforce
    torch.autograd.set_detect_anomaly(debug)
    with torch.set_grad_enabled(mode=not no_grad):
        yield
    torch.autograd.set_detect_anomaly(False)
    for f in flow.flows[1::2]:
        f.no_bruteforce = True
    flow.train()

Turning flow.no_bruteforce to False will force the flow to calculate logdet exactly in .eval() mode.

Toy 2D experiments

To reproduce the toy experiments, run the following example cmd line

python train_toy.py --dataset=EightGaussian --nblocks=1 --depth=20 --dimh=32

Here's the learned density

EightGaussian_1_cpflow_20_32

When only one flow layer (--nblocks=1) is used, it will also produce a few interesting plots
for one to analyze the flow, such as the

(Convex) potential function

EightGaussian_1_20_32_contour

and the corresponding gradient distortion map

EightGaussian_1_20_32_z_meshgrid

For the 8 gaussian experiment, we've color-coded the samples to visualize the encodings:

EightGaussian_1_20_32_z_encode

Toy image point cloud

We can also set --img_file to learn the "density" of a 2D image as follows:

python train_toy.py --img_file=imgs/github.png --dimh=64 --depth=10

Img2dData_1_10_64_x_sample

Toy conditional 2D experiments

We've also have a toy conditional experiment to assess the representational power of
the partial input convex neural network (PICNN). The dataset is a 1D mixture of Gaussian
whose weighting coefficient is to be conditioned on (the values in the legend in the following figure).

python train_toy_cond.py 

Running the above code will generate the following conditional density curves

1dMOG

OT map learning

To learn the optimal transport map (between Gaussians), run

python train_ot.py 

(Modify dimx = 2 in the code for higher dimensional experiments)

CP-Flow will learn to transform the input Gaussian

OT_x_cpflow

into a prior standard Gaussian "monotonically"

OT_z_cpflow

This means the transport map is the most efficient one in the OT sense
(in contrast, IAF also learns a transport map with 0 KL, but it has a higher transport cost):

OT_dimx2

Larger scale experiments

For larger scale experiments reported in the paper, run the following training scripts:

GitHub