DIAMOND (diffusion WM) + DreamerV3 (MBRL) · Checkpoint C4·b
entrenar un agente dentro de un world model y reproducir un resultado de paper. Verás las dos grandes formas de model-based RL: el world model generativo en píxeles (DIAMOND, difusión) y el world model latente recurrente (DreamerV3, RSSM). Cierra C4. Conecta directamente con tu spine de RL (Nivel 2).
3.1La idea de "soñar para aprender"
Model-based RL: en vez de aprender solo de la experiencia real (cara, lenta), el agente aprende un world model que predice (estado, acción) → siguiente estado + recompensa, y luego entrena su política dentro del modelo (en trayectorias "imaginadas"). Ventaja brutal: eficiencia de muestras — Atari 100k da solo 100k acciones (~2 h de juego humano), donde los métodos sin modelo necesitan 50M pasos (500× más).
Dos formas de construir ese world model:
- DIAMOND: el world model es un modelo de difusión que predice el siguiente frame en píxeles, condicionado a frames y acciones pasadas. Preservar detalle visual importa para RL (un píxel pequeño puede ser la bala que te mata). Loss de score-matching (denoising).
- DreamerV3: el world model es un RSSM que codifica observaciones en estados latentes estocásticos y predice futuros latentes + recompensa + continuación. Planifica en latente (más compacto y rápido).
3.2DIAMOND: teoría mínima
El world model de DIAMOND es un denoiser D_θ que, dado el frame ruidoso s_{t+1}^τ a paso de difusión τ y el historial (s_{t:t-h}, a_{t:t-h}), predice el frame limpio:
L(θ) = || D_θ(s_{t+1}^τ | s_{t:t-h}, a_{t:t-h}) − s_{t+1} ||
En inferencia, generas el siguiente frame resolviendo el proceso inverso de difusión (N3·C). El agente (un actor-critic) se entrena con RL sobre las trayectorias que el world model imagina. Al ser difusión, captura la multimodalidad del entorno (varios futuros posibles).
Datos de tamaño (clave para ti): el world model de Atari es ~4.4M parámetros, usa ~12 GB de VRAM y entrenar un juego×seed lleva ~días en una GPU de consumo. En tu 5090 (más rápida que la 4090 del paper) es perfectamente abordable para 1-2 juegos.
3.3Laboratorio L3.1 — Entrenar DIAMOND en Atari y reproducir el HNS
1# Repo oficial
2git clone https://github.com/eloialonso/diamond && cd diamond
3uv venv && source .venv/bin/activate
4uv pip install -r requirements.txt
5
6# Entrena en un juego de Atari 100k (Breakout es un buen primer objetivo)
7python src/main.py env.train.id=BreakoutNoFrameskip-v4 \
8 common.device=cuda \
9 wandb.mode=online
10# Esto: (1) recolecta experiencia, (2) entrena el diffusion world model,
11# (3) entrena el actor-critic DENTRO del world model, en bucle.Qué observar y cómo validar el checkpoint:
- El Human-Normalized Score (HNS):
(score_agente − score_aleatorio) / (score_humano − score_aleatorio). HNS=1.0 = nivel humano. DIAMOND reporta media 1.46 en los 26 juegos; por juego varía. Tu objetivo: igualar el HNS del paper en el juego que elijas (mira la tabla del paper para el target de Breakout). - Usa varias seeds (el paper usa 5): RL tiene varianza alta; reporta media±std, no un run afortunado.
1# lab_n4l3_hns.py — calcula el HNS de tu agente entrenado
2RANDOM = {"Breakout": 1.7, "Pong": -20.7} # scores de referencia (del paper/ALE)
3HUMAN = {"Breakout": 30.5, "Pong": 14.6}
4def hns(game, agent_score):
5 return (agent_score - RANDOM[game]) / (HUMAN[game] - RANDOM[game])
6# Evalúa tu checkpoint en N episodios, promedia el score, y compara hns(...) con el paper.3.4Laboratorio L3.2 — "Jugar" dentro del world model
DIAMOND incluye un modo donde tú juegas dentro del world model entrenado (el modelo es el juego). Es la demostración visceral de qué es un world model generativo:
1python src/play.py # carga el world model entrenado y te deja jugar DENTRO de él
2# Observa: el modelo "alucina" el juego frame a frame respondiendo a tus teclas.
3# Donde el world model es débil, verás artefactos -> intuición sobre sus límites.Esto no es un extra: ver el mundo soñado te enseña dónde el modelo entiende la dinámica y dónde la inventa, que es exactamente lo que limita el RL que se entrena dentro.
3.5DreamerV3: la alternativa latente (RSSM)
DreamerV3 es el otro gran enfoque y vale la pena entenderlo aunque no lo entrenes a fondo:
- RSSM: codifica cada observación en un estado latente con parte determinista (recurrente) y parte estocástica; predice el siguiente latente dado la acción. Planifica en este espacio compacto.
- Predice además recompensa y continuación (¿episodio sigue?), y reconstruye observaciones para que el latente sea informativo.
- Logros: primer agente en conseguir diamantes en Minecraft desde cero sin demos; hiperparámetros fijos funcionan en 150+ tareas (robustez rara en RL); escala 12M–200M params.
1# DreamerV3 (repo oficial danijar/dreamerv3) en un entorno de control/Atari
2git clone https://github.com/danijar/dreamerv3 && cd dreamerv3
3uv pip install -r requirements.txt
4python dreamerv3/main.py --configs atari --task atari_breakout --logdir ./logdir/run1
5# size200m para el modelo de 200M; tamaños menores entrenan más rápido en la 5090Cuándo cada uno: DIAMOND si los detalles visuales importan (entornos visualmente ricos) y quieres difusión; DreamerV3 si quieres un MBRL maduro, robusto y eficiente en latente, o entornos de control/long-horizon. Ambos caben en tu 5090 a escala Atari/control.
3.6CHECKPOINT C4(b) — criterio de aprobado (cierra C4)
- Entrenas DIAMOND en ≥1 juego de Atari 100k y reproduces (igualas) el HNS del paper para ese juego, con varias seeds (media±std).
- Juegas dentro del world model (L3.2) y sabes describir dónde el modelo captura bien la dinámica y dónde falla.
- Sabes explicar la diferencia DIAMOND (difusión, píxeles) vs DreamerV3 (RSSM, latente) y cuándo elegir cada uno.
Rúbrica: Nivel 3 si reproduces el HNS de un juego; Nivel 4 si reproduces varios, o entrenas el world model sobre un dataset propio (estilo CS:GO de DIAMOND) creando tu "motor neuronal".
Combinado con C4(a), esto cierra el Checkpoint C4 del nivel.
3.7Ejercicios
E1. Entrena DIAMOND en Breakout y en Pong. ¿En cuál te acercas más al HNS del paper? ¿Por qué uno es más difícil de modelar que otro?
E2. Reduce el nº de pasos de difusión en la generación de frames. ¿Cómo afecta a la calidad del "sueño" y al rendimiento final del agente? (Conecta con DDIM/consistency de N3·C.)
E3. Entrena DreamerV3 size12m vs size200m en el mismo task. ¿Cuánto mejora el escalado? Relaciónalo con los scaling experiments del paper.
3.8Trampas comunes
- Reportar un solo seed → RL tiene varianza alta; usa media±std.
- Comparar tu score crudo en vez del HNS normalizado.
- Esperar que el world model sea perfecto: sus artefactos limitan el RL (por eso L3.2 importa).
- Querer entrenar DIAMOND CS:GO (381M) entero en local → eso es cloud.
3.9Referencias
- DIAMOND (Alonso et al., NeurIPS 2024; repo eloialonso/diamond; diamond-wm.github.io). DreamerV3 (Hafner et al., Nature 2025; repo danijar/dreamerv3). Atari 100k benchmark (Kaiser et al. 2019).