Aller au contenu

Attention Flash

Ce que vous saurez dans 3 minutes

  • Pourquoi votre GPU passe plus de temps à “attendre des données” qu’à calculer lors de l’entraînement des LLM.
  • Le concept de IO-Awareness et comment FlashAttention utilise le Tiling et la Recomputation pour vaincre le goulot d’étranglement mémoire.
  • L’impact concret sur la longueur de contexte (passer de 8k à 128k tokens).

1. Comprendre

Pour comprendre FlashAttention, il faut comprendre l’architecture d’un GPU. C’est une histoire de plomberie.

Le Problème : HBM vs SRAM

Un GPU moderne (comme le A100 ou H100) possède deux types de mémoire :

  • HBM (High Bandwidth Memory) : Enorme (40GB - 80GB), mais “lente”. C’est le réservoir principal.
  • SRAM (Static RAM) : Minuscule (192KB par processeur), mais ultra-rapide. C’est l’établi du processeur.

Dans l’Attention standard, le GPU passe son temps à faire des allers-retours entre le réservoir (HBM) et l’établi (SRAM) pour calculer la matrice d’attention N×NN \times N. Pour une séquence longue, ces transferts (IO) deviennent le goulot d’étranglement principal. Le GPU “famine” (compute bound vs memory bound).

La Solution : Tiling (Carrelage)

FlashAttention découpe les matrices Q,K,VQ, K, V en petits blocs qui tiennent entiers dans la SRAM. Il effectue tout le calcul (Softmax inclus) dans la SRAM sans jamais renvoyer les résultats intermédiaires en HBM.

“Nous calculons l’attention exacte, mais nous le faisons en étant conscients des IO (Input/Output).” — Tri Dao, Auteur de FlashAttention

Visualisation de l’Architecture Mémoire

flowchart TB
    subgraph GPU["GPU NVIDIA (A100)"]
        direction TB
        HBM[("MO (HBM)\nGrand & Lent\n40-80 GB")]
        
        subgraph ComputeUnits["Streaming Multiprocessors (SMs)"]
            SRAM1["M1 (SRAM)\nPetit & Rapide\n192 KB"]
            Core1["Tensor Core"]
            SRAM1 <--> Core1
        end
    end

    %% Standard Attention Flow
    HBM -- "1. Lit Q, K (LENT)" --> SRAM1
    SRAM1 -- "2. Ecrit S (LENT)" --> HBM
    HBM -- "3. Lit S, V (LENT)" --> SRAM1
    SRAM1 -- "4. Ecrit O (LENT)" --> HBM
    
    linkStyle 2,3,4,5 stroke:red,stroke-width:2px,stroke-dasharray: 5 5;

    %% FlashAttention Note
    note[/"FlashAttention : \nTout reste dans la SRAM.\nZéro écriture intermédiaire en HBM."/]
    style note fill:#d4edda,stroke:#28a745
    
    SRAM1 -.- note

2. Appliquer

FlashAttention est aujourd’hui le moteur par défaut de la plupart des grands frameworks.

Utilisation avec PyTorch 2.0+

Depuis PyTorch 2.0, FlashAttention est intégré nativement via scaled_dot_product_attention (SDPA).

import torch
import torch.nn.functional as F
# Création des tenseurs Q, K, V (Batch, Heads, SeqLen, Dim)
# FlashAttention nécessite souvent float16 ou bfloat16 sur GPU
q = torch.randn(1, 32, 4096, 64, device="cuda", dtype=torch.float16)
k = torch.randn(1, 32, 4096, 64, device="cuda", dtype=torch.float16)
v = torch.randn(1, 32, 4096, 64, device="cuda", dtype=torch.float16)
# L'appel magique
# PyTorch choisira automatiquement le kernel le plus rapide (FlashAttention v2 si dispo)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
output = F.scaled_dot_product_attention(q, k, v)
print("Attention calculée avec succès via FlashAttention Core.")

Versions et Évolution

  • FlashAttention-1 (2022) : Introduit le Tiling et la Recomputation.
  • FlashAttention-2 (2023) : Optimise le parallélisme sur la dimension de la séquence et réduit les ops non-matmul. Jusqu’à 2x plus rapide que la v1.
  • FlashAttention-3 (2024) : Spécialisé pour l’architecture Hopper (H100) utilisant les instructions TMA (Tensor Memory Accelerator).

3. Aller plus loin

Recomputation (Gradient Checkpointing)

L’astuce la plus contre-intuitive de FlashAttention concerne la Backpropagation (l’apprentissage). Normalement, on stocke la matrice d’attention (énorme) pour calculer les gradients. FlashAttention ne la stocke pas. À la place, il la recalcule à la volée lors de la passe arrière. C’est contre-intuitif : faire plus de calculs (recalculer) pour aller plus vite ? Oui, parce que le calcul (sur SRAM) est tellement plus rapide que l’accès mémoire (HBM) que c’est gagnant.

Limites

FlashAttention est exact (pas d’approximation), mais :

  1. Il est complexe à implémenter (CUDA pur).
  2. Il nécessite des GPU récents (Ampere A100 min pour v2 optimale).
  3. Il est parfois moins flexible pour des masques d’attention exotiques (bien que FA2 supporte ALiBi et Sliding Window).

Questions Fréquentes

Est-ce que FlashAttention change le résultat du modèle ?

Non. C’est une optimisation exacte. Contrairement à “Sparse Attention” ou “Linear Attention” qui sont des approximations, Attention(Q,K,V) avec FlashAttention donne le même résultat numérique (aux erreurs d’arrondi float16 près) que l’attention standard.

Pourquoi ne pas augmenter la taille de la SRAM ?

La SRAM coûte extrêmement cher en silicium et en énergie. Les architectes de puces (Nvidia) doivent faire un compromis. C’est au logiciel (FlashAttention) de s’adapter au matériel.


Notions Liées (Spider Web)

Ressources Externes