Imagen – Pytorch (wip)

Implementation of Imagen, Google’s Text-to-Image Neural Network that beats DALL-E2, in Pytorch. It is the new SOTA for text-to-image synthesis.

Architecturally, it is actually much simpler than DALL-E2. It composes of a cascading DDPM conditioned on text embeddings from a large pretrained T5 model (attention network). It also contains dynamic clipping for improved classifier free guidance, noise level conditioning, and a memory efficient unet design.

It appears neither CLIP nor prior network is needed after all. And so research continues.

Install

$ pip install imagen-pytorch

Usage

import torch
from imagen_pytorch import Unet, Imagen

# unet for imagen

unet1 = Unet(
    dim = 32,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).cuda()

unet2 = Unet(
    dim = 32,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).cuda()

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    unet = (unet1, unet2),
    image_sizes = (64, 256),
    timesteps = 100,
    cond_drop_prob = 0.5
).cuda()

# mock images (get a lot of this) and text encodings from large T5

text_embeds = torch.randn(4, 256, 512).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# feed images into imagen, training each unet in the cascade

for i in (1, 2):
    loss = imagen(images, text_embeds = text_embeds, unet_number = i)
    loss.backward()

# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = imagen.sample(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
])

images.shape # (3, 3, 256, 256)

Todo

  • use huggingface transformers for T5-small text embeddings
  • add dynamic thresholding
  • add dynamic thresholding DALLE2 and video-diffusion repository as well
  • allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer)
  • separate unet into base unet and SR3 unet
  • build whatever efficient unet they came up with
  • add the noise level conditioning with the pseudocode in appendix, and figure out what is this sweep they do at inference time
  • port over some training code from DALLE2
  • figure out if learned variance was used at all, and remove it if it was inconsequential

Citations

@misc{Saharia2022,
    title   = {Imagen: unprecedented photorealism × deep level of language understanding}, 
    author  = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
    year    = {2022}
}

GitHub

View Github