Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ notes-rgx=
[REFACTORING]

# Maximum number of nested blocks for function / method body
max-nested-blocks=5
max-nested-blocks=6

# Complete name of functions that never returns. When checking for
# inconsistent-return-statements if a never returning function is called then
Expand Down
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,13 @@ Notes:
- When a boolean is passed, the expert parallel degree defaults to 1 and further the behaviour would be as follows:
- if True, it is Scatter MoE Kernels with experts sharded based on the top level sharding protocol (e.g. FSDP).
- if False, Scatter MoE Kernels with complete replication of experts across ranks.
- `world_size` must be divisible by the `ep_degree`
- lora tuning with ScatterMoE is supported, but because of inference restrictions on vLLM/vanilla PEFT, experts should not be trained as `target_modules` for models being tuned with ScatterMoE. Users have control over which `target_modules` they wish to train:
- Passing `all-linear` to adapter layers will include the router, which is a linear layer, and all attn layers. This **will not** train the expert layers.
- To train only attention layers, specify target modules specifically (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj"]`).
- To train expert layers, specify `input_linear` and `output_linear` in target modules along with `router` (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj", "router", "input_linear", "output_linear"]`). If you specify these layers, inference with vLLM/vanilla HF PEFT **is not possible**.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- To train expert layers, specify `input_linear` and `output_linear` in target modules along with `router` (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj", "router", "input_linear", "output_linear"]`). If you specify these layers, inference with vLLM/vanilla HF PEFT **is not possible**.
- To train expert layers, specify `input_linear` and `output_linear` in target modules along with `router` (i.e `target_modules: ["q_proj", "v_proj", "o_proj", "k_proj", "router", "input_linear", "output_linear"]`). If you specify these layers, inference with vLLM/vanilla HF PEFT **is not currently supported.**.

- When lora tuning with ScatterMoE, the values `--fast_moe 1` or `--fast_moe True` are not expected to work, as FSDP must be enabled when lora tuning. Run either `--fast_moe False` or `--fast-moe x>1`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didnt get your point quite yet here. --fast_moe True disables expert parallel however, experts are sharded by FSDP. So FSDP is active here.

BTW, --fast_moe 1 and --fast_moe False Both have the same effect isn't it? In both the settings, all experts are replicated and deferred from FSDP however, other layers are under FSDP sharding.

May be if you are confortable with a support matrix table, lets do that and pin point case by case.

- When lora tuning with ScatterMoE, `--r` must be set to 16 or greater.
- `world_size` must be divisible by the `--ep_degree`
- `number of experts` in the MoE module must be divisible by the `ep_degree`
- Running fast moe modifies the state dict of the model, and must be post-processed which happens automatically and the converted checkpoint can be found at `hf_converted_checkpoint` folder within every saved checkpoint directory. Alternatively, we can perform similar option manually through [checkpoint utils](https://github.com/foundation-model-stack/fms-acceleration/blob/main/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py) script.
- The typical usecase for this script is to run:
Expand Down
38 changes: 34 additions & 4 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ def main():
save_model_dir, save_model_dir, num_added_tokens
)

# In case of ScatterMoE LoRa
hf_converted_checkpoint = os.path.join(
save_model_dir, "hf_converted_checkpoint"
)
if os.path.exists(
os.path.join(hf_converted_checkpoint, "adapter_model.safetensors")
):
post_process_vLLM_adapters_new_tokens(
hf_converted_checkpoint, hf_converted_checkpoint, num_added_tokens
)

if (
os.path.exists(os.path.join(output_dir, "added_tokens_info.json"))
and job_config.get("save_strategy") != "no"
Expand All @@ -159,11 +170,30 @@ def main():
for _, dirs, _ in os.walk(output_dir, topdown=False):
for name in dirs:
if "checkpoint-" in name.lower():
post_process_vLLM_adapters_new_tokens(
os.path.join(output_dir, name),
os.path.join(output_dir, name),
num_added_tokens,
checkpoint_dir = os.path.join(output_dir, name)
if os.path.exists(
os.path.join(checkpoint_dir, "adapter_model.safetensors")
):
post_process_vLLM_adapters_new_tokens(
checkpoint_dir,
checkpoint_dir,
num_added_tokens,
)

# In case of ScatterMoE LoRa
hf_converted_checkpoint = os.path.join(
checkpoint_dir, "hf_converted_checkpoint"
)
if os.path.exists(
os.path.join(
hf_converted_checkpoint, "adapter_model.safetensors"
)
):
post_process_vLLM_adapters_new_tokens(
hf_converted_checkpoint,
hf_converted_checkpoint,
num_added_tokens,
)
else:
logging.warning(
"Failed to post-process: file added_tokens_info.json not in path %s",
Expand Down
47 changes: 45 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,49 @@ def test_run_moe_ft_and_inference_ep1_kernels(dataset_path, ep_degree):
)


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="moe"),
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
)
@pytest.mark.parametrize(
"dataset_path",
[
TWITTER_COMPLAINTS_DATA_JSONL,
],
)
def test_run_moe_lora_and_inference(dataset_path):
"""Check if we can finetune a moe model and check if hf checkpoint is created"""
with tempfile.TemporaryDirectory() as tempdir:
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = dataset_path
model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base"
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
lora_args = copy.deepcopy(PEFT_LORA_ARGS)
lora_args.r = 16
lora_args.target_modules = [
"q_proj",
"v_proj",
"o_proj",
"k_proj",
] # Router doesn't work with LoRA test inference
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=False))
sft_trainer.train(
model_args,
data_args,
train_args,
lora_args,
fast_moe_config=fast_moe_config,
)
_test_run_inference(
checkpoint_path=os.path.join(
_get_checkpoint_path(tempdir), "hf_converted_checkpoint"
),
base_model_name_or_path="ibm-granite/granite-3.1-1b-a400m-base",
)


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="moe"),
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
Expand Down Expand Up @@ -1485,9 +1528,9 @@ def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
_validate_training(tempdir)


def _test_run_inference(checkpoint_path):
def _test_run_inference(checkpoint_path, base_model_name_or_path=None):
# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)
loaded_model = TunedCausalLM.load(checkpoint_path, base_model_name_or_path)

# Run inference on the text
output_inference = loaded_model.run(
Expand Down
28 changes: 24 additions & 4 deletions tuning/config/acceleration_configs/fast_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dataclasses import dataclass, field
from typing import Union
import argparse
import json
import os

# Third Party
Expand Down Expand Up @@ -121,10 +122,29 @@ def checkpoint(checkpoint_dir, save_dir):
args,
os.path.join(hf_converted_output_dir, TRAINING_ARGS_NAME),
)
# Save model config files
self.trainer.model.config.save_pretrained(
hf_converted_output_dir
)

# Unwrap FSDP module
model = self.trainer.model
if hasattr(model, "module"):
model = model.module

if hasattr(model, "peft_config"):
lora_config = model.peft_config["default"]
config_dict = lora_config.to_dict()
config_dict["target_modules"] = sorted(
list(config_dict["target_modules"])
)
with open(
os.path.join(
hf_converted_output_dir, "adapter_config.json"
),
"w",
encoding="utf-8",
) as f:
json.dump(config_dict, f, indent=2)

else:
model.config.save_pretrained(hf_converted_output_dir)

except Exception as e:
raise ValueError(
Expand Down
Loading