Skip to content

[Cherry-Pick][RL] Support moe_topk_select using Paddle native operators and Add fused stack-transpose-quant for BlockWiseFP8 MoE weight quantization and swiglu-fp8-quant op for DeepGemmFusedMoE for training alignment (#6850)#6935

Open
DanielSun11 wants to merge 11 commits intoPaddlePaddle:release/2.5from
DanielSun11:topk_cp_2.5

Conversation

@DanielSun11
Copy link

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings March 19, 2026 08:35
@paddle-bot
Copy link

paddle-bot bot commented Mar 19, 2026

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 主要围绕 MoE 的 TopK 选择与 FP8(含 UE8M0 scale)量化路径做增强:在 DeepGemmFusedMoE 侧引入基于 Paddle 原生算子的 moe_topk_select 选择逻辑,并新增/对齐 Fleet 侧的 fused 算子接入(stack+transpose+fp8 quant、以及带 routed prob 的 fused swiglu+fp8 quant),同时补充对应单测与环境开关。

Changes:

  • 在 DeepGemm MoE 路径新增 moe_topk_select(Paddle native)并通过 FD_USE_PHI_TOPK 控制启用。
  • 新增 fp8_utils.fused_stack_transpose_quant 并在 Triton MoE 的 UE8M0 量化权重处理里通过 FD_USE_FLEET_FP8_QUANT 切到 Fleet fused kernel。
  • 新增 FD_MOE_PROB_IN_ADVANCE/FD_USE_FLEET_FP8_QUANT/FD_USE_PHI_TOPK 环境变量,并补充 deepgemm/fused-quant 相关测试用例。

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
tests/operators/test_noaux_tc_redundant.py 增加基于 moe_topk_select 的 group-topk 对齐测试
tests/layers/test_fp8_ue8m0.py 增加 Fleet fused stack+transpose+quant 及 fused_stack_transpose_quant 的单测覆盖
tests/layers/test_deepgemm_fused_moe.py 新增 DeepGemmFusedMoE 多路径(TP/EP、prob-in-advance、phi permute 等)对齐测试
fastdeploy/model_executor/layers/quantization/fp8_utils.py 新增 try_importpaddlefleet_ops 引入与 fused_stack_transpose_quant 等量化辅助能力
fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py UE8M0 权重量化新增 Fleet fused 路径(chunk 化处理)
fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py 引入 moe_topk_selectFD_MOE_PROB_IN_ADVANCE 下的 fused swiglu+fp8 quant 路径,并调整 unpermute combine 开关
fastdeploy/envs.py 新增 MoE TopK/Fleet FP8 quant/prob-in-advance 等环境变量开关

Comment on lines +194 to +197
# Weight must be fp8 and scale must cover all experts
self.assertEqual(layer.up_gate_proj_weight.dtype, paddle.float8_e4m3fn)
self.assertEqual(layer.up_gate_proj_weight_scale_inv.shape[0], num_experts)

Comment on lines +252 to +270
# Helper: build a minimal fake paddlefleet_ops namespace
# ------------------------------------------------------------------
def _fake_paddlefleet_ops(self, *, has_op=True, use_pow2_scale_result=False, num_experts=4, out=128, inp=64):
"""Return a mock object that optionally exposes fuse_stack_transpose_fp8_quant."""
fake_ops = mock.MagicMock()
if has_op:
stacked_w = paddle.zeros([num_experts, inp, out], dtype=paddle.float8_e4m3fn)
scale = paddle.ones([num_experts * inp, out // 128 if out >= 128 else 1], dtype=paddle.float32)

def fake_quant(expert_weight_list, use_pow2_scale, use_ue8m0_w, use_ue8m0_s):
return stacked_w, scale

fake_ops.fuse_stack_transpose_fp8_quant = fake_quant
else:
# Simulate that the attribute is absent
del fake_ops.fuse_stack_transpose_fp8_quant
return fake_ops

# ------------------------------------------------------------------
Comment on lines +29 to +35
deep_ep_stub = types.ModuleType("fastdeploy.model_executor.layers.moe.ep.deep_ep")
deep_ep_stub.Buffer = types.SimpleNamespace(capture=lambda: object())
sys.modules["fastdeploy.model_executor.layers.moe.ep.deep_ep"] = deep_ep_stub

from fastdeploy.model_executor.layers.moe import ( # noqa: E402
fused_moe_deepgemm_backend as backend,
)
Comment on lines +191 to +231
dtype = paddle.float8_e4m3fn # current_platform.fp8_dtype() if dtype is None else dtype
assert x.shape[-1] % group_size == 0, (
f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}"
)
assert x.stride(-1) == 1, "`x` groups must be contiguous"

fp8_min, fp8_max = -224.0, 224.0 # get_fp8_min_max()

assert out_q is None or out_q.shape == x.shape
x_q = out_q
if x_q is None:
x_q = paddle.empty(x.shape, dtype=dtype)

shape = x.shape[:-1] + (x.shape[-1] // group_size,)
x_s = paddle.empty(shape, dtype=paddle.float32)

# torch.ops._C.per_token_group_fp8_quant(
# x.contiguous(), x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0
# )
# return x_q, x_s
M = x.numel() // group_size
N = group_size
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
Comment on lines +174 to +188
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dtype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
tma_aligned_scales: Outputs scales in TMA-aligned layout.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor.
Comment on lines +111 to 114
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
ffn_in_x, ffn_in_x_scale_tensor = paddlefleet_ops.fuse_weighted_swiglu_fp8_quant(
ffn_out, dst_weights, using_pow2_scaling=True, use_ue8m0=not disable_ue8m0_cast
)
Comment on lines +460 to +463
if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE:
ffn_in_x, ffn_in_x_scale_tensor = paddlefleet_ops.fuse_weighted_swiglu_fp8_quant(
ffn_out, dst_weights, using_pow2_scaling=True, use_ue8m0=self.quant_config.deepgemm_scale_ue8m0
)
Comment on lines +1630 to +1633
scale_list = []
chunk_size = 64

for start_idx in range(0, num_expert, chunk_size):
Comment on lines +44 to +61
def _deepgemm_available() -> bool:
"""Try to JIT-compile a minimal deepgemm kernel; return False on failure."""
try:
from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm

lhs = paddle.zeros([128, 128], dtype="float8_e4m3fn")
lhs_scale = paddle.ones([128, 1], dtype="float32")
rhs = paddle.zeros([1, 128, 128], dtype="float8_e4m3fn")
rhs_scale = paddle.ones([1, 1, 1], dtype="float32")
out = paddle.empty([128, 128], dtype="bfloat16")
m_indices = paddle.zeros([128], dtype="int32")
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((lhs, lhs_scale), (rhs, rhs_scale), out, m_indices)
return True
except Exception:
return False


_DEEPGEMM_AVAILABLE = _deepgemm_available()
@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 54.78261% with 52 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.5@3d6ec49). Learn more about missing BASE report.

Files with missing lines Patch % Lines
..._executor/layers/moe/fused_moe_deepgemm_backend.py 42.22% 23 Missing and 3 partials ⚠️
...oy/model_executor/layers/quantization/fp8_utils.py 43.47% 23 Missing and 3 partials ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.5    #6935   +/-   ##
==============================================
  Coverage               ?   68.95%           
==============================================
  Files                  ?      389           
  Lines                  ?    53218           
  Branches               ?     8355           
==============================================
  Hits                   ?    36696           
  Misses                 ?    13862           
  Partials               ?     2660           
Flag Coverage Δ
GPU 68.95% <54.78%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants