Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 30 additions & 23 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@
from tqdm import tqdm

from modelopt.torch.opt.searcher import ForwardLoop
from modelopt.torch.quantization.utils import LayerActivationCollector
from modelopt.torch.quantization.utils.activation_collector import LayerActivationCollector
from modelopt.torch.utils import print_rank_0
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.network import (
bind_forward_method,
get_decoder_layers,
unpatch_forward_method,
)
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method
from modelopt.torch.utils.perf import get_used_gpu_mem_fraction

from .calib import MseCalibrator, NVFP4MSECalibrator
Expand Down Expand Up @@ -1848,31 +1844,42 @@ def sequential_calibrate(
calib_func: Callable,
**calib_kwargs,
):
"""Sequential calibration - a sequential layer-by-layer calibration algorithm."""
"""Sequential calibration - a sequential layer-by-layer calibration algorithm.

Runs the full model forward per layer but patches decoder layers with a
skip / run / capture strategy so that inter-layer logic in parent modules
(e.g. mask construction) executes naturally without model-specific hooks.
"""
Comment thread
sugunav14 marked this conversation as resolved.
if forward_loop is None:
raise ValueError("forward_loop must not be None for sequential calibration.")
raise ValueError(
"forward_loop must not be None for sequential calibration. "
"Please provide a valid forward_loop callable."
)

transformer_layers = get_decoder_layers(model)
if transformer_layers is None:
transformer_layers = LayerActivationCollector.get_decoder_layers(model)
if transformer_layers is None or len(transformer_layers) == 0:
raise ValueError(
"Could not find transformer layers in model'. "
"Could not find transformer layers in model. "
"Sequential calibration requires a model with identifiable transformer layers."
)

print_rank_0(f"Sequential calibration: Found {len(transformer_layers)} transformer layers")

gettr = LayerActivationCollector(model)
input_getter = LayerActivationCollector(model)
input_getter._patch_all_layers(decoder_layers=transformer_layers)

for layer in transformer_layers:
# Get updated input activations to the current layer
layer_inputs = gettr.get_input_activations(layer, forward_loop)
try:
for layer_idx, layer in enumerate(transformer_layers):
print_rank_0(f"Calibrating layer {layer_idx + 1}/{len(transformer_layers)}")
layer_inputs = input_getter.get_input_activations(layer, forward_loop)

# Define a forward loop for the current layer
def _layer_forward_loop(m, _inputs=layer_inputs):
for args, kwargs_input in _inputs:
m(*args, **kwargs_input)
def _layer_forward_loop(m, _inputs=layer_inputs):
for args, kwargs_input in _inputs:
m(*args, **kwargs_input)

# Call calibration function
calib_func(layer, _layer_forward_loop, **calib_kwargs)
del layer_inputs
torch.cuda.empty_cache()
calib_func(layer, _layer_forward_loop, **calib_kwargs)
Comment thread
sugunav14 marked this conversation as resolved.

del layer_inputs
torch.cuda.empty_cache()
finally:
input_getter._unpatch_all_layers()
48 changes: 48 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ..nn.modules.quant_linear import _QuantLinear
from ..triton import IS_AVAILABLE as IS_TRITON_AVAILABLE
from ..utils import replace_function, sync_moe_expert_amax
from ..utils.activation_collector import LayerActivationCollector
from .attention import register_attention_for_kv_quant
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear, _QuantFunctionalMixin

Expand Down Expand Up @@ -1367,6 +1368,42 @@ def _is_supported_hf_model(model):
return isinstance(model, tuple(supported_models))


def is_nemotron_h_model(model: nn.Module) -> bool:
return get_nemotron_h_decoder_layers(model) is not None


def get_nemotron_h_decoder_layers(model: nn.Module) -> nn.ModuleList | None:
if not _is_supported_hf_model(model):
return None

if hasattr(model, "backbone") and hasattr(model.backbone, "layers"):
layers = model.backbone.layers
if len(layers) > 0 and hasattr(layers[0], "block_type"):
return layers

return None


def is_homogeneous_hf_model(model: nn.Module) -> bool:
if is_nemotron_h_model(model):
return False
decoder_layers = get_homogeneous_hf_decoder_layers(model)
if decoder_layers is None or len(decoder_layers) == 0:
return False
layer_classes = {type(layer) for layer in decoder_layers}
return len(layer_classes) == 1


def get_homogeneous_hf_decoder_layers(model: nn.Module) -> nn.ModuleList | None:
if not _is_supported_hf_model(model):
return None

if hasattr(model, "model") and hasattr(model.model, "layers"):
return model.model.layers

return None


@contextmanager
def setup_model_for_gradient_checkpointing(model: nn.Module):
use_cache = None
Expand Down Expand Up @@ -1420,6 +1457,17 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
_is_param_grad_enabled_for_auto_quantize,
)

# Order matters: more specific predicates must be registered first because
# the first matching entry wins. Nemotron-H must precede the generic
# homogeneous HF discoverer (which explicitly rejects Nemotron-H).
LayerActivationCollector.register_decoder_layer_support(
is_nemotron_h_model, get_nemotron_h_decoder_layers
)

LayerActivationCollector.register_decoder_layer_support(
is_homogeneous_hf_model, get_homogeneous_hf_decoder_layers
)
Comment thread
sugunav14 marked this conversation as resolved.

CUSTOM_MODEL_PLUGINS.update(
[
register_falcon_linears_on_the_fly,
Expand Down
35 changes: 35 additions & 0 deletions modelopt/torch/quantization/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ruff: noqa: F405
"""Quantization utilities."""

from .activation_collector import LayerActivationCollector
from .core_utils import *

__all__ = [
"EXPORT_MODE",
"convert_quantization_axis_to_reduce_axis",
"export_torch_mode",
"is_quantized",
"is_quantized_column_parallel_linear",
"is_quantized_linear",
"is_quantized_row_parallel_linear",
"reduce_amax",
"reduce_sum",
"replace_function",
"update_quant_cfg_with_kv_cache_quant",
"weight_attr_names",
]
Loading
Loading