[Cherry-Pick][RL] Support moe_topk_select using Paddle native operators and Add fused stack-transpose-quant for BlockWiseFP8 MoE weight quantization and swiglu-fp8-quant op for DeepGemmFusedMoE for training alignment (#6850)#6935
Open
DanielSun11 wants to merge 11 commits intoPaddlePaddle:release/2.5from
Open
Conversation
|
Thanks for your contribution! |
Contributor
There was a problem hiding this comment.
Pull request overview
该 PR 主要围绕 MoE 的 TopK 选择与 FP8(含 UE8M0 scale)量化路径做增强:在 DeepGemmFusedMoE 侧引入基于 Paddle 原生算子的 moe_topk_select 选择逻辑,并新增/对齐 Fleet 侧的 fused 算子接入(stack+transpose+fp8 quant、以及带 routed prob 的 fused swiglu+fp8 quant),同时补充对应单测与环境开关。
Changes:
- 在 DeepGemm MoE 路径新增
moe_topk_select(Paddle native)并通过FD_USE_PHI_TOPK控制启用。 - 新增
fp8_utils.fused_stack_transpose_quant并在 Triton MoE 的 UE8M0 量化权重处理里通过FD_USE_FLEET_FP8_QUANT切到 Fleet fused kernel。 - 新增
FD_MOE_PROB_IN_ADVANCE/FD_USE_FLEET_FP8_QUANT/FD_USE_PHI_TOPK环境变量,并补充 deepgemm/fused-quant 相关测试用例。
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/operators/test_noaux_tc_redundant.py | 增加基于 moe_topk_select 的 group-topk 对齐测试 |
| tests/layers/test_fp8_ue8m0.py | 增加 Fleet fused stack+transpose+quant 及 fused_stack_transpose_quant 的单测覆盖 |
| tests/layers/test_deepgemm_fused_moe.py | 新增 DeepGemmFusedMoE 多路径(TP/EP、prob-in-advance、phi permute 等)对齐测试 |
| fastdeploy/model_executor/layers/quantization/fp8_utils.py | 新增 try_import、paddlefleet_ops 引入与 fused_stack_transpose_quant 等量化辅助能力 |
| fastdeploy/model_executor/layers/moe/fused_moe_triton_backend.py | UE8M0 权重量化新增 Fleet fused 路径(chunk 化处理) |
| fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py | 引入 moe_topk_select 与 FD_MOE_PROB_IN_ADVANCE 下的 fused swiglu+fp8 quant 路径,并调整 unpermute combine 开关 |
| fastdeploy/envs.py | 新增 MoE TopK/Fleet FP8 quant/prob-in-advance 等环境变量开关 |
Comment on lines
+194
to
+197
| # Weight must be fp8 and scale must cover all experts | ||
| self.assertEqual(layer.up_gate_proj_weight.dtype, paddle.float8_e4m3fn) | ||
| self.assertEqual(layer.up_gate_proj_weight_scale_inv.shape[0], num_experts) | ||
|
|
Comment on lines
+252
to
+270
| # Helper: build a minimal fake paddlefleet_ops namespace | ||
| # ------------------------------------------------------------------ | ||
| def _fake_paddlefleet_ops(self, *, has_op=True, use_pow2_scale_result=False, num_experts=4, out=128, inp=64): | ||
| """Return a mock object that optionally exposes fuse_stack_transpose_fp8_quant.""" | ||
| fake_ops = mock.MagicMock() | ||
| if has_op: | ||
| stacked_w = paddle.zeros([num_experts, inp, out], dtype=paddle.float8_e4m3fn) | ||
| scale = paddle.ones([num_experts * inp, out // 128 if out >= 128 else 1], dtype=paddle.float32) | ||
|
|
||
| def fake_quant(expert_weight_list, use_pow2_scale, use_ue8m0_w, use_ue8m0_s): | ||
| return stacked_w, scale | ||
|
|
||
| fake_ops.fuse_stack_transpose_fp8_quant = fake_quant | ||
| else: | ||
| # Simulate that the attribute is absent | ||
| del fake_ops.fuse_stack_transpose_fp8_quant | ||
| return fake_ops | ||
|
|
||
| # ------------------------------------------------------------------ |
Comment on lines
+29
to
+35
| deep_ep_stub = types.ModuleType("fastdeploy.model_executor.layers.moe.ep.deep_ep") | ||
| deep_ep_stub.Buffer = types.SimpleNamespace(capture=lambda: object()) | ||
| sys.modules["fastdeploy.model_executor.layers.moe.ep.deep_ep"] = deep_ep_stub | ||
|
|
||
| from fastdeploy.model_executor.layers.moe import ( # noqa: E402 | ||
| fused_moe_deepgemm_backend as backend, | ||
| ) |
Comment on lines
+191
to
+231
| dtype = paddle.float8_e4m3fn # current_platform.fp8_dtype() if dtype is None else dtype | ||
| assert x.shape[-1] % group_size == 0, ( | ||
| f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}" | ||
| ) | ||
| assert x.stride(-1) == 1, "`x` groups must be contiguous" | ||
|
|
||
| fp8_min, fp8_max = -224.0, 224.0 # get_fp8_min_max() | ||
|
|
||
| assert out_q is None or out_q.shape == x.shape | ||
| x_q = out_q | ||
| if x_q is None: | ||
| x_q = paddle.empty(x.shape, dtype=dtype) | ||
|
|
||
| shape = x.shape[:-1] + (x.shape[-1] // group_size,) | ||
| x_s = paddle.empty(shape, dtype=paddle.float32) | ||
|
|
||
| # torch.ops._C.per_token_group_fp8_quant( | ||
| # x.contiguous(), x_q, x_s, group_size, eps, fp8_min, fp8_max, use_ue8m0 | ||
| # ) | ||
| # return x_q, x_s | ||
| M = x.numel() // group_size | ||
| N = group_size | ||
| BLOCK = triton.next_power_of_2(N) | ||
| # heuristics for number of warps | ||
| num_warps = min(max(BLOCK // 256, 1), 8) | ||
| num_stages = 1 | ||
| _per_token_group_quant_fp8[(M,)]( | ||
| x, | ||
| x_q, | ||
| x_s, | ||
| group_size, | ||
| x.shape[1], | ||
| x.stride(0), | ||
| eps, | ||
| fp8_min=fp8_min, | ||
| fp8_max=fp8_max, | ||
| use_ue8m0=use_ue8m0, | ||
| BLOCK=BLOCK, | ||
| num_warps=num_warps, | ||
| num_stages=num_stages, | ||
| ) |
Comment on lines
+174
to
+188
| """Function to perform per-token-group quantization on an input tensor `x`. | ||
| It converts the tensor values into signed float8 values and returns the | ||
| quantized tensor along with the scaling factor used for quantization. | ||
| Args: | ||
| x: The input tensor with ndim >= 2. | ||
| group_size: The group size used for quantization. | ||
| eps: The minimum to avoid dividing zero. | ||
| dtype: The dtype of output tensor. Note that only `torch.float8_e4m3fn` | ||
| is supported for now. | ||
| column_major_scales: Outputs scales in column major. | ||
| tma_aligned_scales: Outputs scales in TMA-aligned layout. | ||
| out_q: Optional output tensor. If not provided, function will create. | ||
| Returns: | ||
| tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the | ||
| scaling factor. |
Comment on lines
+111
to
114
| if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE: | ||
| ffn_in_x, ffn_in_x_scale_tensor = paddlefleet_ops.fuse_weighted_swiglu_fp8_quant( | ||
| ffn_out, dst_weights, using_pow2_scaling=True, use_ue8m0=not disable_ue8m0_cast | ||
| ) |
Comment on lines
+460
to
+463
| if fastdeploy.envs.FD_MOE_PROB_IN_ADVANCE: | ||
| ffn_in_x, ffn_in_x_scale_tensor = paddlefleet_ops.fuse_weighted_swiglu_fp8_quant( | ||
| ffn_out, dst_weights, using_pow2_scaling=True, use_ue8m0=self.quant_config.deepgemm_scale_ue8m0 | ||
| ) |
Comment on lines
+1630
to
+1633
| scale_list = [] | ||
| chunk_size = 64 | ||
|
|
||
| for start_idx in range(0, num_expert, chunk_size): |
Comment on lines
+44
to
+61
| def _deepgemm_available() -> bool: | ||
| """Try to JIT-compile a minimal deepgemm kernel; return False on failure.""" | ||
| try: | ||
| from fastdeploy.model_executor.layers.quantization.fp8_utils import deep_gemm | ||
|
|
||
| lhs = paddle.zeros([128, 128], dtype="float8_e4m3fn") | ||
| lhs_scale = paddle.ones([128, 1], dtype="float32") | ||
| rhs = paddle.zeros([1, 128, 128], dtype="float8_e4m3fn") | ||
| rhs_scale = paddle.ones([1, 1, 1], dtype="float32") | ||
| out = paddle.empty([128, 128], dtype="bfloat16") | ||
| m_indices = paddle.zeros([128], dtype="int32") | ||
| deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((lhs, lhs_scale), (rhs, rhs_scale), out, m_indices) | ||
| return True | ||
| except Exception: | ||
| return False | ||
|
|
||
|
|
||
| _DEEPGEMM_AVAILABLE = _deepgemm_available() |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## release/2.5 #6935 +/- ##
==============================================
Coverage ? 68.95%
==============================================
Files ? 389
Lines ? 53218
Branches ? 8355
==============================================
Hits ? 36696
Misses ? 13862
Partials ? 2660
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.