Tutorial 2: Momentum SGD + rho versus Fixed-LR Momentum SGD¶

This notebook does for momentum SGD what Tutorial 1 did for Adam.

We compare two ways to run the same optimizer at the same base learning rates:

  • plain fixed-LR momentum SGD,
  • the exact directional rho controller.

The exact controller should not look at the raw gradient alone. It should look at the actual step that momentum SGD is about to take:

  • build the upcoming momentum update,
  • measure its norm,
  • compute rho_a(v) along that same direction,
  • rescale only if the proposed step is too long.

The baseline is ordinary momentum SGD with a fixed learning rate and no controller. The key question is whether the exact controller keeps the realized step ratio r = tau / rho_a near the target while fixed-LR momentum overshoots.

In [1]:
from pathlib import Path
import importlib.util
import subprocess
import sys


def in_colab():
    try:
        import google.colab  # type: ignore
        return True
    except ImportError:
        return False


def find_repo_root():
    cwd = Path.cwd().resolve()
    for base in [cwd, *cwd.parents]:
        if (base / "src" / "ghosts").exists() and (
            (base / "tutorials").exists() or (base / "experiments").exists()
        ):
            return base

    if in_colab():
        repo = Path('/content/ghosts-of-softmax')
        if not repo.exists():
            subprocess.run(
                [
                    'git', 'clone', '--depth', '1',
                    'https://github.com/piyush314/ghosts-of-softmax.git',
                    str(repo),
                ],
                check=True,
            )
        return repo

    raise RuntimeError(
        'Run this notebook from inside the ghosts-of-softmax repository, '
        'or open it in Google Colab so the setup cell can clone the repo automatically.'
    )


REPO = find_repo_root()
SRC = REPO / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))


def load_module(name, relative_path):
    path = REPO / relative_path
    module_dir = str(path.parent)
    if module_dir not in sys.path:
        sys.path.insert(0, module_dir)
    spec = importlib.util.spec_from_file_location(name, path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[name] = module
    spec.loader.exec_module(module)
    return module


OUTPUT_ROOT = Path('/tmp/ghosts-of-softmax-notebooks')
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
print(f"Repo root: {REPO}")
print(f"Notebook outputs: {OUTPUT_ROOT}")
Repo root: /home/runner/work/ghosts-of-softmax/ghosts-of-softmax
Notebook outputs: /tmp/ghosts-of-softmax-notebooks

1. Setup¶

As before, we keep the tutorial small:

  • one architecture (MLP),
  • one seed,
  • three momentum-SGD learning rates,
  • one optional multiseed block.

The learning rates are chosen to show three regimes:

  • 0.05: safe,
  • 0.5: large enough that the exact controller should engage,
  • 5.0: deliberately extreme.
In [2]:
import math
import random

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML, display
from ghosts.plotting import add_end_labels, add_subtitle, apply_plot_style, finish_figure, format_percent_axis


def display_table(rows, columns=None, formats=None):
    if not rows:
        print("No rows to display.")
        return
    if columns is None:
        columns = list(rows[0].keys())
    formats = formats or {}
    parts = ['<table style="border-collapse:collapse">', '<thead><tr>']
    for col in columns:
        parts.append(f'<th style="text-align:left;padding:4px 8px;border-bottom:1px solid #ccc">{col}</th>')
    parts.append('</tr></thead><tbody>')
    for row in rows:
        parts.append('<tr>')
        for col in columns:
            value = row.get(col, '')
            if col in formats:
                value = formats[col](value)
            parts.append(f'<td style="padding:4px 8px">{value}</td>')
        parts.append('</tr>')
    parts.append('</tbody></table>')
    display(HTML(''.join(parts)))
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from torch.func import functional_call, jvp
from torch.utils.data import DataLoader, TensorDataset

torch.set_num_threads(1)

DEVICE = torch.device("cpu")
LOW_LR = 0.05
MID_LR = 0.5
HIGH_LR = 2.0
LRS = [LOW_LR, MID_LR, HIGH_LR]
EPOCHS = 20
BATCH_SIZE = 128
TARGET_R = 1.0
MOMENTUM = 0.9
SEED = 7

RUN_MULTI_SEED = False
MULTI_SEEDS = [0, 1, 2, 3, 4]

PALETTE = {
    "exact": "#006BA2",
    "fixed": "#E3120B",
    "dark": "#3D3D3D",
}


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

2. Model and data¶

We again use the digits dataset and a small MLP so the reader can focus on the controller logic rather than on model complexity.

In [3]:
class MLP(nn.Module):
    def __init__(self, width=128):
        super().__init__()
        self.fc1 = nn.Linear(64, width)
        self.fc2 = nn.Linear(width, width)
        self.fc3 = nn.Linear(width, 10)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        return self.fc3(x)


def make_model():
    return MLP().to(DEVICE)


def build_digits_loaders(seed: int, batch_size: int = 128):
    digits = load_digits()
    X = digits.data.astype(np.float32) / 16.0
    y = digits.target.astype(np.int64)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.25, random_state=seed, stratify=y
    )
    train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
    test_ds = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)
    return train_loader, test_loader


