Skip to content

Moves all prepare inputs methods to utils #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions aiu_fms_testing_utils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Third Party
from aiu_fms_testing_utils.utils.aiu_setup import dprint
from fms.utils.tokenizers import BaseTokenizer
from fms.utils.generation import pad_input_ids
import torch
import torch.nn as nn

Expand Down Expand Up @@ -166,3 +167,43 @@ def sample_squad_v2_qa_requests(
prompt_length_max,
seed,
)

def prepare_inputs(batch_size, seq_length, tokenizer, ds_path, seed=0, ds_type="sharegpt"):
"""
Prepare input IDs and padding kwargs for a batch of questions.

Args:
batch_size (int): The number of questions in the batch.
seq_length (int): The maximum length of the input sequence.
tokenizer (Tokenizer): A tokenizer object to tokenize the questions.
ds_path (str): The path to the dataset file.
seed (int, optional): The random seed for reproducibility. Defaults to 0.
ds_type (str, optional): The type of dataset to use. Can be "sharegpt" or any other supported dataset type. Defaults to "sharegpt".

Returns:
tuple: A tuple containing the input IDs and padding kwargs.
"""
if not "sharegpt" in ds_type:
prompts_and_sizes = sample_squad_v2_qa_requests(
ds_path,
batch_size,
tokenizer,
int(seq_length / 2),
seq_length,
seed,
)
else:
prompts_and_sizes = sample_sharegpt_requests(
ds_path,
batch_size,
tokenizer,
int(seq_length / 2),
seq_length,
seed,
)
prompt_list = []
for prompt, _ in prompts_and_sizes:
prompt_list.append(ids_for_prompt(prompt, tokenizer))

input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
return input_ids, padding_kwargs
22 changes: 10 additions & 12 deletions scripts/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from torch import distributed as dist
from aiu_fms_testing_utils.testing.validation import capture_level_1_metrics, extract_validation_information, LogitsExtractorHook, get_default_validation_prefix, load_validation_information, print_failed_cases, \
validate_level_0, GoldenTokenHook, top_k_loss_calculator
from aiu_fms_testing_utils.utils import ids_for_prompt, sample_sharegpt_requests
from aiu_fms_testing_utils.utils import prepare_inputs
from fms.models import get_model
from fms.utils import tokenizers
from fms.utils.generation import pad_input_ids

parser = argparse.ArgumentParser(
description="Script to determine a reasonable logits loss threshold when testing with aiu"
Expand Down Expand Up @@ -166,14 +165,6 @@ def filter_before_eos(l, filter_indexes):
filtered_results = [list(g)[:filter_indexes[k]] for k, g in groupby(l, key=lambda x: x[0])]
return [item for sublist in filtered_results for item in sublist]

def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
prompts_and_sizes = sample_sharegpt_requests(args.sharegpt_path, batch_size, tokenizer, seq_length // 2, seq_length, seed)
prompt_list = []
for prompt, _ in prompts_and_sizes:
prompt_list.append(ids_for_prompt(prompt, tokenizer))

input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
return input_ids, padding_kwargs

def write_csv(l, path, metric):
with open(path, 'w') as f:
Expand Down Expand Up @@ -212,7 +203,10 @@ def write_csv(l, path, metric):
cpu_model.eval()
print("loaded cpu model")

ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer)
ids, padding_kwargs = prepare_inputs(batch_size=args.batch_size,
seq_length=args.min_pad_length,
tokenizer=tokenizer,
ds_path=args.sharegpt_path)

# first test validation level 0
cpu_validation_info = extract_validation_information(
Expand Down Expand Up @@ -268,7 +262,11 @@ def write_csv(l, path, metric):
cpu_validation_info = load_validation_information(cpu_path, "logits", args.batch_size, tokenizer)
cuda_validation_info = load_validation_information(cuda_path, "logits", args.batch_size, tokenizer)
elif not args.skip_computation:
ids, padding_kwargs = __prepare_inputs(args.batch_size, args.min_pad_length, tokenizer, i)
ids, padding_kwargs = prepare_inputs(batch_size=args.batch_size,
seq_length=args.min_pad_length,
tokenizer=tokenizer,
ds_path=args.sharegpt_path,
seed=i)

# only need to compute this once if we aren't generating more test data
if num_test_tokens_per_sequence > args.max_new_tokens:
Expand Down
31 changes: 11 additions & 20 deletions tests/models/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
warmup_model,
sample_sharegpt_requests,
ids_for_prompt,
prepare_inputs
)
import json
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup
Expand Down Expand Up @@ -261,23 +262,6 @@ def __maybe_get_gptq_kwargs(model_path):
return gptq_kwargs_aiu, gptq_kwargs_cpu


def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
prompts_and_sizes = sample_sharegpt_requests(
SHARE_GPT_DATASET_PATH,
batch_size,
tokenizer,
int(seq_length / 2),
seq_length,
seed,
)
prompt_list = []
for prompt, _ in prompts_and_sizes:
prompt_list.append(ids_for_prompt(prompt, tokenizer))

input_ids, extra_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
return input_ids, extra_kwargs


def __find_eos_index(reference_tokens, eos_token_id, seq_length, max_new_tokens):
result = []
for sentence in reference_tokens:
Expand Down Expand Up @@ -441,7 +425,10 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens, persi
)

# prepare input_ids
input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
input_ids, extra_kwargs = prepare_inputs(batch_size=batch_size,
seq_length=seq_length,
tokenizer=tokenizer,
ds_path=SHARE_GPT_DATASET_PATH)
extra_kwargs["attn_name"] = ATTN_NAME

# warmup aiu model
Expand Down Expand Up @@ -516,8 +503,12 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
for i in range(iters):
# for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip
if i != 0:
input_ids, extra_kwargs = __prepare_inputs(
batch_size, seq_length, tokenizer, seed=i
input_ids, extra_kwargs = prepare_inputs(
batch_size=batch_size,
seq_length=seq_length,
tokenizer=tokenizer,
ds_path=SHARE_GPT_DATASET_PATH,
seed=i
)
extra_kwargs["attn_name"] = ATTN_NAME
cpu_validation_info = __load_validation_info(
Expand Down
25 changes: 12 additions & 13 deletions tests/models/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from fms.utils import tokenizers
import pytest
from fms.models import get_model
from fms.utils.generation import pad_input_ids
import itertools
import torch
from aiu_fms_testing_utils.utils import ids_for_prompt, sample_squad_v2_qa_requests
from aiu_fms_testing_utils.utils import prepare_inputs
from aiu_fms_testing_utils.utils.aiu_setup import dprint
import os
import numpy as np
Expand Down Expand Up @@ -39,15 +38,6 @@
common_shapes = list(itertools.product(common_model_paths, common_batch_sizes, common_seq_lengths))


def __prepare_inputs(batch_size, seq_length, tokenizer, seed=0):
prompts_and_sizes = sample_squad_v2_qa_requests(SQUAD_V2_DATASET_PATH, batch_size, tokenizer, int(seq_length / 2), seq_length, seed)
prompt_list = []
for prompt, _ in prompts_and_sizes:
prompt_list.append(ids_for_prompt(prompt, tokenizer))

input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length, is_causal_mask=False)
return input_ids, padding_kwargs

def __generate_diffs(model_params_1, model_params_2):
model_params_1.model.eval()
model_params_2.model.eval()
Expand Down Expand Up @@ -115,7 +105,11 @@ def test_common_shapes(model_path, batch_size, seq_length):
)

# prepare input_ids
input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
input_ids, padding_kwargs = prepare_inputs(batch_size=batch_size,
seq_length=seq_length,
tokenizer=tokenizer,
ds_path=SQUAD_V2_DATASET_PATH,
ds_type="squad_v2")

# warmup model
logits_getter_fn = lambda x: x if isinstance(x, torch.Tensor) else torch.cat(list(x), dim=-1)
Expand All @@ -126,7 +120,12 @@ def test_common_shapes(model_path, batch_size, seq_length):
diffs = []
for i in range(20):
# prepare input_ids
input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer, seed=i)
input_ids, padding_kwargs = prepare_inputs(batch_size=batch_size,
seq_length=seq_length,
tokenizer=tokenizer,
ds_path=SQUAD_V2_DATASET_PATH,
seed=i,
ds_type="squad_v2")

aiu_msp = ModelSignatureParams(
model,
Expand Down