Tutorial 4: KL Bound¶

Explore the KL quadratic approximation and its cubic remainder bound.

In [1]:
from pathlib import Path
import importlib.util
import subprocess
import sys


def in_colab():
    try:
        import google.colab  # type: ignore
        return True
    except ImportError:
        return False


def find_repo_root():
    cwd = Path.cwd().resolve()
    for base in [cwd, *cwd.parents]:
        if (base / "src" / "ghosts").exists() and (
            (base / "tutorials").exists() or (base / "experiments").exists()
        ):
            return base

    if in_colab():
        repo = Path('/content/ghosts-of-softmax')
        if not repo.exists():
            subprocess.run(
                [
                    'git', 'clone', '--depth', '1',
                    'https://github.com/piyush314/ghosts-of-softmax.git',
                    str(repo),
                ],
                check=True,
            )
        return repo

    raise RuntimeError(
        'Run this notebook from inside the ghosts-of-softmax repository, '
        'or open it in Google Colab so the setup cell can clone the repo automatically.'
    )


REPO = find_repo_root()
SRC = REPO / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))


def load_module(name, relative_path):
    path = REPO / relative_path
    module_dir = str(path.parent)
    if module_dir not in sys.path:
        sys.path.insert(0, module_dir)
    spec = importlib.util.spec_from_file_location(name, path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[name] = module
    spec.loader.exec_module(module)
    return module


OUTPUT_ROOT = Path('/tmp/ghosts-of-softmax-notebooks')
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
print(f"Repo root: {REPO}")
print(f"Notebook outputs: {OUTPUT_ROOT}")
Repo root: /home/runner/work/ghosts-of-softmax/ghosts-of-softmax
Notebook outputs: /tmp/ghosts-of-softmax-notebooks
In [2]:
import numpy as np
import matplotlib.pyplot as plt
from ghosts.plotting import PALETTE, add_end_labels, add_subtitle, apply_plot_style, finish_figure
from ghosts.theory import (
    softmax, computeAttentionKL,
    computeVariance, computeSlopeSpread,
    klRemainderBound, verifyBound,
)

alpha0 = softmax(np.array([[1.0, 2.0, 0.5]]))
a = np.array([[0.3, -0.5, 0.8]])
var0 = computeVariance(alpha0, a)
delta = computeSlopeSpread(a)
apply_plot_style(font_size=10, title_size=12, label_size=10, tick_size=9)

taus = np.linspace(0, 1.5, 200)
kl_exact = [computeAttentionKL(alpha0, a, t) for t in taus]
kl_quad = [0.5 * t**2 * var0 for t in taus]
bound = [klRemainderBound(t, delta) for t in taus]

fig, axes = plt.subplots(1, 2, figsize=(12, 4.3))
ax = axes[0]
ax.plot(taus, kl_exact, color=PALETTE['red'], lw=2)
ax.plot(taus, kl_quad, '--', color=PALETTE['blue'], lw=1.7)
ax.set_xlabel(r'$\tau$')
ax.set_ylabel('KL divergence')
ax.set_title('The quadratic term is accurate only close to zero', loc='left', fontweight='bold')
add_subtitle(ax, 'Farther out, the local quadratic model understates the true KL.', fontsize=9)
ax.grid(True, alpha=0.25)
add_end_labels(ax, taus, [
    (kl_exact[-1], 'exact KL', PALETTE['red'], 'bold'),
    (kl_quad[-1], r'quadratic $\tau^2 \mathrm{Var}/2$', PALETTE['blue'], None),
], fontsize=8)

ax = axes[1]
remainder = [abs(e - q) for e, q in zip(kl_exact, kl_quad)]
ax.plot(taus, remainder, color=PALETTE['red'], lw=2)
ax.plot(taus, bound, '--', color=PALETTE['blue'], lw=1.7)
ax.set_xlabel(r'$\tau$')
ax.set_ylabel('Magnitude')
ax.set_title('The cubic bound tracks the remainder from above', loc='left', fontweight='bold')
add_subtitle(ax, 'The sharp bound stays above the true error while keeping the same scale.', fontsize=9)
ax.grid(True, alpha=0.25)
add_end_labels(ax, taus, [
    (remainder[-1], 'remainder', PALETTE['red'], 'bold'),
    (bound[-1], r'cubic bound', PALETTE['blue'], None),
], fontsize=8)

finish_figure(fig)
plt.show()
No description has been provided for this image
In [3]:
rng = np.random.default_rng(42)
ratios = []
for _ in range(100):
    n = rng.integers(2, 8)
    z = rng.standard_normal((1, n))
    a_vec = rng.standard_normal((1, n))
    tau = rng.uniform(0.01, 0.5)
    result = verifyBound(softmax(z), a_vec, tau)
    ratios.append(result['ratio'])
    assert result['valid'], f"Bound violated! ratio={result['ratio']}"

print(f"All 100 tests passed.")
print(f"Tightness ratios: min={min(ratios):.4f}, max={max(ratios):.4f}, mean={np.mean(ratios):.4f}")
All 100 tests passed.
Tightness ratios: min=0.0027, max=0.9964, mean=0.3270
In [4]:
p = (3 + np.sqrt(3)) / 6
alpha0 = np.array([[p, 1 - p]])
a = np.array([[0.0, 1.0]])

taus = np.linspace(0.01, 0.3, 50)
ratios = []
for t in taus:
    r = verifyBound(alpha0, a, t)
    ratios.append(r['ratio'])

fig, ax = plt.subplots(figsize=(6.5, 4.2))
ax.plot(taus, ratios, color=PALETTE['red'], lw=2)
ax.axhline(1.0, color=PALETTE['blue'], ls='--', alpha=0.6)
ax.set_xlabel(r'$\tau$')
ax.set_ylabel('Remainder / Bound')
ax.set_title('At the Bernoulli extremum, the sharp bound is nearly tight', loc='left', fontweight='bold')
add_subtitle(ax, f'The extremal probability here is p={p:.4f}. Values near 1 mean the bound is tight.', fontsize=9)
ax.grid(True, alpha=0.25)
add_end_labels(ax, taus, [
    (ratios[-1], 'remainder / bound', PALETTE['red'], 'bold'),
    (1.0, 'bound = 1', PALETTE['blue'], None),
], fontsize=8)
finish_figure(fig)
plt.show()

print(f"Max ratio (tightness): {max(ratios):.4f}")
No description has been provided for this image
Max ratio (tightness): 1.0000