-
Notifications
You must be signed in to change notification settings - Fork 14.5k
Unified delta net handling for Qwen3Next and Kimi Linear models #18792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just note that while working on #18683, I have been thinking about whether g sound be pre-broadcasted to [S_k, H_v, n_tokens, n_seqs] before entering this function (to make it the same shape as q and k). A broadcast should be fast, shouldn't hurt much performance
Probably we can play around with that idea, or you can reshape it to [1, n_tokens, H_k, n_seqs] as I suggested in the following comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the file name should be graph-context-delta.cpp to match the graph-context-mamba.cpp naming
src/models/delta.cpp
Outdated
| g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); | ||
| } else { | ||
| // GDA: g [H_v, n_tokens, n_seqs] -> [n_tokens, 1, H_k, n_seqs] | ||
| g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if g is reshaped to [1, n_tokens, H_k, n_seqs], then a large part of the logic below can be reused between KDA and GDA (see comments below)
src/models/delta.cpp
Outdated
| g = ggml_pad(ctx0, g, 0, pad, 0, 0); | ||
| } else { | ||
| // GDA: g shape [n_tokens, 1, H_k, n_seqs] -> pad along dim 0 | ||
| g = ggml_pad(ctx0, g, pad, 0, 0, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first of, I think this branch can be removed if g shape: [1, n_tokens], so we pad along dim 1
| beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); | ||
|
|
||
| // Reshape g for chunks | ||
| ggml_tensor * g_cumsum; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ggml_tensor * g_cumsum; | |
| ggml_tensor * g_cumsum; | |
| ggml_tensor * g_cumsum_t; |
Since we need both versions, it can be a good idea to get the transposed version right here.
For the GDA branch, a transpose will be a simple reshape as the first dim is [1, n_tokens], so no need for ggml_cont
In other words, given a tensor A with shape: [n, 1, ...], then A.view(1, n, ...) == A^T
src/models/delta.cpp
Outdated
| // Cumsum along chunk_size dimension (ne[1]) | ||
| // GGML cumsum operates on ne[0], so we need to transpose, cumsum, transpose back | ||
| g = ggml_cont(ctx0, ggml_transpose(ctx0, g)); // [chunk_size, S_k, n_chunks, H_k * n_seqs] | ||
| g_cumsum = ggml_cumsum(ctx0, g); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a quick note-to-self, but probably we need to support ggml_cumsum column-wise version, that should eliminate some transposes in the future. Or another idea, support non-cont tensors in ggml_cumsum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be like the cumsum in pytorch that you can specify which dimension to cumsum.
src/models/delta.cpp
Outdated
| // GDA: Use decay mask approach (g broadcasts over K dimension) | ||
| // g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs] | ||
| ggml_tensor * gcs_i = g_cumsum; | ||
| ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this gcs_j should be equivalent to g_cumsum_t (or just g_cumsum, depending on what shape of g you consider to be the transposed version)
then g_exp_pos = ggml_exp(ctx0, g_cumsum_t) can be computed directly here
src/models/delta.cpp
Outdated
| if (is_kda) { | ||
| // KDA: Reuse g_exp_pos computed earlier | ||
| gexp = g_exp_pos; | ||
| } else { | ||
| // GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs] | ||
| ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); | ||
| gexp = ggml_exp(ctx0, g_cumsum_t); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be removed when you apply my last trick above
| ggml_tensor * g_diff = ggml_sub(ctx0, g_last_broadcast, g_cumsum); | ||
| g_diff_exp = ggml_exp(ctx0, g_diff); | ||
| } else { | ||
| // GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure, but seems like this can be removed too, as we now have both g_cumsum and g_cumsum_t that you can play with
| } else { | ||
| // GDA: g_last_exp [1, 1, n_chunks, H_k * n_seqs] | ||
| // Broadcasts over both K and V dimensions | ||
| gexp_last_chunk = ggml_reshape_4d(ctx0, gexp_last_chunk, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can avoid this branching if g_last_exp is already boardcasted
|
Thanks for your refactoring effort. I think my kda_autoregressive is better implemented as I used mul_mat to replace sum_rows. If we refactor, the new function should be based on kda_autoregressive. |
|
@ymcki indeed your version is better :) there's like another 4% performance gain on autoregressive passes in Qwen3Next. |
|
This code in the chunking function will cause overflow without clamping. You either have to clamp or you have to use my mul_mat trick for exact solution. My mul_mat trick: |
|
@ymcki aight, migrated the KDA branch to use decay mask as well. |
|
I think my Kimi Linear PR is almost done, so I can start working on refactoring now. Do we want to do refactoring along with block matrix multiplication? The idea is that since we don't care about the upper triangle in Akk and Aqk, so we can take bigger blocks and divide them into chunk size of 64 blocks. For example, if we handle n_seq_tokens >192, then we can pad it to 256 and then break it down to 4x4 64x64 blocks. Then we only need to do mul_mat on 10/16 blocks and apply diag_mask only on the diagonal blocks, ie 4/16 blocks. If we only do refactoring, then maybe only Kimi will be a few % faster. If we include block mul_mat, then both Qwen3Next and Kimi will see significant gain. |
|
@ymcki Sure, can try, sounds like a good idea at least in theory, let's see what we can get out of this in practice. |
Refactoring in preparation for #18755
Tested on CUDA - no performance regressions compared to @ngxson's optimized version.
AI Usage: yes. Opus 4.5.