Tutorial 2: Momentum SGD + rho versus Fixed-LR Momentum SGD¶
This notebook does for momentum SGD what Tutorial 1 did for Adam.
We compare two ways to run the same optimizer at the same base learning rates:
- plain fixed-LR momentum SGD,
- the exact directional
rhocontroller.
The exact controller should not look at the raw gradient alone. It should look at the actual step that momentum SGD is about to take:
- build the upcoming momentum update,
- measure its norm,
- compute
rho_a(v)along that same direction, - rescale only if the proposed step is too long.
The baseline is ordinary momentum SGD with a fixed learning rate and no controller. The key question is whether the exact controller keeps the realized step ratio r = tau / rho_a near the target while fixed-LR momentum overshoots.
from pathlib import Path
import importlib.util
import subprocess
import sys
def in_colab():
try:
import google.colab # type: ignore
return True
except ImportError:
return False
def find_repo_root():
cwd = Path.cwd().resolve()
for base in [cwd, *cwd.parents]:
if (base / "src" / "ghosts").exists() and (
(base / "tutorials").exists() or (base / "experiments").exists()
):
return base
if in_colab():
repo = Path('/content/ghosts-of-softmax')
if not repo.exists():
subprocess.run(
[
'git', 'clone', '--depth', '1',
'https://github.com/piyush314/ghosts-of-softmax.git',
str(repo),
],
check=True,
)
return repo
raise RuntimeError(
'Run this notebook from inside the ghosts-of-softmax repository, '
'or open it in Google Colab so the setup cell can clone the repo automatically.'
)
REPO = find_repo_root()
SRC = REPO / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
def load_module(name, relative_path):
path = REPO / relative_path
module_dir = str(path.parent)
if module_dir not in sys.path:
sys.path.insert(0, module_dir)
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
spec.loader.exec_module(module)
return module
OUTPUT_ROOT = Path('/tmp/ghosts-of-softmax-notebooks')
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
print(f"Repo root: {REPO}")
print(f"Notebook outputs: {OUTPUT_ROOT}")
Repo root: /home/runner/work/ghosts-of-softmax/ghosts-of-softmax Notebook outputs: /tmp/ghosts-of-softmax-notebooks
1. Setup¶
As before, we keep the tutorial small:
- one architecture (
MLP), - one seed,
- three momentum-SGD learning rates,
- one optional multiseed block.
The learning rates are chosen to show three regimes:
0.05: safe,0.5: large enough that the exact controller should engage,5.0: deliberately extreme.
import math
import random
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML, display
from ghosts.plotting import add_end_labels, add_subtitle, apply_plot_style, finish_figure, format_percent_axis
def display_table(rows, columns=None, formats=None):
if not rows:
print("No rows to display.")
return
if columns is None:
columns = list(rows[0].keys())
formats = formats or {}
parts = ['<table style="border-collapse:collapse">', '<thead><tr>']
for col in columns:
parts.append(f'<th style="text-align:left;padding:4px 8px;border-bottom:1px solid #ccc">{col}</th>')
parts.append('</tr></thead><tbody>')
for row in rows:
parts.append('<tr>')
for col in columns:
value = row.get(col, '')
if col in formats:
value = formats[col](value)
parts.append(f'<td style="padding:4px 8px">{value}</td>')
parts.append('</tr>')
parts.append('</tbody></table>')
display(HTML(''.join(parts)))
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from torch.func import functional_call, jvp
from torch.utils.data import DataLoader, TensorDataset
torch.set_num_threads(1)
DEVICE = torch.device("cpu")
LOW_LR = 0.05
MID_LR = 0.5
HIGH_LR = 2.0
LRS = [LOW_LR, MID_LR, HIGH_LR]
EPOCHS = 20
BATCH_SIZE = 128
TARGET_R = 1.0
MOMENTUM = 0.9
SEED = 7
RUN_MULTI_SEED = False
MULTI_SEEDS = [0, 1, 2, 3, 4]
PALETTE = {
"exact": "#006BA2",
"fixed": "#E3120B",
"dark": "#3D3D3D",
}
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
2. Model and data¶
We again use the digits dataset and a small MLP so the reader can focus on the controller logic rather than on model complexity.
class MLP(nn.Module):
def __init__(self, width=128):
super().__init__()
self.fc1 = nn.Linear(64, width)
self.fc2 = nn.Linear(width, width)
self.fc3 = nn.Linear(width, 10)
def forward(self, x):
x = F.gelu(self.fc1(x))
x = F.gelu(self.fc2(x))
return self.fc3(x)
def make_model():
return MLP().to(DEVICE)
def build_digits_loaders(seed: int, batch_size: int = 128):
digits = load_digits()
X = digits.data.astype(np.float32) / 16.0
y = digits.target.astype(np.int64)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.25, random_state=seed, stratify=y
)
train_ds = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_ds = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=256, shuffle=False)
return train_loader, test_loader
train_loader, test_loader = build_digits_loaders(SEED, BATCH_SIZE)
len(train_loader), len(test_loader)
(11, 2)
3. The exact momentum-SGD + $\rho$ controller¶
For momentum SGD, the correct direction is not the raw gradient. It is the next momentum update actually proposed by the optimizer.
In the simplest momentum case,
$$ b_t = \mu b_{t-1} + g_t, \qquad u_{\mathrm{unit}} = b_t, $$
so the unit-learning-rate proposal is the next momentum buffer itself.
The code in this notebook is slightly more general: it also handles weight decay, dampening, and optional Nesterov momentum. In all cases, the geometric controller uses the optimizer's actual proposed update u_unit.
Once that proposal is available, the exact controller is
$$ \tau_{\mathrm{raw}} = \eta_{\mathrm{base}} \lVert u_{\mathrm{unit}} \rVert, \qquad v = -\frac{u_{\mathrm{unit}}}{\lVert u_{\mathrm{unit}} \rVert}, \qquad \rho_{\mathrm{batch}} = \rho_a(v), $$
$$ \eta_{\mathrm{eff}} = \min\!\left(\eta_{\mathrm{base}}, \frac{r_{\mathrm{target}}\,\rho_{\mathrm{batch}}}{\lVert u_{\mathrm{unit}} \rVert}\right). $$
So the optimizer keeps its momentum direction. Only the step length changes.
Code mapping for this notebook:
u_unitcomes frommomentum_unit_step(...)||u_unit||isunit_norm\rho_a(v)is computed bybatch_rho_jvp(...)eta_effis written into SGD's parameter groups beforeopt.step()
4. Fixed learning-rate momentum SGD¶
The baseline is ordinary momentum SGD with a fixed base learning rate and no rho controller.
opt = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=momentum)
opt.step()
This is the right comparison because it isolates what the controller changes:
- the optimizer direction is still the momentum direction,
- the exact controller only shortens the step when
tau_raw > r_target * rho_a, - if the base learning rate is already safe, the exact controller behaves like fixed-LR momentum SGD.
def evaluate(model, loader):
model.eval()
total_loss = 0.0
total_correct = 0
total_count = 0
with torch.no_grad():
for xb, yb in loader:
xb = xb.to(DEVICE)
yb = yb.to(DEVICE)
logits = model(xb)
loss = F.cross_entropy(logits, yb)
total_loss += float(loss.item()) * len(xb)
total_correct += int((logits.argmax(dim=1) == yb).sum())
total_count += len(xb)
return total_loss / total_count, total_correct / total_count
def momentum_unit_step(model, opt):
step_vecs = []
sq = 0.0
for group in opt.param_groups:
momentum = float(group.get('momentum', 0.0))
weight_decay = float(group.get('weight_decay', 0.0))
dampening = float(group.get('dampening', 0.0))
nesterov = bool(group.get('nesterov', False))
for p in group['params']:
if p.grad is None:
z = torch.zeros(p.numel(), device=p.device, dtype=p.dtype)
step_vecs.append(z)
continue
grad = p.grad.detach()
if weight_decay != 0.0:
grad = grad.add(p.detach(), alpha=weight_decay)
if momentum != 0.0:
state = opt.state[p]
buf_prev = state.get('momentum_buffer', torch.zeros_like(p))
buf_next = buf_prev * momentum + grad * (1.0 - dampening)
update = grad + momentum * buf_next if nesterov else buf_next
else:
update = grad
step_vecs.append(update.flatten())
sq += float((update * update).sum().item())
step_vec = torch.cat(step_vecs)
norm = math.sqrt(sq)
return step_vec, norm
def batch_rho_jvp(model, xb, v_flat):
params = dict(model.named_parameters())
tangents = {}
offset = 0
for name, p in params.items():
numel = p.numel()
tangents[name] = v_flat[offset:offset + numel].view_as(p)
offset += numel
def fwd(pdict):
return functional_call(model, pdict, (xb,))
was_training = model.training
model.eval()
_, dlogits = jvp(fwd, (params,), (tangents,))
if was_training:
model.train()
spread = dlogits.max(dim=1).values - dlogits.min(dim=1).values
delta_a = float(spread.max().item())
return math.pi / max(delta_a, 1e-12)
def run_training(mode: str, base_lr: float, seed: int):
set_seed(seed)
train_loader, test_loader = build_digits_loaders(seed, BATCH_SIZE)
model = make_model()
opt = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=MOMENTUM)
history = {
'train_loss': [],
'test_acc': [],
'max_r': [],
'mean_r': [],
'mean_eff_lr': [],
'mean_rho': [],
}
for _ in range(EPOCHS):
model.train()
batch_losses = []
batch_r = []
batch_eff_lr = []
batch_rho = []
for xb, yb in train_loader:
xb = xb.to(DEVICE)
yb = yb.to(DEVICE)
opt.zero_grad(set_to_none=True)
logits = model(xb)
loss = F.cross_entropy(logits, yb)
loss.backward()
unit_step_vec, unit_norm = momentum_unit_step(model, opt)
if unit_norm < 1e-12:
eff_lr = float(base_lr)
for group in opt.param_groups:
group['lr'] = eff_lr
opt.step()
rho_a = float('inf')
r = 0.0
else:
v_dir = -unit_step_vec / unit_norm
rho_a = batch_rho_jvp(model, xb, v_dir)
if mode == 'exact':
eff_lr = min(base_lr, TARGET_R * rho_a / unit_norm)
elif mode == 'fixed':
eff_lr = float(base_lr)
else:
raise ValueError(mode)
for group in opt.param_groups:
group['lr'] = eff_lr
opt.step()
tau = eff_lr * unit_norm
r = tau / rho_a if rho_a > 0 else float('inf')
batch_losses.append(float(loss.item()))
batch_r.append(float(r))
batch_eff_lr.append(float(eff_lr))
batch_rho.append(float(rho_a if math.isfinite(rho_a) else 0.0))
_, test_acc = evaluate(model, test_loader)
history['train_loss'].append(float(np.mean(batch_losses)))
history['test_acc'].append(float(test_acc))
history['max_r'].append(float(np.max(batch_r)))
history['mean_r'].append(float(np.mean(batch_r)))
history['mean_eff_lr'].append(float(np.mean(batch_eff_lr)))
history['mean_rho'].append(float(np.mean(batch_rho)))
return history
RUN_SPECS = [
{'key': 'exact_low', 'mode': 'exact', 'base_lr': LOW_LR, 'label': f'exact momentum+rho (LR={LOW_LR})'},
{'key': 'exact_mid', 'mode': 'exact', 'base_lr': MID_LR, 'label': f'exact momentum+rho (LR={MID_LR})'},
{'key': 'exact_high', 'mode': 'exact', 'base_lr': HIGH_LR, 'label': f'exact momentum+rho (LR={HIGH_LR})'},
{'key': 'fixed_low', 'mode': 'fixed', 'base_lr': LOW_LR, 'label': f'fixed-LR momentum SGD (LR={LOW_LR})'},
{'key': 'fixed_mid', 'mode': 'fixed', 'base_lr': MID_LR, 'label': f'fixed-LR momentum SGD (LR={MID_LR})'},
{'key': 'fixed_high', 'mode': 'fixed', 'base_lr': HIGH_LR, 'label': f'fixed-LR momentum SGD (LR={HIGH_LR})'},
]
def final_summary_row(spec, hist):
return {
'run': spec['label'],
'final_acc': hist['test_acc'][-1],
'peak_r': max(hist['max_r']),
'final_mean_r': hist['mean_r'][-1],
'final_mean_eff_lr': hist['mean_eff_lr'][-1],
'final_mean_rho': hist['mean_rho'][-1],
}
5. Single-seed run¶
We now run the exact controller and fixed-LR momentum SGD on the same split and model.
The exact controller should engage once the base learning rate is large enough that the momentum proposal would otherwise exceed the local radius.
single_seed = {}
for spec in RUN_SPECS:
single_seed[spec['key']] = run_training(spec['mode'], spec['base_lr'], SEED)
summary_rows = [final_summary_row(spec, single_seed[spec['key']]) for spec in RUN_SPECS]
display_table(
summary_rows,
columns=['run', 'final_acc', 'peak_r', 'final_mean_r', 'final_mean_eff_lr', 'final_mean_rho'],
formats={
'final_acc': lambda x: f"{float(x):.3f}",
'peak_r': lambda x: f"{float(x):.3f}",
'final_mean_r': lambda x: f"{float(x):.3f}",
'final_mean_eff_lr': lambda x: f"{float(x):.6f}",
'final_mean_rho': lambda x: f"{float(x):.6f}",
},
)
| run | final_acc | peak_r | final_mean_r | final_mean_eff_lr | final_mean_rho |
|---|---|---|---|---|---|
| exact momentum+rho (LR=0.05) | 0.947 | 1.000 | 0.407 | 0.050000 | 0.104283 |
| exact momentum+rho (LR=0.5) | 0.978 | 1.000 | 1.000 | 0.113370 | 0.095027 |
| exact momentum+rho (LR=2.0) | 0.947 | 1.000 | 1.000 | 0.165223 | 0.125937 |
| fixed-LR momentum SGD (LR=0.05) | 0.947 | 1.128 | 0.408 | 0.050000 | 0.104073 |
| fixed-LR momentum SGD (LR=0.5) | 0.098 | 9972266370.698 | 0.017 | 0.500000 | 2.895905 |
| fixed-LR momentum SGD (LR=2.0) | 0.100 | 32254153843598.777 | 0.185 | 2.000000 | 2.923979 |
6. First look: loss and accuracy¶
Start with the same two outputs as the other controller tutorials:
- training loss on a log scale,
- test accuracy on a linear scale.
Reading key:
- color = method (
exactvsfixed LR), - line style = base learning rate,
- markers help when curves overlap.
apply_plot_style(font_size=10, title_size=12, label_size=10, tick_size=9)
epochs = np.arange(1, EPOCHS + 1)
style_map = {
'exact_low': {'color': PALETTE['exact'], 'ls': '-', 'marker': 'o', 'alpha': 0.82},
'exact_mid': {'color': PALETTE['exact'], 'ls': '--', 'marker': 'o', 'alpha': 0.82},
'exact_high': {'color': PALETTE['exact'], 'ls': '-.', 'marker': 'o', 'alpha': 0.82},
'fixed_low': {'color': PALETTE['fixed'], 'ls': '-', 'marker': 's', 'alpha': 0.82},
'fixed_mid': {'color': PALETTE['fixed'], 'ls': '--', 'marker': 's', 'alpha': 0.82},
'fixed_high': {'color': PALETTE['fixed'], 'ls': '-.', 'marker': 's', 'alpha': 0.82},
}
fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=True)
loss_labels = []
acc_labels = []
for spec in RUN_SPECS:
key = spec['key']
hist = single_seed[key]
style = style_map[key]
axes[0].semilogy(
epochs, hist['train_loss'],
color=style['color'], ls=style['ls'], lw=2.2,
marker=style['marker'], ms=4.5, markevery=2, alpha=style['alpha']
)
axes[1].plot(
epochs, hist['test_acc'],
color=style['color'], ls=style['ls'], lw=2.2,
marker=style['marker'], ms=4.5, markevery=2, alpha=style['alpha']
)
weight = 'bold' if key.startswith('exact') else None
loss_labels.append((hist['train_loss'][-1], spec['label'], style['color'], weight))
acc_labels.append((hist['test_acc'][-1], spec['label'], style['color'], weight))
axes[0].set_title('Exact step control stabilizes momentum at aggressive LR', loc='left', fontweight='bold')
add_subtitle(axes[0], 'Both runs use momentum SGD; the controller only shortens the momentum proposal when needed.', fontsize=9)
axes[1].set_title('The exact controller keeps more of the final accuracy', loc='left', fontweight='bold')
add_subtitle(axes[1], 'Fixed-LR momentum overshoots once the momentum step exceeds the local radius.', fontsize=9)
axes[0].set_ylabel('training loss')
axes[1].set_ylabel('test accuracy')
axes[1].set_ylim(0.0, 1.02)
format_percent_axis(axes[1], xmax=1.0)
for ax in axes:
ax.set_xlabel('epoch')
ax.grid(True, alpha=0.25)
add_end_labels(axes[0], epochs, loss_labels, fontsize=7)
add_end_labels(axes[1], epochs, acc_labels, fontsize=7)
fig.suptitle('Exact momentum controller versus fixed-LR momentum SGD', y=0.99, fontsize=12, fontweight='bold')
finish_figure(fig, rect=[0, 0, 1, 0.94])
plt.show()
7. Diagnostic view: effective learning rate and realized step ratio¶
The left panel shows the effective learning rate.
The right panel shows the realized normalized step ratio
r = tau / rho_a measured along the actual momentum-step direction.
This is the panel where the exact controller should separate itself from fixed-LR momentum SGD.
fig, axes = plt.subplots(1, 2, figsize=(11.8, 4.4), sharex=True)
eta_labels = []
r_labels = []
for spec in RUN_SPECS:
key = spec['key']
hist = single_seed[key]
style = style_map[key]
axes[0].semilogy(
epochs, hist['mean_eff_lr'],
color=style['color'], ls=style['ls'], lw=2.1,
marker=style['marker'], ms=4.0, markevery=2, alpha=0.72
)
axes[1].semilogy(
epochs, hist['max_r'],
color=style['color'], ls=style['ls'], lw=2.1,
marker=style['marker'], ms=4.0, markevery=2, alpha=0.72
)
weight = 'bold' if key.startswith('exact') else None
eta_labels.append((hist['mean_eff_lr'][-1], spec['label'], style['color'], weight))
r_labels.append((hist['max_r'][-1], spec['label'], style['color'], weight))
axes[0].set_title('The controller adapts the effective LR to the local radius', loc='left', fontweight='bold')
add_subtitle(axes[0], 'At safe LR the exact and fixed runs coincide; at larger LR the controller backs off.', fontsize=9)
axes[1].set_title('Exact control keeps the realized step close to the target boundary', loc='left', fontweight='bold')
add_subtitle(axes[1], 'Fixed-LR momentum can still push r far above 1.', fontsize=9)
axes[0].set_ylabel('effective learning rate')
axes[1].set_ylabel('max r per epoch')
axes[1].axhline(TARGET_R, color=PALETTE['dark'], ls=':', lw=1.2)
for ax in axes:
ax.set_xlabel('epoch')
ax.grid(True, alpha=0.25, which='both')
add_end_labels(axes[0], epochs, eta_labels, fontsize=7)
add_end_labels(axes[1], epochs, r_labels, fontsize=7)
fig.suptitle('Momentum-controller diagnostics', y=0.99, fontsize=12, fontweight='bold')
finish_figure(fig, rect=[0, 0, 1, 0.94])
plt.show()
8. Summary table¶
The plot gives the visual story. The table below summarizes the same run numerically.
single_seed = {}
for spec in RUN_SPECS:
single_seed[spec['key']] = run_training(spec['mode'], spec['base_lr'], SEED)
summary_rows = [final_summary_row(spec, single_seed[spec['key']]) for spec in RUN_SPECS]
display_table(
summary_rows,
columns=['run', 'final_acc', 'peak_r', 'final_mean_r', 'final_mean_eff_lr', 'final_mean_rho'],
formats={
'final_acc': lambda x: f"{float(x):.3f}",
'peak_r': lambda x: f"{float(x):.3f}",
'final_mean_r': lambda x: f"{float(x):.3f}",
'final_mean_eff_lr': lambda x: f"{float(x):.6f}",
'final_mean_rho': lambda x: f"{float(x):.6f}",
},
)
| run | final_acc | peak_r | final_mean_r | final_mean_eff_lr | final_mean_rho |
|---|---|---|---|---|---|
| exact momentum+rho (LR=0.05) | 0.947 | 1.000 | 0.407 | 0.050000 | 0.104283 |
| exact momentum+rho (LR=0.5) | 0.978 | 1.000 | 1.000 | 0.113370 | 0.095027 |
| exact momentum+rho (LR=2.0) | 0.947 | 1.000 | 1.000 | 0.165223 | 0.125937 |
| fixed-LR momentum SGD (LR=0.05) | 0.947 | 1.128 | 0.408 | 0.050000 | 0.104073 |
| fixed-LR momentum SGD (LR=0.5) | 0.098 | 9972266370.698 | 0.017 | 0.500000 | 2.895905 |
| fixed-LR momentum SGD (LR=2.0) | 0.100 | 32254153843598.777 | 0.185 | 2.000000 | 2.923979 |
9. Optional: multiseed check¶
If you want a more stable comparison, flip RUN_MULTI_SEED = True and rerun the next cell.
if RUN_MULTI_SEED:
grouped_rows = []
for spec in RUN_SPECS:
rows = []
for seed in MULTI_SEEDS:
hist = run_training(spec['mode'], spec['base_lr'], seed)
rows.append({
'final_acc': hist['test_acc'][-1],
'peak_r': max(hist['max_r']),
})
grouped_rows.append({
'run': spec['label'],
'final_acc_mean': float(np.mean([row['final_acc'] for row in rows])),
'final_acc_std': float(np.std([row['final_acc'] for row in rows], ddof=1)) if len(rows) > 1 else 0.0,
'peak_r_median': float(np.median([row['peak_r'] for row in rows])),
})
display_table(
grouped_rows,
columns=['run', 'final_acc_mean', 'final_acc_std', 'peak_r_median'],
formats={
'final_acc_mean': lambda x: f"{float(x):.3f}",
'final_acc_std': lambda x: f"{float(x):.3f}",
'peak_r_median': lambda x: f"{float(x):.3f}",
},
)
else:
print('Set RUN_MULTI_SEED = True to execute the multiseed check.')
Set RUN_MULTI_SEED = True to execute the multiseed check.
10. What to remember¶
- For momentum SGD, the exact controller must use the actual momentum proposal
u_unit, not just the raw gradient. - Then compute
rho_aalong the proposal directionv = -u_unit / ||u_unit||. - Then set the effective learning rate to
eta_eff = min(eta_base, r_target * rho_a / ||u_unit||)
so the realized step length satisfies tau <= r_target * rho_a.
- Relative to fixed-LR momentum SGD, the controller keeps the momentum direction and only shortens the step when the proposal is too long for the local radius.