Provably learning a multi-head attention layer

Sitan Chen
Harvard University
John A. Paulson School of Engineering and Applied Sciences

Despite the widespread empirical success of transformers, little is known about their learnability from a computational perspective. In practice these models are trained with SGD on a certain next-token prediction objective, but in theory it remains a mystery even to prove that such functions can be learned efficiently at all. In this work, we give the first nontrivial provable algorithms and computational lower bounds for this problem. Our results apply in a realizable setting where one is given random sequence-to-sequence pairs that are generated by some unknown multi-head attention layer. Our algorithm, which is centered around using examples to sculpt a convex body containing the unknown parameters, is a significant departure from existing provable algorithms for learning multi-layer perceptrons, which predominantly exploit fine-grained algebraic and rotation invariance properties of the input distribution.


Back to EnCORE Workshop on Computational vs Statistical Gaps in Learning and Optimization