Deep RL — CartPole Lab
Live demo and image generation scripts for Module 23 (Reinforcement Learning). Trains three methods — SARSA, Q-learning, DQN — on CartPole-v1 and lets you watch them learn in real time.
Prerequisites
- Python 3.9 or later
- No GPU needed — everything runs on CPU
- A working display for the live window (X11 or Wayland on Linux, native on macOS/Windows)
Setup (one time)
# 1. Navigate here
cd clase/23_reinforcement_learning/deep_rl
# 2. Create the virtual environment and install dependencies (~2-5 min first time)
./setup.sh
# 3. Activate the environment
source .venv/bin/activate
# You should see (.venv) in your prompt. You're ready.
To deactivate when you’re done:
deactivate
What gets installed
| Package | Version | Purpose |
|---|---|---|
torch |
2.5.1 (CPU) | Neural network for DQN |
gymnasium[classic-control] |
1.0.0 | CartPole-v1 environment |
numpy |
1.26.4 | Arrays and discretization |
matplotlib |
3.9.4 | Live window and plot generation |
tqdm |
4.67.1 | Terminal progress bars |
Live demo: demo_cartpole.py
Quick start
# Watch DQN learn (recommended first run)
python demo_cartpole.py --method dqn
# Compare all 3 methods, then see a final summary plot
python demo_cartpole.py --compare
All flags
| Flag | Options | Default | Description |
|---|---|---|---|
--method |
sarsa qlearning dqn |
dqn |
Algorithm to train |
--compare |
— | off | Run all 3 methods sequentially, show final comparison |
--episodes |
any integer | 500 |
Number of training episodes |
--speed |
fast normal slow |
normal |
How often to update the animation window |
Speed controls how many training steps occur between window refreshes:
| Speed | Steps per frame | Best for |
|---|---|---|
slow |
1 | Watching every single action |
normal |
5 | Balanced — smooth animation |
fast |
20 | Getting to convergence quickly, less animation detail |
Common commands
# Watch DQN converge over 500 episodes (default)
python demo_cartpole.py --method dqn
# Faster run — 200 episodes, update window every 20 steps
python demo_cartpole.py --method dqn --episodes 200 --speed fast
# Watch a tabular method (SARSA or Q-learning)
python demo_cartpole.py --method sarsa
python demo_cartpole.py --method qlearning
# Slow motion — see every step clearly
python demo_cartpole.py --method dqn --episodes 100 --speed slow
# Compare all 3 methods with 300 episodes each
python demo_cartpole.py --compare --episodes 300
# Quick comparison (150 episodes per method, ~2-3 min total)
python demo_cartpole.py --compare --episodes 150
The window layout
When you run a single method, a window opens with 4 panels and a status bar:
┌─────────────────────────┬──────────────────────────┐
│ │ │
│ CartPole animation │ Reward per episode │
│ (live rendering) │ + 50-ep rolling avg │
│ │ + solved line (475) │
├─────────────────────────┼──────────────────────────┤
│ │ │
│ DQN: MSE loss curve │ DQN: Q-value │
│ Tabular: ε decay │ distribution │
│ │ Tabular: Q-table │
│ │ heatmap │
├─────────────────────────┴──────────────────────────┤
│ Status bar: episode | ε | avg reward | loss │
└────────────────────────────────────────────────────┘
Top-left — CartPole animation: The physical simulation rendered in real time. Early on (ε high) the agent acts randomly and the pole falls immediately. As ε decays and the agent learns, the cart starts making smooth corrective movements to keep the pole up.
Top-right — Reward per episode: Each episode ends when the pole falls or reaches 500 steps. The thin line is the raw per-episode reward; the thick line is the 50-episode rolling average. The dashed line at 475 is the “solved” threshold. You want the thick line to cross that dashed line.
Bottom-left — Loss (DQN) or ε decay (tabular):
- DQN: MSE loss between the network’s prediction and the TD target. Starts high (network predicts random values), decreases as the network learns. A loss that doesn’t decrease suggests the learning rate is too high.
- Tabular methods: ε decays from 1.0 (pure exploration) to 0.05 (mostly greedy). The shape is always exponential — this panel confirms the schedule is working.
Bottom-right — Q-value distribution (DQN) or Q-table heatmap (tabular):
- DQN: Histogram of Q-values sampled from recent states. Early training: narrow cluster near 0. After learning: spread out with higher mean as the network correctly estimates long-horizon returns.
- Tabular methods: Heatmap showing max-Q averaged over the pole-angle × angular-velocity dimensions. Cells that are visited more often get higher values. After convergence you’ll see a pattern — near-zero angle states have high value (good), extreme-angle states have low value (about to fall).
Status bar: Shows at a glance: current episode, current ε, 50-episode average reward, and (for DQN) the most recent loss value.
What to watch for
With DQN (--method dqn):
- Episodes 1-50: Pole falls in ~10-20 steps. Animation looks chaotic.
- Episodes 50-150: Buffer fills up, loss starts dropping, reward slowly climbs.
- Episodes 150-300: ε is now below 0.3 — you can see the agent start correcting. Reward spikes up.
- Episodes 300-500: Rolling average crosses 475. Episodes take longer and longer (more steps = more reward = more time to render).
The per-episode time in the terminal is itself a signal: if episodes are taking 2-4 seconds each, the pole is surviving for hundreds of steps — the agent is solving it.
With tabular methods (--method sarsa/qlearning):
The reward curve climbs a bit early (the agent learns to avoid the worst actions) but then plateaus around 50-80. The discretization loses too much information — nearby continuous states get treated as unrelated. The Q-table heatmap shows which state regions got visited, but the agent can never generalize across bins.
With --compare:
The 3 methods run sequentially in the terminal (with tqdm progress bars). At the end, a single window shows all 3 convergence curves overlaid. The DQN line should climb clearly above the two flat tabular lines.
Headless environments (no display)
If you’re on a server without a display, the script detects this automatically:
No display found. Summary saved to demo_results.png
It trains the model, then saves a PNG with the reward curve. All training logic is identical.
Image generation: lab_deep_rl.py
Generates the 6 static images used in the course pages (08–13).
python lab_deep_rl.py
Output goes to ../images/. Expected runtime: ~5-8 minutes on a modern CPU.
| Image | File | What it shows |
|---|---|---|
| 08 | 08_dqn_architecture.png |
Network diagram: 4→64→64→2 with ReLU |
| 09 | 09_experience_replay.png |
Replay buffer: store arrow + random sample arrow |
| 10 | 10_target_network.png |
Online vs frozen target network, copy arrow, TD equation |
| 11 | 11_convergence_comparison.png |
All 3 methods over 500 episodes with rolling averages |
| 12 | 12_loss_curve.png |
DQN MSE loss vs training step with buffer-full annotation |
| 13 | 13_cartpole_frames.png |
4 frames from a solved episode (pole stays vertical) |
Files in this directory
deep_rl/
├── README.md ← this file
├── setup.sh ← creates .venv/ and installs dependencies
├── requirements.txt ← pinned versions (CPU torch, gymnasium, numpy, matplotlib, tqdm)
├── demo_cartpole.py ← live interactive demo
├── lab_deep_rl.py ← static image generation for course pages
└── .venv/ ← virtual environment (gitignored, created by setup.sh)
Troubleshooting
./setup.sh: Permission denied
chmod +x setup.sh && ./setup.sh
ModuleNotFoundError: No module named 'gymnasium'
The venv is not active. Run source .venv/bin/activate first.
Window doesn’t open (Linux)
Check that $DISPLAY is set: echo $DISPLAY. If empty, you’re in a headless session — the script will fall back to saving a PNG automatically.
Window opens but animation is blank/gray The first few episodes can have blank frames if CartPole terminates in 1 step (the very first reset). This clears up after episode 2-3.
Training is very slow
With --speed slow every step triggers a window refresh (matplotlib pause(0.001)), which is expensive. Use --speed fast if you mainly care about the learning curves, not the animation.
UserWarning: This figure includes Axes that are not compatible with tight_layout
Harmless — the colorbar for the Q-table heatmap causes this. Doesn’t affect output.
Want to reproduce exact results from the course images
Use the same seeds as lab_deep_rl.py:
import numpy, random, torch
numpy.random.seed(42); torch.manual_seed(42); random.seed(42)