Skip to content

Conversation

@pwilkin
Copy link
Collaborator

@pwilkin pwilkin commented Jan 12, 2026

Refactoring in preparation for #18755

Tested on CUDA - no performance regressions compared to @ngxson's optimized version.

AI Usage: yes. Opus 4.5.

@pwilkin pwilkin requested review from CISC, ggerganov and ngxson and removed request for CISC and ggerganov January 12, 2026 19:06
Copy link
Collaborator

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just note that while working on #18683, I have been thinking about whether g sound be pre-broadcasted to [S_k, H_v, n_tokens, n_seqs] before entering this function (to make it the same shape as q and k). A broadcast should be fast, shouldn't hurt much performance

Probably we can play around with that idea, or you can reshape it to [1, n_tokens, H_k, n_seqs] as I suggested in the following comments

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the file name should be graph-context-delta.cpp to match the graph-context-mamba.cpp naming

g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
} else {
// GDA: g [H_v, n_tokens, n_seqs] -> [n_tokens, 1, H_k, n_seqs]
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if g is reshaped to [1, n_tokens, H_k, n_seqs], then a large part of the logic below can be reused between KDA and GDA (see comments below)

g = ggml_pad(ctx0, g, 0, pad, 0, 0);
} else {
// GDA: g shape [n_tokens, 1, H_k, n_seqs] -> pad along dim 0
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first of, I think this branch can be removed if g shape: [1, n_tokens], so we pad along dim 1

beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);

// Reshape g for chunks
ggml_tensor * g_cumsum;
Copy link
Collaborator

@ngxson ngxson Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ggml_tensor * g_cumsum;
ggml_tensor * g_cumsum;
ggml_tensor * g_cumsum_t;

Since we need both versions, it can be a good idea to get the transposed version right here.

For the GDA branch, a transpose will be a simple reshape as the first dim is [1, n_tokens], so no need for ggml_cont

In other words, given a tensor A with shape: [n, 1, ...], then A.view(1, n, ...) == A^T

// Cumsum along chunk_size dimension (ne[1])
// GGML cumsum operates on ne[0], so we need to transpose, cumsum, transpose back
g = ggml_cont(ctx0, ggml_transpose(ctx0, g)); // [chunk_size, S_k, n_chunks, H_k * n_seqs]
g_cumsum = ggml_cumsum(ctx0, g);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a quick note-to-self, but probably we need to support ggml_cumsum column-wise version, that should eliminate some transposes in the future. Or another idea, support non-cont tensors in ggml_cumsum

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be like the cumsum in pytorch that you can specify which dimension to cumsum.

// GDA: Use decay mask approach (g broadcasts over K dimension)
// g_cumsum [chunk_size, 1, n_chunks, H_v * n_seqs]
ggml_tensor * gcs_i = g_cumsum;
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
Copy link
Collaborator

@ngxson ngxson Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this gcs_j should be equivalent to g_cumsum_t (or just g_cumsum, depending on what shape of g you consider to be the transposed version)

then g_exp_pos = ggml_exp(ctx0, g_cumsum_t) can be computed directly here

Comment on lines 251 to 258
if (is_kda) {
// KDA: Reuse g_exp_pos computed earlier
gexp = g_exp_pos;
} else {
// GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs]
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
gexp = ggml_exp(ctx0, g_cumsum_t);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be removed when you apply my last trick above

ggml_tensor * g_diff = ggml_sub(ctx0, g_last_broadcast, g_cumsum);
g_diff_exp = ggml_exp(ctx0, g_diff);
} else {
// GDA: g_cumsum [chunk_size, 1, n_chunks, H_k * n_seqs]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure, but seems like this can be removed too, as we now have both g_cumsum and g_cumsum_t that you can play with

} else {
// GDA: g_last_exp [1, 1, n_chunks, H_k * n_seqs]
// Broadcasts over both K and V dimensions
gexp_last_chunk = ggml_reshape_4d(ctx0, gexp_last_chunk,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can avoid this branching if g_last_exp is already boardcasted

@github-actions github-actions bot added the model Model specific label Jan 12, 2026
@ymcki
Copy link
Contributor

ymcki commented Jan 12, 2026

Thanks for your refactoring effort. I think my kda_autoregressive is better implemented as I used mul_mat to replace sum_rows. If we refactor, the new function should be based on kda_autoregressive.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Jan 13, 2026

@ymcki indeed your version is better :) there's like another 4% performance gain on autoregressive passes in Qwen3Next.

@ymcki
Copy link
Contributor

ymcki commented Jan 13, 2026

This code in the chunking function will cause overflow without clamping. You either have to clamp or you have to use my mul_mat trick for exact solution.

        g_exp_pos = ggml_exp(ctx0, g_cumsum);
        g_exp_neg = ggml_exp(ctx0, ggml_neg(ctx0, g_cumsum));
        ggml_tensor * k_pos_beta = ggml_mul(ctx0, k_beta, g_exp_pos);
        ggml_tensor * k_neg = ggml_mul(ctx0, k, g_exp_neg);
        k_decay = ggml_mul_mat(ctx0, k_pos_beta, k_neg);

My mul_mat trick:

    const int64_t CHB = n_chunks * H_k * n_seqs;
    ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
    ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB]

    ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
    // decay_mask [chunk_size,chunk_size,S_k,CHB]
    ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
    cb(decay_mask, "decay_mask", il);

    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
    cb(decay_mask, "decay_masked", il);
    decay_mask = ggml_exp(ctx0, decay_mask);
    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);

    // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
    decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);

    ggml_tensor * k_i = ggml_cont(ctx0, ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB));
    ggml_tensor * k_j = ggml_cont(ctx0, ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB));
    ggml_tensor * q_i = ggml_cont(ctx0, ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB));

    ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
    ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);

    // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
    ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
    ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
    Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
    Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));

@pwilkin
Copy link
Collaborator Author

pwilkin commented Jan 13, 2026

@ymcki aight, migrated the KDA branch to use decay mask as well.

@ymcki
Copy link
Contributor

ymcki commented Jan 16, 2026

I think my Kimi Linear PR is almost done, so I can start working on refactoring now.

Do we want to do refactoring along with block matrix multiplication?

The idea is that since we don't care about the upper triangle in Akk and Aqk, so we can take bigger blocks and divide them into chunk size of 64 blocks. For example, if we handle n_seq_tokens >192, then we can pad it to 256 and then break it down to 4x4 64x64 blocks. Then we only need to do mul_mat on 10/16 blocks and apply diag_mask only on the diagonal blocks, ie 4/16 blocks.

If we only do refactoring, then maybe only Kimi will be a few % faster. If we include block mul_mat, then both Qwen3Next and Kimi will see significant gain.

@pwilkin
Copy link
Collaborator Author

pwilkin commented Jan 16, 2026

@ymcki Sure, can try, sounds like a good idea at least in theory, let's see what we can get out of this in practice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants