A Simple Neural Network Module for Relational Reasoning

Relation Networks for learning to reason about object relationships

A Simple Neural Network Module for Relational Reasoning introduced Relation Networks (RNs)—a simple but powerful architecture for learning to reason about relationships between objects.

The Problem

Standard neural networks struggle with relational reasoning:

  • “Is object A larger than object B?”
  • “What is between the red and blue objects?”
  • “Are there more circles than squares?”

These require comparing pairs of entities—not easily captured by standard architectures.

Relation Networks

The key insight: explicitly consider all pairs of objects:

RN(O)=fϕ(i,jgθ(oi,oj))\text{RN}(O) = f_\phi\left(\sum_{i,j} g_\theta(o_i, o_j)\right)

where:

  • oi,ojo_i, o_j are object representations
  • gθg_\theta processes each pair (the “relation” function)
  • fϕf_\phi aggregates all pairwise relations

Interactive Demo

Explore relational reasoning on simple visual scenes:

Relational Reasoning

Select object pair:
Visual QA
What color is the object nearest to the red circle?
Answer: Blue
Relation Network Formula
RN(O) = fφi,j gθ(oi, oj))
Consider all pairs of objects, process each pair with g, aggregate with f.

Why Pairs Matter

For nn objects, RN considers all n2n^2 pairs. This:

  • Captures relations regardless of object order
  • Scales to variable numbers of objects
  • Avoids hardcoding specific relations

Architecture Details

For visual QA:

  1. CNN extracts feature map from image
  2. Objects = spatial locations in feature map
  3. Question embedding concatenated to each pair
  4. g network (MLP) processes each (oi,oj,q)(o_i, o_j, q) triple
  5. Sum over all pairs
  6. f network (MLP) produces answer

Results on CLEVR

CLEVR is a visual reasoning benchmark with questions like “What size is the cylinder that is left of the brown metal thing?”

ModelAccuracy
CNN + LSTM42.7%
CNN + LSTM + Attention68.5%
Relation Network95.5%
Human92.6%

RNs achieved superhuman performance!

Key Properties

Permutation invariant: Summing over pairs is order-independent

Relation-centric: Explicitly models pairwise interactions

Data efficient: Strong inductive bias for relational tasks

Beyond Vision

RNs also improved:

  • Text QA (bAbI dataset)
  • Physical reasoning (predicting dynamics)
  • Graph problems (when combined with GNNs)

Connection to Attention

Self-attention can be viewed as a form of relation network:

Attention(Q,K,V)i,jsoftmax(qikj)vj\text{Attention}(Q, K, V) \approx \sum_{i,j} \text{softmax}(q_i \cdot k_j) \cdot v_j

Both aggregate pairwise interactions.

Key Paper

Found an error or want to contribute? Edit on GitHub