|
12 | 12 | import math
|
13 | 13 |
|
14 | 14 | # Third Party
|
15 |
| -from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_enter, stagger_leave |
| 15 | +from aiu_fms_testing_utils.utils import aiu_setup, warmup_model, stagger_region |
16 | 16 | from aiu_fms_testing_utils.utils.aiu_setup import dprint, rank, local_rank, world_size
|
17 | 17 | import numpy as np
|
18 | 18 | import torch
|
@@ -464,38 +464,35 @@ def select_int8_module(
|
464 | 464 | dprint(f"data_type={default_dtype}")
|
465 | 465 | dprint("="*60 + "\n")
|
466 | 466 |
|
467 |
| -stagger_enter(args.stagger_load) |
468 |
| - |
469 |
| -model = get_model( |
470 |
| - args.architecture, |
471 |
| - args.variant, |
472 |
| - model_path=args.model_path, |
473 |
| - device_type="cpu" if is_aiu_backend else args.device_type, |
474 |
| - data_type=default_dtype, |
475 |
| - source=args.model_source, |
476 |
| - distributed_strategy=distr_param, |
477 |
| - group=dist.group.WORLD, |
478 |
| - linear_config=linear_config, |
479 |
| - fused_weights=fused_weights, |
480 |
| -) |
481 |
| - |
482 |
| -if args.quantization in ["gptq", "int8"]: |
483 |
| - if rank == 0 and args.verbose > 0: |
484 |
| - dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters())) |
485 |
| - dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers())) |
| 467 | +with stagger_region(args.stagger_load) as _s: |
| 468 | + model = get_model( |
| 469 | + args.architecture, |
| 470 | + args.variant, |
| 471 | + model_path=args.model_path, |
| 472 | + device_type="cpu" if is_aiu_backend else args.device_type, |
| 473 | + data_type=default_dtype, |
| 474 | + source=args.model_source, |
| 475 | + distributed_strategy=distr_param, |
| 476 | + group=dist.group.WORLD, |
| 477 | + linear_config=linear_config, |
| 478 | + fused_weights=fused_weights, |
| 479 | + ) |
| 480 | + |
| 481 | + if args.quantization in ["gptq", "int8"]: |
| 482 | + if rank == 0 and args.verbose > 0: |
| 483 | + dprint("PARAMS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_parameters())) |
| 484 | + dprint("BUFFERS:\n" + "\n".join(f"{k:60} {str(v.dtype):15} {str(v.device):10} {list(v.size())}" for k,v in model.named_buffers())) |
| 485 | + dprint("="*60 + "\n") |
| 486 | + if args.architecture == "llama": |
| 487 | + dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.") |
| 488 | + dprint(model) |
486 | 489 | dprint("="*60 + "\n")
|
487 |
| - if args.architecture == "llama": |
488 |
| - dprint("[NOTE] In Llama models, it's OK for bias and rotary embeddings to be marked as unused keys.") |
489 |
| - dprint(model) |
490 |
| - dprint("="*60 + "\n") |
491 |
| - |
492 |
| -tokenizer = tokenizers.get_tokenizer(args.tokenizer) |
493 |
| -model.eval() |
494 |
| -torch.set_grad_enabled(False) |
495 |
| -loading_model_time = time.time() - loading_model_time |
496 |
| -dprint(f"loading complete, took {loading_model_time:.3f}s") |
497 |
| - |
498 |
| -stagger_leave(args.stagger_load) |
| 490 | + |
| 491 | + tokenizer = tokenizers.get_tokenizer(args.tokenizer) |
| 492 | + model.eval() |
| 493 | + torch.set_grad_enabled(False) |
| 494 | + loading_model_time = time.time() - loading_model_time |
| 495 | + dprint(f"loading complete, took {loading_model_time:.3f}s") |
499 | 496 |
|
500 | 497 | if args.compile:
|
501 | 498 | dprint("compiling model")
|
|
0 commit comments