Skip to content

Commit b9d46ca

Browse files
committed
Harden RAE DiT conversion and pipeline helpers
1 parent afc2db7 commit b9d46ca

4 files changed

Lines changed: 84 additions & 7 deletions

File tree

scripts/convert_rae_stage2_to_diffusers.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,14 @@ def build_scheduler_config(config: dict[str, Any]) -> tuple[FlowMatchEulerDiscre
184184
misc = _resolve_section(config, "misc")
185185

186186
transport_params = transport.get("params", {})
187+
path_type = str(transport_params.get("path_type", "Linear"))
188+
prediction = str(transport_params.get("prediction", "velocity"))
189+
if path_type.lower() != "linear" or prediction.lower() != "velocity":
190+
raise ValueError(
191+
"Only `transport.params.path_type=Linear` with `transport.params.prediction=velocity` is "
192+
"supported by this converter because it always saves a `FlowMatchEulerDiscreteScheduler`."
193+
)
194+
187195
latent_size = misc.get("latent_size", None)
188196
if latent_size is None:
189197
raise KeyError("Config must define `misc.latent_size` for scheduler conversion.")
@@ -200,8 +208,8 @@ def build_scheduler_config(config: dict[str, Any]) -> tuple[FlowMatchEulerDiscre
200208
metadata = {
201209
"num_train_timesteps": scheduler.config.num_train_timesteps,
202210
"shift": scheduler.config.shift,
203-
"path_type": transport_params.get("path_type", "Linear"),
204-
"prediction": transport_params.get("prediction", "velocity"),
211+
"path_type": path_type,
212+
"prediction": prediction,
205213
"time_dist_type": transport_params.get("time_dist_type", "uniform"),
206214
}
207215
return scheduler, metadata
@@ -307,20 +315,27 @@ def write_metadata(output_path: Path, metadata: dict[str, Any]) -> None:
307315

308316

309317
def resolve_input_path(accessor: RepoAccessor, path: str) -> Path:
318+
expanded_path = Path(path).expanduser()
319+
if expanded_path.is_absolute():
320+
if expanded_path.is_file():
321+
return expanded_path
322+
raise FileNotFoundError(f"Absolute path does not exist: {expanded_path}")
323+
310324
candidates = [path]
311325
if path.startswith("models/"):
312326
candidates.append(path[len("models/") :])
313327

314328
for candidate in candidates:
315-
local_path = Path(candidate)
316-
if local_path.is_file():
317-
return local_path
318-
319329
try:
320330
return accessor.fetch(candidate)
321331
except FileNotFoundError:
322332
continue
323333

334+
for candidate in candidates:
335+
local_path = Path(candidate).expanduser()
336+
if local_path.is_file():
337+
return local_path
338+
324339
raise FileNotFoundError(f"Could not resolve `{path}` from `{accessor.repo_or_path}`.")
325340

326341

src/diffusers/pipelines/rae_dit/pipeline_rae_dit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def __call__(
249249
if output_type == "latent":
250250
output = latents
251251
else:
252-
images = self.vae.decode(latents).sample.clamp(0, 1)
252+
images = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample.clamp(0, 1)
253253
output = self.image_processor.postprocess(images, output_type=output_type)
254254

255255
self.maybe_free_model_hooks()

tests/others/test_rae_dit_conversion.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,18 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
1617
import tempfile
18+
from pathlib import Path
1719

20+
import pytest
1821
import torch
1922

2023
from diffusers import AutoencoderRAE
2124
from scripts.convert_rae_stage2_to_diffusers import (
25+
RepoAccessor,
26+
build_scheduler_config,
27+
resolve_input_path,
2228
translate_transformer_state_dict,
2329
unwrap_state_dict,
2430
)
@@ -78,6 +84,48 @@ def test_translate_transformer_state_dict_maps_gelu_keys():
7884
assert torch.equal(translated["blocks.0.mlp.net.2.weight"], fc2_weight)
7985

8086

87+
def test_build_scheduler_config_rejects_non_linear_or_non_velocity_transport():
88+
with pytest.raises(ValueError):
89+
build_scheduler_config(
90+
{
91+
"transport": {"params": {"path_type": "VP", "prediction": "velocity"}},
92+
"misc": {"latent_size": [768, 16, 16]},
93+
}
94+
)
95+
96+
with pytest.raises(ValueError):
97+
build_scheduler_config(
98+
{
99+
"transport": {"params": {"path_type": "Linear", "prediction": "epsilon"}},
100+
"misc": {"latent_size": [768, 16, 16]},
101+
}
102+
)
103+
104+
105+
def test_resolve_input_path_prefers_repo_accessor_for_relative_paths():
106+
original_cwd = Path.cwd()
107+
108+
with tempfile.TemporaryDirectory() as repo_tmpdir, tempfile.TemporaryDirectory() as cwd_tmpdir:
109+
repo_root = Path(repo_tmpdir)
110+
cwd_root = Path(cwd_tmpdir)
111+
112+
repo_config = repo_root / "configs" / "sample.yaml"
113+
repo_config.parent.mkdir(parents=True, exist_ok=True)
114+
repo_config.write_text("repo: true\n", encoding="utf-8")
115+
116+
cwd_config = cwd_root / "configs" / "sample.yaml"
117+
cwd_config.parent.mkdir(parents=True, exist_ok=True)
118+
cwd_config.write_text("cwd: true\n", encoding="utf-8")
119+
120+
os.chdir(cwd_root)
121+
try:
122+
resolved = resolve_input_path(RepoAccessor(str(repo_root)), "configs/sample.yaml")
123+
finally:
124+
os.chdir(original_cwd)
125+
126+
assert resolved == repo_config
127+
128+
81129
def test_autoencoder_rae_from_pretrained_loads_local_checkpoint():
82130
model = AutoencoderRAE(
83131
encoder_type="mae",

tests/pipelines/rae_dit/test_pipeline_rae_dit.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,20 @@ def test_inference(self):
188188
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
189189
self.assertLessEqual(max_diff, 1e-4)
190190

191+
def test_inference_casts_latents_to_vae_dtype_before_decode(self):
192+
components = self.get_dummy_components()
193+
components["vae"] = components["vae"].to(dtype=torch.float64)
194+
pipe = self.pipeline_class(**components).to("cpu")
195+
pipe.set_progress_bar_config(disable=None)
196+
197+
inputs = self.get_dummy_inputs("cpu")
198+
inputs["output_type"] = "pt"
199+
200+
images = pipe(**inputs).images
201+
202+
self.assertEqual(images.shape, (1, 3, 4, 4))
203+
self.assertTrue(torch.isfinite(images).all().item())
204+
191205
def test_inference_classifier_free_guidance(self):
192206
pipe = self.pipeline_class(**self.get_dummy_components()).to("cpu")
193207
pipe.set_progress_bar_config(disable=None)

0 commit comments

Comments
 (0)