Pytorch-Struct

A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications. (or an implementation of Inside-Outside and Forward-Backward Algorithms Are Just Backprop")

Getting Started

pip install . 
import torch_struct
import torch
batch, N = 10,  100
scores = torch.rand(N, 100, 100, requires_grad=True)

# Tree marginals
marginals = torch.deptree(scores)

# Tree Argmax
argmax = torch.deptree(scores, seminring=torch_struct.MaxSemiring)
max_score = torch.mul(argmax, scores)

# Tree Counts
ones = torch.ones(N, 100, 100)
ntrees = torch.deptree(ones, semiring=torch_struct.StdSemiring)

# Tree Sample
sample = torch.deptree(scores, seminring=torch_struct.SampledSemiring)

# Tree Partition
v, _ = torch.deptree_inside(scores)

# Tree Max
v, _ = torch.deptree_inside(scores, semiring=torch_struct.MaxSemiring)

Library

Current algorithms implemented:

  • Linear Chain (CRF / HMM)
  • Semi-Markov (CRF / HSMM)
  • Dependency Parsing (Projective and Non-Projective)
  • CKY (CFG)

Design Strategy:

  1. Minimal implementatations. Most are 10 lines.
  2. Batched for GPU.
  3. Code can be ported to other backends

Semirings:

  • Log Marginals
  • Max and MAP computation
  • Sampling through specialized backprop

Example: https://github.com/harvardnlp/pytorch-struct/blob/master/notebooks/Examples.ipynb

Applications

Application Example (to come):

  • Structured Attention
  • EM training
  • Stuctured VAE
  • Posterior Regularization

GitHub