Skip to content

Fused kernels truncate optimizer states to PARAM_DTYPE when quantize=False #2

@Phoenix8215

Description

@Phoenix8215

Hi @josejg When using quantize=False with fused=True , the Triton kernels _triton_momentum_kernel and _triton_adam_kernel incorrectly cast optimizer states to PARAM_DTYPE before storing them back to memory. Since the underlying storage tensors are fp32, this introduces an fp32→bf16→fp32 roundtrip every step, accumulating significant precision loss. :-(

How to see it

import torch

from flashoptim import FlashLion, FlashAdam, FlashSGD


def _seed():
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)


def make_model(hidden: int = 256) -> torch.nn.Module:
    """A tiny bf16 model on CUDA."""
    _seed()
    m = torch.nn.Sequential(
        torch.nn.Linear(hidden, hidden, bias=False),
        torch.nn.Linear(hidden, hidden, bias=False),
    ).to(device="cuda", dtype=torch.bfloat16)
    return m


def run_steps(opt, model, n_steps: int = 50):
    """Run n_steps of random-gradient optimiser updates."""
    _seed()
    for _ in range(n_steps):
        for p in model.parameters():
            p.grad = torch.randn_like(p)
        opt.step()
        opt.zero_grad()


def extract_state_flat(opt) -> dict[str, torch.Tensor]:
    out = {}
    for i, (param, state) in enumerate(opt.state.items()):
        for k, v in state.items():
            if hasattr(v, "materialize"):
                v = v.materialize()
            if isinstance(v, torch.Tensor) and v.is_floating_point:
                out[f"param{i}/{k}"] = v.detach().float().cpu().flatten()
    return out


def compare(name: str, opt_cls, opt_kwargs: dict, n_steps: int = 50):
    model_fused = make_model()
    opt_fused = opt_cls(
        model_fused.parameters(),
        fused=True,
        quantize=False,           
        master_weight_bits=None,  
        **opt_kwargs,
    )
    run_steps(opt_fused, model_fused, n_steps)
    states_fused = extract_state_flat(opt_fused)

    model_unfused = make_model()
    opt_unfused = opt_cls(
        model_unfused.parameters(),
        fused=False,
        quantize=False,
        master_weight_bits=None,
        **opt_kwargs,
    )
    run_steps(opt_unfused, model_unfused, n_steps)
    states_unfused = extract_state_flat(opt_unfused)

    print(f"\n{'='*60}")
    print(f"  {name}  |  quantize=False  |  {n_steps} steps")
    print(f"{'='*60}")

    any_diff = False
    for key in sorted(states_fused.keys()):
        if key not in states_unfused:
            continue
        f = states_fused[key]
        u = states_unfused[key]
        max_abs_diff = (f - u).abs().max().item()
        rel_diff = ((f - u).abs() / (u.abs() + 1e-12)).max().item()

        is_bad = max_abs_diff > 1e-6
        tag = "  💀💀💀" if is_bad else "  ✅✅✅"
        if is_bad:
            any_diff = True
        print(
            f"  {key:30s}  max|diff|={max_abs_diff:.3e}  "
            f"max_rel={rel_diff:.3e}{tag}"
        )

    p_fused = [p.float().cpu() for p in model_fused.parameters()]
    p_unfused = [p.float().cpu() for p in model_unfused.parameters()]
    for i, (pf, pu) in enumerate(zip(p_fused, p_unfused)):
        d = (pf - pu).abs().max().item()
        tag = "  💀💀💀" if d > 1e-6 else "  ✅✅✅"
        if d > 1e-6:
            any_diff = True
        print(f"  {'param' + str(i) + '/weight':30s}  max|diff|={d:.3e}{tag}")
    return any_diff


if __name__ == "__main__":
    results = {}

    results["FlashLion"] = compare(
        "FlashLion",
        FlashLion,
        dict(lr=1e-3, betas=(0.9, 0.99), weight_decay=0.01),
    )
    results["FlashSGD"] = compare(
        "FlashSGD (momentum=0.9)",
        FlashSGD,
        dict(lr=0.01, momentum=0.9, weight_decay=0.01),
    )
    results["FlashAdam"] = compare(
        "FlashAdam",
        FlashAdam,
        dict(lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01),
    )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions