Kernels CUDA / Triton (Checkpoint C3a)
escribir un kernel fusionado en Triton que gane al baseline de PyTorch en tu 5090 en una operación memory-bound, perfilarlo, y explicar por qué gana. Esto te da rayos-X sobre todo lo demás: cuando un modelo va lento, sabrás si es el kernel, la memoria o el lanzamiento.
A.1 Por qué Triton y no CUDA C++ directo
CUDA C++ te hace razonar a nivel de hilo (qué hace cada thread, shared memory a mano, sincronización). Triton (DSL de OpenAI, en Python) te hace razonar a nivel de tile (qué hace cada bloque sobre un trozo de datos) y el compilador se encarga de vectorización, shared memory y scheduling. Resultado: escribes kernels correctos mucho más rápido, con rendimiento dentro del 10–20% del CUDA hecho a mano para ops memory-bound. Además, Triton es el backend que genera torch.compile — entenderlo es entender qué hace tu PyTorch por debajo.
Aprende CUDA C++ también (PMPP + GPU MODE lectures 1–4) para los fundamentos (jerarquía de memoria, coalescing, occupancy), pero escribe tus kernels de producción en Triton.
A.2 Caveat Blackwell honesto (define tu checkpoint)
En la RTX 5090 (SM_120): Triton corre correctamente y el compilador soporta Blackwell, pero SM_120 carece del subsistema TMEM que habilita las optimizaciones de kernel persistente de los Blackwell de datacenter (SM_100/B200). Implicación práctica:
- No persigas batir a cuBLAS en GEMM denso en la 5090: cuBLAS está hiperoptimizado y tú no tienes las features SM_100. Perderás.
- Persigue ganar en operaciones memory-bound fusionables (softmax, layernorm/RMSNorm, fused cross-entropy, dropout+activación): ahí la fusión reduce tráfico a HBM y ganas a PyTorch porque PyTorch lanza varios kernels que van y vuelven de memoria. Este es tu terreno de victoria y el objetivo de C3a.
A.3 El modelo mental: por qué un fused softmax gana
Softmax naïve en PyTorch lee el input tres veces desde HBM: una para el max (estabilidad numérica), una para exp, una para la suma; y escribe intermedios entre pasos. Un kernel fusionado hace los tres pasos en shared memory en una sola pasada → reduce el tráfico HBM ~3×. Como softmax es memory-bound (N0·L1), menos tráfico = más rápido. Esa es toda la magia.
El truco numérico clave es el online softmax: computar max y suma de exponenciales iterativamente en una sola pasada (en vez de dos), manteniendo un running-max y reescalando. Es la misma idea que hace posible FlashAttention.
A.4 Laboratorio A.1 — Fused softmax en Triton (código completo)
1# lab_n3a_softmax.py — fused softmax en Triton, con autotune y comparación vs PyTorch
2import torch, triton, triton.language as tl
3
4@triton.autotune(
5 configs=[triton.Config({"BLOCK_SIZE": bs}, num_warps=nw)
6 for bs in (256, 512, 1024, 2048, 4096) for nw in (2, 4, 8, 16)],
7 key=["n_cols"], # re-tunea cuando cambia el nº de columnas
8)
9@triton.jit
10def softmax_kernel(in_ptr, out_ptr, in_row_stride, out_row_stride,
11 n_cols, BLOCK_SIZE: tl.constexpr):
12 row = tl.program_id(0) # cada "programa" procesa una fila
13 col = tl.arange(0, BLOCK_SIZE) # offsets de columna del tile
14 mask = col < n_cols # evita leer fuera de la fila
15 x = tl.load(in_ptr + row * in_row_stride + col, mask=mask, other=-float("inf"))
16 x = x - tl.max(x, axis=0) # estabilidad numérica (resta el max)
17 num = tl.exp(x)
18 denom = tl.sum(num, axis=0)
19 y = num / denom
20 tl.store(out_ptr + row * out_row_stride + col, y, mask=mask)
21
22def softmax_triton(x: torch.Tensor) -> torch.Tensor:
23 n_rows, n_cols = x.shape
24 out = torch.empty_like(x)
25 softmax_kernel[(n_rows,)](x, out, x.stride(0), out.stride(0), n_cols) # grid = 1 prog por fila
26 return out
27
28if __name__ == "__main__":
29 x = torch.randn(8192, 4096, device="cuda", dtype=torch.float32)
30 # correctness
31 assert torch.allclose(softmax_triton(x), torch.softmax(x, dim=-1), atol=1e-5)
32
33 def bench(fn):
34 for _ in range(10): fn(x) # warmup (N0·L1)
35 torch.cuda.synchronize()
36 import time; t0 = time.perf_counter()
37 for _ in range(100): fn(x)
38 torch.cuda.synchronize()
39 dt = (time.perf_counter() - t0) / 100
40 bytes_rw = x.numel() * x.element_size() * 2 # leer + escribir
41 return dt * 1e6, bytes_rw / dt / 1e9 # us, GB/s efectivos
42 us_t, gbps_t = bench(softmax_triton)
43 us_p, gbps_p = bench(lambda z: torch.softmax(z, dim=-1))
44 print(f"Triton : {us_t:7.1f} us {gbps_t:7.0f} GB/s")
45 print(f"PyTorch: {us_p:7.1f} us {gbps_p:7.0f} GB/s")
46 print(f"speedup: {us_p/us_t:.2f}x")Líneas no triviales explicadas:
@triton.autotune(..., key=["n_cols"]): Triton prueba todas las configs (BLOCK_SIZE × num_warps) la primera vez para cadan_cols, cachea la mejor y la reutiliza. En producción, fijas la ganadora para evitar el coste de tuning.BLOCK_SIZE: tl.constexpr: debe ser constante en tiempo de compilación (Triton genera código especializado por valor). Por eso es un parámetro especial, no un argumento normal.tl.program_id(0): el índice del "programa" (bloque) en la grid; aquí cada programa procesa una fila entera → softmax por filas.mask = col < n_cols+other=-inf: si la fila no es múltiplo de BLOCK_SIZE, las columnas sobrantes se enmascaran; cargar-infhace que suexpsea 0 y no contaminen el softmax. Olvidar la máscara = lectura fuera de límites = resultados corruptos.x - tl.max(x): estabilidad numérica (evitaexpde números grandes → inf). Imprescindible.- El cálculo de
bytes_rwy GB/s conecta con el roofline: estás midiendo cuánto te acercas al ancho de banda de la 5090 (N0·L1).
Qué deberías ver: el kernel Triton iguala o supera a torch.softmax y se acerca más al ancho de banda pico, porque fusiona los pasos. Si no gana, revisa BLOCK_SIZE (autotune) y que estás midiendo bien (warmup + synchronize).
A.5 Laboratorio A.2 — Profiling con Nsight Compute
Ganar no basta; tienes que saber por qué. Perfila el kernel:
1# perfila el kernel Triton y mira occupancy, throughput de memoria, etc.
2ncu --set full --kernel-name softmax_kernel -o softmax_profile python lab_n3a_softmax.py
3ncu-ui softmax_profile.ncu-rep # GUI; o lee el resumen en terminalBusca: Memory Throughput (¿cerca del pico?), Achieved Occupancy (¿usas bien los SM?), DRAM vs L2 traffic (¿la fusión redujo el tráfico a DRAM?). Para tu lab notebook: explica el speedup en términos de estas métricas, no como "fue más rápido".
A.6 Progresión de kernels (haz estos en orden)
- Vector add (en CUDA C++ con
load_inliney en Triton) — el "hola mundo", para el flujo de compilar/lanzar/perfilar. - Fused softmax (A.1) — primera victoria memory-bound.
- Fused RMSNorm — el normalizador de los LLM modernos; mismo patrón.
- Fused cross-entropy — el de Unsloth; combina logits→loss sin materializar todo.
- (stretch) un tile de atención estilo FlashAttention (online softmax + tiling de K,V) — el jefe final del track.
Recursos: GPU MODE lectures (1: custom kernels+profiling; 2–4: CUDA básico; 12: Flash Attention), Triton-Puzzles, los tutoriales oficiales de Triton, CS336 L6.
A.7 CHECKPOINT C3a — criterio de aprobado
- Un kernel fusionado en Triton (softmax, RMSNorm o cross-entropy) que gana al baseline de PyTorch en tu 5090, con
allcloseverificando correctness. - Benchmark reproducible (warmup + synchronize) con speedup y GB/s efectivos.
- Explicación del por qué gana, respaldada con métricas de Nsight (reducción de tráfico DRAM por fusión, occupancy).
- Entiendes y puedes explicar el caveat SM_120/TMEM (A.2): por qué eliges una op memory-bound y no GEMM.
Rúbrica: Nivel 3 si ganas y lo explicas; Nivel 4 si contribuyes un kernel a un repo (Liger-Kernel, etc.) o compites en el leaderboard de GPU MODE (kernelbot/popcorn-cli). Bonus que cruza con el spine: notebook Unsloth "Automatic Kernel Creation with RL" — entrenar un modelo que escribe kernels.
A.8 Ejercicios
E1. Implementa el fused RMSNorm en Triton y compáralo con torch.nn.RMSNorm. ¿Cuánto tráfico DRAM ahorras (Nsight)?
E2. Quita la máscara del softmax y usa una matriz cuyo n_cols no sea potencia de 2. Observa el resultado corrupto. (Aprender el fallo > evitarlo.)
E3. Intenta batir a cuBLAS en un GEMM y documenta por qué pierdes en la 5090 (conecta con A.2). Entender por qué pierdes es tan valioso como ganar.
A.9 Trampas comunes
- Medir sin warmup/synchronize (N0·L1).
- Olvidar la máscara → lectura fuera de límites.
- Perseguir GEMM en SM_120 → frustración; ve a memory-bound.
- No perfilar → "es más rápido" sin saber por qué = Nivel 2, no 3.
A.10 Referencias
- PMPP 4ª ed.; GPU MODE (lectures, Triton-Puzzles, resource-stream); tutoriales oficiales de Triton; CS336 L6; "Learning Triton One Kernel at a Time" (TDS). Caveat Blackwell SM_120/TMEM: Spheron "Triton Kernel on GPU Cloud 2026".