train_loader, test_loader = build_digits_loaders(SEED, BATCH_SIZE)
len(train_loader), len(test_loader)
Out[3]:
(11, 2)

3. The exact momentum-SGD + $\rho$ controller¶

For momentum SGD, the correct direction is not the raw gradient. It is the next momentum update actually proposed by the optimizer.

In the simplest momentum case,

$$ b_t = \mu b_{t-1} + g_t, \qquad u_{\mathrm{unit}} = b_t, $$

so the unit-learning-rate proposal is the next momentum buffer itself.

The code in this notebook is slightly more general: it also handles weight decay, dampening, and optional Nesterov momentum. In all cases, the geometric controller uses the optimizer's actual proposed update u_unit.

Once that proposal is available, the exact controller is

$$ \tau_{\mathrm{raw}} = \eta_{\mathrm{base}} \lVert u_{\mathrm{unit}} \rVert, \qquad v = -\frac{u_{\mathrm{unit}}}{\lVert u_{\mathrm{unit}} \rVert}, \qquad \rho_{\mathrm{batch}} = \rho_a(v), $$

$$ \eta_{\mathrm{eff}} = \min\!\left(\eta_{\mathrm{base}}, \frac{r_{\mathrm{target}}\,\rho_{\mathrm{batch}}}{\lVert u_{\mathrm{unit}} \rVert}\right). $$

So the optimizer keeps its momentum direction. Only the step length changes.

Code mapping for this notebook:

  • u_unit comes from momentum_unit_step(...)
  • ||u_unit|| is unit_norm
  • \rho_a(v) is computed by batch_rho_jvp(...)
  • eta_eff is written into SGD's parameter groups before opt.step()

4. Fixed learning-rate momentum SGD¶

The baseline is ordinary momentum SGD with a fixed base learning rate and no rho controller.

opt = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=momentum)
opt.step()

This is the right comparison because it isolates what the controller changes:

  • the optimizer direction is still the momentum direction,
  • the exact controller only shortens the step when tau_raw > r_target * rho_a,
  • if the base learning rate is already safe, the exact controller behaves like fixed-LR momentum SGD.
In [4]:
def evaluate(model, loader):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_count = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            total_loss += float(loss.item()) * len(xb)
            total_correct += int((logits.argmax(dim=1) == yb).sum())
            total_count += len(xb)
    return total_loss / total_count, total_correct / total_count


