Skip to content

Commit d77e570

Browse files
authored
Merge pull request #49 from foundation-model-stack/handle_trained_micro_models
Handle trained micro models in test_decoders validation testing
2 parents 7d214a0 + 0b9084f commit d77e570

File tree

1 file changed

+51
-28
lines changed

1 file changed

+51
-28
lines changed

tests/models/test_decoders.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,51 @@
3434
except ImportError:
3535
GPTQ_ENABLED = False
3636

37+
MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home/models/tiny-models")
38+
3739
# Add models to test here
3840
LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
3941
GRANITE_3p2_8B_INSTRUCT = "ibm-granite/granite-3.2-8b-instruct"
4042
GRANITE_3p3_8B_INSTRUCT = "ibm-granite/granite-3.3-8b-instruct"
4143
GRANITE_20B_CODE_INSTRUCT_8K = "ibm-granite/granite-20b-code-instruct-8k"
4244
LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
4345

46+
micro_model_mapping = {
47+
LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-3.1-8b-layers-3-step-24000"),
48+
GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-100000"),
49+
# FIXME: Because this uses the same config as 3.2, re-using here, but should update
50+
GRANITE_3p3_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-100000"),
51+
LLAMA_3p1_70B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-3.1-70b-layers-3-step-24000")
52+
}
53+
4454
SHARE_GPT_DATASET_PATH = os.environ.get(
4555
"SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json")
4656
)
4757
USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1"
4858
USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1"
49-
FORCE_VALIDATION_LEVEL_1 = os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1"
59+
60+
FORCE_VALIDATION_LEVEL_1 = (
61+
os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1"
62+
)
5063
skip_assertions = os.environ.get("FMS_TEST_SHAPES_SKIP_ASSERTIONS", {})
5164
validation_info_dir = os.environ.get(
5265
"FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info"
5366
)
5467
common_model_paths = os.environ.get(
5568
"FMS_TEST_SHAPES_COMMON_MODEL_PATHS",
56-
[LLAMA_3p1_8B_INSTRUCT, GRANITE_3p2_8B_INSTRUCT, GRANITE_3p3_8B_INSTRUCT, GRANITE_20B_CODE_INSTRUCT_8K, LLAMA_3p1_70B_INSTRUCT],
69+
[
70+
LLAMA_3p1_8B_INSTRUCT,
71+
GRANITE_3p2_8B_INSTRUCT,
72+
GRANITE_3p3_8B_INSTRUCT,
73+
GRANITE_20B_CODE_INSTRUCT_8K,
74+
LLAMA_3p1_70B_INSTRUCT,
75+
],
5776
)
5877
# for validation level 1, the default is a failure rate of 1%
5978
# set this environment variable if you would like to relax that threshold
6079
failure_rate_threshold = os.environ.get("FMS_TEST_SHAPES_FAILURE_THRESHOLD", 0.01)
6180
default_metrics_threshold = os.environ.get(
62-
"FMS_TEST_SHAPES_METRICS_THRESHOLD", (3.0, .001)
81+
"FMS_TEST_SHAPES_METRICS_THRESHOLD", (3.0, 0.001)
6382
)
6483
save_validation_info_outputs = (
6584
os.environ.get("FMS_TEST_SHAPES_SAVE_VALIDATION_INFO_OUTPUTS", "0") == "1"
@@ -85,7 +104,9 @@
85104

86105
# pass custom default metrics threshold as a comma separated str of floats <cross-entropy threshold>,<mean diff threshold>
87106
if isinstance(default_metrics_threshold, str):
88-
default_metrics_threshold = tuple([float(m) for m in default_metrics_threshold.split(",")])
107+
default_metrics_threshold = tuple(
108+
[float(m) for m in default_metrics_threshold.split(",")]
109+
)
89110

90111
# pass custom common batch sizes as a comma separated str of ints
91112
if isinstance(common_batch_sizes, str):
@@ -123,22 +144,6 @@
123144
# if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above
124145
# threshold key is (model_id, is_tiny_model)
125146
fail_thresholds = {
126-
(LLAMA_3p1_8B_INSTRUCT, True): (
127-
3.7392955756187423,
128-
.001, # FIXME: compute
129-
),
130-
(GRANITE_3p2_8B_INSTRUCT, True): (
131-
2.996668996810913,
132-
.001, # FIXME: compute
133-
),
134-
(GRANITE_20B_CODE_INSTRUCT_8K, True): (
135-
3.7392955756187423, # FIXME: compute -- setting to micro llama 3.1 8b instruct
136-
.001, # FIXME: compute
137-
),
138-
(LLAMA_3p1_70B_INSTRUCT, True): (
139-
3.8235735702514626,
140-
.001, # FIXME: compute
141-
),
142147
(LLAMA_3p1_8B_INSTRUCT, False): (
143148
2.6994638133048965,
144149
0.00047589250549208347,
@@ -322,13 +327,18 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
322327
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
323328
is_gptq = len(gptq_kwargs_aiu) != 0
324329

325-
if USE_MICRO_MODELS:
330+
micro_model_path = micro_model_mapping.get(model_path, None)
331+
if USE_MICRO_MODELS and micro_model_path is None:
332+
dprint("using randomly initialized model")
326333
micro_model_kwargs = {"architecture": "hf_configured", "nlayers": 3}
327334
else:
335+
dprint("using trained model")
328336
micro_model_kwargs = {"architecture": "hf_pretrained"}
329337

330338
if not USE_MICRO_MODELS and os.path.exists(model_path):
331339
model_path_kwargs = {"model_path": model_path}
340+
elif USE_MICRO_MODELS and micro_model_path is not None:
341+
model_path_kwargs = {"model_path": micro_model_path}
332342
else:
333343
model_path_kwargs = {"variant": model_path}
334344

@@ -435,10 +445,12 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
435445
cross_entropy = torch.nn.CrossEntropyLoss()(
436446
r, t.softmax(dim=1).to(dtype=torch.float32)
437447
)
438-
diff = torch.mean(torch.abs(
439-
r.softmax(dim=1).to(dtype=torch.float32)
440-
- t.softmax(dim=1).to(dtype=torch.float32)
441-
))
448+
diff = torch.mean(
449+
torch.abs(
450+
r.softmax(dim=1).to(dtype=torch.float32)
451+
- t.softmax(dim=1).to(dtype=torch.float32)
452+
)
453+
)
442454
return (cross_entropy, diff)
443455

444456
iters = 1024 // max_new_tokens
@@ -506,9 +518,20 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
506518
# only consider those metrics captured prior to the eos
507519
level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes)
508520

509-
ce_threshold, diff_threshold = fail_thresholds.get(
510-
(model_path, USE_MICRO_MODELS), default_metrics_threshold
511-
)
521+
# if we do not have real model weights, use a default_metrics_threshold
522+
if USE_MICRO_MODELS and micro_model_path is None:
523+
ce_threshold, diff_threshold = default_metrics_threshold
524+
# if we have real weights, try and get the proper validation metrics threshold
525+
else:
526+
# if we have a micro model with real weights, but no real thresholds, default to the full model thresholds
527+
if USE_MICRO_MODELS:
528+
ce_threshold, diff_threshold = fail_thresholds.get(
529+
(model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold)
530+
)
531+
else:
532+
ce_threshold, diff_threshold = fail_thresholds.get(
533+
(model_path, False), default_metrics_threshold
534+
)
512535

513536
# get all failed responses for each metric
514537
ce_fail_responses = filter_failed_level_1_cases(

0 commit comments

Comments
 (0)