Skip to content

Commit 37a4944

Browse files
authored
Fix checkpoint export errors for the Dream model (#311)
1 parent 0f77750 commit 37a4944

File tree

7 files changed

+111
-789
lines changed

7 files changed

+111
-789
lines changed

fast_llm/engine/checkpoint/huggingface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,6 @@ def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None:
156156
shutil.copy(self.configuration_file, config.path)
157157
if self.generation_utils_file:
158158
shutil.copy(self.generation_utils_file, config.path)
159+
gen_config = pathlib.Path(self.generation_utils_file).parent / "generation_config.json"
160+
if gen_config.exists():
161+
shutil.copy(gen_config, config.path)

fast_llm/models/gpt/conversion.py

Lines changed: 17 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,15 @@ def import_weight(
118118
class CommonHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler):
119119
_model: GPTModel
120120
_model_class: typing.ClassVar[FastLLMModelConfig] = GPTModelConfig
121+
architecture: typing.ClassVar[str]
121122
"""
122123
Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral)
123124
"""
124125

125126
@classmethod
126127
def _create_config_converters(cls) -> list[ParamConverter]:
127128
return super()._create_config_converters() + [
129+
ConstantExportParamConverter(export_names=(("architectures",),), export_value=[cls.architecture]),
128130
ConstantImportParamConverter(fast_llm_names=(("use_position_embeddings",),), fast_llm_value=False),
129131
RenameParamConverter(
130132
fast_llm_names=(("transformer", "rotary", "theta"),), export_names=(("rope_theta",),)
@@ -320,8 +322,8 @@ class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler)
320322

321323
@classmethod
322324
def _create_config_converters(cls) -> list[ParamConverter]:
325+
cls.architecture = "Starcoder2ForCausalLM"
323326
return super()._create_config_converters() + [
324-
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Starcoder2ForCausalLM"]),
325327
ConstantImportParamConverter(
326328
fast_llm_names=(("transformer", "rotary", "type"),),
327329
fast_llm_value=DefaultRotaryConfig.dynamic_type_name,
@@ -447,8 +449,8 @@ class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler)
447449

448450
@classmethod
449451
def _create_config_converters(cls) -> list[ParamConverter]:
452+
cls.architecture = "LlamaForCausalLM"
450453
return super()._create_config_converters() + [
451-
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["LlamaForCausalLM"]),
452454
# TODO: Llama supports biases
453455
ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False),
454456
ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False),
@@ -499,8 +501,8 @@ class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler):
499501

500502
@classmethod
501503
def _create_config_converters(cls) -> list[ParamConverter]:
504+
cls.architecture = "Qwen2ForCausalLM"
502505
return super()._create_config_converters() + [
503-
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]),
504506
ConstantImportParamConverter(
505507
fast_llm_names=(("transformer", "normalization", "type"),),
506508
fast_llm_value="rms_norm",
@@ -545,8 +547,8 @@ class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle
545547

546548
@classmethod
547549
def _create_config_converters(cls) -> list[ParamConverter]:
550+
cls.architecture = "MistralForCausalLM"
548551
return super()._create_config_converters() + [
549-
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MistralForCausalLM"]),
550552
IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None),
551553
]
552554

@@ -569,8 +571,8 @@ class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle
569571

570572
@classmethod
571573
def _create_config_converters(cls) -> list[ParamConverter]:
574+
cls.architecture = "MixtralForCausalLM"
572575
return super()._create_config_converters() + [
573-
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MixtralForCausalLM"]),
574576
ConstantImportParamConverter(
575577
fast_llm_names=(("transformer", "expert_routing_type"),), fast_llm_value=RoutingType.topk
576578
),
@@ -613,8 +615,8 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam
613615

614616
@classmethod
615617
def _create_config_converters(cls) -> list[ParamConverter]:
618+
cls.architecture = "MTPLlamaForCausalLM"
616619
return super()._create_config_converters() + [
617-
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["MTPLlamaForCausalLM"]),
618620
ConstantExportParamConverter(
619621
export_names=(("auto_map",),),
620622
export_value={
@@ -685,7 +687,12 @@ def _create_lm_head_converters(self) -> list[WeightConverter]:
685687
return converters
686688

687689

688-
class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonHuggingfaceCheckpointHandler):
690+
class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Qwen2HuggingfaceCheckpointHandler):
691+
"""
692+
Handler for DiffusionDream Huggingface checkpoints.
693+
Inherits from Qwen2HuggingfaceCheckpointHandler (and CustomModelingExportMixin),
694+
but overrides _create_config_converters to update architectures and auto_map.
695+
"""
689696

690697
from fast_llm.models.gpt.external.diffusion_dream import configuration_dream, generation_utils, modeling_dream
691698

@@ -697,33 +704,8 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Comm
697704

698705
@classmethod
699706
def _create_config_converters(cls) -> list[ParamConverter]:
707+
cls.architecture = "DreamModel"
700708
return super()._create_config_converters() + [
701-
# From Qwen2HuggingfaceCheckpointHandler - Change architectures to DiffusionDream
702-
ConstantImportParamConverter(
703-
fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm
704-
),
705-
RenameParamConverter(
706-
fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),)
707-
),
708-
ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True),
709-
ConstantImportParamConverter(
710-
fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv"
711-
),
712-
RopeScalingParamConverter(
713-
fast_llm_names=(
714-
("transformer", "rotary", "type"),
715-
("transformer", "rotary", "scale_factor"),
716-
("transformer", "rotary", "low_frequency_factor"),
717-
("transformer", "rotary", "high_frequency_factor"),
718-
("transformer", "rotary", "original_context_length"),
719-
("transformer", "rotary", "attention_factor"),
720-
("transformer", "rotary", "beta_fast"),
721-
("transformer", "rotary", "beta_slow"),
722-
),
723-
export_names=(("rope_scaling",),),
724-
),
725-
IgnoreImportQwen2SlidingWindowParamsConverter(),
726-
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DreamModel"]),
727709
ConstantExportParamConverter(
728710
export_names=(("auto_map",),),
729711
export_value={
@@ -733,26 +715,8 @@ def _create_config_converters(cls) -> list[ParamConverter]:
733715
),
734716
]
735717

