From cac0b8cd93eeabfa5f2f892af9fca8aa1a5d1dd5 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Mon, 24 Mar 2025 13:39:56 -0400 Subject: [PATCH 01/16] save peft Signed-off-by: Will Johnson --- tuning/config/acceleration_configs/fast_moe.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index f36fbf4c3..d346fabcb 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -14,6 +14,7 @@ # Standard from dataclasses import dataclass +from peft import PeftModel import os # Third Party @@ -113,9 +114,13 @@ def checkpoint(checkpoint_dir, save_dir): os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME), ) # Save model config files - self.trainer.model.config.save_pretrained( - hf_converted_output_dir - ) + if isinstance(self.trainer.model, PeftModel): + # Save PEFT adapter configuration + PeftModel.save_pretrained(hf_converted_output_dir) + else: + self.trainer.model.config.save_pretrained( + hf_converted_output_dir + ) except Exception as e: raise ValueError( From c5224296bcd98b71c8fb249477e88175a339f1ea Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Mon, 24 Mar 2025 14:03:04 -0400 Subject: [PATCH 02/16] fix: model Signed-off-by: Will Johnson --- tuning/config/acceleration_configs/fast_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index d346fabcb..f370b9805 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -116,7 +116,7 @@ def checkpoint(checkpoint_dir, save_dir): # Save model config files if isinstance(self.trainer.model, PeftModel): # Save PEFT adapter configuration - PeftModel.save_pretrained(hf_converted_output_dir) + PeftModel.save_pretrained(self.trainer.model, hf_converted_output_dir) else: self.trainer.model.config.save_pretrained( hf_converted_output_dir From 481dde627e26ba85d51361c2966ab135fb9ab327 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Tue, 1 Apr 2025 14:30:42 -0400 Subject: [PATCH 03/16] post process hf converted dir Signed-off-by: Will Johnson --- build/accelerate_launch.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index 6cbc7d252..f1472f534 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -145,6 +145,13 @@ def main(): post_process_vLLM_adapters_new_tokens( save_model_dir, save_model_dir, num_added_tokens ) + hf_converted_checkpoint = os.path.join(save_model_dir, "hf_converted_checkpoint") + if os.path.exists( + os.path.join(hf_converted_checkpoint, "adapter_model.safetensors") + ): + post_process_vLLM_adapters_new_tokens( + hf_converted_checkpoint, hf_converted_checkpoint, num_added_tokens + ) if ( os.path.exists(os.path.join(output_dir, "added_tokens_info.json")) From 397c9ba1d42c50bfe09b58afedfa04becf8474cc Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Mon, 7 Apr 2025 10:26:55 -0400 Subject: [PATCH 04/16] fix: convert hf converted checkpoint Signed-off-by: Will Johnson --- build/accelerate_launch.py | 42 ++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index f1472f534..6dcb282dd 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -145,13 +145,16 @@ def main(): post_process_vLLM_adapters_new_tokens( save_model_dir, save_model_dir, num_added_tokens ) - hf_converted_checkpoint = os.path.join(save_model_dir, "hf_converted_checkpoint") - if os.path.exists( - os.path.join(hf_converted_checkpoint, "adapter_model.safetensors") - ): - post_process_vLLM_adapters_new_tokens( - hf_converted_checkpoint, hf_converted_checkpoint, num_added_tokens - ) + + hf_converted_checkpoint = os.path.join( + save_model_dir, "hf_converted_checkpoint" + ) + if os.path.exists( + os.path.join(hf_converted_checkpoint, "adapter_model.safetensors") + ): + post_process_vLLM_adapters_new_tokens( + hf_converted_checkpoint, hf_converted_checkpoint, num_added_tokens + ) if ( os.path.exists(os.path.join(output_dir, "added_tokens_info.json")) @@ -166,11 +169,28 @@ def main(): for _, dirs, _ in os.walk(output_dir, topdown=False): for name in dirs: if "checkpoint-" in name.lower(): - post_process_vLLM_adapters_new_tokens( - os.path.join(output_dir, name), - os.path.join(output_dir, name), - num_added_tokens, + checkpoint_dir = os.path.join(output_dir, name) + if os.path.exists( + os.path.join(checkpoint_dir, "adapter_model.safetensors") + ): + post_process_vLLM_adapters_new_tokens( + checkpoint_dir, + checkpoint_dir, + num_added_tokens, + ) + hf_converted_checkpoint = os.path.join( + checkpoint_dir, "hf_converted_checkpoint" ) + if os.path.exists( + os.path.join( + hf_converted_checkpoint, "adapter_model.safetensors" + ) + ): + post_process_vLLM_adapters_new_tokens( + hf_converted_checkpoint, + hf_converted_checkpoint, + num_added_tokens, + ) else: logging.warning( "Failed to post-process: file added_tokens_info.json not in path %s", From 79dec24d030c22b2a5bae14653b26345abc91223 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Mon, 7 Apr 2025 13:13:33 -0400 Subject: [PATCH 05/16] lora config Signed-off-by: Will Johnson --- tuning/config/acceleration_configs/fast_moe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index f370b9805..04425d6fa 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -14,10 +14,10 @@ # Standard from dataclasses import dataclass -from peft import PeftModel import os # Third Party +from peft import LoraModel, PeftModel from transformers import ( Trainer, TrainerCallback, @@ -114,9 +114,10 @@ def checkpoint(checkpoint_dir, save_dir): os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME), ) # Save model config files - if isinstance(self.trainer.model, PeftModel): + if isinstance(self.trainer.model._fsdp_wrapped_module.base_model, LoraModel): # Save PEFT adapter configuration - PeftModel.save_pretrained(self.trainer.model, hf_converted_output_dir) + self.trainer.model._fsdp_wrapped_module.base_model.save_pretrained(hf_converted_output_dir) + else: self.trainer.model.config.save_pretrained( hf_converted_output_dir From 3103720afcfcbe10b0044b78b70d4f0b758488f5 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Mon, 7 Apr 2025 16:20:04 -0400 Subject: [PATCH 06/16] save adapter config Signed-off-by: Will Johnson --- tuning/config/acceleration_configs/fast_moe.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index 04425d6fa..7ed501b7b 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -15,6 +15,7 @@ # Standard from dataclasses import dataclass import os +import json # Third Party from peft import LoraModel, PeftModel @@ -113,10 +114,18 @@ def checkpoint(checkpoint_dir, save_dir): args, os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME), ) - # Save model config files - if isinstance(self.trainer.model._fsdp_wrapped_module.base_model, LoraModel): - # Save PEFT adapter configuration - self.trainer.model._fsdp_wrapped_module.base_model.save_pretrained(hf_converted_output_dir) + + # Unwrap FSDP module + model = self.trainer.model + if hasattr(model, "module"): + model = model.module + + if model.peft_config: + lora_config = model.peft_config["default"] + config_dict = lora_config.to_dict() + config_dict['target_modules'] = sorted(list(config_dict['target_modules'])) + with open(os.path.join(hf_converted_output_dir,"adapter_config.json"), "w") as f: + json.dump(config_dict, f, indent=2) else: self.trainer.model.config.save_pretrained( From b61cbde6d1986318f197be91eda7ae17356b20e4 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Mon, 7 Apr 2025 16:31:22 -0400 Subject: [PATCH 07/16] fmt + comments Signed-off-by: Will Johnson --- build/accelerate_launch.py | 3 +++ .../config/acceleration_configs/fast_moe.py | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index 6dcb282dd..bea6d032b 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -146,6 +146,7 @@ def main(): save_model_dir, save_model_dir, num_added_tokens ) + # In case of ScatterMoE LoRa hf_converted_checkpoint = os.path.join( save_model_dir, "hf_converted_checkpoint" ) @@ -178,6 +179,8 @@ def main(): checkpoint_dir, num_added_tokens, ) + + # In case of ScatterMoE LoRa hf_converted_checkpoint = os.path.join( checkpoint_dir, "hf_converted_checkpoint" ) diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index 7ed501b7b..7573dd7ff 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -14,11 +14,10 @@ # Standard from dataclasses import dataclass -import os import json +import os # Third Party -from peft import LoraModel, PeftModel from transformers import ( Trainer, TrainerCallback, @@ -123,14 +122,20 @@ def checkpoint(checkpoint_dir, save_dir): if model.peft_config: lora_config = model.peft_config["default"] config_dict = lora_config.to_dict() - config_dict['target_modules'] = sorted(list(config_dict['target_modules'])) - with open(os.path.join(hf_converted_output_dir,"adapter_config.json"), "w") as f: + config_dict["target_modules"] = sorted( + list(config_dict["target_modules"]) + ) + with open( + os.path.join( + hf_converted_output_dir, "adapter_config.json" + ), + "w", + encoding="utf-8" + ) as f: json.dump(config_dict, f, indent=2) else: - self.trainer.model.config.save_pretrained( - hf_converted_output_dir - ) + model.config.save_pretrained(hf_converted_output_dir) except Exception as e: raise ValueError( From c12be0ef2a69798573d47fc911f480942015eebc Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Tue, 8 Apr 2025 13:02:21 -0400 Subject: [PATCH 08/16] fix: add input linear and output linear to target modules Signed-off-by: Will Johnson --- tuning/config/acceleration_configs/fast_moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index 7573dd7ff..142d7655b 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -125,6 +125,8 @@ def checkpoint(checkpoint_dir, save_dir): config_dict["target_modules"] = sorted( list(config_dict["target_modules"]) ) + if "router" in config_dict["target_modules"]: + config_dict["target_modules"].append("input_linear, output_linear") with open( os.path.join( hf_converted_output_dir, "adapter_config.json" From 123c2d481ae77e0974ce2f2adf02ed1010445c87 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Tue, 8 Apr 2025 14:09:05 -0400 Subject: [PATCH 09/16] fix: extend instead of append Signed-off-by: Will Johnson --- tuning/config/acceleration_configs/fast_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index 142d7655b..94507855c 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -126,7 +126,7 @@ def checkpoint(checkpoint_dir, save_dir): list(config_dict["target_modules"]) ) if "router" in config_dict["target_modules"]: - config_dict["target_modules"].append("input_linear, output_linear") + config_dict["target_modules"].extend(["input_linear", "output_linear"]) with open( os.path.join( hf_converted_output_dir, "adapter_config.json" From f68500b64f320cffb7883a767743f4f3e3837152 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Tue, 8 Apr 2025 15:40:33 -0400 Subject: [PATCH 10/16] fix: if hasattr peft config Signed-off-by: Will Johnson --- tuning/config/acceleration_configs/fast_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index 94507855c..a2258d35c 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -119,7 +119,7 @@ def checkpoint(checkpoint_dir, save_dir): if hasattr(model, "module"): model = model.module - if model.peft_config: + if hasattr(model, "peft_config"): lora_config = model.peft_config["default"] config_dict = lora_config.to_dict() config_dict["target_modules"] = sorted( From 55ec4b505365161c491929e8383f670f6f01ddc6 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 9 Apr 2025 14:46:07 -0400 Subject: [PATCH 11/16] fix: remove unneeded target modules Signed-off-by: Will Johnson --- tuning/config/acceleration_configs/fast_moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index a2258d35c..40d1e286a 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -125,8 +125,6 @@ def checkpoint(checkpoint_dir, save_dir): config_dict["target_modules"] = sorted( list(config_dict["target_modules"]) ) - if "router" in config_dict["target_modules"]: - config_dict["target_modules"].extend(["input_linear", "output_linear"]) with open( os.path.join( hf_converted_output_dir, "adapter_config.json" From 23623494c3750cb8e30a541ff5e765afd062c178 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Thu, 10 Apr 2025 09:34:05 -0400 Subject: [PATCH 12/16] lint + fmt Signed-off-by: Will Johnson --- .pylintrc | 2 +- tuning/config/acceleration_configs/fast_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.pylintrc b/.pylintrc index 41f7e4e73..5e9f356b9 100644 --- a/.pylintrc +++ b/.pylintrc @@ -475,7 +475,7 @@ notes-rgx= [REFACTORING] # Maximum number of nested blocks for function / method body -max-nested-blocks=5 +max-nested-blocks=6 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then diff --git a/tuning/config/acceleration_configs/fast_moe.py b/tuning/config/acceleration_configs/fast_moe.py index 97eb214cd..37602daf1 100644 --- a/tuning/config/acceleration_configs/fast_moe.py +++ b/tuning/config/acceleration_configs/fast_moe.py @@ -139,7 +139,7 @@ def checkpoint(checkpoint_dir, save_dir): hf_converted_output_dir, "adapter_config.json" ), "w", - encoding="utf-8" + encoding="utf-8", ) as f: json.dump(config_dict, f, indent=2) From a848a9b45dcbed7baf4f58aa0fcd34ce924e6c00 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Fri, 11 Apr 2025 16:09:46 -0400 Subject: [PATCH 13/16] docs Signed-off-by: Will Johnson --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index e503a7d63..109985383 100644 --- a/README.md +++ b/README.md @@ -902,6 +902,10 @@ Notes: - When a boolean is passed, the expert parallel degree defaults to 1 and further the behaviour would be as follows: - if True, it is Scatter MoE Kernels with experts sharded based on the top level sharding protocol (e.g. FSDP). - if False, Scatter MoE Kernels with complete replication of experts across ranks. + - lora tuning with ScatterMoE is supported, but because of inference restrictions on vLLM/vanilla PEFT, experts should not be trained as `target_modules` for models being tuned with ScatterMoE. Users have control over which `target_modules` they wish to train: + - Passing `all-linear` to adapter layers will include the router, which is a linear layer, and all attn layers. This **will not** train the expert layers. + - To train only attention layers, specify target modules specifically (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj"]`). + - To train expert layers, specify `input_linear` and `output_linear` in target modules along with `router` (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj", "router", "input_linear", "output_linear"]`). If you specify these layers, inference with vLLM/vanilla HF PEFT **is not possible**. - `world_size` must be divisible by the `ep_degree` - `number of experts` in the MoE module must be divisible by the `ep_degree` - Running fast moe modifies the state dict of the model, and must be post-processed which happens automatically and the converted checkpoint can be found at `hf_converted_checkpoint` folder within every saved checkpoint directory. Alternatively, we can perform similar option manually through [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) script. From 42c420c3028713899e5a03bde6f72903478869e6 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Mon, 14 Apr 2025 16:50:13 -0400 Subject: [PATCH 14/16] test: lora for scattermoe Signed-off-by: Will Johnson --- tests/test_sft_trainer.py | 42 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 664c67ad7..c4dc8a5ed 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1447,6 +1447,44 @@ def test_run_moe_ft_and_inference_ep1_kernels(dataset_path, ep_degree): ) +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="moe"), + reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", +) +@pytest.mark.parametrize( + "dataset_path", + [ + TWITTER_COMPLAINTS_DATA_JSONL, + ], +) +def test_run_moe_lora_and_inference(dataset_path): + """Check if we can finetune a moe model and check if hf checkpoint is created""" + with tempfile.TemporaryDirectory() as tempdir: + data_args = copy.deepcopy(DATA_ARGS) + data_args.training_data_path = dataset_path + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base" + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + lora_args = copy.deepcopy(PEFT_LORA_ARGS) + lora_args.r = 16 + lora_args.target_modules = ["q_proj", "v_proj", "o_proj", "k_proj"] # Router doesn't work with LoRA test inference + fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=False)) + sft_trainer.train( + model_args, + data_args, + train_args, + lora_args, + fast_moe_config=fast_moe_config, + ) + _test_run_inference( + checkpoint_path=os.path.join( + _get_checkpoint_path(tempdir), "hf_converted_checkpoint" + ), + base_model_name_or_path="ibm-granite/granite-3.1-1b-a400m-base" + ) + + @pytest.mark.skipif( not is_fms_accelerate_available(plugins="moe"), reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", @@ -1485,9 +1523,9 @@ def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): _validate_training(tempdir) -def _test_run_inference(checkpoint_path): +def _test_run_inference(checkpoint_path, base_model_name_or_path=None): # Load the model - loaded_model = TunedCausalLM.load(checkpoint_path) + loaded_model = TunedCausalLM.load(checkpoint_path, base_model_name_or_path) # Run inference on the text output_inference = loaded_model.run( From e3e7525db94d3ebcf83a327bc0f3b91ef04b83c1 Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Tue, 15 Apr 2025 09:29:54 -0400 Subject: [PATCH 15/16] fmt tests Signed-off-by: Will Johnson --- tests/test_sft_trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index c4dc8a5ed..e97e51383 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1468,7 +1468,12 @@ def test_run_moe_lora_and_inference(dataset_path): train_args.output_dir = tempdir lora_args = copy.deepcopy(PEFT_LORA_ARGS) lora_args.r = 16 - lora_args.target_modules = ["q_proj", "v_proj", "o_proj", "k_proj"] # Router doesn't work with LoRA test inference + lora_args.target_modules = [ + "q_proj", + "v_proj", + "o_proj", + "k_proj", + ] # Router doesn't work with LoRA test inference fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=False)) sft_trainer.train( model_args, @@ -1481,7 +1486,7 @@ def test_run_moe_lora_and_inference(dataset_path): checkpoint_path=os.path.join( _get_checkpoint_path(tempdir), "hf_converted_checkpoint" ), - base_model_name_or_path="ibm-granite/granite-3.1-1b-a400m-base" + base_model_name_or_path="ibm-granite/granite-3.1-1b-a400m-base", ) From 844965959b86c31997e103ee397d8a112549953f Mon Sep 17 00:00:00 2001 From: Will Johnson Date: Wed, 16 Apr 2025 11:48:59 -0400 Subject: [PATCH 16/16] docs: notes on restrictions Signed-off-by: Will Johnson --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 109985383..51876ef6b 100644 --- a/README.md +++ b/README.md @@ -906,7 +906,9 @@ Notes: - Passing `all-linear` to adapter layers will include the router, which is a linear layer, and all attn layers. This **will not** train the expert layers. - To train only attention layers, specify target modules specifically (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj"]`). - To train expert layers, specify `input_linear` and `output_linear` in target modules along with `router` (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj", "router", "input_linear", "output_linear"]`). If you specify these layers, inference with vLLM/vanilla HF PEFT **is not possible**. - - `world_size` must be divisible by the `ep_degree` + - When lora tuning with ScatterMoE, the values `--fast_moe 1` or `--fast_moe True` are not expected to work, as FSDP must be enabled when lora tuning. Run either `--fast_moe False` or `--fast-moe x>1`. + - When lora tuning with ScatterMoE, `--r` must be set to 16 or greater. + - `world_size` must be divisible by the `--ep_degree` - `number of experts` in the MoE module must be divisible by the `ep_degree` - Running fast moe modifies the state dict of the model, and must be post-processed which happens automatically and the converted checkpoint can be found at `hf_converted_checkpoint` folder within every saved checkpoint directory. Alternatively, we can perform similar option manually through [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) script. - The typical usecase for this script is to run: