---
abstract: |
  A successful paradigm in representation learning is to perform self-supervised pretraining using tasks based on mini-batch statistics (e.g., SimCLR, VICReg, SwAV, MSN). We show that in the formulation of all these methods is an overlooked prior to learn features that enable uniform clustering of the data. While this prior has led to remarkably semantic representations when pretraining on class-balanced data, such as ImageNet, we demonstrate that it can hamper performance when pretraining on class-imbalanced data. By moving away from conventional uniformity priors and instead preferring power-law distributed feature clusters, we show that one can improve the quality of the learned representations on real-world class-imbalanced datasets. To demonstrate this, we develop an extension of the Masked Siamese Networks (MSN) method to support the use of arbitrary features priors.
author:
- |
  Mahmoud Assran$^{1,2,3}$[^1], Randall Balestriero$^{1}$, Quentin Duval$^{1}$, Florian Bordes$^{1,3,4}$, Ishan Misra$^{1}$\
  **Piotr Bojanowski$^{1}$, Pascal Vincent$^{1}$, Michael Rabbat$^{1,3}$, Nicolas Ballas$^{1}$\
  $^{1}$Meta AI (FAIR)\
  $^{2}$McGill University, ECE\
  $^{3}$Mila, Quebec AI Institute\
  $^{4}$Universite de Montreal, DIRO**
bibliography:
- refs.bib
title: |
  The hidden uniform cluster prior in\
  self-supervised learning
---

