@@ -4683,10 +4683,22 @@ def check_hotswap(self, do_hotswap, ranks, alpha_scalings):
4683
4683
output_after1 = model (inputs ).logits
4684
4684
assert torch .allclose (output1 , output_after1 , atol = tol , rtol = tol )
4685
4685
4686
+ # we need to call forward third time since cudagraphs are not recorded in first call.
4687
+ if do_hotswap :
4688
+ hotswap_adapter (model , os .path .join (tmp_dirname , "adapter0" ), adapter_name = "default" )
4689
+ output_after2 = model (inputs ).logits
4690
+ assert torch .allclose (output0 , output_after2 , atol = tol , rtol = tol )
4691
+
4686
4692
# it is important to check hotswapping small to large ranks and large to small ranks
4687
4693
@pytest .mark .parametrize ("ranks" , [(11 , 11 ), (7 , 13 ), (13 , 7 )])
4688
4694
def test_hotswapping_compiled_model_does_not_trigger_recompilation (self , ranks ):
4689
- with torch ._dynamo .config .patch (error_on_recompile = True ): # raise an error on recompilation
4695
+ # here we set three configs to ensure no recompilation or cudagraph re-record occurs:
4696
+ # 1. error_on_recompile: raise an error on recompilation
4697
+ # 2. inline_inbuilt_nn_modules: needed to raise an error on static input address changes instead of re-recording
4698
+ # 3. triton.cudagraph_support_input_mutation: same as above
4699
+ dynamo_config_ctx = torch ._dynamo .config .patch (error_on_recompile = True , inline_inbuilt_nn_modules = False )
4700
+ inductor_config_ctx = torch ._inductor .config .patch ("triton.cudagraph_support_input_mutation" , False )
4701
+ with dynamo_config_ctx , inductor_config_ctx :
4690
4702
self .check_hotswap (do_hotswap = True , ranks = ranks , alpha_scalings = ranks )
4691
4703
4692
4704
def test_no_hotswapping_compiled_model_triggers_recompilation (self ):
0 commit comments