Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,15 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
ggml_tensor * identity,
int il);

ggml_tensor * build_delta_net_autoregressive(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
int il);

ggml_tensor * build_norm_gated(
ggml_tensor * input,
ggml_tensor * weights,
Expand Down
133 changes: 113 additions & 20 deletions src/models/qwen3next.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,10 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking(

GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case

// TODO: can this ever be false?
const bool use_qk_l2norm = true;
const float eps_norm = hparams.f_norm_rms_eps;

if (use_qk_l2norm) {
const float eps_norm = hparams.f_norm_rms_eps;

q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
}
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);

const float scale = 1.0f / sqrtf(S_v);

Expand Down Expand Up @@ -397,15 +392,10 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent(

GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case

// TODO: can this ever be false?
const bool use_qk_l2norm = true;
const float eps_norm = hparams.f_norm_rms_eps;

if (use_qk_l2norm) {
const float eps_norm = hparams.f_norm_rms_eps;

q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);
}
q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);

const float scale = 1.0f / sqrtf(S_v);

Expand Down Expand Up @@ -610,6 +600,104 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent(
return ggml_concat(ctx0, flat_output, flat_state, 0);
}

ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
int il) {
GGML_ASSERT(ggml_is_contiguous(q));
GGML_ASSERT(ggml_is_contiguous(k));
GGML_ASSERT(ggml_is_contiguous(v));
GGML_ASSERT(ggml_is_contiguous(g));
GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state));

const int64_t S_k = q->ne[0];
const int64_t H_k = q->ne[1];
const int64_t n_tokens = q->ne[2];
const int64_t n_seqs = q->ne[3];

const int64_t S_v = v->ne[0];
const int64_t H_v = v->ne[1];

GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing
GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);

GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);

GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case

const float eps_norm = hparams.f_norm_rms_eps;

q = ggml_l2_norm(ctx0, q, eps_norm);
k = ggml_l2_norm(ctx0, k, eps_norm);

const float scale = 1.0f / sqrtf(S_v);

q = ggml_scale(ctx0, q, scale);
beta = ggml_sigmoid(ctx0, beta);

cb(q, "q_in", il);
cb(k, "k_in", il);
cb(v, "v_in", il);
cb(beta, "beta_in", il);
cb(g, "g_in", il);

state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);

ggml_tensor * g_t = ggml_cont_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
ggml_tensor * beta_t = ggml_cont_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);

// Apply exponential to g_t
g_t = ggml_exp(ctx0, g_t);

// Apply the gated delta rule for the single timestep
// last_recurrent_state = last_recurrent_state * g_t
state = ggml_mul(ctx0, state, g_t);

// kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
ggml_tensor * k_t_unsqueezed = ggml_cont_4d(ctx0, k, 1, S_v, H_v, n_seqs);
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
// we need to sum over dim=-2, so we transpose, sum, then transpose again
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));

// v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
ggml_tensor * v_t = ggml_cont_4d(ctx0, v, S_v, 1, H_v, n_seqs);
// delta = (v_t - kv_mem) * beta_t
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs]
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);

// last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
state = ggml_add(ctx0, state, k_t_delta);

// Compute the attention output
// core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
ggml_tensor * q_t_unsqueezed = ggml_cont_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
// again, since it's over dim = -2, transpose, sum, transpose back
ggml_tensor * core_attn_out =
ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));

// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
cb(core_attn_out, "output_tokens", il);
cb(state, "new_state", il);

// flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise
ggml_tensor * flat_output = ggml_cont_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs);
ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs);

return ggml_concat(ctx0, flat_output, flat_state, 0);
}

ggml_tensor * llm_build_qwen3next::build_norm_gated(
ggml_tensor * input,
ggml_tensor * weights,
Expand Down Expand Up @@ -925,10 +1013,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
cb(k_conv, "k_conv_predelta", il);
cb(v_conv, "v_conv_predelta", il);

// Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ?
build_delta_net_chunking (q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il) :
build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il);
// Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens
ggml_tensor * attn_out;
if (n_seq_tokens == 1) {
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
} else if (n_seq_tokens > CHUNK_SIZE) {
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il);
} else {
attn_out = build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il);
}
Comment on lines +1016 to +1024
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is highly not recommended. Instead of adding more branches, we have to figure out how to make the graph static. Start with simplifying the existing graphs by removing redundant ops.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in this case we can't make the graph static since the special branch here is one where the decay mask computation doesn't happen (because n_seq_tokens == 1, so it all collapses to trivial transformations, therefore they can be optimized out).

I can probably remove the recurrent part now since I'm not sure there's a realistic case for it, it'll be either chunking or autoregressive.

cb(attn_out, "attn_out", il);

// The tensors were concatenated 1d, so we need to extract them 1d as well
Expand Down
Loading