Skip to content

Add JiT model and pipeline#13098

Open
AlanPonnachan wants to merge 9 commits into
huggingface:mainfrom
AlanPonnachan:add-jit-diffusion
Open

Add JiT model and pipeline#13098
AlanPonnachan wants to merge 9 commits into
huggingface:mainfrom
AlanPonnachan:add-jit-diffusion

Conversation

@AlanPonnachan
Copy link
Copy Markdown
Contributor

Fixes #13070

Description

This PR adds the JiT (Just image Transformer) model and pipeline for pixel-space diffusion, as requested in issue #13070. JiT is a diffusion transformer that operates directly on pixel patches without a VAE, using Flow Matching for generation.

Reference: JiT: Just image Transformer for Pixel-space Diffusion

Changes:

  • Added JiTTransformer2DModel in src/diffusers/models/transformers/jit_transformer_2d.py.
  • Added JiTPipeline in src/diffusers/pipelines/jit/pipeline_jit.py.
  • Added model tests in tests/models/transformers/test_models_jit_transformer_2d.py.
  • Added pipeline tests in tests/pipelines/jit/test_jit.py.

The implementation includes support for:

  • 2D Sin-Cos positional embeddings.
  • In-context conditioning (appending class labels as tokens).
  • Custom Rotary Positional Embeddings (RoPE) adapted for 2D image grids.
  • FlowMatchEulerDiscreteScheduler compatibility.

Before submitting

Who can review?

@sayakpaul @kashif

@sayakpaul sayakpaul requested a review from kashif February 8, 2026 08:30
@sayakpaul
Copy link
Copy Markdown
Member

@claude-2-serge could you do a review?

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗 Serge says:

This PR adds the JiT (Just image Transformer) model and pipeline for pixel-space diffusion. While the overall structure is reasonable and the math in the pipeline's velocity conversion checks out, there are several significant issues that need to be addressed before merging.

Correctness

  • Manual attention instead of F.scaled_dot_product_attention: JiTAttention implements attention manually with q @ k.transpose(-2, -1) * scale followed by softmax. Virtually every transformer model in diffusers uses F.scaled_dot_product_attention (SDPA), which enables flash attention, memory-efficient attention, and is the expected pattern. This is a blocking issue — the manual implementation will be significantly slower and use more memory at realistic resolutions.

  • Unused import in __init__.py: is_transformers_available is imported in src/diffusers/pipelines/jit/__init__.py but never used. JiT doesn't depend on transformers.

  • latent_model_input variable is misleading: The pipeline operates in pixel space (no VAE), but the variable is named latent_model_input. This is confusing — it should be something like sample or image.

Style / Conventions

  • Missing class docstring on JiTTransformer2DModel: Other transformer models in diffusers (e.g., DiTTransformer2DModel) have docstrings documenting all parameters. This model has none.

  • Old gradient checkpointing pattern: The create_custom_forward closure pattern is outdated. Modern diffusers models pass the module directly to torch.utils.checkpoint.checkpoint.

  • numpy used for positional embeddings: The get_2d_sincos_pos_embed family of functions uses numpy. While functional, pure-torch implementations are preferred in newer models.

Dead Code Analysis

  • Under the default config (in_context_len=32, in_context_start=4), the feat_rope (without in-context tokens) is used for blocks 0–3, and feat_rope_incontext for blocks 4+. Both RoPE modules are exercised.
  • JiTRotaryEmbedding accepts freqs_for options "pixel" and "constant" but only "lang" (the default) is ever used. These branches appear to be dead code copied from the reference implementation.
  • JiTLabelEmbedder allocates num_classes + 1 embeddings; the +1 slot is used for CFG null class in the pipeline, so this is exercised.

Tests

  • Pipeline tests skip test_model_cpu_offload_forward_pass and test_cpu_offload_forward_pass_twice with a reasonable justification (single-model pipeline).
  • Missing tests/pipelines/jit/__init__.py file (needed for test discovery in some configurations).

16 LLM turns · 22 tool calls · 124.5s · 459763 in / 6089 out tokens

