Attention Flash
Ce que vous saurez dans 3 minutes
- Pourquoi votre GPU passe plus de temps à “attendre des données” qu’à calculer lors de l’entraînement des LLM.
- Le concept de IO-Awareness et comment FlashAttention utilise le Tiling et la Recomputation pour vaincre le goulot d’étranglement mémoire.
- L’impact concret sur la longueur de contexte (passer de 8k à 128k tokens).
1. Comprendre
Pour comprendre FlashAttention, il faut comprendre l’architecture d’un GPU. C’est une histoire de plomberie.
Le Problème : HBM vs SRAM
Un GPU moderne (comme le A100 ou H100) possède deux types de mémoire :
- HBM (High Bandwidth Memory) : Enorme (40GB - 80GB), mais “lente”. C’est le réservoir principal.
- SRAM (Static RAM) : Minuscule (192KB par processeur), mais ultra-rapide. C’est l’établi du processeur.
Dans l’Attention standard, le GPU passe son temps à faire des allers-retours entre le réservoir (HBM) et l’établi (SRAM) pour calculer la matrice d’attention . Pour une séquence longue, ces transferts (IO) deviennent le goulot d’étranglement principal. Le GPU “famine” (compute bound vs memory bound).
La Solution : Tiling (Carrelage)
FlashAttention découpe les matrices en petits blocs qui tiennent entiers dans la SRAM. Il effectue tout le calcul (Softmax inclus) dans la SRAM sans jamais renvoyer les résultats intermédiaires en HBM.
“Nous calculons l’attention exacte, mais nous le faisons en étant conscients des IO (Input/Output).” — Tri Dao, Auteur de FlashAttention
Visualisation de l’Architecture Mémoire
flowchart TB
subgraph GPU["GPU NVIDIA (A100)"]
direction TB
HBM[("MO (HBM)\nGrand & Lent\n40-80 GB")]
subgraph ComputeUnits["Streaming Multiprocessors (SMs)"]
SRAM1["M1 (SRAM)\nPetit & Rapide\n192 KB"]
Core1["Tensor Core"]
SRAM1 <--> Core1
end
end
%% Standard Attention Flow
HBM -- "1. Lit Q, K (LENT)" --> SRAM1
SRAM1 -- "2. Ecrit S (LENT)" --> HBM
HBM -- "3. Lit S, V (LENT)" --> SRAM1
SRAM1 -- "4. Ecrit O (LENT)" --> HBM
linkStyle 2,3,4,5 stroke:red,stroke-width:2px,stroke-dasharray: 5 5;
%% FlashAttention Note
note[/"FlashAttention : \nTout reste dans la SRAM.\nZéro écriture intermédiaire en HBM."/]
style note fill:#d4edda,stroke:#28a745
SRAM1 -.- note
2. Appliquer
FlashAttention est aujourd’hui le moteur par défaut de la plupart des grands frameworks.
Utilisation avec PyTorch 2.0+
Depuis PyTorch 2.0, FlashAttention est intégré nativement via scaled_dot_product_attention (SDPA).
import torchimport torch.nn.functional as F
# Création des tenseurs Q, K, V (Batch, Heads, SeqLen, Dim)# FlashAttention nécessite souvent float16 ou bfloat16 sur GPUq = torch.randn(1, 32, 4096, 64, device="cuda", dtype=torch.float16)k = torch.randn(1, 32, 4096, 64, device="cuda", dtype=torch.float16)v = torch.randn(1, 32, 4096, 64, device="cuda", dtype=torch.float16)
# L'appel magique# PyTorch choisira automatiquement le kernel le plus rapide (FlashAttention v2 si dispo)with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): output = F.scaled_dot_product_attention(q, k, v)
print("Attention calculée avec succès via FlashAttention Core.")| Séquence | Attention Standard | FlashAttention v2 | Gain |
|---|---|---|---|
| 2k | 10 ms | 5 ms | x2 |
| 8k | 120 ms | 20 ms | x6 |
| 16k | OOM (Out of Memory) | 50 ms | Infini |
Versions et Évolution
- FlashAttention-1 (2022) : Introduit le Tiling et la Recomputation.
- FlashAttention-2 (2023) : Optimise le parallélisme sur la dimension de la séquence et réduit les ops non-matmul. Jusqu’à 2x plus rapide que la v1.
- FlashAttention-3 (2024) : Spécialisé pour l’architecture Hopper (H100) utilisant les instructions TMA (Tensor Memory Accelerator).
3. Aller plus loin
Recomputation (Gradient Checkpointing)
L’astuce la plus contre-intuitive de FlashAttention concerne la Backpropagation (l’apprentissage). Normalement, on stocke la matrice d’attention (énorme) pour calculer les gradients. FlashAttention ne la stocke pas. À la place, il la recalcule à la volée lors de la passe arrière. C’est contre-intuitif : faire plus de calculs (recalculer) pour aller plus vite ? Oui, parce que le calcul (sur SRAM) est tellement plus rapide que l’accès mémoire (HBM) que c’est gagnant.
Limites
FlashAttention est exact (pas d’approximation), mais :
- Il est complexe à implémenter (CUDA pur).
- Il nécessite des GPU récents (Ampere A100 min pour v2 optimale).
- Il est parfois moins flexible pour des masques d’attention exotiques (bien que FA2 supporte ALiBi et Sliding Window).
Questions Fréquentes
Est-ce que FlashAttention change le résultat du modèle ?
Non. C’est une optimisation exacte. Contrairement à “Sparse Attention” ou “Linear Attention” qui sont des approximations, Attention(Q,K,V) avec FlashAttention donne le même résultat numérique (aux erreurs d’arrondi float16 près) que l’attention standard.
Pourquoi ne pas augmenter la taille de la SRAM ?
La SRAM coûte extrêmement cher en silicium et en énergie. Les architectes de puces (Nvidia) doivent faire un compromis. C’est au logiciel (FlashAttention) de s’adapter au matériel.
Notions Liées (Spider Web)
- Composant : Mecanisme d’Attention
- Matériel : GPU
- Optimisation : Quantization
Ressources Externes
- Papier Original : FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- GitHub : Dao-AILab/flash-attention