@@ -244,7 +244,7 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):
244
244
245
245
def _load_model (checkpoint_path , device , precision ):
246
246
checkpoint = torch .load (
247
- str (checkpoint_path ), mmap = True , weights_only = True , map_location = device
247
+ str (checkpoint_path ), mmap = True , weights_only = True , map_location = "cpu"
248
248
)
249
249
if "model" in checkpoint and "stories" in str (checkpoint_path ):
250
250
checkpoint = checkpoint ["model" ]
@@ -366,34 +366,24 @@ def ffn_or_attn_only(mod, fqn):
366
366
import os
367
367
import pwd
368
368
369
- from gemlite .core import GemLiteLinearTriton
369
+ import gemlite
370
+
371
+ gemlite .set_autotune ("max" )
372
+ config_file = f"/tmp/{ pwd .getpwuid (os .getuid ()).pw_gecos } _gemlite.json"
370
373
371
374
_quant_args = quantization .split ("-" )
372
- bit_width = int (_quant_args [- 2 ])
373
- group_size = None if _quant_args [- 1 ] == "None" else int (_quant_args [- 1 ])
374
- try :
375
- packing_bitwidth = int (_quant_args [- 3 ])
376
- except :
377
- # if only 2 inputs found, use default value
378
- packing_bitwidth = 32
375
+ bit_width = int (_quant_args [1 ])
376
+ group_size = None if _quant_args [2 ] == "None" else int (_quant_args [2 ])
377
+ mode = "dynamic" if _quant_args [3 ] == "dq" else "weight_only"
379
378
380
379
quantize_ (
381
380
model ,
382
- gemlite_uintx_weight_only (group_size , bit_width , packing_bitwidth ),
381
+ gemlite_uintx_weight_only (
382
+ bit_width = bit_width , group_size = group_size , mode = mode
383
+ ),
383
384
)
384
385
385
- # try to load gemlite kernel config
386
- try :
387
- GemLiteLinearTriton .load_config (
388
- f"/tmp/{ pwd .getpwuid (os .getuid ()).pw_gecos } _gemlite.json"
389
- )
390
- print (
391
- f"loaded gemlite kernel cache /tmp/{ pwd .getpwuid (os .getuid ()).pw_gecos } _gemlite.json"
392
- )
393
- except :
394
- print (
395
- f"unable to load gemlite kernel cache /tmp/{ pwd .getpwuid (os .getuid ()).pw_gecos } _gemlite.json"
396
- )
386
+ gemlite .load_config (config_file )
397
387
398
388
print ("running gemlite warmup" )
399
389
generate (
@@ -405,9 +395,8 @@ def ffn_or_attn_only(mod, fqn):
405
395
temperature = temperature ,
406
396
top_k = top_k ,
407
397
)
408
- GemLiteLinearTriton .cache_config (
409
- f"/tmp/{ pwd .getpwuid (os .getuid ()).pw_gecos } _gemlite.json"
410
- )
398
+ gemlite .cache_config (config_file )
399
+
411
400
if "int8wo" in quantization :
412
401
quantize_ (model , int8_weight_only ())
413
402
if "int8dq" in quantization :
0 commit comments