Tutorial 0: SGD Step Control with Batch JVP rho¶

This is the intended starting point for a new reader.

The goal is narrow: understand what changes when we take a standard SGD step and add a local rho-based controller.

In this notebook we compare:

  • fixed SGD with three learning rates,
  • rho-capped SGD with the same three learning rates,
  • one rho-set SGD run driven entirely by local geometry.

We use the sklearn digits dataset as a lightweight stand-in for MNIST so the notebook stays fast and interactive.

In [1]:
from pathlib import Path
import importlib.util
import subprocess
import sys


def in_colab():
    try:
        import google.colab  # type: ignore
        return True
    except ImportError:
        return False


def find_repo_root():
    cwd = Path.cwd().resolve()
    for base in [cwd, *cwd.parents]:
        if (base / "src" / "ghosts").exists() and (
            (base / "tutorials").exists() or (base / "experiments").exists()
        ):
            return base

    if in_colab():
        repo = Path('/content/ghosts-of-softmax')
        if not repo.exists():
            subprocess.run(
                [
                    'git', 'clone', '--depth', '1',
                    'https://github.com/piyush314/ghosts-of-softmax.git',
                    str(repo),
                ],
                check=True,
            )
        return repo

    raise RuntimeError(
        'Run this notebook from inside the ghosts-of-softmax repository, '
        'or open it in Google Colab so the setup cell can clone the repo automatically.'
    )


REPO = find_repo_root()
SRC = REPO / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))


def load_module(name, relative_path):
    path = REPO / relative_path
    module_dir = str(path.parent)
    if module_dir not in sys.path:
        sys.path.insert(0, module_dir)
    spec = importlib.util.spec_from_file_location(name, path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[name] = module
    spec.loader.exec_module(module)
    return module


OUTPUT_ROOT = Path('/tmp/ghosts-of-softmax-notebooks')
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
print(f"Repo root: {REPO}")
print(f"Notebook outputs: {OUTPUT_ROOT}")
Repo root: /home/runner/work/ghosts-of-softmax/ghosts-of-softmax
Notebook outputs: /tmp/ghosts-of-softmax-notebooks

1. Setup¶

The default path is deliberately simple:

  • one architecture (MLP),
  • one seed,
  • three base learning rates spanning safe, borderline, and unstable behavior,
  • one optional multiseed block at the end.

The purpose of the first pass is not to exhaust the design space. It is to make the controller's effect visually obvious.

In [2]:
import math
import random
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML, display
from ghosts.plotting import add_end_labels, add_subtitle, apply_plot_style, finish_figure, format_percent_axis


def display_table(rows, columns=None, formats=None):
    if not rows:
        print("No rows to display.")
        return
    if columns is None:
        columns = list(rows[0].keys())
    formats = formats or {}
    parts = ['<table style="border-collapse:collapse">', '<thead><tr>']
    for col in columns:
        parts.append(f'<th style="text-align:left;padding:4px 8px;border-bottom:1px solid #ccc">{col}</th>')
    parts.append('</tr></thead><tbody>')
    for row in rows:
        parts.append('<tr>')
        for col in columns:
            value = row.get(col, '')
            if col in formats:
                value = formats[col](value)
            parts.append(f'<td style="padding:4px 8px">{value}</td>')
        parts.append('</tr>')
    parts.append('</tbody></table>')
    display(HTML(''.join(parts)))
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from torch.func import functional_call, jvp
from torch.utils.data import DataLoader, TensorDataset

torch.set_num_threads(1)

DEVICE = torch.device("cpu")
ARCH_NAME = "MLP"
LOW_LR = 0.05
MID_LR = 0.50
HIGH_LR = 5.00
LRS = [LOW_LR, MID_LR, HIGH_LR]
EPOCHS = 24
BATCH_SIZE = 128
TARGET_R = 1.0
SEED = 7

RUN_MULTI_SEED = False
MULTI_SEEDS = [0, 1, 2, 3, 4]

PALETTE = {
    "sgd": "#E3120B",
    "rho_capped": "#006BA2",
    "rho_set": "#00843D",
    "dark": "#3D3D3D",
    "mid": "#767676",
    "light": "#D0D0D0",
}


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def build_digits_loaders(seed: int, batch_size: int = 128):
    digits = load_digits()
    X = digits.data.astype(np.float32) / 16.0
    y = digits.target.astype(np.int64)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.25, random_state=seed, stratify=y
    )
    train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
    test_ds = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)
    return train_loader, test_loader


train_loader, test_loader = build_digits_loaders(SEED, BATCH_SIZE)
len(train_loader), len(test_loader)
Out[2]:
(11, 2)

2. Model¶

We use one small MLP. That keeps the story focused on the controller rather than on architecture differences.

Later tutorials can revisit the same idea for Adam, SGD with momentum, and other model families.

In [3]:
class MLP(nn.Module):
    def __init__(self, width=128):
        super().__init__()
        self.fc1 = nn.Linear(64, width)
        self.fc2 = nn.Linear(width, width)
        self.fc3 = nn.Linear(width, 10)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        return self.fc3(x)


def make_model():
    return MLP().to(DEVICE)


make_model()
Out[3]:
MLP(
  (fc1): Linear(in_features=64, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
)

3. Batch JVP estimate of rho¶

For each training batch we do four things:

  1. compute the gradient,
  2. normalize it into a direction,
  3. push that direction through the logits with a JVP,
  4. convert the largest directional logit spread into rho = pi / Delta_a.

The key conceptual point is this: fixed SGD and rho-capped SGD start from the same direction. The controller only changes the proposed step length when that step would be too large relative to the local radius.

Direct comparison: plain SGD versus JVP-based rho control¶

A plain SGD step uses only the gradient norm:

loss.backward()
grads, grad_norm = grad_dict(model)
tau_raw = base_lr * grad_norm
eta_eff = base_lr

The JVP-based controller keeps the same gradient direction, but asks one extra question: how far can we move along that direction before we leave the local convergence radius?

direction, grad_norm = unit_direction_from_grads(grads)
rho_batch = batch_rho_jvp(model, xb, direction)
tau_cap = target_r * rho_batch
tau = min(tau_raw, tau_cap)
eta_eff = tau / max(grad_norm, 1e-12)

The batch JVP itself is computed by pushing the normalized step direction through the logits:

def batch_rho_jvp(model, xb, direction):
    params = {name: param.detach() for name, param in model.named_parameters()}
    tangents = {name: direction[name] for name in params}
    _, jvp_out = jvp(lambda p: functional_call(model, p, (xb,)), (params,), (tangents,))
    spread = (jvp_out.max(dim=1).values - jvp_out.min(dim=1).values).amax().item()
    return math.pi / max(spread, 1e-12)

So the only difference from plain SGD is this: we compute a local directional radius rho_batch, then shorten the SGD proposal if its raw step length tau_raw is too large.

In [4]:
@dataclass
class EpochStats:
    train_loss: float
    test_acc: float
    mean_r: float
    max_r: float
    mean_eta: float
    cap_fraction: float


def evaluate(model: nn.Module, loader: DataLoader):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            logits = model(xb)
            total_loss += F.cross_entropy(logits, yb, reduction="sum").item()
            correct += (logits.argmax(dim=1) == yb).sum().item()
            total += len(yb)
    return total_loss / total, correct / total


def grad_dict(model: nn.Module):
    grads = {}
    sq = 0.0
    for name, param in model.named_parameters():
        if param.grad is None:
            continue
        g = param.grad.detach().clone()
        grads[name] = g
        sq += float((g * g).sum().item())
    norm = math.sqrt(max(sq, 1e-12))
    return grads, norm


def unit_direction_from_grads(grads):
    sq = sum(float((g * g).sum().item()) for g in grads.values())
    norm = math.sqrt(max(sq, 1e-12))
    return {name: -g / norm for name, g in grads.items()}, norm


def batch_rho_jvp(model: nn.Module, xb: torch.Tensor, direction):
    params = {name: param.detach() for name, param in model.named_parameters()}

    def f(pdict):
        return functional_call(model, pdict, (xb,))

    _, dlogits = jvp(f, (params,), (direction,))
    spread = dlogits.max(dim=1).values - dlogits.min(dim=1).values
    delta_a = float(spread.max().item())
    rho = math.pi / max(delta_a, 1e-12)
    return rho


def apply_sgd_update(model: nn.Module, eta_eff: float):
    with torch.no_grad():
        for param in model.parameters():
            if param.grad is not None:
                param.add_(param.grad, alpha=-eta_eff)


def run_training(mode: str, base_lr: float | None, seed: int, target_r: float = 1.0):
    assert mode in {"sgd", "rho_capped", "rho_set"}
    set_seed(seed)
    train_loader, test_loader = build_digits_loaders(seed, BATCH_SIZE)
    model = make_model()

    history = {key: [] for key in [
        "train_loss", "test_acc", "mean_r", "max_r", "mean_eta", "cap_fraction"
    ]}

    for epoch in range(EPOCHS):
        model.train()
        batch_losses = []
        batch_r = []
        batch_eta = []
        batch_capped = []

        for xb, yb in train_loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)

            model.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            loss.backward()

            grads, grad_norm = grad_dict(model)
            direction, grad_norm = unit_direction_from_grads(grads)
            rho_batch = batch_rho_jvp(model, xb, direction)

            tau_raw = 0.0 if base_lr is None else base_lr * grad_norm
            tau_cap = target_r * rho_batch

            if mode == "sgd":
                tau = tau_raw
                eta_eff = float(base_lr)
                capped = 0.0
            elif mode == "rho_capped":
                tau = min(tau_raw, tau_cap)
                eta_eff = tau / max(grad_norm, 1e-12)
                capped = float(tau_raw > tau_cap)
            else:
                tau = tau_cap
                eta_eff = tau / max(grad_norm, 1e-12)
                capped = 1.0

            r = tau / max(rho_batch, 1e-12)
            apply_sgd_update(model, eta_eff)

            batch_losses.append(float(loss.item()))
            batch_r.append(r)
            batch_eta.append(eta_eff)
            batch_capped.append(capped)

        _, test_acc = evaluate(model, test_loader)
        history["train_loss"].append(float(np.mean(batch_losses)))
        history["test_acc"].append(float(test_acc))
        history["mean_r"].append(float(np.mean(batch_r)))
        history["max_r"].append(float(np.max(batch_r)))
        history["mean_eta"].append(float(np.mean(batch_eta)))
        history["cap_fraction"].append(float(np.mean(batch_capped)))

    return history


RUN_SPECS = [
    {"key": "sgd_low", "mode": "sgd", "base_lr": LOW_LR, "label": f"fixed SGD (LR={LOW_LR})"},
    {"key": "sgd_mid", "mode": "sgd", "base_lr": MID_LR, "label": f"fixed SGD (LR={MID_LR})"},
    {"key": "sgd_high", "mode": "sgd", "base_lr": HIGH_LR, "label": f"fixed SGD (LR={HIGH_LR})"},
    {"key": "rho_capped_low", "mode": "rho_capped", "base_lr": LOW_LR, "label": f"rho-capped SGD (LR={LOW_LR})"},
    {"key": "rho_capped_mid", "mode": "rho_capped", "base_lr": MID_LR, "label": f"rho-capped SGD (LR={MID_LR})"},
    {"key": "rho_capped_high", "mode": "rho_capped", "base_lr": HIGH_LR, "label": f"rho-capped SGD (LR={HIGH_LR})"},
    {"key": "rho_set", "mode": "rho_set", "base_lr": None, "label": "rho-set SGD"},
]


def final_summary_row(spec, hist):
    return {
        "mode": spec["mode"],
        "label": spec["label"],
        "base_lr": spec["base_lr"],
        "final_acc": hist["test_acc"][-1],
        "peak_r": max(hist["max_r"]),
        "final_mean_eta": hist["mean_eta"][-1],
        "final_cap_fraction": hist["cap_fraction"][-1],
    }

4. Single-seed run¶

We now run all seven curves on the same dataset split and the same model.

The only difference between fixed SGD and rho-capped SGD is the step length on batches where the raw SGD proposal is too aggressive.

The rho-set run is different: it does not start from a chosen base learning rate at all. Its step length is set directly by the local geometry.

In [5]:
single_seed = {}
for spec in RUN_SPECS:
    single_seed[spec["key"]] = run_training(spec["mode"], spec["base_lr"], SEED, target_r=TARGET_R)

summary_rows = [final_summary_row(spec, single_seed[spec["key"]]) for spec in RUN_SPECS]
display_table(
    summary_rows,
    columns=["label", "base_lr", "final_acc", "peak_r", "final_mean_eta", "final_cap_fraction"],
    formats={
        "base_lr": lambda x: "—" if x is None else f"{x:.2f}",
        "final_acc": lambda x: f"{float(x):.3f}",
        "peak_r": lambda x: f"{float(x):.3f}",
        "final_mean_eta": lambda x: f"{float(x):.4f}",
        "final_cap_fraction": lambda x: f"{float(x):.3f}",
    },
)
labelbase_lrfinal_accpeak_rfinal_mean_etafinal_cap_fraction
fixed SGD (LR=0.05)0.050.6600.0270.05000.000
fixed SGD (LR=0.5)0.500.95318.2870.50000.000
fixed SGD (LR=5.0)5.000.100405057744658765602491465728.0005.00000.000
rho-capped SGD (LR=0.05)0.050.6600.0270.05000.000
rho-capped SGD (LR=0.5)0.500.9531.0000.14221.000
rho-capped SGD (LR=5.0)5.000.9621.0000.16551.000
rho-set SGD—0.9621.0000.16011.000

5. First look: loss and accuracy¶

Start with the two outputs that are easiest to read:

  • training loss on a log scale,
  • test accuracy on a linear scale.

Reading key:

  • color = optimizer mode,
  • line style = learning-rate choice,
  • markers help when curves overlap.

The main question is whether the controller changes the optimization trajectory in a visible way.

In [6]:
apply_plot_style(font_size=10, title_size=12, label_size=10, tick_size=9)

epochs = np.arange(1, EPOCHS + 1)
style_map = {
    "sgd_low": {"color": PALETTE["sgd"], "ls": "-", "marker": "o", "alpha": 0.82},
    "sgd_mid": {"color": PALETTE["sgd"], "ls": "--", "marker": "o", "alpha": 0.82},
    "sgd_high": {"color": PALETTE["sgd"], "ls": "-.", "marker": "o", "alpha": 0.82},
    "rho_capped_low": {"color": PALETTE["rho_capped"], "ls": "-", "marker": "s", "alpha": 0.82},
    "rho_capped_mid": {"color": PALETTE["rho_capped"], "ls": "--", "marker": "s", "alpha": 0.82},
    "rho_capped_high": {"color": PALETTE["rho_capped"], "ls": "-.", "marker": "s", "alpha": 0.82},
    "rho_set": {"color": PALETTE["rho_set"], "ls": "-", "marker": "^", "alpha": 0.90},
}

fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=True)
loss_labels = []
acc_labels = []

