Skip to content

Conversation

@pwilkin
Copy link
Collaborator

@pwilkin pwilkin commented Dec 13, 2025

This change adds a dedicated autoregressive version of delta-net which short cirtuits all the recurrent computations for n_seq_tokens == 1. The end result is roughly a 40% bump in token generation speed.

@jeffbolznv
Copy link
Collaborator

before:

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -r 10 --prio 1 -m c:\models\Qwen3-Next-80B-A3B-Instruct-Q2_K_L.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q2_K - Medium |  27.23 GiB |    79.67 B | Vulkan     |  99 |  1 |           pp512 |      3071.35 ± 18.75 |
| qwen3next 80B.A3B Q2_K - Medium |  27.23 GiB |    79.67 B | Vulkan     |  99 |  1 |           tg128 |         74.65 ± 0.39 |

build: 5266379bc (7387)

after:

Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -fa 1 -r 10 --prio 1 -m c:\models\Qwen3-Next-80B-A3B-Instruct-Q2_K_L.gguf
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q2_K - Medium |  27.23 GiB |    79.67 B | Vulkan     |  99 |  1 |           pp512 |      3075.92 ± 14.32 |
| qwen3next 80B.A3B Q2_K - Medium |  27.23 GiB |    79.67 B | Vulkan     |  99 |  1 |           tg128 |         94.36 ± 1.10 |

build: 4a494ab77 (7387)

@jacekpoplawski
Copy link
Contributor

before:

ggml_cuda_init: found 3 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 2: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes

model size params backend ngl test t/s
qwen3next 80B.A3B Q6_K 61.20 GiB 79.67 B CUDA 99 pp512 740.56 ± 4.37
qwen3next 80B.A3B Q6_K 61.20 GiB 79.67 B CUDA 99 tg128 43.33 ± 0.35

after:

ggml_cuda_init: found 3 CUDA devices:
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
Device 2: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes

model size params backend ngl test t/s
qwen3next 80B.A3B Q6_K 61.20 GiB 79.67 B CUDA 99 pp512 739.84 ± 3.79
qwen3next 80B.A3B Q6_K 61.20 GiB 79.67 B CUDA 99 tg128 51.40 ± 1.09

@IIIIIllllIIIIIlllll
Copy link

Sadly, there are no changes on AI MAX+ 395 (ROCm 7.1.1 build, latest code).

master:

mark@MarkPC:~/llama.cpp/llama.cpp-master$  ./llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL.gguf -p 2048 -n 32 -ub 2048 -b 2048 -fa 1 -mmp 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        497.00 ± 2.64 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         17.04 ± 0.26 |

build: unknown (0)

this PR:

mark@MarkPC:~/llama.cpp-lean_mean_token_machine/build/bin$ ./llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL.gguf -p 2048 -n 32 -ub 2048 -b 2048 -fa 1 -mmp 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        493.22 ± 2.55 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         18.68 ± 0.01 |

build: unknown (0)

Edited:
Sorry, I didn't realize this was an optimization for CUDA, but I'm keeping the comment here for other AMD users to see.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Dec 13, 2025

Nah, this should be a general optimization. This means there are other bottlenecks in play for the ROCm implementation than the slow delta-net.

Can you run inference with --verbose and GGML_SCHED_DEBUG=2 to dump the entire graph split?

@othermod
Copy link

Sadly, there are no changes on AI MAX+ 395 (ROCm 7.1.1 build, latest code).

master:

mark@MarkPC:~/llama.cpp/llama.cpp-master$  ./llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL.gguf -p 2048 -n 32 -ub 2048 -b 2048 -fa 1 -mmp 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        497.00 ± 2.64 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         17.04 ± 0.26 |

build: unknown (0)

this PR:

mark@MarkPC:~/llama.cpp-lean_mean_token_machine/build/bin$ ./llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL.gguf -p 2048 -n 32 -ub 2048 -b 2048 -fa 1 -mmp 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        493.22 ± 2.55 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         18.68 ± 0.01 |

build: unknown (0)

That looks like a 10% bump, right?

@IIIIIllllIIIIIlllll
Copy link

Nah, this should be a general optimization. This means there are other bottlenecks in play for the ROCm implementation than the slow delta-net.

Can you run inference with --verbose and GGML_SCHED_DEBUG=2 to dump the entire graph split?

@pwilkin Hopefully this log is what you need :)
qwen-next-bench.zip

Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an excessive amount of conts and asserts here, most of which I'm sure are unnecessary, but I think qwen3next needs a general cleanup of these anyway, so will leave that to you at a later stage.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Dec 13, 2025

@IIIIIllllIIIIIlllll can you do a bench for -ub 512,1024,2048? Shouldn't influence generation, but I'm wondering about the pp overhead of such huge graphs.

@CISC
Copy link
Collaborator

CISC commented Dec 13, 2025

@pwilkin In case you're wondering, I think the ggml_cont_Xd of q/k/v/state is unnecessary (the latter may be beneficial, though hardly critical).

@IIIIIllllIIIIIlllll
Copy link

IIIIIllllIIIIIlllll commented Dec 13, 2025

mark@MarkPC:~/llama.cpp-lean_mean_token_machine/build/bin$ ./llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL.gguf -p 2048 -n 32 -ub 512,1024,2048 -fa 1 -mmp 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |      512 |  1 |    0 |          pp2048 |        445.64 ± 1.74 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |      512 |  1 |    0 |            tg32 |         18.66 ± 0.03 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     1024 |  1 |    0 |          pp2048 |        469.62 ± 0.83 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     1024 |  1 |    0 |            tg32 |         18.71 ± 0.04 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        503.91 ± 0.32 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         18.70 ± 0.03 |

build: unknown (0)

@pwilkin
Here, or use --verbose & GGML_SCHED_DEBUG=2 do again ?

@mpapili
Copy link

mpapili commented Dec 13, 2025

Adding some multi-GPU ROCm data with several experts offloaded to CPU:

Setup

Specs

CPU: Ryzen 9 3950x
Memory: 64GB DDR4 3000mhz
GPU1: Rx 6800 (16GB)
GPU2: Rx 6800 (16GB)

Model

Qwen3-Next-80B-A3B-Thinking-Q4_K_S

Command

/llama.cpp/build/build/bin/llama-server --host 127.0.0.1 --jinja --min-p 0 --mlock --mmap -ncmoe 20 --port 44163 --repeat-penalty 1.05 --temp 0.5 --top-k 0.20 --top-p 0.95 --warmup --alias Qwen3-Next-80B-A3B-Thinking-Q4_K_S --ctx-size 75000 --cache-type-k q8_0 --cache-type-v q8_0 --flash-attn on --model /models/Qwen3-Next-80B-A3B-Thinking-Q4_K_S.gguf --n-gpu-layers 999 --threads 8 --tensor-split 67,33 --log-verbose

Results

ggml-org/main Branch

17.3 tokens/second

pwilkin:lean_mean_token_machine Branch

22.5 tokens/second

Increase of >5 tokens/second or ~30% increase in token-gen speed

@github-actions github-actions bot added the model Model specific label Dec 13, 2025
@heislera763
Copy link

Some 4x V100 32GB results w/ q8_0 gguf

master:

alexander@alexander-main:~/.llama-server$ llama-bench -m models/qwen_qwen3-next-80b-a3b-thinking-q8_0.gguf -fa 1
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
  Device 0: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
  Device 1: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
  Device 2: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
  Device 3: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |  1 |           pp512 |        363.55 ± 1.14 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |  1 |           tg128 |         38.39 ± 0.08 |

build: 5266379bc (7387)

lean_mean_token_machine:

alexander@alexander-main:~/.llama-server$ llama-bench -m models/qwen_qwen3-next-80b-a3b-thinking-q8_0.gguf -fa 1
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
  Device 0: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
  Device 1: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
  Device 2: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
  Device 3: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |  1 |           pp512 |        360.25 ± 2.75 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |  1 |           tg128 |         44.60 ± 0.03 |

build: 4a494ab77 (7387)

Before: 38.39 t/s
After: 44.60 t/s
Gain: 6.21 t/s (+16.2%)

@heislera763
Copy link

I was feeling a bit bored and naively asked gemini-cli to make the changes CISC suggested, it seems like it's consistently faster and it seems coherent (only did very brief testing). I do remember it breaking when it changed the sum_row conts though, but I don't know if any of the rest are needed.

cont/assert reduction:

alexander@alexander-main:~/dev$ ./llama.cpp/build/bin/llama-bench -m ~/.llama-server/models/qwen_qwen3-next-80b-a3b-thinking-q8_0.gguf -fa 1
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 CUDA devices:
  Device 0: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
  Device 1: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
  Device 2: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
  Device 3: Tesla V100-SXM2-32GB, compute capability 7.0, VMM: yes
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |  1 |           pp512 |        370.85 ± 1.33 |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |  1 |           tg128 |         45.79 ± 0.02 |

build: df1be3be (14)

Gain of 1.19 t/s over this commit (+2.67%) for a total gain of 7.4 t/s (+19.3%) over master

patch file if your interested: qwen3.patch

@CISC
Copy link
Collaborator

CISC commented Dec 13, 2025

| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  78.98 GiB |    79.67 B | CUDA       |  99 |  1 |           pp512 |        370.85 ± 1.33 |

Nice little PP boost.

@Som-anon
Copy link

anon@t480 ~/work/llama.cpp (git)-[remotes/origin/HEAD] % ./build/bin/llama-bench --model ~/models/Qwen-Next/Qwen3-Next-80B-A3B-Instruct-UD-Q5_K_XL-00001-of-00002.gguf
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3next 80B.A3B Q5_K - Medium |  52.86 GiB |    79.67 B | BLAS       |       4 |           pp512 |         13.15 ± 0.04 |
| qwen3next 80B.A3B Q5_K - Medium |  52.86 GiB |    79.67 B | BLAS       |       4 |           tg128 |          3.36 ± 0.01 |

build: 4d5ae24c0 (7386)
./build/bin/llama-bench --model   1923.22s user 22.33s system 425% cpu 7:37.05 total
anon@t480 ~/work/llama.cpp (git)-[remotes/origin/HEAD] % cd ~/work/llama.cpp-qwen.next
anon@t480 ~/work/llama.cpp-qwen.next (git)-[lean_mean_token_machine] % ./build/bin/llama-bench --model ~/models/Qwen-Next/Qwen3-Next-80B-A3B-Instruct-UD-Q5_K_XL-00001-of-00002.gguf
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3next 80B.A3B Q5_K - Medium |  52.86 GiB |    79.67 B | BLAS       |       4 |           pp512 |         13.14 ± 0.10 |
| qwen3next 80B.A3B Q5_K - Medium |  52.86 GiB |    79.67 B | BLAS       |       4 |           tg128 |          2.65 ± 0.00 |

build: 4a494ab7 (7387)
./build/bin/llama-bench --model   2143.68s user 8.83s system 450% cpu 7:58.32 total
anon@t480 ~/work/llama.cpp-qwen.next (git)-[lean_mean_token_machine] % cd ~/work/llama.cpp
anon@t480 ~/work/llama.cpp (git)-[remotes/origin/HEAD] % ./build/bin/llama-bench --model ~/models/Qwen-Next/Qwen__Qwen3-Next-80B-A3B-Instruct-Q5_K_S.gguf
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3next 80B.A3B Q5_K - Small |  51.21 GiB |    79.67 B | BLAS       |       4 |           pp512 |         12.98 ± 0.03 |
| qwen3next 80B.A3B Q5_K - Small |  51.21 GiB |    79.67 B | BLAS       |       4 |           tg128 |          3.42 ± 0.00 |

