Tutorial 4: KL Bound¶
Explore the KL quadratic approximation and its cubic remainder bound.
In [1]:
from pathlib import Path
import importlib.util
import subprocess
import sys
def in_colab():
try:
import google.colab # type: ignore
return True
except ImportError:
return False
def find_repo_root():
cwd = Path.cwd().resolve()
for base in [cwd, *cwd.parents]:
if (base / "src" / "ghosts").exists() and (
(base / "tutorials").exists() or (base / "experiments").exists()
):
return base
if in_colab():
repo = Path('/content/ghosts-of-softmax')
if not repo.exists():
subprocess.run(
[
'git', 'clone', '--depth', '1',
'https://github.com/piyush314/ghosts-of-softmax.git',
str(repo),
],
check=True,
)
return repo
raise RuntimeError(
'Run this notebook from inside the ghosts-of-softmax repository, '
'or open it in Google Colab so the setup cell can clone the repo automatically.'
)
REPO = find_repo_root()
SRC = REPO / "src"
if str(SRC) not in sys.path:
sys.path.insert(0, str(SRC))
def load_module(name, relative_path):
path = REPO / relative_path
module_dir = str(path.parent)
if module_dir not in sys.path:
sys.path.insert(0, module_dir)
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
spec.loader.exec_module(module)
return module
OUTPUT_ROOT = Path('/tmp/ghosts-of-softmax-notebooks')
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
print(f"Repo root: {REPO}")
print(f"Notebook outputs: {OUTPUT_ROOT}")
Repo root: /home/runner/work/ghosts-of-softmax/ghosts-of-softmax Notebook outputs: /tmp/ghosts-of-softmax-notebooks
In [2]:
import numpy as np
import matplotlib.pyplot as plt
from ghosts.plotting import PALETTE, add_end_labels, add_subtitle, apply_plot_style, finish_figure
from ghosts.theory import (
softmax, computeAttentionKL,
computeVariance, computeSlopeSpread,
klRemainderBound, verifyBound,
)
alpha0 = softmax(np.array([[1.0, 2.0, 0.5]]))
a = np.array([[0.3, -0.5, 0.8]])
var0 = computeVariance(alpha0, a)
delta = computeSlopeSpread(a)
apply_plot_style(font_size=10, title_size=12, label_size=10, tick_size=9)
taus = np.linspace(0, 1.5, 200)
kl_exact = [computeAttentionKL(alpha0, a, t) for t in taus]
kl_quad = [0.5 * t**2 * var0 for t in taus]
bound = [klRemainderBound(t, delta) for t in taus]
fig, axes = plt.subplots(1, 2, figsize=(12, 4.3))
ax = axes[0]
ax.plot(taus, kl_exact, color=PALETTE['red'], lw=2)
ax.plot(taus, kl_quad, '--', color=PALETTE['blue'], lw=1.7)
ax.set_xlabel(r'$\tau$')
ax.set_ylabel('KL divergence')
ax.set_title('The quadratic term is accurate only close to zero', loc='left', fontweight='bold')
add_subtitle(ax, 'Farther out, the local quadratic model understates the true KL.', fontsize=9)
ax.grid(True, alpha=0.25)
add_end_labels(ax, taus, [
(kl_exact[-1], 'exact KL', PALETTE['red'], 'bold'),
(kl_quad[-1], r'quadratic $\tau^2 \mathrm{Var}/2$', PALETTE['blue'], None),
], fontsize=8)
ax = axes[1]
remainder = [abs(e - q) for e, q in zip(kl_exact, kl_quad)]
ax.plot(taus, remainder, color=PALETTE['red'], lw=2)
ax.plot(taus, bound, '--', color=PALETTE['blue'], lw=1.7)
ax.set_xlabel(r'$\tau$')
ax.set_ylabel('Magnitude')
ax.set_title('The cubic bound tracks the remainder from above', loc='left', fontweight='bold')
add_subtitle(ax, 'The sharp bound stays above the true error while keeping the same scale.', fontsize=9)
ax.grid(True, alpha=0.25)
add_end_labels(ax, taus, [
(remainder[-1], 'remainder', PALETTE['red'], 'bold'),
(bound[-1], r'cubic bound', PALETTE['blue'], None),
], fontsize=8)
finish_figure(fig)
plt.show()
In [3]:
rng = np.random.default_rng(42)
ratios = []
for _ in range(100):
n = rng.integers(2, 8)
z = rng.standard_normal((1, n))
a_vec = rng.standard_normal((1, n))
tau = rng.uniform(0.01, 0.5)
result = verifyBound(softmax(z), a_vec, tau)
ratios.append(result['ratio'])
assert result['valid'], f"Bound violated! ratio={result['ratio']}"
print(f"All 100 tests passed.")
print(f"Tightness ratios: min={min(ratios):.4f}, max={max(ratios):.4f}, mean={np.mean(ratios):.4f}")
All 100 tests passed. Tightness ratios: min=0.0027, max=0.9964, mean=0.3270
In [4]:
p = (3 + np.sqrt(3)) / 6
alpha0 = np.array([[p, 1 - p]])
a = np.array([[0.0, 1.0]])
taus = np.linspace(0.01, 0.3, 50)
ratios = []
for t in taus:
r = verifyBound(alpha0, a, t)
ratios.append(r['ratio'])
fig, ax = plt.subplots(figsize=(6.5, 4.2))
ax.plot(taus, ratios, color=PALETTE['red'], lw=2)
ax.axhline(1.0, color=PALETTE['blue'], ls='--', alpha=0.6)
ax.set_xlabel(r'$\tau$')
ax.set_ylabel('Remainder / Bound')
ax.set_title('At the Bernoulli extremum, the sharp bound is nearly tight', loc='left', fontweight='bold')
add_subtitle(ax, f'The extremal probability here is p={p:.4f}. Values near 1 mean the bound is tight.', fontsize=9)
ax.grid(True, alpha=0.25)
add_end_labels(ax, taus, [
(ratios[-1], 'remainder / bound', PALETTE['red'], 'bold'),
(1.0, 'bound = 1', PALETTE['blue'], None),
], fontsize=8)
finish_figure(fig)
plt.show()
print(f"Max ratio (tightness): {max(ratios):.4f}")
Max ratio (tightness): 1.0000