for spec in RUN_SPECS:
    key = spec["key"]
    hist = single_seed[key]
    style = style_map[key]
    axes[0].semilogy(
        epochs, hist["train_loss"],
        color=style["color"], ls=style["ls"], lw=2.2,
        marker=style["marker"], ms=4.5, markevery=2, alpha=style["alpha"]
    )
    axes[1].plot(
        epochs, hist["test_acc"],
        color=style["color"], ls=style["ls"], lw=2.2,
        marker=style["marker"], ms=4.5, markevery=2, alpha=style["alpha"]
    )
    weight = "bold" if key == "rho_set" else None
    loss_labels.append((hist["train_loss"][-1], spec["label"], style["color"], weight))
    acc_labels.append((hist["test_acc"][-1], spec["label"], style["color"], weight))

axes[0].set_title("Geometry control prevents the high-LR loss blow-up", loc="left", fontweight="bold")
add_subtitle(axes[0], "The same optimizer becomes stable once the step is capped by the local radius.", fontsize=9)
axes[1].set_title("The controller preserves accuracy when fixed SGD collapses", loc="left", fontweight="bold")
add_subtitle(axes[1], "Low LR stays safe; high LR only works once the step is shortened.", fontsize=9)
axes[0].set_ylabel("training loss")
axes[1].set_ylabel("test accuracy")
axes[1].set_ylim(0.0, 1.02)
format_percent_axis(axes[1], xmax=1.0)

for ax in axes:
    ax.set_xlabel("epoch")
    ax.grid(True, alpha=0.25)

add_end_labels(axes[0], epochs, loss_labels, fontsize=7)
add_end_labels(axes[1], epochs, acc_labels, fontsize=7)
fig.suptitle("Fixed SGD versus geometry-controlled SGD", y=0.99, fontsize=12, fontweight="bold")
finish_figure(fig, rect=[0, 0, 1, 0.94])
plt.show()
No description has been provided for this image

6. Diagnostic view: effective learning rate and normalized step size¶

Okay, now that the visible optimization behavior is clear, let's look at the controller quantities directly.

These two panels answer a more specific question:

  • how does the effective learning rate change over training,
  • and how does the normalized step size evolve for all seven curves?

Both panels use a log scale. We also keep the curves slightly transparent so overlapping trajectories remain visible.

In [7]:
fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=True)
eta_labels = []
r_labels = []

for spec in RUN_SPECS:
    key = spec["key"]
    hist = single_seed[key]
    style = style_map[key]
    axes[0].semilogy(
        epochs, hist["mean_eta"],
        color=style["color"], ls=style["ls"], lw=2.1,
        marker=style["marker"], ms=4.0, markevery=2, alpha=0.72
    )
    axes[1].semilogy(
        epochs, hist["max_r"],
        color=style["color"], ls=style["ls"], lw=2.1,
        marker=style["marker"], ms=4.0, markevery=2, alpha=0.72
    )
    weight = "bold" if key == "rho_set" else None
    eta_labels.append((hist["mean_eta"][-1], spec["label"], style["color"], weight))
    r_labels.append((hist["max_r"][-1], spec["label"], style["color"], weight))

axes[0].set_title("The controller changes the effective learning rate", loc="left", fontweight="bold")
add_subtitle(axes[0], "High-LR runs are throttled back toward the same geometric target.", fontsize=9)
axes[1].set_title("The normalized step stays near the same boundary", loc="left", fontweight="bold")
add_subtitle(axes[1], "rho-capped and rho-set runs keep max r near the target line.", fontsize=9)
axes[0].set_ylabel("effective learning rate")
axes[1].set_ylabel("max r per epoch")
axes[1].axhline(TARGET_R, color=PALETTE["dark"], ls=":", lw=1.2)

for ax in axes:
    ax.set_xlabel("epoch")
    ax.grid(True, alpha=0.25, which="both")

add_end_labels(axes[0], epochs, eta_labels, fontsize=7)
add_end_labels(axes[1], epochs, r_labels, fontsize=7)
fig.suptitle("Controller diagnostics", y=0.99, fontsize=12, fontweight="bold")
finish_figure(fig, rect=[0, 0, 1, 0.94])
plt.show()
No description has been provided for this image