Comment thread src/diffusers/models/transformers/jit_transformer_2d.py Outdated
Comment thread src/diffusers/models/transformers/jit_transformer_2d.py Outdated
Comment thread src/diffusers/models/transformers/jit_transformer_2d.py Outdated
Comment thread src/diffusers/models/transformers/jit_transformer_2d.py Outdated
Comment thread src/diffusers/pipelines/jit/pipeline_jit.py Outdated
Comment thread src/diffusers/pipelines/jit/__init__.py Outdated
Comment thread src/diffusers/pipelines/jit/pipeline_jit.py Outdated
Comment thread src/diffusers/pipelines/jit/pipeline_jit.py Outdated
Comment thread src/diffusers/models/transformers/jit_transformer_2d.py Outdated
Comment thread src/diffusers/pipelines/jit/pipeline_jit.py Outdated
@Bili-Sakura
Copy link
Copy Markdown

@sayakpaul I have test the code, a big issue is that currently we use FlowMatchEulerDiscreteScheduler scheduler which not align with JiT source paper.

@sayakpaul
Copy link
Copy Markdown
Member

Would you like to open a PR to https://github.com/AlanPonnachan/diffusers/tree/add-jit-diffusion? @AlanPonnachan would that be okay with you?

This way, it stays collaborative :)

@AlanPonnachan
Copy link
Copy Markdown
Contributor Author

That’s okay, thank you. In the meantime, I’ll try to resolve these issues

@AlanPonnachan
Copy link
Copy Markdown
Contributor Author

@sayakpaul I have test the code, a big issue is that currently we use FlowMatchEulerDiscreteScheduler scheduler which not align with JiT source paper.

I think there may be a misunderstanding here regarding the math, as FlowMatchEulerDiscreteScheduler is mathematically compatible with JiT’s Euler ODE formulation.

I have gone through the paper and official code.

1. The Forward Process (Linear Interpolation)
In https://github.com/LTH14/JiT/blob/cbc743a2ada5e9762697da2c83f8c4f8379e8c17/denoiser.py#L52-L55, JiT defines its noise interpolation as a straight line:

t = self.sample_t(...)
e = torch.randn_like(x) * self.noise_scale
z = t * x + (1 - t) * e

2. Vector Field Conversion
In https://github.com/LTH14/JiT/blob/cbc743a2ada5e9762697da2c83f8c4f8379e8c17/denoiser.py#L93-L94, JiT repo manually converts the predicted $x_0$ into a vector field for the Euler step, importantly clamping the denominator:

x_cond = self.net(z, t.flatten(), labels)
v_cond = (x_cond - z) / (1.0 - t).clamp_min(self.t_eps)

In this PR https://github.com/AlanPonnachan/diffusers/blob/26b276065f3cfcbfa75a62971ac3999bfd8d69c3/src/diffusers/pipelines/jit/pipeline_jit.py#L138-L147, we do exactly the same algebraic conversion—including the clamping—so the scheduler can perform an Euler update with the supplied vector field:

# Predict x
noise_pred_x = self.transformer(
    model_input, timestep=timesteps_tensor, class_labels=class_labels_input
).sample

# Compute velocity v = (x - z) / (1 - t) = (x - z) / sigma
sigma_clamped = max(sigma.item(), 1e-5)

# The scheduler expects (z - x) / sigma to move towards x when integrating with negative dt
v_pred_all = (model_input - noise_pred_x) / sigma_clamped

So while the scheduler uses different parameterization conventions internally, the resulting update reproduces the same Euler update direction under the scheduler’s descending sigma convention

Official JiT defaults to Heun. Should we also expose a Heun scheduler ?

Let me know if something I am missing!

@Bili-Sakura
Copy link
Copy Markdown

Bili-Sakura commented May 22, 2026

@AlanPonnachan I use FlowMatchEulerDiscreteScheduler for sampling with my converted ckpt as https://huggingface.co/BiliSakura/JiT-diffusers , the result is not a desirable image. I seems to try FlowMatchHeunDiscreteScheduler as well which is also not desirable. Maybe there is something I missed in parameterized in these scheduler.

Can you try test inference with my diffusers style ckpt at https://huggingface.co/BiliSakura/JiT-diffusers. I already simplified the JiTTransformer2D implementation the same as yours (and already fix the issue above mentioned by GitHub bot) but still with source code scheduler implementation (not diffusers built-in).

