-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
Hi @josejg When using quantize=False with fused=True , the Triton kernels _triton_momentum_kernel and _triton_adam_kernel incorrectly cast optimizer states to PARAM_DTYPE before storing them back to memory. Since the underlying storage tensors are fp32, this introduces an fp32→bf16→fp32 roundtrip every step, accumulating significant precision loss. :-(
How to see it
import torch
from flashoptim import FlashLion, FlashAdam, FlashSGD
def _seed():
torch.manual_seed(42)
torch.cuda.manual_seed(42)
def make_model(hidden: int = 256) -> torch.nn.Module:
"""A tiny bf16 model on CUDA."""
_seed()
m = torch.nn.Sequential(
torch.nn.Linear(hidden, hidden, bias=False),
torch.nn.Linear(hidden, hidden, bias=False),
).to(device="cuda", dtype=torch.bfloat16)
return m
def run_steps(opt, model, n_steps: int = 50):
"""Run n_steps of random-gradient optimiser updates."""
_seed()
for _ in range(n_steps):
for p in model.parameters():
p.grad = torch.randn_like(p)
opt.step()
opt.zero_grad()
def extract_state_flat(opt) -> dict[str, torch.Tensor]:
out = {}
for i, (param, state) in enumerate(opt.state.items()):
for k, v in state.items():
if hasattr(v, "materialize"):
v = v.materialize()
if isinstance(v, torch.Tensor) and v.is_floating_point:
out[f"param{i}/{k}"] = v.detach().float().cpu().flatten()
return out
def compare(name: str, opt_cls, opt_kwargs: dict, n_steps: int = 50):
model_fused = make_model()
opt_fused = opt_cls(
model_fused.parameters(),
fused=True,
quantize=False,
master_weight_bits=None,
**opt_kwargs,
)
run_steps(opt_fused, model_fused, n_steps)
states_fused = extract_state_flat(opt_fused)
model_unfused = make_model()
opt_unfused = opt_cls(
model_unfused.parameters(),
fused=False,
quantize=False,
master_weight_bits=None,
**opt_kwargs,
)
run_steps(opt_unfused, model_unfused, n_steps)
states_unfused = extract_state_flat(opt_unfused)
print(f"\n{'='*60}")
print(f" {name} | quantize=False | {n_steps} steps")
print(f"{'='*60}")
any_diff = False
for key in sorted(states_fused.keys()):
if key not in states_unfused:
continue
f = states_fused[key]
u = states_unfused[key]
max_abs_diff = (f - u).abs().max().item()
rel_diff = ((f - u).abs() / (u.abs() + 1e-12)).max().item()
is_bad = max_abs_diff > 1e-6
tag = " 💀💀💀" if is_bad else " ✅✅✅"
if is_bad:
any_diff = True
print(
f" {key:30s} max|diff|={max_abs_diff:.3e} "
f"max_rel={rel_diff:.3e}{tag}"
)
p_fused = [p.float().cpu() for p in model_fused.parameters()]
p_unfused = [p.float().cpu() for p in model_unfused.parameters()]
for i, (pf, pu) in enumerate(zip(p_fused, p_unfused)):
d = (pf - pu).abs().max().item()
tag = " 💀💀💀" if d > 1e-6 else " ✅✅✅"
if d > 1e-6:
any_diff = True
print(f" {'param' + str(i) + '/weight':30s} max|diff|={d:.3e}{tag}")
return any_diff
if __name__ == "__main__":
results = {}
results["FlashLion"] = compare(
"FlashLion",
FlashLion,
dict(lr=1e-3, betas=(0.9, 0.99), weight_decay=0.01),
)
results["FlashSGD"] = compare(
"FlashSGD (momentum=0.9)",
FlashSGD,
dict(lr=0.01, momentum=0.9, weight_decay=0.01),
)
results["FlashAdam"] = compare(
"FlashAdam",
FlashAdam,
dict(lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01),
)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels