Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 140 additions & 28 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import hashlib
import json
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, replace
Expand All @@ -21,8 +22,9 @@

import safetensors.torch
import torch
from safetensors import safe_open

from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available, is_torchao_version
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook

Expand All @@ -32,9 +34,22 @@
from accelerate.utils import send_to_device


if is_torchao_available():
if is_torchao_version(">=", "0.16.0"):
from torchao.prototype.safetensors.safetensors_support import (
flatten_tensor_state_dict,
unflatten_tensor_state_dict,
)
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao


logger = get_logger(__name__) # pylint: disable=invalid-name


def _supports_torchao_safetensors() -> bool:
return is_torchao_available() and is_torchao_version(">=", "0.16.0")


def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
if not is_torchao_available():
return False
Expand Down Expand Up @@ -162,6 +177,8 @@ def __init__(

self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
self._torchao_disk_key_remap: dict[str, str] = {}
self._has_torchao_tensors = any(_is_torchao_tensor(tensor) for tensor in self.tensor_to_key)
self.cpu_param_dict = {}
else:
self.cpu_param_dict = self._init_cpu_param_dict()
Expand All @@ -179,6 +196,27 @@ def _to_cpu(tensor, low_cpu_mem_usage):
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
return t if low_cpu_mem_usage else t.pin_memory()

@staticmethod
def _get_torchao_subset_metadata_for_unflatten(metadata):
tensor_names = metadata.get("tensor_names")
if tensor_names is None:
return None

try:
tensor_names = json.loads(tensor_names)
except (TypeError, json.JSONDecodeError):
logger.warning("Could not parse TorchAO safetensors metadata for disk offloading; using full metadata.")
return None

dotted_tensor_names = [name for name in tensor_names if "." in name]
if len(dotted_tensor_names) == 0:
return None

return {
"tensor_names": json.dumps(dotted_tensor_names),
**{name: metadata[name] for name in dotted_tensor_names if name in metadata},
}

def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
Expand Down Expand Up @@ -238,18 +276,78 @@ def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None)
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, default_stream)

def _check_disk_offload_torchao(self):
all_tensors = list(self.tensor_to_key.keys())
has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
if has_torchao:
def _check_disk_offload_torchao_support(self):
if self._has_torchao_tensors and not _supports_torchao_safetensors():
raise ValueError(
"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
"setting `offload_to_disk_path`."
"Disk offloading TorchAO quantized tensors requires torchao >= 0.16.0 because older torchao "
"versions cannot serialize tensor subclasses with safetensors. Use memory offloading instead by "
"not setting `offload_to_disk_path`."
)

def _get_torchao_disk_state_dict(self):
tensors_to_save = {
key: (
tensor.to(self.offload_device) if _is_torchao_tensor(tensor) else tensor.data.to(self.offload_device)
)
for tensor, key in self.tensor_to_key.items()
}

# TorchAO safetensors support expects logical parameter names and stores
# tensor subclass internals plus reconstruction metadata separately.
metadata = {}
tensors_for_flatten = {}
self._torchao_disk_key_remap = self._get_torchao_disk_key_remap()
for key, tensor in tensors_to_save.items():
tensors_for_flatten[self._torchao_disk_key_remap.get(key, key)] = tensor

flattened_state_dict = flatten_tensor_state_dict(tensors_for_flatten)
if isinstance(flattened_state_dict, tuple):
tensors_to_save, metadata = flattened_state_dict
else:
tensors_to_save = flattened_state_dict

return tensors_to_save, metadata

def _get_torchao_disk_key_remap(self):
return {
key: f"{key}.weight"
for tensor, key in self.tensor_to_key.items()
if _is_torchao_tensor(tensor) and "." not in key
}

def _load_torchao_disk_state_dict(self, device):
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)

with safe_open(self.safetensors_file_path, framework="pt") as f:
metadata = f.metadata() or {}

if is_metadata_torchao(metadata):
metadata = self._get_torchao_subset_metadata_for_unflatten(metadata) or metadata
try:
reconstructed_state_dict, leftover_state_dict = unflatten_tensor_state_dict(loaded_tensors, metadata)
loaded_tensors = {**leftover_state_dict, **reconstructed_state_dict}
except Exception as error:
raise RuntimeError("Failed to reconstruct TorchAO tensors from disk offload safetensors.") from error

