Skip to content

Commit 9fa47ad

Browse files
committed
Merge remote-tracking branch 'origin/main' into feature/ouro
2 parents 879a348 + d3dc2e3 commit 9fa47ad

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+5135
-820
lines changed

.github/workflows/pull_request.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,6 @@ jobs:
3838
- name: Run tests
3939
shell: bash -l {0}
4040
run: |
41-
python -m xmlrunner discover -v tests -o test-results/
41+
curl -o test_data.zip -L https://github.com/ml-explore/mlx-lm/releases/download/test_data/test_data.zip
42+
unzip test_data.zip
43+
HF_HOME="." python -m xmlrunner discover -v tests -o test-results/

.github/workflows/release.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ permissions:
1212
jobs:
1313

1414
build_release:
15-
if: github.repository == 'ml-explore/mlx'
15+
if: github.repository == 'ml-explore/mlx-lm'
1616
runs-on: ubuntu-22.04
1717
permissions:
1818
id-token: write

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ prompt = "Write a story about Einstein"
7171

7272
messages = [{"role": "user", "content": prompt}]
7373
prompt = tokenizer.apply_chat_template(
74-
messages, add_generation_prompt=True
74+
messages, add_generation_prompt=True,
7575
)
7676

7777
text = generate(model, tokenizer, prompt=prompt, verbose=True)
@@ -130,7 +130,7 @@ prompt = "Write a story about Einstein"
130130

131131
messages = [{"role": "user", "content": prompt}]
132132
prompt = tokenizer.apply_chat_template(
133-
messages, add_generation_prompt=True
133+
messages, add_generation_prompt=True,
134134
)
135135

136136
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
@@ -170,7 +170,7 @@ mlx_lm.generate --help
170170
To quantize a model from the command line run:
171171

172172
```
173-
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q
173+
mlx_lm.convert --model mistralai/Mistral-7B-Instruct-v0.3 -q
174174
```
175175

176176
For more options run:
@@ -185,7 +185,7 @@ You can upload new models to Hugging Face by specifying `--upload-repo` to
185185

186186
```
187187
mlx_lm.convert \
188-
--hf-path mistralai/Mistral-7B-Instruct-v0.3 \
188+
--model mistralai/Mistral-7B-Instruct-v0.3 \
189189
-q \
190190
--upload-repo mlx-community/my-4bit-mistral
191191
```

mlx_lm/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright © 2023-2025 Apple Inc.
22

3-
__version__ = "0.28.4"
3+
__version__ = "0.30.0"

mlx_lm/benchmark.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from mlx_lm import batch_generate, load, stream_generate
88
from mlx_lm.generate import DEFAULT_MODEL
9-
from mlx_lm.utils import pipeline_load
9+
from mlx_lm.utils import pipeline_load, sharded_load
1010

1111

1212
def setup_arg_parser():
@@ -49,6 +49,11 @@ def setup_arg_parser():
4949
help="Number of timing trials",
5050
type=int,
5151
)
52+
parser.add_argument(
53+
"--pipeline",
54+
action="store_true",
55+
help="Use pipelining instead of tensor parallelism",
56+
)
5257
return parser
5358

5459

@@ -59,6 +64,8 @@ def main():
5964

6065
group = mx.distributed.init()
6166
rank = group.rank()
67+
pipeline_group = group if args.pipeline else None
68+
tensor_group = group if not args.pipeline else None
6269

6370
def rprint(*args, **kwargs):
6471
if rank == 0:
@@ -67,7 +74,9 @@ def rprint(*args, **kwargs):
6774
model_path = args.model or DEFAULT_MODEL
6875

6976
if group.size() > 1:
70-
model, tokenizer, config = pipeline_load(args.model, return_config=True)
77+
model, tokenizer, config = sharded_load(
78+
args.model, pipeline_group, tensor_group, return_config=True
79+
)
7180
else:
7281
model, tokenizer, config = load(
7382
args.model, return_config=True, tokenizer_config={"trust_remote_code": True}

mlx_lm/cache_prompt.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,6 @@ def setup_arg_parser():
4141
default=None,
4242
help="End of sequence token for tokenizer",
4343
)
44-
parser.add_argument(
45-
"--ignore-chat-template",
46-
action="store_true",
47-
help="Use the raw prompt without the tokenizer's chat template.",
48-
)
49-
parser.add_argument(
50-
"--use-default-chat-template",
51-
action="store_true",
52-
help="Use the default chat template",
53-
)
5444
parser.add_argument(
5545
"--max-kv-size",
5646
type=int,
@@ -107,14 +97,12 @@ def main():
10797

10898
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
10999

110-
if args.use_default_chat_template:
111-
if tokenizer.chat_template is None:
112-
tokenizer.chat_template = tokenizer.default_chat_template
113-
114-
if not args.ignore_chat_template and tokenizer.chat_template is not None:
100+
if tokenizer.has_chat_template:
115101
messages = [{"role": "user", "content": args.prompt}]
116102
prompt = tokenizer.apply_chat_template(
117-
messages, add_generation_prompt=False, continue_final_message=True
103+
messages,
104+
add_generation_prompt=False,
105+
continue_final_message=True,
118106
)
119107

120108
else:
@@ -153,7 +141,6 @@ def callback(processed, total_tokens):
153141
print("Saving...")
154142
metadata = {}
155143
metadata["model"] = args.model
156-
metadata["chat_template"] = json.dumps(tokenizer.chat_template)
157144
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
158145
save_prompt_cache(args.prompt_cache_file, cache, metadata)
159146

mlx_lm/chat.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .generate import stream_generate
88
from .models.cache import make_prompt_cache
99
from .sample_utils import make_sampler
10-
from .utils import load
10+
from .utils import load, sharded_load
1111

1212
DEFAULT_TEMP = 0.0
1313
DEFAULT_TOP_P = 1.0
@@ -79,35 +79,54 @@ def setup_arg_parser():
7979
default=None,
8080
help="System prompt to be used for the chat template",
8181
)
82+
parser.add_argument(
83+
"--pipeline",
84+
action="store_true",
85+
help="Use pipelining instead of tensor parallelism",
86+
)
8287
return parser
8388

