@@ -118,13 +118,15 @@ def import_weight(
118
118
class CommonHuggingfaceCheckpointHandler (HuggingfaceStateDictCheckpointHandler ):
119
119
_model : GPTModel
120
120
_model_class : typing .ClassVar [FastLLMModelConfig ] = GPTModelConfig
121
+ architecture : typing .ClassVar [str ]
121
122
"""
122
123
Common converter for llama-based huggingface models (llama, starcoder2, mistral, mixtral)
123
124
"""
124
125
125
126
@classmethod
126
127
def _create_config_converters (cls ) -> list [ParamConverter ]:
127
128
return super ()._create_config_converters () + [
129
+ ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = [cls .architecture ]),
128
130
ConstantImportParamConverter (fast_llm_names = (("use_position_embeddings" ,),), fast_llm_value = False ),
129
131
RenameParamConverter (
130
132
fast_llm_names = (("transformer" , "rotary" , "theta" ),), export_names = (("rope_theta" ,),)
@@ -320,8 +322,8 @@ class Starcoder2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler)
320
322
321
323
@classmethod
322
324
def _create_config_converters (cls ) -> list [ParamConverter ]:
325
+ cls .architecture = "Starcoder2ForCausalLM"
323
326
return super ()._create_config_converters () + [
324
- ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["Starcoder2ForCausalLM" ]),
325
327
ConstantImportParamConverter (
326
328
fast_llm_names = (("transformer" , "rotary" , "type" ),),
327
329
fast_llm_value = DefaultRotaryConfig .dynamic_type_name ,
@@ -447,8 +449,8 @@ class LlamaHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandler)
447
449
448
450
@classmethod
449
451
def _create_config_converters (cls ) -> list [ParamConverter ]:
452
+ cls .architecture = "LlamaForCausalLM"
450
453
return super ()._create_config_converters () + [
451
- ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["LlamaForCausalLM" ]),
452
454
# TODO: Llama supports biases
453
455
ConstantExportParamConverter (export_names = (("attention_bias" ,),), export_value = False ),
454
456
ConstantExportParamConverter (export_names = (("mlp_bias" ,),), export_value = False ),
@@ -499,8 +501,8 @@ class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler):
499
501
500
502
@classmethod
501
503
def _create_config_converters (cls ) -> list [ParamConverter ]:
504
+ cls .architecture = "Qwen2ForCausalLM"
502
505
return super ()._create_config_converters () + [
503
- ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["Qwen2ForCausalLM" ]),
504
506
ConstantImportParamConverter (
505
507
fast_llm_names = (("transformer" , "normalization" , "type" ),),
506
508
fast_llm_value = "rms_norm" ,
@@ -545,8 +547,8 @@ class MistralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle
545
547
546
548
@classmethod
547
549
def _create_config_converters (cls ) -> list [ParamConverter ]:
550
+ cls .architecture = "MistralForCausalLM"
548
551
return super ()._create_config_converters () + [
549
- ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["MistralForCausalLM" ]),
550
552
IgnoreImportParamConverter (export_names = (("sliding_window" ,),), ignore_export_value = None ),
551
553
]
552
554
@@ -569,8 +571,8 @@ class MixtralHuggingfaceCheckpointHandler(CommonLlamaHuggingfaceCheckpointHandle
569
571
570
572
@classmethod
571
573
def _create_config_converters (cls ) -> list [ParamConverter ]:
574
+ cls .architecture = "MixtralForCausalLM"
572
575
return super ()._create_config_converters () + [
573
- ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["MixtralForCausalLM" ]),
574
576
ConstantImportParamConverter (
575
577
fast_llm_names = (("transformer" , "expert_routing_type" ),), fast_llm_value = RoutingType .topk
576
578
),
@@ -613,8 +615,8 @@ class MTPLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, CommonLlam
613
615
614
616
@classmethod
615
617
def _create_config_converters (cls ) -> list [ParamConverter ]:
618
+ cls .architecture = "MTPLlamaForCausalLM"
616
619
return super ()._create_config_converters () + [
617
- ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["MTPLlamaForCausalLM" ]),
618
620
ConstantExportParamConverter (
619
621
export_names = (("auto_map" ,),),
620
622
export_value = {
@@ -685,7 +687,12 @@ def _create_lm_head_converters(self) -> list[WeightConverter]:
685
687
return converters
686
688
687
689
688
- class DiffusionDreamHuggingfaceCheckpointHandler (CustomModelingExportMixin , CommonHuggingfaceCheckpointHandler ):
690
+ class DiffusionDreamHuggingfaceCheckpointHandler (CustomModelingExportMixin , Qwen2HuggingfaceCheckpointHandler ):
691
+ """
692
+ Handler for DiffusionDream Huggingface checkpoints.
693
+ Inherits from Qwen2HuggingfaceCheckpointHandler (and CustomModelingExportMixin),
694
+ but overrides _create_config_converters to update architectures and auto_map.
695
+ """
689
696
690
697
from fast_llm .models .gpt .external .diffusion_dream import configuration_dream , generation_utils , modeling_dream
691
698
@@ -697,33 +704,8 @@ class DiffusionDreamHuggingfaceCheckpointHandler(CustomModelingExportMixin, Comm
697
704
698
705
@classmethod
699
706
def _create_config_converters (cls ) -> list [ParamConverter ]:
707
+ cls .architecture = "DreamModel"
700
708
return super ()._create_config_converters () + [
701
- # From Qwen2HuggingfaceCheckpointHandler - Change architectures to DiffusionDream
702
- ConstantImportParamConverter (
703
- fast_llm_names = (("transformer" , "normalization" , "type" ),), fast_llm_value = NormalizationType .rms_norm
704
- ),
705
- RenameParamConverter (
706
- fast_llm_names = (("transformer" , "normalization" , "epsilon" ),), export_names = (("rms_norm_eps" ,),)
707
- ),
708
- ConstantImportParamConverter (fast_llm_names = (("transformer" , "gated" ),), fast_llm_value = True ),
709
- ConstantImportParamConverter (
710
- fast_llm_names = (("transformer" , "add_linear_biases" ),), fast_llm_value = "only_attn_qkv"
711
- ),
712
- RopeScalingParamConverter (
713
- fast_llm_names = (
714
- ("transformer" , "rotary" , "type" ),
715
- ("transformer" , "rotary" , "scale_factor" ),
716
- ("transformer" , "rotary" , "low_frequency_factor" ),
717
- ("transformer" , "rotary" , "high_frequency_factor" ),
718
- ("transformer" , "rotary" , "original_context_length" ),
719
- ("transformer" , "rotary" , "attention_factor" ),
720
- ("transformer" , "rotary" , "beta_fast" ),
721
- ("transformer" , "rotary" , "beta_slow" ),
722
- ),
723
- export_names = (("rope_scaling" ,),),
724
- ),
725
- IgnoreImportQwen2SlidingWindowParamsConverter (),
726
- ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["DreamModel" ]),
727
709
ConstantExportParamConverter (
728
710
export_names = (("auto_map" ,),),
729
711
export_value = {
@@ -733,26 +715,8 @@ def _create_config_converters(cls) -> list[ParamConverter]:
733
715
),
734
716
]
735
717
736
- def _get_mlp_converters (self , fast_llm_prefix : str , hf_prefix : str ) -> list [WeightConverter ]:
737
- # From Qwen2HuggingfaceCheckpointHandler
738
- transformer_config : TransformerConfig = self ._model .config .base_model .transformer
739
- return [
740
- * self ._get_weight_and_bias_converters (
741
- f"{ fast_llm_prefix } .mlp.layer_1" ,
742
- (f"{ hf_prefix } .mlp.gate_proj" , f"{ hf_prefix } .mlp.up_proj" ),
743
- transformer_config .add_mlp_bias ,
744
- SplitWeightConverter ,
745
- ),
746
- * self ._get_weight_and_bias_converters (
747
- f"{ fast_llm_prefix } .mlp.layer_2" ,
748
- f"{ hf_prefix } .mlp.down_proj" ,
749
- transformer_config .add_mlp_bias ,
750
- MLPLayer2Converter ,
751
- ),
752
- ]
753
718
754
-
755
- class DiffusionLlamaHuggingfaceCheckpointHandler (CustomModelingExportMixin , CommonLlamaHuggingfaceCheckpointHandler ):
719
+ class DiffusionLlamaHuggingfaceCheckpointHandler (CustomModelingExportMixin , LlamaHuggingfaceCheckpointHandler ):
756
720
757
721
from fast_llm .models .gpt .external .diffusion_llama import (
758
722
configuration_diffusion_llama ,
@@ -768,12 +732,8 @@ class DiffusionLlamaHuggingfaceCheckpointHandler(CustomModelingExportMixin, Comm
768
732
769
733
@classmethod
770
734
def _create_config_converters (cls ) -> list [ParamConverter ]:
735
+ cls .architecture = "DiffusionLlamaModel"
771
736
return super ()._create_config_converters () + [
772
- # From LlamaHuggingfaceCheckpointHandler - Update architectures to DiffusionLlama
773
- # TODO: Llama supports biases
774
- ConstantExportParamConverter (export_names = (("attention_bias" ,),), export_value = False ),
775
- ConstantExportParamConverter (export_names = (("mlp_bias" ,),), export_value = False ),
776
- ConstantExportParamConverter (export_names = (("architectures" ,),), export_value = ["DiffusionLlamaModel" ]),
777
737
ConstantExportParamConverter (
778
738
export_names = (("auto_map" ,),),
779
739
export_value = {
@@ -789,24 +749,6 @@ def _create_config_converters(cls) -> list[ParamConverter]:
789
749
# ),
790
750
]
791
751
792
- def _get_mlp_converters (self , fast_llm_prefix : str , hf_prefix : str ) -> list [WeightConverter ]:
793
- # From LlamaHuggingfaceCheckpointHandler
794
- transformer_config : TransformerConfig = self ._model .config .base_model .transformer
795
- return [
796
- * self ._get_weight_and_bias_converters (
797
- f"{ fast_llm_prefix } .mlp.layer_1" ,
798
- (f"{ hf_prefix } .mlp.gate_proj" , f"{ hf_prefix } .mlp.up_proj" ),
799
- transformer_config .add_mlp_bias ,
800
- SplitWeightConverter ,
801
- ),
802
- * self ._get_weight_and_bias_converters (
803
- f"{ fast_llm_prefix } .mlp.layer_2" ,
804
- f"{ hf_prefix } .mlp.down_proj" ,
805
- transformer_config .add_mlp_bias ,
806
- MLPLayer2Converter ,
807
- ),
808
- ]
809
-
810
752
811
753
class AutoGPTHuggingfaceCheckpointHandler (
812
754
AutoStateDictCheckpointHandler , HuggingfaceStateDictCheckpointHandler , abc .ABC
0 commit comments