Tutorial 3: Binary Radius¶
A self-contained introduction to the binary exact radius and the multiclass lower bound.
InĀ [1]:
from pathlib import Path
import importlib.util
import subprocess
import sys
def in_colab():
try:
import google.colab # type: ignore
return True
except ImportError:
return False
def find_repo_root():
cwd = Path.cwd().resolve()
for base in [cwd, *cwd.parents]:
if (base / "src" / "ghosts").exists() and (
(base / "tutorials").exists() or (base / "experiments").exists()
):
return base
if in_colab():
repo = Path('/content/ghosts-of-softmax')
if not repo.exists():
subprocess.run(
[
'git', 'clone', '--depth', '1',
'https://github.com/piyush314/ghosts-of-softmax.git',
str(repo),
],
check=True,
)
return repo
raise RuntimeError(
'Run this notebook from inside the ghosts-of-softmax repository, '
'or open it in Google Colab so the setup cell can clone the repo automatically.'
)
REPO = find_repo_root()
SRC = REPO / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
def load_module(name, relative_path):
path = REPO / relative_path
module_dir = str(path.parent)
if module_dir not in sys.path:
sys.path.insert(0, module_dir)
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
spec.loader.exec_module(module)
return module
OUTPUT_ROOT = Path('/tmp/ghosts-of-softmax-notebooks')
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
print(f"Repo root: {REPO}")
print(f"Notebook outputs: {OUTPUT_ROOT}")
Repo root: /home/runner/work/ghosts-of-softmax/ghosts-of-softmax Notebook outputs: /tmp/ghosts-of-softmax-notebooks
InĀ [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from ghosts.plotting import PALETTE, add_end_labels, add_subtitle, apply_plot_style, finish_figure
from ghosts.radii import compute_rho_out
PI = np.pi
apply_plot_style(font_size=10, title_size=12, label_size=10, tick_size=9)
deltas = np.linspace(0.1, 10, 200)
rho_lower = PI / deltas
rho_exact = np.sqrt(deltas**2 + PI**2) / deltas
fig, ax = plt.subplots(figsize=(6.5, 4.2))
ax.plot(deltas, rho_lower, color=PALETTE['red'], lw=2)
ax.plot(deltas, rho_exact, '--', color=PALETTE['blue'], lw=2)
ax.set_xlabel(r'Logit gap $\Delta$')
ax.set_ylabel(r'Convergence radius $\rho$')
ax.set_title('The exact binary radius always stays above the multiclass lower bound', loc='left', fontweight='bold')
add_subtitle(ax, 'As predictions sharpen, both radii shrink and the gap between them narrows.', fontsize=9)
ax.set_ylim(0, 5)
ax.grid(True, alpha=0.25)
add_end_labels(ax, deltas, [
(rho_lower[-1], r'$\pi/\Delta$', PALETTE['red'], 'bold'),
(rho_exact[-1], r'$\sqrt{\delta^2+\pi^2}/\Delta_a$', PALETTE['blue'], None),
], fontsize=8)
finish_figure(fig)
plt.show()
InĀ [3]:
uniform = torch.tensor([[0.0, 0.0]])
print(f"Uniform: rho = {compute_rho_out(uniform, gap='maxmin', reduce='mean'):.2f}")
confident = torch.tensor([[0.0, 10.0]])
print(f"Confident: rho = {compute_rho_out(confident, gap='maxmin', reduce='mean'):.4f}")
batch = torch.tensor([[0.0, 1.0], [0.0, 5.0], [0.0, 10.0]])
per_sample = compute_rho_out(batch, gap='maxmin', reduce='per_sample')
print(f"Per-sample rho: {per_sample}")
Uniform: rho = 3141592.65 Confident: rho = 0.3142 Per-sample rho: tensor([3.1416, 0.6283, 0.3142])
InĀ [4]:
def ce_loss(z, target=0):
return np.log1p(np.exp(-z)) if target == 0 else np.log1p(np.exp(z))
z0 = 2.0
rho = PI / z0
taus = np.linspace(-3 * rho, 3 * rho, 500)
losses = [ce_loss(z0 + t) for t in taus]
L0 = ce_loss(z0)
sig = 1 / (1 + np.exp(-z0))
dL = -(1 - sig)
d2L = sig * (1 - sig)
taylor2 = L0 + dL * taus + 0.5 * d2L * taus**2
fig, ax = plt.subplots(figsize=(6.5, 4.2))
ax.plot(taus, losses, color=PALETTE['red'], lw=2)
ax.plot(taus, taylor2, '--', color=PALETTE['blue'], lw=1.8)
ax.axvline(-rho, color=PALETTE['mid_gray'], ls=':', alpha=0.7)
ax.axvline(rho, color=PALETTE['mid_gray'], ls=':', alpha=0.7)
ax.text(rho * 1.05, 2.65, fr'$\rho \approx {rho:.2f}$', color=PALETTE['mid_gray'], fontsize=9)
ax.set_xlabel(r'Step size $\tau$')
ax.set_ylabel('Loss')
ax.set_title('Outside the radius, the quadratic surrogate stops tracking the true loss', loc='left', fontweight='bold')
add_subtitle(ax, f'Here the logit gap is {z0:.1f}, so the lower-bound radius is about {rho:.2f}.', fontsize=9)
ax.set_ylim(-0.1, 3)
ax.grid(True, alpha=0.25)
add_end_labels(ax, taus, [
(losses[-1], 'true loss', PALETTE['red'], 'bold'),
(taylor2[-1], 'quadratic Taylor', PALETTE['blue'], None),
], fontsize=8)
finish_figure(fig)
plt.show()