---
abstract: |
  Transformers process tokens in parallel but are temporally shallow: at position $t$, each layer attends to key--value pairs computed based on the previous layer, yielding a depth capped by the number of layers. Recurrent models offer unbounded temporal depth but suffer from optimization instability and historically underutilize modern accelerators. We introduce the *Recurrent Transformer*, a simple architectural change where *each layer* attends to key--value pairs computed off its own activations, yielding layerwise recurrent memory while preserving standard autoregressive decoding cost. We show that the architecture can emulate both (i) a conventional Transformer and (ii) token-to-token recurrent updates under mild assumptions, while avoiding optimization instability. Naively, prefill/training appears bandwidth-bound with effective arithmetic intensity near $1$ because keys and values are revealed sequentially; we give an exact tiling-based algorithm that preserves the mathematical computation while reducing HBM traffic from $\Theta(N^2)$ to $\Theta(N\log N)$, increasing effective arithmetic intensity to $\Theta(N/\log N)$ for sequence length $N$. On 150M and 300M parameter C4 pretraining, `\methodname `{=latex}s improve cross-entropy over a parameter-matched Transformer baseline and achieve the improvement with fewer layers (fixed parameters), suggesting that recurrence can trade depth for width, thus reducing KV cache memory footprint and inference latency. Code is available at <https://github.com/geniucos/recurrent-transformer>
bibliography:
- references.bib
title: |
  The Recurrent Transformer:\
  Greater Effective Depth and Efficient Decoding
---