736-
def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
737-
# From Qwen2HuggingfaceCheckpointHandler
738-
transformer_config: TransformerConfig = self._model.config.base_model.transformer
739-
return [
740-
*self._get_weight_and_bias_converters(
741-
f"{fast_llm_prefix}.mlp.layer_1",
742-
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
743-
transformer_config.add_mlp_bias,
744-
SplitWeightConverter,
745-
),
746-
*self._get_weight_and_bias_converters(
747-
f"{fast_llm_prefix}.mlp.layer_2",
748-
f"{hf_prefix}.mlp.down_proj",
749-
transformer_config.add_mlp_bias,
750-
MLPLayer2Converter,
751-
),
752-
]
753718

754-
755-
class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlamaHuggingfaceCheckpointHandler):
719+
class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, LlamaHuggingfaceCheckpointHandler):
756720

757721
from fast_llm.models.gpt.external.diffusion_llama import (
758722
configuration_diffusion_llama,
@@ -768,12 +732,8 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Comm
768732

769733
@classmethod
770734
def _create_config_converters(cls) -> list[ParamConverter]:
735+
cls.architecture = "DiffusionLlamaModel"
771736
return super()._create_config_converters() + [
772-
# From LlamaHuggingfaceCheckpointHandler - Update architectures to DiffusionLlama
773-
# TODO: Llama supports biases
774-
ConstantExportParamConverter(export_names=(("attention_bias",),), export_value=False),
775-
ConstantExportParamConverter(export_names=(("mlp_bias",),), export_value=False),
776-
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["DiffusionLlamaModel"]),
777737
ConstantExportParamConverter(
778738
export_names=(("auto_map",),),
779739
export_value={
@@ -789,24 +749,6 @@ def _create_config_converters(cls) -> list[ParamConverter]:
789749
# ),
790750
]
791751

792-
def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
793-
# From LlamaHuggingfaceCheckpointHandler
794-
transformer_config: TransformerConfig = self._model.config.base_model.transformer
795-
return [
796-
*self._get_weight_and_bias_converters(
797-
f"{fast_llm_prefix}.mlp.layer_1",
798-
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
799-
transformer_config.add_mlp_bias,
800-
SplitWeightConverter,
801-
),
802-
*self._get_weight_and_bias_converters(
803-
f"{fast_llm_prefix}.mlp.layer_2",
804-
f"{hf_prefix}.mlp.down_proj",
805-
transformer_config.add_mlp_bias,
806-
MLPLayer2Converter,
807-
),
808-
]
809-
810752

811753
class AutoGPTHuggingfaceCheckpointHandler(
812754
AutoStateDictCheckpointHandler, HuggingfaceStateDictCheckpointHandler, abc.ABC

fast_llm/models/gpt/external/diffusion_dream/configuration_dream.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding=utf-8
21
# Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,7 +17,6 @@
1817
from transformers.modeling_rope_utils import rope_config_validation
1918
from transformers.utils import logging
2019

21-
2220
logger = logging.get_logger(__name__)
2321

2422

@@ -47,7 +45,7 @@ def __init__(
4745
max_window_layers=28,
4846
attention_dropout=0.0,
4947
mask_token_id=151666,
50-
pad_token_id=151643, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=None,
48+
pad_token_id=None, # vocab_size is set to 8192 for test cases this would fail on Embedding layer check: # pad_token_id=151643,
5149
**kwargs,
5250
):
5351
self.vocab_size = vocab_size
@@ -77,7 +75,7 @@ def __init__(
7775
if self.rope_scaling is not None and "type" in self.rope_scaling:
7876
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
7977
rope_config_validation(self)
80-
78+
8179
super().__init__(
8280
tie_word_embeddings=tie_word_embeddings,
8381
**kwargs,
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"_from_model_config": true,
3+
"bos_token_id": 151643,
4+
"eos_token_id": 151643,
5+
"pad_token_id": 151643,
6+
"transformers_version": "4.46.2"
7+
}

fast_llm/models/gpt/external/diffusion_llama/configuration_diffusion_llama.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding=utf-8
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
32
#
43
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
@@ -20,7 +19,7 @@
2019
"""LLaMA model configuration"""
2120

2221
import math
23-
from typing import Optional, Tuple
22+
from typing import Optional
2423

2524
from transformers.configuration_utils import PretrainedConfig
2625
from transformers.utils import is_torch_available, logging
@@ -30,13 +29,14 @@
3029
if is_torch_available():
3130
import torch
3231

32+
3333
# Update yarn implementation for RoPE (Taken from Llama but updated to use original_max_position_embeddings)
3434
def _compute_default_rope_parameters(
3535
config: Optional[PretrainedConfig] = None,
3636
device: Optional["torch.device"] = None,
3737
seq_len: Optional[int] = None,
3838
**rope_kwargs,
39-
) -> Tuple["torch.Tensor", float]:
39+
) -> tuple["torch.Tensor", float]:
4040
"""
4141
Computes the inverse frequencies according to the original RoPE implementation
4242
Args:
@@ -72,9 +72,10 @@ def _compute_default_rope_parameters(
7272
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
7373
return inv_freq, attention_factor
7474

75+
7576
def _compute_yarn_parameters(
7677
config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
77-
) -> Tuple["torch.Tensor", float]:
78+
) -> tuple["torch.Tensor", float]:
7879
"""
7980
Computes the inverse frequencies with NTK scaling. Please refer to the
8081
[original paper](https://arxiv.org/abs/2309.00071)
@@ -101,7 +102,7 @@ def _compute_yarn_parameters(
101102
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
102103
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
103104
dim = int(head_dim * partial_rotary_factor)
104-
105+
105106
# Apriel: Use original max_position_embeddings instead of max_position_embeddings
106107
max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings")
107108
factor = config.rope_scaling["factor"]
@@ -152,14 +153,14 @@ def linear_ramp_factor(min, max, dim):
152153

153154
return inv_freq, attention_factor
154155

156+
155157
def _check_received_keys(
156158
rope_type: str,
157159
received_keys: set,
158160
required_keys: set,
159161
optional_keys: Optional[set] = None,
160162
ignore_keys: Optional[set] = None,
161163
):
162-
163164
"""Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
164165
# BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
165166
if "type" in received_keys:
@@ -189,6 +190,7 @@ def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Opt
189190
received_keys = set(rope_scaling.keys())
190191
_check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
191192

193+
192194
def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
193195
rope_scaling = config.rope_scaling
194196
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
@@ -218,6 +220,8 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
218220
f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
219221
f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
220222
)
223+
224+
221225
# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
222226
# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
223227
# parameterizations, as long as the callable has the same signature.
@@ -232,6 +236,7 @@ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[se
232236
"yarn": _validate_yarn_parameters,
233237
}
234238

239+
235240
def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
236241
"""
237242
Validate the RoPE config arguments, given a `PretrainedConfig` object
@@ -250,6 +255,7 @@ def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set]
250255
f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
251256
)
252257

258+
253259
class DiffusionLlamaConfig(PretrainedConfig):
254260
r"""
255261
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
@@ -397,7 +403,7 @@ def __init__(
397403
max_position_embeddings=2048,
398404
initializer_range=0.02,
399405
rms_norm_eps=1e-6,
400-
use_cache=False, # cache not implemented in diffusion
406+
use_cache=False, # cache not implemented in diffusion
401407
pad_token_id=None,
402408
bos_token_id=1,
403409
eos_token_id=2,
@@ -409,7 +415,7 @@ def __init__(
409415
attention_dropout=0.0,
410416
mlp_bias=False,
411417
head_dim=None,
412-
# mask_token_id= TODO: add the mask_token_id we will be using,
418+
mask_token_id=131072, # TODO: add the mask_token_id we will be using,
413419
**kwargs,
414420
):
415421
self.vocab_size = vocab_size
@@ -435,6 +441,8 @@ def __init__(
435441
self.attention_dropout = attention_dropout
436442
self.mlp_bias = mlp_bias
437443
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
444+
self.mask_token_id = mask_token_id
445+
self.pad_token_id = pad_token_id
438446
# Validate the correctness of rotary position embeddings parameters
439447
# BC: if there is a 'type' field, copy it it to 'rope_type'.
440448
if self.rope_scaling is not None and "type" in self.rope_scaling:
@@ -450,4 +458,5 @@ def __init__(
450458
)
451459
# TODO: self.mask_token_id = mask_token_id
452460

453-
__all__ = ["LlamaConfig"]
461+
462+
__all__ = ["DiffusionLlamaConfig"]

0 commit comments

Comments
 (0)