# Support legacy in-memory tensor keys used by GroupOffloading when
# flattening introduced dot-based names to satisfy TorchAO's safetensors API.
self._torchao_disk_key_remap = self._get_torchao_disk_key_remap()
for original_key, flattened_key in self._torchao_disk_key_remap.items():
if original_key not in loaded_tensors and flattened_key in loaded_tensors:
loaded_tensors[original_key] = loaded_tensors.pop(flattened_key)

return loaded_tensors

def _release_torchao_onload_tensors(self):
for tensor_obj in self.tensor_to_key.keys():
if _is_torchao_tensor(tensor_obj):
placeholder = tensor_obj.to(self.offload_device)
_swap_torchao_tensor(tensor_obj, placeholder)
else:
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)

def _onload_from_disk(self):
self._check_disk_offload_torchao()
self._check_disk_offload_torchao_support()

if self.stream is not None:
# Wait for previous Host->Device transfer to complete
Expand All @@ -259,22 +357,27 @@ def _onload_from_disk(self):
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None

with context:
device = str(self.onload_device) if self.stream is None else "cpu"
loaded_tensors = (
self._load_torchao_disk_state_dict(device=device)
if self._has_torchao_tensors
else safetensors.torch.load_file(self.safetensors_file_path, device=device)
)

if self.stream is not None:
# Load to CPU first, pin memory, then async copy to the target device
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
pinned_memory = {
tensor_obj: loaded_tensors[self.tensor_to_key[tensor_obj]].pin_memory()
for tensor_obj in self.tensor_to_key
}
for tensor_obj, pinned_tensor in pinned_memory.items():
self._transfer_tensor_to_device(tensor_obj, pinned_tensor, current_stream)
else:
# Load directly to the target device
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
for tensor_obj in self.tensor_to_key:
self._transfer_tensor_to_device(
tensor_obj,
loaded_tensors[self.tensor_to_key[tensor_obj]],
default_stream=None,
)

def _onload_from_memory(self):
if self.stream is not None:
Expand All @@ -292,7 +395,7 @@ def _onload_from_memory(self):
self._process_tensors_from_modules(None)

def _offload_to_disk(self):
self._check_disk_offload_torchao()
self._check_disk_offload_torchao_support()

# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
Expand All @@ -301,15 +404,24 @@ def _offload_to_disk(self):
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
if self._has_torchao_tensors:
tensors_to_save, metadata = self._get_torchao_disk_state_dict()
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path, metadata=metadata)
else:
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)

# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True

# We do this to free up the RAM which is still holding the up tensor data.
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
if self._has_torchao_tensors:
self._release_torchao_onload_tensors()
else:
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)

def _offload_to_memory(self):
if self.stream is not None:
Expand Down
41 changes: 41 additions & 0 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,47 @@ def _check_serialization_expected_slice(self, quant_type, expected_slice, device
self.assertTrue(isinstance(loaded_quantized_model.proj_out.weight, TorchAOBaseTensor))
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)

@require_torchao_version_greater_or_equal("0.16.0")
def test_group_offload_to_disk(self):
quant_type = Int8DynamicActivationInt8WeightConfig(version=2)
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])

quantized_model = self.get_dummy_model(quant_type, torch_device)

with tempfile.TemporaryDirectory() as offload_to_disk_path:
quantized_model.enable_group_offload(
onload_device=torch_device,
offload_type="leaf_level",
offload_to_disk_path=offload_to_disk_path,
)

inputs = self.get_dummy_tensor_inputs(torch_device)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()

self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)

output = quantized_model(**inputs)[0]
output_slice_2 = output.flatten()[-9:].detach().float().cpu().numpy()

self.assertTrue(numpy_cosine_similarity_distance(output_slice_2, expected_slice) < 1e-3)

del quantized_model
gc.collect()
backend_empty_cache(torch_device)

quantized_model = self.get_dummy_model(quant_type, torch_device)
quantized_model.enable_group_offload(
onload_device=torch_device,
offload_type="leaf_level",
offload_to_disk_path=offload_to_disk_path,
)

output = quantized_model(**inputs)[0]
output_slice_3 = output.flatten()[-9:].detach().float().cpu().numpy()

self.assertTrue(numpy_cosine_similarity_distance(output_slice_3, expected_slice) < 1e-3)

def test_int_a8w8_accelerator(self):
quant_type = Int8DynamicActivationInt8WeightConfig()
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
Expand Down
Loading