11
11
import random
12
12
import math
13
13
14
+ def stagger_enter (limit : int ):
15
+ if limit > 0 and limit != world_size :
16
+ for _set in range ( math .ceil (world_size / float (limit )) ):
17
+ if rank < (_set + 1 )* limit :
18
+ break
19
+ torch .distributed .barrier ()
20
+ dprint (f"Stagger: Enter (Set: { _set + 1 } of { math .ceil (world_size / float (limit ))} )" )
21
+
22
+ def stagger_leave (limit : int ):
23
+ if limit > 0 and limit != world_size :
24
+ for _set in range ( math .ceil (world_size / float (limit )) ):
25
+ if rank >= (_set + 1 )* limit :
26
+ continue
27
+ torch .distributed .barrier ()
28
+ dprint (f"Stagger: All Complete" )
29
+
14
30
def warmup_model (model : nn .Module , input_ids : torch .Tensor , max_new_tokens : int , compile_dynamic_sendnn = False , stagger_update_lazyhandle = 0 , ** padding_kwargs ):
15
31
import torch_sendnn
16
32
dprint ("AIU warmup" )
@@ -19,25 +35,15 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
19
35
if compile_dynamic_sendnn :
20
36
max_new_tokens_warmup = 2
21
37
22
- if stagger_update_lazyhandle > 0 and stagger_update_lazyhandle != world_size :
23
- for _set in range ( math .ceil (world_size / float (stagger_update_lazyhandle )) ):
24
- if rank < (_set + 1 )* stagger_update_lazyhandle :
25
- break
26
- torch .distributed .barrier ()
27
- dprint (f"Stagger update_lazyhandle: Begin (Set: { _set + 1 } of { math .ceil (world_size / float (stagger_update_lazyhandle ))} )" )
38
+ stagger_enter (stagger_update_lazyhandle )
28
39
29
40
pt_compile_model_time = time .time ()
30
41
with torch_sendnn .warmup_mode ():
31
42
generate (model , input_ids , max_new_tokens = max_new_tokens_warmup , max_seq_len = model .config .max_expected_seq_len , use_cache = True , do_sample = False , contiguous_cache = True , extra_kwargs = extra_kwargs )
32
43
pt_compile_model_time = time .time () - pt_compile_model_time
33
44
dprint (f"PT compile complete, took { pt_compile_model_time :.3f} s" )
34
45
35
- if stagger_update_lazyhandle > 0 and stagger_update_lazyhandle != world_size :
36
- for _set in range ( math .ceil (world_size / float (stagger_update_lazyhandle )) ):
37
- if rank >= (_set + 1 )* stagger_update_lazyhandle :
38
- continue
39
- torch .distributed .barrier ()
40
- dprint (f"Stagger update_lazyhandle: All Complete" )
46
+ stagger_leave (stagger_update_lazyhandle )
41
47
42
48
def ids_for_prompt (prompt , tokenizer ):
43
49
tokens = tokenizer .tokenize (prompt )
0 commit comments