def momentum_unit_step(model, opt):
    step_vecs = []
    sq = 0.0
    for group in opt.param_groups:
        momentum = float(group.get('momentum', 0.0))
        weight_decay = float(group.get('weight_decay', 0.0))
        dampening = float(group.get('dampening', 0.0))
        nesterov = bool(group.get('nesterov', False))

        for p in group['params']:
            if p.grad is None:
                z = torch.zeros(p.numel(), device=p.device, dtype=p.dtype)
                step_vecs.append(z)
                continue

            grad = p.grad.detach()
            if weight_decay != 0.0:
                grad = grad.add(p.detach(), alpha=weight_decay)

            if momentum != 0.0:
                state = opt.state[p]
                buf_prev = state.get('momentum_buffer', torch.zeros_like(p))
                buf_next = buf_prev * momentum + grad * (1.0 - dampening)
                update = grad + momentum * buf_next if nesterov else buf_next
            else:
                update = grad

            step_vecs.append(update.flatten())
            sq += float((update * update).sum().item())

    step_vec = torch.cat(step_vecs)
    norm = math.sqrt(sq)
    return step_vec, norm


def batch_rho_jvp(model, xb, v_flat):
    params = dict(model.named_parameters())
    tangents = {}
    offset = 0
    for name, p in params.items():
        numel = p.numel()
        tangents[name] = v_flat[offset:offset + numel].view_as(p)
        offset += numel

    def fwd(pdict):
        return functional_call(model, pdict, (xb,))

    was_training = model.training
    model.eval()
    _, dlogits = jvp(fwd, (params,), (tangents,))
    if was_training:
        model.train()

    spread = dlogits.max(dim=1).values - dlogits.min(dim=1).values
    delta_a = float(spread.max().item())
    return math.pi / max(delta_a, 1e-12)


def run_training(mode: str, base_lr: float, seed: int):
    set_seed(seed)
    train_loader, test_loader = build_digits_loaders(seed, BATCH_SIZE)
    model = make_model()
    opt = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=MOMENTUM)

    history = {
        'train_loss': [],
        'test_acc': [],
        'max_r': [],
        'mean_r': [],
        'mean_eff_lr': [],
        'mean_rho': [],
    }

    for _ in range(EPOCHS):
        model.train()
        batch_losses = []
        batch_r = []
        batch_eff_lr = []
        batch_rho = []

        for xb, yb in train_loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            loss.backward()

            unit_step_vec, unit_norm = momentum_unit_step(model, opt)
            if unit_norm < 1e-12:
                eff_lr = float(base_lr)
                for group in opt.param_groups:
                    group['lr'] = eff_lr
                opt.step()
                rho_a = float('inf')
                r = 0.0
            else:
                v_dir = -unit_step_vec / unit_norm
                rho_a = batch_rho_jvp(model, xb, v_dir)

                if mode == 'exact':
                    eff_lr = min(base_lr, TARGET_R * rho_a / unit_norm)
                elif mode == 'fixed':
                    eff_lr = float(base_lr)
                else:
                    raise ValueError(mode)

                for group in opt.param_groups:
                    group['lr'] = eff_lr
                opt.step()
                tau = eff_lr * unit_norm
                r = tau / rho_a if rho_a > 0 else float('inf')

            batch_losses.append(float(loss.item()))
            batch_r.append(float(r))
            batch_eff_lr.append(float(eff_lr))
            batch_rho.append(float(rho_a if math.isfinite(rho_a) else 0.0))

        _, test_acc = evaluate(model, test_loader)
        history['train_loss'].append(float(np.mean(batch_losses)))
        history['test_acc'].append(float(test_acc))
        history['max_r'].append(float(np.max(batch_r)))
        history['mean_r'].append(float(np.mean(batch_r)))
        history['mean_eff_lr'].append(float(np.mean(batch_eff_lr)))
        history['mean_rho'].append(float(np.mean(batch_rho)))

    return history


RUN_SPECS = [
    {'key': 'exact_low', 'mode': 'exact', 'base_lr': LOW_LR, 'label': f'exact momentum+rho (LR={LOW_LR})'},
    {'key': 'exact_mid', 'mode': 'exact', 'base_lr': MID_LR, 'label': f'exact momentum+rho (LR={MID_LR})'},
    {'key': 'exact_high', 'mode': 'exact', 'base_lr': HIGH_LR, 'label': f'exact momentum+rho (LR={HIGH_LR})'},
    {'key': 'fixed_low', 'mode': 'fixed', 'base_lr': LOW_LR, 'label': f'fixed-LR momentum SGD (LR={LOW_LR})'},
    {'key': 'fixed_mid', 'mode': 'fixed', 'base_lr': MID_LR, 'label': f'fixed-LR momentum SGD (LR={MID_LR})'},
    {'key': 'fixed_high', 'mode': 'fixed', 'base_lr': HIGH_LR, 'label': f'fixed-LR momentum SGD (LR={HIGH_LR})'},
]


def final_summary_row(spec, hist):
    return {
        'run': spec['label'],
        'final_acc': hist['test_acc'][-1],
        'peak_r': max(hist['max_r']),
        'final_mean_r': hist['mean_r'][-1],
        'final_mean_eff_lr': hist['mean_eff_lr'][-1],
        'final_mean_rho': hist['mean_rho'][-1],
    }

5. Single-seed run¶

We now run the exact controller and fixed-LR momentum SGD on the same split and model.

The exact controller should engage once the base learning rate is large enough that the momentum proposal would otherwise exceed the local radius.

In [5]:
single_seed = {}
for spec in RUN_SPECS:
    single_seed[spec['key']] = run_training(spec['mode'], spec['base_lr'], SEED)

summary_rows = [final_summary_row(spec, single_seed[spec['key']]) for spec in RUN_SPECS]
display_table(
    summary_rows,
    columns=['run', 'final_acc', 'peak_r', 'final_mean_r', 'final_mean_eff_lr', 'final_mean_rho'],
    formats={
        'final_acc': lambda x: f"{float(x):.3f}",
        'peak_r': lambda x: f"{float(x):.3f}",
        'final_mean_r': lambda x: f"{float(x):.3f}",
        'final_mean_eff_lr': lambda x: f"{float(x):.6f}",
        'final_mean_rho': lambda x: f"{float(x):.6f}",
    },
)
runfinal_accpeak_rfinal_mean_rfinal_mean_eff_lrfinal_mean_rho
exact momentum+rho (LR=0.05)0.9471.0000.4070.0500000.104283
exact momentum+rho (LR=0.5)0.9781.0001.0000.1133700.095027
exact momentum+rho (LR=2.0)0.9471.0001.0000.1652230.125937
fixed-LR momentum SGD (LR=0.05)0.9471.1280.4080.0500000.104073
fixed-LR momentum SGD (LR=0.5)0.0989972266370.6980.0170.5000002.895905
fixed-LR momentum SGD (LR=2.0)0.10032254153843598.7770.1852.0000002.923979

6. First look: loss and accuracy¶

Start with the same two outputs as the other controller tutorials:

  • training loss on a log scale,
  • test accuracy on a linear scale.

Reading key:

  • color = method (exact vs fixed LR),
  • line style = base learning rate,
  • markers help when curves overlap.
In [6]:
apply_plot_style(font_size=10, title_size=12, label_size=10, tick_size=9)

epochs = np.arange(1, EPOCHS + 1)
style_map = {
    'exact_low': {'color': PALETTE['exact'], 'ls': '-', 'marker': 'o', 'alpha': 0.82},
    'exact_mid': {'color': PALETTE['exact'], 'ls': '--', 'marker': 'o', 'alpha': 0.82},
    'exact_high': {'color': PALETTE['exact'], 'ls': '-.', 'marker': 'o', 'alpha': 0.82},
    'fixed_low': {'color': PALETTE['fixed'], 'ls': '-', 'marker': 's', 'alpha': 0.82},
    'fixed_mid': {'color': PALETTE['fixed'], 'ls': '--', 'marker': 's', 'alpha': 0.82},
    'fixed_high': {'color': PALETTE['fixed'], 'ls': '-.', 'marker': 's', 'alpha': 0.82},
}

fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=True)
loss_labels = []
acc_labels = []
for spec in RUN_SPECS:
    key = spec['key']
    hist = single_seed[key]
    style = style_map[key]
    axes[0].semilogy(
        epochs, hist['train_loss'],
        color=style['color'], ls=style['ls'], lw=2.2,
        marker=style['marker'], ms=4.5, markevery=2, alpha=style['alpha']
    )
    axes[1].plot(
        epochs, hist['test_acc'],
        color=style['color'], ls=style['ls'], lw=2.2,
        marker=style['marker'], ms=4.5, markevery=2, alpha=style['alpha']
    )
    weight = 'bold' if key.startswith('exact') else None
    loss_labels.append((hist['train_loss'][-1], spec['label'], style['color'], weight))
    acc_labels.append((hist['test_acc'][-1], spec['label'], style['color'], weight))

axes[0].set_title('Exact step control stabilizes momentum at aggressive LR', loc='left', fontweight='bold')
add_subtitle(axes[0], 'Both runs use momentum SGD; the controller only shortens the momentum proposal when needed.', fontsize=9)
axes[1].set_title('The exact controller keeps more of the final accuracy', loc='left', fontweight='bold')
add_subtitle(axes[1], 'Fixed-LR momentum overshoots once the momentum step exceeds the local radius.', fontsize=9)
axes[0].set_ylabel('training loss')
axes[1].set_ylabel('test accuracy')
axes[1].set_ylim(0.0, 1.02)
format_percent_axis(axes[1], xmax=1.0)
for ax in axes:
    ax.set_xlabel('epoch')
    ax.grid(True, alpha=0.25)

add_end_labels(axes[0], epochs, loss_labels, fontsize=7)
add_end_labels(axes[1], epochs, acc_labels, fontsize=7)
fig.suptitle('Exact momentum controller versus fixed-LR momentum SGD', y=0.99, fontsize=12, fontweight='bold')
finish_figure(fig, rect=[0, 0, 1, 0.94])
plt.show()
No description has been provided for this image

7. Diagnostic view: effective learning rate and realized step ratio¶

The left panel shows the effective learning rate. The right panel shows the realized normalized step ratio r = tau / rho_a measured along the actual momentum-step direction.

This is the panel where the exact controller should separate itself from fixed-LR momentum SGD.

In [7]:
fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=True)
eta_labels = []
r_labels = []
for spec in RUN_SPECS:
    key = spec['key']
    hist = single_seed[key]
    style = style_map[key]
    axes[0].semilogy(
        epochs, hist['mean_eff_lr'],
        color=style['color'], ls=style['ls'], lw=2.1,
        marker=style['marker'], ms=4.0, markevery=2, alpha=0.72
    )
    axes[1].semilogy(
        epochs, hist['max_r'],
        color=style['color'], ls=style['ls'], lw=2.1,
        marker=style['marker'], ms=4.0, markevery=2, alpha=0.72
    )
    weight = 'bold' if key.startswith('exact') else None
    eta_labels.append((hist['mean_eff_lr'][-1], spec['label'], style['color'], weight))
    r_labels.append((hist['max_r'][-1], spec['label'], style['color'], weight))

axes[0].set_title('The controller adapts the effective LR to the local radius', loc='left', fontweight='bold')
add_subtitle(axes[0], 'At safe LR the exact and fixed runs coincide; at larger LR the controller backs off.', fontsize=9)
axes[1].set_title('Exact control keeps the realized step close to the target boundary', loc='left', fontweight='bold')
add_subtitle(axes[1], 'Fixed-LR momentum can still push r far above 1.', fontsize=9)
axes[0].set_ylabel('effective learning rate')
axes[1].set_ylabel('max r per epoch')
axes[1].axhline(TARGET_R, color=PALETTE['dark'], ls=':', lw=1.2)
for ax in axes:
    ax.set_xlabel('epoch')
    ax.grid(True, alpha=0.25, which='both')