7. Summary table¶

Now that the visual story is clear, here is the same comparison in tabular form.

You should read the columns as follows:

  • final_acc: where the run ends,
  • peak_r: the largest normalized step seen during training,
  • final_mean_eta: the average effective learning rate in the last epoch,
  • final_cap_fraction: what fraction of steps were actively shortened in the last epoch.
In [8]:
single_seed = {}
for spec in RUN_SPECS:
    single_seed[spec["key"]] = run_training(spec["mode"], spec["base_lr"], SEED, target_r=TARGET_R)

summary_rows = [final_summary_row(spec, single_seed[spec["key"]]) for spec in RUN_SPECS]
display_table(
    summary_rows,
    columns=["label", "base_lr", "final_acc", "peak_r", "final_mean_eta", "final_cap_fraction"],
    formats={
        "base_lr": lambda x: "—" if x is None else f"{x:.2f}",
        "final_acc": lambda x: f"{float(x):.3f}",
        "peak_r": lambda x: f"{float(x):.3f}",
        "final_mean_eta": lambda x: f"{float(x):.4f}",
        "final_cap_fraction": lambda x: f"{float(x):.3f}",
    },
)
labelbase_lrfinal_accpeak_rfinal_mean_etafinal_cap_fraction
fixed SGD (LR=0.05)0.050.6600.0270.05000.000
fixed SGD (LR=0.5)0.500.95318.2870.50000.000
fixed SGD (LR=5.0)5.000.100405057744658765602491465728.0005.00000.000
rho-capped SGD (LR=0.05)0.050.6600.0270.05000.000
rho-capped SGD (LR=0.5)0.500.9531.0000.14221.000
rho-capped SGD (LR=5.0)5.000.9621.0000.16551.000
rho-set SGD—0.9621.0000.16011.000

8. Optional multiseed check¶

Once the single-seed story is clear, you can flip on RUN_MULTI_SEED and rerun this section.

That turns the notebook from a teaching example into a small reproducibility check.

In [9]:
if RUN_MULTI_SEED:
    multi_rows = []
    for spec in RUN_SPECS:
        rows = []
        for seed in MULTI_SEEDS:
            hist = run_training(spec["mode"], spec["base_lr"], seed, target_r=TARGET_R)
            rows.append(final_summary_row(spec, hist))
        accs = [row["final_acc"] for row in rows]
        peak_rs = [row["peak_r"] for row in rows]
        caps = [row["final_cap_fraction"] for row in rows]
        multi_rows.append({
            "run": spec["label"],
            "base_lr": "—" if spec["base_lr"] is None else f"{spec['base_lr']:.2f}",
            "acc_median": float(np.median(accs)),
            "acc_iqr": float(np.percentile(accs, 75) - np.percentile(accs, 25)),
            "peak_r_median": float(np.median(peak_rs)),
            "cap_fraction_median": float(np.median(caps)),
        })
    display_table(
        multi_rows,
        columns=["run", "base_lr", "acc_median", "acc_iqr", "peak_r_median", "cap_fraction_median"],
        formats={
            "acc_median": lambda x: f"{float(x):.3f}",
            "acc_iqr": lambda x: f"{float(x):.3f}",
            "peak_r_median": lambda x: f"{float(x):.3f}",
            "cap_fraction_median": lambda x: f"{float(x):.3f}",
        },
    )
else:
    print("RUN_MULTI_SEED is False. Flip it to True if you want the slower multiseed summary.")
RUN_MULTI_SEED is False. Flip it to True if you want the slower multiseed summary.

Where to go next¶

After this notebook, the intended sequence is:

  1. a momentum-SGD controller tutorial,
  2. an Adam controller tutorial,
  3. then the more theory-facing notebooks.