Skip to content

Commit 139f388

Browse files
committed
only relevant files
1 parent 2432cf4 commit 139f388

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

moe.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from sparseml.transformers import SparseAutoModelForCausalLM, SparseAutoTokenizer, oneshot
2+
from copy import deepcopy
3+
import torch
4+
5+
model_name = "Qwen/Qwen1.5-MoE-A2.7B"
6+
7+
model = SparseAutoModelForCausalLM.from_pretrained(
8+
model_name,
9+
device_map="cuda:0",
10+
torch_dtype=torch.float16,
11+
)
12+
og_model = deepcopy(model)
13+
tokenizer = SparseAutoTokenizer.from_pretrained(
14+
model_name
15+
)
16+
17+
dataset = "open-platypus"
18+
recipe = "tests/sparseml/transformers/compression/recipes/new_quant_full.yaml"
19+
20+
oneshot(
21+
model=model,
22+
dataset=dataset,
23+
overwrite_output_dir=True,
24+
output_dir="./output_one_shot",
25+
recipe=recipe,
26+
num_calibration_samples=8
27+
28+
)
29+
30+
prompt = "Why did the transformer cross the road?"
31+
prompt_tokenized = tokenizer(prompt, return_tensors="pt").to(model.device)
32+
print('----')
33+
print(f"Output from the original model: {tokenizer.decode(og_model.generate(**prompt_tokenized, max_length=50)[0])}")
34+
print('----')
35+
tokenizer = SparseAutoTokenizer.from_pretrained("./output_one_shot")
36+
prompt_tokenized = tokenizer(prompt, return_tensors="pt").to(model.device)
37+
print(f"Output from the quantized model: {tokenizer.decode(model.generate(**prompt_tokenized, max_length=50)[0])}")
38+
print('----')
39+
model = SparseAutoModelForCausalLM.from_pretrained("./output_one_shot", device_map="cuda:1", torch_dtype=torch.float16)
40+
print(f"Output from the quantized model (reloaded): {tokenizer.decode(model.generate(**prompt_tokenized.to(model.device), max_length=50)[0])}")
41+
print('----')

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
"opencv-python<=4.6.0.66",
7979
]
8080
_transformers_deps = _pytorch_deps + [
81-
"transformers<4.40",
81+
"transformers<4.41",
8282
"datasets<2.19",
8383
"dvc",
8484
"scikit-learn",

0 commit comments

Comments
 (0)