Tutorial 3: Binary Radius¶

A self-contained introduction to the binary exact radius and the multiclass lower bound.

InĀ [1]:
from pathlib import Path
import importlib.util
import subprocess
import sys


def in_colab():
    try:
        import google.colab  # type: ignore
        return True
    except ImportError:
        return False


def find_repo_root():
    cwd = Path.cwd().resolve()
    for base in [cwd, *cwd.parents]:
        if (base / "src" / "ghosts").exists() and (
            (base / "tutorials").exists() or (base / "experiments").exists()
        ):
            return base

    if in_colab():
        repo = Path('/content/ghosts-of-softmax')
        if not repo.exists():
            subprocess.run(
                [
                    'git', 'clone', '--depth', '1',
                    'https://github.com/piyush314/ghosts-of-softmax.git',
                    str(repo),
                ],
                check=True,
            )
        return repo

    raise RuntimeError(
        'Run this notebook from inside the ghosts-of-softmax repository, '
        'or open it in Google Colab so the setup cell can clone the repo automatically.'
    )


REPO = find_repo_root()
SRC = REPO / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))


def load_module(name, relative_path):
    path = REPO / relative_path
    module_dir = str(path.parent)
    if module_dir not in sys.path:
        sys.path.insert(0, module_dir)
    spec = importlib.util.spec_from_file_location(name, path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[name] = module
    spec.loader.exec_module(module)
    return module


OUTPUT_ROOT = Path('/tmp/ghosts-of-softmax-notebooks')
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
print(f"Repo root: {REPO}")
print(f"Notebook outputs: {OUTPUT_ROOT}")
Repo root: /home/runner/work/ghosts-of-softmax/ghosts-of-softmax
Notebook outputs: /tmp/ghosts-of-softmax-notebooks
InĀ [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from ghosts.plotting import PALETTE, add_end_labels, add_subtitle, apply_plot_style, finish_figure
from ghosts.radii import compute_rho_out

PI = np.pi
apply_plot_style(font_size=10, title_size=12, label_size=10, tick_size=9)

deltas = np.linspace(0.1, 10, 200)
rho_lower = PI / deltas
rho_exact = np.sqrt(deltas**2 + PI**2) / deltas

fig, ax = plt.subplots(figsize=(6.5, 4.2))
ax.plot(deltas, rho_lower, color=PALETTE['red'], lw=2)
ax.plot(deltas, rho_exact, '--', color=PALETTE['blue'], lw=2)
ax.set_xlabel(r'Logit gap $\Delta$')
ax.set_ylabel(r'Convergence radius $\rho$')
ax.set_title('The exact binary radius always stays above the multiclass lower bound', loc='left', fontweight='bold')
add_subtitle(ax, 'As predictions sharpen, both radii shrink and the gap between them narrows.', fontsize=9)
ax.set_ylim(0, 5)
ax.grid(True, alpha=0.25)
add_end_labels(ax, deltas, [
    (rho_lower[-1], r'$\pi/\Delta$', PALETTE['red'], 'bold'),
    (rho_exact[-1], r'$\sqrt{\delta^2+\pi^2}/\Delta_a$', PALETTE['blue'], None),
], fontsize=8)
finish_figure(fig)
plt.show()
No description has been provided for this image
InĀ [3]:
uniform = torch.tensor([[0.0, 0.0]])
print(f"Uniform: rho = {compute_rho_out(uniform, gap='maxmin', reduce='mean'):.2f}")

confident = torch.tensor([[0.0, 10.0]])
print(f"Confident: rho = {compute_rho_out(confident, gap='maxmin', reduce='mean'):.4f}")

batch = torch.tensor([[0.0, 1.0], [0.0, 5.0], [0.0, 10.0]])
per_sample = compute_rho_out(batch, gap='maxmin', reduce='per_sample')
print(f"Per-sample rho: {per_sample}")
Uniform: rho = 3141592.65
Confident: rho = 0.3142
Per-sample rho: tensor([3.1416, 0.6283, 0.3142])
InĀ [4]:
def ce_loss(z, target=0):
    return np.log1p(np.exp(-z)) if target == 0 else np.log1p(np.exp(z))

z0 = 2.0
rho = PI / z0

taus = np.linspace(-3 * rho, 3 * rho, 500)
losses = [ce_loss(z0 + t) for t in taus]

L0 = ce_loss(z0)
sig = 1 / (1 + np.exp(-z0))
dL = -(1 - sig)
d2L = sig * (1 - sig)
taylor2 = L0 + dL * taus + 0.5 * d2L * taus**2

fig, ax = plt.subplots(figsize=(6.5, 4.2))
ax.plot(taus, losses, color=PALETTE['red'], lw=2)
ax.plot(taus, taylor2, '--', color=PALETTE['blue'], lw=1.8)
ax.axvline(-rho, color=PALETTE['mid_gray'], ls=':', alpha=0.7)
ax.axvline(rho, color=PALETTE['mid_gray'], ls=':', alpha=0.7)
ax.text(rho * 1.05, 2.65, fr'$\rho \approx {rho:.2f}$', color=PALETTE['mid_gray'], fontsize=9)
ax.set_xlabel(r'Step size $\tau$')
ax.set_ylabel('Loss')
ax.set_title('Outside the radius, the quadratic surrogate stops tracking the true loss', loc='left', fontweight='bold')
add_subtitle(ax, f'Here the logit gap is {z0:.1f}, so the lower-bound radius is about {rho:.2f}.', fontsize=9)
ax.set_ylim(-0.1, 3)
ax.grid(True, alpha=0.25)
add_end_labels(ax, taus, [
    (losses[-1], 'true loss', PALETTE['red'], 'bold'),
    (taylor2[-1], 'quadratic Taylor', PALETTE['blue'], None),
], fontsize=8)
finish_figure(fig)
plt.show()
No description has been provided for this image