Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 0 additions & 39 deletions modelopt/torch/opt/plugins/mcore_dist_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Any

import torch
import yaml
from megatron.core import dist_checkpointing, mpu
from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy
from megatron.core.dist_checkpointing.strategies.common import COMMON_STATE_FNAME
Expand All @@ -36,21 +35,6 @@

SUPPORTED_WRAPPERS[Float16Module] = "module"

DROP_SUBSTRINGS = [
"fp4",
"fp8",
"tp_",
"parallel",
"cuda_graph",
"init_",
"cpu",
"recompute",
"inference",
"pipeline",
"comm",
"batch",
]


def remove_per_module_state(
modelopt_state: dict[str, Any],
Expand Down Expand Up @@ -138,29 +122,6 @@ def save_sharded_modelopt_state(
sharded_strategy: configures sharded tensors saving behavior and backend
prefix: the prefix to add to the modelopt_state keys ("model." for NeMo)
"""

def _parse_transformer_config(transformer_config: dict) -> dict:
config = {}

for k, v in transformer_config.items():
if any(substring in k for substring in DROP_SUBSTRINGS):
continue
if isinstance(v, (bool, int, str)):
config[k] = v
else:
config[k] = str(v)

return config

# Save own version of run config, if not already saved by the framework.
if dist.is_master() and not os.path.exists(f"{checkpoint_name}/run_config.yaml"):
run_config_name = f"{checkpoint_name}/modelopt_run_config.yaml"
# We avoid deepcopy since some attributes in Megatron-Bridge config cannot be deepcopied.
config_dict = _parse_transformer_config(model[0].config.__dict__)
config_dict["nvidia_modelopt_version"] = modelopt.__version__
with open(run_config_name, "w") as f:
yaml.dump(config_dict, f, default_flow_style=False)

if not mto.ModeloptStateManager.is_converted(model[0]):
return
if len(model) > 1:
Expand Down
Loading