Skip to content

Commit 942e9dd

Browse files
authored
Merge pull request #41 from foundation-model-stack/mean_diff_bug_fix
Fixed bug with validation threshold generated for mean diff
2 parents 4b051fe + b43dc97 commit 942e9dd

File tree

3 files changed

+152
-111
lines changed

3 files changed

+152
-111
lines changed

aiu_fms_testing_utils/testing/validation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def save(self, save_dir_path: str):
8888

8989
def __len__(self):
9090
return len(self._validation_info_list)
91+
92+
def get_default_validation_prefix(model_id: str, max_new_tokens: int, batch_size: int, seq_length: int, dtype: str):
93+
return f"{model_id.replace('/', '--')}_max-new-tokens-{max_new_tokens}_batch-size-{batch_size}_seq-length-{seq_length}_dtype-{dtype}"
9194

9295

9396
def load_validation_information(validation_path, validation_files_type, batch_size, tokenizer=None):

scripts/generate_metrics.py

Lines changed: 125 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from typing import List, Optional, Tuple
77

88
import torch
9-
10-
from aiu_fms_testing_utils.testing.validation import capture_level_1_metrics, extract_validation_information, LogitsExtractorHook, print_failed_cases, \
9+
from torch import distributed as dist
10+
from aiu_fms_testing_utils.testing.validation import capture_level_1_metrics, extract_validation_information, LogitsExtractorHook, get_default_validation_prefix, load_validation_information, print_failed_cases, \
1111
validate_level_0, GoldenTokenHook, top_k_loss_calculator
1212
from aiu_fms_testing_utils.utils import ids_for_prompt, sample_sharegpt_requests
1313
from fms.models import get_model
@@ -97,8 +97,27 @@
9797
default={},
9898
help="Use this to override model configuration values to get model. Example: --extra_get_model_kwargs nlayers=2,..."
9999
)
100+
parser.add_argument(
101+
"--distributed",
102+
action="store_true",
103+
help="This is a distributed job (multiple instances run with RANK+WORLD_SIZE)",
104+
)
105+
parser.add_argument(
106+
"--skip_computation",
107+
action="store_true",
108+
help="Set this if the output is already assumed to be computed and would like to regenerate metrics without model loading or computation"
109+
)
110+
local_rank = int(os.getenv("LOCAL_RANK", 0))
111+
world_size = int(os.getenv("WORLD_SIZE", 1))
100112
args = parser.parse_args()
101113

114+
if args.distributed:
115+
dist.init_process_group()
116+
distr_param = "tp"
117+
torch.cuda.set_device(local_rank)
118+
else:
119+
distr_param = None
120+
102121
extra_get_model_kwargs = {}
103122
for a in args.extra_get_model_kwargs:
104123
a_split = a.split("=")
@@ -108,7 +127,7 @@
108127
extra_get_model_kwargs[a_split[0]] = a_split[1]
109128

110129
# this follows the same pattern of naming in test_shapes. This way we can save and re-use for quicker shape testing.
111-
prefix = f"{args.variant.replace('/', '--')}_max-new-tokens-{args.max_new_tokens}_batch-size-{args.batch_size}_seq-length-{args.min_pad_length}_dtype-{args.default_dtype}"
130+
prefix = get_default_validation_prefix(args.variant, args.max_new_tokens, args.batch_size, args.min_pad_length, args.default_dtype)
112131
if os.path.exists(os.path.join(args.output_dir, f"{prefix}.prob_mean.csv")):
113132
print("skipping metric generation as it has already been done")
114133
exit(0)
@@ -129,31 +148,6 @@
129148

130149
torch.set_grad_enabled(False)
131150

132-
# prepare the cuda model
133-
cuda_model = get_model(
134-
architecture=args.architecture,
135-
variant=args.variant,
136-
model_path=args.model_path,
137-
device_type="cuda",
138-
data_type=default_dtype,
139-
**extra_get_model_kwargs,
140-
)
141-
142-
cuda_model.eval()
143-
print("loaded cuda model")
144-
145-
# prepare the cpu model (this is the reference)
146-
cpu_model = get_model(
147-
architecture=args.architecture,
148-
variant=args.variant,
149-
model_path=args.model_path,
150-
device_type="cpu",
151-
data_type=torch.float32,
152-
**extra_get_model_kwargs,
153-
)
154-
cpu_model.eval()
155-
print("loaded cpu model")
156-
157151
def find_eos_index(reference_tokens, eos_token_id):
158152
result = []
159153
for sentence in reference_tokens:
@@ -181,92 +175,129 @@ def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
181175
input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
182176
return input_ids, padding_kwargs
183177

184-
ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer)
185-
186-
# first test validation level 0
187-
cpu_validation_info = extract_validation_information(
188-
cpu_model,
189-
ids,
190-
args.max_new_tokens,
191-
LogitsExtractorHook(),
192-
attn_algorithm="math",
193-
**padding_kwargs
194-
)
195-
cpu_static_tokens = cpu_validation_info.get_info("tokens")
196-
print("extracted cpu validation information")
197-
198-
eos_indexes = find_eos_index(cpu_static_tokens, tokenizer.eos_token_id)
199-
print(f"valid testing tokens per sequence: {eos_indexes}")
200-
201-
# generate cpu validation info
202-
cuda_validation_info = extract_validation_information(
203-
cuda_model,
204-
ids.to("cuda"),
205-
args.max_new_tokens,
206-
None,
207-
only_last_token=True,
208-
**{k: v.to("cuda") for k,v in padding_kwargs.items()}
209-
)
210-
cuda_static_tokens = cuda_validation_info.get_info("tokens")
211-
failed_responses = validate_level_0(cpu_static_tokens, cuda_static_tokens)
212-
213-
print("extracted cuda validation information level 0")
214-
if len(failed_responses) != 0:
215-
print_failed_cases(failed_responses, cpu_static_tokens, cuda_static_tokens, tokenizer)
216-
217178
def write_csv(l, path, metric):
218179
with open(path, 'w') as f:
219180
f.write(f'{metric}\n')
220181
for t in l:
221182
f.write(f"{t[2].item()}\n")
222183
f.close()
223184

185+
# prepare the cuda model
186+
if not args.skip_computation:
187+
cuda_model = get_model(
188+
architecture=args.architecture,
189+
variant=args.variant,
190+
model_path=args.model_path,
191+
device_type="cuda",
192+
data_type=default_dtype,
193+
distributed_strategy=distr_param,
194+
group=dist.group.WORLD,
195+
**extra_get_model_kwargs,
196+
)
197+
198+
cuda_model.eval()
199+
print("loaded cuda model")
200+
201+
# prepare the cpu model (this is the reference)
202+
cpu_model = get_model(
203+
architecture=args.architecture,
204+
variant=args.variant,
205+
model_path=args.model_path,
206+
device_type="cpu",
207+
data_type=torch.float32,
208+
distributed_strategy=distr_param,
209+
group=dist.group.WORLD,
210+
**extra_get_model_kwargs,
211+
)
212+
cpu_model.eval()
213+
print("loaded cpu model")
214+
215+
ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer)
216+
217+
# first test validation level 0
218+
cpu_validation_info = extract_validation_information(
219+
cpu_model,
220+
ids,
221+
args.max_new_tokens,
222+
LogitsExtractorHook(),
223+
attn_algorithm="math",
224+
**padding_kwargs
225+
)
226+
cpu_static_tokens = cpu_validation_info.get_info("tokens")
227+
print("extracted cpu validation information")
228+
229+
eos_indexes = find_eos_index(cpu_static_tokens, tokenizer.eos_token_id)
230+
print(f"valid testing tokens per sequence: {eos_indexes}")
231+
232+
# generate cpu validation info
233+
cuda_validation_info = extract_validation_information(
234+
cuda_model,
235+
ids.to("cuda"),
236+
args.max_new_tokens,
237+
None,
238+
only_last_token=True,
239+
**{k: v.to("cuda") for k,v in padding_kwargs.items()}
240+
)
241+
cuda_static_tokens = cuda_validation_info.get_info("tokens")
242+
failed_responses = validate_level_0(cpu_static_tokens, cuda_static_tokens)
243+
244+
print("extracted cuda validation information level 0")
245+
if local_rank == 0:
246+
if len(failed_responses) != 0:
247+
print_failed_cases(failed_responses, cpu_static_tokens, cuda_static_tokens, tokenizer)
248+
224249
num_test_tokens_per_sequence = args.num_test_tokens_per_sequence
225250
if num_test_tokens_per_sequence is None:
226251
num_test_tokens_per_sequence = args.max_new_tokens
227252

228253
cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()(r, t.softmax(dim=1).to(dtype=torch.float32))
229254
prob_mean = lambda r, t: torch.mean((r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32)) - 1.0)
230255
prob_std = lambda r, t: torch.std(r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32))
231-
diff_mean = lambda r, t: torch.mean(r.softmax(dim=1).to(dtype=torch.float32) - t.softmax(dim=1).to(dtype=torch.float32))
256+
diff_mean = lambda r, t: torch.mean(torch.abs(r.softmax(dim=1).to(dtype=torch.float32) - t.softmax(dim=1).to(dtype=torch.float32)))
232257

233258
prob_mean_metrics = []
234259
prob_std_metrics = []
235260
prob_diff_metrics = []
236261
prob_ce_loss_metrics = []
237262

238-
prefix = f"{args.variant.replace('/', '--')}_max-new-tokens-{args.max_new_tokens}_batch-size-{args.batch_size}_seq-length{args.min_pad_length}_dtype-{args.default_dtype}"
239-
240263
for i in range(num_test_tokens_per_sequence // args.max_new_tokens):
241-
ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer, i)
242-
243-
# only need to compute this once if we aren't generating more test data
244-
if num_test_tokens_per_sequence > args.max_new_tokens:
245-
cpu_validation_info = extract_validation_information(
246-
cpu_model,
247-
ids,
264+
cpu_path = os.path.join(args.output_dir, f"{prefix}.cpu_validation_info.{i}.out")
265+
cuda_path = os.path.join(args.output_dir, f"{prefix}.cuda_validation_info.{i}.out")
266+
if os.path.exists(cpu_path) and os.path.exists(cuda_path):
267+
print(f"found the logits at {cpu_path}, reusing")
268+
cpu_validation_info = load_validation_information(cpu_path, "logits", args.batch_size, tokenizer)
269+
cuda_validation_info = load_validation_information(cuda_path, "logits", args.batch_size, tokenizer)
270+
elif not args.skip_computation:
271+
ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer, i)
272+
273+
# only need to compute this once if we aren't generating more test data
274+
if num_test_tokens_per_sequence > args.max_new_tokens:
275+
cpu_validation_info = extract_validation_information(
276+
cpu_model,
277+
ids,
278+
args.max_new_tokens,
279+
LogitsExtractorHook(),
280+
attn_algorithm="math",
281+
**padding_kwargs
282+
)
283+
284+
# generate aiu validation info
285+
cuda_validation_info = extract_validation_information(
286+
cuda_model,
287+
ids.to("cuda"),
248288
args.max_new_tokens,
249-
LogitsExtractorHook(),
250-
attn_algorithm="math",
251-
**padding_kwargs
289+
GoldenTokenHook(cpu_validation_info.get_info("tokens"), "cuda"),
290+
only_last_token=True,
291+
**{k: v.to("cuda") for k,v in padding_kwargs.items()}
252292
)
253-
eos_indexes = find_eos_index(cpu_validation_info.get_info("tokens"), tokenizer.eos_token_id)
254-
255-
# generate aiu validation info
256-
cuda_validation_info = extract_validation_information(
257-
cuda_model,
258-
ids.to("cuda"),
259-
args.max_new_tokens,
260-
GoldenTokenHook(cpu_validation_info.get_info("tokens"), "cuda"),
261-
only_last_token=True,
262-
**{k: v.to("cuda") for k,v in padding_kwargs.items()}
263-
)
264-
265-
print("extracted cuda validation information level 1")
266293

267-
cpu_validation_info.save(os.path.join(args.output_dir, f"{prefix}.cpu_validation_info.{i}.out"))
268-
cuda_validation_info.save(os.path.join(args.output_dir, f"{prefix}.cuda_validation_info.{i}.out"))
294+
print("extracted cuda validation information level 1")
269295

296+
if local_rank == 0:
297+
cpu_validation_info.save(cpu_path)
298+
cuda_validation_info.save(cuda_path)
299+
300+
eos_indexes = find_eos_index(cpu_validation_info.get_info("tokens"), tokenizer.eos_token_id)
270301
level_1_metrics = capture_level_1_metrics(
271302
cpu_validation_info.get_info("logits"),
272303
cuda_validation_info.get_info("logits"),
@@ -295,7 +326,8 @@ def write_csv(l, path, metric):
295326
)
296327
prob_diff_metrics.extend(filter_before_eos(level_1_metrics, eos_indexes))
297328

298-
write_csv(prob_mean_metrics, os.path.join(args.output_dir, f"{prefix}.prob_mean.csv"), "prob_mean")
299-
write_csv(prob_std_metrics, os.path.join(args.output_dir, f"{prefix}.prob_std.csv"), "prob_std")
300-
write_csv(prob_ce_loss_metrics, os.path.join(args.output_dir, f"{prefix}.ce.csv"), "ce")
301-
write_csv(prob_diff_metrics, os.path.join(args.output_dir, f"{prefix}.diff_mean.csv"), "diff_mean")
329+
if local_rank == 0:
330+
write_csv(prob_mean_metrics, os.path.join(args.output_dir, f"{prefix}.prob_mean.csv"), "prob_mean")
331+
write_csv(prob_std_metrics, os.path.join(args.output_dir, f"{prefix}.prob_std.csv"), "prob_std")
332+
write_csv(prob_ce_loss_metrics, os.path.join(args.output_dir, f"{prefix}.ce.csv"), "ce")
333+
write_csv(prob_diff_metrics, os.path.join(args.output_dir, f"{prefix}.diff_mean.csv"), "diff_mean")

0 commit comments

Comments
 (0)