build: 4d5ae24c0 (7386)
./build/bin/llama-bench --model   1949.32s user 21.52s system 426% cpu 7:41.65 total
anon@t480 ~/work/llama.cpp (git)-[remotes/origin/HEAD] % cd ~/work/llama.cpp-qwen.next
anon@t480 ~/work/llama.cpp-qwen.next (git)-[lean_mean_token_machine] % ./build/bin/llama-bench --model ~/models/Qwen-Next/Qwen__Qwen3-Next-80B-A3B-Instruct-Q5_K_S.gguf
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3next 80B.A3B Q5_K - Small |  51.21 GiB |    79.67 B | BLAS       |       4 |           pp512 |         13.11 ± 0.10 |
| qwen3next 80B.A3B Q5_K - Small |  51.21 GiB |    79.67 B | BLAS       |       4 |           tg128 |          2.67 ± 0.00 |

build: 4a494ab7 (7387)
./build/bin/llama-bench --model   2116.94s user 8.65s system 448% cpu 7:53.83 total

@jeffbolznv
Copy link
Collaborator

patch file if your interested: qwen3.patch

Worth a few percent on my system:

| qwen3next 80B.A3B Q2_K - Medium |  27.23 GiB |    79.67 B | Vulkan     |  99 |  1 |           tg128 |         96.69 ± 0.66 |

The number of CONT ops for -p 0 -n 128 -r 1 decreases from 119196 to 86688, so still plenty to go.

@fuutott
Copy link

fuutott commented Dec 13, 2025

D:\llama>d:/llama/latest/llama-bench.exe   -m d:\models\lmstudio-community\Qwen3-Next-80B-A3B-Instruct-GGUF\Qwen3-Next-80B-A3B-Instruct-Q4_K_M.gguf   
-p 512   -n 512   -b 1024   -ub 512   -ngl 99  -mmp 0   -fa 1   -o md   -r 3   -d 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes

| model                          |       size |     params | backend    | ngl | n_batch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q4_K - Medium |  45.08 GiB |    79.67 B | CUDA       |  99 |    1024 |  1 |    0 |           pp512 |      1747.04 ± 56.82 |
| qwen3next 80B.A3B Q4_K - Medium |  45.08 GiB |    79.67 B | CUDA       |  99 |    1024 |  1 |    0 |           tg512 |         21.44 ± 0.11 |

build: 5266379bc (7387)
D:\llama>d:/llama/llama.cpp/build/bin/llama-bench.exe   -m d:\models\lmstudio-community\Qwen3-Next-80B-A3B-Instruct-GGUF\Qwen3-Next-80B-A3B-Instruct-Q4_K_M.gguf   
-p 512   -n 512   -b 1024   -ub 512   -ngl 99  -mmp 0   -fa 1   -o md   -r 3   -d 0
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | n_batch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q4_K - Medium |  45.08 GiB |    79.67 B | CUDA       |  99 |    1024 |  1 |    0 |           pp512 |       1763.99 ± 4.62 |
| qwen3next 80B.A3B Q4_K - Medium |  45.08 GiB |    79.67 B | CUDA       |  99 |    1024 |  1 |    0 |           tg512 |         29.36 ± 0.71 |

build: 4a494ab7 (7387)

@IIIIIllllIIIIIlllll
Copy link

IIIIIllllIIIIIlllll commented Dec 13, 2025

please ignore my previous reply. The test results in my previous reply were executed in the PuTTY terminal, and I don't know why they were so bad.

It's really strange, changing -DGGML_HIP_ROCWMMA_FATTN to OFF significantly improved pp's speed...
tg's speed has increased by about 9%.

Perhaps the performance of AI MAX+ 395 has reached its limit (this is questionable).

this PR - DGGML_HIP_ROCWMMA_FATTN=OFF:

/home/mark/llama.cpp-lean_mean_token_machine/build/bin/llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL.gguf -r 5 -p 2048 -n 32 -b 2048 -ub 2048 -fa 1 -mmp 0

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        573.84 ± 1.12 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         26.57 ± 0.03 |

build: unknown (0)

this PR - DGGML_HIP_ROCWMMA_FATTN=ON:

/home/mark/llama.cpp-lean_mean_token_machine/build/bin/llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL.gguf -r 5 -p 2048 -n 32 -b 2048 -ub 2048 -fa 1 -mmp 0

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        502.11 ± 0.80 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         26.47 ± 0.02 |

build: unknown (0)

master- DGGML_HIP_ROCWMMA_FATTN=OFF:

/home/mark/llama.cpp/llama.cpp-master/llama-bench -m /home/mark/Models/Q8/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL/Qwen3-Next-80B-A3B-Instruct-UD-Q8_K_KL.gguf -r 5 -p 2048 -n 32 -b 2048 -ub 2048 -fa 1 -mmp 0

ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon Graphics, gfx1151 (0x1151), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | n_ubatch | fa | mmap |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -------: | -: | ---: | --------------: | -------------------: |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |          pp2048 |        485.42 ± 0.55 |
| qwen3next 80B.A3B Q8_0         |  79.57 GiB |    79.67 B | ROCm       |  99 |     2048 |  1 |    0 |            tg32 |         24.14 ± 0.04 |

build: unknown (0)

Comment on lines +1016 to +1024
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
ggml_tensor * attn_out;
if (n_seq_tokens == 1) {
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
} else if (n_seq_tokens > CHUNK_SIZE) {
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il);
} else {
attn_out = build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is highly not recommended. Instead of adding more branches, we have to figure out how to make the graph static. Start with simplifying the existing graphs by removing redundant ops.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in this case we can't make the graph static since the special branch here is one where the decay mask computation doesn't happen (because n_seq_tokens == 1, so it all collapses to trivial transformations, therefore they can be optimized out).

I can probably remove the recurrent part now since I'm not sure there's a realistic case for it, it'll be either chunking or autoregressive.

@Som-anon
Copy link

Som-anon commented Dec 13, 2025

anon@t480 ~/work/llama.cpp (git)-[remotes/origin/HEAD] % ./build/bin/llama-bench --model ~/models/Qwen-Next/Qwen3-Next-80B-A3B-Instruct-UD-Q5_K_XL-00001-of-00002.gguf
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3next 80B.A3B Q5_K - Medium |  52.86 GiB |    79.67 B | BLAS       |       4 |           pp512 |         13.15 ± 0.04 |
| qwen3next 80B.A3B Q5_K - Medium |  52.86 GiB |    79.67 B | BLAS       |       4 |           tg128 |          3.36 ± 0.01 |

build: 4d5ae24c0 (7386)
./build/bin/llama-bench --model   1923.22s user 22.33s system 425% cpu 7:37.05 total
anon@t480 ~/work/llama.cpp (git)-[remotes/origin/HEAD] % cd ~/work/llama.cpp-qwen.next
anon@t480 ~/work/llama.cpp-qwen.next (git)-[lean_mean_token_machine] % ./build/bin/llama-bench --model ~/models/Qwen-Next/Qwen3-Next-80B-A3B-Instruct-UD-Q5_K_XL-00001-of-00002.gguf
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3next 80B.A3B Q5_K - Medium |  52.86 GiB |    79.67 B | BLAS       |       4 |           pp512 |         13.14 ± 0.10 |
| qwen3next 80B.A3B Q5_K - Medium |  52.86 GiB |    79.67 B | BLAS       |       4 |           tg128 |          2.65 ± 0.00 |

build: 4a494ab7 (7387)
./build/bin/llama-bench --model   2143.68s user 8.83s system 450% cpu 7:58.32 total
anon@t480 ~/work/llama.cpp-qwen.next (git)-[lean_mean_token_machine] % cd ~/work/llama.cpp
anon@t480 ~/work/llama.cpp (git)-[remotes/origin/HEAD] % ./build/bin/llama-bench --model ~/models/Qwen-Next/Qwen__Qwen3-Next-80B-A3B-Instruct-Q5_K_S.gguf
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3next 80B.A3B Q5_K - Small |  51.21 GiB |    79.67 B | BLAS       |       4 |           pp512 |         12.98 ± 0.03 |
| qwen3next 80B.A3B Q5_K - Small |  51.21 GiB |    79.67 B | BLAS       |       4 |           tg128 |          3.42 ± 0.00 |

build: 4d5ae24c0 (7386)
./build/bin/llama-bench --model   1949.32s user 21.52s system 426% cpu 7:41.65 total
anon@t480 ~/work/llama.cpp (git)-[remotes/origin/HEAD] % cd ~/work/llama.cpp-qwen.next
anon@t480 ~/work/llama.cpp-qwen.next (git)-[lean_mean_token_machine] % ./build/bin/llama-bench --model ~/models/Qwen-Next/Qwen__Qwen3-Next-80B-A3B-Instruct-Q5_K_S.gguf
| model                          |       size |     params | backend    | threads |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | ------: | --------------: | -------------------: |
| qwen3next 80B.A3B Q5_K - Small |  51.21 GiB |    79.67 B | BLAS       |       4 |           pp512 |         13.11 ± 0.10 |
| qwen3next 80B.A3B Q5_K - Small |  51.21 GiB |    79.67 B | BLAS       |       4 |           tg128 |          2.67 ± 0.00 |

build: 4a494ab7 (7387)
./build/bin/llama-bench --model   2116.94s user 8.65s system 448% cpu 7:53.83 total

Is there any reason why it could have gotten slower for me? I'm compiling it with

cmake -B build -DGGML_VULKAN=0 -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS && cmake --build build --config Release -j 12

@kiuckhuang
Copy link

got an interesting finding in Win11 + RTX5090, compile with vulkan support and force to use vulkan0 device with pp512 60%+ and tg128 100%+

vulkan0:
$ llama-bench.exe -dev vulkan0 -fa 1 -ngl 99 -m Qwen3-Next-80B-A3B-Instruct-UD-Q2_K_XL.gguf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes
load_backend: loaded CUDA backend from G:\ai\llama.cpp\build\bin\Release\ggml-cuda.dll
load_backend: loaded RPC backend from G:\ai\llama.cpp\build\bin\Release\ggml-rpc.dll
ggml_vulkan: Found 2 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
ggml_vulkan: 1 = Intel(R) UHD Graphics 770 (Intel Corporation) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 32768 | int dot: 1 | matrix cores: none
load_backend: loaded Vulkan backend from G:\ai\llama.cpp\build\bin\Release\ggml-vulkan.dll
load_backend: loaded CPU backend from G:\ai\llama.cpp\build\bin\Release\ggml-cpu-alderlake.dll

model size params backend ngl fa dev test t/s
qwen3next 80B.A3B Q2_K - Medium 27.31 GiB 79.67 B CUDA,Vulkan 99 1 Vulkan0 **pp512 2841.57 ± 9.18
qwen3next 80B.A3B Q2_K - Medium 27.31 GiB 79.67 B CUDA,Vulkan 99 1 Vulkan0 **tg128 90.77 ± 0.26

build: c00ff92 (7389)

cuda0:
$ llama-bench.exe -dev cuda0 -fa 1 -ngl 99 -m Qwen3-Next-80B-A3B-Instruct-UD-Q2_K_XL.gguf
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
Device 0: NVIDIA GeForce RTX 5090, compute capability 12.0, VMM: yes
load_backend: loaded CUDA backend from G:\ai\llama.cpp\build\bin\Release\ggml-cuda.dll
load_backend: loaded RPC backend from G:\ai\llama.cpp\build\bin\Release\ggml-rpc.dll
ggml_vulkan: Found 2 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: NV_coopmat2
ggml_vulkan: 1 = Intel(R) UHD Graphics 770 (Intel Corporation) | uma: 1 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 32768 | int dot: 1 | matrix cores: none
load_backend: loaded Vulkan backend from G:\ai\llama.cpp\build\bin\Release\ggml-vulkan.dll
load_backend: loaded CPU backend from G:\ai\llama.cpp\build\bin\Release\ggml-cpu-alderlake.dll

model size params backend ngl fa dev test t/s
qwen3next 80B.A3B Q2_K - Medium 27.31 GiB 79.67 B CUDA,Vulkan 99 1 CUDA0 pp512 1699.91 ± 44.21
qwen3next 80B.A3B Q2_K - Medium 27.31 GiB 79.67 B CUDA,Vulkan 99 1 CUDA0 tg128 38.16 ± 0.60

build: c00ff92 (7389)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific

Projects

None yet

Development

Successfully merging this pull request may close these issues.