```{=latex}
\newcommand{\costin}[1]{{\color{magenta} Costin:#1}}
```
```{=latex}
\newcommand{\depen}[1]{{\color{blue} Depen:#1}}
```
```{=latex}
\newcommand{\alex}[1]{{\color{orange} Alex:#1}}
```
```{=latex}
\newcommand{\figleft}{{\em (Left)}}
```
```{=latex}
\newcommand{\figcenter}{{\em (Center)}}
```
```{=latex}
\newcommand{\figright}{{\em (Right)}}
```
```{=latex}
\newcommand{\figtop}{{\em (Top)}}
```
```{=latex}
\newcommand{\figbottom}{{\em (Bottom)}}
```
```{=latex}
\newcommand{\captiona}{{\em (a)}}
```
```{=latex}
\newcommand{\captionb}{{\em (b)}}
```
```{=latex}
\newcommand{\captionc}{{\em (c)}}
```
```{=latex}
\newcommand{\captiond}{{\em (d)}}
```
```{=latex}
\newcommand{\newterm}[1]{{\bf #1}}
```
```{=latex}
\def\figref#1{figure~\ref{#1}}
```
```{=latex}
\def\Figref#1{Figure~\ref{#1}}
```
```{=latex}
\def\twofigref#1#2{figures \ref{#1} and \ref{#2}}
```
```{=latex}
\def\quadfigref#1#2#3#4{figures \ref{#1}, \ref{#2}, \ref{#3} and \ref{#4}}
```
```{=latex}
\def\secref#1{section~\ref{#1}}
```
```{=latex}
\def\Secref#1{Section~\ref{#1}}
```
```{=latex}
\def\twosecrefs#1#2{sections \ref{#1} and \ref{#2}}
```
```{=latex}
\def\secrefs#1#2#3{sections \ref{#1}, \ref{#2} and \ref{#3}}
```
```{=latex}
\def\eqref#1{equation~\ref{#1}}
```
```{=latex}
\def\Eqref#1{Equation~\ref{#1}}
```
```{=latex}
\def\plaineqref#1{\ref{#1}}
```
```{=latex}
\def\chapref#1{chapter~\ref{#1}}
```
```{=latex}
\def\Chapref#1{Chapter~\ref{#1}}
```
```{=latex}
\def\rangechapref#1#2{chapters\ref{#1}--\ref{#2}}
```
```{=latex}
\def\algref#1{algorithm~\ref{#1}}
```
```{=latex}
\def\Algref#1{Algorithm~\ref{#1}}
```
```{=latex}
\def\twoalgref#1#2{algorithms \ref{#1} and \ref{#2}}
```
```{=latex}
\def\Twoalgref#1#2{Algorithms \ref{#1} and \ref{#2}}
```
```{=latex}
\def\partref#1{part~\ref{#1}}
```
```{=latex}
\def\Partref#1{Part~\ref{#1}}
```
```{=latex}
\def\twopartref#1#2{parts \ref{#1} and \ref{#2}}
```
```{=latex}
\def\ceil#1{\lceil #1 \rceil}
```
```{=latex}
\def\floor#1{\lfloor #1 \rfloor}
```
```{=latex}
\def\1{\bm{1}}
```
```{=latex}
\newcommand{\train}{\mathcal{D}}
```
```{=latex}
\newcommand{\valid}{\mathcal{D_{\mathrm{valid}}}}
```
```{=latex}
\newcommand{\test}{\mathcal{D_{\mathrm{test}}}}
```
```{=latex}
\def\eps{{\epsilon}}
```
```{=latex}
\def\reta{{\textnormal{$\eta$}}}
```
```{=latex}
\def\ra{{\textnormal{a}}}
```
```{=latex}
\def\rb{{\textnormal{b}}}
```
```{=latex}
\def\rc{{\textnormal{c}}}
```
```{=latex}
\def\rd{{\textnormal{d}}}
```
```{=latex}
\def\re{{\textnormal{e}}}
```
```{=latex}
\def\rf{{\textnormal{f}}}
```
```{=latex}
\def\rg{{\textnormal{g}}}
```
```{=latex}
\def\rh{{\textnormal{h}}}
```
```{=latex}
\def\ri{{\textnormal{i}}}
```
```{=latex}
\def\rj{{\textnormal{j}}}
```
```{=latex}
\def\rk{{\textnormal{k}}}
```
```{=latex}
\def\rl{{\textnormal{l}}}
```
```{=latex}
\def\rn{{\textnormal{n}}}
```
```{=latex}
\def\ro{{\textnormal{o}}}
```
```{=latex}
\def\rp{{\textnormal{p}}}
```
```{=latex}
\def\rq{{\textnormal{q}}}
```
```{=latex}
\def\rr{{\textnormal{r}}}
```
```{=latex}
\def\rs{{\textnormal{s}}}
```
```{=latex}
\def\rt{{\textnormal{t}}}
```
```{=latex}
\def\ru{{\textnormal{u}}}
```
```{=latex}
\def\rv{{\textnormal{v}}}
```
```{=latex}
\def\rw{{\textnormal{w}}}
```
```{=latex}
\def\rx{{\textnormal{x}}}
```
```{=latex}
\def\ry{{\textnormal{y}}}
```
```{=latex}
\def\rz{{\textnormal{z}}}
```
```{=latex}
\def\rvepsilon{{\mathbf{\epsilon}}}
```
```{=latex}
\def\rvtheta{{\mathbf{\theta}}}
```
```{=latex}
\def\rva{{\mathbf{a}}}
```
```{=latex}
\def\rvb{{\mathbf{b}}}
```
```{=latex}
\def\rvc{{\mathbf{c}}}
```
```{=latex}
\def\rvd{{\mathbf{d}}}
```
```{=latex}
\def\rve{{\mathbf{e}}}
```
```{=latex}
\def\rvf{{\mathbf{f}}}
```
```{=latex}
\def\rvg{{\mathbf{g}}}
```
```{=latex}
\def\rvh{{\mathbf{h}}}
```
```{=latex}
\def\rvu{{\mathbf{i}}}
```
```{=latex}
\def\rvj{{\mathbf{j}}}
```
```{=latex}
\def\rvk{{\mathbf{k}}}
```
```{=latex}
\def\rvl{{\mathbf{l}}}
```
```{=latex}
\def\rvm{{\mathbf{m}}}
```
```{=latex}
\def\rvn{{\mathbf{n}}}
```
```{=latex}
\def\rvo{{\mathbf{o}}}
```
```{=latex}
\def\rvp{{\mathbf{p}}}
```
```{=latex}
\def\rvq{{\mathbf{q}}}
```
```{=latex}
\def\rvr{{\mathbf{r}}}
```
```{=latex}
\def\rvs{{\mathbf{s}}}
```
```{=latex}
\def\rvt{{\mathbf{t}}}
```
```{=latex}
\def\rvu{{\mathbf{u}}}
```
```{=latex}
\def\rvv{{\mathbf{v}}}
```
```{=latex}
\def\rvw{{\mathbf{w}}}
```
```{=latex}
\def\rvx{{\mathbf{x}}}
```
```{=latex}
\def\rvy{{\mathbf{y}}}
```
```{=latex}
\def\rvz{{\mathbf{z}}}
```
```{=latex}
\def\erva{{\textnormal{a}}}
```
```{=latex}
\def\ervb{{\textnormal{b}}}
```
```{=latex}
\def\ervc{{\textnormal{c}}}
```
```{=latex}
\def\ervd{{\textnormal{d}}}
```
```{=latex}
\def\erve{{\textnormal{e}}}
```
```{=latex}
\def\ervf{{\textnormal{f}}}
```
```{=latex}
\def\ervg{{\textnormal{g}}}
```
```{=latex}
\def\ervh{{\textnormal{h}}}
```
```{=latex}
\def\ervi{{\textnormal{i}}}
```
```{=latex}
\def\ervj{{\textnormal{j}}}
```
```{=latex}
\def\ervk{{\textnormal{k}}}
```
```{=latex}
\def\ervl{{\textnormal{l}}}
```
```{=latex}
\def\ervm{{\textnormal{m}}}
```
```{=latex}
\def\ervn{{\textnormal{n}}}
```
```{=latex}
\def\ervo{{\textnormal{o}}}
```
```{=latex}
\def\ervp{{\textnormal{p}}}
```
```{=latex}
\def\ervq{{\textnormal{q}}}
```
```{=latex}
\def\ervr{{\textnormal{r}}}
```
```{=latex}
\def\ervs{{\textnormal{s}}}
```
```{=latex}
\def\ervt{{\textnormal{t}}}
```
```{=latex}
\def\ervu{{\textnormal{u}}}
```
```{=latex}
\def\ervv{{\textnormal{v}}}
```
```{=latex}
\def\ervw{{\textnormal{w}}}
```
```{=latex}
\def\ervx{{\textnormal{x}}}
```
```{=latex}
\def\ervy{{\textnormal{y}}}
```
```{=latex}
\def\ervz{{\textnormal{z}}}
```
```{=latex}
\def\rmA{{\mathbf{A}}}
```
```{=latex}
\def\rmB{{\mathbf{B}}}
```
```{=latex}
\def\rmC{{\mathbf{C}}}
```
```{=latex}
\def\rmD{{\mathbf{D}}}
```
```{=latex}
\def\rmE{{\mathbf{E}}}
```
```{=latex}
\def\rmF{{\mathbf{F}}}
```
```{=latex}
\def\rmG{{\mathbf{G}}}
```
```{=latex}
\def\rmH{{\mathbf{H}}}
```
```{=latex}
\def\rmI{{\mathbf{I}}}
```
```{=latex}
\def\rmJ{{\mathbf{J}}}
```
```{=latex}
\def\rmK{{\mathbf{K}}}
```
```{=latex}
\def\rmL{{\mathbf{L}}}
```
```{=latex}
\def\rmM{{\mathbf{M}}}
```
```{=latex}
\def\rmN{{\mathbf{N}}}
```
```{=latex}
\def\rmO{{\mathbf{O}}}
```
```{=latex}
\def\rmP{{\mathbf{P}}}
```
```{=latex}
\def\rmQ{{\mathbf{Q}}}
```
```{=latex}
\def\rmR{{\mathbf{R}}}
```
```{=latex}
\def\rmS{{\mathbf{S}}}
```
```{=latex}
\def\rmT{{\mathbf{T}}}
```
```{=latex}
\def\rmU{{\mathbf{U}}}
```
```{=latex}
\def\rmV{{\mathbf{V}}}
```
```{=latex}
\def\rmW{{\mathbf{W}}}
```
```{=latex}
\def\rmX{{\mathbf{X}}}
```
```{=latex}
\def\rmY{{\mathbf{Y}}}
```
```{=latex}
\def\rmZ{{\mathbf{Z}}}
```
```{=latex}
\def\ermA{{\textnormal{A}}}
```
```{=latex}
\def\ermB{{\textnormal{B}}}
```
```{=latex}
\def\ermC{{\textnormal{C}}}
```
```{=latex}
\def\ermD{{\textnormal{D}}}
```
```{=latex}
\def\ermE{{\textnormal{E}}}
```
```{=latex}
\def\ermF{{\textnormal{F}}}
```
```{=latex}
\def\ermG{{\textnormal{G}}}
```
```{=latex}
\def\ermH{{\textnormal{H}}}
```
```{=latex}
\def\ermI{{\textnormal{I}}}
```
```{=latex}
\def\ermJ{{\textnormal{J}}}
```
```{=latex}
\def\ermK{{\textnormal{K}}}
```
```{=latex}
\def\ermL{{\textnormal{L}}}
```
```{=latex}
\def\ermM{{\textnormal{M}}}
```
```{=latex}
\def\ermN{{\textnormal{N}}}
```
```{=latex}
\def\ermO{{\textnormal{O}}}
```
```{=latex}
\def\ermP{{\textnormal{P}}}
```
```{=latex}
\def\ermQ{{\textnormal{Q}}}
```
```{=latex}
\def\ermR{{\textnormal{R}}}
```
```{=latex}
\def\ermS{{\textnormal{S}}}
```
```{=latex}
\def\ermT{{\textnormal{T}}}
```
```{=latex}
\def\ermU{{\textnormal{U}}}
```
```{=latex}
\def\ermV{{\textnormal{V}}}
```
```{=latex}
\def\ermW{{\textnormal{W}}}
```
```{=latex}
\def\ermX{{\textnormal{X}}}
```
```{=latex}
\def\ermY{{\textnormal{Y}}}
```
```{=latex}
\def\ermZ{{\textnormal{Z}}}
```
```{=latex}
\def\vzero{{\bm{0}}}
```
```{=latex}
\def\vone{{\bm{1}}}
```
```{=latex}
\def\vmu{{\bm{\mu}}}
```
```{=latex}
\def\vtheta{{\bm{\theta}}}
```
```{=latex}
\def\va{{\bm{a}}}
```
```{=latex}
\def\vb{{\bm{b}}}
```
```{=latex}
\def\vc{{\bm{c}}}
```
```{=latex}
\def\vd{{\bm{d}}}
```
```{=latex}
\def\ve{{\bm{e}}}
```
```{=latex}
\def\vf{{\bm{f}}}
```
```{=latex}
\def\vg{{\bm{g}}}
```
```{=latex}
\def\vh{{\bm{h}}}
```
```{=latex}
\def\vi{{\bm{i}}}
```
```{=latex}
\def\vj{{\bm{j}}}
```
```{=latex}
\def\vk{{\bm{k}}}
```
```{=latex}
\def\vl{{\bm{l}}}
```
```{=latex}
\def\vm{{\bm{m}}}
```
```{=latex}
\def\vn{{\bm{n}}}
```
```{=latex}
\def\vo{{\bm{o}}}
```
```{=latex}
\def\vp{{\bm{p}}}
```
```{=latex}
\def\vq{{\bm{q}}}
```
```{=latex}
\def\vr{{\bm{r}}}
```
```{=latex}
\def\vs{{\bm{s}}}
```
```{=latex}
\def\vt{{\bm{t}}}
```
```{=latex}
\def\vu{{\bm{u}}}
```
```{=latex}
\def\vv{{\bm{v}}}
```
```{=latex}
\def\vw{{\bm{w}}}
```
```{=latex}
\def\vx{{\bm{x}}}
```
```{=latex}
\def\vy{{\bm{y}}}
```
```{=latex}
\def\vz{{\bm{z}}}
```
```{=latex}
\def\evalpha{{\alpha}}
```
```{=latex}
\def\evbeta{{\beta}}
```
```{=latex}
\def\evepsilon{{\epsilon}}
```
```{=latex}
\def\evlambda{{\lambda}}
```
```{=latex}
\def\evomega{{\omega}}
```
```{=latex}
\def\evmu{{\mu}}
```
```{=latex}
\def\evpsi{{\psi}}
```
```{=latex}
\def\evsigma{{\sigma}}
```
```{=latex}
\def\evtheta{{\theta}}
```
```{=latex}
\def\eva{{a}}
```
```{=latex}
\def\evb{{b}}
```
```{=latex}
\def\evc{{c}}
```
```{=latex}
\def\evd{{d}}
```
```{=latex}
\def\eve{{e}}
```
```{=latex}
\def\evf{{f}}
```
```{=latex}
\def\evg{{g}}
```
```{=latex}
\def\evh{{h}}
```
```{=latex}
\def\evi{{i}}
```
```{=latex}
\def\evj{{j}}
```
```{=latex}
\def\evk{{k}}
```
```{=latex}
\def\evl{{l}}
```
```{=latex}
\def\evm{{m}}
```
```{=latex}
\def\evn{{n}}
```
```{=latex}
\def\evo{{o}}
```
```{=latex}
\def\evp{{p}}
```
```{=latex}
\def\evq{{q}}
```
```{=latex}
\def\evr{{r}}
```
```{=latex}
\def\evs{{s}}
```
```{=latex}
\def\evt{{t}}
```
```{=latex}
\def\evu{{u}}
```
```{=latex}
\def\evv{{v}}
```
```{=latex}
\def\evw{{w}}
```
```{=latex}
\def\evx{{x}}
```
```{=latex}
\def\evy{{y}}
```
```{=latex}
\def\evz{{z}}
```
```{=latex}
\def\mA{{\bm{A}}}
```
```{=latex}
\def\mB{{\bm{B}}}
```
```{=latex}
\def\mC{{\bm{C}}}
```
```{=latex}
\def\mD{{\bm{D}}}
```
```{=latex}
\def\mE{{\bm{E}}}
```
```{=latex}
\def\mF{{\bm{F}}}
```
```{=latex}
\def\mG{{\bm{G}}}
```
```{=latex}
\def\mH{{\bm{H}}}
```
```{=latex}
\def\mI{{\bm{I}}}
```
```{=latex}
\def\mJ{{\bm{J}}}
```
```{=latex}
\def\mK{{\bm{K}}}
```
```{=latex}
\def\mL{{\bm{L}}}
```
```{=latex}
\def\mM{{\bm{M}}}
```
```{=latex}
\def\mN{{\bm{N}}}
```
```{=latex}
\def\mO{{\bm{O}}}
```
```{=latex}
\def\mP{{\bm{P}}}
```
```{=latex}
\def\mQ{{\bm{Q}}}
```
```{=latex}
\def\mR{{\bm{R}}}
```
```{=latex}
\def\mS{{\bm{S}}}
```
```{=latex}
\def\mT{{\bm{T}}}
```
```{=latex}
\def\mU{{\bm{U}}}
```
```{=latex}
\def\mV{{\bm{V}}}
```
```{=latex}
\def\mW{{\bm{W}}}
```
```{=latex}
\def\mX{{\bm{X}}}
```
```{=latex}
\def\mY{{\bm{Y}}}
```
```{=latex}
\def\mZ{{\bm{Z}}}
```
```{=latex}
\def\mBeta{{\bm{\beta}}}
```
```{=latex}
\def\mPhi{{\bm{\Phi}}}
```
```{=latex}
\def\mLambda{{\bm{\Lambda}}}
```
```{=latex}
\def\mSigma{{\bm{\Sigma}}}
```
```{=latex}
\newcommand{\tens}[1]{\bm{\mathsfit{#1}}}
```
```{=latex}
\def\tA{{\tens{A}}}
```
```{=latex}
\def\tB{{\tens{B}}}
```
```{=latex}
\def\tC{{\tens{C}}}
```
```{=latex}
\def\tD{{\tens{D}}}
```
```{=latex}
\def\tE{{\tens{E}}}
```
```{=latex}
\def\tF{{\tens{F}}}
```
```{=latex}
\def\tG{{\tens{G}}}
```
```{=latex}
\def\tH{{\tens{H}}}
```
```{=latex}
\def\tI{{\tens{I}}}
```
```{=latex}
\def\tJ{{\tens{J}}}
```
```{=latex}
\def\tK{{\tens{K}}}
```
```{=latex}
\def\tL{{\tens{L}}}
```
```{=latex}
\def\tM{{\tens{M}}}
```
```{=latex}
\def\tN{{\tens{N}}}
```
```{=latex}
\def\tO{{\tens{O}}}
```
```{=latex}
\def\tP{{\tens{P}}}
```
```{=latex}
\def\tQ{{\tens{Q}}}
```
```{=latex}
\def\tR{{\tens{R}}}
```
```{=latex}
\def\tS{{\tens{S}}}
```
```{=latex}
\def\tT{{\tens{T}}}
```
```{=latex}
\def\tU{{\tens{U}}}
```
```{=latex}
\def\tV{{\tens{V}}}
```
```{=latex}
\def\tW{{\tens{W}}}
```
```{=latex}
\def\tX{{\tens{X}}}
```
```{=latex}
\def\tY{{\tens{Y}}}
```
```{=latex}
\def\tZ{{\tens{Z}}}
```
```{=latex}
\def\gA{{\mathcal{A}}}
```
```{=latex}
\def\gB{{\mathcal{B}}}
```
```{=latex}
\def\gC{{\mathcal{C}}}
```
```{=latex}
\def\gD{{\mathcal{D}}}
```
```{=latex}
\def\gE{{\mathcal{E}}}
```
```{=latex}
\def\gF{{\mathcal{F}}}
```
```{=latex}
\def\gG{{\mathcal{G}}}
```
```{=latex}
\def\gH{{\mathcal{H}}}
```
```{=latex}
\def\gI{{\mathcal{I}}}
```
```{=latex}
\def\gJ{{\mathcal{J}}}
```
```{=latex}
\def\gK{{\mathcal{K}}}
```
```{=latex}
\def\gL{{\mathcal{L}}}
```
```{=latex}
\def\gM{{\mathcal{M}}}
```
```{=latex}
\def\gN{{\mathcal{N}}}
```
```{=latex}
\def\gO{{\mathcal{O}}}
```
```{=latex}
\def\gP{{\mathcal{P}}}
```
```{=latex}
\def\gQ{{\mathcal{Q}}}
```
```{=latex}
\def\gR{{\mathcal{R}}}
```
```{=latex}
\def\gS{{\mathcal{S}}}
```
```{=latex}
\def\gT{{\mathcal{T}}}
```
```{=latex}
\def\gU{{\mathcal{U}}}
```
```{=latex}
\def\gV{{\mathcal{V}}}
```
```{=latex}
\def\gW{{\mathcal{W}}}
```
```{=latex}
\def\gX{{\mathcal{X}}}
```
```{=latex}
\def\gY{{\mathcal{Y}}}
```
```{=latex}
\def\gZ{{\mathcal{Z}}}
```
```{=latex}
\def\sA{{\mathbb{A}}}
```
```{=latex}
\def\sB{{\mathbb{B}}}
```
```{=latex}
\def\sC{{\mathbb{C}}}
```
```{=latex}
\def\sD{{\mathbb{D}}}
```
```{=latex}
\def\sF{{\mathbb{F}}}
```
```{=latex}
\def\sG{{\mathbb{G}}}
```
```{=latex}
\def\sH{{\mathbb{H}}}
```
```{=latex}
\def\sI{{\mathbb{I}}}
```
```{=latex}
\def\sJ{{\mathbb{J}}}
```
```{=latex}
\def\sK{{\mathbb{K}}}
```
```{=latex}
\def\sL{{\mathbb{L}}}
```
```{=latex}
\def\sM{{\mathbb{M}}}
```
```{=latex}
\def\sN{{\mathbb{N}}}
```
```{=latex}
\def\sO{{\mathbb{O}}}
```
```{=latex}
\def\sP{{\mathbb{P}}}
```
```{=latex}
\def\sQ{{\mathbb{Q}}}
```
```{=latex}
\def\sR{{\mathbb{R}}}
```
```{=latex}
\def\sS{{\mathbb{S}}}
```
```{=latex}
\def\sT{{\mathbb{T}}}
```
```{=latex}
\def\sU{{\mathbb{U}}}
```
```{=latex}
\def\sV{{\mathbb{V}}}
```
```{=latex}
\def\sW{{\mathbb{W}}}
```
```{=latex}
\def\sX{{\mathbb{X}}}
```
```{=latex}
\def\sY{{\mathbb{Y}}}
```
```{=latex}
\def\sZ{{\mathbb{Z}}}
```
```{=latex}
\def\emLambda{{\Lambda}}
```
```{=latex}
\def\emA{{A}}
```
```{=latex}
\def\emB{{B}}
```
```{=latex}
\def\emC{{C}}
```
```{=latex}
\def\emD{{D}}
```
```{=latex}
\def\emE{{E}}
```
```{=latex}
\def\emF{{F}}
```
```{=latex}
\def\emG{{G}}
```
```{=latex}
\def\emH{{H}}
```
```{=latex}
\def\emI{{I}}
```
```{=latex}
\def\emJ{{J}}
```
```{=latex}
\def\emK{{K}}
```
```{=latex}
\def\emL{{L}}
```
```{=latex}
\def\emM{{M}}
```
```{=latex}
\def\emN{{N}}
```
```{=latex}
\def\emO{{O}}
```
```{=latex}
\def\emP{{P}}
```
```{=latex}
\def\emQ{{Q}}
```
```{=latex}
\def\emR{{R}}
```
```{=latex}
\def\emS{{S}}
```
```{=latex}
\def\emT{{T}}
```
```{=latex}
\def\emU{{U}}
```
```{=latex}
\def\emV{{V}}
```
```{=latex}
\def\emW{{W}}
```
```{=latex}
\def\emX{{X}}
```
```{=latex}
\def\emY{{Y}}
```
```{=latex}
\def\emZ{{Z}}
```
```{=latex}
\def\emSigma{{\Sigma}}
```
```{=latex}
\newcommand{\etens}[1]{\mathsfit{#1}}
```
```{=latex}
\def\etLambda{{\etens{\Lambda}}}
```
```{=latex}
\def\etA{{\etens{A}}}
```
```{=latex}
\def\etB{{\etens{B}}}
```
```{=latex}
\def\etC{{\etens{C}}}
```
```{=latex}
\def\etD{{\etens{D}}}
```
```{=latex}
\def\etE{{\etens{E}}}
```
```{=latex}
\def\etF{{\etens{F}}}
```
```{=latex}
\def\etG{{\etens{G}}}
```
```{=latex}
\def\etH{{\etens{H}}}
```
```{=latex}
\def\etI{{\etens{I}}}
```
```{=latex}
\def\etJ{{\etens{J}}}
```
```{=latex}
\def\etK{{\etens{K}}}
```
```{=latex}
\def\etL{{\etens{L}}}
```
```{=latex}
\def\etM{{\etens{M}}}
```
```{=latex}
\def\etN{{\etens{N}}}
```
```{=latex}
\def\etO{{\etens{O}}}
```
```{=latex}
\def\etP{{\etens{P}}}
```
```{=latex}
\def\etQ{{\etens{Q}}}
```
```{=latex}
\def\etR{{\etens{R}}}
```
```{=latex}
\def\etS{{\etens{S}}}
```
```{=latex}
\def\etT{{\etens{T}}}
```
```{=latex}
\def\etU{{\etens{U}}}
```
```{=latex}
\def\etV{{\etens{V}}}
```
```{=latex}
\def\etW{{\etens{W}}}
```
```{=latex}
\def\etX{{\etens{X}}}
```
```{=latex}
\def\etY{{\etens{Y}}}
```
```{=latex}
\def\etZ{{\etens{Z}}}
```
```{=latex}
\newcommand{\pdata}{p_{\rm{data}}}
```
```{=latex}
\newcommand{\ptrain}{\hat{p}_{\rm{data}}}
```
```{=latex}
\newcommand{\Ptrain}{\hat{P}_{\rm{data}}}
```
```{=latex}
\newcommand{\pmodel}{p_{\rm{model}}}
```
```{=latex}
\newcommand{\Pmodel}{P_{\rm{model}}}
```
```{=latex}
\newcommand{\ptildemodel}{\tilde{p}_{\rm{model}}}
```
```{=latex}
\newcommand{\pencode}{p_{\rm{encoder}}}
```
```{=latex}
\newcommand{\pdecode}{p_{\rm{decoder}}}
```
```{=latex}
\newcommand{\precons}{p_{\rm{reconstruct}}}
```
```{=latex}
\newcommand{\laplace}{\mathrm{Laplace}}
```
```{=latex}
\newcommand{\E}{\mathbb{E}}
```
```{=latex}
\newcommand{\Ls}{\mathcal{L}}
```
```{=latex}
\newcommand{\R}{\mathbb{R}}
```
```{=latex}
\newcommand{\N}{\mathbb{N}}
```
```{=latex}
\newcommand{\emp}{\tilde{p}}
```
```{=latex}
\newcommand{\lr}{\alpha}
```
```{=latex}
\newcommand{\reg}{\lambda}
```
```{=latex}
\newcommand{\rect}{\mathrm{rectifier}}
```
```{=latex}
\newcommand{\softmax}{\mathrm{softmax}}
```
```{=latex}
\newcommand{\sigmoid}{\sigma}
```
```{=latex}
\newcommand{\softplus}{\zeta}
```
```{=latex}
\newcommand{\KL}{D_{\mathrm{KL}}}
```
```{=latex}
\newcommand{\Var}{\mathrm{Var}}
```
```{=latex}
\newcommand{\standarderror}{\mathrm{SE}}
```
```{=latex}
\newcommand{\Cov}{\mathrm{Cov}}
```
```{=latex}
\newcommand{\normlzero}{L^0}
```
```{=latex}
\newcommand{\normlone}{L^1}
```
```{=latex}
\newcommand{\normltwo}{L^2}
```
```{=latex}
\newcommand{\normlp}{L^p}
```
```{=latex}
\newcommand{\normmax}{L^\infty}
```
```{=latex}
\newcommand{\parents}{Pa}
```
```{=latex}
\newcommand{\LN}{\mathrm{LN}}
```
```{=latex}
\newcommand{\RMS}{\mathrm{RMS}}
```
```{=latex}
\newcommand{\MLP}{\mathrm{MLP}}
```
```{=latex}
\newcommand{\Attn}{\mathrm{Attn}}
```
```{=latex}
\newcommand{\lnorm}[2]{\Vert #1 \Vert_{#2}}
```
```{=latex}
\newcommand{\ltwonorm}[1]{\lnorm{#1}{2}}
```
```{=latex}
\newcommand{\innerp}[2]{\langle{#1, #2}\rangle}
```
```{=latex}
\newcommand{\pluseq}{\mathrel{\raisebox{0.19ex}{$\scriptstyle+$}}=}
```
```{=latex}
\newcommand{\defeq}{\triangleq}
```
```{=latex}
\DeclareMathOperator{\sign}{sign}
```
```{=latex}
\DeclareMathOperator{\Tr}{Tr}
```
```{=latex}
\newcommand{\balpha}{\boldsymbol \alpha}
```
```{=latex}
\newcommand{\bbeta}{\boldsymbol \beta}
```
```{=latex}
\newcommand{\bgamma}{\boldsymbol \gamma}
```
```{=latex}
\renewcommand{\hat}{\widehat}
```
```{=latex}
\renewcommand{\tilde}{ }
```
```{=latex}
\renewcommand{\>}{{\rightarrow}}
```
```{=latex}
\renewcommand{\=}{\stackrel{\triangle}{=}}
```
```{=latex}
\newcommand{\half}{\textstyle{\frac{1}{2}}}
```
```{=latex}
\newcommand{\grad}{\nabla}
```
```{=latex}
\newcommand{\conv}{\operatorname{conv}}
```
```{=latex}
\newcommand{\block}{\mathrm{block}}
```
```{=latex}
\newcommand{\mixer}{\mathrm{mixer}}
```
```{=latex}
\newcommand{\sampler}{\mathrm{sampler}}
```
```{=latex}
\newcommand{\agg}{\mathrm{agg}}
```
```{=latex}
\newcommand{\frmstate}{\mathrm{read}}
```
```{=latex}
\newcommand{\cont}{\mathrm{cont}}
```
```{=latex}
\newcommand{\poly}{\operatorname{poly}}
```
```{=latex}
\newcommand{\rank}{\operatorname{rank}}
```
```{=latex}
\newcommand{\support}{\operatorname{support}}
```
```{=latex}
\newcommand{\bX}{{\mathbf X}}
```
```{=latex}
\newcommand{\z}{{\mathbf z}}
```
```{=latex}
\newcommand{\x}{{\mathbf x}}
```
```{=latex}
\newcommand{\g}{{\mathbf g}}
```
```{=latex}
\newcommand{\ba}{{\mathbf a}}
```
```{=latex}
\newcommand{\methodname}{Recurrent Transformer}
```
```{=latex}
\newcommand{\methodnames}{Recurrent Transformers}
```
```{=latex}
\newcommand{\methodnameabv}{\textsc{RT}}
```
```{=latex}
\newcommand{\qkRMS}{\mathrm{\textcolor{magenta}{RMS}}}
```
```{=latex}
\newcommand{\temp}{\textcolor{red}{\mathrm{temp}}}
```
```{=latex}
\newcommand{\zmark}[1]{\textcolor{blue!55!black}{#1}}
```
```{=latex}
\maketitle
```
Introduction
============

Transformers [@vaswani2017attention] are highly effective sequence models, but their computation across positions is structurally shallow: within each layer, position $t$ attends to key--value pairs computed from the previous layer embeddings, allowing essentially at most one interaction per layer between any pair of positions. A growing body of theory studies the fundamental limitations implied by bounded depth in attention models, including circuit-complexity characterizations of what low-depth Transformers can and cannot represent [@merrill2022saturated; @liu2022shortcuts]. These perspectives motivate architectures that achieve greater effective depth.

We introduce the `\newterm{\methodname}`{=latex} (`\methodnameabv`{=latex}), a simple modification of how key--value pairs are computed that makes each layer temporally recurrent. In a standard Transformer, at layer $\ell$, the key--value pair at position $t$ is computed from the layer-$(\ell-1)$ representation at that position and can then be attended to by later positions $t' > t$. In the `\methodname`{=latex}, by contrast, the key--value pair at position $t$ in layer $\ell$ is computed from that position's output at layer $\ell$, rather than from its layer-$(\ell-1)$ representation. Consequently, a later position $t < t'$ at layer $\ell$ attends to a representation at $t$ that already reflects layer $\ell$ attention and MLP computation. Importantly, `\methodname{}`{=latex} performs this recurrence separately within each layer, so each layer maintains its own key--value memory. This differs from the Feedback Transformer [@fan2020feedback], which uses a shared memory across layers, and this layerwise separation is a key reason why our architecture can be implemented efficiently.

We motivate `\methodname{}`{=latex}'s design through lenses of representation, optimization and computational efficiency:

```{=latex}
\begin{figure*}[t]
  
  \includegraphics[width=\textwidth]{diagrams/rt_flow.pdf}
  \caption{One layer of the \methodname{} mapping input embeddings $\vx_1\ldots\vx_4$ to output embeddings $\vz_1\ldots \vz_N$. Notice how the \emph{persistent} key--value pairs are a function of the layer's output and are used for all subsequent attention computations. The \emph{temporary} key--value pairs are only used at the time they are computed and then discarded. They are only used to avoid ill-defined attention since, for example, $\va_2$ cannot attend to $(\vk_2, \vv_2)$ as that indirectly depends on it. This is in contrast to a vanilla Transformer that uses these same key--value pairs for all subsequent attention computation as well.}
  \label{fig:rt-arch}
\end{figure*}
```
#### (i) Representational perspective.

`\methodnames{}`{=latex} retains per-token key--value memory just like a Transformer, but increase the space of computations that can be expressed within a single layer by allowing later positions to attend to representations that have already undergone attention and MLP processing. Under mild assumptions, `\methodnames{}`{=latex} can emulate standard Transformer behavior; conversely, by restricting attention to the previous position, they can implement token-to-token recurrent computation. This positions `\methodname{}`{=latex} between fully parallel attention and purely recurrent state-space computation, while avoiding a capped-memory bottleneck.

#### (ii) Training Stability.

Viewing the model as a directed computation graph over positions, classical RNNs transmit information from position $i$ to $j$ only through the length-$(j-i)$ chain of intermediate states. The potentially large length of such paths gives rise to vanishing and exploding gradient phenomena [@bengio1994long; @pascanu2013difficulty], making it hard to ensure information flow between distant positions. `\methodname{}`{=latex} alleviates this by introducing many additional multi-hop paths, corresponding to repeated attend+MLP applications across positions within a layer, while still permitting direct one-hop attention interactions between any two positions. In practice, we find that this architecture, together with appropriate normalization before key--value computation and standard depth-wise residual scaling [@bordelon2023depthwise; @yang2023tensor], trains stably. We expand on this view, and on why exploding gradients are not expected to be an issue, in Section `\ref{sec:paths}`{=latex}.

#### (iii) Training-time efficiency.

A naive implementation of `\methodname{}`{=latex} training/prefill is sequential in position and appears bandwidth-bound: keys and values are revealed one position at a time, and each query must aggregate over a linearly-growing prefix, leading to a very low effective arithmetic intensity -- $\Theta(1)$ -- under the Roofline model [@williams2009roofline]. We give an *exact* tiling algorithm that preserves the mathematical attention computation while reorganizing memory movement, reducing high-bandwidth memory (HBM) traffic from $\Theta(N^2)$ to $\Theta(N\log N)$ and raising effective arithmetic intensity to $\Theta(N/\log N)$. Our key observation is that, during training/prefill, the full sequence of queries is available in advance even though persistent key--value pairs are revealed causally. This makes it possible to reorganize the computation into a tiled schedule, in the spirit of Flash Inference [@flashInference], that reuses each revealed key--value tile across many future queries before it is evicted from fast memory. The final algorithm interleaves attention blocks and MLP computation while employing the same methodology as [@rabe2021blockAttention; @dao2022flashattention] to accumulate attention contribution.

#### (iv) Depth to inference efficiency.

Crucially, the additional effective temporal depth can translate into a better depth--width tradeoff: at fixed parameter count, achieving the same quality with fewer layers reduces the amount of stored key--value state and the corresponding decode-time memory traffic. Our experiments support this regime, with shallower `\methodname{}`{=latex} models outperforming deeper Transformer baselines.

#### Contributions.

-   In Section `\ref{sec:arch}`{=latex}, we propose the `\methodname{}`{=latex} (`\methodnameabv`{=latex}), a layerwise recurrent attention architecture that computes each layer's key--value pairs from that layer's outputs rather than from the previous layer's representations.

-   In Section `\ref{sec:repr}`{=latex}, we provide representational arguments showing `\methodname{}`{=latex} can emulate standard self-attention behavior and can implement token-to-token recurrent computation via attention concentration under mild assumptions.

-   In Section `\ref{sec:paths}`{=latex}, we provide a path-based analysis of training stability in `\methodname{}`{=latex}, showing how the architecture combines additional multi-hop computation with direct one-hop attention paths, and giving theoretical evidence in a simplified setting that neither exploding gradients nor vanishing gradients are expected under appropriate scaling.

-   In Section `\ref{sec:tiling}`{=latex}, we provide an *exact*, IO-aware tiling algorithm for prefill/training that preserves the mathematical attention computation while reducing memory traffic from $\Theta(N^2)$ to $\Theta(N\log N)$ and increasing effective arithmetic intensity from $\Theta(1)$ to $\Theta(N/\log N)$.

-   In Section `\ref{sec:comp_challenges}`{=latex}, we outline various computational challenges and design choices required to make `\methodname{}`{=latex} training more efficient and practical.

-   In Section `\ref{sec:expts}`{=latex}, we present empirical results on 300M-parameter C4 pretraining showing improved cross-entropy over parameter-matched Transformer baselines and favorable depth--width tradeoffs at fixed parameter count (as shown in Figure `\ref{fig:c4-300m}`{=latex}). In particular, `\methodname{}`{=latex} with $6$ layers performs comparably to $12$ layers (fixed parameters), reducing KV cache size by approximately 30% and lowering decode-time memory traffic, thereby improving inference efficiency. Additional results for the 150M-parameter model are provided in Appendix `\ref{app:150m-pretrain}`{=latex}.