add_end_labels(axes[0], epochs, eta_labels, fontsize=7)
add_end_labels(axes[1], epochs, r_labels, fontsize=7)
fig.suptitle('Momentum-controller diagnostics', y=0.99, fontsize=12, fontweight='bold')
finish_figure(fig, rect=[0, 0, 1, 0.94])
plt.show()
No description has been provided for this image

8. Summary table¶

The plot gives the visual story. The table below summarizes the same run numerically.

In [8]:
single_seed = {}
for spec in RUN_SPECS:
    single_seed[spec['key']] = run_training(spec['mode'], spec['base_lr'], SEED)

summary_rows = [final_summary_row(spec, single_seed[spec['key']]) for spec in RUN_SPECS]
display_table(
    summary_rows,
    columns=['run', 'final_acc', 'peak_r', 'final_mean_r', 'final_mean_eff_lr', 'final_mean_rho'],
    formats={
        'final_acc': lambda x: f"{float(x):.3f}",
        'peak_r': lambda x: f"{float(x):.3f}",
        'final_mean_r': lambda x: f"{float(x):.3f}",
        'final_mean_eff_lr': lambda x: f"{float(x):.6f}",
        'final_mean_rho': lambda x: f"{float(x):.6f}",
    },
)
runfinal_accpeak_rfinal_mean_rfinal_mean_eff_lrfinal_mean_rho
exact momentum+rho (LR=0.05)0.9471.0000.4070.0500000.104283
exact momentum+rho (LR=0.5)0.9781.0001.0000.1133700.095027
exact momentum+rho (LR=2.0)0.9471.0001.0000.1652230.125937
fixed-LR momentum SGD (LR=0.05)0.9471.1280.4080.0500000.104073
fixed-LR momentum SGD (LR=0.5)0.0989972266370.6980.0170.5000002.895905
fixed-LR momentum SGD (LR=2.0)0.10032254153843598.7770.1852.0000002.923979

9. Optional: multiseed check¶

If you want a more stable comparison, flip RUN_MULTI_SEED = True and rerun the next cell.

In [9]:
if RUN_MULTI_SEED:
    grouped_rows = []
    for spec in RUN_SPECS:
        rows = []
        for seed in MULTI_SEEDS:
            hist = run_training(spec['mode'], spec['base_lr'], seed)
            rows.append({
                'final_acc': hist['test_acc'][-1],
                'peak_r': max(hist['max_r']),
            })
        grouped_rows.append({
            'run': spec['label'],
            'final_acc_mean': float(np.mean([row['final_acc'] for row in rows])),
            'final_acc_std': float(np.std([row['final_acc'] for row in rows], ddof=1)) if len(rows) > 1 else 0.0,
            'peak_r_median': float(np.median([row['peak_r'] for row in rows])),
        })

    display_table(
        grouped_rows,
        columns=['run', 'final_acc_mean', 'final_acc_std', 'peak_r_median'],
        formats={
            'final_acc_mean': lambda x: f"{float(x):.3f}",
            'final_acc_std': lambda x: f"{float(x):.3f}",
            'peak_r_median': lambda x: f"{float(x):.3f}",
        },
    )
else:
    print('Set RUN_MULTI_SEED = True to execute the multiseed check.')
Set RUN_MULTI_SEED = True to execute the multiseed check.

10. What to remember¶

  • For momentum SGD, the exact controller must use the actual momentum proposal u_unit, not just the raw gradient.
  • Then compute rho_a along the proposal direction v = -u_unit / ||u_unit||.
  • Then set the effective learning rate to
eta_eff = min(eta_base, r_target * rho_a / ||u_unit||)

so the realized step length satisfies tau <= r_target * rho_a.

  • Relative to fixed-LR momentum SGD, the controller keeps the momentum direction and only shortens the step when the proposal is too long for the local radius.