Authors:
(1) Bobby He, Department of Computer Science, ETH Zurich (Correspondence to: [email protected].);
(2) Thomas Hofmann, Department of Computer Science, ETH Zurich.
Simplifying Transformer Blocks
Discussion, Reproducibility Statement, Acknowledgements and References
A Duality Between Downweighted Residual and Restricting Updates In Linear Layers
Simplifying deep NNs by removing block components has received a lot of attention, both in transformers and other architectures. In these works, signal propagation theory often acts as inspiration.
It has been shown that judicious use of weight initialisations and architectural tools, like skip connections and normalisation layers, can improve signal propagation degeneracies and the trainability of deep NNs. Such considerations have motivated principled modifications with simpler architectures. De & Smith (2020) show that an implicit mechanism of Pre-LN skip connections is to downweight the residual branch relative to the skip branch, leading to better signal propagation. They also show that explicitly downweighting the residual branch allows normalisation layers to be removed without affecting performance. The idea of downweighting residuals for improved signal propagation & trainability has been studied extensively in the literature (Zhang et al., 2018; Hanin & Rolnick, 2018; Tarnowski et al., 2019; Zhang et al., 2019; Arpit et al., 2019; Xu et al., 2020; Bachlechner et al., 2021; Touvron et al., 2021; Hayou et al., 2021; Hayou & Yang, 2023; Martens et al., 2021; Davis et al., 2021; Noci et al., 2022; Wang et al., 2022a; Huang et al., 2020; Wang et al., 2022b).
For skip connections (He et al., 2016), it has been shown that transforming non-linear activation functions in MLPs and CNNs to be more linear according to a given deep architecture can enable good signal propagation even without skip connections (Martens et al., 2021; Zhang et al., 2022; Li et al., 2022). He et al. (2023) apply similar considerations to the self-attention mechanism, where the key insight is that attention matrices need to be more identity-like in order to prevent signal degradation in skipless transformers. However, these works find that skipless architectures suffer from significant losses in training speed compared to their residual counterparts, when using standard optimisers like SGD or Adam. Such differences were not observed with stronger optimisers like K-FAC (Martens & Grosse, 2015) on CNNs, and this inability to explain training phenomena highlights a current limitation of signal propagation theory. Ding et al. (2021; 2023) design a CNN, RepVGG, that can be trained like a residual architecture for fast per-update convergence, but reparameterised to be skipless at test time for significantly higher inference throughput. This reparameterisation is related to our considerations of value and projection parameters in Sec. 4.
Many works have considered simplifications or improvements specific to the transformer architecture. Most relevant to our work is the parallel block of Wang & Komatsuzaki (2021) (pictured Fig. 1, bottom right), which computes the MLP and attention sub-blocks in parallel for efficiency gains, with minimal performance loss. Trockman & Kolter (2023) observe that the product of value and projection parameters often has a large identity component in trained transformers, and design an initialisation mimicking this to improve performance in standard transformers on small datasets. We find these matrices can be fixed to the identity without loss of performance, which removes them from our simplified architecture. Other works have considered reducing the frequency of MLP sub-blocks (Sridhar et al., 2022; Pires et al., 2023) or efficient replacements to softmax attention (Katharopoulos et al., 2020; Schlag et al., 2021; Choromanski et al., 2021). Sukhbaatar et al. (2019) remove the MLP by integrating it into the attention sub-block, augmented with persistent memory