![C4 pretraining: loss curves for 300m parameter model trained on C4 dataset.](plots/best_runs_val_300m_512.png){#fig:c4-300m width="\\linewidth"}

```{=latex}
\hfill
```
```{=latex}
\captionof{table}{C4 pretraining loss for 300M parameter model.}
```
`\label{tab:c4-512-300m}`{=latex}

  Model                      Layers   Width    Val CE $\downarrow$
  ------------------------- -------- -------- ---------------------
  Transformer                  6      $2048$         $2.917$
  Transformer                  12     $1408$         $2.896$
  Transformer                  24     $1024$         $2.892$
  `\methodname{}`{=latex}      12     $1408$         $2.867$
  `\methodname{}`{=latex}      6      $2048$         $2.86$

Architectural overview and notation {#sec:arch}
===================================

#### Architectural overview.

Relative to a standard causal Transformer, the defining change in `\methodname{}`{=latex} is where the key--value pairs exposed to future positions come from. In a standard Transformer, the key--value pair at position $i$ is computed from the layer input at that position. In `\methodname{}`{=latex}, by contrast, the *persistent* key--value pair at position $i$ is computed from that position's layer output. Consequently, later positions attend to earlier positions whose representations have already undergone same-layer attention and MLP computation, making each layer recurrent along the temporal axis.

This creates a circularity at the current position: because the layer output at position $i$ also attends to the current position, the persistent pair $(\vk_i,\vv_i)$ cannot itself be used while computing that output. To resolve this, `\methodname{}`{=latex} distinguishes between two kinds of key--value pairs. A *temporary* pair, computed from the current layer input, is used only when evaluating attention at the current position. A *persistent* pair, computed from the resulting layer output, is then stored and made available to all later positions.

#### Notation.

We present the single-head formulation; multihead attention applies the same construction independently per head and then uses the usual output projection. We assume a sequence length of $N$ and use $L$ for the number of stacked layers. Let $D$ be the embedding dimension and consider a single layer with inputs $\vx_1,\ldots,\vx_N\in\R^D$. Let $\MLP:\R^D\to\R^D$ denote the MLP block and let $\RMS:\R^D\to\R^D$ denote Root Mean Square normalization [@RMSNorm]. While in practice we use learnable parameters, as far as presentation and analysis is concerned, we take $\RMS(x)=\sqrt{D} \cdot \vx / \ltwonorm{\vx}$. We use (magenta) $\qkRMS$ to distinguish query/key normalization [@dehghani2023scalingQKNorm].

The attention operator $\Attn:(\R^D\times\R^D)^{*}\times \R^D\to\R^D$ maps a sequence of key--value pairs $(\vk_1,\vv_1),\ldots,(\vk_\ell,\vv_\ell)$ and a query $\vq$ to $$\begin{aligned}
\Attn\big((\vk_1,\vv_1),\ldots,(\vk_\ell,\vv_\ell),\vq\big) = \sum_{i=1}^\ell {\vv_i \cdot
\frac{\exp(\innerp{\vk_i}{\vq})}
{\sum_{j=1}^\ell{\exp(\innerp{\vk_j}{\vq}})}
}\end{aligned}$$

We use projection matrices $Q,K,V\in\R^{D\times D}$ to compute queries, keys and values based off an embedding. Following standard Transformer parameterizations [@bordelon2023depthwise; @yang2023tensor], we use pre-LN [@xiong2020layerPreLN] and assume attention and MLP residual updates are initialized/parameterized with an appropriate $1/\sqrt{L}$ scale so chaining maps of the form $\vx\mapsto \vx+ \frac{1}{\sqrt{L}} \{\Attn,\MLP\}(\RMS(\vx))$ is well-behaved.

The Transformer layer {#sec:baseline-layer}
---------------------

We first recall a standard *causal* decoder-only Transformer layer [@vaswani2017attention]. Given inputs $\vx_1,\ldots,\vx_N\in\R^D$, position $i$ forms its query, key, and value from the current layer input: $$\begin{aligned}
\vq_i &= \qkRMS[Q\,\RMS(\vx_i)], \\
\vk_i &= \qkRMS[K\,\RMS(\vx_i)], \\
\vv_i &= V\,\RMS(\vx_i).\end{aligned}$$ The attention output at position $i$ is then computed by attending over the prefix of key--value pairs available up to that position: $$\begin{aligned}
\va_i &= \Attn\big((\vk_1,\vv_1),\ldots,(\vk_i,\vv_i),\vq_i\big).\end{aligned}$$ Finally, the layer output is obtained by adding the attention and MLP residual branches: $$\begin{aligned}
\vy_i &= \vx_i + \frac{1}{\sqrt{L}}\left(\va_i + \MLP[\RMS(\vx_i+\frac{1}{\sqrt{L}}\va_i)]\right).\end{aligned}$$

The key structural point is that, in a standard Transformer, the key--value pair stored at position $i$ is computed from the layer input at the same position.

The `\methodname{}`{=latex} layer {#sec:rt-layer}
---------------------------------

`\methodname{}`{=latex} layers (illustrated in Figure `\ref{fig:rt-arch}`{=latex}) differ from standard Transformer layers only in how the key--value pairs exposed to future positions are formed. At position $i$, `\methodname{}`{=latex} first forms the query together with a *temporary* key--value pair from the current layer input: $$\begin{aligned}
\vq_i &= \qkRMS[Q\,\RMS(\vx_i)], \\
\vk_i^{\temp} &= \qkRMS[K\,\RMS(\vx_i)], \\
\vv_i^{\temp} &= V\,\RMS(\vx_i).\end{aligned}$$ These definitions are identical to the Transformer's query, key, and value projections at position $i$. The attention output at position $i$ is then computed using the persistent key--value pairs from earlier positions together with the temporary pair at the current position: $$\begin{aligned}
\va_i
&=
\Attn\big(
(\vk_1,\vv_1),\ldots,(\vk_{i-1},\vv_{i-1}),
(\vk_i^{\temp},\vv_i^{\temp}),
\vq_i
\big).\end{aligned}$$ We next form the layer output representation $$\begin{aligned}
\vz_i &= \vx_i + \frac{1}{\sqrt{L}}\left(\va_i + \MLP[\RMS(\vx_i+\frac{1}{\sqrt{L}}\va_i)]\right),\end{aligned}$$ which is both the representation passed to the next layer and the source from which the persistent key--value pair at position $i$ is computed. We define that persistent pair by projecting from this output: $$\begin{aligned}
\label{eq:persistentK}
\vk_i &= \qkRMS[K\,\RMS(\vz_i)], \\
\label{eq:persistentV}
\vv_i &= V\,\RMS(\vz_i).\end{aligned}$$

Thus, $\vk_i^{\temp}, \vv_i^{\temp}$ is used only to compute attention at position $i$; it is not exposed to future positions. The persistent pair, by contrast, is defined only after $\vz_i$ has been formed and is then stored for use by all later positions. Thus, unlike in a standard Transformer, future positions attend not to a pair computed from the layer input at position $i$, but to one computed from the already-updated representation $\vz_i$.

We reuse the same projection matrices $K$ and $V$ for both the temporary and persistent key--value pairs. Consequently, `\methodname{}`{=latex} does not introduce additional key/value projection parameters relative to a Transformer; this reuse also preserves a shared semantics between the temporary and persistent key--value representations.

Closest Related Work {#sec:closest-related-work}
--------------------

The closest representational relatives are Feedback Transformer variants. Feedback Transformer [@fan2020feedback] uses a cross-layer feedback memory shared across depth, essentially having just one list of key-value pairs computed based on the whole model's output rather than independently at each layer. Staircase Attention [@ju2021staircase] generalizes Feedback Transformers, studying recurrent processing and caching variants with weight sharing -- still at a model rather than layerwise level. This separation matters not only representationally but also computationally: within an `\methodnameabv{}`{=latex} layer, all queries are available early, which is the enabling condition behind our efficient training methodology (Section `\ref{sec:tiling}`{=latex}).

TransformerFAM [@hwang2024transformerfam] is closer in that it also operates independently at each layer and allows later positions to access more processed representations. However, it does so through a bounded memory that is read from and written to via attention. By contrast, `\methodname{}`{=latex} retains per-token persistent key--value memory rather than compressing past information into a fixed-size state. This difference is important both for avoiding a bounded-memory bottleneck and for the representational results of Section `\ref{sec:repr-transformers}`{=latex}.

Representational Perspective {#sec:repr}
============================

In this section, we theoretically show that `\methodname{}`{=latex} can emulate both a Transformer and RNN under mild assumptions. This shows that it subsumes both RNNs and Transformers, at least in the representation power.

Representing Transformers {#sec:repr-transformers}
-------------------------

Intuitively, `\methodnames{}`{=latex} can recover the behavior of a standard Transformer of lower width by ensuring that the persistent key--value pairs computed from $\vz_i$ track those that would have been computed from $\vx_i$ via $K$ and $V$ projections. We concretize this statement below:

```{=latex}
\begin{thm_samy}[informal]Any width-$d'$ Transformer can be approximately simulated by a width-$d=3d'$ \methodname{}: the simulated Transformer activations can be embedded into disjoint feature groups of the \methodnameabv's embeddings. The \methodnameabv{} layer can be parameterized so that (i) attention scores are preserved and (ii) the layer output exactly tracks the Transformer layer output.
\label{thm:infrml-gen}
\end{thm_samy}
```
At a high level, the construction relies on representing smaller Transformer states inside a larger embedding of `\methodnameabv{}`{=latex} by dedicating disjoint feature blocks to different roles. One block stores a protected copy of the layer input $\vx_i$ so that when `\methodnameabv{}`{=latex} later computes persistent keys and values from the layer output $\vz_i$, it can still recover exactly the same key--value pairs the Transformer would have computed from $\vx_i$. A separate block is used to hold the attention contribution $\va_i$, so that when adding it to $\vx_i$ prior to applying the MLP, the contents of $\vx_i$ are protected from being lost. In this way, later tokens see identical attention scores, while the layer output matches the Transformer's layer output in the designated block. The width overhead of a factor of $3$, rather than $2$, is a subtle technical requirement for stacking multiple layers; a single layer can be replicated with an overhead of $2$. The complete construction and formal proof are provided in Appendix `\ref{app:transformer-sim}`{=latex}.

Representing token-to-token recurrence {#sec:repr-rnn}
--------------------------------------

If, using positional embeddings or biases, attention concentrates locally to the previous position, `\methodnames{}`{=latex} implement an RNN-like update. Formally, if $\va_i$ is dominated by the previous persistent value $\vv_{i-1}$, i.e. $\innerp{\vq_i}{\vk_{i-1}} \gg \innerp{\vq_i}{\vk^{\temp}_i}$ and $\innerp{\vq_i}{\vk_{i-1}} \gg \innerp{\vq_i}{\vk_j}$ for any $j < i - 1$, then we get that: $$\begin{aligned}
\vz_i &\approx \vx_i + \vv_{i-1} + \MLP[\RMS(\vx_i+\vv_{i-1})] = V\,\RMS(\vz_{i-1}) + \vx_i + \MLP[\RMS(\vx_i+V\,\RMS(\vz_{i-1}))]\end{aligned}$$ Under the additional simplifying assumption that $V$ is the identity, this becomes a particular state recurrence with a skip connection: $$\begin{aligned}
\vz_i = \RMS(\vz_{i-1}) + \vx_i + \MLP[\RMS(\vx_i+\RMS(\vz_{i-1}))]\end{aligned}$$ We do not claim to reproduce gated RNN/LSTMs, nor that training would yield to learning such structures. We stress that representationally, `\methodname{}`{=latex} is rich enough to express explicit iterative computation within a layer while also retaining full-prefix per-token memory (which was required to simulate Transformers in the previous section).

Crucially, once an architecture can represent such iterative computation, a natural question is whether the classic learnability issues of RNNs [@bengio1994long; @pascanu2013difficulty] impede training. Section `\ref{sec:paths}`{=latex} explains why `\methodnames{}`{=latex} multi-hop dynamics can still train stably.

Why temporal depth matters
--------------------------

Transformers are shallow-through-time: deeper iterative computation along the sequence must be simulated primarily by stacking layers. Theory on low-depth attention models and finite-automata tracking problems suggests that bounded depth can have concrete consequences, with shallow Transformers being representationally insufficient to simulate certain automata [@liu2022shortcuts] and more generally bound to TC0 - a class of shallow circuits [@merrill2022saturated]. `\methodnames{}`{=latex} expose additional temporal depth within each layer. This is complementary to the depth obtained from stacking layers, thus pointing to the potential of achieving matching Transformers' effective depth while using fewer layers. We corroborate this hypothesis empirically in Section `\ref{sec:expts}`{=latex}.

Training Stability of `\methodname{}`{=latex} {#sec:paths}
=============================================

In this section, we explain how `\methodname{}`{=latex} manages to avoid degenerate dynamics such as gradient vanishing or exploding through depth. We formalize our arguments by viewing the model as a directed computation graph over positions: there is an edge $i\to j$ when the computation at position $j$ *directly* depends on quantities computed at position $i$. In a classical RNN, information (and gradients) from position $i$ to $j$ must traverse the full chain $i\to i\!+\!1\to\cdots\to j$, and repeated composition along long chains leads to vanishing/exploding gradient phenomena [@bengio1994long; @pascanu2013difficulty]. This chain topology forces all influence from position $i$ to $j$ through $(j-i)$ successive state transitions. Stabilizing training typically requires these transitions to be close to contractive, but then the influence of $\vx_i$ on $\vx_j$ shrinks rapidly with $(j-i)$, making distant information difficult to transmit. While in RNNs this issue can be alleviated through careful initialization schemes [@orvieto2023resurrecting], our method takes advantage of the fact that there are both direct hops, as well as additional multi-hop paths between layers.

As in a standard Transformer, token $j$ can directly attend to any earlier token $i<j$ via the stored key--value pair $(\vk_i,\vv_i)$, creating a one-hop information path $i\!\to\! j$. The key difference is that in `\methodname{}`{=latex}, the stored pair $(\vk_i,\vv_i)$ is computed from the *layer output* $\vz_i$, and $\vz_i$ already includes the result of attending to earlier stored pairs. Consequently, information can propagate not only directly from $i$ to $j$, but also *indirectly*.

Concretely, a multi-hop path from token $1$ to token $4$ (Figure `\ref{fig:rt-arch}`{=latex}) can go through intermediate write--read steps: $$\vx_1 \to \vz_1 \to (\vk_1,\vv_1) \to \va_2 \to \vz_2 \to (\vk_2,\vv_2) \to \va_4 \to \vz_4$$ Here each step $\vz_t \to (\vk_t,\vv_t)$ is a *write* to the layer's persistent memory, and each step $(\vk_t,\vv_t) \to \va_{t'}$ is a *read* by attention at any later position $t'>t$. Chaining these write--read operations yields multi-hop influence paths whose length scales with the distance between positions, enabling within-layer iterative computation (as in Section `\ref{sec:repr-rnn}`{=latex}) while preserving the direct one-hop attention routes of a Transformer.

#### Dampening long paths without eliminating long-range access.

Multi-hop paths are only useful if they do not explode. In practice, two levers dominate. First, standard depth-wise scaling conventions for residual branches keep per-layer updates in a stable range [@bordelon2023depthwise; @yang2023tensor]. Second, normalization preceding computation of persistent keys/values (the $\RMS(\vz_i)$ inside Equations `\ref{eq:persistentK}`{=latex}-`\ref{eq:persistentV}`{=latex}) controls magnitudes even though $\vz_i$ is a sum of multiple components. Empirically, these choices place long multi-hop influences on the vanishing end: longer chains have smaller effect. Unlike a pure RNN, this does not remove long-range access because direct attention edges remain available even when long chains are damped.

In the theorem below, for a very simplified setup without normalization, we show that we do not get exploding gradients at initialization. Normalization helps in stability, that is, it allows stable training of `\methodname{}`{=latex} with higher learning rates. This is demonstrated empirically in Appendix `\ref{app:layernorm_stable}`{=latex}.

```{=latex}
\begin{theorem} \label{thm:train_stable}
    Consider a simplified 1-layer uniform-attention only \methodnameabv{} layer with inputs given by $x_1,...,x_n$ and outputs denoted as $z_1,...,z_n$, where
    \[ z_k = x_k + \frac{\alpha}{k} \left( V x_k + V \sum_{j=1}^{k-1} z_j \right) \]
    where $\alpha$ is a scalar denoting the scaling of the residual and $V$ is the value matrix. Then, for $k \geq 2$,

    \[ \frac{\partial z_k}{\partial x_1} = \frac{1}{k!} \sum_{r=1}^k {k \brack r}\,\alpha^r V^r \]

    where ${k \brack r}$ denotes the total number of permutations of k elements having exactly $r$ cycles.
\end{theorem}
```
As the total number of permutations is $k!$, the theorem above shows that as long as the maximum eigenvalue of $\alpha V$ is smaller than $1$, we do not get an exploding gradient from $z_j$ to $x_1$. Thus, for orthonormal initialization, for any $\alpha < 1$, we expect to be in this regime. Moreover, since the overall gradient is summed over paths of various lengths (given by $r$ in the above expression), we can see that we have non-vanishing gradient even when the maximum eigenvalue of $\alpha V$ is smaller than $1$. In particular, since ${k \brack 1} = (k-1)!$, the term in the above expression corresponding to $r=1$ is $\frac{\alpha}{k} \cdot V$, which is precisely the gradient a vanilla transformer would yield. Proof for this theorem can be found in Appendix `\ref{app:train_stable}`{=latex}.

![ We use the tiling of [@flashInference] to increase arithmetic intensity during the forward pass since $(\vk_t, \vv_t)$ only become available after the attention output $\va_t$ is computed - this in turn happens once position $t$ has attended to all previous KVs. ](diagrams/RT_fast_tiling_8.png){#fig:tiling width="0.6\\columnwidth"}

Exact Tiling for Training and Prefill {#sec:tiling}
=====================================

#### What makes naive evaluation slow.

During training/prefill, `\methodnames{}`{=latex} are fundamentally sequential in position; to compute the persistent pair $(\vk_t,\vv_t)$ we must first compute $\vz_t$, and $\vz_t$ depends on $\va_t$, which aggregates over all previous persistent key--value pairs: $\{(\vk_i, \vv_i)\}_{i<t}$. Therefore, a naive implementation reveals persistent keys/values one position at a time by having each new query aggregate over a growing prefix, yielding low reuse and high memory traffic.

#### A short Roofline view: why we care about arithmetic intensity.

The Roofline model [@williams2009roofline] bounds attainable throughput by either peak compute or peak memory bandwidth depending on arithmetic intensity (FLOPs per byte moved). When attention repeatedly streams large prefixes of keys/values to produce small incremental updates, effective arithmetic intensity (AI) can be close to constant, making the operation bandwidth-bound even on large accelerators. This is the regime where reorganizing memory movement (even without changing the math) can give large wins.

#### Enabling observation: within-layer queries are available early.

Despite the sequential reveal of persistent keys/values, all queries $\{\vq_i\}_{i=1}^N$ in a layer depend only on the layer input $\{\vx_i\}$ and can be computed early on in parallel. This means that one could do some eager work of \"aggregating\" the contribution of any key--value pairs available thus far, to any future queries, not just the immediately upcoming one. For example, after $(\vk_4, \vv_4)$ are computed, naively, we would wait until next step when we need to know $\va_5$; for that, we check the whole prefix of $4$ key--value pairs, \"inquiring\" about *just one* query ($\vq_5$). This \"just one\" is what gives the arithmetic intensity of $\approx 1$. Alternatively, one can already start accounting for how they contribute to $\va_5\ldots \va_8$ - inquiring about $4$ queries at once and thus raising the arithmetic intensity to $\approx 4$ [^1].

A very similar regime is exploited in the Flash Inference framework [@flashInference] -- while their framework is meant for decoding, one forward pass of `\methodnameabv{}`{=latex} is essentially a sequence of decode steps. It applies to our case since the computation of interest is:

-   Contribution-based ($\va_i$ can be accumulated over different groups of key--value pairs)

-   the contribution is independent of future $\vz$'s (i.e., all queries are readily available)

The second condition also clarifies why the same approach cannot extend to cross-layer feedback architectures [@fan2020feedback; @ju2021staircase]: when future queries indirectly depend on feedback that is only produced after running later layers, queries are not all available early.

#### Exact tiling schedule.

Our algorithm is an exact evaluation algorithm: it computes the same attention outputs up to floating-point reordering effects. The schedule follows the tiling in Figure `\ref{fig:tiling}`{=latex}. It interleaves

-   computing $\vz_t$ (via $\MLP$) as soon as $\va_t$ is available, to then reveal the new persistent key--value pair ($\vk_t, \vv_t)$ and

-   updating attention accumulators for several future queries that are already known - by processing the newly freed tile.

For example, as $\va_6$ becomes available, $\vz_6$ and then ($\vk_6, \vv_6)$ are computed and then one can process all the the contribution of $\{(\vk_5, \vv_5), (\vk_6, \vv_6)\}$ to $\{\va_7, \va_8\}$ (by \"asking\" queries $\vq_7, \vq_8$). In order to aggregate attention contribution, we maintain the same online softmax statistics as [@rabe2021blockAttention; @dao2022flashattention] (running attention score maxima and normalizing factor) so that contributions from multiple key/value tiles can be accumulated stably. The full algorithm description is available in Appendix `\ref{app:complete-algo}`{=latex}.

![One-layer forward-pass latency as a function of sequence length at batch size $512$ on a single H100 GPU with $1024$ width. The naive recurrent implementation shows approximately quadratic growth with context length, whereas the tiled implementation scales much closer to linearly. This matches the intended effect of the tiled schedule, which increases reuse of loaded key--value pairs across multiple future queries. We also include the vanilla Transformer baseline for reference.](plots/latency/latency_vs_seq_len_bs512_transformer.png){#fig:latency-scaling width="0.6\\linewidth"}

#### Asymptotics.

Counting HBM movement, the naive one-query-at-a-time implementation incurs $\Theta(N^2)$ memory traffic. The tiled schedule reduces traffic to $\tilde{\Theta}(N\log N)$ by reusing streamed key/value tiles across many queries, while attention FLOPs remain $\Theta(N^2)$. Consequently, effective arithmetic intensity increases from $\Theta(1)$ to $\Theta(N/\log N)$. The gains of this tiling approach can be seen in Figure `\ref{fig:latency-scaling}`{=latex} - while the latency of a naive eager implementation grows approximately quadratically with context length, our method exhibits near-linear scaling.

A deep dive into the computational challenges {#sec:comp_challenges}
=============================================

In this section, we outline the key computational design decisions required to make `\methodname{}`{=latex} training practically efficient, enabling the execution of our language modeling experiments. In contrast to the tiling algorithm discussed earlier, which focuses on algorithmic structure, the emphasis here is on implementation-level optimizations. These changes do not alter asymptotic complexity, but instead yield meaningful constant-factor speedups that are critical for reducing overall training time.

#### The setup.

We run all experiments on H100 GPUs, with training carried out on a single device at a time. Our implementation is in PyTorch [@paszke2017automaticPyTorch]. Beyond the algorithmic tiling strategy of Section `\ref{sec:tiling}`{=latex}, we make several implementation choices aimed at improving hardware utilization and reducing overheads. While we successfully use Torch compile to fuse a number of components, we do not rely on custom kernels in the current implementation. We leave such kernel-level optimization to future work and focus here on the algorithmic improvements introduced by the architecture and computation schedule.

Unless otherwise noted, all latency measurements assume hidden dimension $1024$ and $16$ attention heads, are averaged over $5$ runs after $3$ warmup runs, and have standard deviation below $1$ ms. Latencies are reported on a per-layer basis: they measure the computation of a single layer's map $\vx \mapsto \vz$, including both attention and MLP computation, but excluding embedding, unembedding, and loss computation, which are identical across Recurrent Transformers and Transformers.

MLPs and batch size {#sec:batch-size}
-------------------

While the tiling algorithm greatly improves the arithmetic intensity of the attention component of `\methodnames{}`{=latex}, the MLP computations must be interleaved with it and can become the dominant cost of a forward pass. The reason is that, unlike in a standard Transformer, the MLP does not receive all $B \times N$ tokens at once. Instead, it processes only $B$ tokens at a time over $N$ iterations, one position at a time. As a result, the per-device batch size $B$ directly controls the arithmetic intensity of the MLP; in the regime $B \le O(d)$, this intensity is approximately linear in $B$.

In practice, for the model scales we consider, $B = 512$ provides a favorable trade-off between GPU utilization and activation memory. This constraint is fundamental to recurrent-in-time architectures and is not specific to the attention mechanism; similar issues arise in classical recurrent models such as LSTMs. In particular, sustaining such a batch size on a single device is already challenging from the perspective of activation memory, even for models in the 150M--300M parameter range. Under ordinary circumstances, one would employ gradient accumulation, but that defeats the purpose here: increasing the total batch size via accumulation does not increase the effective batching seen by the MLP, and therefore does not improve arithmetic intensity.

#### Activation Checkpointing

We therefore rely on activation checkpointing. One useful property of our computation is that once the inputs $\vx$ to each layer are stored, the outputs $\vz$ can be retained essentially \`\`for free," since they serve as the inputs to the next layer. This leads to a substantially cheaper recomputation procedure. In particular, once the persistent key--value pairs have been (parallelly) reconstructed from $\vz$, the remaining intermediate quantities can be recovered in a fully parallel manner. In particular, the attention-related intermediates can be recomputed without replaying the slow sequential process by which the $\vz_i$ were originally revealed one at a time. Consequently, although we still incur the standard cost of checkpointing, the recomputation overhead is meaningfully smaller than that of the original forward pass.

#### Critical batch size.

Even if memory permitted arbitrarily large batches, there remains a statistical efficiency limit to how many tokens can be processed per optimizer step before optimization quality begins to degrade [@McCandlish2018Empirical; @Shallue2018Measuring]. In our setup, this critical batch size is around $256$K tokens per optimizer iteration for Transformer models in the 150M and 300M parameter range [@zhang2025how]. Accordingly, throughout our experiments we use sequences of length $512$ and batch size $512$, corresponding to $256$K tokens per iteration. In Appendix `\ref{app:150m-pretrain}`{=latex}, we further verify that the critical batch size of `\methodname{}`{=latex} is not below this value.

Using CUDA Graphs
-----------------

::: {#tab:cudagraph}
  batch size     without CUDA Graphs   with CUDA Graphs
  ------------ --------------------- ------------------
  32                          277.08              38.85
  64                          279.23              42.28
  128                         279.42              48.95
  256                         276.05              61.39
  512                         277.33              81.73
  1024                        275.49             134.09
  2048                        293.70             240.83

  : One layer of Recurrent Transformer forward pass latency (ms) for sequences of $512$ tokens each. CPU overhead dominates at lower batch sizes and we employ CUDA Graphs to mitigate this.
:::

Both the attention and MLP portions of our architecture involve $O(N)$ launches of moderately small kernels. Even when the underlying kernels are themselves compute-bound, the amount of work per kernel can be small enough that CPU-side dispatch overhead becomes a dominant bottleneck. Ordinarily, launch overhead is hidden because future kernels can be enqueued while current ones are still executing. Here, however, the kernels are sufficiently short-lived that this overlap is limited, and dispatch latency becomes visible on the critical path.

For this reason, we use CUDA Graphs, recording the full forward (and backward) pass computation and replaying it with a single launch. This turns out to be crucial for performance. The resulting latency improvements are reported in Table `\ref{tab:cudagraph}`{=latex}. One noteworthy feature of Table `\ref{tab:cudagraph}`{=latex} is that, without CUDA Graphs, latency remains nearly flat across a wide range of batch sizes, indicating that dispatch overhead rather than arithmetic work is the main bottleneck. With CUDA Graphs enabled, latency scales much more meaningfully with batch size, reflecting the underlying compute cost more faithfully.

Cache locality and memory access pattern
----------------------------------------

The tiling schedule also has a favorable memory-access pattern. As we iterate through positions, the portions of the KV cache accessed at successive steps overlap heavily and are often quite small: on average involving only $O(\log N)$ positions. To exploit this locality, we store the KV cache with the position dimension first, rather than the more conventional batch- or head-major layout. This ensures that the slices accessed by each tiled update are contiguous in memory, improving cache locality and reducing unnecessary memory movement.

The backward pass
-----------------

To avoid repeated `cat`-operations, which would increase both memory traffic and peak memory usage, we preallocate the persistent KV cache and write to it in place. Since PyTorch autograd is not compatible with in-place operations, we implement a custom backward pass.

A naive implementation would simply mimic the reverse traversal that autograd would have carried out on the corresponding computational graph. However, the structure of the computation allows a more parallel schedule. In particular, within the reverse loop over positions, one only needs to accumulate the gradients with respect to $(\vk_i, \vv_i)$. The reason is that, before moving from position $i$ to position $i-1$ and propagating the effect of $\va_{i-1}$ onto earlier key--value pairs, one must already have the final gradient with respect to $\vz_{i-1}$, which itself depends in part on $(\vk_{i-1}, \vv_{i-1})$. By contrast, gradients with respect to $\vx$, $\vq$, the temporary key--value pairs, and the model parameters do not impose such immediate dependencies and can therefore be computed outside the loop in a batched manner via larger kernels. This substantially improves parallelism in the backward pass.

Experiments {#sec:expts}
===========

We evaluate `\methodname{}`{=latex} on synthetic tasks designed to stretch models' representation ability, as well as language modeling.

Synthetic diagnostics
---------------------

We use the MAD suite [@poli2024mad] as a diagnostic for hybrid architectures. We also include the copy task [@jelassi2024repeat] that is provably impossible to solve via models of finite memory (this classification includes all forms of RNNs, including SSMs [@gu2024mamba]). These diagnostics are not intended as long-range benchmarks; they isolate whether recurrence is being used in the intended way. Since we want to measure the effective depth of a layer, we compare one Transformer layer to one layer of `\methodnameabv{}`{=latex}, otherwise preserving the same model configurations as in [@poli2024mad]. The precise hyperparameter details are provided in Appendix `\ref{app:hyper_synth}`{=latex}. The sequence-level accuracies are displayed in Figure `\ref{fig:synth}`{=latex} and show `\methodnameabv `{=latex}s significantly outperform Transformers that do not achieve meaningful performance on any of the tasks.

```{=latex}
\begin{figure*}[t]
  
  \includegraphics[width=\textwidth]{plots/all_tasks_ExactMatch.pdf}
  \caption{Sequence-level accuracy of the \methodname{} and a regular Transformer on MAD synthetic tasks and the copy task. \methodnameabv outperforms Transformers in all tasks but compression. Neither model achieves non-trivial performance on compression at sequence-level, but they do achieve meaningful token-level accuracy with \methodnameabv{} still in the lead, as shown in Appendix~\ref{app:synthetics-token-level}.}
  \label{fig:synth}
\end{figure*}
```
::: {#tab:c4-512-300m-downstream}
  Model          Layers       piqa CE          hellaswag CE        arc easy CE       openbook qa CE          sciq CE         winogrande CE
  ------------- -------- ------------------ ------------------ ------------------- ------------------- ------------------- ------------------
  Transformer      6          $5.536$            $4.334$            $11.128$            $12.408$            $14.514$            $8.413$
  Transformer      12         $5.508$            $4.156$            $10.894$             $12.36$             $14.15$            $7.809$
  Recurrent        12     $\textbf{5.276}$   $\textbf{4.052}$   $\textbf{10.336}$   $\textbf{11.557}$       $13.384$        $\textbf{7.628}$
  Recurrent        6          $5.356$            $4.122$            $10.388$             $11.65$        $\textbf{13.206}$       $7.914$

  : Downstream performance for the 300M model.
:::

Language modeling on C4 (300M parameters)
-----------------------------------------

We implement the `\methodname{}`{=latex} on top of the OLMo-2 [@olmo20242] codebase and pretrain 300M and 150M parameter models on C4 [@C4raffel2020exploring], for $1\times$ Chinchilla tokens ($\approx 3b$ tokens). The precise hyperparameters are provided in Appendix `\ref{app:hyper_c4}`{=latex}. Figure `\ref{fig:c4-300m}`{=latex} shows the cross-entropy loss for the 300M model during training for Transformers and `\methodnames{}`{=latex} at $12$ layers ($24$ layers is the standard configuration used in previous works such as @zhao2025deconstructing which performs comparably as shown in Table `\ref{tab:c4-512-300m}`{=latex}) and parameter-equivalent $6$ layers ($d$ scaled up by $\sqrt{2}$ from $1408$ to $2048$). Table `\ref{tab:c4-512-300m}`{=latex} contains the final cross entropy values. As shown, `\methodnameabv `{=latex}s outperform Transformers meaningfully by a delta cross-entropy of $0.03$ at $12$ layers and $0.057$ at $6$ layers. Notably, the layerwise recurrence shifts the depth--width optimum at fixed parameter count.

We evaluate the model's downstream performance on 6 multiple choice tasks used in OLMo [@groeneveld2024olmoacceleratingsciencelanguage]: PIQA [@Bisk2020], Hellaswag [@Zellers2019], ARC Easy [@Clark2018ARC], OpenBookQA [@Mihaylov2018], SciQ [@Welbl2017] and Winogrande [@Sakaguchi2020]. We provide the model's CE loss values on the ground truth answers in Table `\ref{tab:c4-512-300m-downstream}`{=latex}. We used this metric as it is known to be smoother at small scales [@bhagia2025establishingtaskscalinglaws]. We also provide the downstream accuracies in Appendix Table `\ref{tab:c4-512-300m-downstream-acc}`{=latex}, but most of them are close to random noise (with Winogrande being close to 50%, while most others being 6-7% above random). The results for the 150M model are provided in Appendix `\ref{app:150m-pretrain}`{=latex}.

Depth--width tradeoffs and decoding footprint
---------------------------------------------

For autoregressive decoding, `\methodname{}`{=latex} exhibits essentially the same per-token attention behavior as a Transformer with comparable depth and width: each new token attends to cached keys and values computed from preceding tokens. If, however, `\methodname{}`{=latex} achieves comparable model quality using fewer layers---reduced by a factor of $\alpha$---while keeping the total parameter count fixed, the size of the key--value (KV) cache decreases by a factor of $\sqrt{\alpha}$. This follows from the corresponding increase in model width by only $\sqrt{\alpha}$. A smaller KV cache directly reduces memory traffic during decoding, which can lead to higher throughput in bandwidth-limited settings. More broadly, trading off depth for width may be advantageous for decoding, since increased width can be more effectively parallelized using techniques such as tensor parallelism. We leave a detailed evaluation of fully optimized decoding latency to future work.

Related Work {#sec:related}
============

We group the most relevant prior work by the core bottleneck it imposes and by whether it introduces recurrence/feedback-like computation beyond standard feedforward self-attention.

#### Bounded-memory sequence models.

Classical RNNs and modern state-space models maintain a fixed-size state that is updated recurrently [@bengio1994long; @pascanu2013difficulty; @LSTMhochreiter1997long; @smith2022simplified; @gu2024mamba]. Linear-attention/retention variants also admit recurrent formulations with bounded state [@katharopoulos2020transformers; @sun2023retentive; @peng2023rwkv]. While computationally attractive, bounded-state families cannot in general preserve information that scales with sequence length; [@jelassi2024repeat] highlight this limitation by proving such models cannot perform the copy task. This limitation is in contrast to `\methodnameabv `{=latex}(as shown in Section `\ref{sec:repr-transformers}`{=latex}).

#### Segment recurrence and memory mechanisms.

Transformer-XL and follow-ups process context in segments and potentially summarize them via attention mechanisms [@dai2019transformerxl; @rae1911compressive]. RMTs take this a step forward by summarizing the feedback information across segments [@bulatov2022rmt]. These methods primarily address efficient long-context handling rather than layerwise recurrent computation over token states. Furthermore, they have the same limitation that classic RNNs have - namely bounded memory.

#### Recurrent/feedback Transformers.

Feedback Transformers [@fan2020feedback], Staircase Attention [@ju2021staircase], and TransformerFAM [@hwang2024transformerfam] introduce recurrent/feedback-style computation in Transformer blocks. As discussed in Section `\ref{sec:closest-related-work}`{=latex}, `\methodname{}`{=latex} is *layerwise* recurrent with separate per-layer key--value collections (rather than cross-layer shared feedback), and this structure is also what enables our efficient training/prefill (Section `\ref{sec:tiling}`{=latex}).

Discussion and Conclusion {#sec:conclusion}
=========================

In this work, we introduce `\methodname{}`{=latex}, which integrates three key ideas: (i) a deeper-in-time representation within each layer, (ii) a path-based perspective that enables longer multi-hop influences while retaining direct attention access by operating closer to the vanishing-gradient regime, and (iii) an exact, I/O-aware evaluation algorithm that makes training and prefill practical by reducing memory traffic without altering the underlying computation. Our results clearly demonstrate the increased effective depth provided by introducing layerwise recurrence.

We view this work as a proof of concept that opens several directions for future research. As with any new architecture, the introduction of layerwise recurrence may alter tuning behavior and scaling laws, potentially shifting the optimal depth--width trade-off. In addition, the current design can be extended with blocking, in which recurrence is executed over blocks of steps, yielding a controllable trade-off between recurrent depth and training speed. Finally, while the proposed tiling algorithm is exact and delivers measurable gains, further improvements are likely achievable through fully optimized kernels and extensions of existing parallelization techniques, which we leave to future work.

In conclusion, layerwise recurrence provides a simple and principled mechanism for exposing additional temporal depth while retaining a memory that scales with sequence length and preserves the full representational capacity of the Transformer. When combined with an exact tiling strategy that enables computation reuse and reduces HBM traffic, `\methodnames{}`{=latex} make recurrent-in-layer training and prefill feasible in practice and shift depth--width trade-offs at a fixed parameter count. This shift is also beneficial at decode time, where a reduced KV cache size leads to improved efficiency.

Acknowledgements {#acknowledgements .unnumbered}
================

DM, AM, MK acknowledge the support of a Kempner Institute Graduate Research Fellowship. The authors acknowledge that this work has been made possible in part by a gift from the Chan Zuckerberg Initiative Foundation to establish the Kempner Institute for the Study of Natural and Artificial Intelligence. SK, CO and DM acknowledge support from the Office of Naval Research under award N0001422-1-2377 and the National Science Foundation Grant under award \#IIS 2229881. DM is also supported by a Simons Investigator Fellowship, NSF grant DMS-2134157, DARPA grant W911NF2010021,and DOE grant DE-SC0022199.

```{=latex}
\bibliographystyle{abbrvnat}
```
```{=latex}
\newpage
```
```{=latex}
\appendix
```
```{=latex}
\onecolumn
```
Simulating Transformers with `\methodname`{=latex}
==================================================

Transformer Generalization Theorem Statement {#app:transformer-sim}
--------------------------------------------

The approximate part of the Informal Theorem `\ref{thm:infrml-gen}`{=latex} refers to the statement applying exactly when no $\RMS$s are used. This is only a small technicality required to make the statement exact. We restate both architectures without $\RMS$s (including inside attention projections and the MLP) and give an exact representation construction in this setting.

#### Norm-free architectures.

*Transformer (width $d'$).* Given inputs $\vx_1^T,\ldots,\vx_N^T\in\R^{d'}$ and parameters $Q^T,K^T,V^T\in\R^{d'\times d'}$ and $\MLP^T:\R^{d'}\to\R^{d'}$, then the outputs $\vy_1^T,\ldots,\vy_N^T\in\R^{d'}$ are computed by: $$\begin{aligned}
& \vq_i^T = Q^T\vx_i^T \qquad \vk_i^T = K^T\vx_i^T \qquad \vv_i^T = V^T\vx_i^T \\
& \va_i^T = \Attn\big((\vk_1^T,\vv_1^T),\ldots,(\vk_i^T,\vv_i^T),\vq_i^T\big) \\
& \vy_i^T = \vx_i^T + \va_i^T + \MLP^T[\vx_i^T+\va_i^T] .\end{aligned}$$

*`\methodname{}`{=latex} (width $d$).* Given inputs $\vx_1,\ldots,\vx_N\in\R^{d}$ and parameters $Q,K,V\in\R^{d\times d}$ and $\MLP:\R^{d}\to\R^{d}$, the $\vy_1,\ldots,\vy_N\in\R^{d}$ are defined via: $$\begin{aligned}
& \vq_i = Q\vx_i \qquad \vk_i^{\temp} = K\vx_i \qquad \vv_i^{\temp} = V\vx_i \\
& \va_i = \Attn\big((\vk_1,\vv_1),\ldots,(\vk_{i-1},\vv_{i-1}),(\vk_i^{\temp},\vv_i^{\temp}),\vq_i\big) \\
& \vz_i = \vx_i + \va_i + \MLP[\vx_i+\va_i] \\
& \vk_i = K\vz_i \qquad \vv_i = V\vz_i .\end{aligned}$$

```{=latex}
\begin{theorem}[Transformer Generalization]\label{thm:transformer-containment-normfree}
Assuming neither architecture uses $\RMS$s, any width-$d'$ Transformer (of arbitrary depth) can be simulated by a width-$d=3d'$ \methodname{} of as many layers.
There exists a parameterization of \methodnameabv{} such that Transformer's activations are embedded into disjoint feature groups of the \methodnameabv's ones across layers for any input sequence. This is achieved while ensuring that:
(i) attention scores match those of the Transformer at every position and every layer
and (ii) the layer output exactly tracks the Transformer layer output.
\end{theorem}
```
Proof {#app:transformer-sim-proof}
-----

#### Three blocks and the per-layer invariant.

Let $d=3d'$ and decompose $\R^{d}$ into three $d'$-dimensional blocks $$\begin{aligned}
& \R^{3d'} = \mathcal{C}\oplus\mathcal{L}\oplus\mathcal{S} .\end{aligned}$$ We call them *carry* ($\mathcal{C}$), *live* ($\mathcal{L}$) and *scratch* ($\mathcal{S}$). Carry is the only block that $K$ and $V$ read from (so both temporary and persistent key/value pairs depend on it); live holds the next-layer activation and scratch holds attention outputs so residual addition does not corrupt carry.

Fix one Transformer layer (parameters $Q^T,K^T,V^T,\MLP^T$) and assume, for every position $1 \leq i \leq N$ that: $$\begin{aligned}
& \vx_i = (\vx_i^T \; ; \; * \; ; \; 0) .\end{aligned}$$ We will construct `\methodnameabv`{=latex}'s layer parameters ($Q,K,V,\MLP$) so that the output satisfies: $$\begin{aligned}
& \vz_i = (\vx_i^T \; ; \; \vy_i^T \; ; \; 0) .\end{aligned}$$ Then a swap between the dimensional blocks $\mathcal{C}$ and $\mathcal{L}$ restores the same input form (namely, the output follows the $(\vy_i^T \; ; \; * \; ; \; 0)$ pattern following the swap) enabling stacking. Define block-sparse linear maps $$\begin{aligned}
& Q(\vc \; ; \; \vl \; ; \; \vs) = (Q^T\vc \; ; \; 0 \; ; \; 0) \\
& K(\vc \; ; \; \vl \; ; \; \vs) = (K^T\vc \; ; \; 0 \; ; \; 0) \\
& V(\vc \; ; \; \vl \; ; \; \vs) = (0 \; ; \; 0 \; ; \; V^T\vc) .\end{aligned}$$ Because $K$ and $V$ read only from carry the fact that they are shared between temporary and persistent pairs is automatically respected.

#### Attention matches (induction on position).

We prove by induction on $i$ that attention scores match and that the attention output lands in scratch.

Induction hypothesis: for all $j<i$ the persistent pairs match the embedded Transformer pairs $$\begin{aligned}
& \vk_j = (\vk_j^T \; ; \; 0 \; ; \; 0) \qquad \vv_j = (0 \; ; \; 0 \; ; \vv_j^T) .\end{aligned}$$ Using the input form $\vx_i = (\vx_i^T \; ; \; * \; ; \; 0)$ we have $$\begin{aligned}
& \vq_i = (\vq_i^T \; ; \; 0 \; ; \; 0)
\qquad
\vk_i^{\temp} = (\vk_i^T \; ; \; 0 \; ; \; 0)
\qquad
\vv_i^{\temp} = (0 \; ; \; 0 \; ; \vv_i^T) .\end{aligned}$$ Therefore the logits $\innerp{\vk_j}{\vq_i}$ match $\innerp{\vk_j^T}{\vq_i^T}$ for all $j\le i$ and the attention output matches as well $$\begin{aligned}
& \va_i = (0 \; ; \; 0 \; ; \va_i^T) .\end{aligned}$$

It thus follows that $\vx_i+\va_i = (\vx_i^T \; ; \; * \; ; \; \va_i^T)$. Define $\MLP[\cdot]$ on such inputs so that: $$\begin{aligned}
& \MLP[(\vx_i^T \; ; \; * \; ; \; \va_i^T)]
=
(0 \; ; \; \vx_i^T + \va_i^T + \MLP^T[\vx_i^T+\va_i^T] - *\; ; \; -\va_i^T) .\end{aligned}$$ Substituting into the `\methodname{}`{=latex} update shows why this choice is natural $$\begin{aligned}
& \vz_i = \vx_i + \va_i + \MLP[\vx_i+\va_i] \\
& \vz_i
= (\vx_i^T \; ; \; * \; ; \; 0) + (0 \; ; \; 0 \; ; \va_i^T)
+ (0 \; ; \; \vx_i^T + \va_i^T + \MLP^T[\vx_i^T+\va_i^T] - * \; ; \; -\va_i^T) \\
& \vz_i
= (\vx_i^T \; ; \; \vx_i^T + \va_i^T + \MLP^T[\vx_i^T+\va_i^T] \; ; \; 0) \\
& \vz_i = (\vx_i^T \; ; \; \vy_i^T \; ; \; 0) .\end{aligned}$$ In particular the persistent pairs for position $i$ satisfy $$\begin{aligned}
& \vk_i = K\vz_i = (\vk_i^T \; ; \; 0 \; ; \; 0)
\qquad
\vv_i = V\vz_i = (0 \; ; \; 0 \; ; \vv_i^T) .\end{aligned}$$ This closes the induction and proves equality of attention scores at every position within the layer.

#### Stacking layers.

After one layer we have $\vz_i = (\vx_i^T \; ; \; \vy_i^T \; ; \; 0)$. To simulate the next Transformer layer, the next `\methodname{}`{=latex} layer must see carry equal to $\vy_i^T$ while scratch remains $0$. Swapping carry and live between layers, we get: $$\begin{aligned}
& (\vx_i^T \; ; \; \vy_i^T \; ; \; 0)\mapsto(\vy_i^T \; ; \; \vx_i^T \; ; \; 0) .\end{aligned}$$ This restores the input form $\vx_i = (\vx_i^T \; ; \; * \; ; \; 0)$ for next layer since its input is current layer's output $\vy_i^T$. Since each `\methodname{}`{=latex} layer has its own parameters the swap can be absorbed into the next layer's $(Q,K,V,\MLP)$ choice. Iterating over depth completes the simulation of an arbitrary-depth Transformer.

#### Remark (single layer vs stacking and why $3d'$ is needed).

A single layer can be simulated with $2d'$ by preserving carry and writing the output elsewhere. Stacking forces an additional scratch subspace: attention outputs must be representable and cancelable without corrupting the carry block that $K,V$ read while the live block stores the next-layer activation. This is why the clean exact construction uses $3d'$.

```{=latex}
\newpage
```
Training Stability of `\methodname`{=latex} {#app:train_stable}
===========================================

The proof for Theorem `\ref{thm:train_stable}`{=latex} is provided below.

```{=latex}
\begin{proof}
Given the update, we can see
\[ \frac{\partial z_k}{\partial x_1} = \frac{\alpha}{k} V \sum_{j=1}^{k-1} \frac{\partial z_j}{\partial x_1} \]

Moreover, 
\[ \frac{\partial z_1}{\partial x_1} = I + \alpha V \]
Let's denote $\frac{\partial z_k}{\partial x_1}$ by $f(k)$. 
Define
\[
S(k):=\sum_{j=1}^k f(j).
\]
Then for $k\ge 2$, the recurrence gives
\[
f(k)=\frac{\alpha}{k}V\,S(k-1).
\]
Hence
\begin{align*}
S(k)
&=S(k-1)+f(k)\\
&=S(k-1)+\frac{\alpha}{k}V\,S(k-1)\\
&=\left(I+\frac{\alpha}{k}V\right)S(k-1).
\end{align*}
Since
\[
S(1)=f(1)=I+\alpha V,
\]
we obtain by iterating the above relation that
\[
S(k)=\prod_{m=1}^k \left(I+\frac{\alpha}{m}V\right).
\]
Therefore, for $k\ge 2$,
\begin{align*}
f(k)
&=S(k)-S(k-1)\\
&=\left(I+\frac{\alpha}{k}V\right)S(k-1)-S(k-1)\\
&=\frac{\alpha}{k}V\,S(k-1)\\
&=\frac{\alpha}{k}V\prod_{m=1}^{k-1}\left(I+\frac{\alpha}{m}V\right).
\end{align*}

Now set $z=\alpha V$. Then
\begin{align*}
f(k)
&=\frac{z}{k}\prod_{m=1}^{k-1}\left(I+\frac{z}{m}\right)\\
&=\frac{z}{k}\prod_{m=1}^{k-1}\frac{z+mI}{m}\\
&=\frac{z}{k}\cdot \frac{1}{(k-1)!}\prod_{m=1}^{k-1}(z+mI)\\
&=\frac{1}{k!}\prod_{m=0}^{k-1}(z+mI).
\end{align*}
Substituting back $z=\alpha V$ gives
\[
f(k)=\frac{1}{k!}\prod_{m=0}^{k-1}(\alpha V+mI).
\]

Finally, we use the standard rising-factorial expansion
\[
x(x+1)\cdots(x+k-1)=\sum_{r=0}^k {k \brack r}x^r,
\]
where ${k \brack r}$ are the unsigned Stirling numbers of the first kind or the total number of permutations of $k$ elements with exactly $r$ cycles. Replacing the scalar variable $x$ by the matrix $\alpha V$, we obtain
\[
\prod_{m=0}^{k-1}(\alpha V+mI)=\sum_{r=0}^k {k \brack r}\,\alpha^r V^r.
\]
Since ${k \brack 0}=0$ for $k\ge 1$, this becomes
\[
\prod_{m=0}^{k-1}(\alpha V+mI)=\sum_{r=1}^k {k \brack r}\,\alpha^r V^r.
\]
Therefore,
\[
f(k)=\frac{1}{k!}\sum_{r=1}^k {k \brack r}\,\alpha^r V^r,
\]
as claimed.
\end{proof}
```
```{=latex}
\newpage
```
More on computational efficiency {#app:complete-algo}
================================

We start by providing the full forward-pass algorithm for evaluating one `\methodnameabv{}`{=latex} layer exactly (Algorithm `\ref{alg:rt-fwd}`{=latex}). The schedule follows Figure `\ref{fig:tiling}`{=latex}: persistent key--value pairs $(\vk_t,\vv_t)$ are revealed sequentially (only after $\vz_t$ is computed), but queries $\{\vq_i\}_{i=1}^N$ are available from the very beginning. We take advantage of this by immediately attenting to newly-available key--value pairs across an entire range of future queries rather than only the next query, increasing KV-access reuse while preserving the model's exact computation.

Concretely, when token $t$ finishes, the tile size is chosen as the largest power of $2$, $P$ that divides $t$. Algorithm `\ref{alg:rt-fwd}`{=latex} immediately incorporates the contribution of $(\vk_{t-P+1:t},\vv_{t-P+1:t})$ into the attention accumulators of the next query block $q_{t+1:t+P}$. Over the full run, every query position accumulates contributions from every earlier key--value pair exactly once, matching naive causal attention up to floating-point reordering effects.

```{=latex}
\begin{algorithm}[t]\caption{Exact tiled forward pass for one \methodname{} layer (training/prefill)}
\label{alg:rt-fwd}
\begin{algorithmic}[1]
\REQUIRE Inputs $\vx_{1:N}$ for a single layer
\REQUIRE Projections $Q,K,V$ and MLP block $\MLP$
\REQUIRE Query tile size $B$ (power of $2$)
\ENSURE Outputs $\vz_{1:N}$ and persistent pairs $(\vk_{1:N},\vv_{1:N})$

\STATE Compute queries in parallel: $\vq_i \leftarrow \qkRMS[Q\,\RMS(\vx_i)]$ for $i=1,\ldots,N$
\STATE Initialize running stats for all queries:
\STATE $m_i \leftarrow -\infty$ \quad $l_i \leftarrow 0$ \quad $\vo_i \leftarrow 0$ for $i=1,\ldots,N$
\STATE Initialize persistent buffers $\vk_{1:N},\vv_{1:N}$ as empty

\FOR{$t=1: N$}
    \STATE $\vk_t^{\temp} \leftarrow \qkRMS[K\,\RMS(\vx_t)]$ \qquad $\vv_t^{\temp} \leftarrow V\,\RMS(\vx_t)$
    \STATE $\textsc{UpdateTile}\big(q_{t:t}\;,\; \vk_t^{\temp}\;,\;\vv_t^{\temp}\big)$ \COMMENT{temporary self contribution}

    \STATE $\va_t \leftarrow \vo_t / l_t$
    \STATE $\vz_t \leftarrow \vx_t + \va_t + \MLP[\RMS(\vx_t+\va_t)]$
    \STATE $\vk_t \leftarrow \qkRMS[K\,\RMS(\zmark{\vz_t})]$ \qquad $\vv_t \leftarrow V\,\RMS(\zmark{\vz_t})$ \COMMENT{persistent KV pair revealed}

    \STATE $P \leftarrow 2^{\nu_2(t)}$ \COMMENT{largest power of $2$ dividing $t$}
    \STATE $(u, v) \leftarrow (t+1, \min(t+P,N)))$
    \IF{$u \le v$}
        \STATE $\textsc{UpdateTile}\big(q_{u:v}\;,\;\vk_{t-P+1:t}\;,\;\vv_{t-P+1:t}\big)$
        \COMMENT{have the next query block attend to the newly-available KV segment}
    \ENDIF
\ENDFOR
\end{algorithmic}
\end{algorithm}
```
```{=latex}
\begin{algorithm}
\caption{$\textsc{UpdateTile}(q_{u:v},\vk_{s:e},\vv_{s:e})$: online-softmax update for a query tile}
\label{alg:update-tile}
\begin{algorithmic}[1]
\REQUIRE Query indices $u\!:\!v$ with queries $\vq_{u:v}$
\REQUIRE A key--value tile $\vk_{s:e},\vv_{s:e}$ (persistent) or a single temporary pair $(\vk_t^{\temp},\vv_t^{\temp})$
\REQUIRE Running stats $(m_{u:v},l_{u:v},\vo_{u:v})$
\ENSURE Updated $(m_{u:v},l_{u:v},\vo_{u:v})$ corresponding to including this tile

\STATE Compute tile logits $\alpha_{i,j} \leftarrow \innerp{\vq_i}{\vk_j}$ for all $i\in[u,v]$ and $j\in[s,e]$
\STATE Compute per-query tile maxima $m^{\text{tile}}_i \leftarrow \max_{j\in[s,e]} \alpha_{i,j}$ for all $i\in[u,v]$
\STATE Compute new maxima $m^{\text{new}}_i \leftarrow \max(m_i,\; m^{\text{tile}}_i)$ for all $i\in[u,v]$

\STATE Rescale old accumulators:
\STATE $\vo_i \leftarrow \vo_i \cdot \exp(m_i - m^{\text{new}}_i)$ \qquad $l_i \leftarrow l_i \cdot \exp(m_i - m^{\text{new}}_i)$ for all $i\in[u,v]$

\STATE Accumulate this tile:
\STATE $\vo_i \leftarrow \vo_i + \sum_{j=s}^e \vv_j \exp(\alpha_{i,j} - m^{\text{new}}_i)$ for all $i\in[u,v]$
\STATE $l_i \leftarrow l_i + \sum_{j=s}^e \exp(\alpha_{i,j} - m^{\text{new}}_i)$ for all $i\in[u,v]$

\STATE Finalize maxima: $m_i \leftarrow m^{\text{new}}_i$ for all $i\in[u,v]$
\end{algorithmic}
\end{algorithm}
```
Regarding how to accumulate attention contributions from multiple ranges of key--value pairs, we use [@rabe2021blockAttention; @dao2022flashattention]'s approach; For each query position (or query tile), we maintain the standard online-softmax running statistics:

-   a running max logit $m$

-   a running normalizer $l$

-   and a running numerator vector $\vo$

When a new contribution tile is processed, these statistics are updated by rescaling the existing accumulators and adding the tile's contribution computed relative to the updated maximum (keeping the logit maxima is only required for numerical stability). After all prefix tiles have been incorporated for position $t$, the exact attention output is recovered as $\va_t=\vo_t/l_t$.

Algorithm `\ref{alg:update-tile}`{=latex} spells out the $\textsc{UpdateTile}$ primitive used by the forward schedule. It takes a query range and a range of key--value pairs and updates $(m, l,\vo)$ for all queries in the tile in a vectorized manner.

```{=latex}
\newpage
```
Hyperparameter details {#app:hyper}
======================

C4 pretraining experiments {#app:hyper_c4}
--------------------------

For the C4 pretraining experiments in Figure `\ref{fig:c4-300m}`{=latex}, we used a 300m non-embedding parameter transformer. For the 12 layer experiments, the width of the model was $1408$, with the MLP width being $5632$ and number of heads being $22$ (so as to keep per-head dim to be $64$). For the 6 layer experiments, the width was adjusted to $2048$, with the MLP width being $8192$ and number of heads being $32$. The maximum sequence length was fixed to $512$ and the models were trained for $1$x Chinchilla tokens ($\approx 6b$ tokens), leading to $25 k$ steps for the $512$ batch size experiment. We used the alibi positional embeddings [@press2022trainshorttestlong] with max alibi bias of $8.0$. The throughput of `\methodname{}`{=latex} at 12 layers was $42k$ tokens/sec as compared to $132k$ tokens/sec for vanilla transformer. The throughput of `\methodname{}`{=latex} at 6 layers was $49k$ tokens/sec as compared to $153k$ tokens/sec for vanilla transformer.

We used the Adam optimizer for the experiments, with hyperparameter tuning given by: $\eta \in \{1e-3, 3e-3, 1e-2\}, \beta_1 = 0.9, \beta_2 \in \{0.95, 0.99\}, \eps = 1e-8$, and the weight decay was set to 0.0. We used warmup and cosine schedule for the experiments, with the warmup accounting for $40\%$ of the training as found to be optimal for this scale in previous works [@zhao2025deconstructing].

Synthetic experiments {#app:hyper_synth}
---------------------

For the synthetic experiments in Figure `\ref{fig:synth}`{=latex}, we used a single layer transformer, with model width $128$, MLP width $512$ and $16$ heads as in @poli2024mad. We used the alibi positional embeddings with max alibi bias set to $8.0$. We used the AdamW optimizer, with the hyperparameter tuning given by: $\eta \in \{1e-4, 5e-4, 1e-3, 5e-4\}, \beta_1 = 0.9, \beta_2 = 0.98, \epsilon = 1e-8, \lambda \in \{ 0.0, 0.1 \}$, where $\lambda$ represents the weight decay.

More experiments and results
============================

Synthetics Token Level {#app:synthetics-token-level}
----------------------

Figure `\ref{fig:synth-token-acc}`{=latex} shows the token-level accuracies for the different synthetic tasks. Note how in the compression task where neither transformers nor the `\methodnameabv{}`{=latex} have non-trivial performance at sequence level, the accuracy becomes non-trivial at the token level and the gap between the two architectures is still prominent.

```{=latex}
\begin{figure*}[t]
  
  \includegraphics[width=\textwidth]{plots/all_tasks_TokenAccuracy.pdf}
  \caption{Token level accuracies on synthetic diagnostics (MAD + copy).}
  \label{fig:synth-token-acc}
\end{figure*}
```
RMSNorm Ablation {#app:layernorm_stable}
----------------

In this section, we ablate the RMSnorm used in `\methodname{}`{=latex}, i.e, we replace Equation `\ref{eq:persistentK}`{=latex} and `\ref{eq:persistentV}`{=latex} with

$$\begin{aligned}
\vk_i &= \qkRMS(K\,\zmark{\vz_i}) \\
\vv_i &= V\,\zmark{\vz_i}\end{aligned}$$

The best performance with this setup was obtained with $\eta = 1e-3$, with higher learning rates destabilizing. Note that, in comparison, with RMSNorm, even learning rates till $1e-2$ are stable, although, $3e-3$ turns out to be the optimal. The results are shown in Figure `\ref{fig:LN-ablate}`{=latex}. As can be seen, the performance obtained is significantly worse without the layernorm.

![C4 pretraining: Ablating the use of RMSNorm in `\methodname{}`{=latex} for 150M parameter model at 512 batch size.](plots/best_runs_LN.png){#fig:LN-ablate width="0.7\\linewidth"}

C4 pretraining (150M scale) {#app:150m-pretrain}
---------------------------

For the 150M parameter model, for the 12 layer experiments, the width of the model was 1024, with the MLP width being 4096 and 16 number of heads. For the 6 layer experiments, the width was adjusted to 1408, with the MLP width being 5632 and 22 number of heads. The model was trained for 1x Chinchilla tokens ($\approx$ 3b tokens). We provide loss curves in Figure `\ref{fig:c4-512}`{=latex} and `\ref{fig:c4-256}`{=latex} at batch sizes $512$ and $256$ respectively - this is such that the critical batch size of $256K$ of the 150M model is not exceeded [@zhang2025how], while keeping the batch size large enough to have MLPs be compute-bound. Figures `\ref{fig:c4-512}`{=latex} and `\ref{fig:c4-256}`{=latex} shows that we have benefits at various batch sizes. The corresponding losses are displayed in Tables `\ref{tab:c4-512}`{=latex} and `\ref{tab:c4-256}`{=latex}. We also report the downstream performance of these models in terms of cross entropy loss of the ground truth answer in Tables `\ref{tab:c4-512-downstream-ce}`{=latex} and `\ref{tab:c4-256-downstream-CE}`{=latex} respectively. We also report the downstream accuracy in Tables `\ref{tab:c4-512-downstream-acc}`{=latex} and `\ref{tab:c4-256-downstream-acc}`{=latex} respectively.

::: {#tab:c4-512}
  Model                      Layers   Width    Val CE $\downarrow$
  ------------------------- -------- -------- ---------------------
  Transformer                  6      $1408$         $3.097$
  Transformer                  12     $1024$         $3.067$
  `\methodname{}`{=latex}      6      $1408$         $3.049$
  `\methodname{}`{=latex}      12     $1024$         $3.046$

  : C4 pretraining loss at 150M parameters at batch size 512.
:::

![C4 pretraining: loss curve for the 150M parameter model at batch size 512.](plots/best_runs_val_512.png){#fig:c4-512 width="0.7\\linewidth"}

::: {#tab:c4-512-downstream-ce}
  Model          Layers       piqa CE          hellaswag CE        arc easy CE       openbook qa CE         sciq CE         winogrande CE    
  ------------- -------- ------------------ ------------------ ------------------- ------------------- ------------------ ------------------ --
  Transformer      6          $5.791$            $4.460$            $11.419$            $12.509$             $15.3$            $7.912$       
  Transformer      12         $5.911$            $4.371$             $11.87$            $13.471$            $14.948$           $8.365$       
  Recurrent        12     $\textbf{5.606}$   $\textbf{4.274}$   $\textbf{11.124}$   $\textbf{12.067}$   $\textbf{14.43}$   $\textbf{7.645}$  
  Recurrent        6          $5.689$            $4.327$            $11.242$            $12.252$            $14.529$           $8.314$       

  : Downstream performance for the 150M model at batch size 512.
:::

::: {#tab:c4-512-downstream-acc}
  Model          Layers       piqa acc        hellaswag acc       arc easy acc     openbook qa acc       sciq acc        winogrande acc   
  ------------- -------- ------------------ ------------------ ------------------ ----------------- ------------------ ------------------ --
  Transformer      6          $60.55$            $28.34$            $32.11$            $35.6$            $34.29$            $49.49$       
  Transformer      12         $60.94$            $28.41$             $31.4$            $37.4$        $\textbf{37.17}$   $\textbf{52.01}$  
  Recurrent        12     $\textbf{61.26}$   $\textbf{29.67}$   $\textbf{32.63}$       $36.2$            $36.28$            $50.82$       
  Recurrent        6          $61.21$            $29.06$            $31.23$        $\textbf{38.4}$       $31.08$            $48.93$       

  : Downstream accuracy for the 150M model at batch size 512.
:::

::: {#tab:c4-256}
  Model                      Layers   Width    Val CE $\downarrow$
  ------------------------- -------- -------- ---------------------
  Transformer                  6      $1440$         $3.091$
  Transformer                  12     $1024$         $3.059$
  `\methodname{}`{=latex}      6      $1440$         $3.037$
  `\methodname{}`{=latex}      12     $1024$         $3.036$

  : C4 pretraining loss at 150M parameters, training at batch-size 256.
:::

![C4 pretraining: loss curve for the 150M parameter model at batch size 256.](plots/best_runs_val_256.png){#fig:c4-256 width="0.7\\linewidth"}

::: {#tab:c4-256-downstream-CE}
  Model          Layers       piqa CE         hellaswag CE        arc easy CE       openbook qa CE          sciq CE         winogrande CE    
  ------------- -------- ------------------ ----------------- ------------------- ------------------- ------------------- ------------------ --
  Transformer      6          $5.851$            $4.43$            $11.792$             $13.42$            $15.061$            $8.359$       
  Transformer      12         $6.018$            $4.535$           $12.552$            $13.999$            $16.643$            $8.883$       
  Recurrent        12         $5.512$            $4.231$           $10.822$             $12.28$            $14.149$            $7.842$       
  Recurrent        6      $\textbf{5.387}$   $\textbf{4.15}$   $\textbf{10.516}$   $\textbf{11.529}$   $\textbf{13.494}$   $\textbf{7.682}$  

  : Downstream performance for the 150M model at batch size 256.
:::

::: {#tab:c4-256-downstream-acc}
  Model          Layers       piqa acc        hellaswag acc       arc easy acc     openbook qa acc       sciq acc        winogrande acc   
  ------------- -------- ------------------ ------------------ ------------------ ----------------- ------------------ ------------------ --
  Transformer      6          $60.94$            $28.59$            $31.58$        $\textbf{37.8}$       $39.16$             $51.3$       
  Transformer      12          $60.5$            $28.57$            $32.63$            $37.6$             $36.5$        $\textbf{52.88}$  
  Recurrent        12         $61.92$            $29.18$            $31.75$            $36.6$            $35.29$             $49.8$       
  Recurrent        6      $\textbf{62.19}$   $\textbf{29.28}$   $\textbf{32.81}$   $\textbf{37.8}$   $\textbf{39.71}$       $50.51$       

  : Downstream accuracy for the 150M model at batch size 256.
:::

Downstream accuracy of 300M parameter transformer
-------------------------------------------------

::: {#tab:c4-512-300m-downstream-acc}
  Model          Layers       piqa acc        hellaswag acc       arc easy acc     openbook qa acc       sciq acc        winogrande acc
  ------------- -------- ------------------ ------------------ ------------------ ----------------- ------------------ ------------------
  Transformer      6      $\textbf{63.49}$       $31.67$        $\textbf{33.51}$       $37.8$             $34.4$            $50.36$
  Transformer      12          $62.4$            $31.82$            $32.46$        $\textbf{38.2}$        $34.4$        $\textbf{51.85}$
  Recurrent        12     $\textbf{63.49}$   $\textbf{33.02}$       $31.58$             $37$         $\textbf{35.95}$       $50.59$
  Recurrent        6          $63.27$            $32.66$            $30.88$            $37.2$            $35.62$            $49.96$

  : Downstream accuracy for the 300M model.
:::

In Table `\ref{tab:c4-512-300m-downstream-acc}`{=latex}, we provide the downstream accuracy for the 300m model.

[^1]: Since we also need to load the queries, for $\text{cnt}_q$ queries to attend to $\text{cnt}_{kv}$ key--value pairs we get an AI of $\Theta(2\cdot \text{cnt}_q \cdot \text{cnt}_{kv} / (\text{cnt}_q + 2\text{cnt}_{kv}))$
