One-Layer Transformer Provably Learns One-Nearest Neighbor In Context

Zihao Li, Yuan Cao, Cheng Gao, Yihan He, Han Liu, Jason M. Klusowski, Jianqing Fan, Mengdi Wang

Research output: Contribution to journalConference articlepeer-review

Abstract

Transformers have achieved great success in recent years. Interestingly, transformers have shown particularly strong in-context learning capability - even without fine-tuning, they are still able to solve unseen tasks well purely based on task-specific prompts. In this paper, we study the capability of one-layer transformers in learning one of the most classical nonparametric estimators, the one-nearest neighbor prediction rule. Under a theoretical framework where the prompt contains a sequence of labeled training data and unlabeled test data, we show that, although the loss function is nonconvex when trained with gradient descent, a single softmax attention layer can successfully learn to behave like a one-nearest neighbor classifier. Our result gives a concrete example of how transformers can be trained to implement nonparametric machine learning algorithms, and sheds light on the role of softmax attention in transformer models.

Original languageEnglish (US)
JournalAdvances in Neural Information Processing Systems
Volume37
StatePublished - 2024
Event38th Conference on Neural Information Processing Systems, NeurIPS 2024 - Vancouver, Canada
Duration: Dec 9 2024Dec 15 2024

Funding

We thank the anonymous reviewers for their helpful comments. Yuan Cao is partially supported by NSFC 12301657 and Hong Kong RGC-ECS 27308624. Han Liu's research is partially supported by the NIH R01LM01372201. Jason M. Klusowski was supported in part by the National Science Foundation through CAREER DMS-2239448, DMS-2054808 and HDR TRIPODS CCF-1934924. Jianqing Fan's research was partially supported by NSF grants DMS-2210833, DMS-2053832, and ONR grant N00014-22-1-2340. Mengdi Wang acknowledges the support by NSF IIS-2107304, NSF CPS-2312093, ONR 1006977 and Genmab.

ASJC Scopus subject areas

  • Computer Networks and Communications
  • Information Systems
  • Signal Processing

Fingerprint

Dive into the research topics of 'One-Layer Transformer Provably Learns One-Nearest Neighbor In Context'. Together they form a unique fingerprint.

Cite this