Open
Description
Hello, I load pre-trained llava-llama3 SFT weights and fine-tune using LoRA, but get an error when merging weights:
scripts:
Training:
deepspeed --master_port=$((RANDOM + 10000)) --include localhost:0,1,2,3 llava/train/train_mem.py \
--lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
--deepspeed ./scripts/zero2.json \
--model_name_or_path ./HuggingFace-Download-Accelerator/models--MBZUAI--LLaVA-Meta-Llama-3-8B-Instruct-FT \
--version llama3 \
--data_path train_data.json \
--image_folder .data/image \
--vision_tower clip-vit-large-patch14-336 \
--pretrain_mm_mlp_adapter ./HuggingFace-Download-Accelerator/models--MBZUAI--LLaVA-Meta-Llama-3-8B-Instruct-pretrain/mm_projector.bin \
...
merge lora:
python scripts/merge_lora_weights.py \
--model-path ./checkpoints/llava-v1.5-8b-finetune-lora_loadfrom_FT \
--model-base ./HuggingFace-Download-Accelerator/models--MBZUAI--LLaVA-Meta-Llama-3-8B-Instruct-FT \
--save-model-path ./checkpoints/merge_llava-llama3-finetune-lora_loadfrom_FT
Error:
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading LLaVA from base model...
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
Loading checkpoint shards: 0%| | 0/4 [00:10<?, ?it/s]
Traceback (most recent call last):
File "/LLaVA-pp/LLaVA/scripts/merge_lora_weights.py", line 24, in <module>
merge_lora(args)
File "/LLaVA-pp/LLaVA/scripts/merge_lora_weights.py", line 8, in merge_lora
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
File "/LLaVA-pp/LLaVA/llava/model/builder.py", line 64, in load_pretrained_model
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
File "/miniconda-3/envs/llava-llama/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3682, in from_pretrained
) = cls._load_pretrained_model(
File "/miniconda-3/envs/llava-llama/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4109, in _load_pretrained_model
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
File "/miniconda-3/envs/llava-llama/lib/python3.10/site-packages/transformers/modeling_utils.py", line 887, in _load_state_dict_into_meta_model
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
File "/miniconda-3/envs/llava-llama/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 358, in set_module_tensor_to_device
raise ValueError(
ValueError: Trying to set a tensor of shape torch.Size([128257, 4096]) in "weight" (which has shape torch.Size([128256, 4096])), this look incorrect.
/LLaVA-pp/LLaVA/llava/model/builder.py:54: UserWarning: There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
Traceback (most recent call last):
File "/miniconda-3/envs/llava-llama/lib/python3.10/site-packages/transformers/utils/hub.py", line 398, in cached_file
resolved_file = hf_hub_download(
File "/miniconda-3/envs/llava-llama/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 111, in _inner_fn
validate_repo_id(arg_value)
File "/miniconda-3/envs/llava-llama/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 159, in validate_repo_id
raise HFValidationError(
huggingface_hub.utils._validators.HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': '/LLaVA-pp/LLaVA/checkpoints/merge_llava-llama3-finetune-lora_loadfrom_FT'. Use `repo_type` argument if needed.
Metadata
Metadata
Assignees
Labels
No labels