|
6 | 6 | from typing import List, Optional, Tuple
|
7 | 7 |
|
8 | 8 | import torch
|
9 |
| - |
10 |
| -from aiu_fms_testing_utils.testing.validation import capture_level_1_metrics, extract_validation_information, LogitsExtractorHook, print_failed_cases, \ |
| 9 | +from torch import distributed as dist |
| 10 | +from aiu_fms_testing_utils.testing.validation import capture_level_1_metrics, extract_validation_information, LogitsExtractorHook, get_default_validation_prefix, load_validation_information, print_failed_cases, \ |
11 | 11 | validate_level_0, GoldenTokenHook, top_k_loss_calculator
|
12 | 12 | from aiu_fms_testing_utils.utils import ids_for_prompt, sample_sharegpt_requests
|
13 | 13 | from fms.models import get_model
|
|
97 | 97 | default={},
|
98 | 98 | help="Use this to override model configuration values to get model. Example: --extra_get_model_kwargs nlayers=2,..."
|
99 | 99 | )
|
| 100 | +parser.add_argument( |
| 101 | + "--distributed", |
| 102 | + action="store_true", |
| 103 | + help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)", |
| 104 | +) |
| 105 | +parser.add_argument( |
| 106 | + "--skip_computation", |
| 107 | + action="store_true", |
| 108 | + help="Set this if the output is already assumed to be computed and would like to regenerate metrics without model loading or computation" |
| 109 | +) |
| 110 | +local_rank = int(os.getenv("LOCAL_RANK", 0)) |
| 111 | +world_size = int(os.getenv("WORLD_SIZE", 1)) |
100 | 112 | args = parser.parse_args()
|
101 | 113 |
|
| 114 | +if args.distributed: |
| 115 | + dist.init_process_group() |
| 116 | + distr_param = "tp" |
| 117 | + torch.cuda.set_device(local_rank) |
| 118 | +else: |
| 119 | + distr_param = None |
| 120 | + |
102 | 121 | extra_get_model_kwargs = {}
|
103 | 122 | for a in args.extra_get_model_kwargs:
|
104 | 123 | a_split = a.split("=")
|
|
108 | 127 | extra_get_model_kwargs[a_split[0]] = a_split[1]
|
109 | 128 |
|
110 | 129 | # this follows the same pattern of naming in test_shapes. This way we can save and re-use for quicker shape testing.
|
111 |
| -prefix = f"{args.variant.replace('/', '--')}_max-new-tokens-{args.max_new_tokens}_batch-size-{args.batch_size}_seq-length-{args.min_pad_length}_dtype-{args.default_dtype}" |
| 130 | +prefix = get_default_validation_prefix(args.variant, args.max_new_tokens, args.batch_size, args.min_pad_length, args.default_dtype) |
112 | 131 | if os.path.exists(os.path.join(args.output_dir, f"{prefix}.prob_mean.csv")):
|
113 | 132 | print("skipping metric generation as it has already been done")
|
114 | 133 | exit(0)
|
|
129 | 148 |
|
130 | 149 | torch.set_grad_enabled(False)
|
131 | 150 |
|
132 |
| -# prepare the cuda model |
133 |
| -cuda_model = get_model( |
134 |
| - architecture=args.architecture, |
135 |
| - variant=args.variant, |
136 |
| - model_path=args.model_path, |
137 |
| - device_type="cuda", |
138 |
| - data_type=default_dtype, |
139 |
| - **extra_get_model_kwargs, |
140 |
| -) |
141 |
| - |
142 |
| -cuda_model.eval() |
143 |
| -print("loaded cuda model") |
144 |
| - |
145 |
| -# prepare the cpu model (this is the reference) |
146 |
| -cpu_model = get_model( |
147 |
| - architecture=args.architecture, |
148 |
| - variant=args.variant, |
149 |
| - model_path=args.model_path, |
150 |
| - device_type="cpu", |
151 |
| - data_type=torch.float32, |
152 |
| - **extra_get_model_kwargs, |
153 |
| -) |
154 |
| -cpu_model.eval() |
155 |
| -print("loaded cpu model") |
156 |
| - |
157 | 151 | def find_eos_index(reference_tokens, eos_token_id):
|
158 | 152 | result = []
|
159 | 153 | for sentence in reference_tokens:
|
@@ -181,92 +175,129 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
|
181 | 175 | input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
|
182 | 176 | return input_ids, padding_kwargs
|
183 | 177 |
|
184 |
| -ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer) |
185 |
| - |
186 |
| -# first test validation level 0 |
187 |
| -cpu_validation_info = extract_validation_information( |
188 |
| - cpu_model, |
189 |
| - ids, |
190 |
| - args.max_new_tokens, |
191 |
| - LogitsExtractorHook(), |
192 |
| - attn_algorithm="math", |
193 |
| - **padding_kwargs |
194 |
| -) |
195 |
| -cpu_static_tokens = cpu_validation_info.get_info("tokens") |
196 |
| -print("extracted cpu validation information") |
197 |
| - |
198 |
| -eos_indexes = find_eos_index(cpu_static_tokens, tokenizer.eos_token_id) |
199 |
| -print(f"valid testing tokens per sequence: {eos_indexes}") |
200 |
| - |
201 |
| -# generate cpu validation info |
202 |
| -cuda_validation_info = extract_validation_information( |
203 |
| - cuda_model, |
204 |
| - ids.to("cuda"), |
205 |
| - args.max_new_tokens, |
206 |
| - None, |
207 |
| - only_last_token=True, |
208 |
| - **{k: v.to("cuda") for k,v in padding_kwargs.items()} |
209 |
| -) |
210 |
| -cuda_static_tokens = cuda_validation_info.get_info("tokens") |
211 |
| -failed_responses = validate_level_0(cpu_static_tokens, cuda_static_tokens) |
212 |
| - |
213 |
| -print("extracted cuda validation information level 0") |
214 |
| -if len(failed_responses) != 0: |
215 |
| - print_failed_cases(failed_responses, cpu_static_tokens, cuda_static_tokens, tokenizer) |
216 |
| - |
217 | 178 | def write_csv(l, path, metric):
|
218 | 179 | with open(path, 'w') as f:
|
219 | 180 | f.write(f'{metric}\n')
|
220 | 181 | for t in l:
|
221 | 182 | f.write(f"{t[2].item()}\n")
|
222 | 183 | f.close()
|
223 | 184 |
|
| 185 | +# prepare the cuda model |
| 186 | +if not args.skip_computation: |
| 187 | + cuda_model = get_model( |
| 188 | + architecture=args.architecture, |
| 189 | + variant=args.variant, |
| 190 | + model_path=args.model_path, |
| 191 | + device_type="cuda", |
| 192 | + data_type=default_dtype, |
| 193 | + distributed_strategy=distr_param, |
| 194 | + group=dist.group.WORLD, |
| 195 | + **extra_get_model_kwargs, |
| 196 | + ) |
| 197 | + |
| 198 | + cuda_model.eval() |
| 199 | + print("loaded cuda model") |
| 200 | + |
| 201 | + # prepare the cpu model (this is the reference) |
| 202 | + cpu_model = get_model( |
| 203 | + architecture=args.architecture, |
| 204 | + variant=args.variant, |
| 205 | + model_path=args.model_path, |
| 206 | + device_type="cpu", |
| 207 | + data_type=torch.float32, |
| 208 | + distributed_strategy=distr_param, |
| 209 | + group=dist.group.WORLD, |
| 210 | + **extra_get_model_kwargs, |
| 211 | + ) |
| 212 | + cpu_model.eval() |
| 213 | + print("loaded cpu model") |
| 214 | + |
| 215 | + ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer) |
| 216 | + |
| 217 | + # first test validation level 0 |
| 218 | + cpu_validation_info = extract_validation_information( |
| 219 | + cpu_model, |
| 220 | + ids, |
| 221 | + args.max_new_tokens, |
| 222 | + LogitsExtractorHook(), |
| 223 | + attn_algorithm="math", |
| 224 | + **padding_kwargs |
| 225 | + ) |
| 226 | + cpu_static_tokens = cpu_validation_info.get_info("tokens") |
| 227 | + print("extracted cpu validation information") |
| 228 | + |
| 229 | + eos_indexes = find_eos_index(cpu_static_tokens, tokenizer.eos_token_id) |
| 230 | + print(f"valid testing tokens per sequence: {eos_indexes}") |
| 231 | + |
| 232 | + # generate cpu validation info |
| 233 | + cuda_validation_info = extract_validation_information( |
| 234 | + cuda_model, |
| 235 | + ids.to("cuda"), |
| 236 | + args.max_new_tokens, |
| 237 | + None, |
| 238 | + only_last_token=True, |
| 239 | + **{k: v.to("cuda") for k,v in padding_kwargs.items()} |
| 240 | + ) |
| 241 | + cuda_static_tokens = cuda_validation_info.get_info("tokens") |
| 242 | + failed_responses = validate_level_0(cpu_static_tokens, cuda_static_tokens) |
| 243 | + |
| 244 | + print("extracted cuda validation information level 0") |
| 245 | + if local_rank == 0: |
| 246 | + if len(failed_responses) != 0: |
| 247 | + print_failed_cases(failed_responses, cpu_static_tokens, cuda_static_tokens, tokenizer) |
| 248 | + |
224 | 249 | num_test_tokens_per_sequence = args.num_test_tokens_per_sequence
|
225 | 250 | if num_test_tokens_per_sequence is None:
|
226 | 251 | num_test_tokens_per_sequence = args.max_new_tokens
|
227 | 252 |
|
228 | 253 | cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()(r, t.softmax(dim=1).to(dtype=torch.float32))
|
229 | 254 | prob_mean = lambda r, t: torch.mean((r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32)) - 1.0)
|
230 | 255 | prob_std = lambda r, t: torch.std(r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32))
|
231 |
| -diff_mean = lambda r, t: torch.mean(r.softmax(dim=1).to(dtype=torch.float32) - t.softmax(dim=1).to(dtype=torch.float32)) |
| 256 | +diff_mean = lambda r, t: torch.mean(torch.abs(r.softmax(dim=1).to(dtype=torch.float32) - t.softmax(dim=1).to(dtype=torch.float32))) |
232 | 257 |
|
233 | 258 | prob_mean_metrics = []
|
234 | 259 | prob_std_metrics = []
|
235 | 260 | prob_diff_metrics = []
|
236 | 261 | prob_ce_loss_metrics = []
|
237 | 262 |
|
238 |
| -prefix = f"{args.variant.replace('/', '--')}_max-new-tokens-{args.max_new_tokens}_batch-size-{args.batch_size}_seq-length{args.min_pad_length}_dtype-{args.default_dtype}" |
239 |
| - |
240 | 263 | for i in range(num_test_tokens_per_sequence // args.max_new_tokens):
|
241 |
| - ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer, i) |
242 |
| - |
243 |
| - # only need to compute this once if we aren't generating more test data |
244 |
| - if num_test_tokens_per_sequence > args.max_new_tokens: |
245 |
| - cpu_validation_info = extract_validation_information( |
246 |
| - cpu_model, |
247 |
| - ids, |
| 264 | + cpu_path = os.path.join(args.output_dir, f"{prefix}.cpu_validation_info.{i}.out") |
| 265 | + cuda_path = os.path.join(args.output_dir, f"{prefix}.cuda_validation_info.{i}.out") |
| 266 | + if os.path.exists(cpu_path) and os.path.exists(cuda_path): |
| 267 | + print(f"found the logits at {cpu_path}, reusing") |
| 268 | + cpu_validation_info = load_validation_information(cpu_path, "logits", args.batch_size, tokenizer) |
| 269 | + cuda_validation_info = load_validation_information(cuda_path, "logits", args.batch_size, tokenizer) |
| 270 | + elif not args.skip_computation: |
| 271 | + ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer, i) |
| 272 | + |
| 273 | + # only need to compute this once if we aren't generating more test data |
| 274 | + if num_test_tokens_per_sequence > args.max_new_tokens: |
| 275 | + cpu_validation_info = extract_validation_information( |
| 276 | + cpu_model, |
| 277 | + ids, |
| 278 | + args.max_new_tokens, |
| 279 | + LogitsExtractorHook(), |
| 280 | + attn_algorithm="math", |
| 281 | + **padding_kwargs |
| 282 | + ) |
| 283 | + |
| 284 | + # generate aiu validation info |
| 285 | + cuda_validation_info = extract_validation_information( |
| 286 | + cuda_model, |
| 287 | + ids.to("cuda"), |
248 | 288 | args.max_new_tokens,
|
249 |
| - LogitsExtractorHook(), |
250 |
| - attn_algorithm="math", |
251 |
| - **padding_kwargs |
| 289 | + GoldenTokenHook(cpu_validation_info.get_info("tokens"), "cuda"), |
| 290 | + only_last_token=True, |
| 291 | + **{k: v.to("cuda") for k,v in padding_kwargs.items()} |
252 | 292 | )
|
253 |
| - eos_indexes = find_eos_index(cpu_validation_info.get_info("tokens"), tokenizer.eos_token_id) |
254 |
| - |
255 |
| - # generate aiu validation info |
256 |
| - cuda_validation_info = extract_validation_information( |
257 |
| - cuda_model, |
258 |
| - ids.to("cuda"), |
259 |
| - args.max_new_tokens, |
260 |
| - GoldenTokenHook(cpu_validation_info.get_info("tokens"), "cuda"), |
261 |
| - only_last_token=True, |
262 |
| - **{k: v.to("cuda") for k,v in padding_kwargs.items()} |
263 |
| - ) |
264 |
| - |
265 |
| - print("extracted cuda validation information level 1") |
266 | 293 |
|
267 |
| - cpu_validation_info.save(os.path.join(args.output_dir, f"{prefix}.cpu_validation_info.{i}.out")) |
268 |
| - cuda_validation_info.save(os.path.join(args.output_dir, f"{prefix}.cuda_validation_info.{i}.out")) |
| 294 | + print("extracted cuda validation information level 1") |
269 | 295 |
|
| 296 | + if local_rank == 0: |
| 297 | + cpu_validation_info.save(cpu_path) |
| 298 | + cuda_validation_info.save(cuda_path) |
| 299 | + |
| 300 | + eos_indexes = find_eos_index(cpu_validation_info.get_info("tokens"), tokenizer.eos_token_id) |
270 | 301 | level_1_metrics = capture_level_1_metrics(
|
271 | 302 | cpu_validation_info.get_info("logits"),
|
272 | 303 | cuda_validation_info.get_info("logits"),
|
@@ -295,7 +326,8 @@ def write_csv(l, path, metric):
|
295 | 326 | )
|
296 | 327 | prob_diff_metrics.extend(filter_before_eos(level_1_metrics, eos_indexes))
|
297 | 328 |
|
298 |
| -write_csv(prob_mean_metrics, os.path.join(args.output_dir, f"{prefix}.prob_mean.csv"), "prob_mean") |
299 |
| -write_csv(prob_std_metrics, os.path.join(args.output_dir, f"{prefix}.prob_std.csv"), "prob_std") |
300 |
| -write_csv(prob_ce_loss_metrics, os.path.join(args.output_dir, f"{prefix}.ce.csv"), "ce") |
301 |
| -write_csv(prob_diff_metrics, os.path.join(args.output_dir, f"{prefix}.diff_mean.csv"), "diff_mean") |
| 329 | +if local_rank == 0: |
| 330 | + write_csv(prob_mean_metrics, os.path.join(args.output_dir, f"{prefix}.prob_mean.csv"), "prob_mean") |
| 331 | + write_csv(prob_std_metrics, os.path.join(args.output_dir, f"{prefix}.prob_std.csv"), "prob_std") |
| 332 | + write_csv(prob_ce_loss_metrics, os.path.join(args.output_dir, f"{prefix}.ce.csv"), "ce") |
| 333 | + write_csv(prob_diff_metrics, os.path.join(args.output_dir, f"{prefix}.diff_mean.csv"), "diff_mean") |
0 commit comments