```{=latex}
\newcommand{\fix}{\marginpar{FIX}}
```
```{=latex}
\newcommand{\new}{\marginpar{NEW}}
```
```{=latex}
\newcommand{\norm}[1]{\left\lVert#1\right\rVert}
```
```{=latex}
\newcommand{\defeq}{\coloneqq}
```
```{=latex}
\newcommand{\mido}[1]{{\color{blue}#1}}
```
```{=latex}
\newcommand{\florian}[1]{{\color{green}#1}}
```
```{=latex}
\newcommand{\quentin}[1]{{\color{purple}#1}}
```
```{=latex}
\newcommand{\figleft}{{\em (Left)}}
```
```{=latex}
\newcommand{\figcenter}{{\em (Center)}}
```
```{=latex}
\newcommand{\figright}{{\em (Right)}}
```
```{=latex}
\newcommand{\figtop}{{\em (Top)}}
```
```{=latex}
\newcommand{\figbottom}{{\em (Bottom)}}
```
```{=latex}
\newcommand{\captiona}{{\em (a)}}
```
```{=latex}
\newcommand{\captionb}{{\em (b)}}
```
```{=latex}
\newcommand{\captionc}{{\em (c)}}
```
```{=latex}
\newcommand{\captiond}{{\em (d)}}
```
```{=latex}
\newcommand{\newterm}[1]{{\bf #1}}
```
```{=latex}
\def\figref#1{figure~\ref{#1}}
```
```{=latex}
\def\Figref#1{Figure~\ref{#1}}
```
```{=latex}
\def\twofigref#1#2{figures \ref{#1} and \ref{#2}}
```
```{=latex}
\def\quadfigref#1#2#3#4{figures \ref{#1}, \ref{#2}, \ref{#3} and \ref{#4}}
```
```{=latex}
\def\secref#1{section~\ref{#1}}
```
```{=latex}
\def\Secref#1{Section~\ref{#1}}
```
```{=latex}
\def\twosecrefs#1#2{sections \ref{#1} and \ref{#2}}
```
```{=latex}
\def\secrefs#1#2#3{sections \ref{#1}, \ref{#2} and \ref{#3}}
```
```{=latex}
\def\eqref#1{equation~\ref{#1}}
```
```{=latex}
\def\Eqref#1{Equation~\ref{#1}}
```
```{=latex}
\def\plaineqref#1{\ref{#1}}
```
```{=latex}
\def\chapref#1{chapter~\ref{#1}}
```
```{=latex}
\def\Chapref#1{Chapter~\ref{#1}}
```
```{=latex}
\def\rangechapref#1#2{chapters\ref{#1}--\ref{#2}}
```
```{=latex}
\def\algref#1{algorithm~\ref{#1}}
```
```{=latex}
\def\Algref#1{Algorithm~\ref{#1}}
```
```{=latex}
\def\twoalgref#1#2{algorithms \ref{#1} and \ref{#2}}
```
```{=latex}
\def\Twoalgref#1#2{Algorithms \ref{#1} and \ref{#2}}
```
```{=latex}
\def\partref#1{part~\ref{#1}}
```
```{=latex}
\def\Partref#1{Part~\ref{#1}}
```
```{=latex}
\def\twopartref#1#2{parts \ref{#1} and \ref{#2}}
```
```{=latex}
\def\ceil#1{\lceil #1 \rceil}
```
```{=latex}
\def\floor#1{\lfloor #1 \rfloor}
```
```{=latex}
\def\1{\bm{1}}
```
```{=latex}
\newcommand{\train}{\mathcal{D}}
```
```{=latex}
\newcommand{\valid}{\mathcal{D_{\mathrm{valid}}}}
```
```{=latex}
\newcommand{\test}{\mathcal{D_{\mathrm{test}}}}
```
```{=latex}
\def\eps{{\epsilon}}
```
```{=latex}
\def\reta{{\textnormal{$\eta$}}}
```
```{=latex}
\def\ra{{\textnormal{a}}}
```
```{=latex}
\def\rb{{\textnormal{b}}}
```
```{=latex}
\def\rc{{\textnormal{c}}}
```
```{=latex}
\def\rd{{\textnormal{d}}}
```
```{=latex}
\def\re{{\textnormal{e}}}
```
```{=latex}
\def\rf{{\textnormal{f}}}
```
```{=latex}
\def\rg{{\textnormal{g}}}
```
```{=latex}
\def\rh{{\textnormal{h}}}
```
```{=latex}
\def\ri{{\textnormal{i}}}
```
```{=latex}
\def\rj{{\textnormal{j}}}
```
```{=latex}
\def\rk{{\textnormal{k}}}
```
```{=latex}
\def\rl{{\textnormal{l}}}
```
```{=latex}
\def\rn{{\textnormal{n}}}
```
```{=latex}
\def\ro{{\textnormal{o}}}
```
```{=latex}
\def\rp{{\textnormal{p}}}
```
```{=latex}
\def\rq{{\textnormal{q}}}
```
```{=latex}
\def\rr{{\textnormal{r}}}
```
```{=latex}
\def\rs{{\textnormal{s}}}
```
```{=latex}
\def\rt{{\textnormal{t}}}
```
```{=latex}
\def\ru{{\textnormal{u}}}
```
```{=latex}
\def\rv{{\textnormal{v}}}
```
```{=latex}
\def\rw{{\textnormal{w}}}
```
```{=latex}
\def\rx{{\textnormal{x}}}
```
```{=latex}
\def\ry{{\textnormal{y}}}
```
```{=latex}
\def\rz{{\textnormal{z}}}
```
```{=latex}
\def\rvepsilon{{\mathbf{\epsilon}}}
```
```{=latex}
\def\rvtheta{{\mathbf{\theta}}}
```
```{=latex}
\def\rva{{\mathbf{a}}}
```
```{=latex}
\def\rvb{{\mathbf{b}}}
```
```{=latex}
\def\rvc{{\mathbf{c}}}
```
```{=latex}
\def\rvd{{\mathbf{d}}}
```
```{=latex}
\def\rve{{\mathbf{e}}}
```
```{=latex}
\def\rvf{{\mathbf{f}}}
```
```{=latex}
\def\rvg{{\mathbf{g}}}
```
```{=latex}
\def\rvh{{\mathbf{h}}}
```
```{=latex}
\def\rvu{{\mathbf{i}}}
```
```{=latex}
\def\rvj{{\mathbf{j}}}
```
```{=latex}
\def\rvk{{\mathbf{k}}}
```
```{=latex}
\def\rvl{{\mathbf{l}}}
```
```{=latex}
\def\rvm{{\mathbf{m}}}
```
```{=latex}
\def\rvn{{\mathbf{n}}}
```
```{=latex}
\def\rvo{{\mathbf{o}}}
```
```{=latex}
\def\rvp{{\mathbf{p}}}
```
```{=latex}
\def\rvq{{\mathbf{q}}}
```
```{=latex}
\def\rvr{{\mathbf{r}}}
```
```{=latex}
\def\rvs{{\mathbf{s}}}
```
```{=latex}
\def\rvt{{\mathbf{t}}}
```
```{=latex}
\def\rvu{{\mathbf{u}}}
```
```{=latex}
\def\rvv{{\mathbf{v}}}
```
```{=latex}
\def\rvw{{\mathbf{w}}}
```
```{=latex}
\def\rvx{{\mathbf{x}}}
```
```{=latex}
\def\rvy{{\mathbf{y}}}
```
```{=latex}
\def\rvz{{\mathbf{z}}}
```
```{=latex}
\def\erva{{\textnormal{a}}}
```
```{=latex}
\def\ervb{{\textnormal{b}}}
```
```{=latex}
\def\ervc{{\textnormal{c}}}
```
```{=latex}
\def\ervd{{\textnormal{d}}}
```
```{=latex}
\def\erve{{\textnormal{e}}}
```
```{=latex}
\def\ervf{{\textnormal{f}}}
```
```{=latex}
\def\ervg{{\textnormal{g}}}
```
```{=latex}
\def\ervh{{\textnormal{h}}}
```
```{=latex}
\def\ervi{{\textnormal{i}}}
```
```{=latex}
\def\ervj{{\textnormal{j}}}
```
```{=latex}
\def\ervk{{\textnormal{k}}}
```
```{=latex}
\def\ervl{{\textnormal{l}}}
```
```{=latex}
\def\ervm{{\textnormal{m}}}
```
```{=latex}
\def\ervn{{\textnormal{n}}}
```
```{=latex}
\def\ervo{{\textnormal{o}}}
```
```{=latex}
\def\ervp{{\textnormal{p}}}
```
```{=latex}
\def\ervq{{\textnormal{q}}}
```
```{=latex}
\def\ervr{{\textnormal{r}}}
```
```{=latex}
\def\ervs{{\textnormal{s}}}
```
```{=latex}
\def\ervt{{\textnormal{t}}}
```
```{=latex}
\def\ervu{{\textnormal{u}}}
```
```{=latex}
\def\ervv{{\textnormal{v}}}
```
```{=latex}
\def\ervw{{\textnormal{w}}}
```
```{=latex}
\def\ervx{{\textnormal{x}}}
```
```{=latex}
\def\ervy{{\textnormal{y}}}
```
```{=latex}
\def\ervz{{\textnormal{z}}}
```
```{=latex}
\def\rmA{{\mathbf{A}}}
```
```{=latex}
\def\rmB{{\mathbf{B}}}
```
```{=latex}
\def\rmC{{\mathbf{C}}}
```
```{=latex}
\def\rmD{{\mathbf{D}}}
```
```{=latex}
\def\rmE{{\mathbf{E}}}
```
```{=latex}
\def\rmF{{\mathbf{F}}}
```
```{=latex}
\def\rmG{{\mathbf{G}}}
```
```{=latex}
\def\rmH{{\mathbf{H}}}
```
```{=latex}
\def\rmI{{\mathbf{I}}}
```
```{=latex}
\def\rmJ{{\mathbf{J}}}
```
```{=latex}
\def\rmK{{\mathbf{K}}}
```
```{=latex}
\def\rmL{{\mathbf{L}}}
```
```{=latex}
\def\rmM{{\mathbf{M}}}
```
```{=latex}
\def\rmN{{\mathbf{N}}}
```
```{=latex}
\def\rmO{{\mathbf{O}}}
```
```{=latex}
\def\rmP{{\mathbf{P}}}
```
```{=latex}
\def\rmQ{{\mathbf{Q}}}
```
```{=latex}
\def\rmR{{\mathbf{R}}}
```
```{=latex}
\def\rmS{{\mathbf{S}}}
```
```{=latex}
\def\rmT{{\mathbf{T}}}
```
```{=latex}
\def\rmU{{\mathbf{U}}}
```
```{=latex}
\def\rmV{{\mathbf{V}}}
```
```{=latex}
\def\rmW{{\mathbf{W}}}
```
```{=latex}
\def\rmX{{\mathbf{X}}}
```
```{=latex}
\def\rmY{{\mathbf{Y}}}
```
```{=latex}
\def\rmZ{{\mathbf{Z}}}
```
```{=latex}
\def\ermA{{\textnormal{A}}}
```
```{=latex}
\def\ermB{{\textnormal{B}}}
```
```{=latex}
\def\ermC{{\textnormal{C}}}
```
```{=latex}
\def\ermD{{\textnormal{D}}}
```
```{=latex}
\def\ermE{{\textnormal{E}}}
```
```{=latex}
\def\ermF{{\textnormal{F}}}
```
```{=latex}
\def\ermG{{\textnormal{G}}}
```
```{=latex}
\def\ermH{{\textnormal{H}}}
```
```{=latex}
\def\ermI{{\textnormal{I}}}
```
```{=latex}
\def\ermJ{{\textnormal{J}}}
```
```{=latex}
\def\ermK{{\textnormal{K}}}
```
```{=latex}
\def\ermL{{\textnormal{L}}}
```
```{=latex}
\def\ermM{{\textnormal{M}}}
```
```{=latex}
\def\ermN{{\textnormal{N}}}
```
```{=latex}
\def\ermO{{\textnormal{O}}}
```
```{=latex}
\def\ermP{{\textnormal{P}}}
```
```{=latex}
\def\ermQ{{\textnormal{Q}}}
```
```{=latex}
\def\ermR{{\textnormal{R}}}
```
```{=latex}
\def\ermS{{\textnormal{S}}}
```
```{=latex}
\def\ermT{{\textnormal{T}}}
```
```{=latex}
\def\ermU{{\textnormal{U}}}
```
```{=latex}
\def\ermV{{\textnormal{V}}}
```
```{=latex}
\def\ermW{{\textnormal{W}}}
```
```{=latex}
\def\ermX{{\textnormal{X}}}
```
```{=latex}
\def\ermY{{\textnormal{Y}}}
```
```{=latex}
\def\ermZ{{\textnormal{Z}}}
```
```{=latex}
\def\vzero{{\bm{0}}}
```
```{=latex}
\def\vone{{\bm{1}}}
```
```{=latex}
\def\vmu{{\bm{\mu}}}
```
```{=latex}
\def\vtheta{{\bm{\theta}}}
```
```{=latex}
\def\va{{\bm{a}}}
```
```{=latex}
\def\vb{{\bm{b}}}
```
```{=latex}
\def\vc{{\bm{c}}}
```
```{=latex}
\def\vd{{\bm{d}}}
```
```{=latex}
\def\ve{{\bm{e}}}
```
```{=latex}
\def\vf{{\bm{f}}}
```
```{=latex}
\def\vg{{\bm{g}}}
```
```{=latex}
\def\vh{{\bm{h}}}
```
```{=latex}
\def\vi{{\bm{i}}}
```
```{=latex}
\def\vj{{\bm{j}}}
```
```{=latex}
\def\vk{{\bm{k}}}
```
```{=latex}
\def\vl{{\bm{l}}}
```
```{=latex}
\def\vm{{\bm{m}}}
```
```{=latex}
\def\vn{{\bm{n}}}
```
```{=latex}
\def\vo{{\bm{o}}}
```
```{=latex}
\def\vp{{\bm{p}}}
```
```{=latex}
\def\vq{{\bm{q}}}
```
```{=latex}
\def\vr{{\bm{r}}}
```
```{=latex}
\def\vs{{\bm{s}}}
```
```{=latex}
\def\vt{{\bm{t}}}
```
```{=latex}
\def\vu{{\bm{u}}}
```
```{=latex}
\def\vv{{\bm{v}}}
```
```{=latex}
\def\vw{{\bm{w}}}
```
```{=latex}
\def\vx{{\bm{x}}}
```
```{=latex}
\def\vy{{\bm{y}}}
```
```{=latex}
\def\vz{{\bm{z}}}
```
```{=latex}
\def\vpi{{\bm{\pi}}}
```
```{=latex}
\def\evalpha{{\alpha}}
```
```{=latex}
\def\evbeta{{\beta}}
```
```{=latex}
\def\evepsilon{{\epsilon}}
```
```{=latex}
\def\evlambda{{\lambda}}
```
```{=latex}
\def\evomega{{\omega}}
```
```{=latex}
\def\evmu{{\mu}}
```
```{=latex}
\def\evpsi{{\psi}}
```
```{=latex}
\def\evsigma{{\sigma}}
```
```{=latex}
\def\evtheta{{\theta}}
```
```{=latex}
\def\eva{{a}}
```
```{=latex}
\def\evb{{b}}
```
```{=latex}
\def\evc{{c}}
```
```{=latex}
\def\evd{{d}}
```
```{=latex}
\def\eve{{e}}
```
```{=latex}
\def\evf{{f}}
```
```{=latex}
\def\evg{{g}}
```
```{=latex}
\def\evh{{h}}
```
```{=latex}
\def\evi{{i}}
```
```{=latex}
\def\evj{{j}}
```
```{=latex}
\def\evk{{k}}
```
```{=latex}
\def\evl{{l}}
```
```{=latex}
\def\evm{{m}}
```
```{=latex}
\def\evn{{n}}
```
```{=latex}
\def\evo{{o}}
```
```{=latex}
\def\evp{{p}}
```
```{=latex}
\def\evq{{q}}
```
```{=latex}
\def\evr{{r}}
```
```{=latex}
\def\evs{{s}}
```
```{=latex}
\def\evt{{t}}
```
```{=latex}
\def\evu{{u}}
```
```{=latex}
\def\evv{{v}}
```
```{=latex}
\def\evw{{w}}
```
```{=latex}
\def\evx{{x}}
```
```{=latex}
\def\evy{{y}}
```
```{=latex}
\def\evz{{z}}
```
```{=latex}
\def\mA{{\bm{A}}}
```
```{=latex}
\def\mB{{\bm{B}}}
```
```{=latex}
\def\mC{{\bm{C}}}
```
```{=latex}
\def\mD{{\bm{D}}}
```
```{=latex}
\def\mE{{\bm{E}}}
```
```{=latex}
\def\mF{{\bm{F}}}
```
```{=latex}
\def\mG{{\bm{G}}}
```
```{=latex}
\def\mH{{\bm{H}}}
```
```{=latex}
\def\mI{{\bm{I}}}
```
```{=latex}
\def\mJ{{\bm{J}}}
```
```{=latex}
\def\mK{{\bm{K}}}
```
```{=latex}
\def\mL{{\bm{L}}}
```
```{=latex}
\def\mM{{\bm{M}}}
```
```{=latex}
\def\mN{{\bm{N}}}
```
```{=latex}
\def\mO{{\bm{O}}}
```
```{=latex}
\def\mP{{\bm{P}}}
```
```{=latex}
\def\mQ{{\bm{Q}}}
```
```{=latex}
\def\mR{{\bm{R}}}
```
```{=latex}
\def\mS{{\bm{S}}}
```
```{=latex}
\def\mT{{\bm{T}}}
```
```{=latex}
\def\mU{{\bm{U}}}
```
```{=latex}
\def\mV{{\bm{V}}}
```
```{=latex}
\def\mW{{\bm{W}}}
```
```{=latex}
\def\mX{{\bm{X}}}
```
```{=latex}
\def\mY{{\bm{Y}}}
```
```{=latex}
\def\mZ{{\bm{Z}}}
```
```{=latex}
\def\mBeta{{\bm{\beta}}}
```
```{=latex}
\def\mPhi{{\bm{\Phi}}}
```
```{=latex}
\def\mLambda{{\bm{\Lambda}}}
```
```{=latex}
\def\mSigma{{\bm{\Sigma}}}
```
```{=latex}
\newcommand{\tens}[1]{\bm{\mathsfit{#1}}}
```
```{=latex}
\def\tA{{\tens{A}}}
```
```{=latex}
\def\tB{{\tens{B}}}
```
```{=latex}
\def\tC{{\tens{C}}}
```
```{=latex}
\def\tD{{\tens{D}}}
```
```{=latex}
\def\tE{{\tens{E}}}
```
```{=latex}
\def\tF{{\tens{F}}}
```
```{=latex}
\def\tG{{\tens{G}}}
```
```{=latex}
\def\tH{{\tens{H}}}
```
```{=latex}
\def\tI{{\tens{I}}}
```
```{=latex}
\def\tJ{{\tens{J}}}
```
```{=latex}
\def\tK{{\tens{K}}}
```
```{=latex}
\def\tL{{\tens{L}}}
```
```{=latex}
\def\tM{{\tens{M}}}
```
```{=latex}
\def\tN{{\tens{N}}}
```
```{=latex}
\def\tO{{\tens{O}}}
```
```{=latex}
\def\tP{{\tens{P}}}
```
```{=latex}
\def\tQ{{\tens{Q}}}
```
```{=latex}
\def\tR{{\tens{R}}}
```
```{=latex}
\def\tS{{\tens{S}}}
```
```{=latex}
\def\tT{{\tens{T}}}
```
```{=latex}
\def\tU{{\tens{U}}}
```
```{=latex}
\def\tV{{\tens{V}}}
```
```{=latex}
\def\tW{{\tens{W}}}
```
```{=latex}
\def\tX{{\tens{X}}}
```
```{=latex}
\def\tY{{\tens{Y}}}
```
```{=latex}
\def\tZ{{\tens{Z}}}
```
```{=latex}
\def\gA{{\mathcal{A}}}
```
```{=latex}
\def\gB{{\mathcal{B}}}
```
```{=latex}
\def\gC{{\mathcal{C}}}
```
```{=latex}
\def\gD{{\mathcal{D}}}
```
```{=latex}
\def\gE{{\mathcal{E}}}
```
```{=latex}
\def\gF{{\mathcal{F}}}
```
```{=latex}
\def\gG{{\mathcal{G}}}
```
```{=latex}
\def\gH{{\mathcal{H}}}
```
```{=latex}
\def\gI{{\mathcal{I}}}
```
```{=latex}
\def\gJ{{\mathcal{J}}}
```
```{=latex}
\def\gK{{\mathcal{K}}}
```
```{=latex}
\def\gL{{\mathcal{L}}}
```
```{=latex}
\def\gM{{\mathcal{M}}}
```
```{=latex}
\def\gN{{\mathcal{N}}}
```
```{=latex}
\def\gO{{\mathcal{O}}}
```
```{=latex}
\def\gP{{\mathcal{P}}}
```
```{=latex}
\def\gQ{{\mathcal{Q}}}
```
```{=latex}
\def\gR{{\mathcal{R}}}
```
```{=latex}
\def\gS{{\mathcal{S}}}
```
```{=latex}
\def\gT{{\mathcal{T}}}
```
```{=latex}
\def\gU{{\mathcal{U}}}
```
```{=latex}
\def\gV{{\mathcal{V}}}
```
```{=latex}
\def\gW{{\mathcal{W}}}
```
```{=latex}
\def\gX{{\mathcal{X}}}
```
```{=latex}
\def\gY{{\mathcal{Y}}}
```
```{=latex}
\def\gZ{{\mathcal{Z}}}
```
```{=latex}
\def\sA{{\mathbb{A}}}
```
```{=latex}
\def\sB{{\mathbb{B}}}
```
```{=latex}
\def\sC{{\mathbb{C}}}
```
```{=latex}
\def\sD{{\mathbb{D}}}
```
```{=latex}
\def\sF{{\mathbb{F}}}
```
```{=latex}
\def\sG{{\mathbb{G}}}
```
```{=latex}
\def\sH{{\mathbb{H}}}
```
```{=latex}
\def\sI{{\mathbb{I}}}
```
```{=latex}
\def\sJ{{\mathbb{J}}}
```
```{=latex}
\def\sK{{\mathbb{K}}}
```
```{=latex}
\def\sL{{\mathbb{L}}}
```
```{=latex}
\def\sM{{\mathbb{M}}}
```
```{=latex}
\def\sN{{\mathbb{N}}}
```
```{=latex}
\def\sO{{\mathbb{O}}}
```
```{=latex}
\def\sP{{\mathbb{P}}}
```
```{=latex}
\def\sQ{{\mathbb{Q}}}
```
```{=latex}
\def\sR{{\mathbb{R}}}
```
```{=latex}
\def\sS{{\mathbb{S}}}
```
```{=latex}
\def\sT{{\mathbb{T}}}
```
```{=latex}
\def\sU{{\mathbb{U}}}
```
```{=latex}
\def\sV{{\mathbb{V}}}
```
```{=latex}
\def\sW{{\mathbb{W}}}
```
```{=latex}
\def\sX{{\mathbb{X}}}
```
```{=latex}
\def\sY{{\mathbb{Y}}}
```
```{=latex}
\def\sZ{{\mathbb{Z}}}
```
```{=latex}
\def\emLambda{{\Lambda}}
```
```{=latex}
\def\emA{{A}}
```
```{=latex}
\def\emB{{B}}
```
```{=latex}
\def\emC{{C}}
```
```{=latex}
\def\emD{{D}}
```
```{=latex}
\def\emE{{E}}
```
```{=latex}
\def\emF{{F}}
```
```{=latex}
\def\emG{{G}}
```
```{=latex}
\def\emH{{H}}
```
```{=latex}
\def\emI{{I}}
```
```{=latex}
\def\emJ{{J}}
```
```{=latex}
\def\emK{{K}}
```
```{=latex}
\def\emL{{L}}
```
```{=latex}
\def\emM{{M}}
```
```{=latex}
\def\emN{{N}}
```
```{=latex}
\def\emO{{O}}
```
```{=latex}
\def\emP{{P}}
```
```{=latex}
\def\emQ{{Q}}
```
```{=latex}
\def\emR{{R}}
```
```{=latex}
\def\emS{{S}}
```
```{=latex}
\def\emT{{T}}
```
```{=latex}
\def\emU{{U}}
```
```{=latex}
\def\emV{{V}}
```
```{=latex}
\def\emW{{W}}
```
```{=latex}
\def\emX{{X}}
```
```{=latex}
\def\emY{{Y}}
```
```{=latex}
\def\emZ{{Z}}
```
```{=latex}
\def\emSigma{{\Sigma}}
```
```{=latex}
\newcommand{\etens}[1]{\mathsfit{#1}}
```
```{=latex}
\def\etLambda{{\etens{\Lambda}}}
```
```{=latex}
\def\etA{{\etens{A}}}
```
```{=latex}
\def\etB{{\etens{B}}}
```
```{=latex}
\def\etC{{\etens{C}}}
```
```{=latex}
\def\etD{{\etens{D}}}
```
```{=latex}
\def\etE{{\etens{E}}}
```
```{=latex}
\def\etF{{\etens{F}}}
```
```{=latex}
\def\etG{{\etens{G}}}
```
```{=latex}
\def\etH{{\etens{H}}}
```
```{=latex}
\def\etI{{\etens{I}}}
```
```{=latex}
\def\etJ{{\etens{J}}}
```
```{=latex}
\def\etK{{\etens{K}}}
```
```{=latex}
\def\etL{{\etens{L}}}
```
```{=latex}
\def\etM{{\etens{M}}}
```
```{=latex}
\def\etN{{\etens{N}}}
```
```{=latex}
\def\etO{{\etens{O}}}
```
```{=latex}
\def\etP{{\etens{P}}}
```
```{=latex}
\def\etQ{{\etens{Q}}}
```
```{=latex}
\def\etR{{\etens{R}}}
```
```{=latex}
\def\etS{{\etens{S}}}
```
```{=latex}
\def\etT{{\etens{T}}}
```
```{=latex}
\def\etU{{\etens{U}}}
```
```{=latex}
\def\etV{{\etens{V}}}
```
```{=latex}
\def\etW{{\etens{W}}}
```
```{=latex}
\def\etX{{\etens{X}}}
```
```{=latex}
\def\etY{{\etens{Y}}}
```
```{=latex}
\def\etZ{{\etens{Z}}}
```
```{=latex}
\newcommand{\pdata}{p_{\rm{data}}}
```
```{=latex}
\newcommand{\ptrain}{\hat{p}_{\rm{data}}}
```
```{=latex}
\newcommand{\Ptrain}{\hat{P}_{\rm{data}}}
```
```{=latex}
\newcommand{\pmodel}{p_{\rm{model}}}
```
```{=latex}
\newcommand{\Pmodel}{P_{\rm{model}}}
```
```{=latex}
\newcommand{\ptildemodel}{\tilde{p}_{\rm{model}}}
```
```{=latex}
\newcommand{\pencode}{p_{\rm{encoder}}}
```
```{=latex}
\newcommand{\pdecode}{p_{\rm{decoder}}}
```
```{=latex}
\newcommand{\precons}{p_{\rm{reconstruct}}}
```
```{=latex}
\newcommand{\laplace}{\mathrm{Laplace}}
```
```{=latex}
\newcommand{\E}{\mathbb{E}}
```
```{=latex}
\newcommand{\Ls}{\mathcal{L}}
```
```{=latex}
\newcommand{\R}{\mathbb{R}}
```
```{=latex}
\newcommand{\emp}{\tilde{p}}
```
```{=latex}
\newcommand{\lr}{\alpha}
```
```{=latex}
\newcommand{\reg}{\lambda}
```
```{=latex}
\newcommand{\rect}{\mathrm{rectifier}}
```
```{=latex}
\newcommand{\softmax}{\mathrm{softmax}}
```
```{=latex}
\newcommand{\sigmoid}{\sigma}
```
```{=latex}
\newcommand{\softplus}{\zeta}
```
```{=latex}
\newcommand{\KL}{D_{\mathrm{KL}}}
```
```{=latex}
\newcommand{\Var}{\mathrm{Var}}
```
```{=latex}
\newcommand{\standarderror}{\mathrm{SE}}
```
```{=latex}
\newcommand{\Cov}{\mathrm{Cov}}
```
```{=latex}
\newcommand{\normlzero}{L^0}
```
```{=latex}
\newcommand{\normlone}{L^1}
```
```{=latex}
\newcommand{\normltwo}{L^2}
```
```{=latex}
\newcommand{\normlp}{L^p}
```
```{=latex}
\newcommand{\normmax}{L^\infty}
```
```{=latex}
\newcommand{\parents}{Pa}
```
```{=latex}
\DeclareMathOperator*{\argmax}{arg\,max}
```
```{=latex}
\DeclareMathOperator*{\argmin}{arg\,min}
```
```{=latex}
\DeclareMathOperator{\sign}{sign}
```
```{=latex}
\DeclareMathOperator{\Tr}{Tr}
```
```{=latex}
\let\ab\allowbreak
```
```{=latex}
\maketitle
```
Introduction
============