8489

8590
def main():
8691
parser = setup_arg_parser()
8792
args = parser.parse_args()
8893

94+
group = mx.distributed.init()
95+
rank = group.rank()
96+
pipeline_group = group if args.pipeline else None
97+
tensor_group = group if not args.pipeline else None
98+
99+
def rprint(*args, **kwargs):
100+
if rank == 0:
101+
print(*args, **kwargs)
102+
89103
if args.seed is not None:
90104
mx.random.seed(args.seed)
91105

92-
model, tokenizer = load(
93-
args.model,
94-
adapter_path=args.adapter_path,
95-
tokenizer_config={
96-
"trust_remote_code": True if args.trust_remote_code else None
97-
},
98-
)
106+
if group.size() > 1:
107+
if args.adapter_path:
108+
parser.error("Adapters not supported in distributed mode")
109+
model, tokenizer = sharded_load(args.model, pipeline_group, tensor_group)
110+
else:
111+
model, tokenizer = load(
112+
args.model,
113+
adapter_path=args.adapter_path,
114+
tokenizer_config={
115+
"trust_remote_code": True if args.trust_remote_code else None
116+
},
117+
)
99118

100119
def print_help():
101-
print("The command list:")
102-
print("- 'q' to exit")
103-
print("- 'r' to reset the chat")
104-
print("- 'h' to display these commands")
120+
rprint("The command list:")
121+
rprint("- 'q' to exit")
122+
rprint("- 'r' to reset the chat")
123+
rprint("- 'h' to display these commands")
105124

106-
print(f"[INFO] Starting chat session with {args.model}.")
125+
rprint(f"[INFO] Starting chat session with {args.model}.")
107126
print_help()
108127
prompt_cache = make_prompt_cache(model, args.max_kv_size)
109128
while True:
110-
query = input(">> ")
129+
query = input(">> " if rank == 0 else "")
111130
if query == "q":
112131
break
113132
if query == "r":
@@ -120,7 +139,10 @@ def print_help():
120139
if args.system_prompt is not None:
121140
messages.append({"role": "system", "content": args.system_prompt})
122141
messages.append({"role": "user", "content": query})
123-
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
142+
prompt = tokenizer.apply_chat_template(
143+
messages,
144+
add_generation_prompt=True,
145+
)
124146
for response in stream_generate(
125147
model,
126148
tokenizer,
@@ -137,8 +159,8 @@ def print_help():
137159
),
138160
prompt_cache=prompt_cache,
139161
):
140-
print(response.text, flush=True, end="")
141-
print()
162+
rprint(response.text, flush=True, end="")
163+
rprint()
142164

143165

144166
if __name__ == "__main__":

mlx_lm/convert.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,25 +179,36 @@ def configure_parser() -> argparse.ArgumentParser:
179179
description="Convert Hugging Face model to MLX format"
180180
)
181181

182-
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
182+
parser.add_argument(
183+
"--hf-path",
184+
"--model",
185+
type=str,
186+
help="Path to the model. This can be a local path or a Hugging Face Hub model identifier.",
187+
)
183188
parser.add_argument(
184189
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
185190
)
186191
parser.add_argument(
187192
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
188193
)
189194
parser.add_argument(
190-
"--q-group-size", help="Group size for quantization.", type=int, default=64
195+
"--q-group-size",
196+
help="Group size for quantization.",
197+
type=int,
198+
default=None,
191199
)
192200
parser.add_argument(
193-
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
201+
"--q-bits",
202+
help="Bits per weight for quantization.",
203+
type=int,
204+
default=None,
194205
)
195206
parser.add_argument(
196207
"--q-mode",
197208
help="The quantization mode.",
198209
type=str,
199210
default="affine",
200-
choices=["affine", "mxfp4"],
211+
choices=["affine", "mxfp4", "nvfp4", "mxfp8"],
201212
)
202213
parser.add_argument(
203214
"--quant-predicate",

mlx_lm/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from .generate import batch_generate
2727
from .models.cache import make_prompt_cache
2828
from .sample_utils import make_sampler
29-
from .utils import common_prefix_len, load
29+
from .utils import load
3030

3131
DEFAULT_MAX_TOKENS = 8192
3232

mlx_lm/examples/batch_generate_response.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,26 @@
2626
]
2727

2828
# Set `verbose=True` to see generation statistics
29-
result = batch_generate(model, tokenizer, prompts, verbose=False, max_tokens=128)
29+
result = batch_generate(
30+
model, tokenizer, prompts, verbose=False, return_prompt_caches=True
31+
)
32+
print(result.texts[-1])
3033

31-
# The returned result contains texts completions in the same order as prompts
32-
print(result.texts[0])
34+
prompts = [
35+
"Could you summarize that?",
36+
"And what about the sea?",
37+
"Try again?",
38+
"And Mt Olympus?",
39+
]
40+
prompts = [
41+
tokenizer.apply_chat_template(
42+
[{"role": "user", "content": p}],
43+
add_generation_prompt=True,
44+
)
45+
for p in prompts
46+
]
47+
48+
result = batch_generate(
49+
model, tokenizer, prompts, verbose=False, prompt_caches=result.caches
50+
)
51+
print(result.texts[-1])

0 commit comments

Comments
 (0)