Tutorial 0: SGD Step Control with Batch JVP rho¶
This is the intended starting point for a new reader.
The goal is narrow: understand what changes when we take a standard SGD step and add a local rho-based controller.
In this notebook we compare:
- fixed SGD with three learning rates,
rho-capped SGD with the same three learning rates,- one
rho-set SGD run driven entirely by local geometry.
We use the sklearn digits dataset as a lightweight stand-in for MNIST so the notebook stays fast and interactive.
from pathlib import Path
import importlib.util
import subprocess
import sys
def in_colab():
try:
import google.colab # type: ignore
return True
except ImportError:
return False
def find_repo_root():
cwd = Path.cwd().resolve()
for base in [cwd, *cwd.parents]:
if (base / "src" / "ghosts").exists() and (
(base / "tutorials").exists() or (base / "experiments").exists()
):
return base
if in_colab():
repo = Path('/content/ghosts-of-softmax')
if not repo.exists():
subprocess.run(
[
'git', 'clone', '--depth', '1',
'https://github.com/piyush314/ghosts-of-softmax.git',
str(repo),
],
check=True,
)
return repo
raise RuntimeError(
'Run this notebook from inside the ghosts-of-softmax repository, '
'or open it in Google Colab so the setup cell can clone the repo automatically.'
)
REPO = find_repo_root()
SRC = REPO / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
def load_module(name, relative_path):
path = REPO / relative_path
module_dir = str(path.parent)
if module_dir not in sys.path:
sys.path.insert(0, module_dir)
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
spec.loader.exec_module(module)
return module
OUTPUT_ROOT = Path('/tmp/ghosts-of-softmax-notebooks')
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
print(f"Repo root: {REPO}")
print(f"Notebook outputs: {OUTPUT_ROOT}")
Repo root: /home/runner/work/ghosts-of-softmax/ghosts-of-softmax Notebook outputs: /tmp/ghosts-of-softmax-notebooks
1. Setup¶
The default path is deliberately simple:
- one architecture (
MLP), - one seed,
- three base learning rates spanning safe, borderline, and unstable behavior,
- one optional multiseed block at the end.
The purpose of the first pass is not to exhaust the design space. It is to make the controller's effect visually obvious.
import math
import random
from dataclasses import dataclass
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML, display
from ghosts.plotting import add_end_labels, add_subtitle, apply_plot_style, finish_figure, format_percent_axis
def display_table(rows, columns=None, formats=None):
if not rows:
print("No rows to display.")
return
if columns is None:
columns = list(rows[0].keys())
formats = formats or {}
parts = ['<table style="border-collapse:collapse">', '<thead><tr>']
for col in columns:
parts.append(f'<th style="text-align:left;padding:4px 8px;border-bottom:1px solid #ccc">{col}</th>')
parts.append('</tr></thead><tbody>')
for row in rows:
parts.append('<tr>')
for col in columns:
value = row.get(col, '')
if col in formats:
value = formats[col](value)
parts.append(f'<td style="padding:4px 8px">{value}</td>')
parts.append('</tr>')
parts.append('</tbody></table>')
display(HTML(''.join(parts)))
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from torch.func import functional_call, jvp
from torch.utils.data import DataLoader, TensorDataset
torch.set_num_threads(1)
DEVICE = torch.device("cpu")
ARCH_NAME = "MLP"
LOW_LR = 0.05
MID_LR = 0.50
HIGH_LR = 5.00
LRS = [LOW_LR, MID_LR, HIGH_LR]
EPOCHS = 24
BATCH_SIZE = 128
TARGET_R = 1.0
SEED = 7
RUN_MULTI_SEED = False
MULTI_SEEDS = [0, 1, 2, 3, 4]
PALETTE = {
"sgd": "#E3120B",
"rho_capped": "#006BA2",
"rho_set": "#00843D",
"dark": "#3D3D3D",
"mid": "#767676",
"light": "#D0D0D0",
}
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def build_digits_loaders(seed: int, batch_size: int = 128):
digits = load_digits()
X = digits.data.astype(np.float32) / 16.0
y = digits.target.astype(np.int64)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, random_state=seed, stratify=y
)
train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_ds = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)
return train_loader, test_loader
train_loader, test_loader = build_digits_loaders(SEED, BATCH_SIZE)
len(train_loader), len(test_loader)
(11, 2)
2. Model¶
We use one small MLP. That keeps the story focused on the controller rather than on architecture differences.
Later tutorials can revisit the same idea for Adam, SGD with momentum, and other model families.
class MLP(nn.Module):
def __init__(self, width=128):
super().__init__()
self.fc1 = nn.Linear(64, width)
self.fc2 = nn.Linear(width, width)
self.fc3 = nn.Linear(width, 10)
def forward(self, x):
x = F.gelu(self.fc1(x))
x = F.gelu(self.fc2(x))
return self.fc3(x)
def make_model():
return MLP().to(DEVICE)
make_model()
MLP( (fc1): Linear(in_features=64, out_features=128, bias=True) (fc2): Linear(in_features=128, out_features=128, bias=True) (fc3): Linear(in_features=128, out_features=10, bias=True) )
3. Batch JVP estimate of rho¶
For each training batch we do four things:
- compute the gradient,
- normalize it into a direction,
- push that direction through the logits with a JVP,
- convert the largest directional logit spread into
rho = pi / Delta_a.
The key conceptual point is this: fixed SGD and rho-capped SGD start from the same direction. The controller only changes the proposed step length when that step would be too large relative to the local radius.
Direct comparison: plain SGD versus JVP-based rho control¶
A plain SGD step uses only the gradient norm:
loss.backward()
grads, grad_norm = grad_dict(model)
tau_raw = base_lr * grad_norm
eta_eff = base_lr
The JVP-based controller keeps the same gradient direction, but asks one extra question: how far can we move along that direction before we leave the local convergence radius?
direction, grad_norm = unit_direction_from_grads(grads)
rho_batch = batch_rho_jvp(model, xb, direction)
tau_cap = target_r * rho_batch
tau = min(tau_raw, tau_cap)
eta_eff = tau / max(grad_norm, 1e-12)
The batch JVP itself is computed by pushing the normalized step direction through the logits:
def batch_rho_jvp(model, xb, direction):
params = {name: param.detach() for name, param in model.named_parameters()}
tangents = {name: direction[name] for name in params}
_, jvp_out = jvp(lambda p: functional_call(model, p, (xb,)), (params,), (tangents,))
spread = (jvp_out.max(dim=1).values - jvp_out.min(dim=1).values).amax().item()
return math.pi / max(spread, 1e-12)
So the only difference from plain SGD is this: we compute a local directional radius rho_batch, then shorten the SGD proposal if its raw step length tau_raw is too large.
@dataclass
class EpochStats:
train_loss: float
test_acc: float
mean_r: float
max_r: float
mean_eta: float
cap_fraction: float
def evaluate(model: nn.Module, loader: DataLoader):
model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for xb, yb in loader:
xb = xb.to(DEVICE)
yb = yb.to(DEVICE)
logits = model(xb)
total_loss += F.cross_entropy(logits, yb, reduction="sum").item()
correct += (logits.argmax(dim=1) == yb).sum().item()
total += len(yb)
return total_loss / total, correct / total
def grad_dict(model: nn.Module):
grads = {}
sq = 0.0
for name, param in model.named_parameters():
if param.grad is None:
continue
g = param.grad.detach().clone()
grads[name] = g
sq += float((g * g).sum().item())
norm = math.sqrt(max(sq, 1e-12))
return grads, norm
def unit_direction_from_grads(grads):
sq = sum(float((g * g).sum().item()) for g in grads.values())
norm = math.sqrt(max(sq, 1e-12))
return {name: -g / norm for name, g in grads.items()}, norm
def batch_rho_jvp(model: nn.Module, xb: torch.Tensor, direction):
params = {name: param.detach() for name, param in model.named_parameters()}
def f(pdict):
return functional_call(model, pdict, (xb,))
_, dlogits = jvp(f, (params,), (direction,))
spread = dlogits.max(dim=1).values - dlogits.min(dim=1).values
delta_a = float(spread.max().item())
rho = math.pi / max(delta_a, 1e-12)
return rho
def apply_sgd_update(model: nn.Module, eta_eff: float):
with torch.no_grad():
for param in model.parameters():
if param.grad is not None:
param.add_(param.grad, alpha=-eta_eff)
def run_training(mode: str, base_lr: float | None, seed: int, target_r: float = 1.0):
assert mode in {"sgd", "rho_capped", "rho_set"}
set_seed(seed)
train_loader, test_loader = build_digits_loaders(seed, BATCH_SIZE)
model = make_model()
history = {key: [] for key in [
"train_loss", "test_acc", "mean_r", "max_r", "mean_eta", "cap_fraction"
]}
for epoch in range(EPOCHS):
model.train()
batch_losses = []
batch_r = []
batch_eta = []
batch_capped = []
for xb, yb in train_loader:
xb = xb.to(DEVICE)
yb = yb.to(DEVICE)
model.zero_grad(set_to_none=True)
logits = model(xb)
loss = F.cross_entropy(logits, yb)
loss.backward()
grads, grad_norm = grad_dict(model)
direction, grad_norm = unit_direction_from_grads(grads)
rho_batch = batch_rho_jvp(model, xb, direction)
tau_raw = 0.0 if base_lr is None else base_lr * grad_norm
tau_cap = target_r * rho_batch
if mode == "sgd":
tau = tau_raw
eta_eff = float(base_lr)
capped = 0.0
elif mode == "rho_capped":
tau = min(tau_raw, tau_cap)
eta_eff = tau / max(grad_norm, 1e-12)
capped = float(tau_raw > tau_cap)
else:
tau = tau_cap
eta_eff = tau / max(grad_norm, 1e-12)
capped = 1.0
r = tau / max(rho_batch, 1e-12)
apply_sgd_update(model, eta_eff)
batch_losses.append(float(loss.item()))
batch_r.append(r)
batch_eta.append(eta_eff)
batch_capped.append(capped)
_, test_acc = evaluate(model, test_loader)
history["train_loss"].append(float(np.mean(batch_losses)))
history["test_acc"].append(float(test_acc))
history["mean_r"].append(float(np.mean(batch_r)))
history["max_r"].append(float(np.max(batch_r)))
history["mean_eta"].append(float(np.mean(batch_eta)))
history["cap_fraction"].append(float(np.mean(batch_capped)))
return history
RUN_SPECS = [
{"key": "sgd_low", "mode": "sgd", "base_lr": LOW_LR, "label": f"fixed SGD (LR={LOW_LR})"},
{"key": "sgd_mid", "mode": "sgd", "base_lr": MID_LR, "label": f"fixed SGD (LR={MID_LR})"},
{"key": "sgd_high", "mode": "sgd", "base_lr": HIGH_LR, "label": f"fixed SGD (LR={HIGH_LR})"},
{"key": "rho_capped_low", "mode": "rho_capped", "base_lr": LOW_LR, "label": f"rho-capped SGD (LR={LOW_LR})"},
{"key": "rho_capped_mid", "mode": "rho_capped", "base_lr": MID_LR, "label": f"rho-capped SGD (LR={MID_LR})"},
{"key": "rho_capped_high", "mode": "rho_capped", "base_lr": HIGH_LR, "label": f"rho-capped SGD (LR={HIGH_LR})"},
{"key": "rho_set", "mode": "rho_set", "base_lr": None, "label": "rho-set SGD"},
]
def final_summary_row(spec, hist):
return {
"mode": spec["mode"],
"label": spec["label"],
"base_lr": spec["base_lr"],
"final_acc": hist["test_acc"][-1],
"peak_r": max(hist["max_r"]),
"final_mean_eta": hist["mean_eta"][-1],
"final_cap_fraction": hist["cap_fraction"][-1],
}
4. Single-seed run¶
We now run all seven curves on the same dataset split and the same model.
The only difference between fixed SGD and rho-capped SGD is the step length on batches where the raw SGD proposal is too aggressive.
The rho-set run is different: it does not start from a chosen base learning rate at all. Its step length is set directly by the local geometry.
single_seed = {}
for spec in RUN_SPECS:
single_seed[spec["key"]] = run_training(spec["mode"], spec["base_lr"], SEED, target_r=TARGET_R)
summary_rows = [final_summary_row(spec, single_seed[spec["key"]]) for spec in RUN_SPECS]
display_table(
summary_rows,
columns=["label", "base_lr", "final_acc", "peak_r", "final_mean_eta", "final_cap_fraction"],
formats={
"base_lr": lambda x: "—" if x is None else f"{x:.2f}",
"final_acc": lambda x: f"{float(x):.3f}",
"peak_r": lambda x: f"{float(x):.3f}",
"final_mean_eta": lambda x: f"{float(x):.4f}",
"final_cap_fraction": lambda x: f"{float(x):.3f}",
},
)
| label | base_lr | final_acc | peak_r | final_mean_eta | final_cap_fraction |
|---|---|---|---|---|---|
| fixed SGD (LR=0.05) | 0.05 | 0.660 | 0.027 | 0.0500 | 0.000 |
| fixed SGD (LR=0.5) | 0.50 | 0.953 | 18.287 | 0.5000 | 0.000 |
| fixed SGD (LR=5.0) | 5.00 | 0.100 | 405057744658765602491465728.000 | 5.0000 | 0.000 |
| rho-capped SGD (LR=0.05) | 0.05 | 0.660 | 0.027 | 0.0500 | 0.000 |
| rho-capped SGD (LR=0.5) | 0.50 | 0.953 | 1.000 | 0.1422 | 1.000 |
| rho-capped SGD (LR=5.0) | 5.00 | 0.962 | 1.000 | 0.1655 | 1.000 |
| rho-set SGD | — | 0.962 | 1.000 | 0.1601 | 1.000 |
5. First look: loss and accuracy¶
Start with the two outputs that are easiest to read:
- training loss on a log scale,
- test accuracy on a linear scale.
Reading key:
- color = optimizer mode,
- line style = learning-rate choice,
- markers help when curves overlap.
The main question is whether the controller changes the optimization trajectory in a visible way.
apply_plot_style(font_size=10, title_size=12, label_size=10, tick_size=9)
epochs = np.arange(1, EPOCHS + 1)
style_map = {
"sgd_low": {"color": PALETTE["sgd"], "ls": "-", "marker": "o", "alpha": 0.82},
"sgd_mid": {"color": PALETTE["sgd"], "ls": "--", "marker": "o", "alpha": 0.82},
"sgd_high": {"color": PALETTE["sgd"], "ls": "-.", "marker": "o", "alpha": 0.82},
"rho_capped_low": {"color": PALETTE["rho_capped"], "ls": "-", "marker": "s", "alpha": 0.82},
"rho_capped_mid": {"color": PALETTE["rho_capped"], "ls": "--", "marker": "s", "alpha": 0.82},
"rho_capped_high": {"color": PALETTE["rho_capped"], "ls": "-.", "marker": "s", "alpha": 0.82},
"rho_set": {"color": PALETTE["rho_set"], "ls": "-", "marker": "^", "alpha": 0.90},
}
fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=True)
loss_labels = []
acc_labels = []
for spec in RUN_SPECS:
key = spec["key"]
hist = single_seed[key]
style = style_map[key]
axes[0].semilogy(
epochs, hist["train_loss"],
color=style["color"], ls=style["ls"], lw=2.2,
marker=style["marker"], ms=4.5, markevery=2, alpha=style["alpha"]
)
axes[1].plot(
epochs, hist["test_acc"],
color=style["color"], ls=style["ls"], lw=2.2,
marker=style["marker"], ms=4.5, markevery=2, alpha=style["alpha"]
)
weight = "bold" if key == "rho_set" else None
loss_labels.append((hist["train_loss"][-1], spec["label"], style["color"], weight))
acc_labels.append((hist["test_acc"][-1], spec["label"], style["color"], weight))
axes[0].set_title("Geometry control prevents the high-LR loss blow-up", loc="left", fontweight="bold")
add_subtitle(axes[0], "The same optimizer becomes stable once the step is capped by the local radius.", fontsize=9)
axes[1].set_title("The controller preserves accuracy when fixed SGD collapses", loc="left", fontweight="bold")
add_subtitle(axes[1], "Low LR stays safe; high LR only works once the step is shortened.", fontsize=9)
axes[0].set_ylabel("training loss")
axes[1].set_ylabel("test accuracy")
axes[1].set_ylim(0.0, 1.02)
format_percent_axis(axes[1], xmax=1.0)
for ax in axes:
ax.set_xlabel("epoch")
ax.grid(True, alpha=0.25)
add_end_labels(axes[0], epochs, loss_labels, fontsize=7)
add_end_labels(axes[1], epochs, acc_labels, fontsize=7)
fig.suptitle("Fixed SGD versus geometry-controlled SGD", y=0.99, fontsize=12, fontweight="bold")
finish_figure(fig, rect=[0, 0, 1, 0.94])
plt.show()
6. Diagnostic view: effective learning rate and normalized step size¶
Okay, now that the visible optimization behavior is clear, let's look at the controller quantities directly.
These two panels answer a more specific question:
- how does the effective learning rate change over training,
- and how does the normalized step size evolve for all seven curves?
Both panels use a log scale. We also keep the curves slightly transparent so overlapping trajectories remain visible.
fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=True)
eta_labels = []
r_labels = []
for spec in RUN_SPECS:
key = spec["key"]
hist = single_seed[key]
style = style_map[key]
axes[0].semilogy(
epochs, hist["mean_eta"],
color=style["color"], ls=style["ls"], lw=2.1,
marker=style["marker"], ms=4.0, markevery=2, alpha=0.72
)
axes[1].semilogy(
epochs, hist["max_r"],
color=style["color"], ls=style["ls"], lw=2.1,
marker=style["marker"], ms=4.0, markevery=2, alpha=0.72
)
weight = "bold" if key == "rho_set" else None
eta_labels.append((hist["mean_eta"][-1], spec["label"], style["color"], weight))
r_labels.append((hist["max_r"][-1], spec["label"], style["color"], weight))
axes[0].set_title("The controller changes the effective learning rate", loc="left", fontweight="bold")
add_subtitle(axes[0], "High-LR runs are throttled back toward the same geometric target.", fontsize=9)
axes[1].set_title("The normalized step stays near the same boundary", loc="left", fontweight="bold")
add_subtitle(axes[1], "rho-capped and rho-set runs keep max r near the target line.", fontsize=9)
axes[0].set_ylabel("effective learning rate")
axes[1].set_ylabel("max r per epoch")
axes[1].axhline(TARGET_R, color=PALETTE["dark"], ls=":", lw=1.2)
for ax in axes:
ax.set_xlabel("epoch")
ax.grid(True, alpha=0.25, which="both")
add_end_labels(axes[0], epochs, eta_labels, fontsize=7)
add_end_labels(axes[1], epochs, r_labels, fontsize=7)
fig.suptitle("Controller diagnostics", y=0.99, fontsize=12, fontweight="bold")
finish_figure(fig, rect=[0, 0, 1, 0.94])
plt.show()
7. Summary table¶
Now that the visual story is clear, here is the same comparison in tabular form.
You should read the columns as follows:
final_acc: where the run ends,peak_r: the largest normalized step seen during training,final_mean_eta: the average effective learning rate in the last epoch,final_cap_fraction: what fraction of steps were actively shortened in the last epoch.
single_seed = {}
for spec in RUN_SPECS:
single_seed[spec["key"]] = run_training(spec["mode"], spec["base_lr"], SEED, target_r=TARGET_R)
summary_rows = [final_summary_row(spec, single_seed[spec["key"]]) for spec in RUN_SPECS]
display_table(
summary_rows,
columns=["label", "base_lr", "final_acc", "peak_r", "final_mean_eta", "final_cap_fraction"],
formats={
"base_lr": lambda x: "—" if x is None else f"{x:.2f}",
"final_acc": lambda x: f"{float(x):.3f}",
"peak_r": lambda x: f"{float(x):.3f}",
"final_mean_eta": lambda x: f"{float(x):.4f}",
"final_cap_fraction": lambda x: f"{float(x):.3f}",
},
)
| label | base_lr | final_acc | peak_r | final_mean_eta | final_cap_fraction |
|---|---|---|---|---|---|
| fixed SGD (LR=0.05) | 0.05 | 0.660 | 0.027 | 0.0500 | 0.000 |
| fixed SGD (LR=0.5) | 0.50 | 0.953 | 18.287 | 0.5000 | 0.000 |
| fixed SGD (LR=5.0) | 5.00 | 0.100 | 405057744658765602491465728.000 | 5.0000 | 0.000 |
| rho-capped SGD (LR=0.05) | 0.05 | 0.660 | 0.027 | 0.0500 | 0.000 |
| rho-capped SGD (LR=0.5) | 0.50 | 0.953 | 1.000 | 0.1422 | 1.000 |
| rho-capped SGD (LR=5.0) | 5.00 | 0.962 | 1.000 | 0.1655 | 1.000 |
| rho-set SGD | — | 0.962 | 1.000 | 0.1601 | 1.000 |
8. Optional multiseed check¶
Once the single-seed story is clear, you can flip on RUN_MULTI_SEED and rerun this section.
That turns the notebook from a teaching example into a small reproducibility check.
if RUN_MULTI_SEED:
multi_rows = []
for spec in RUN_SPECS:
rows = []
for seed in MULTI_SEEDS:
hist = run_training(spec["mode"], spec["base_lr"], seed, target_r=TARGET_R)
rows.append(final_summary_row(spec, hist))
accs = [row["final_acc"] for row in rows]
peak_rs = [row["peak_r"] for row in rows]
caps = [row["final_cap_fraction"] for row in rows]
multi_rows.append({
"run": spec["label"],
"base_lr": "—" if spec["base_lr"] is None else f"{spec['base_lr']:.2f}",
"acc_median": float(np.median(accs)),
"acc_iqr": float(np.percentile(accs, 75) - np.percentile(accs, 25)),
"peak_r_median": float(np.median(peak_rs)),
"cap_fraction_median": float(np.median(caps)),
})
display_table(
multi_rows,
columns=["run", "base_lr", "acc_median", "acc_iqr", "peak_r_median", "cap_fraction_median"],
formats={
"acc_median": lambda x: f"{float(x):.3f}",
"acc_iqr": lambda x: f"{float(x):.3f}",
"peak_r_median": lambda x: f"{float(x):.3f}",
"cap_fraction_median": lambda x: f"{float(x):.3f}",
},
)
else:
print("RUN_MULTI_SEED is False. Flip it to True if you want the slower multiseed summary.")
RUN_MULTI_SEED is False. Flip it to True if you want the slower multiseed summary.
Where to go next¶
After this notebook, the intended sequence is:
- a momentum-SGD controller tutorial,
- an Adam controller tutorial,
- then the more theory-facing notebooks.