Self-supervised pretraining has emerged as a highly effective strategy for unsupervised representation learning, with remarkable advances demonstrated by joint-embedding methods [@chen2020simple; @caron2021emerging; @bardes2021vicreg; @assran2022masked]. In the context of visual data, these approaches typically learn representations by training a neural network encoder to produce similar embeddings for two or more views of the same image. However, since outputting a constant vector regardless of the input would satisfy this objective, one of the main challenges with joint-embedding methods is to prevent such pathological solutions. A common remedy is to employ a regularizer that maximizes the volume of space occupied by the representations. This is sometimes referred to as the volume maximization principle. In practice, the volume maximization principle is implemented in a variety of ways, for example, by contrasting negative samples [@bromley1993signature; @he2019moco; @chen2020simple], by removing correlations in the feature space [@bardes2021vicreg; @zbontar2021barlow], or by finding high entropy clusterings of the data [@asano2019self; @caron2020unsupervised; @assran2021semi; @assran2022masked]. When pretrained on the ImageNet dataset [@russakovsky2015imagenet], these methods have been shown to produce representations that encode highly semantic features [@caron2020unsupervised; @caron2021emerging; @assran2022masked].

However, the commonly used ImageNet-1K dataset is relatively class-balanced, which is in contrast to most real-world settings, where data is often *class-imbalanced* and semantic concepts follow a long-tailed power-law distribution [@newman2005power; @mahajan2018exploring; @van2018inaturalist]. Indeed, it has been shown that pretraining the same joint-embedding methods on long-tailed datasets can lead to significant drops in performance [@tian2021divide]. Such an observation is problematic in that it significantly hinders the applicability of modern research advances with joint-embedding methods to real-world settings.

In this work, we explore the use of joint-embedding methods for class-imbalanced datasets. First, we theoretically show that current methods with volume maximization regularizers such as VICReg [@bardes2021vicreg], SwAV [@caron2020unsupervised], MSN [@assran2022masked] and SimCLR [@chen2020simple] (with limited assumptions), have a uniform feature prior; i.e., a bias to learn features that enable grouping the data into clusters of roughly equal size. Consequently, these joint-embedding methods will penalize features that do not uniformly cluster the data, even if such features correlate well with class information; see Figure `\ref{fig:kmeans}`{=latex}.

Second, we empirically validate that joint-embedding methods employing volume maximization regularizers are sensitive to the mini-batch class distributions. These approaches fail to learn class-discriminative features when the samples within a mini-batch do not follow a uniform class distribution. This observation partially explains why performance degrades when pretraining with real-world data, where sampled mini-batches often contain highly imbalanced class distributions. `\setlength{\columnsep}{4.5mm}`{=latex}

```{=latex}
\begin{wrapfigure}[33]{r}{0.3\linewidth}
% \begin{figure}[t]
    \centering
    \begin{subfigure}{\linewidth}
        \centering
        \includegraphics[width=\linewidth]{assets/kmeans_0.5.pdf}
        \caption{K-means with class-balanced data}
        \label{fig:visu_rcdm_balanced}
    \end{subfigure}\\[1em]
    \begin{subfigure}{\linewidth}
        \centering
        \includegraphics[width=\linewidth]{assets/kmeans_0.95.pdf}
        \caption{K-means with class-imbalanced data}
        \label{fig:visu_rcdm_unbalanced}
    \end{subfigure}
    \caption{Impact of uniform cluster prior in K-means when class distribution of data is imbalanced.  K-means clustering depicted in color (green vs red). Ground-truth cluster separation depicted with a dotted black line. When uniform feature prior is not satisfied, K-means can identify undesirable features for discriminating between data points.}
    \label{fig:kmeans}
\end{wrapfigure}
```
Finally, based on this observation, we propose to move away from conventional uniformity priors and instead reformulate self-supervised criteria to prefer long-tailed feature priors that are more aligned with the distribution of semantic concepts in real-world datasets. In particular, we extend Masked Siamese Networks (MSN) of @assran2022masked to support the use of arbitrary features priors, and refer to this extension as *Prior Matching for Siamese Networks (PMSN)*. When pretraining on the iNaturalist 2018 dataset [@van2018inaturalist], which is naturally long-tailed, we demonstrate that moving away from uniform priors leads to more semantic representations and improved transfer on downstream tasks.

Background
==========

Given the recent success of joint-embedding methods, there is a growing literature that aims to build a better understanding of their behaviour. Several works have sought to develop generalization bounds for joint-embedding methods with volume maximization penalties [@arora2019theoretical; @balestriero2022contrastive]. Other works have sought to better understand the differences between various volume maximization penalties and connect them under limited assumptions [@garrido2022duality]. In general, it has been shown that $\ell_2$-normalized contrastive losses can be decomposed into an \`\`alignment" plus volume maximization component that scatters the representations uniformly on the unit hypersphere [@wang2020understanding]. Following this observation, other works [@chen2021intriguing] have sought to reformulate contrastive losses to scatter representations either (a) uniformly on the unit hypercube, or (b) onto Gaussian distributions (which have the highest entropy amongst all distributions with a given variance). There is also theoretical work [@tian2021understanding] which aims to understand why certain joint-embedding methods, such as BYOL [@grill2020bootstrap], can avoid representation collapse without explicit use of a volume maximization penalty.

While these works have helped build our understanding on the training dynamics of joint-embedding methods, they do not directly explain why empirical use of these methods with real-world class-imbalanced data has often led to a degradation in downstream task performance [@tian2021divide; @goyal2022vision] (see Appendix `\ref{apndx:relatedwork}`{=latex} for a broader discussion of related work). In this work, we explore the use of joint-embedding methods with class-imbalanced data. In particular, we theoretically show that a broad range of methods (beyond contrastive) prescribe a uniform *feature* prior, and that this prior is detrimental when pretraining with class-imbalanced data.

Uniform priors in modern self-supervised learning {#sec:theory}
=================================================

In this section, we theoretically show that common SSL methods such a, VICReg [@bardes2021vicreg], SwAV [@caron2020unsupervised], MSN [@assran2022masked], and (with limited assumptions) SimCLR [@chen2020simple], correspond to variants of K-means, and thereby impose a uniform cluster prior; i.e., a bias to learn features that enable uniform clustering of the data. The governing assumption in K-means is the presence of isotropic data clusters, with roughly an equal number of samples per cluster [@wu2009adapting; @liang2012k]. When this assumption is not satisfied in practice, K-means may learn undesirable features for discriminating between samples (cf. Figure `\ref{fig:kmeans}`{=latex}).

Background: K-means formulations and the uniform prior
------------------------------------------------------

#### Explicit (Centroid) K-means.

Recall that K-means proposes a centroid based clustering of the data. In particular, given a set of $N$ data points $\{\vx_n\}^N_{n=1}$, K-means partitions the elements into $K$ disjoint groups $\sX_1,\dots,\sX_K$, such that $\sum^K_{i=1}\lvert\sX_i\rvert = N$. The K-means objective can be written as $$\label{eq:kmeans}
    \min_{\{\sX_k\}_{k=1}^K}
    \sum_{k=1}^K \sum_{\vx \in \sX_k} \|\vx-\mu_k\|^2_2,$$ where the optimization problem is to identify the members of the disjoint sets $\sX_1, \ldots, \sX_K$, and $\mu_k$ (the $k^\text{th}$ cluster centroid) is precisely the average of the members of $\sX_k$.

#### Implicit (Centroid) K-means.

One important note that we will carry through our study is that K-means does not require explicit computation of the cluster centroids $\mu_k$ to evaluate the objective. To see this standard result, one can relate the sum of pairwise distances to the sum of radial distances for any data partition, $$\begin{aligned}
    \sum_{k=1}^K \sum_{\vx \in\sX_k} \lVert\vx-\mu_k\rVert^2_2 & = \sum_{k=1}^K \frac{1}{2\lvert\sX_k\rvert} \sum_{\vx, \vx^\prime \in \sX_k} \lVert\vx-\vx'\rVert^2_2. \label{eq2}\end{aligned}$$ This result suggests that the K-means loss can be minimized by either learning cluster centroids (explicit K-means), or by learning cluster memberships (implicit K-means). We make this relation precise in Proposition `\ref{prop:kmeans}`{=latex}, which is proven in Appendix `\ref{proof:kmeans1}`{=latex} and was already noted in @zha2001spectral [@awasthi2015relax].

```{=latex}
\begin{proposition}
    \label{prop:kmeans}
    The explicit K-means problem, defined by learning a set of $K$ centroids $\mu_1,\dots,\mu_{K}$, admits the same global optimum as the implicit K-means problem, defined by learning a cluster membership matrix $\mP \in \{0,1\}^{N \times K}$, such that $\mP\mathbf{1}_{K}=\mathbf{1}_{N}$.
\end{proposition}
```
The fundamental assumptions governing the success of K-means lie in having clusters with roughly the same number of samples and intra-cluster data covariance that is isotropic with the form $\sigma \mI$, where $\sigma > 0$ and $\mI$ is the identity matrix [@wu2009adapting; @liang2012k]. In the sequel, we will show that various SSL methods can be seen as employing either explicit or implicit K-means.

How K-means naturally emerges from self-supervised losses {#sec:swav}
---------------------------------------------------------

In this section, we demonstrate how standard SSL methods naturally employ K-means at their core. Although this result might seem intuitive for methods that explicitly compute a clustering of the representations, we surprisingly found that this is also the case for some methods (e.g., VICReg) that do not explicitly involve centroid parameters or a clustering step.

**Implicit K-means: VICReg and SimCLR.**  The VICReg loss [@bardes2021vicreg] consists of three terms, one measuring the $\ell_2$ distance between the positive view pairs, one encouraging the off-diagonal entries of the embedding covariance matrix to go to $0$, and one encouraging its diagonal entries to be greater than $1$. One standard simplification done in practice to study VICReg is to switch the variance and covariance terms to a single term so that the new VICReg loss is given by $$\begin{aligned}
    \mathcal{L}=&\alpha \|\Cov(\mZ) - \mI \|_F^2 +\frac{\gamma}{N} \sum_{i,j=1}^{N}(\mG)_{i,j}\|\vz_{i}-\vz_{j}\|_2^2,\label{eq:VICR2}\end{aligned}$$ where $\Cov(\mZ)$ is the covariance matrix of the vectors $\vz_1,\ldots,\vz_n$, the constants $\alpha, \gamma > 0$ are hyper-parameters, and $G_{i,j} \in \{0,1\}$ is equal to 1 when the representations $z_i$ and $z_j$ correspond to positive views of the same image. From this formulation, we relate VICReg to the implicit K-means algorithm in Proposition `\ref{prop:vicreg}`{=latex}, which is proven in Appendix `\ref{proof:kmeans2}`{=latex}. In particular, note how the summation term on the right hand side of `\eqref{eq:VICR2}`{=latex} recovers the right hand side of `\eqref{eq2}`{=latex}.

```{=latex}
\begin{proposition}
\label{prop:vicreg}
    VICReg with hyper-parameters $\gamma \gg \alpha$ recovers the K-means loss from~\eqref{eq:kmeans} on the embeddings, with an additional regularizer enforcing orthogonality of the centroids.
\end{proposition}
```
The hyper-parameter requirement $\gamma \gg \alpha$ is commonly employed with VICReg in practice to ensure invariance to the pre-defined set of data-augmentations used to construct the positive image views. In fact, in the opposite case, where $\alpha \gg \gamma$, a degenerate whitening can be learned without keeping any information about the input samples due to the non-linearity of the deep neural network mapping, as shown in @balestriero2022contrastive. Note that under mild conditions, the VICReg loss has been shown to be equivalent to the SimCLR loss [@garrido2022duality]. In such scenarios, Proposition `\ref{prop:vicreg}`{=latex} would also directly apply to the SimCLR method.

**Explicit K-means with soft constraints: MSN.**   A common relaxation of constrained K-means is to remove the hard cluster assignment constraint, such that the condition $\mP_{i,j} \in \{0,1\}$ is generalized to $\mP_{i,j} \in [0,1]$ [@wang2010learning]. This relaxes constrained K-means to the more general Gaussian Mixture Model (GMM) formulation, where each data point can partially belong to all clusters with some probability mass. Since, the derivation of the GMM loss (i.e., the ELBO and log-likelihood) is beyond the scope of this section, we simply state the objective as $$\min_{\{\mu_k,\Sigma_k\}_{k=1}^K}\sum_{\vx \in \sX}\sum_{k=1}^{K}
     \frac{[p(\vx)]_k}{2}\lVert\vx-\mu_k\rVert^2_{\Sigma^{-1}}+N\sum_{k=1}^{K}\log\det(\Sigma_k)+
     \sum_{\vx \in \sX}D_{\text{KL}}(p(\vx)\ \lVert\ \vpi),\label{eq:GMM}$$ where $\vpi$ is the cluster prior and $[p(\vx)]_k$ is the posterior probability of input $\vx$ belonging to cluster $k$, obtained from $$p(\vx)=\text{softmax}\left(\mW^T \vx+\log (\vpi)-\frac{1}{2}\lVert \vx \rVert_2^2 - \frac{1}{2}\text{diag}(\mW^T\mW)\right),$$ where $\mW=[\mu_1,\dots,\mu_K]$ concatenates the centroids, and $\text{diag}(\cdot)$ extracts the diagonal of its matrix argument into a column vector.

Of particular interest to us is the case in which the prior is set to the uniform distribution, $[\vpi]_k = \nicefrac{1}{K}$ for all $k \in [K]$, the covariance matrix is isotropic, $\Sigma_k=\sigma \mI$ for $\sigma \geq 0$, and the centroids and data vectors are $\ell_2$-normalized, in which case `\eqref{eq:GMM}`{=latex} simplifies to $$\min_{\mu_1,\ldots,\mu_K}\sum_{\vx \in \sX}\sum_{k=1}^{K}
     \frac{[\text{softmax}(\mW^T\vx)]_k}{2}\|\vx-\mu_k\|^2_{2}-\sum_{x\in \sX} H\left(\text{softmax}(\mW^T\vx)\right).
     \label{eq:GMM_uni}$$ One can also express the simplified GMM objective in `\eqref{eq:GMM_uni}`{=latex} in terms of the marginal entropy (with a summation inside the entropy term) by following the ELBO derivations of @hoffman2016elbo. The Masked Siamese Network (MSN) loss [@assran2022masked] with positive pairs $\vx_n,\vx_n^+$ and posteriors $\vp_n=\text{softmax}(\mW^T\vx_n/\sigma)$, $\vp_n^+=\text{softmax}(\mW^T\vx_n^+/\sigma)$, with temperature $\sigma$, is $$\label{eq:msn}
    \frac{1}{N} \sum^N_{n=1} H(\vp_{n}^+, \vp_n) - \lambda H({\overline{\vp}}),
    % \frac{1}{N} \sum^N_{n=1} H(\vp_{n}^+, \vp_n) + \lambda D_{\text{KL}}({\overline{\vp}\ \lVert\ \vp_\text{prior}}),$$ where $\lambda > 0$, and $\overline{\vp} \defeq \frac{1}{N} \sum^N_{n=1} \vp_n$. We show that the MSN objective can be seen as variant of K-means variant employing an explicit cluster-membership penalty (cf. Appendix `\ref{proof:kmeans3}`{=latex}).

```{=latex}
\begin{proposition}
\label{prop:msn}
    MSN recovers the GMM loss from~\eqref{eq:GMM_uni} with the variation that the $\ell_2$ distance is replaced with the cross-entropy distance of the posterior.
\end{proposition}
```
**Explicit K-means with hard constraints: SwAV**  Among popular K-means variants, one particular version requires specification of the cardinality of the clusters [@kleindessner2019fair; @bradley2000constrained; @rujeerapaiboon2019size]. Specifically, the number of items in each cluster $\lvert\sX_k\rvert$ is strictly enforced to take a specific value $N_k$. The resulting K-means formulation becomes $$\min_{\{\sX_k\}_{k=1}^K \text{s.t.} |\sX_k|=N_k}
    \sum_{k=1}^K \sum_{\vx \in \sX_k} \|\vx-\mu_k\|^2_2.\label{eq:constrained}$$ When $N_k=\frac{N}{K}$, we obtain strict enforcement of the uniform cluster prior, which is otherwise implicit in the formulation, but not strictly enforced. This variant of K-means was previously used in @wang2010learning. The SwAV [@caron2020unsupervised] loss with positive posterior pairs $\vp_n,\vp_n^+$ (defined similarly to those in `\eqref{eq:msn}`{=latex}) is $$\label{eq:swav}
    \frac{1}{N} \sum^N_{n=1} H(\vp_{n}^+, \vp_n)
    \quad \text{subject to}
    \quad \mP \mathbf{1}_N = \frac{N}{K} \mathbf{1}_K;
    \quad \mP^\top \mathbf{1}_K = \mathbf{1}_N,$$ where $\mP = [\vp^+_1, \ldots, \vp^+_n]$ concatenates the predictions. In @caron2020unsupervised, this constraint is enforced in practice by projecting the matrix $\mP$ onto the constraint set in each iteration using the Sinkhorn-Knopp algorithm [@cuturi2013sinkhorn]. SwAV thus recovers a strictly constrained variant of K-means (cf. Appendix `\ref{proof:swav}`{=latex}).

```{=latex}
\begin{proposition}
\label{prop:swav}
SwAV recovers the constrained K-means loss from~\eqref{eq:GMM_uni} with constraint enforcement through Sinkhorn-Knopp normalization.
\end{proposition}
```
Note that for all presented SSL methods and their corresponding K-means variants, enforcing normalization of the centroids and features does not reduce the generality of the results; in this case, we simply obtain a correspondence to similar variants of spherical K-means [@hornik2012spherical].

Negative effect of uniform priors for class-imbalanced data {#sec:class_stratified}
===========================================================

As proven in Section `\ref{sec:theory}`{=latex}, joint-embedding methods employing volume maximization regularizers have an uniform feature prior. Following this observation, in this section we empirically demonstrate that such methods are sensitive to non-uniform mini-batch class distributions.

#### Experimental setup.

We explore three joint-embedding methods employing diverse collapse prevention strategies: SimCLR [@chen2020simple], VICReg [@bardes2021vicreg], and MSN [@assran2022masked]. We also compare the performance of those models to instance-based methods such as MAE [@he2021masked] and data2vec [@baevski2022data2vec], which do not employ volume maximization regularizers. In this evaluation, all models are pretrained on the ImageNet-1K dataset without access to the class labels. To explore sensitivity to mini-batch class distributions, we explore two dichotomous sampling strategies (which do require semantic knowledge of the image classes).

One strategy, termed *class-balanced sampling*, constructs the mini-batches in each iteration by first randomly selecting 960 classes out of the 1000 ImageNet-1K classes, and then sampling an equal number of images from each class.[^2] Given that ImageNet is relatively class-balanced, this strategy produces mini-batches with similar statistics to traditional uniform sampling. Another strategy, termed *class-imbalanced sampling*, constructs the mini-batches in each iteration by first randomly selecting $K \ll 960$ classes out of the 1000 ImageNet-1K classes, and then sampling an equal number of images from each class, such that the total mini-batch size is the same as under the class-balanced sampling strategy. We pick $K$ as the smallest value needed to reach the default batch size for each respective method: $8$ for SimCLR (which has the a default batch size of $4096$) and $2$ for VICReg and MSN (which have default batch sizes of $1024$). Note that the class-imbalanced strategy maintains the same marginal probability of sampling individual data points as compared to the class-balanced sampling strategy (see Appendix `\ref{app:marginalprob}`{=latex} for a derivation of this equivalence).

After pretraining all models using the various sampling strategies, we evaluate performance on a wide range of downstream tasks requiring different levels of abstraction, i.e., classification with CIFAR100 [@krizhevsky2009learning], Places205 [@zhou2014learning], and iNat18 [@van2018inaturalist]; object counting with Clevr/Count [@johnson2017clevr]; and depth prediction with Clevr/Dist [@johnson2017clevr] and KITTI [@geiger2013vision]. We also evaluate in-distribution performance of ImageNet classification [@russakovsky2015imagenet; @chen2020simple]. Additional pretraining and evaluation details can be found in Appendix `\ref{apndx:classstratified_sampling}`{=latex}; the full set of results can be found in Appendix `\ref{apndx:classstratified_results}`{=latex}.

```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:class_stratified_sampling}
                       ImageNet                       iNat18                      CIFAR100                    Places205                  Clevr/Count          Clevr/Dist   KITTI
  ---------- ----------------------------- ---------------------------- ---------------------------- --------------------------- --------------------------- ------------ -------
      SimCLR  `\color{red}`{=latex} -11.2   `\color{red}`{=latex}-8.0    `\color{red}`{=latex}-10.2   `\color{red}`{=latex}-5.2   `\color{red}`{=latex}-4.3      +0.9      +1.2
         MSN  `\color{red}`{=latex}-17.7    `\color{red}`{=latex}-15.1   `\color{red}`{=latex}-13.2   `\color{red}`{=latex}-4.6   `\color{red}`{=latex}-6.4      +1.9      -1.6
      VICReg  `\color{red}`{=latex}-17.7    `\color{red}`{=latex}-17.3   `\color{red}`{=latex}-12.0   `\color{red}`{=latex}-6.0   `\color{red}`{=latex}-3.0      +0.7      -1.1
    data2vec             -0.8                          +0.3                         -1.6                        +0.0                        -2.1                 -1.5      -0.1
         MAE             -0.1                          +1.4                         +2.5                        +0.1                        -0.8                 +0.0      +0.0

  : **Transfer:** Evaluation of the pretraining mini-batch sampling distribution on various downstream tasks. Each cell reports the task performance when pretraining with class-balanced sampling minus the task performance when pretraining with class-imbalanced sampling. Sampling imbalanced mini-batches during pretraining leads to a significant drop in image classification tasks for joint-embedding methods with volume maximization priors (SimCLR, MSN, VICReg), whereas instance-based methods, which do not employ such regularization (MAE, data2vec), are relatively unaffected.
:::

#### Empirical observations.

As can be seen in Table `\ref{tb:class_stratified_sampling}`{=latex}, the performance of joint-embedding methods employing volume maximization regularizers degrades significantly on all the semantic downstream tasks (IN1K, CIFAR100, Places205, Clevr/Count) when the mini-batches sampled during pretraining are not class-balanced (e.g., drops by as much as 17.7 top-1 on IN1K), but remain relatively stable (and even marginally improve) on low-level depth prediction tasks Clevr/Dist and KITTI, suggesting that class-imbalanced pretraining leads the model to capture lower-level (less semantic) features.

By contrast, evaluations with instanced-based methods data2vec and MAE in Table `\ref{tb:class_stratified_sampling}`{=latex} show different trends. The MAE method employs a simple pixel-reconstruction loss for representation learning, and thus does not explicitly compute mini-batch statistics during pretraining. The data2vec method is more similar to MSN, SimCLR, and VICReg in that it utilizes a joint-embedding architecture; however, in contrast to those methods, data2vec does not explicitly compute mini-batch statistics during pretraining, and instead relies on architectural heuristics and careful hyperparameter choices to prevent collapse. When evaluating these methods with class-balanced pretraining versus class-imbalanced pretraining, we observe virtually no change in downstream task performance. Only methods with explicit volume maximization terms exhibit sensitivity to the mini-batch class distribution.

#### Visualizing learned prototype vectors.

In Figure `\ref{fig:visu_rcdm_sampling}`{=latex}, we use RCDM [@bordes2022high] to visualize the prototypes learned by an MSN model pretrained on IN1K with either class-balanced or class-imbalanced mini-batch distributions. A prototype here refers to a row in the weight matrix of the final linear layer in the encoder. Each row in Figure `\ref{fig:visu_rcdm_sampling}`{=latex} corresponds to samples generated by conditioning on a *single prototype* using various random seeds. Characteristics that remain constant across a row in Figure `\ref{fig:visu_rcdm_sampling}`{=latex} reflect information contained in the prototype, whereas characteristics that vary reflect information that is not contained (i.e., to which the representations are invariant). When pretraining with class-balanced mini-batches, the emergent features tend to be associated with high-level concepts, such as specific ImageNet classes (Figure `\ref{fig:visu_rcdm_balanced}`{=latex}). In contrast, when pretraining with class-imbalanced mini-batches, the learned features tend to be associated with low-level concepts, such as shape, pose, or texture (Figure `\ref{fig:visu_rcdm_unbalanced}`{=latex}).

```{=latex}
\centering
```
```{=latex}
\centering
```
![Pretraining with class-balanced mini-batch sampling](assets/MSN_fig_balanced.png){#fig:visu_rcdm_balanced width="\\linewidth"}

```{=latex}
\hfill
```
```{=latex}
\centering
```
![Pretraining with class-imbalanced mini-batch sampling](assets/MSN_fig_unbalanced.png){#fig:visu_rcdm_unbalanced width="\\linewidth"}

Prior Matching for Siamese Networks {#sec:priors}
===================================

Section `\ref{sec:class_stratified}`{=latex} validates that in settings where the samples within a mini-batch do not follow a uniform class distribution, joint-embedding methods with volume maximization penalties encode less class-oriented features, and perform worse on downstream semantic classification tasks. In this section, we demonstrate that using alternative (long-tailed) feature priors can lead to representations of a higher semantic level when pretraining on real-world class-imbalanced data.

#### Siamese Networks with Arbitrary Prior.

In @assran2022masked, the MSN prior is explicitly set to the uniform distribution. As discussed in \[Section `\ref{sec:theory}`{=latex}, `\eqref{eq:GMM}`{=latex}\], the KL penalty in MSN explicitly encourages learning representations that enable uniform clustering of the data. However, when pretraining with class-imbalanced data, the semantic concepts of interest no longer satisfy the assumptions of the uniformity prior. In particular, natural observations \`\`in the wild" tend to follow long-tailed (often power-law) distributions [@newman2005power; @mahajan2018exploring; @van2018inaturalist]. Based on this observation, we introduce Prior Matching for Siamese Networks (PMSN), which extends MSN to support the use of arbitrary feature priors. Specifically, we modify MSN by replacing the negative entropy term in `\eqref{eq:msn}`{=latex} with the KL-divergence to a user-specified distribution.[^3] For instance, we can instantiate PMSN as $$\label{eq:msn_prior}
     \frac{1}{N} \sum^N_{i=1} H(\vp_{i}^+, \vp_i) + \lambda D_{\text{KL}}( \overline{\vp}\ \lVert\ \vp_{\text{\sc pl}(\tau)}),$$ where $\vp_{\text{PL}(\tau)}$ is a power-law distribution with exponent $\tau > 0$.

Toy setting
-----------

We claim that the uniform prior in joint-embedding methods significantly impacts the features captured in their learned representations. To illustrate this point to the research community using well-known data, we construct a toy setting to study how changes in the prior affect the suppression and selection of features by the encoder. We take grayscaled CIFAR10 [@krizhevsky2009learning] images and overlay one of ten MNIST digits [@lecun-mnisthandwrittendigit-2010] in the top left corner, such that the overall distribution of MNIST digits in the dataset follows a power-law distribution with exponent $0.5$.[^4] Next we perform self-supervised pretraining on this dataset using PMSN. We compare pretraining using a uniform feature prior, to pretraining using a power-law prior with exponent equal to $0.5$, which corresponds to the true power-law distribution of MNIST digits in the toy dataset.

The first column in Figure `\ref{fig:cifar10-nn}`{=latex} shows reference images from this toy dataset; the images in subsequent column visualize the corresponding nearest neighbours in the embedding space of the pretrained models. When pretraining using a power-law prior, the MNIST digit is encoded by the model, and the nearest neighbours all have the same digit class (Figure `\ref{fig:cifar10-nn-powerlaw}`{=latex}). However, when pretraining with a uniform prior, MNIST digit information is discarded by the model, and therefore the nearest neighbours have different digit classes (Figure `\ref{fig:cifar10-nn-uniform}`{=latex}). In particular, since the MNIST digit \`\`feature" does not admit a uniform clustering of the data, it is discarded when pretraining with the conventional volume maximization penalty. This toy experiment provides insight into how semantic features in class-imbalanced datasets can be suppressed by the encoder and how alternative (non-uniform) priors can be used to recover such features.

```{=latex}
\centering
```
![Ref.\
 ](assets/cifar10-ref.png){width="\\textwidth"}

```{=latex}
\hfill
```
```{=latex}
\centering
```
![Nearest Neighbours when pretraining with power-law prior\
**(Matching MNIST digit distribution)**](assets/cifar10-powerlaw.png){#fig:cifar10-nn-powerlaw width="\\textwidth"}

```{=latex}
\hfill
```
```{=latex}
\centering
```
![Nearest Neighbours when pretraining with uniform prior\
**(Not matching MNIST digit distribution)**](assets/cifar10-uniform.png){#fig:cifar10-nn-uniform width="\\textwidth"}

Natural class-imbalanced setting
--------------------------------

In this section, we examine the downstream task performance obtained by pretraining PMSN in more realistic settings. We examine pretraining with both the IN1K, which is relatively class-balanced, and with the iNaturalist18 dataset [@van2018inaturalist], which is relatively class-imbalanced. The iNat18 dataset contains approximately 430K images from over 8142 different species of plants and animals; since some species are more abundant and easier to photograph than others, the class distribution of images in this dataset naturally follows a long-tailed distribution. After pretraining with either iNat18 or IN1k, we examine downstream task performance on the same set of tasks used in Section `\ref{sec:class_stratified}`{=latex}. Additional details about the experimental setup can be found in Appendix `\ref{apndx:priors}`{=latex}.

#### Results with true class distribution.

Table `\ref{tb:inat_class_prior}`{=latex} compares iNat18 pretraining with PMSN using a uniform distribution or the *true* class distribution of iNat18. In particular, if we have $D$ images in the dataset and $D_k$ images in the $k^{\text{th}}$ class category, then we define the class-prior, $\vp_{\text{class}}$, such that $[\vp_{\text{class}}]_k = \nicefrac{D_k}{D}$. Of course this information is not usually assumed to be known in self-supervised pretraining; we consider it here to illustrate the effect of the uniform prior on representation learning. As shown in Table `\ref{tb:inat_class_prior}`{=latex}, simply replacing the uniform prior with the long-tailed class-prior improves performance on *all* downstream tasks: object classification with CIFAR100 (+0.9%), iNat18 (+3.0%), Places205 (+1.4%), object counting with Clevr/Count (+2.1%), and depth prediction with Clevr/Dist (+2.3%) and KITTI (+2.7%).

These results clearly support the intuition that moving away from the uniformity prior when pretraining with class-imbalanced data can improve the quality of the representations that are learned. However, note that we have used the true class distribution in these experiments. While this distribution may be estimated using weak supervision sources often available for internet-scale data (e.g., image captions or hashtags), we would prefer general methods that do not require precise a priori knowledge of the class distribution.

```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:inat_class_prior}
                                                                                  Prior            iNat18                       CIFAR100                      Place205                     Clevr/Count                   Clevr/Dist                       KITTI
  ------------------------------------------------------------------------------------- ----------------------------- ----------------------------- ----------------------------- ----------------------------- ----------------------------- -----------------------------
    `\multicolumn{7}{c}{\scriptsize\bf Pretrained on iNaturalist18 (ViT-S/16)}`{=latex}                                                                                                                                                       
                                                                                uniform             29.1                          59.4                           36.9                         69.4                          56.8                          68.2
                                                                            class-prior           **32.1**                      **60.3**                      **38.3**                      **71.5**                      **59.1**                      **70.9**
                                                                               $\Delta$  `\color{blue}`{=latex} +3.0   `\color{blue}`{=latex} +0.9   `\color{blue}`{=latex} +1.4   `\color{blue}`{=latex} +2.1   `\color{blue}`{=latex} +2.3   `\color{blue}`{=latex} +2.7

  : **True Class Prior:** Comparing the uniform prior with the class prior computed from the true class distribution leads to significant gains on downstream tasks.
:::

#### Results with general power-law distributions.

Table `\ref{tb:inat_powerlaw_priors}`{=latex} evaluates models pretrained using PMSN on the ImageNet-1K and iNat18 datasets when we do not use precise knowledge of the class distribution. In particular, we explore a power-law prior, $\vp_{\text{\sc PL}(\tau)}$, such that $[\vp_{\text{\sc PL}(\tau)}]_k \propto \nicefrac{1}{k}^\tau$ with power-law exponent $\tau = 0.25$. As expected, use of a power-law prior improves downstream task performance when pretraining on the class-imbalanced iNat18 dataset (top half of Table `\ref{tb:inat_powerlaw_priors}`{=latex}), but degrades performance when pretraining on the class-balanced ImageNet-1K dataset (bottom half of Table `\ref{tb:inat_powerlaw_priors}`{=latex}). These results indicate that it is preferable to match the prior distribution in self-supervised algorithms to the distribution of semantic concepts in the pretraining dataset. In the case of ImageNet pretraining, the uniform prior more closely matches the distribution of classes in the dataset, and thus we expect it to achieve strong downstream task performance in that setting.

While these results support the observation that one can still improve SSL pretraining on class-imbalanced datasets without having a priori knowledge of the class distribution, comparing to Table `\ref{tb:inat_class_prior}`{=latex} shows that it would be preferable to use the true class distribution when such information is available. In short, we consider the results of this section as a demonstration of the effect of non-uniform feature priors and do not claim to have completely solved the issue of class-imbalanced pretraining.

```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:inat_powerlaw_priors}
                                                                                  Prior            iNat18                       CIFAR100                      Place205                   Clevr/Count                   Clevr/Dist                     KITTI
  ------------------------------------------------------------------------------------- ----------------------------- ----------------------------- ---------------------------- ---------------------------- ---------------------------- ---------------------------
    `\multicolumn{7}{c}{\scriptsize\bf Pretrained on iNaturalist18 (ViT-S/16)}`{=latex}                                                                                                                                                    
                                                                                uniform             29.1                          59.4                          36.9                         69.4                         56.8                        68.2
                                                                              power-law             30.1                          60.1                          37.7                         71.1                         58.9                        68.2
                                                                               $\Delta$  `\color{blue}`{=latex} +1.0   `\color{blue}`{=latex} +0.7   `\color{blue}`{=latex}+0.8   `\color{blue}`{=latex}+1.7   `\color{blue}`{=latex}+2.1             +0.0
         `\multicolumn{7}{c}{\scriptsize\bf Pretrained on ImageNet (ViT-B/16)}`{=latex}                                                                                                                                                    
                                                                                uniform             41.9                          81.7                          54.0                         74.3                         57.8                        73.3
                                                                              power-law             22.0                          63.8                          38.3                         71.4                         65.8                        66.6
                                                                               $\Delta$  `\color{red}`{=latex}-19.9    `\color{red}`{=latex}-17.8    `\color{red}`{=latex}-15.7   `\color{red}`{=latex}-2.9    `\color{blue}`{=latex}+8.0   `\color{red}`{=latex}-6.7

  : **Power-Law Prior**: PMSN with a power-law prior achieves better downstream performance than uniform prior, when the pretraining dataset has a long-tailed classes distribution. Power-law prior hurts performance for class-balance pretraining dataset. Feature Prior should therefore matches the class-distribution.
:::

#### Visualizing learned prototype vectors.

In Figure `\ref{fig:visu_rcdm_inat}`{=latex}, we use RCDM to visualize the prototypes of a PMSN model pretrained with either power-law or uniform priors on the iNat18 dataset. The features that emerge when pretraining with a power-law prior are more associated with high level concepts such as specific image classes. For example, one can recognize specific types of birds and plants in Figure `\ref{fig:visu_rcdm_inat_powerlaw}`{=latex}, which is not the case with the samples generated using the uniform prior prototypes in Figure `\ref{fig:visu_rcdm_inat_uniforme}`{=latex}. These qualitative results further highlight the effect of the overlooked uniform prior in self-supervised learning with class-imbalanced data.

```{=latex}
\centering
```
```{=latex}
\centering
```
![ Pretraining on iNat18 with power-law prior\
**(Prior matching iNat18 class distribution)** ](assets/MSN2_prior_powerlaw.png){#fig:visu_rcdm_inat_powerlaw width="\\textwidth"}

```{=latex}
\centering
```
![ Pretraining on iNat18 with uniform prior.\
**(Prior not aligned with iNat18 class distribution)** ](assets/MSN2_prior_uniform.png){#fig:visu_rcdm_inat_uniforme width="\\textwidth"}

Conclusion
==========

In this work, we show that many common self-supervised learning frameworks have a prior to capture features that enable uniform clustering of the data, and as such, require class-balanced datasets to learn class-discriminative features. By reformulating self-supervised criteria to prefer power-law distributed features, one can improve quality of the learned representations on real-world class-imbalanced datasets.

```{=latex}
\vfill
```
```{=latex}
\pagebreak
```
Reproducibility statement {#reproducibility-statement .unnumbered}
=========================

To facilitate reproducibility, we provide details on our pretraining and evaluation protocol in Appendices `\ref{apndx:classstratified_sampling}`{=latex} and `\ref{apndx:priors}`{=latex}. When pretraining using existing methods, we leverage publicly available implementations along with the default hyperparameters; see Appendix `\ref{apndx:classstratified_sampling}`{=latex} for details. For evaluation, we use the publicly available VISSL codebase [@goyal2021vissl]; specific evaluation configurations are provided in Appendix `\ref{apndx:classstratified_sampling_eval}`{=latex}. The training details for the PMSN experiments are provided in Appendix `\ref{sec:priors}`{=latex}. And finally, the proofs for all propositions in Section `\ref{sec:theory}`{=latex} are produced in Appendix `\ref{apndx:proofs}`{=latex}.

```{=latex}
\bibliographystyle{ICLRtemplate/iclr2023_conference}
```
```{=latex}
\vfill
```
```{=latex}
\pagebreak
```
```{=latex}
\appendix
```
Broader related work {#apndx:relatedwork}
====================

Joint-embedding architectures are an active line of research in self-supervised representation learning [@wu2018unsupervised; @he2019moco; @chen2020exploring; @grill2020bootstrap; @chen2020mocov2; @caron2021emerging; @bardes2021vicreg; @zhou2021ibotyes]. These approaches rely on invariance based pretraining where a neural network encoder is trained to output similar embeddings for two or more views of the same image. To avoid pathological solution, joint-embedding approaches use explicit regularization [@chen2020simple; @caron2021emerging; @bardes2021vicreg; @assran2021semi] or architectural constraints [@grill2020bootstrap; @chen2020exploring].

Explicit regularization usually maximizes the volume of space occupied by the representations. Regularization can be implemented using various strategies such as contrasting negative samples [@bromley1993signature; @he2019moco; @chen2020simple], variance-covariance regularization [@bardes2021vicreg; @zbontar2021barlow], or by maximizing the entropy of the representations [@asano2019self; @caron2020unsupervised; @caron2021emerging; @assran2021semi; @assran2022masked]. Alternative collapse-prevention approaches based on architectural constraints leverage architectural design to avoid collapse such as stopping the gradient flow in one of the Siamese Network branches [@chen2020simple], using a momentum encoder to compute the network targets [@grill2020bootstrap], or using an asymmetric prediction head [@grill2020bootstrap; @chen2020simple; @baevski2022data2vec]. Recent theoretical work [@tian2021understanding] explores why certain joint-embedding methods with architectural constraint avoid representation collapse without explicit use of a volume maximization penalty; the implicit collapse prevention mechanisms here are not mutually exclusive. Recent empirical work [@bordes2022guillotine] studies the invariance properties of these pretrained representations. The work of @mitrovic2020representation has drawn connections between invariance and causality when the data-augmentations used during pretraining manipulate specific factors of variation in the data-generating distribution. Other works have studied the computational efficiency of joint-embedding methods, and demonstrated how small amounts of supervision can be used to accelerate convergence [@assran2020supervision].

While joint-embedding architectures are usually leveraged to learn a global image representations, some works explore these use of these architecture for learning local and dense representations [@chen2022intra; @gidaris2020learning]. More recently, @lecun2022path proposes an architecture based on joint-embedding approaches to learn generic world model, capturing both dense local features as well as global image features.

Orthogonal to the contributions of invariance-based pretraining, another line of work attempts to learn representations by artificially masking parts of the input and training a network to reconstruct the hidden content [@vincent2010stacked]. Auto-regressive models, and denoising auto-encoders, in particular, predict clean visual inputs from noisy views [@chen2020generative; @vincent2010stacked; @he2021masked; @bao2021beit; @baevski2022data2vec]. A mask-noise is usually used to perturb the images and those approaches predict the masked inputs either at the pixel level [@dosovitskiy2020image; @he2021masked; @xie2019unsupervised] or at a token-level, using a pixel (often patch-level) tokenizer [@bao2021beit; @wei2021masked]. While these works demonstrate impressive scalability, they usually learn features at a low-level of semantic abstraction compared to joint-embedding approaches [@assran2022masked].

More recently, a set of approaches attempt to combine both joint-embedding and reconstruction based approaches [@zhou2021ibotyes; @el2021large], wherein they combine an invariance pretraining loss with a patch-level reconstruction loss.

Relation to the InfoMax principle
=================================

A longstanding conviction in unsupervised representation learning is that the resulting representations should be both maximally informative about the inputs, while also satisfying certain simplicity constraints [@linsker1988self; @goodfellow2016deep]. The former objective is often referred to as the information-maximization principle (InfoMax), while the latter is sometimes referred to as the parsimony principle [@ma2022principles], which is crucial to the problem formulation. Indeed, @bridle1991unsupervised, one of the first works to empirically explore unsupervised representation learning via information-maximization, found that, in the absence of additional constraints, the resulting InfoMax representations may not be very useful. The more recent analysis of @tschannen2019mutual also argues that simplicity constraints are essential to the success of modern representation learning methods built on the InfoMax principle. Historically, simplicity constraints were enforced by encouraging the learned representations to be sparse, low-dimensional, or disentangled, i.e., the individual dimensions of the representation vector should be statistically independent.

Modern state-of-the-art approaches for unsupervised representation learning still frequently employ an information-maximization formulation [@hjelm2018learning; @bachman2019learning; @krause2010discriminative; @hu2017learning; @oord2018representation], but with the simplicity constraints manifested in self-supervised loss terms. One example is the widespread view-invariance penalty [@misra2020self], often coupled with with independence [@zbontar2021barlow; @bardes2021vicreg] or low-dimensionality constraints, e.g., by projecting representations on the unit hypersphere [@chen2020simple; @he2019moco; @grill2020bootstrap].

To better understand these methods, recall that the mutual information, $I(\cdot, \cdot)$, between a latent vector $Z$ and data $X$ can be written as $$I(Z,X) = H[Z] - H[Z|X],$$ where $H[Z]$ is the marginal entropy of $Z$, and $H[Z|X]$ is the expected entropy of the posterior distribution. In learning a representation of the data, $Z$, that maximizes the mutual information $I(Z,X)$, we thereby seek to maximize the marginal entropy of $Z$, i.e., search for uniformly distributed feature embeddings.

However, as we show in this work, perhaps the features that we wish our representations to capture are not necessarily those with the highest marginal entropy. It is often the case that the semantic concepts we wish to capture actually follow a long-tailed distribution in the wild . Such desirable features would be penalized under existing information-maximization frameworks. Thus, in the absence of finer grained notions of information, perhaps it is necessary to reconsider the longstanding conviction of seeking representations that maximize information content.

Theoretical guarantees {#apndx:proofs}
======================

Proof of proposition `\ref{prop:kmeans}`{=latex} {#proof:kmeans1}
------------------------------------------------

We will demonstrate that $$\begin{aligned}
    \min_{\mu_1,\dots,\mu_k}\sum_{n=1}^{N}\min_{c=1,\dots,K} \lVert\vx_n-\mu_c \rVert^2_2 =\min_{\mP\in \{0,1\}^{N \times K}:\mP\mathbf{1}_{K}=\mathbf{1}_{N}}\sum_{k=1}^{K}\sum_{n,n^\prime=1}^{N}  \frac{\mP_{n,k}\mP_{n^\prime,k}}{\mathbf{1}_{N}^T\mP_{.,k}}\lVert\vx_{n}-\vx_{n'} \rVert^2_2,\end{aligned}$$ where the left-hand side minimizes over the values of the $K$ centroids and the inner minimization identifies the cluster membership of sample $x_n$. To do so, one should first notice that almost surely, for any input $\vx_n$, the $\min_{c=1,\dots,K} \lVert\vx_n-\mu_c \rVert^2_2$ is attained for a single centroid. In fact, for any continuous distribution on the data (and/or centroids) the probability to sample them so that a sample $\vx$ lies exactly equidistant from two (or more) centroids is $0$ since this is a space of dimension $D-1$, which has measure $0$. Hence, we first obtain $$\begin{aligned}
    \min_{\mu_1,\dots,\mu_k}\sum_{n=1}^{N}\min_{c=1,\dots,K} \lVert\vx_n-\mu_c \rVert^2_2 =\min_{\mu_1,\dots,\mu_k}\sum_{n=1}^{N}\min_{\vp \in \{0,1\}^{K}:\vp^T\mathbf{1}_K=1}\sum_{k=1}^{K} \vp_k\lVert\vx_n-\mu_k \rVert^2_2,\end{aligned}$$ and because each sub-problem is independent, we can pull them out of the sum to obtain $$\begin{aligned}
    \min_{\mu_1,\dots,\mu_k}\sum_{n=1}^{N}\min_{c=1,\dots,K} \lVert\vx_n-\mu_c \rVert^2_2 =\min_{\mu_1,\dots,\mu_k}\min_{\mP\in \{0,1\}^{N \times K}:\mP\mathbf{1}_{K}}\sum_{n=1}^{N}\sum_{k=1}^{K} \mP_{n,k}\lVert\vx_n-\mu_k \rVert^2_2,\end{aligned}$$ now we can switch the minimization order to obtain $$\begin{aligned}
    \min_{\mP\in \{0,1\}^{N \times K}:\mP\mathbf{1}_{K}}\min_{\mu_1,\dots,\mu_k}\sum_{n=1}^{N}\sum_{k=1}^{K} \mP_{n,k}\lVert\vx_n-\mu_k \rVert^2_2=\min_{\mP\in \{0,1\}^{N \times K}:\mP\mathbf{1}_{K}}\sum_{n=1}^{N}\sum_{k=1}^{K} \mP_{n,k}\lVert\vx_n-\mu^*_k \rVert^2_2,\end{aligned}$$ with $\mu^*_k$ the mean of the samples within cluster $k$, i.e. $\frac{1}{(\mP^T\mathbf{1}_N)_k}\sum_{n=1}^{N}\mP_{n,k}\vx_n$. The only thing left to show is that the sum of distances to the cluster means is equivalent to the pairwise distances between all the points within each cluster with an appropriate normalization, which is a standard result; see, e.g., @zha2001spectral.

Proof of proposition `\ref{prop:vicreg}`{=latex} {#proof:kmeans2}
------------------------------------------------

#### Notation.

Before establishing the theoretical connections between common joint-embedding methods for self-supervised learning and the K-means method, we first define some notation to facilitate the discussion. Given a \`\`stand-alone" dataset, $\mX'\in\mathbb{R}^{N^\prime\times 3HW}$, we construct the pretraining data $\mX \in \R^{N\times 3HW}$ by repeatedly perturbing the elements of the stand-alone dataset. Specifically, $$\mX \triangleq [\text{View}_1(\mX')^T,\dots,\text{View}_V(\mX')^T]^T,$$ where $\text{View}_i(.)$ is a sample-wise image transformation, e.g., random crop, color jitter, patch masking. We also define the ground-truth similarity matrix, $\mG \in \{0,1\}^{N \times 3HW}$, given by $$\mG_{i,j}=
    \begin{cases}
        (\vx_i \sim \vx_j), \;\; i \neq j\\
        0, \;\;\;\;\;\;\;\;\;\;\;\;\;\;\; i=j
    \end{cases},$$ where the $\sim$ operator returns $1$ if its two arguments are positively related (i.e., correspond to different views of the same sample). We also define the matrix of *embeddings* obtained from a model $f_{\theta}(\cdot)$ as $$\begin{aligned}
    \mZ\triangleq[\vz_1,\dots,\vz_N]^T\in\mathbb{R}^{N \times D}\;\;\text{ with }\;\vz_n\triangleq f_{\theta}(\vx_n).\label{eq:embedding}\end{aligned}$$ Now we are ready to present our first equivalence relations for the VICReg and SimCLR methods. The VICReg loss [@bardes2021vicreg], which is a function of $\mX$ and $\mG$, can be defined as $$\begin{aligned}
\mathcal{L}\hspace{-0.05cm}=&\alpha\hspace{-0.05cm} \sum_{k=1}^{K}\max\hspace{-0.05cm}\left(\hspace{-0.05cm}0,1\hspace{-0.05cm}-\hspace{-0.05cm}\sqrt{\Cov(\mZ)_{k,k}}\right)\hspace{-0.08cm}+\hspace{-0.05cm} \beta  \sum_{k=1}^{K}\sum_{\substack{\ell=1\\\ell \not = k}}^{K} \Cov(\mZ)^2_{k,\ell} +\hspace{-0.05cm}\frac{\gamma}{N} \sum_{i=1}^{N}\sum_{j=1}^{N}(\mG)_{i,j}\|\vz_{i}-\vz_{j}\|_2^2,\end{aligned}$$ but this desired result relies on the simplification employed in `\eqref{eq:VICR2}`{=latex}. To prove our statement, we first remind the reader of a common result that we will heavily rely on: the decomposition of the covariance matrix into within- and between-cluster covariance matrices. Let us assume for simplicity that $\mZ$ is already centered. We can now decompose the covariance into $$\begin{aligned}
    \Cov(\mZ)=&\frac{1}{N}\mZ^T\mH\mZ=\frac{1}{N}\left(\mZ^T\mG\mZ+\mZ^T(\mI-\mG)\mZ\right),\end{aligned}$$ where two terms are now the *between cluster* and *within cluster* covariances, with $\mH$ the centering matrix defined by $\mI-\mathbf{1}_{N}\mathbf{1}_{N}^T/N$. Let's first consider the LHS of the VICReg loss to simplify it into $$\begin{aligned}
     \left\|\frac{1}{N}\mZ^T\mZ-\mI\right\|_F^2=& \left\|\frac{1}{N}\mZ^T\mG\mZ+\frac{1}{N}\mZ^T(\mI-\mG)\mZ-\mI\right\|_F^2\\
    =& \left\|\frac{1}{N}\mZ^T\mG\mZ-\mI\right\|_F^2+ \left\|\frac{1}{N}\mZ^T(\mI-\mG)\mZ\right\|_F^2\\
    &\hspace{4.5cm}+\frac{2}{N} Tr\left(\mZ^T(\mI-\mG)\mZ(\frac{1}{N}\mZ^T\mG\mZ-\mI)\right)\\
    =& \left\|\frac{1}{N}\mZ^T\mG\mZ-\mI\right\|_F^2+ \left\|\frac{1}{N}\mZ^T(\mI-\mG)\mZ\right\|_F^2-\frac{2}{N} Tr\left(\mZ^T(\mI-\mG)\mZ\right)\\
    &\hspace{5.5cm}+\frac{2}{N^2} Tr\left(\mZ^T(\mI-\mG)\mZ\mZ^T\mG\mZ\right)\\
    &\hspace{-1.1cm}= \left\|\frac{1}{N}\mZ^T\mG\mZ-\mI\right\|_F^2+\sum_i \lambda_i \left(\frac{1}{N}\mZ^T(\mI-\mG)\mZ\right)\left[\lambda_i \left(\frac{1}{N}\mZ^T(\mI-\mG)\mZ\right)-2\right]\\
    &+\frac{2}{N^2} Tr\left(\mZ^T(\mI-\mG)\mZ\mZ^T\mG\mZ\right),\end{aligned}$$ where $\lambda_i(\mM)$ returns the $i^{\rm th}$ eigenvalue of its matrix argument $\mM$; the last equality is obtained by noticing that $\| \mM \|_F^2 = \sum_i\lambda_i(\mM)^2$ and that for a symmetric positive semidefinite matrix, $Tr(\mM)=\sum_i\lambda_i(\mM)$; we thus obtain the following upper and lower bounds $$\begin{aligned}
     \left\|\frac{1}{N}\mZ^T\mZ-\mI\right\|_F^2\leq& \left\|\frac{1}{N}\mZ^T\mG\mZ-\mI\right\|_F^2+\sum_i \lambda_i \left(\frac{1}{N}\mZ^T(\mI-\mG)\mZ\right)\\
     &\times \left[\lambda_i \left(\frac{1}{N}\mZ^T(\mI-\mG)\mZ\right)-2+\lambda_i \left(\frac{1}{N}\mZ^T\mG\mZ\right)\right]\\
     \left\|\frac{1}{N}\mZ^T\mZ-\mI\right\|_F^2\geq& \left\|\frac{1}{N}\mZ^T\mG\mZ-\mI\right\|_F^2+\sum_i \lambda_i \left(\frac{1}{N}\mZ^T(\mI-\mG)\mZ\right)\\
     &\times \left[\lambda_i \left(\frac{1}{N}\mZ^T(\mI-\mG)\mZ\right)-2+\lambda_{K+1-i} \left(\frac{1}{N}\mZ^T\mG\mZ\right)\right]\\
     \left\|\frac{1}{N}\mZ^T\mZ-\mI\right\|_F^2\geq& \left\|\frac{1}{N}\mZ^T\mG\mZ-\mI\right\|_F^2+\sum_i \left[\lambda_i \left(\frac{1}{N}\mZ^T(\mI-\mG)\mZ\right)-1\right]^2-K,\end{aligned}$$ from which is becomes clear that to minimize the variance+covariance terms, one must either maximize the intra-cluster variance, or the inter-cluster variance, or both. However, the intra-cluster variance is exactly the invariance term since it can be expressed as $$\begin{aligned}
    \sum_{j=1}^{N}(\mG)_{i,j}\|\vz_{i}-\vz_{j}\|_2^2=2 Tr(\mZ^T(\mI-\mG)\mZ),\end{aligned}$$ and thus the only possible solution to minimize the invariance term while minimizing the variance+covariance is to minimize $\left\|\frac{1}{N}\mZ^T\mG\mZ-\mI\right\|_F^2$ and thus we recover that VICReg's loss corresponds to the K-means loss plus a regularizer $\|\mu\mu^T-\mI \|_F^2$ as in $$\begin{aligned}
    \frac{\gamma}{N}\sum_{k=1}^{K}\sum_{\vx \in \sX_k}\|\vx-\mu_k\|_2^2+\alpha\|\mM^T\mM-\mI \|_F^2,\end{aligned}$$ with $\mM\triangleq [\mu_1,\dots,\mu_K]$, and the number of centroids $K$ is given by $\min(\dim(\vz),rank(\mG+\mI))$ and the centroids are given by $\mu_k=\frac{1}{\langle \mP_{.,k},\mathbf{1}\rangle}\sum_{n=1}^{N}\mP_{n,k}\vx_n$ with $\mG=\mP^T\mD\mP$, and finally $\sX_k=\{\vx_n \in \sX:\mP_{n,k}>0\}$

Proof of proposition `\ref{prop:msn}`{=latex} {#proof:kmeans3}
---------------------------------------------

The proof is relatively straightforward and will follow the same principle as the proof showing how GMM recovers K-means. The only difference is that we will take the zero-noise limit of the MSN loss (`\eqref{eq:msn}`{=latex}) instead of GMM loss (`\eqref{eq:GMM_uni}`{=latex}), and we will see that we recover a constrained version of K-means with extra constraints on the cluster distribution. The MSN loss is defined for positive pairs $\vx_n,\vx_n^+$ and estimates the corresponding cluster posteriors $\vp_n,\vp_n^+$ via $\text{softmax}(\mM^T\vx_n/\sigma)$ and $\text{softmax}(\mM^T\vx^+_n/\sigma^+)$ respectively with $\mM\triangleq [\mu_1,\dots,\mu_K]$ as we employed in VICReg, and the two temperature parameters are commonly $\sigma^+ \gg \sigma>0$; this asymmetry is know as sharpening. Without loss of generality we consider the $\sigma$MSN loss, i.e., the MSN loss re-scaled by $\sigma$, this does not alter the training dynamics as the learning rate can be adapted accordingly, but will simplify our derivations below; we thus also replace $\lambda$ with $\lambda/\sigma$. In this setting, we see that $\text{softmax}(\mM^T\vx^+_n/\sigma^+)_k=\delta(k-k(n))$ where we hereafter denote $k(n)=\argmin_{c}\|\mu_c-\vx_n^+\|_2$ as the cluster assignment of the positive view of $\vx_n$. With those notations, we can finally derive $$\begin{aligned}
\lim_{\sigma \mapsto 0 } \sigma MSN =&\lim_{\sigma \mapsto 0 }\left[\frac{\sigma}{N} \sum^N_{n=1} H(\vp_{n}^+, \vp_n) + \lambda D_{\text{KL}}({\overline{\vp}\ \lVert\ \vp_\text{prior}})\right]\\
&\hspace{-2cm}=\lim_{\sigma \mapsto 0 }\left[-\frac{\sigma}{N}\sum_{n=1}^{N}\sum_{k=1}^{K}\delta(k-k(n))\log\left(\text{softmax}(\mW^T\vx_n/\sigma)_k\right)\right]+\lambda \sum_{k=1}^{K} \frac{N_k}{N}\log\left(\frac{N_k/N}{(\vp_{\rm prior})_k}\right)\end{aligned}$$ where we could push the limit inside $D_{\text{KL}}$ since we assume that no cluster is empty and thus the KL function is continuous as $\sigma\mapsto 0$, and that $\lim_{\sigma \mapsto 0}\overline{\vp}=(\frac{N_1}{N},\dots,\frac{N_K}{N})$. The value of $N_k$ is the number of samples that are assigned to cluster $k$. This assumption, which was only used for simplicity, can be removed easily by noticing that $\lim_{u\mapsto 0} u\log(u)=0$. Now considering the left term of the loss we obtain a direct simplification of the log-softmax as follows $$\begin{gathered}
    \lim_{\sigma \mapsto 0 }\left[-\frac{\sigma}{N}\sum_{n=1}^{N}\sum_{k=1}^{K}\delta(k-k(n))\log\left(\text{softmax}(\mW^T\vx_n/\sigma)_k\right)\right]\\=\frac{1}{N}\sum_{n=1}^{N}\underbrace{\|\vx_n-\mu_{k(n)} \|_2^2-\min_{k=1,\dots,K}\|\vx_n-\mu_{k} \|_2^2}_{=0 \iff \argmin_c \|\vx_n-\mu_c\|_2=\argmin_c\|\vx_n^+-\mu_c\|_2},\end{gathered}$$ which is minimized whenever all the samples $\vx_n$ that have their positive views $\vx_n^+$ associated to the same centroid belong to the same cluster; thus the overall MSN loss can be written as $$\begin{aligned}
    \sum_{k=1}^{N}\sum_{\vx \in \sX_k}\|\vx-\mu_k\|_2^2+\lambda \sum_{k=1}^{K} \frac{N_k}{N}\log\left(\frac{N_k/N}{(\vp_{\rm prior})_k}\right),\end{aligned}$$ where $\sX_k = \{\vx \in \sX : \argmin_{c=1,\dots,K}\|\vx^+-\mu_c\|_2=k \}$.

Proof of proposition `\ref{prop:swav}`{=latex} {#proof:swav}
----------------------------------------------

The proof for SwAV mainly relies on the same development as done for MSN. That is, we saw that with enough sharpening, the cross-entropy between $\vp_{n}^+$ and $\vp_n$ falls back to the K-means like term with an extra margin that is minimized only when the two views are in the same cluster. The difference arises in that SwAV explicitly adds linear constraints on the cluster-membership matrix whereas MSN was employing a KL-divergence between the average posterior and the prior. The connection between the linear constraint and the SwAV Sinkhorn-Knopp procedure has been made precisely in @wang2010learning, where it was shown that the latter solved a relaxed optimization problem for which the cluster-membership is no longer constrained to be $0$ or $1$.

Pretraining and evaluation details for section `\ref{sec:class_stratified}`{=latex} {#apndx:classstratified_sampling}
===================================================================================

Pretraining protocol
--------------------

#### SimCLR.

We use the VISSL [@goyal2021vissl] code base to pretrain a ResNet-50 with SimCLR [@chen2020simple], with a batch size of $4096$ for $300$ epochs. Our pretraining follow the standard hyperparameters defined in @chen2020simple. The learning rate follows the default cosine schedule with a $10$ epoch warmup. We use a temperature of $0.1$ for the contrastive loss and LARS [@you2017large] as an optimizer. We modify the sampler to force $K$ different classes inside each mini-batch where $K$ is set to $8$ for the class-imbalanced sampling experiments and $960$ for the class-balanced sampling experiments.

#### VICReg.

We pretrain a ResNet-50 with VICReg [@bardes2021vicreg] using the LARS optimizer with a batch size of $1024$ for $300$ epochs using the official code base, which is publicly available: <https://github.com/facebookresearch/vicreg>. Our pretraining follow the standard hyperparameters defined in @bardes2021vicreg. The learning rate follows the default cosine schedule with a $10$ epoch warmup. We set the dimensions of the expander MLP to the default $8192-8192-8192$. Weight decay is set to $10^{-6}$, the variance and invariance coefficients are set to $25.0$, and the covariance coefficient is set to $1.0$. We modify the sampler to force $K$ different classes inside each mini-batch where $K$ is set to $2$ for the class-imbalanced sampling experiments and $960$ for the class-balanced sampling experiments.

#### MSN.

We pretrain a ViT-B/16 with MSN [@assran2022masked] using the AdamW optimizer with a batch size of $1024$ for $300$ epochs and $1024$ prototypes using the official code base, which is publicly available: <https://github.com/facebookresearch/msn>. Our pretraining follow the standard hyperparameters defined in @assran2022masked. The learning rate follows the default cosine schedule with a $15$ epoch warmup. Weight decay is linearly increased from $0.04$ to $0.4$. Gradient clipping is set to $3.0$, the entropy coefficient is set to $1.0$, the temperature is set to $0.1$, the sharpening exponent is set to $0.25$, and the masking ratio is set to $0.5$. The number of random-mask views in each iteration is set to $1$, and the number of focal-mask views in each iteration is set to $10$. We modify the sampler to force $K$ different classes inside each mini-batch where $K$ is set to $2$ for the class-imbalanced sampling experiments and $960$ for the class-balanced sampling experiments.

#### MAE.

We pretrain a ViT-L/16 with MAE [@he2021masked] using the AdamW optimizer with a batch size of $1024$ for $800$ epochs using the official code base, which is publicly available: <https://github.com/facebookresearch/mae>. Our pretraining follow the standard hyperparameters defined in @he2021masked. The learning rate follows the default cosine schedule with a $40$ epoch warmup. Weight decay is set to $0.05$, and the masking ratio is set to $0.75$. We modify the sampler to force $K$ different classes inside each mini-batch where $K$ is set to $2$ for the class-imbalanced sampling experiments and $960$ for the class-balanced sampling experiments.

#### data2vec.

We pretrain a ViT-B/16 with data2vec [@he2021masked] using the AdamW optimizer with a batch size of $2048$ for $800$ epochs using the official code base, which is publicly available: <http://github.com/facebookresearch/data2vec_vision/tree/main/beit>. Our pretraining follow the standard hyperparameters defined in @baevski2022data2vec. Specifically, the learning rate follows the default cosine schedule with a $10$ epoch warmup. Path drop is set to $0.25$, gradient clipping is set to $3.0$, weight decay is set to $0.05$, and the target layers are set to $[6,7,8,9,10,11]$. We modify the sampler to force $K$ different classes inside each mini-batch where $K$ is set to $2$ for the class-imbalanced sampling experiments and $960$ for the class-balanced sampling experiments.

Evaluation protocol {#apndx:classstratified_sampling_eval}
-------------------

For linear evaluation, we use the default linear evaluation configurations of VISSL [@goyal2021vissl] to evaluate our models on the following datasets: ImageNet [@russakovsky2015imagenet], iNaturalist18 [@van2018inaturalist], CIFAR100 [@krizhevsky2009learning], Clevr/Count [@johnson2017clevr], Clevr/Dist [@johnson2017clevr], KITTI/Dist [@geiger2013vision] and Places205 [@zhou2014learning].

For pretrained models based on Vision Transformers [@dosovitskiy2020image], we report the best linear classifier number among the following representations:

-   the concatenation of the last $4$ layers of the class token, [@caron2021emerging]

-   the representation of the last layer of the class token.

For pretrained models based on ResNet50 architectures [@he2016deep], we follow the evaluation protocol of SEER [@goyal2022vision] and report the best linear classifier number among the following representations:

-   the final representation layer (of dimension $2048$ for a ResNet50),

-   an adaptive average pooling of the last feature map, concatenated to get $8192$ dimensions.

We also follow the default VISSL [@goyal2021vissl] configuration and attach $2$ linear heads per chosen representation, one composed of a single linear layer, and one with an added batch normalization [@https://doi.org/10.48550/arxiv.1502.03167] before the linear layer.

Pretraining and evaluation details for section `\ref{sec:priors}`{=latex} {#apndx:priors}
=========================================================================

Pretraining protocol
--------------------

### Toy setting

We pretrain a ViT-Tiny/4 with MSN [@assran2022masked] using the AdamW optimizer with a batch size of $1024$ for $300$ epochs and $10$ prototypes using the official code base, which is publicly available: <https://github.com/facebookresearch/msn>. The learning rate follows the default cosine schedule with a $15$ epoch warmup. Weight decay is linearly increased from $0.04$ to $0.4$. Gradient clipping is set to $0.0$, the entropy coefficient is set to $100.0$, the temperature is set to $0.1$, the sharpening exponent is set to $0.25$, and the masking ratio is set to $0.05$. The number of random-mask views in each iteration is set to $1$, and the number of focal-mask views in each iteration is set to $0$. To accommodate the lower-resolution images, we modify the scale of random-resized-crop data augmentation to $(0.5, 1.0)$ and train without Gaussian-Blur.

### Natural class-imbalanced setting

We pretrain a ViT-S/16 with MSN [@assran2022masked] using the AdamW optimizer with a batch size of $4096$ for $300$ epochs and $8142$ prototypes using the official code base, which is publicly available: <https://github.com/facebookresearch/msn>. The learning rate follows the default cosine schedule with a $15$ epoch warmup. Weight decay is linearly increased from $0.04$ to $0.4$. Gradient clipping is set to $3.0$, the entropy coefficient is set to $5.0$, the temperature is set to $0.1$, the sharpening exponent is set to $0.25$, and the masking ratio is set to $0.15$. The number of random-mask views in each iteration is set to $1$, and the number of focal-mask views in each iteration is set to $10$.

Evaluation protocol {#evaluation-protocol}
-------------------

We use the exact same evaluation protocol as `\ref{apndx:classstratified_sampling_eval}`{=latex}.

Alternatives to feature priors for section `\ref{sec:priors}`{=latex}: data sampling
====================================================================================

As an alternative to changing the feature prior, perhaps one can devise more intelligence sampling strategies to align the class distribution in the sampled mini-batches with the implicit uniform prior in self-supervised algorithms. Here, we investigate the impact on the mini-batch sampling schemes when pretraining on iNat18. We compare (uniform) random sampling to unthresholded inverse square-root frequency sampling [@https://doi.org/10.48550/arxiv.1310.4546], which is commonly applied in the context of (weakly) supervised learning on internet-scale class-imbalanced data. Specifically, the prescription of inverse square-root frequency sampling is to sample a class with probability inversely proportional to the square root of the frequency of the class, and then sample images uniformly within the class. For example, if we have $D$ images in the dataset and $D_k$ images in the $k^{\text{th}}$ class category, then the probability of sampling class $k$ is equal to $\nicefrac{\sqrt{D_k}}{D}$, as opposed to the traditional $\nicefrac{D_k}{D}$ probability under uniform sampling. The effect of this sampling strategy is to mitigate class-imbalances by oversampling underrepresented classes in the dataset. While such an effect is desirable, one limitation to this strategy is that it induces an implicit reduction in the dataset size; for example, we may see the same images from the tail of the distribution very frequently and, given finite training time, may never see some of the images in the head. Another limitation of this strategy is that requires knowing the class label (or a weak class label) for every image in the training set.

Despite these limitations, we explore inverse frequency sampling in Table `\ref{tb:inatsampling}`{=latex} and observe that it does not improve upon uniform random sampling. We hypothesize the inverse frequency sampling tends to aggressively over-sample classes with very few examples, reducing the effective number of images seen per epoch and thus degrading the quality of the learned representations. This result suggests that changing the feature prior is a more viable solution for pretraining with class-imbalanced data.

```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:inatsampling}
                                                                               Sampling   iNat18    CIFAR100   Place205   Clevr/Count   Clevr/Dist    KITTI
  ------------------------------------------------------------------------------------- ---------- ---------- ---------- ------------- ------------ ----------
    `\multicolumn{7}{c}{\scriptsize\bf Pretrained on iNaturalist18 (ViT-S/16)}`{=latex}                                                             
                                                                                Uniform  **29.1**   **59.4**   **36.9**    **69.4**        56.8      **68.2**
                                                               Inverse Square Root Freq    23.4       58.4       35.1        68.3        **59.8**      66.6

  : **Impact of Sampling Scheme**: Changing the mini-batch distribution by using an inverse square root sampling strategy on iNat18 does not improve downstream performance of the learned representations.
:::

Marginal sampling probabilities {#app:marginalprob}
===============================

The class-imbalanced strategy in Section `\ref{sec:class_stratified}`{=latex} maintains the same marginal probability of sampling individual data points as compared to the class-balanced sampling strategy; i.e., the probability of sampling a particular data point in each iteration is unchanged. To see this, suppose we have a mini-batch of size $B = n \times 960$ for some integer $n > 0$. Under class-balanced sampling, we will thus sample $960$ classes in each iteration, and then $n$ images per class. Under class-imbalanced sampling, we will first sample $2$ classes in each iteration, and then $n \times 480$ images per class so that the overall batch size is $B$. Now consider the probability of sampling a data point $x$ that comes form a class $C$ in our dataset containing $N \geq 480 \cdot n$ samples. Under class-balanced sampling, the probability of sampling the data point $x$ can be factored as $$p_{\text{balanced}}(x) = p(x|C)p(C) = \frac{{N-1 \choose n-1}}{{N \choose  n}} \frac{{999 \choose 959}}{{1000 \choose 960}} = \frac{n}{N} \frac{960}{1000}.$$ Under class-imbalanced sampling, the probability of sampling $x$ can be factored as $$p_{\text{imbalanced}}(x) = p(x|C)p(C) =  \frac{{N-1 \choose 480 \cdot n-1}}{{N \choose 480 \cdot n}} \frac{{999 \choose 1}}{{1000 \choose 2}} = \frac{480\cdot n}{N} \frac{2}{1000},$$ from which it is clear that $p_{\text{balanced}}(x)$ is equal to $p_{\text{imbalanced}}(x)$.

Visualizing prototypes with rcdm {#apndx:rcdm}
================================

We use the RCDM framework [@bordes2022high] to visualize the representations and prototypes learned with MSN. RCDM trains a conditional generative diffusion model, which maps a noise vector to pixel space using a neural network representation as conditioning. Many works [@dale2; @imagen] have demonstrated the potential of conditional diffusion model for image generation, however they are also very useful, as highlighted by [@bordes2022high] to get a better understanding of what is learned by neural networks.

During training, RCDM takes as input a noisy image $\hat{x}_{t}$ (corrupted with an $\epsilon_t$ noise vector such as $\hat{x}_{t} = \vx + \epsilon_t$) and the representation vector $\vy$ computed by MSN of the image $\vx$. Then, RCDM is train, with a denoising score matching loss [@vincent2011connection; @ddpm_2020], to reconstruct the image $\vx$ that was used to compute the representation vector $\vy$. More formally, we define a RCDM neural network $g_{\eta}(\hat{x}_{t}, \vy)$ that learns to predict the noise component $\epsilon_t$ of $\hat{x}_{t}$, ie. by minimizing $\|g_{\eta}(\hat{x}_{t}, \vy) - \epsilon_t\|_2^2$. As demonstrated by [@bordes2022high], RCDM extract as many information as possible from the representation vector $\vy$ in order to reconstruct faithfully the image.

The conditioning vector $\vy$ is computed from $\vx$ using a pretrained and frozen MSN model. MSN generates as output the probability distribution $\vp$ that an image $\vx$ belongs to a given cluster, $\vp =\text{softmax}(\mW^t p_\gamma(f_\theta(\vx))$ where $\mW$ is the matrix concatenating all the prototypes (or cluster centroids), $p_\gamma$ the projection head and $f_\theta$ the encoder. To visualize the prototype contained in $\mW$, we first train a conditional generative diffusion model that take the last linear layer input as conditioning, i.e $\vy=p_\gamma(f_\theta(\vx))$. After training, we replace the image embedding with a learned prototype and generate the corresponding output in pixel space by setting $\vy=\mW_i$, where $i$ is randomly selected.

To summarize, we gather for every images in the training set their embedding (with dimension size of 256)[^5] of a trained MSN model. Then, we use these embedding as conditioning for RCDM which is train to reconstruct the corresponding image associated to a given embedding. When training is complete, we replace the projector's embedding by the prototype learned with MSN (which also have a dimension of 256). By doing so we can visualize which information is associated to each prototype (or cluster) learned with MSN. For every RCDM training, we used the same defaults settings as the ones on <https://github.com/facebookresearch/RCDM>. We train each network for 200000 iterations.

Full set of results for section `\ref{sec:class_stratified}`{=latex} {#apndx:classstratified_results}
====================================================================

In this section, we report the full experimental results for Section `\ref{sec:class_stratified}`{=latex}. Tables `\ref{tb:simclr}`{=latex}, `\ref{tb:msn}`{=latex},`\ref{tb:vicreg}`{=latex}, `\ref{tb:data2vec}`{=latex} and  `\ref{tb:mae}`{=latex}, and  show the results of SimCLR, MSN, VICReg, data2vec and MAE on the CIFAR100, CIFAR100 $1\%$, Place205, Clevr/Count, Clevr/Dist and KITTI downstream tasks. Additionaly, we reports performance on ImageNet linear and ImageNet low-shot $1\%$ in Tables `\ref{tb:indistribution}`{=latex} and `\ref{tb:indistribution_lowshot}`{=latex}. Figure `\ref{fig:visu_class_stratified}`{=latex} shows a summary of the main results.

```{=latex}
\centering
```
![Visual representations of the results of Table `\ref{tb:class_stratified_sampling}`{=latex}. Methods relying on volume maximization regularizers all exhibit similar performance alteration across diverse transfer tasks.](assets/balanced_vs_imbalanced.png){#fig:visu_class_stratified width="\\textwidth"}

```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:indistribution_lowshot}
                                CIFAR100                  CIFAR100 $1\%$                   Place205                   Clevr/Count           Clevr/Dist   KITTI
  ------------------- ---------------------------- ----------------------------- ---------------------------- ---------------------------- ------------ -------
    960 cls per batch             69.9                         31.4                          52.1                         77.4                 65.5      70.5
      8 cls per batch             63.4                         21.25                         47.1                         73.8                 66.5      71.7
             $\Delta$  [-6.4]{style="color: red"}   [-10.2]{style="color: red"}   [-5.2]{style="color: red"}   [-4.3]{style="color: red"}      +0.9      +1.2

  : **In-distribution low-shot:** Evaluation of the mini-batch sampling distribution on in-distribution ImageNet linear evaluation using only $1\%$ of the training set.
:::

```{=latex}
\vspace{1.5em}
```
```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:indistribution_lowshot}
                                CIFAR100                   CIFAR100 $1\%$                   Place205                   Clevr/Count           Clevr/Dist   KITTI
  ------------------- ----------------------------- ----------------------------- ---------------------------- ---------------------------- ------------ -------
    960 cls per batch             84.3                          46.2                          81.0                         56.7                 63.7      73.2
      2 cls per batch             71.4                          26.4                          76.4                         50.3                 65.6      71.5
             $\Delta$  [-12.9]{style="color: red"}   [-13.2]{style="color: red"}   [-4.6]{style="color: red"}   [-6.4]{style="color: red"}      +1.9      -1.6

  : **In-distribution low-shot:** Evaluation of the mini-batch sampling distribution on in-distribution ImageNet linear evaluation using only $1\%$ of the training set.
:::

```{=latex}
\vspace{1.5em}
```
```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:indistribution_lowshot}
                                CIFAR100                  CIFAR100 $1\%$                   Place205                   Clevr/Count           Clevr/Dist   KITTI
  ------------------- ---------------------------- ----------------------------- ---------------------------- ---------------------------- ------------ -------
    960 cls per batch             69.7                         28.9                          51.0                         79.8                 69.0      73.3
      2 cls per batch             60.8                         16.9                          44.9                         76.8                69.84      72.1
             $\Delta$  [-8.9]{style="color: red"}   [-12.0]{style="color: red"}   [-6.0]{style="color: red"}   [-3.0]{style="color: red"}      +0.7      -1.1

  : **In-distribution low-shot:** Evaluation of the mini-batch sampling distribution on in-distribution ImageNet linear evaluation using only $1\%$ of the training set.
:::

```{=latex}
\vspace{1.5em}
```
```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:indistribution_lowshot}
                       CIFAR100   CIFAR100 $1\%$   Place205          Clevr/Count           Clevr/Dist   KITTI
  ------------------- ---------- ---------------- ---------- ---------------------------- ------------ -------
    960 cls per batch    50.3          13.7          37.0                76.8                 49.7      65.3
      2 cls per batch    48.6          13.1          37.0                74.7                 48.2      65.1
             $\Delta$    -1.7          -0.5           0       [-2.1]{style="color: red"}      +1.5      -0.2

  : **In-distribution low-shot:** Evaluation of the mini-batch sampling distribution on in-distribution ImageNet linear evaluation using only $1\%$ of the training set.
:::

```{=latex}
\vspace{1.5em}
```
```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:indistribution_lowshot}
                       CIFAR100   CIFAR100 $1\%$   Place205   Clevr/Count   Clevr/Dist   KITTI
  ------------------- ---------- ---------------- ---------- ------------- ------------ -------
    960 cls per batch    75.0          28.3          50.4        90.4          72.4      70.0
      2 cls per batch    75.4          30.8          50.3        89.6          71.7      70.0
             $\Delta$    +0.4          +2.5          -0.1        -0.8          -0.7      +0.0

  : **In-distribution low-shot:** Evaluation of the mini-batch sampling distribution on in-distribution ImageNet linear evaluation using only $1\%$ of the training set.
:::

```{=latex}
\vspace{1.5em}
```
```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:indistribution_lowshot}
                                         SimCLR                          MSN                         VICReg              data2vec   MAE
  --------------------------- ----------------------------- ----------------------------- ----------------------------- ---------- ------
      class balanced sampling             66.9                          77.1                          69.1                 41.5     65.9
    class imbalanced sampling             55.8                          59.4                          51.4                 40.6     65.8
                     $\Delta$  [-11.1]{style="color: red"}   [-17.7]{style="color: red"}   [-17.7]{style="color: red"}     -0.8     -0.1

  : **In-distribution low-shot:** Evaluation of the mini-batch sampling distribution on in-distribution ImageNet linear evaluation using only $1\%$ of the training set.
:::

```{=latex}
\vspace{1.5em}
```
```{=latex}
\centering
```
```{=latex}
\footnotesize
```
::: {#tb:indistribution_lowshot}
                                           MSN                         VICReg              data2vec   MAE
  --------------------------- ----------------------------- ----------------------------- ---------- ------
      class balanced sampling             66.2                          48.6                 27.4     35.1
    class imbalanced sampling             28.0                          18.1                 31.4     34.8
                     $\Delta$  [-38.2]{style="color: red"}   [-30.5]{style="color: red"}     +4.0     -0.3

  : **In-distribution low-shot:** Evaluation of the mini-batch sampling distribution on in-distribution ImageNet linear evaluation using only $1\%$ of the training set.
:::

[^1]: `massran@meta.com`

[^2]: We choose 960 classes instead of 1000 so that the overall batch size is divisible by the number of GPUs utilized for distributed training.

[^3]: The negative entropy in `\eqref{eq:msn}`{=latex} is simply the KL-divergence to the uniform distribution plus a constant.

[^4]: Note the image distribution is still uniform over the CIFAR10 classes.

[^5]: the one that is use to perform the clustering with respect to the prototypes
