1
- import torch
2
- import torch .nn as nn
3
- import time
4
- from fms .utils .tokenizers import BaseTokenizer
5
- from aiu_fms_testing_utils .utils .aiu_setup import dprint
1
+ # Standard
6
2
from typing import Optional , List , Tuple
7
- import os
8
- import requests
9
3
import json
4
+ import os
10
5
import random
6
+ import requests
7
+ import time
8
+
9
+ # Third Party
10
+ from aiu_fms_testing_utils .utils .aiu_setup import dprint
11
+ from fms .utils .tokenizers import BaseTokenizer
12
+ import torch
13
+ import torch .nn as nn
11
14
12
- def warmup_model (model : nn .Module , input_ids : torch .Tensor , max_new_tokens : int , compile_dynamic_sendnn = False , use_cache : bool = True , ** extra_kwargs ):
15
+
16
+ def warmup_model (
17
+ model : nn .Module ,
18
+ input_ids : torch .Tensor ,
19
+ max_new_tokens : int ,
20
+ compile_dynamic_sendnn : bool = False ,
21
+ use_cache : bool = True ,
22
+ ** extra_kwargs
23
+ ):
13
24
import torch_sendnn
14
25
attention_specific_kwargs = {}
15
26
attn_name = extra_kwargs ["attn_name" ]
@@ -19,7 +30,7 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
19
30
# TODO: Add a unified generation dependent on attn_type
20
31
from fms .utils .generation import generate
21
32
attention_specific_kwargs ["contiguous_cache" ] = True
22
-
33
+
23
34
dprint ("AIU warmup" )
24
35
pt_compile_model_time = time .time ()
25
36
@@ -31,12 +42,23 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
31
42
_max_new_tokens = 2
32
43
# always warmup with batch size 2 when using attn_type=paged
33
44
if "paged" in attn_name :
34
- _warmup_input_ids , _extra_kwargs = adjust_inputs_to_batch (input_ids , ** extra_kwargs )
45
+ _warmup_input_ids , _extra_kwargs = adjust_inputs_to_batch (
46
+ input_ids ,
47
+ ** extra_kwargs ,
48
+ )
35
49
36
50
extra_kwargs = {** _extra_kwargs , "only_last_token" : "paged" not in attn_name }
37
51
38
52
with torch_sendnn .warmup_mode ():
39
- generate (model , _warmup_input_ids , max_new_tokens = _max_new_tokens , do_sample = False , use_cache = use_cache , extra_kwargs = extra_kwargs , ** attention_specific_kwargs )
53
+ generate (
54
+ model ,
55
+ _warmup_input_ids ,
56
+ max_new_tokens = _max_new_tokens ,
57
+ do_sample = False ,
58
+ use_cache = use_cache ,
59
+ extra_kwargs = extra_kwargs ,
60
+ ** attention_specific_kwargs ,
61
+ )
40
62
pt_compile_model_time = time .time () - pt_compile_model_time
41
63
dprint (f"PT compile complete, took { pt_compile_model_time :.3f} s" )
42
64
@@ -52,17 +74,17 @@ def __download_file(url, filename):
52
74
try :
53
75
response = requests .get (url , stream = True )
54
76
response .raise_for_status ()
55
-
77
+
56
78
with open (filename , 'wb' ) as file :
57
79
for chunk in response .iter_content (chunk_size = 8192 ):
58
80
file .write (chunk )
59
81
print (f"Successfully downloaded { filename } " )
60
-
82
+
61
83
except requests .exceptions .RequestException as e :
62
84
print (f"An error occurred: { e } " )
63
85
64
86
def __sample_requests (
65
- prompt_list : List [str ],
87
+ prompt_list : List [str ],
66
88
num_requests : int ,
67
89
tokenizer : BaseTokenizer ,
68
90
prompt_length_min : int = 32 ,
@@ -82,16 +104,14 @@ def __sample_requests(
82
104
# Tokenize the prompts and completions.
83
105
prompt = prompt_list [i ]
84
106
prompt_token_ids = ids_for_prompt (prompt , tokenizer )
85
-
107
+
86
108
prompt_len = len (prompt_token_ids )
87
109
if prompt_len < prompt_length_min or prompt_len > prompt_length_max :
88
110
# Prune too short or too long sequences.
89
111
continue
90
112
filtered_dataset .append ((prompt , prompt_len ))
91
-
92
- return filtered_dataset
93
-
94
113
114
+ return filtered_dataset
95
115
96
116
def sample_sharegpt_requests (
97
117
dataset_path : str ,
@@ -111,15 +131,22 @@ def sample_sharegpt_requests(
111
131
# Filter out the conversations with less than 2 turns.
112
132
dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
113
133
dataset = [data ["conversations" ][0 ]["value" ] for data in dataset ]
114
-
115
- return __sample_requests (dataset , num_requests , tokenizer , prompt_length_min , prompt_length_max , seed )
134
+
135
+ return __sample_requests (
136
+ dataset ,
137
+ num_requests ,
138
+ tokenizer ,
139
+ prompt_length_min ,
140
+ prompt_length_max ,
141
+ seed ,
142
+ )
116
143
117
144
def sample_squad_v2_qa_requests (
118
145
dataset_path : str ,
119
- num_requests : int ,
120
- tokenizer : BaseTokenizer ,
121
- prompt_length_min : int = 32 ,
122
- prompt_length_max : int = 64 ,
146
+ num_requests : int ,
147
+ tokenizer : BaseTokenizer ,
148
+ prompt_length_min : int = 32 ,
149
+ prompt_length_max : int = 64 ,
123
150
seed : Optional [int ] = None
124
151
) -> List [Tuple [str , int ]]:
125
152
from datasets import load_dataset
@@ -128,10 +155,14 @@ def sample_squad_v2_qa_requests(
128
155
ds = load_dataset (dataset_path )['train' ]
129
156
else :
130
157
ds = load_dataset ("rajpurkar/squad_v2" , cache_dir = dataset_path )['train' ]
131
-
132
-
133
- ds = [f"{ data ['context' ]} \n { data ['question' ]} " for data in ds ]
134
158
135
- return __sample_requests (ds , num_requests , tokenizer , prompt_length_min , prompt_length_max , seed )
136
-
159
+ ds = [f"{ data ['context' ]} \n { data ['question' ]} " for data in ds ]
137
160
161
+ return __sample_requests (
162
+ ds ,
163
+ num_requests ,
164
+ tokenizer ,
165
+ prompt_length_min ,
166
+ prompt_length_max ,
167
+ seed ,
168
+ )
0 commit comments