-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Open
Description
Before submitting a bug, please make sure the issue hasn't been already addressed by searching through the FAQs and existing/past issues
Describe the bug
When attempting to run the example script example_text_completion.py
I am getting an error:
TypeError: ModelArgs.__init__() got an unexpected keyword argument 'use_scaled_rope'
Removing "use_scaled_rope": true,
from the params.json
fixes the error and allows the prompts to run.
Minimal reproducible example
Running the following with the default downloaded params gives me the error.
torchrun --nproc_per_node 1 example_text_completion.py --ckpt_dir Meta-Llama-3.1-8B/ --tokenizer_path Meta-Llama-3.1-8B/tokenizer.model --max_seq_len 128 --max_batch_size 4
Default params.json for Meta-Llama-3.1-8b
{"dim": 4096, "ffn_dim_multiplier": 1.3, "multiple_of": 1024, "n_heads": 32, "n_kv_heads": 8, "n_layers": 32, "norm_eps": 1e-05, "rope_theta": 500000.0, "use_scaled_rope": true, "vocab_size": 128256}
example_text_completion.py
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
from typing import List
import fire
from llama import Llama
def main(
ckpt_dir: str,
tokenizer_path: str,
temperature: float = 0.6,
top_p: float = 0.9,
max_seq_len: int = 128,
max_gen_len: int = 64,
max_batch_size: int = 4,
):
"""
Examples to run with the pre-trained models (no fine-tuning). Prompts are
usually in the form of an incomplete text prefix that the model can then try to complete.
The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192.
`max_gen_len` is needed because pre-trained models usually do not stop completions naturally.
"""
generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
prompts: List[str] = [
# For these prompts, the expected answer is the natural continuation of the prompt
"I believe the meaning of life is",
"Simply put, the theory of relativity states that ",
"""A brief message congratulating the team on the launch:
Hi everyone,
I just """,
# Few shot prompt (providing a few examples before asking model to complete more);
"""Translate English to French:
sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>""",
]
results = generator.text_completion(
prompts,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
for prompt, result in zip(prompts, results):
print(prompt)
print(f"> {result['generation']}")
print("\n==================================\n")
if __name__ == "__main__":
fire.Fire(main)
Output
<Remember to wrap the output in ```triple-quotes blocks```
>
> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
/home/andrew/llama3/llama/generation.py:94: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(ckpt_path, map_location="cpu")
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/andrew/llama3/example_text_completion.py", line 64, in <module>
[rank0]: fire.Fire(main)
[rank0]: File "/home/andrew/.local/lib/python3.10/site-packages/fire/core.py", line 143, in Fire
[rank0]: component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank0]: File "/home/andrew/.local/lib/python3.10/site-packages/fire/core.py", line 477, in _Fire
[rank0]: component, remaining_args = _CallAndUpdateTrace(
[rank0]: File "/home/andrew/.local/lib/python3.10/site-packages/fire/core.py", line 693, in _CallAndUpdateTrace
[rank0]: component = fn(*varargs, **kwargs)
[rank0]: File "/home/andrew/llama3/example_text_completion.py", line 27, in main
[rank0]: generator = Llama.build(
[rank0]: File "/home/andrew/llama3/llama/generation.py", line 98, in build
[rank0]: model_args: ModelArgs = ModelArgs(
[rank0]: TypeError: ModelArgs.__init__() got an unexpected keyword argument 'use_scaled_rope'
[rank0]:[W725 20:13:04.250669532 ProcessGroupNCCL.cpp:1168] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator())
E0725 20:13:05.312000 132402189336576 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 4312) of binary: /usr/bin/python3
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 8, in <module>
sys.exit(main())
File "/home/andrew/.local/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
return f(*args, **kwargs)
File "/home/andrew/.local/lib/python3.10/site-packages/torch/distributed/run.py", line 901, in main
run(args)
File "/home/andrew/.local/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
elastic_launch(
File "/home/andrew/.local/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/andrew/.local/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
example_text_completion.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2024-07-25_20:13:05
host : patent-desktop
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 4312)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
Runtime Environment
- Model: Meta-Llama-3.1-8B
- Using via huggingface?: no
- OS: Ubuntu 22.04
- GPU VRAM: 24GB
- Number of GPUs: 1
- GPU Make: NVIDIA
Additional context
Python 3.10.12
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01 Driver Version: 535.183.01 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3090 Ti Off | 00000000:01:00.0 On | Off |
| 0% 46C P8 26W / 450W | 127MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 991 G /usr/lib/xorg/Xorg 108MiB |
| 0 N/A N/A 1083 G /usr/bin/gnome-shell 10MiB |
+---------------------------------------------------------------------------------------+
hafezmg48, eslambakr, jinlongchen, Delusion4013, hlsafin and 2 more
Metadata
Metadata
Assignees
Labels
No labels