Emerging Symbol Binding Network (ESBN) - Pytorch

Usable implementation of Emerging Symbol Binding Network (ESBN), in Pytorch. They propose to have the main recurrent neural network interact with the input image representations only through a set of memory key / values.

The input image representation are cast as memory values, and are explicitly bound to memory keys that are generated by the network. The network generates the memory keys after getting a sum of all previous memory keys weighted by the similarity of the incoming representation to the set of memory values in storage.

This decoupling / indirection of sensory to abstract processing allows the network to outperform all previous approaches, including transformers.


import torch
from esbn_pytorch import ESBN

model = ESBN(
    value_dim = 64,
    key_dim = 64,
    hidden_dim = 512,
    output_dim = 4

images = torch.randn(3, 2, 3, 32, 32) # (n, b, c, h, w)
model(images) # (3, 2, 4) # (n, b, o)


    title={Emergent Symbols through Binding in External Memory}, 
    author={Taylor W. Webb and Ishan Sinha and Jonathan D. Cohen},