diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index f3d1f3389bb7..ee7fdee24494 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -13,6 +13,7 @@ # limitations under the License. import hashlib +import json import os from contextlib import contextmanager, nullcontext from dataclasses import dataclass, replace @@ -21,8 +22,9 @@ import safetensors.torch import torch +from safetensors import safe_open -from ..utils import get_logger, is_accelerate_available, is_torchao_available +from ..utils import get_logger, is_accelerate_available, is_torchao_available, is_torchao_version from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS from .hooks import HookRegistry, ModelHook @@ -32,9 +34,22 @@ from accelerate.utils import send_to_device +if is_torchao_available(): + if is_torchao_version(">=", "0.16.0"): + from torchao.prototype.safetensors.safetensors_support import ( + flatten_tensor_state_dict, + unflatten_tensor_state_dict, + ) + from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao + + logger = get_logger(__name__) # pylint: disable=invalid-name +def _supports_torchao_safetensors() -> bool: + return is_torchao_available() and is_torchao_version(">=", "0.16.0") + + def _is_torchao_tensor(tensor: torch.Tensor) -> bool: if not is_torchao_available(): return False @@ -162,6 +177,8 @@ def __init__( self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} + self._torchao_disk_key_remap: dict[str, str] = {} + self._has_torchao_tensors = any(_is_torchao_tensor(tensor) for tensor in self.tensor_to_key) self.cpu_param_dict = {} else: self.cpu_param_dict = self._init_cpu_param_dict() @@ -179,6 +196,27 @@ def _to_cpu(tensor, low_cpu_mem_usage): t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu() return t if low_cpu_mem_usage else t.pin_memory() + @staticmethod + def _get_torchao_subset_metadata_for_unflatten(metadata): + tensor_names = metadata.get("tensor_names") + if tensor_names is None: + return None + + try: + tensor_names = json.loads(tensor_names) + except (TypeError, json.JSONDecodeError): + logger.warning("Could not parse TorchAO safetensors metadata for disk offloading; using full metadata.") + return None + + dotted_tensor_names = [name for name in tensor_names if "." in name] + if len(dotted_tensor_names) == 0: + return None + + return { + "tensor_names": json.dumps(dotted_tensor_names), + **{name: metadata[name] for name in dotted_tensor_names if name in metadata}, + } + def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -238,18 +276,78 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None) source = pinned_memory[buffer] if pinned_memory else buffer.data self._transfer_tensor_to_device(buffer, source, default_stream) - def _check_disk_offload_torchao(self): - all_tensors = list(self.tensor_to_key.keys()) - has_torchao = any(_is_torchao_tensor(t) for t in all_tensors) - if has_torchao: + def _check_disk_offload_torchao_support(self): + if self._has_torchao_tensors and not _supports_torchao_safetensors(): raise ValueError( - "Disk offloading is not supported for TorchAO quantized tensors because safetensors " - "cannot serialize TorchAO subclass tensors. Use memory offloading instead by not " - "setting `offload_to_disk_path`." + "Disk offloading TorchAO quantized tensors requires torchao >= 0.16.0 because older torchao " + "versions cannot serialize tensor subclasses with safetensors. Use memory offloading instead by " + "not setting `offload_to_disk_path`." ) + def _get_torchao_disk_state_dict(self): + tensors_to_save = { + key: ( + tensor.to(self.offload_device) if _is_torchao_tensor(tensor) else tensor.data.to(self.offload_device) + ) + for tensor, key in self.tensor_to_key.items() + } + + # TorchAO safetensors support expects logical parameter names and stores + # tensor subclass internals plus reconstruction metadata separately. + metadata = {} + tensors_for_flatten = {} + self._torchao_disk_key_remap = self._get_torchao_disk_key_remap() + for key, tensor in tensors_to_save.items(): + tensors_for_flatten[self._torchao_disk_key_remap.get(key, key)] = tensor + + flattened_state_dict = flatten_tensor_state_dict(tensors_for_flatten) + if isinstance(flattened_state_dict, tuple): + tensors_to_save, metadata = flattened_state_dict + else: + tensors_to_save = flattened_state_dict + + return tensors_to_save, metadata + + def _get_torchao_disk_key_remap(self): + return { + key: f"{key}.weight" + for tensor, key in self.tensor_to_key.items() + if _is_torchao_tensor(tensor) and "." not in key + } + + def _load_torchao_disk_state_dict(self, device): + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) + + with safe_open(self.safetensors_file_path, framework="pt") as f: + metadata = f.metadata() or {} + + if is_metadata_torchao(metadata): + metadata = self._get_torchao_subset_metadata_for_unflatten(metadata) or metadata + try: + reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict(loaded_tensors, metadata) + loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict} + except Exception as error: + raise RuntimeError("Failed to reconstruct TorchAO tensors from disk offload safetensors.") from error + + # Support legacy in-memory tensor keys used by GroupOffloading when + # flattening introduced dot-based names to satisfy TorchAO's safetensors API. + self._torchao_disk_key_remap = self._get_torchao_disk_key_remap() + for original_key, flattened_key in self._torchao_disk_key_remap.items(): + if original_key not in loaded_tensors and flattened_key in loaded_tensors: + loaded_tensors[original_key] = loaded_tensors.pop(flattened_key) + + return loaded_tensors + + def _release_torchao_onload_tensors(self): + for tensor_obj in self.tensor_to_key.keys(): + if _is_torchao_tensor(tensor_obj): + placeholder = tensor_obj.to(self.offload_device) + _swap_torchao_tensor(tensor_obj, placeholder) + else: + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + def _onload_from_disk(self): - self._check_disk_offload_torchao() + self._check_disk_offload_torchao_support() if self.stream is not None: # Wait for previous Host->Device transfer to complete @@ -259,22 +357,27 @@ def _onload_from_disk(self): current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None with context: + device = str(self.onload_device) if self.stream is None else "cpu" + loaded_tensors = ( + self._load_torchao_disk_state_dict(device=device) + if self._has_torchao_tensors + else safetensors.torch.load_file(self.safetensors_file_path, device=device) + ) + if self.stream is not None: - # Load to CPU first, pin memory, then async copy to the target device - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") - for key, tensor_obj in self.key_to_tensor.items(): - pinned_tensor = loaded_tensors[key].pin_memory() - tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - tensor_obj.data.record_stream(current_stream) + pinned_memory = { + tensor_obj: loaded_tensors[self.tensor_to_key[tensor_obj]].pin_memory() + for tensor_obj in self.tensor_to_key + } + for tensor_obj, pinned_tensor in pinned_memory.items(): + self._transfer_tensor_to_device(tensor_obj, pinned_tensor, current_stream) else: - # Load directly to the target device - onload_device = ( - self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device - ) - loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) - for key, tensor_obj in self.key_to_tensor.items(): - tensor_obj.data = loaded_tensors[key] + for tensor_obj in self.tensor_to_key: + self._transfer_tensor_to_device( + tensor_obj, + loaded_tensors[self.tensor_to_key[tensor_obj]], + default_stream=None, + ) def _onload_from_memory(self): if self.stream is not None: @@ -292,7 +395,7 @@ def _onload_from_memory(self): self._process_tensors_from_modules(None) def _offload_to_disk(self): - self._check_disk_offload_torchao() + self._check_disk_offload_torchao_support() # TODO: we can potentially optimize this code path by checking if the _all_ the desired # safetensor files exist on the disk and if so, skip this step entirely, reducing IO @@ -301,15 +404,24 @@ def _offload_to_disk(self): # Check if the file has been saved in this session or if it already exists on disk. if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) - tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} - safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) + if self._has_torchao_tensors: + tensors_to_save, metadata = self._get_torchao_disk_state_dict() + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path, metadata=metadata) + else: + tensors_to_save = { + key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() + } + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True # We do this to free up the RAM which is still holding the up tensor data. - for tensor_obj in self.tensor_to_key.keys(): - tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) + if self._has_torchao_tensors: + self._release_torchao_onload_tensors() + else: + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) def _offload_to_memory(self): if self.stream is not None: diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 8a811cfc1c73..c0c1a8fd4fc9 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -605,6 +605,47 @@ def _check_serialization_expected_slice(self, quant_type, expected_slice, device self.assertTrue(isinstance(loaded_quantized_model.proj_out.weight, TorchAOBaseTensor)) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + @require_torchao_version_greater_or_equal("0.16.0") + def test_group_offload_to_disk(self): + quant_type = Int8DynamicActivationInt8WeightConfig(version=2) + expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) + + quantized_model = self.get_dummy_model(quant_type, torch_device) + + with tempfile.TemporaryDirectory() as offload_to_disk_path: + quantized_model.enable_group_offload( + onload_device=torch_device, + offload_type="leaf_level", + offload_to_disk_path=offload_to_disk_path, + ) + + inputs = self.get_dummy_tensor_inputs(torch_device) + output = quantized_model(**inputs)[0] + output_slice = output.flatten()[-9:].detach().float().cpu().numpy() + + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + + output = quantized_model(**inputs)[0] + output_slice_2 = output.flatten()[-9:].detach().float().cpu().numpy() + + self.assertTrue(numpy_cosine_similarity_distance(output_slice_2, expected_slice) < 1e-3) + + del quantized_model + gc.collect() + backend_empty_cache(torch_device) + + quantized_model = self.get_dummy_model(quant_type, torch_device) + quantized_model.enable_group_offload( + onload_device=torch_device, + offload_type="leaf_level", + offload_to_disk_path=offload_to_disk_path, + ) + + output = quantized_model(**inputs)[0] + output_slice_3 = output.flatten()[-9:].detach().float().cpu().numpy() + + self.assertTrue(numpy_cosine_similarity_distance(output_slice_3, expected_slice) < 1e-3) + def test_int_a8w8_accelerator(self): quant_type = Int8DynamicActivationInt8WeightConfig() expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])