Skip to content

Fix checkpoint export errors for the Dream model #311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 25, 2025
3 changes: 3 additions & 0 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,6 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None:
shutil.copy(self.configuration_file, config.path)
if self.generation_utils_file:
shutil.copy(self.generation_utils_file, config.path)
gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json"
if gen_config.exists():
shutil.copy(gen_config, config.path)
92 changes: 17 additions & 75 deletions fast_llm/models/gpt/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,15 @@ def import_weight(
class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler):
_model: GPTModel
_model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig
architecture: typing.ClassVar[list[str]]
"""
Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral)
"""

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=cls.architecture),
ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False),
RenameParamConverter(
fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),)
Expand Down Expand Up @@ -320,8 +322,8 @@ class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler)

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = ["Starcoder2ForCausalLM"]
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "rotary", "type"),),
fast_llm_value=DefaultRotaryConfig.dynamic_type_name,
Expand Down Expand Up @@ -447,8 +449,8 @@ class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler)

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = ["LlamaForCausalLM"]
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["LlamaForCausalLM"]),
# TODO: Llama supports biases
ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False),
ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False),
Expand Down Expand Up @@ -499,8 +501,8 @@ class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler):

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = ["Qwen2ForCausalLM"]
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "normalization", "type"),),
fast_llm_value="rms_norm",
Expand Down Expand Up @@ -545,8 +547,8 @@ class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = ["MistralForCausalLM"]
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MistralForCausalLM"]),
IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None),
]

Expand All @@ -569,8 +571,8 @@ class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = ["MixtralForCausalLM"]
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MixtralForCausalLM"]),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk
),
Expand Down Expand Up @@ -613,8 +615,8 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = ["MTPLlamaForCausalLM"]
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MTPLlamaForCausalLM"]),
ConstantExportParamConverter(
export_names=(("auto_map",),),
export_value={
Expand Down Expand Up @@ -685,7 +687,12 @@ def _create_lm_head_converters(self) -> list[WeightConverter]:
return converters


class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonHuggingfaceCheckpointHandler):
class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen2HuggingfaceCheckpointHandler):
"""
Handler for DiffusionDream Huggingface checkpoints.
Inherits from Qwen2HuggingfaceCheckpointHandler (and CustomModelingExportMixin),
but overrides _create_config_converters to update architectures and auto_map.
"""

from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream

Expand All @@ -697,33 +704,8 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Comm

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = ["DreamModel"]
return super()._create_config_converters() + [
# From Qwen2HuggingfaceCheckpointHandler - Change architectures to DiffusionDream
ConstantImportParamConverter(
fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm
),
RenameParamConverter(
fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),)
),
ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv"
),
RopeScalingParamConverter(
fast_llm_names=(
("transformer", "rotary", "type"),
("transformer", "rotary", "scale_factor"),
("transformer", "rotary", "low_frequency_factor"),
("transformer", "rotary", "high_frequency_factor"),
("transformer", "rotary", "original_context_length"),
("transformer", "rotary", "attention_factor"),
("transformer", "rotary", "beta_fast"),
("transformer", "rotary", "beta_slow"),
),
export_names=(("rope_scaling",),),
),
IgnoreImportQwen2SlidingWindowParamsConverter(),
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DreamModel"]),
ConstantExportParamConverter(
export_names=(("auto_map",),),
export_value={
Expand All @@ -733,26 +715,8 @@ def _create_config_converters(cls) -> list[ParamConverter]:
),
]

def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
# From Qwen2HuggingfaceCheckpointHandler
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
transformer_config.add_mlp_bias,
SplitWeightConverter,
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]


class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlamaHuggingfaceCheckpointHandler):
class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlamaHuggingfaceCheckpointHandler):

from fast_llm.models.gpt.external.diffusion_llama import (
configuration_diffusion_llama,
Expand All @@ -768,12 +732,8 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Comm

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
cls.architecture = ["DiffusionLlamaModel"]
return super()._create_config_converters() + [
# From LlamaHuggingfaceCheckpointHandler - Update architectures to DiffusionLlama
# TODO: Llama supports biases
ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False),
ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False),
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DiffusionLlamaModel"]),
ConstantExportParamConverter(
export_names=(("auto_map",),),
export_value={
Expand All @@ -789,24 +749,6 @@ def _create_config_converters(cls) -> list[ParamConverter]:
# ),
]

def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
# From LlamaHuggingfaceCheckpointHandler
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
transformer_config.add_mlp_bias,
SplitWeightConverter,
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]


class AutoGPTHuggingfaceCheckpointHandler(
AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -18,7 +17,6 @@
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -47,7 +45,7 @@ def __init__(
max_window_layers=28,
attention_dropout=0.0,
mask_token_id=151666,
pad_token_id=151643, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=None,
pad_token_id=None, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=151643,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -77,7 +75,7 @@ def __init__(
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"_from_model_config": true,
"bos_token_id": 151643,
"eos_token_id": 151643,
"pad_token_id": 151643,
"transformers_version": "4.46.2"
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
Expand All @@ -20,7 +19,7 @@
"""LLaMA model configuration"""

import math
from typing import Optional, Tuple
from typing import Optional

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import is_torch_available, logging
Expand All @@ -30,13 +29,14 @@
if is_torch_available():
import torch


# Update yarn implementation for RoPE (Taken from Llama but updated to use original_max_position_embeddings)
def _compute_default_rope_parameters(
config: Optional[PretrainedConfig] = None,
device: Optional["torch.device"] = None,
seq_len: Optional[int] = None,
**rope_kwargs,
) -> Tuple["torch.Tensor", float]:
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies according to the original RoPE implementation
Args:
Expand Down Expand Up @@ -72,9 +72,10 @@ def _compute_default_rope_parameters(
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
return inv_freq, attention_factor


def _compute_yarn_parameters(
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
) -> Tuple["torch.Tensor", float]:
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies with NTK scaling. Please refer to the
[original paper](https://arxiv.org/abs/2309.00071)
Expand All @@ -101,7 +102,7 @@ def _compute_yarn_parameters(
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
dim = int(head_dim * partial_rotary_factor)

# Apriel: Use original max_position_embeddings instead of max_position_embeddings
max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings")
factor = config.rope_scaling["factor"]
Expand Down Expand Up @@ -152,14 +153,14 @@ def linear_ramp_factor(min, max, dim):

return inv_freq, attention_factor


def _check_received_keys(
rope_type: str,
received_keys: set,
required_keys: set,
optional_keys: Optional[set] = None,
ignore_keys: Optional[set] = None,
):

"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
if "type" in received_keys:
Expand Down Expand Up @@ -189,6 +190,7 @@ def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Opt
received_keys = set(rope_scaling.keys())
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)


def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
rope_scaling = config.rope_scaling
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
Expand Down Expand Up @@ -218,6 +220,8 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
)


# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
# parameterizations, as long as the callable has the same signature.
Expand All @@ -232,6 +236,7 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
"yarn": _validate_yarn_parameters,
}


def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
"""
Validate the RoPE config arguments, given a `PretrainedConfig` object
Expand All @@ -250,6 +255,7 @@ def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set]
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
)


class DiffusionLlamaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
Expand Down Expand Up @@ -397,7 +403,7 @@ def __init__(
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=False, # cache not implemented in diffusion
use_cache=False, # cache not implemented in diffusion
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
Expand All @@ -409,7 +415,7 @@ def __init__(
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
# mask_token_id= TODO: add the mask_token_id we will be using,
mask_token_id=131072, # TODO: add the mask_token_id we will be using,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -435,6 +441,8 @@ def __init__(
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
self.mask_token_id = mask_token_id
self.pad_token_id = pad_token_id
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
Expand All @@ -450,4 +458,5 @@ def __init__(
)
# TODO: self.mask_token_id = mask_token_id

__all__ = ["LlamaConfig"]

__all__ = ["DiffusionLlamaConfig"]
Loading
Loading