I am not quite familiar with diffusion scheduler and sampling. I may make mistakes.

@Bili-Sakura
Copy link
Copy Markdown

Bili-Sakura commented May 23, 2026

@AlanPonnachan I use FlowMatchEulerDiscreteScheduler for sampling with my converted ckpt as https://huggingface.co/BiliSakura/JiT-diffusers , the result is not a desirable image. I seems to try FlowMatchHeunDiscreteScheduler as well which is also not desirable. Maybe there is something I missed in parameterized in these scheduler.

Can you try test inference with my diffusers style ckpt at https://huggingface.co/BiliSakura/JiT-diffusers. I already simplified the JiTTransformer2D implementation the same as yours (and already fix the issue above mentioned by GitHub bot) but still with source code scheduler implementation (not diffusers built-in).

I am not quite familiar with diffusion scheduler and sampling. I may make mistakes.

@AlanPonnachan @sayakpaul

It works right! I set t_ep=5e-2 and shift=4 to FlowMatchHeunDiscreteScheduler testing on my https://huggingface.co/BiliSakura/JiT-diffusers with the JiTTransformer2D proposed by (@AlanPonnachan).

So far, my hf JiT model repo is self-contained and working as a standard custom diffusion pipeline. I can asssist if the JiT is implemented in the main branch of diffusers.

@Bili-Sakura
Copy link
Copy Markdown

Bili-Sakura commented May 29, 2026

Another bug founded, if we set num_inference_steps to large number like over 250 with FlowMatchHeunScheduler, the result image would look broken.
This maybe the inward bug of JiT itself, not our implementation.
—-

I solve it by using fp32 it is caused by running with bf16

@kashif
Copy link
Copy Markdown
Contributor

kashif commented Jun 3, 2026

pushed a small fix for the bad-images thing @AlanPonnachan @Bili-Sakura. it wasn't the scheduler or the math (those are fine in fp32), it's bf16: JiT predicts clean x and the (x - z)/(1 - t) velocity is precision-sensitive, so accumulating the sample in bf16 over many steps decays into salt-and-pepper noise (worse with more steps, which matches what you were seeing).

the official repo keeps the sample in fp32 and only runs the model under bf16 autocast, so i did the diffusers version of that: keep sample in fp32, cast to transformer.dtype only for the forward, and do the velocity + scheduler step in fp32. also clamped the 1/(1-t) denom at 0.05 (JiT's t_eps) and added a noise_scale arg for 512/1024.

verified on your B-16 ckpt: bf16 at 300 steps is clean now, and the model/pipeline tests + torch.compile pass. the model-side bits serge flagged (sdpa, docstring, dead freqs_for) are still worth doing separately.

@kashif
Copy link
Copy Markdown
Contributor

kashif commented Jun 3, 2026

also pushed the model-side cleanups @AlanPonnachan: moved JiTAttention to a proper JiTAttnProcessor going through dispatch_attention_fn (so attention backends work), added AttentionMixin + a class docstring, swapped the old gradient-checkpoint closure for self._gradient_checkpointing_func, and dropped the dead freqs_for pixel/constant branches. left the einsum unpatchify as-is since serge marked it a soft-pass and DiT/SD3 do the same.

verified on the B-16 ckpt: same image as before (parity held), model+pipeline tests pass (70), and torch.compile(fullgraph=True) works. should cover serge's comments now.

@AlanPonnachan
Copy link
Copy Markdown
Contributor Author

Thank you @kashif. I have pushed changes for torch implementations for positional embeddings instead of numpy as recommended by serge

@sayakpaul sayakpaul requested a review from dg845 June 5, 2026 07:58
@sayakpaul
Copy link
Copy Markdown
Member

@claude could you do a review?

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 5, 2026

Claude finished @sayakpaul's task in 3m 16s —— View job


Review of JiT Model and Pipeline

  • Read review guideline files (AGENTS.md, models.md, pipelines.md, modular.md)
  • Read the model implementation (jit_transformer_2d.py)
  • Read the pipeline implementation (pipeline_jit.py)
  • Read the pipeline __init__.py
  • Read the __init__.py registration files
  • Read tests
  • Trace call path from pipeline to model
  • Write review

Overall the implementation is clean and well-structured after the recent rounds of fixes (SDPA via dispatch_attention_fn, bf16 precision handling, gradient checkpointing modernization, numpy→torch for positional embeddings). A few items remain.

Blocking

1. torch.float64 in positional embedding helperjit_transformer_2d.py:318

omega = torch.arange(embed_dim // 2, dtype=torch.float64)

Per models.md gotcha #5, torch.float64 breaks on MPS, NPU, and Neuron backends. Although this runs at init time and the buffer is stored as float32, model construction can happen on any device. Should use torch.float32 — the precision difference in sinusoidal position encodings is immaterial.

Fix this →

2. Missing _no_split_modules class attributejit_transformer_2d.py:331

Every transformer in diffusers declares _no_split_modules. Without it, from_pretrained(..., device_map="auto") raises ValueError. Add:

_no_split_modules = ["JiTBlock"]

3. Missing _repeated_blocks class attributejit_transformer_2d.py:331

Needed for model.compile_repeated_blocks() support. Add:

_repeated_blocks = ["JiTBlock"]

4. Gradient checkpointing guard uses self.training instead of torch.is_grad_enabled()jit_transformer_2d.py:466

# current
if self.training and self.gradient_checkpointing:

# should be (matches Flux, Wan, QwenImage, etc.)
if torch.is_grad_enabled() and self.gradient_checkpointing:

Per models.md: "add if torch.is_grad_enabled() and self.gradient_checkpointing: branches." Every reference transformer in the repo uses torch.is_grad_enabled(), not self.training.

Non-blocking

5. Unnecessary PyTorch 2.0 checkjit_transformer_2d.py:178-179

if not hasattr(F, "scaled_dot_product_attention"):
    raise ImportError("JiTAttnProcessor requires PyTorch 2.0...")

Diffusers' minimum PyTorch version already requires 2.0+. No other model processor in the repo does this check. Can be removed.

6. broadcat utility is overly generaljit_transformer_2d.py:33-51

This 19-line function is called exactly once (line 72) to concatenate two tensors. Per AGENTS.md: "Prefer inlining small helper/utility functions." The single call site is:

freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)

Which is just torch.cat((freqs[:, None, :].expand(...), freqs[None, :, :].expand(...)), dim=-1). Consider inlining or at minimum removing the general error-checking overhead.

7. sigmas[i] access by loop indexpipeline_jit.py:129

sigma = sigmas[i].to(torch.float32)

Accessing sigmas by the loop variable i assumes the loop index always matches the scheduler's sigma array layout. FlowMatchEulerDiscreteScheduler appends a terminal 0 to sigmas, so this works today for indices 0..N-1, but it's fragile. Consider deriving sigma from the timestep: sigma = t / self.scheduler.config.num_train_timesteps or using self.scheduler.sigmas[self.scheduler.step_index] which is scheduler-state-aware.

8. JiTPatchEmbed stores unused attributesjit_transformer_2d.py:109-110

self.img_size and self.patch_size are stored but never read after __init__. Minor dead storage — conventional in vision models so not blocking, but worth noting.

Suggestions / Dead Code Analysis

I traced the call path from JiTPipeline.__call__ through the model's forward:

Component Status
feat_rope (blocks 0 to in_context_start-1) ✅ Exercised
feat_rope_incontext (blocks in_context_start+) ✅ Exercised
JiTLabelEmbedder null-class slot (num_classes + 1) ✅ Used for CFG
attention_dropout / dropout on middle-quarter blocks ✅ Active at training time
in_context_posemb parameter ✅ Added at block in_context_start
All forward parameters (hidden_states, timestep, class_labels) ✅ All passed by pipeline
JiTPatchEmbed.img_size, .patch_size attributes ⚠️ Stored but never read (see #8)

Under the default config, no dead methods or unused layers were found. All code paths are reachable.


Summary: 4 blocking items (float64, two missing class attributes, gradient checkpointing guard) and 4 non-blocking suggestions. The model architecture, attention pattern, pipeline math, and test coverage all look good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add JiT: Just image Transformer (JiT) for Pixel-space Diffusion

4 participants