Relational Recurrent Neural Networks

RNNs with relational memory that enables reasoning across time

Relational Recurrent Neural Networks combines the temporal processing of RNNs with the relational reasoning of attention mechanisms. The result is a memory system where memories can interact with each other.

Motivation

Standard LSTMs have a fixed memory cell. But complex reasoning requires:

  • Multiple pieces of information stored simultaneously
  • Interactions between stored memories
  • Dynamic retrieval based on relationships

Relational Memory Core (RMC)

The RMC maintains multiple memory slots M=[m1,m2,...,mN]M = [m_1, m_2, ..., m_N] that interact via attention:

Mt+1=MHDPA(Mt)+MLP(MHDPA(Mt))M^{t+1} = \text{MHDPA}(M^t) + \text{MLP}(\text{MHDPA}(M^t))

where MHDPA is Multi-Head Dot Product Attention.

Key Innovation

Memories attend to each other, not just to inputs:

A=softmax(MWQ(MWK)Tdk)MWVA = \text{softmax}\left(\frac{M W_Q (M W_K)^T}{\sqrt{d_k}}\right) M W_V

This allows reasoning about relationships between stored facts.

Interactive Demo

Watch memory slots interact via attention:

Relational Memory Core

t = 0
Memory Slots (M)
m1
m2
m3
m4
Attention between memories
Multi-Head Attention
Memory slots attend to each other via MHDPA (Multi-Head Dot Product Attention)
Gated Update
Attended memories combined via gates, similar to LSTM
Key Equation
Mt+1 = MHDPA(Mt) + MLP(MHDPA(Mt))
Memories update by attending to each other—enabling relational reasoning over time

Gating Mechanism

Like LSTMs, RMC uses gates for stable updates:

M~=σ(Wg[A~;M])A~+(1σ(Wg[A~;M]))M\tilde{M} = \sigma(W_g[\tilde{A}; M]) \odot \tilde{A} + (1 - \sigma(W_g[\tilde{A}; M])) \odot M

This prevents catastrophic forgetting of important memories.

Architecture

Input → Linear projection → Concatenate with memories
     → Multi-head self-attention over all slots
     → MLP (residual)
     → Gated update
     → Output from attended memories

Results

Language Modeling (WikiText-103)

ModelPerplexity
LSTM48.7
Transformer44.1
Relational Memory31.6

Program Evaluation (Nth Farthest)

Task: Given N objects, find the Nth farthest from a query.

ModelAccuracy
LSTM17%
DNC37%
RMC91%

Why It Works

  1. Multiple memories: Can store several facts
  2. Memory interaction: Facts can “talk” to each other
  3. Attention routing: Dynamic retrieval based on relevance
  4. Temporal integration: Processes sequences naturally

Connection to Transformers

RMC anticipated key Transformer ideas:

  • Multi-head attention
  • Residual connections
  • Layer normalization

The main difference: RMC processes sequences recurrently, while Transformers process in parallel.

Key Paper

Found an error or want to contribute? Edit on GitHub