24
24
)
25
25
import json
26
26
from aiu_fms_testing_utils .utils .aiu_setup import dprint , aiu_dist_setup
27
+ import shutil
27
28
import os
28
29
29
30
try :
@@ -795,6 +796,7 @@ def _run_cpu_aiu_validation_test(
795
796
cpu_model ,
796
797
aiu_model ,
797
798
micro_model_path ,
799
+ verify_cache_state = None ,
798
800
):
799
801
# Get the tokenizer and AIU / CPU models to compare
800
802
tokenizer = tokenizers .get_tokenizer (model_path )
@@ -820,6 +822,12 @@ def _run_cpu_aiu_validation_test(
820
822
aiu_model ,
821
823
)
822
824
825
+ # Used only for cache tests; this is a nonparametric closure that
826
+ # should assert the cache for torch sendnn is in the correct state
827
+ # for this test
828
+ if verify_cache_state is not None :
829
+ verify_cache_state ()
830
+
823
831
# if level 0 fails validation, validate level 1
824
832
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0 :
825
833
if failed_validation_level_0 :
@@ -841,6 +849,87 @@ def _run_cpu_aiu_validation_test(
841
849
)
842
850
843
851
852
+ def _reset_cache_settings (purge_cache_dir ):
853
+ os .environ ["TORCH_SENDNN_CACHE_ENABLE" ] = "1"
854
+ os .environ ["COMPILATION_MODE" ] = "offline_decoder"
855
+ cache_dir = os .environ ["TORCH_SENDNN_CACHE_DIR" ]
856
+
857
+ # Ensure we start in clean state
858
+ if purge_cache_dir and os .path .isdir (cache_dir ):
859
+ shutil .rmtree (cache_dir )
860
+ os .mkdir (cache_dir )
861
+
862
+ from torch_sendnn .backends import cache
863
+
864
+ # Explicitly clear cache paths from the global torch sendnn graph;
865
+ # TODO would be better to add a helper to explicitly do this in
866
+ # torch sendnn
867
+ cache .cache = {}
868
+
869
+
870
+ @pytest .fixture
871
+ def use_cached_model ():
872
+ """Configures the tochsendnn cache and runs the AIU model prior to test execution;
873
+ this is computationally expensive and should only be used in situations like testing
874
+ cache hit correctness;
875
+ """
876
+ torch .manual_seed (42 )
877
+ torch .set_grad_enabled (False )
878
+ _reset_cache_settings (purge_cache_dir = True )
879
+
880
+ model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
881
+ micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
882
+
883
+ def verify_cache_miss ():
884
+ cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
885
+ updated_cache_len = (
886
+ len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
887
+ )
888
+ assert updated_cache_len == max_new_tokens , (
889
+ "cache directory not populated on cache miss"
890
+ )
891
+
892
+ dprint (
893
+ f"Setting up cache [i.e., cache miss check] for model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } "
894
+ )
895
+
896
+ # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
897
+ gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
898
+
899
+ model = _get_aiu_model (
900
+ model_path ,
901
+ gptq_kwargs_aiu ,
902
+ persistent_model_inst = None ,
903
+ )
904
+
905
+ validation_model = _get_cpu_model (
906
+ model_path ,
907
+ gptq_kwargs_cpu ,
908
+ micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
909
+ )
910
+
911
+ _run_cpu_aiu_validation_test (
912
+ model_path ,
913
+ batch_size ,
914
+ seq_length ,
915
+ max_new_tokens ,
916
+ validation_model ,
917
+ model ,
918
+ micro_model_path ,
919
+ verify_cache_state = verify_cache_miss ,
920
+ )
921
+
922
+
923
+ def _get_cache_test_params ():
924
+ # NOTE - currently we always use granite 3.3 for the cache test,
925
+ # TODO make this configurable as tests are refactored
926
+ model_path = GRANITE_3p3_8B_INSTRUCT
927
+ batch_size = COMMON_BATCH_SIZES [0 ]
928
+ seq_length = COMMON_SEQ_LENGTHS [0 ]
929
+ max_new_tokens = COMMON_MAX_NEW_TOKENS [0 ]
930
+ return [model_path , batch_size , seq_length , max_new_tokens ]
931
+
932
+
844
933
@pytest .mark .parametrize (
845
934
"model_path,batch_size,seq_length,max_new_tokens" , common_shapes
846
935
)
@@ -879,3 +968,51 @@ def test_common_shapes(
879
968
model ,
880
969
micro_model_path ,
881
970
)
971
+
972
+
973
+ def test_cache (use_cached_model ):
974
+ torch .manual_seed (42 )
975
+ torch .set_grad_enabled (False )
976
+ _reset_cache_settings (purge_cache_dir = False )
977
+
978
+ model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
979
+ micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
980
+
981
+ def verify_cache_hit ():
982
+ cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
983
+ updated_cache_len = (
984
+ len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
985
+ )
986
+ assert updated_cache_len == max_new_tokens , (
987
+ "cache miss occurred when hit was expected"
988
+ )
989
+
990
+ dprint (
991
+ f"testing: model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } , for cache hit"
992
+ )
993
+
994
+ # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
995
+ gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
996
+
997
+ model = _get_aiu_model (
998
+ model_path ,
999
+ gptq_kwargs_aiu ,
1000
+ persistent_model_inst = None ,
1001
+ )
1002
+
1003
+ validation_model = _get_cpu_model (
1004
+ model_path ,
1005
+ gptq_kwargs_cpu ,
1006
+ micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
1007
+ )
1008
+
1009
+ _run_cpu_aiu_validation_test (
1010
+ model_path ,
1011
+ batch_size ,
1012
+ seq_length ,
1013
+ max_new_tokens ,
1014
+ validation_model ,
1015
+ model ,
1016
+ micro_model_path ,
1017
+ verify_cache_state = verify_cache_hit ,
1018
+ )
0 commit comments