Skip to content

[not for land yet]: improve cuda graph support for Qwen-Image#13263

Open
vkuzo wants to merge 1 commit intohuggingface:mainfrom
vkuzo:20260312_qwen_image_cuda_graphs
Open

[not for land yet]: improve cuda graph support for Qwen-Image#13263
vkuzo wants to merge 1 commit intohuggingface:mainfrom
vkuzo:20260312_qwen_image_cuda_graphs

Conversation

@vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Mar 12, 2026

Summary:

Very brief writeup as I'm about to head out for the day:

  1. we want to enable cuda graphs for qwen-image + nvfp4 at small batch sizes, because without cuda graphs are we bottlenecked on cpu ops
  2. to make cuda graphs work, we need to change the modeling code a bit to match the cuda graph requirements

There is a cleaner way to do this change repo-wide without having to change each model's modeling code, for now this
is just a quick hack to demonstrate performance + accuracy

Test Plan:

use a modified version of @sayakpaul's script: https://gist.github.com/vkuzo/acac22c62404c89db2dcf195a64543db

then, run it and see nvfp4 + bsz 1 time on qwen image improve by ~1.6x from 9.5s to 5.9s

// baseline

(pt_nightly) dev@gpu-dev-6c281422:~/tmp$ python 20260212_diffuser_nvfp4.py --compile True --torch_compile_mode reduce-overhead
...
======================================================================
SUMMARY
======================================================================
Quantization: None
Compile: True
Batch size: 1
Latency: 7.461s
Peak Memory: 62.21 GB

// nvfp4 dynamic, torch.compile default

(pt_nightly) dev@gpu-dev-6c281422:~/tmp$ python 20260212_diffuser_nvfp4.py --compile True --quant dynamic --use_filter_fn True
...
======================================================================
SUMMARY
======================================================================
Quantization: dynamic
Compile: True
Batch size: 1
Latency: 9.536s
Peak Memory: 52.45 GB
======================================================================

// nvfp4 dynamic, torch.compile reduce-overhead (for cuda graphs)

(pt_nightly) dev@gpu-dev-6c281422:~/tmp$ python 20260212_diffuser_nvfp4.py --compile True --quant dynamic --use_filter_fn True --torch_compile_mode reduce-overhead
...
======================================================================
SUMMARY
======================================================================
Quantization: dynamic
Compile: True
Batch size: 1
Latency: 5.936s
Peak Memory: 52.45 GB
======================================================================

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Summary:

Very brief writeup as I'm about to head out for the day:
1. we want to enable cuda graphs for qwen-image + nvfp4 at small batch
   sizes, because without cuda graphs are we bottlenecked on cpu ops
2. to make cuda graphs work, we need to change the modeling code a bit
   to match the cuda graph requirements

There is a cleaner way to do this change repo-wide without having to
change each model's modeling code, for now this
is just a quick hack to demonstrate performarnce + accuracy

Test Plan:

use a modified version of @sayakpaul's script: https://gist.github.com/vkuzo/acac22c62404c89db2dcf195a64543db

then, run it and see nvfp4 + bsz 1 time on qwen image improve by ~1.6x from 9.5s to 5.9s

```
// baseline

(pt_nightly) dev@gpu-dev-6c281422:~/tmp$ python 20260212_diffuser_nvfp4.py --compile True --torch_compile_mode reduce-overhead
...
======================================================================
SUMMARY
======================================================================
Quantization: None
Compile: True
Batch size: 1
Latency: 7.461s
Peak Memory: 62.21 GB

// nvfp4 dynamic, torch.compile default

(pt_nightly) dev@gpu-dev-6c281422:~/tmp$ python 20260212_diffuser_nvfp4.py --compile True --quant dynamic --use_filter_fn True
...
======================================================================
SUMMARY
======================================================================
Quantization: dynamic
Compile: True
Batch size: 1
Latency: 9.536s
Peak Memory: 52.45 GB
======================================================================

// nvfp4 dynamic, torch.compile reduce-overhead (for cuda graphs)

(pt_nightly) dev@gpu-dev-6c281422:~/tmp$ python 20260212_diffuser_nvfp4.py --compile True --quant dynamic --use_filter_fn True --torch_compile_mode reduce-overhead
...
======================================================================
SUMMARY
======================================================================
Quantization: dynamic
Compile: True
Batch size: 1
Latency: 5.936s
Peak Memory: 52.45 GB
======================================================================

```
@sayakpaul
Copy link
Member

Thanks for this PR! Do we know how clone() helps the NVFP4 case but not the others like BF16?

There is a cleaner way to do this change repo-wide without having to change each model's modeling code,

What would you recommend for this? hidden_states and encoder_hidden_states enter the forward() cloned?

@sayakpaul sayakpaul requested a review from yiyixuxu March 13, 2026 02:08
@sayakpaul
Copy link
Member

If we want to keep the modeling code unchanged, the following could be another approach, I guess?

def _clone_inputs_hook(module, args, kwargs):
    args = tuple(a.clone() if isinstance(a, torch.Tensor) else a for a in args)
    kwargs = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
    return args, kwargs

transformer.register_forward_pre_hook(clone_inputs_hook, with_kwargs=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants