White-Box Transformers via Sparse Rate Reduction

@ NeurIPS 2023 (and more!)
Yaodong Yu1
Sam Buchanan2
Druv Pai1
Tianzhe Chu1 3
Ziyang Wu1
Shengbang Tong1
Hao Bai4
Yuexiang Zhai1
Ben Haeffele5
Yi Ma1 6
1UC Berkeley   2TTIC   3ShanghaiTech   4UIUC   5JHU   6HKU  

TL;DR: CRATE is a transformer-like architecture which is constructed through first principles and has competitive performance on standard tasks while also enjoying many side benefits.

What is CRATE?

CRATE (Coding RATE transformer) is a white-box (mathematically interpretable) transformer architecture, where each layer performs a single step of an alternating minimization algorithm to optimize the sparse rate reduction objective \[\max_{f}\mathbb{E}_{\boldsymbol{Z} = f(\boldsymbol{X})}[\Delta R(\boldsymbol{Z} \mid \boldsymbol{U}_{[K]}) - \lambda \|\boldsymbol{Z}\|_{0}],\] where the \(\ell^{0}\) norm promotes the sparsity of the final token representations \(\boldsymbol{Z} = f(\boldsymbol{X})\). The function \(f\) is defined as \[f = f^{L} \circ f^{L - 1} \circ \cdots \circ f^{1} \circ f^{\mathrm{pre}},\] where \(f^{\mathrm{pre}}\) is the pre-processing mapping, and \(f^{\ell}\) is the \(\ell^{\mathrm{th}}\)-layer forward mapping that transforms the token distribution to optimize the above sparse rate reduction objective incrementally. More specifically, \(f^{\ell}\) transforms the token representations \(\boldsymbol{Z}^{\ell}\), which are the representations at the input of the \(\ell^{\mathrm{th}}\) layer, to \(\boldsymbol{Z}^{\ell + 1}\) via the \(\texttt{MSSA}\) (Multi-Head Subspace Self-Attention) block and the \(\texttt{ISTA}\) (Iterative Shrinkage-Thresholding Algorithm) block, i.e., \[\boldsymbol{Z}^{\ell + 1} = f^{\ell}(\boldsymbol{Z}^{\ell}) = \texttt{ISTA}(\boldsymbol{Z}^{\ell} + \texttt{MSSA}(\boldsymbol{Z}^{\ell})).\]

Architecture

The following figure presents an overview of the general CRATE architecture:

After encoding input data \(\boldsymbol{X}\) as a sequence of tokens \(\boldsymbol{Z}^1\), CRATE constructs a deep network that transforms the data to a canonical configuration of low-dimensional subspaces by successive compression against a local model for the distribution, generating \(\boldsymbol{Z}^{\ell+1/2}\), and sparsification against a global dictionary, generating \(\boldsymbol{Z}^{\ell+1}\). Repeatedly stacking these blocks and training the model parameters via backpropagation yields a powerful and interpretable representation of the data.

The full architecture is simply a concatenation of such layers, with some initial tokenizer and final task-specific architecture (i.e., a classification head).

Classification

Below, the classification pipeline for CRATE is depicted. It is virtually identical to the popular vision transformer.

We use soft-max cross entropy loss to train on the supervised image classification task. We obtain competitive performance with the usual vision transformer (ViT) trained on classification, with similar scaling behavior, including above 80% top-1 accuracy on ImageNet-1K with 25% of the parameters of ViT.

Segmentation and Object Detection

An interesting phenomenon of CRATE is that even when trained on supervised classification, it learns to segment the input images, with such segmentations being easily recoverable via attention maps, as in the following pipeline (similar to DINO).

Such segmentations were only previously seen in transformer-like architectures using a complex self-supervised training mechanism as in DINO, yet in CRATE, segmentation emerges as a byproduct of supervised classification training. In particular, the model does not obtain any a priori segmentation information at any time. Below, we show some example segmentations.

Another remarkable property is that attention heads in CRATE automatically carry semantic meaning, which implies that CRATE may have post-hoc interpretability for any classification it makes. Below, we visualize the output of some attention heads across several images and several animals, showing that some attention heads correspond to different parts of the animal, and this correspondence is consistent across different animals and different classes of animals.

Autoencoding

We extend CRATE to autoencoding by using the following pipeline. We construct each layer of the decoder \(f^{\ell}\) in a diffusion/optimal transport-inspired way: if we think of \(f^{\ell}\) as transporting the probability mass of its input distribution in some way, then \(g^{L - \ell}\) is constructed to be an approximate inverse of this transportation map. The full encoder and decoder layers are given below. This variant of the CRATE architecture achieves competitive performance on the masked autoencoding task, as shown by the below examples. In addition, it obtains the same emergent properties (shown above) as the classification-trained CRATE.

Theoretical Principles

We derive the encoder architecture via unrolled optimization on the sparse rate reduction. The representations which optimize the sparse rate reduction are compressed and sparse, as in the below figure, where we depict them as being realized by some encoder \(f\): In CRATE, the compression operator \(\boldsymbol{Z}^{\ell} \to \boldsymbol{Z}^{\ell + 1/2}\) and sparsification operator \(\boldsymbol{Z}^{\ell + 1/2} \to \boldsymbol{Z}^{\ell + 1}\) are approximate (proximal) gradient steps on different parts of the sparse rate reduction objective.

To derive the decoder architecture, we propose a novel framework of structured denoising-diffusion, which is analogous to the common (ordinary) denoising-diffusion framework popularly used for generative modelling of imagery data. Our framework relies on a quantitative connection between the compression operator and the score function (as used in denoising-diffusion models), as shown below: The encoder and decoder are derived through discretizations of the structured denoising and diffusion processes, respectively. Importantly, the encoder derived from unrolled optimization and the encoder derived from structured denoising have the same architecture, which is described above.

Acknowledgements

This work is partially supported by the ONR grant N00014-22-1-2102, and the joint Simons Foundation-NSF DMS grant 2031899.

This website template was adapted from Brent Yi's project page for TILTED.