Add JiT model and pipeline#13098
Conversation
|
@claude-2-serge could you do a review? |
There was a problem hiding this comment.
🤗 Serge says:
This PR adds the JiT (Just image Transformer) model and pipeline for pixel-space diffusion. While the overall structure is reasonable and the math in the pipeline's velocity conversion checks out, there are several significant issues that need to be addressed before merging.
Correctness
-
Manual attention instead of
F.scaled_dot_product_attention:JiTAttentionimplements attention manually withq @ k.transpose(-2, -1) * scalefollowed by softmax. Virtually every transformer model in diffusers usesF.scaled_dot_product_attention(SDPA), which enables flash attention, memory-efficient attention, and is the expected pattern. This is a blocking issue — the manual implementation will be significantly slower and use more memory at realistic resolutions. -
Unused import in
__init__.py:is_transformers_availableis imported insrc/diffusers/pipelines/jit/__init__.pybut never used. JiT doesn't depend on transformers. -
latent_model_inputvariable is misleading: The pipeline operates in pixel space (no VAE), but the variable is namedlatent_model_input. This is confusing — it should be something likesampleorimage.
Style / Conventions
-
Missing class docstring on
JiTTransformer2DModel: Other transformer models in diffusers (e.g.,DiTTransformer2DModel) have docstrings documenting all parameters. This model has none. -
Old gradient checkpointing pattern: The
create_custom_forwardclosure pattern is outdated. Modern diffusers models pass the module directly totorch.utils.checkpoint.checkpoint. -
numpyused for positional embeddings: Theget_2d_sincos_pos_embedfamily of functions uses numpy. While functional, pure-torch implementations are preferred in newer models.
Dead Code Analysis
- Under the default config (
in_context_len=32,in_context_start=4), thefeat_rope(without in-context tokens) is used for blocks 0–3, andfeat_rope_incontextfor blocks 4+. Both RoPE modules are exercised. JiTRotaryEmbeddingacceptsfreqs_foroptions"pixel"and"constant"but only"lang"(the default) is ever used. These branches appear to be dead code copied from the reference implementation.JiTLabelEmbedderallocatesnum_classes + 1embeddings; the+1slot is used for CFG null class in the pipeline, so this is exercised.
Tests
- Pipeline tests skip
test_model_cpu_offload_forward_passandtest_cpu_offload_forward_pass_twicewith a reasonable justification (single-model pipeline). - Missing
tests/pipelines/jit/__init__.pyfile (needed for test discovery in some configurations).
16 LLM turns · 22 tool calls · 124.5s · 459763 in / 6089 out tokens
|
@sayakpaul I have test the code, a big issue is that currently we use |
|
Would you like to open a PR to https://github.com/AlanPonnachan/diffusers/tree/add-jit-diffusion? @AlanPonnachan would that be okay with you? This way, it stays collaborative :) |
|
That’s okay, thank you. In the meantime, I’ll try to resolve these issues |
I think there may be a misunderstanding here regarding the math, as FlowMatchEulerDiscreteScheduler is mathematically compatible with JiT’s Euler ODE formulation. I have gone through the paper and official code. 1. The Forward Process (Linear Interpolation) t = self.sample_t(...)
e = torch.randn_like(x) * self.noise_scale
z = t * x + (1 - t) * e2. Vector Field Conversion x_cond = self.net(z, t.flatten(), labels)
v_cond = (x_cond - z) / (1.0 - t).clamp_min(self.t_eps)In this PR https://github.com/AlanPonnachan/diffusers/blob/26b276065f3cfcbfa75a62971ac3999bfd8d69c3/src/diffusers/pipelines/jit/pipeline_jit.py#L138-L147, we do exactly the same algebraic conversion—including the clamping—so the scheduler can perform an Euler update with the supplied vector field: # Predict x
noise_pred_x = self.transformer(
model_input, timestep=timesteps_tensor, class_labels=class_labels_input
).sample
# Compute velocity v = (x - z) / (1 - t) = (x - z) / sigma
sigma_clamped = max(sigma.item(), 1e-5)
# The scheduler expects (z - x) / sigma to move towards x when integrating with negative dt
v_pred_all = (model_input - noise_pred_x) / sigma_clampedSo while the scheduler uses different parameterization conventions internally, the resulting update reproduces the same Euler update direction under the scheduler’s descending sigma convention Official JiT defaults to Heun. Should we also expose a Heun scheduler ? Let me know if something I am missing! |
|
@AlanPonnachan I use FlowMatchEulerDiscreteScheduler for sampling with my converted ckpt as https://huggingface.co/BiliSakura/JiT-diffusers , the result is not a desirable image. I seems to try FlowMatchHeunDiscreteScheduler as well which is also not desirable. Maybe there is something I missed in parameterized in these scheduler. Can you try test inference with my diffusers style ckpt at https://huggingface.co/BiliSakura/JiT-diffusers. I already simplified the JiTTransformer2D implementation the same as yours (and already fix the issue above mentioned by GitHub bot) but still with source code scheduler implementation (not diffusers built-in). I am not quite familiar with diffusion scheduler and sampling. I may make mistakes. |
It works right! I set t_ep=5e-2 and shift=4 to FlowMatchHeunDiscreteScheduler testing on my https://huggingface.co/BiliSakura/JiT-diffusers with the JiTTransformer2D proposed by (@AlanPonnachan). So far, my hf JiT model repo is self-contained and working as a standard custom diffusion pipeline. I can asssist if the JiT is implemented in the main branch of diffusers. |
|
Another bug founded, if we set num_inference_steps to large number like over 250 with FlowMatchHeunScheduler, the result image would look broken. I solve it by using fp32 it is caused by running with bf16 |
|
pushed a small fix for the bad-images thing @AlanPonnachan @Bili-Sakura. it wasn't the scheduler or the math (those are fine in fp32), it's bf16: JiT predicts clean the official repo keeps the sample in fp32 and only runs the model under bf16 autocast, so i did the diffusers version of that: keep verified on your B-16 ckpt: bf16 at 300 steps is clean now, and the model/pipeline tests + torch.compile pass. the model-side bits serge flagged (sdpa, docstring, dead |
|
also pushed the model-side cleanups @AlanPonnachan: moved verified on the B-16 ckpt: same image as before (parity held), model+pipeline tests pass (70), and |
|
Thank you @kashif. I have pushed changes for torch implementations for positional embeddings instead of numpy as recommended by serge |
|
@claude could you do a review? |
|
Claude finished @sayakpaul's task in 3m 16s —— View job Review of JiT Model and Pipeline
Overall the implementation is clean and well-structured after the recent rounds of fixes (SDPA via Blocking1. omega = torch.arange(embed_dim // 2, dtype=torch.float64)Per models.md gotcha #5, 2. Missing Every transformer in diffusers declares _no_split_modules = ["JiTBlock"]3. Missing Needed for _repeated_blocks = ["JiTBlock"]4. Gradient checkpointing guard uses # current
if self.training and self.gradient_checkpointing:
# should be (matches Flux, Wan, QwenImage, etc.)
if torch.is_grad_enabled() and self.gradient_checkpointing:Per models.md: "add Non-blocking5. Unnecessary PyTorch 2.0 check — if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("JiTAttnProcessor requires PyTorch 2.0...")Diffusers' minimum PyTorch version already requires 2.0+. No other model processor in the repo does this check. Can be removed. 6. This 19-line function is called exactly once (line 72) to concatenate two tensors. Per AGENTS.md: "Prefer inlining small helper/utility functions." The single call site is: freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)Which is just 7. sigma = sigmas[i].to(torch.float32)Accessing sigmas by the loop variable 8.
Suggestions / Dead Code AnalysisI traced the call path from
Under the default config, no dead methods or unused layers were found. All code paths are reachable. Summary: 4 blocking items (float64, two missing class attributes, gradient checkpointing guard) and 4 non-blocking suggestions. The model architecture, attention pattern, pipeline math, and test coverage all look good. |
Fixes #13070
Description
This PR adds the JiT (Just image Transformer) model and pipeline for pixel-space diffusion, as requested in issue #13070. JiT is a diffusion transformer that operates directly on pixel patches without a VAE, using Flow Matching for generation.
Reference: JiT: Just image Transformer for Pixel-space Diffusion
Changes:
JiTTransformer2DModelinsrc/diffusers/models/transformers/jit_transformer_2d.py.JiTPipelineinsrc/diffusers/pipelines/jit/pipeline_jit.py.tests/models/transformers/test_models_jit_transformer_2d.py.tests/pipelines/jit/test_jit.py.The implementation includes support for:
FlowMatchEulerDiscreteSchedulercompatibility.Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @kashif