Skip to content

Commit 1711ea7

Browse files
Add cache tests back
1 parent cfef621 commit 1711ea7

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

tests/models/test_decoders.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
import json
2626
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup
27+
import shutil
2728
import os
2829

2930
try:
@@ -795,6 +796,7 @@ def _run_cpu_aiu_validation_test(
795796
cpu_model,
796797
aiu_model,
797798
micro_model_path,
799+
verify_cache_state=None,
798800
):
799801
# Get the tokenizer and AIU / CPU models to compare
800802
tokenizer = tokenizers.get_tokenizer(model_path)
@@ -820,6 +822,12 @@ def _run_cpu_aiu_validation_test(
820822
aiu_model,
821823
)
822824

825+
# Used only for cache tests; this is a nonparametric closure that
826+
# should assert the cache for torch sendnn is in the correct state
827+
# for this test
828+
if verify_cache_state is not None:
829+
verify_cache_state()
830+
823831
# if level 0 fails validation, validate level 1
824832
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
825833
if failed_validation_level_0:
@@ -841,6 +849,87 @@ def _run_cpu_aiu_validation_test(
841849
)
842850

843851

852+
def _reset_cache_settings(purge_cache_dir):
853+
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
854+
os.environ["COMPILATION_MODE"] = "offline_decoder"
855+
cache_dir = os.environ["TORCH_SENDNN_CACHE_DIR"]
856+
857+
# Ensure we start in clean state
858+
if purge_cache_dir and os.path.isdir(cache_dir):
859+
shutil.rmtree(cache_dir)
860+
os.mkdir(cache_dir)
861+
862+
from torch_sendnn.backends import cache
863+
864+
# Explicitly clear cache paths from the global torch sendnn graph;
865+
# TODO would be better to add a helper to explicitly do this in
866+
# torch sendnn
867+
cache.cache = {}
868+
869+
870+
@pytest.fixture
871+
def use_cached_model():
872+
"""Configures the tochsendnn cache and runs the AIU model prior to test execution;
873+
this is computationally expensive and should only be used in situations like testing
874+
cache hit correctness;
875+
"""
876+
torch.manual_seed(42)
877+
torch.set_grad_enabled(False)
878+
_reset_cache_settings(purge_cache_dir=True)
879+
880+
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
881+
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
882+
883+
def verify_cache_miss():
884+
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
885+
updated_cache_len = (
886+
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
887+
)
888+
assert updated_cache_len == max_new_tokens, (
889+
"cache directory not populated on cache miss"
890+
)
891+
892+
dprint(
893+
f"Setting up cache [i.e., cache miss check] for model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
894+
)
895+
896+
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
897+
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
898+
899+
model = _get_aiu_model(
900+
model_path,
901+
gptq_kwargs_aiu,
902+
persistent_model_inst=None,
903+
)
904+
905+
validation_model = _get_cpu_model(
906+
model_path,
907+
gptq_kwargs_cpu,
908+
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
909+
)
910+
911+
_run_cpu_aiu_validation_test(
912+
model_path,
913+
batch_size,
914+
seq_length,
915+
max_new_tokens,
916+
validation_model,
917+
model,
918+
micro_model_path,
919+
verify_cache_state=verify_cache_miss,
920+
)
921+
922+
923+
def _get_cache_test_params():
924+
# NOTE - currently we always use granite 3.3 for the cache test,
925+
# TODO make this configurable as tests are refactored
926+
model_path = GRANITE_3p3_8B_INSTRUCT
927+
batch_size = COMMON_BATCH_SIZES[0]
928+
seq_length = COMMON_SEQ_LENGTHS[0]
929+
max_new_tokens = COMMON_MAX_NEW_TOKENS[0]
930+
return [model_path, batch_size, seq_length, max_new_tokens]
931+
932+
844933
@pytest.mark.parametrize(
845934
"model_path,batch_size,seq_length,max_new_tokens", common_shapes
846935
)
@@ -879,3 +968,51 @@ def test_common_shapes(
879968
model,
880969
micro_model_path,
881970
)
971+
972+
973+
def test_cache(use_cached_model):
974+
torch.manual_seed(42)
975+
torch.set_grad_enabled(False)
976+
_reset_cache_settings(purge_cache_dir=False)
977+
978+
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
979+
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
980+
981+
def verify_cache_hit():
982+
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
983+
updated_cache_len = (
984+
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
985+
)
986+
assert updated_cache_len == max_new_tokens, (
987+
"cache miss occurred when hit was expected"
988+
)
989+
990+
dprint(
991+
f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit"
992+
)
993+
994+
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
995+
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
996+
997+
model = _get_aiu_model(
998+
model_path,
999+
gptq_kwargs_aiu,
1000+
persistent_model_inst=None,
1001+
)
1002+
1003+
validation_model = _get_cpu_model(
1004+
model_path,
1005+
gptq_kwargs_cpu,
1006+
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
1007+
)
1008+
1009+
_run_cpu_aiu_validation_test(
1010+
model_path,
1011+
batch_size,
1012+
seq_length,
1013+
max_new_tokens,
1014+
validation_model,
1015+
model,
1016+
micro_model_path,
1017+
verify_cache_state=verify_cache_hit,
1018+
)

0 commit comments

Comments
 (0)