diff --git a/examples/dreambooth/README_ideogram4.md b/examples/dreambooth/README_ideogram4.md new file mode 100644 index 000000000000..de5fefe17171 --- /dev/null +++ b/examples/dreambooth/README_ideogram4.md @@ -0,0 +1,134 @@ +# DreamBooth training example for Ideogram 4 + +[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept. +[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning-like performance with a fraction of the learnable parameters. + +`train_dreambooth_lora_ideogram4.py` shows how to implement LoRA DreamBooth training for [Ideogram 4](https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/ideogram4.md). + +> [!NOTE] +> **About the model** +> +> Ideogram 4 is a flow-matching text-to-image model with a few characteristics that are relevant for training: +> - It uses **two** transformers β€” a text-conditional `transformer` and an `unconditional_transformer` blended at inference via asymmetric classifier-free guidance. This trainer adds LoRA to the **conditional `transformer` only**; the unconditional one stays frozen. +> - Text conditioning comes from a **Qwen3-VL** multimodal text encoder (a fixed set of decoder layers is concatenated into the per-token features). + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/dreambooth` folder and run + +```bash +pip install -r requirements_ideogram4.txt +``` + +Initialize an [πŸ€— Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config default +``` + +We use the PEFT library as the backend for LoRA training; make sure `peft>=0.11.1` is installed. + +### Quantized (nf4) base β€” QLoRA + +Ideogram 4 is a large model, so a pre-quantized **nf4** checkpoint (`bitsandbytes`) is a convenient base for low-memory LoRA training. When the base checkpoint is already quantized, the trainer detects it automatically β€” you do **not** need to pass `--bnb_quantization_config_path` (that flag is for quantizing a full-precision checkpoint on the fly). The LoRA adapter is trained on top of the frozen 4-bit base (QLoRA) and saved in full precision. + +## Prompt format + +> [!IMPORTANT] +> Ideogram 4 is trained on structured **JSON captions** β€” a single-line JSON object that exhaustively describes the image β€” rather than free-form text. Plain text works, but the model understands the JSON structure natively, so captions in the schema generally train and generate best. + +A caption is a JSON object; commonly used fields (see the upstream [ideogram-oss/ideogram4](https://github.com/ideogram-oss/ideogram4) prompt docs for the full schema) include: +- `high_level_description` β€” a one-line summary of the whole image. +- `compositional_deconstruction` β€” spatial layout, with a `background` string and an `elements` array; each element has a `type` (e.g. `"obj"`, `"text"`) and a `desc`. +- `colour_palette` β€” an array of hex colors to steer the image's color scheme. +- `bbox` β€” bounding-box coordinates for explicit placement of subjects, text, and background regions. + +For best results, make each training caption describe its image as exhaustively as the schema allows. + +For `--caption_column` / `--instance_prompt` (and at inference): +- **Recommended:** provide captions already in Ideogram 4's JSON caption schema. +- Or pass `--upsample_prompt` to rewrite free-form captions into the JSON schema during caching. This loads the prompt-enhancer LM head (`--prompt_enhancer_head_id`, default [`diffusers/qwen3-vl-8b-instruct-lm-head`](https://huggingface.co/diffusers/qwen3-vl-8b-instruct-lm-head)) as the pipeline's `prompt_enhancer_head`; install `outlines` for schema-constrained output. +- At inference, pass a short prompt with `prompt_upsampling=True` to rewrite it into the schema. + +## Training example + +For this example we use the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset: + +```bash +export MODEL_NAME="ideogram-ai/ideogram-v4" +export OUTPUT_DIR="trained-ideogram4-lora" +# Ideogram 4 expects a structured JSON caption (see "Prompt format" above). +export INSTANCE_PROMPT='{"high_level_description":"A puppy in a soft yarn-art style","compositional_deconstruction":{"background":"a plain cream studio backdrop","elements":[{"type":"obj","desc":"a small fluffy puppy crocheted from multicolored yarn, sitting upright and facing the viewer"}]}}' + +accelerate launch train_dreambooth_lora_ideogram4.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name="Norod78/Yarn-art-style" \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="$INSTANCE_PROMPT" \ + --resolution=1024 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --rank=16 \ + --optimizer="adamw" \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --mixed_precision="bf16" \ + --seed="0" +``` + +To track training with Weights & Biases add `--report_to="wandb"`, and to add periodic samples add `--validation_prompt="$INSTANCE_PROMPT" --validation_epochs=25` (a JSON caption, like the training prompt). + +> [!NOTE] +> By default the LoRA weights are saved locally to `--output_dir`. To upload them to the Hub, add `--push_to_hub` (and `--hub_model_id`). Keep private datasets/LoRAs in private repos. + +## Memory optimizations + +Many of these can be combined: + +- `--cache_latents` β€” pre-encode images with the VAE, then free it. +- `--offload` β€” offload the VAE / text encoder to CPU when not in use. +- `--gradient_accumulation_steps` β€” accumulate gradients to use a smaller effective batch. +- `--gradient_checkpointing` β€” recompute activations in the backward pass to save memory (slower). +- `--use_8bit_adam` β€” 8-bit AdamW optimizer (`bitsandbytes`); only applies to the `adamw` optimizer. +- `--resolution` β€” lower the training resolution (images are resized/cropped to this). Must be a multiple of 16; Ideogram 4 supports 256–2048. +- `--rank` β€” lower the LoRA rank for fewer trainable parameters. + +### Precision of saved LoRA layers + +By default the trained LoRA layers are saved in the training precision (e.g. `bf16` with `--mixed_precision="bf16"`). Pass `--upcast_before_saving` to save them in `float32` instead. + +## Inference + +After training, load the base pipeline and your LoRA: + +```python +import torch +from diffusers import Ideogram4Pipeline + +pipeline = Ideogram4Pipeline.from_pretrained("ideogram-ai/ideogram-v4", torch_dtype=torch.bfloat16) +pipeline.to("cuda") +pipeline.load_lora_weights("trained-ideogram4-lora", weight_name="pytorch_lora_weights.safetensors") + +# Ideogram 4 expects a structured JSON caption (or pass a short prompt with prompt_upsampling=True). +prompt = '{"high_level_description":"A puppy in a soft yarn-art style","compositional_deconstruction":{"background":"a plain cream studio backdrop","elements":[{"type":"obj","desc":"a small fluffy puppy crocheted from multicolored yarn, sitting upright and facing the viewer"}]}}' +image = pipeline(prompt, height=1024, width=1024).images[0] +image.save("ideogram4_lora.png") +``` + +Ideogram 4 uses a guidance *schedule* by default; to use a constant scale instead, pass `guidance_scale=, guidance_schedule=None` (exactly one of the two must be set, and a `guidance_schedule` must have length `num_inference_steps`). diff --git a/examples/dreambooth/requirements_ideogram4.txt b/examples/dreambooth/requirements_ideogram4.txt new file mode 100644 index 000000000000..24b2084e4bc5 --- /dev/null +++ b/examples/dreambooth/requirements_ideogram4.txt @@ -0,0 +1,9 @@ +accelerate>=1.13.0 +torchvision +transformers>=5.6 +ftfy +tensorboard +Jinja2 +peft>=0.18.1 +sentencepiece +bitsandbytes diff --git a/examples/dreambooth/train_dreambooth_lora_ideogram4.py b/examples/dreambooth/train_dreambooth_lora_ideogram4.py new file mode 100644 index 000000000000..18916669f611 --- /dev/null +++ b/examples/dreambooth/train_dreambooth_lora_ideogram4.py @@ -0,0 +1,2122 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "diffusers @ git+https://github.com/huggingface/diffusers.git", +# "torch>=2.0.0", +# "accelerate>=0.31.0", +# "transformers>=4.41.2", +# "ftfy", +# "tensorboard", +# "Jinja2", +# "peft>=0.11.1", +# "sentencepiece", +# "torchvision", +# "datasets", +# "bitsandbytes", +# "prodigyopt", +# ] +# /// + +import argparse +import itertools +import json +import logging +import math +import os +import random +import shutil +import warnings +from contextlib import nullcontext +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torch.utils.data.sampler import BatchSampler +from torchvision import transforms +from torchvision.transforms import functional as TF +from tqdm.auto import tqdm + +import diffusers +from diffusers import ( + AutoencoderKLFlux2, + BitsAndBytesConfig, + Ideogram4Pipeline, + Ideogram4Transformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.pipelines.ideogram4.pipeline_ideogram4 import _resolution_aware_mu +from diffusers.training_utils import ( + _collate_lora_metadata, + _to_cpu_contiguous, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + find_nearest_bucket, + free_memory, + get_fsdp_kwargs_from_accelerator, + offload_models, + parse_buckets_string, + wrap_with_fsdp, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available +from diffusers.utils.torch_utils import is_compiled_module + + +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.39.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, + quant_training=None, +): + widget_dict = [] + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + + model_description = f""" +# Ideogram4 DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Ideogram4 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_ideogram4.md). + +Quant training? {quant_training} + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import Ideogram4Pipeline +import torch + +pipeline = Ideogram4Pipeline.from_pretrained("{base_model}", torch_dtype=torch.bfloat16) +pipeline.to('cuda') +pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms of the base model [{base_model}](https://huggingface.co/{base_model}). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "ideogram4", + "ideogram4-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + epoch, + torch_dtype, + is_final_validation=False, +): + args.num_validation_images = args.num_validation_images if args.num_validation_images else 1 + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # The base may be quantized (nf4): casting a quantized pipeline to a dtype is unsupported, and + # from_pretrained already loads the components in the right dtype, so we skip the dtype cast here. + # We move the whole pipeline to device rather than enable_model_cpu_offload(): the pipeline's + # encode_prompt reaches into the text-encoder submodules directly, so the offload hooks never fire + # and leave the encoder on CPU (device-mismatch on token embedding). + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None + # No autocast: the pipeline already runs natively in the right dtype (e.g. bf16), and wrapping it in + # torch.autocast corrupts the outputs (fp16 -> NaN images, bf16 -> gray/noisy images). Verified + # empirically on the nf4 checkpoint: identical generations are clean without autocast and garbage + # under autocast of either dtype. + autocast_ctx = nullcontext() + + # Intermediate validation runs the LIVE training transformer, which can carry float32 params that + # other trainers paper over with autocast (unusable here, see above): + # - trainable LoRA params (peft creates fp32 adapters on quantized base layers), and + # - biases of quantized Linear4bit modules that received fp32 activations during training + # (bitsandbytes mutates `bias.data` to the input dtype in its forward). + # Any fp32 param also makes the model's `.dtype` property report fp32, steering the pipeline's + # internal casts to fp32 and crashing attention with mixed dtypes. So cast every fp32 param to the + # inference dtype for generation, and restore the trainable (LoRA) ones to fp32 afterwards β€” the + # Parameter objects are preserved, so optimizer state and references are unaffected. The bnb biases + # were loaded in the inference dtype to begin with, so they need no restore. + restore_fp32 = False + if not is_final_validation and torch_dtype != torch.float32: + for param in pipeline.transformer.parameters(): + if param.dtype == torch.float32: + param.data = param.data.to(torch_dtype) + restore_fp32 = restore_fp32 or param.requires_grad + + # accelerate's mixed-precision wrapper makes the live transformer's forward run under + # torch.autocast, which corrupts Ideogram4 generations just like an explicit autocast (see + # above). Strip the wrapper for validation and restore it afterwards so training still runs + # under the accelerate-managed autocast. + saved_wrapped_forward = None + transformer_module = pipeline.transformer + if not is_final_validation and "_original_forward" in transformer_module.__dict__: + saved_wrapped_forward = transformer_module.__dict__["forward"] + saved_original_forward = transformer_module.__dict__["_original_forward"] + accelerator.unwrap_model(transformer_module, keep_fp32_wrapper=False) + + images = [] + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt=args.validation_prompt, + height=args.resolution, + width=args.resolution, + generator=generator, + ).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + if saved_wrapped_forward is not None: + transformer_module.forward = saved_wrapped_forward + transformer_module._original_forward = saved_original_forward + + if restore_fp32: + cast_training_params(pipeline.transformer, dtype=torch.float32) + + del pipeline + free_memory() + + return images + + +def module_filter_fn(mod: torch.nn.Module, fqn: str): + # Keep precision-sensitive modules in higher precision: the final output projection + # (final_layer.linear) and the modules Ideogram4Transformer2DModel flags in + # `_skip_layerwise_casting_patterns` (timestep embedding, AdaLN projection, image-indicator embed). + skip_patterns = ("final_layer.linear", "t_embedding", "adaln_proj", "embed_image_indicator") + if any(pattern in fqn for pattern in skip_patterns): + return False + # don't convert linear modules with weight dimensions not divisible by 16 + if isinstance(mod, torch.nn.Linear): + if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: + return False + return True + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--bnb_quantization_config_path", + type=str, + default=None, + help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.", + ) + parser.add_argument( + "--do_fp8_training", + action="store_true", + help="if we are doing FP8 training.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that πŸ€— Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + parser.add_argument( + "--upsample_prompt", + action="store_true", + help=( + "Rewrite each caption into Ideogram4's structured JSON caption (via the pipeline's prompt enhancer) " + "before encoding/caching. Loads the prompt-enhancer LM head (see --prompt_enhancer_head_id)." + ), + ) + parser.add_argument( + "--prompt_enhancer_head_id", + type=str, + default="diffusers/qwen3-vl-8b-instruct-lm-head", + help=( + "Repo id of the Ideogram4 prompt-enhancer LM head, loaded as the pipeline's `prompt_enhancer_head` " + "component when --upsample_prompt is set." + ), + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=2048, + help=( + "Maximum text token length for prompt encoding (Ideogram 4 uses a Qwen3-VL text encoder; the pipeline " + "default is 2048). Upsampled JSON captions can be long, so keep this large when using --upsample_prompt." + ), + ) + # parser.add_argument( + # "--text_encoder_out_layers", + # type=int, + # nargs="+", + # default=[10, 20, 30], + # help="Text encoder hidden layers to compute the final text embeddings.", + # ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--skip_final_inference", + default=False, + action="store_true", + help="Whether to skip the final inference step with loaded lora weights upon training completion. This will run intermediate validation inference if `validation_prompt` is provided. Specify to reduce memory.", + ) + parser.add_argument( + "--final_validation_prompt", + type=str, + default=None, + help="A prompt that is used during a final validation to verify that the model is learning. Ignored if `--validation_prompt` is provided.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=4, + help="LoRA alpha to be used for additional scaling.", + ) + parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers") + + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--aspect_ratio_buckets", + type=str, + default=None, + help=( + "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. " + "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'" + "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="logit_normal", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=( + 'Defaults to "logit_normal" to match the Ideogram4 schedule (resolution-aware mean + `--logit_std`). ' + "The sampled density is used directly as the flow-matching sigma." + ), + ) + parser.add_argument( + "--logit_mean", + type=float, + default=0.0, + help="Base mean of the logit-normal sigma schedule, before the resolution-aware shift " + "(0.5*log(pixels/512**2)). Matches the pipeline's `mu`.", + ) + parser.add_argument( + "--logit_std", + type=float, + default=1.5, + help="Std of the logit-normal sigma schedule. Defaults to 1.5 to match the Ideogram4 pipeline.", + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--upcast_before_saving", + action="store_true", + default=False, + help=( + "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " + "Defaults to precision dtype used for training to save memory" + ), + ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoder to CPU when they are not used.", + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU") + parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + if args.do_fp8_training and args.bnb_quantization_config_path: + raise ValueError("Both `do_fp8_training` and `bnb_quantization_config_path` cannot be passed.") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + buckets=None, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + self.buckets = buckets + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + for i, image in enumerate(self.instance_images): + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + + width, height = image.size + + # Find the closest bucket + bucket_idx = find_nearest_bucket(height, width, self.buckets) + target_height, target_width = self.buckets[bucket_idx] + self.size = (target_height, target_width) + + # based on the bucket assignment, define the transformations + image = self.train_transform( + image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, + ) + self.pixel_values.append((image, bucket_idx)) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + example["bucket_idx"] = bucket_idx + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # custom prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + def train_transform(self, image, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + + return image + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class BucketBatchSampler(BatchSampler): + def __init__( + self, + dataset: DreamBoothDataset, + batch_size: int, + drop_last: bool = False, + shuffle_batches_each_epoch: bool = True, + ): + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last)) + + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.shuffle_batches_each_epoch = shuffle_batches_each_epoch + + # Group indices by bucket + self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))] + for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values): + self.bucket_indices[bucket_idx].append(idx) + + self.sampler_len = 0 + self.batches = [] + + # Pre-generate batches for each bucket + for indices_in_bucket in self.bucket_indices: + # Shuffle indices within the bucket + random.shuffle(indices_in_bucket) + # Create batches + for i in range(0, len(indices_in_bucket), self.batch_size): + batch = indices_in_bucket[i : i + self.batch_size] + if len(batch) < self.batch_size and self.drop_last: + continue # Skip partial batch if drop_last is True + self.batches.append(batch) + self.sampler_len += 1 # Count the number of batches + + if not self.shuffle_batches_each_epoch: + # Shuffle the precomputed batches once to mix buckets while keeping + # the order stable across epochs for step-indexed caches. + random.shuffle(self.batches) + + def __iter__(self): + if self.shuffle_batches_each_epoch: + random.shuffle(self.batches) + for batch in self.batches: + yield batch + + def __len__(self): + return self.sampler_len + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `hf auth login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + if args.do_fp8_training: + from torchao.float8 import Float8LinearConfig, convert_to_float8_training + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + + pipeline = Ideogram4Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype): + images = pipeline(prompt=example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + free_memory() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Prompt upsampling needs the optional LM head grafted onto the (shared) text encoder; load it up front so the + # text-encoding pipeline can rewrite captions into Ideogram4's structured JSON schema during caching. + prompt_enhancer_head = None + if args.upsample_prompt: + try: + from diffusers import Ideogram4PromptEnhancerHead + except ImportError: + raise ValueError( + "`--upsample_prompt` requires a diffusers with Ideogram4 prompt upsampling " + "(`Ideogram4PromptEnhancerHead`); please upgrade diffusers." + ) + prompt_enhancer_head = Ideogram4PromptEnhancerHead.from_pretrained(args.prompt_enhancer_head_id) + + # Initialize a text encoding pipeline and keep it to CPU for now. + text_encoding_pipeline = Ideogram4Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=None, + transformer=None, + unconditional_transformer=None, + scheduler=None, + prompt_enhancer_head=prompt_enhancer_head, + revision=args.revision, + ) + tokenizer = text_encoding_pipeline.tokenizer + text_encoder = text_encoding_pipeline.text_encoder + + if args.upsample_prompt and getattr(text_encoding_pipeline, "prompt_enhancer_head", None) is None: + raise ValueError( + "`--upsample_prompt` was set but the prompt-enhancer head failed to load; check " + "`--prompt_enhancer_head_id` and your diffusers version." + ) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if prompt_enhancer_head is not None: + # The head is grafted onto the text-encoder body and consumes its hidden states, so it must share the body's + # (floating) dtype β€” e.g. the bf16 a bnb-quantized text encoder outputs, regardless of mixed precision. + prompt_enhancer_head.to(dtype=text_encoder.dtype) + + # Load models + vae = AutoencoderKLFlux2.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + patch_size = 2 + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + # Pixel size of one packed image token (vae_scale_factor * patch_size). + patch_pixels = vae_scale_factor * patch_size + latents_bn_mean = vae.bn.running_mean.view(1, 1, -1).to(accelerator.device) + latents_bn_std = torch.sqrt(vae.bn.running_var + vae.config.batch_norm_eps).view(1, 1, -1).to(accelerator.device) + + quantization_config = None + if args.bnb_quantization_config_path is not None: + with open(args.bnb_quantization_config_path, "r") as f: + config_kwargs = json.load(f) + if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]: + config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype + quantization_config = BitsAndBytesConfig(**config_kwargs) + + transformer = Ideogram4Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=quantization_config, + torch_dtype=weight_dtype, + ) + if args.bnb_quantization_config_path is not None: + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + text_encoder.requires_grad_(False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + if args.enable_npu_flash_attention: + if is_torch_npu_available(): + logger.info("npu flash attention enabled.") + transformer.set_attention_backend("_native_npu") + else: + raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ") + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + to_kwargs = {"dtype": weight_dtype, "device": accelerator.device} if not args.offload else {"dtype": weight_dtype} + # flux vae is stable in bf16 so load it in weight_dtype to reduce memory + vae.to(**to_kwargs) + # we never offload the transformer to CPU, so we can just use the accelerator device. + # A pre-quantized checkpoint (e.g. nf4) is already quantized even without + # --bnb_quantization_config_path; casting a quantized model to a dtype is unsupported, so only + # set dtype for an unquantized transformer. + transformer_is_quantized = args.bnb_quantization_config_path is not None or getattr( + transformer, "is_quantized", False + ) + transformer_to_kwargs = ( + {"device": accelerator.device} + if transformer_is_quantized + else {"device": accelerator.device, "dtype": weight_dtype} + ) + + is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None + if not is_fsdp: + transformer.to(**transformer_to_kwargs) + + if args.do_fp8_training: + convert_to_float8_training( + transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True) + ) + + # The nf4 checkpoint ships a pre-quantized text encoder too; casting a bitsandbytes model to a + # dtype is unsupported, so only move it to device (skip the dtype) when it is quantized. + text_encoder_is_quantized = ( + getattr(text_encoder, "is_quantized", False) + or getattr(text_encoder, "is_loaded_in_4bit", False) + or getattr(text_encoder, "is_loaded_in_8bit", False) + ) + if text_encoder_is_quantized: + if not args.offload: + text_encoder.to(device=accelerator.device) + else: + text_encoder.to(**to_kwargs) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + # Default: LoRA on each Ideogram4 block's attention projections (to_q/to_k/to_v/to_out.0) and + # feed-forward (w1/w2/w3). + target_modules = ["to_q", "to_k", "to_v", "to_out.0", "w1", "w2", "w3"] + + # now we will add new LoRA weights the transformer layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(model) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict + transformer_lora_layers_to_save = None + if accelerator.is_main_process: + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + Ideogram4Pipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + **_collate_lora_metadata(modules_to_save), + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + if not is_fsdp: + while len(models) > 0: + model = models.pop() + + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Ideogram4Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = Ideogram4Pipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + if args.aspect_ratio_buckets is not None: + buckets = parse_buckets_string(args.aspect_ratio_buckets) + else: + buckets = [(args.resolution, args.resolution)] + logger.info(f"Using parsed aspect ratio buckets: {buckets}") + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + buckets=buckets, + ) + has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts + batch_sampler = BucketBatchSampler( + train_dataset, + batch_size=args.train_batch_size, + drop_last=True, + shuffle_batches_each_epoch=not has_step_indexed_caches, + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + def compute_text_embeddings(prompt, grid_h, grid_w, text_encoding_pipeline): + # encode_prompt builds the full packed conditioning for the given image grid and returns + # (prompt_embeds, position_ids, segment_ids, indicator) ready to feed the transformer. + with torch.no_grad(): + if args.upsample_prompt: + # Ideogram4 is conditioned on a structured JSON caption; rewrite the caption(s) into that schema + # (matching the image aspect ratio) before encoding. Done once here, then cached. + # The enhancer grafts the LM head onto the (shared) text-encoder body, so the head must sit on the + # same device as that body for the call (the body may already be on GPU β€” e.g. a bnb-quantized text + # encoder β€” even when offloading is off). + with offload_models( + text_encoding_pipeline.prompt_enhancer_head, + device=text_encoding_pipeline.text_encoder.device, + offload=True, + ): + prompt = text_encoding_pipeline.upsample_prompt( + prompt, height=grid_h * patch_pixels, width=grid_w * patch_pixels + ) + prompt_embeds, position_ids, segment_ids, indicator = text_encoding_pipeline.encode_prompt( + prompt=prompt, + grid_h=grid_h, + grid_w=grid_w, + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + ) + return prompt_embeds, position_ids, segment_ids, indicator + + def patchify_latents(z, patch_size): + """Inverse of Ideogram4Pipeline._decode (w/o the bn de-normalization): pack a VAE latent + (B, ae_channels, H, W) into the transformer's token layout (B, grid_h*grid_w, ae*patch**2).""" + batch_size = z.shape[0] + patch = patch_size + # z shape: (batch_size, ae_channels, H, W) + + ae_channels = z.shape[1] + grid_h = z.shape[2] // patch + grid_w = z.shape[3] // patch + + # Patch: inverse of z.view(batch, ae_ch, grid_h*patch, grid_w*patch) + z = z.view(batch_size, ae_channels, grid_h, patch, grid_w, patch) + # Inverse of the pipeline decode's permute(0, 5, 1, 3, 2, 4): the packed channel layout is + # (p_h, p_w, ae) β€” ae varies fastest β€” so ae must be the LAST dim before flattening. + z = z.permute(0, 2, 4, 3, 5, 1).contiguous() + # Pack channels and flatten the grid to a token sequence: (B, grid_h*grid_w, patch*patch*ae) + z = z.view(batch_size, grid_h * grid_w, patch * patch * ae_channels) + + return z + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + # The packed conditioning depends on the image grid, so for the single instance prompt we use the + # fixed training-resolution grid (single-bucket training). + static_grid_h = static_grid_w = args.resolution // patch_pixels + if not train_dataset.custom_instance_prompts: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + instance_bundle = compute_text_embeddings( + args.instance_prompt, static_grid_h, static_grid_w, text_encoding_pipeline + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + class_bundle = compute_text_embeddings( + args.class_prompt, static_grid_h, static_grid_w, text_encoding_pipeline + ) + + # Init FSDP for text encoder + if args.fsdp_text_encoder: + fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator) + text_encoder_fsdp = wrap_with_fsdp( + model=text_encoding_pipeline.text_encoder, + device=accelerator.device, + offload=args.offload, + limit_all_gathers=True, + use_orig_params=True, + fsdp_kwargs=fsdp_kwargs, + ) + + text_encoding_pipeline.text_encoder = text_encoder_fsdp + dist.barrier() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + if not train_dataset.custom_instance_prompts: + static_bundle = instance_bundle + if args.with_prior_preservation: + # concat instance+class along the batch dim for each of (embeds, position_ids, segment_ids, indicator) + static_bundle = tuple(torch.cat([inst, cls], dim=0) for inst, cls in zip(instance_bundle, class_bundle)) + + # if cache_latents is set to True, we encode images to latents and store them. + # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided + # we encode them in advance as well. + if precompute_latents: + bundle_cache = [] + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + if args.cache_latents: + with offload_models(vae, device=accelerator.device, offload=args.offload): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=vae.dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if train_dataset.custom_instance_prompts: + # grid for this batch's actual image size (drives the packed layout); each batch is a + # single bucket, so this is correct under aspect-ratio bucketing. + grid_h = batch["pixel_values"].shape[-2] // patch_pixels + grid_w = batch["pixel_values"].shape[-1] // patch_pixels + if args.fsdp_text_encoder: + bundle = compute_text_embeddings(batch["prompts"], grid_h, grid_w, text_encoding_pipeline) + else: + with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload): + bundle = compute_text_embeddings(batch["prompts"], grid_h, grid_w, text_encoding_pipeline) + bundle_cache.append(bundle) + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + if args.cache_latents: + vae = vae.to("cpu") + del vae + + # With --upsample_prompt the LoRA only ever sees long structured (JSON) captions, so a plain + # validation prompt is out-of-distribution for it and degrades misleadingly as training + # progresses. Upsample the validation prompt once here (while the enhancer is still loaded) so + # validation matches the training conditioning distribution. + if args.upsample_prompt and args.validation_prompt is not None: + with offload_models( + text_encoding_pipeline.prompt_enhancer_head, + device=text_encoding_pipeline.text_encoder.device, + offload=True, + ): + upsampled = text_encoding_pipeline.upsample_prompt( + args.validation_prompt, height=args.resolution, width=args.resolution + ) + # upsample_prompt returns a list for a single string input; keep validation_prompt a + # str so downstream (tracker hparams, encode_prompt) handles it as a single prompt. + args.validation_prompt = upsampled[0] if isinstance(upsampled, list) else upsampled + logger.info(f"Upsampled validation prompt: {args.validation_prompt}") + + # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624 + text_encoding_pipeline = text_encoding_pipeline.to("cpu") + del text_encoder, tokenizer + free_memory() + + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-ideogram4-lora" + args_cp = vars(args).copy() + # args_cp["text_encoder_out_layers"] = str(args_cp["text_encoder_out_layers"]) + accelerator.init_trackers(tracker_name, config=args_cp) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + prompts = batch["prompts"] + + with accelerator.accumulate(models_to_accumulate): + if train_dataset.custom_instance_prompts: + prompt_embeds, position_ids, segment_ids, indicator = bundle_cache[step] + else: + # collate_fn orders prior-preservation batches as [inst1..instB, class1..classB]; + # repeat each static entry along dim 0 to match. Read from the stable `static_bundle`. + num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) + prompt_embeds, position_ids, segment_ids, indicator = ( + t.repeat_interleave(num_repeat_elements, dim=0) for t in static_bundle + ) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].mode() + else: + with offload_models(vae, device=accelerator.device, offload=args.offload): + pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.mode() + + # Pack the VAE latent into the transformer's token layout, then bn-normalize. + # The packed conditioning (prompt_embeds + position/segment/indicator) was built for this + # same grid by encode_prompt, so the sequence lengths line up. + image_height = model_input.shape[-2] * vae_scale_factor + image_width = model_input.shape[-1] * vae_scale_factor + model_input = patchify_latents(model_input, patch_size) # (B, grid_h*grid_w, latent_dim) + model_input = (model_input - latents_bn_mean) / latents_bn_std + + bsz, num_image_tokens, latent_dim = model_input.shape + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + + # Sample sigma per image. Ideogram4 uses a resolution-aware logit-normal schedule (defaults + # below), so we pass the pixel-count-shifted mean to compute_density_for_timestep_sampling and + # use its output directly as sigma -- unlike most scripts we do NOT index the FlowMatch + # scheduler grid, which carries a shift Ideogram4's schedule does not use. sigma in (0, 1). + sigmas = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=_resolution_aware_mu(image_height, image_width, base_mu=args.logit_mean), + logit_std=args.logit_std, + mode_scale=args.mode_scale, + device=accelerator.device, + ) + sigmas = sigmas.to(dtype=model_input.dtype).view(bsz, 1, 1) + + # Add noise according to flow matching: zt = (1 - sigma) * x0 + sigma * noise. + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # Pack the noisy image tokens behind zeroed text positions to form the full sequence + # the transformer expects: [text padding (max_sequence_length) ++ image tokens]. + text_latent_padding = torch.zeros( + bsz, + args.max_sequence_length, + latent_dim, + dtype=noisy_model_input.dtype, + device=noisy_model_input.device, + ) + packed_hidden_states = torch.cat([text_latent_padding, noisy_model_input], dim=1) + + # Ideogram4 model time t = 1 - sigma (0=noise, 1=data), shaped (B,). + model_timestep = (1.0 - sigmas).reshape(bsz) + + # Predict the velocity; only the image positions carry a meaningful prediction. + model_pred = transformer( + hidden_states=packed_hidden_states, + timestep=model_timestep, + encoder_hidden_states=prompt_embeds, + position_ids=position_ids, + segment_ids=segment_ids, + indicator=indicator, + return_dict=False, + )[0] + model_pred = model_pred[:, args.max_sequence_length :] + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # ideogram predicts v=x0βˆ’noise + target = model_input - noise + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + weighting, weighting_prior = torch.chunk(weighting, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or is_fsdp: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + pipeline = Ideogram4Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + transformer=unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + epoch=epoch, + torch_dtype=weight_dtype, + ) + + del pipeline + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + + if is_fsdp: + transformer = unwrap_model(transformer) + state_dict = accelerator.get_state_dict(transformer) + if accelerator.is_main_process: + modules_to_save = {} + if is_fsdp: + if args.bnb_quantization_config_path is None: + if args.upcast_before_saving: + state_dict = { + k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + else: + state_dict = { + k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items() + } + + transformer_lora_layers = get_peft_model_state_dict( + transformer, + state_dict=state_dict, + ) + transformer_lora_layers = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers.items() + } + + else: + transformer = unwrap_model(transformer) + # Skip the dtype cast for a quantized base (passed config or a pre-quantized nf4 checkpoint): + # casting a bitsandbytes model is unsupported, and we only extract the (unquantized) LoRA layers anyway. + if args.bnb_quantization_config_path is None and not getattr(transformer, "is_quantized", False): + if args.upcast_before_saving: + transformer.to(torch.float32) + else: + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + modules_to_save["transformer"] = transformer + + Ideogram4Pipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + **_collate_lora_metadata(modules_to_save), + ) + + images = [] + run_validation = (args.validation_prompt and args.num_validation_images > 0) or (args.final_validation_prompt) + should_run_final_inference = not args.skip_final_inference and run_validation + if should_run_final_inference: + pipeline = Ideogram4Pipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + epoch=epoch, + is_final_validation=True, + torch_dtype=weight_dtype, + ) + images = None + del pipeline + free_memory() + + validation_prompt = args.validation_prompt if args.validation_prompt else args.final_validation_prompt + quant_training = None + if args.do_fp8_training: + quant_training = "FP8 TorchAO" + elif args.bnb_quantization_config_path: + quant_training = "BitsandBytes" + save_model_card( + (args.hub_model_id or Path(args.output_dir).name) if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + instance_prompt=args.instance_prompt, + validation_prompt=validation_prompt, + repo_folder=args.output_dir, + quant_training=quant_training, + ) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 33eeba673a98..2eb1f5cc7a44 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -86,6 +86,7 @@ def text_encoder_attn_modules(text_encoder): "QwenImageLoraLoaderMixin", "ZImageLoraLoaderMixin", "Flux2LoraLoaderMixin", + "Ideogram4LoraLoaderMixin", "ErnieImageLoraLoaderMixin", "CosmosLoraLoaderMixin", ] @@ -128,6 +129,7 @@ def text_encoder_attn_modules(text_encoder): HeliosLoraLoaderMixin, HiDreamImageLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, + Ideogram4LoraLoaderMixin, KandinskyLoraLoaderMixin, LoraLoaderMixin, LTX2LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 52b2aad174be..f4bec3f15e1b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -6018,6 +6018,206 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class Ideogram4LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Ideogram4Transformer2DModel`]. Specific to [`Ideogram4Pipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + adapter_name: str | None = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: str | os.PathLike, + transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: dict | None = None, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: list[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: list[str] | None = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + class ErnieImageLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`ErnieImageTransformer2DModel`]. Specific to [`ErnieImagePipeline`]. diff --git a/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py b/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py index 61ba4fa43a62..0c11322c5c13 100644 --- a/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py +++ b/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py @@ -20,6 +20,7 @@ from transformers.masking_utils import create_causal_mask from ...image_processor import VaeImageProcessor +from ...loaders import Ideogram4LoraLoaderMixin from ...models.autoencoders import AutoencoderKLFlux2 from ...models.transformers.transformer_ideogram4 import ( IMAGE_POSITION_OFFSET, @@ -137,7 +138,7 @@ def _expand_tensor_to_effective_batch( return torch.repeat_interleave(tensor, repeats=repeat_by, dim=0, output_size=tensor.shape[0] * repeat_by) -class Ideogram4Pipeline(DiffusionPipeline): +class Ideogram4Pipeline(DiffusionPipeline, Ideogram4LoraLoaderMixin): r""" Text-to-image pipeline for Ideogram4.