Skip to content

Commit bbc9f5d

Browse files
authored
FIX Avoid CUDA Graph re-record with hotswap (#2611)
1 parent d26f332 commit bbc9f5d

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/peft/utils/hotswap.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def hotswap_adapter_from_state_dict(
480480
# either
481481
# - adapters had the same rank
482482
# - adapters were padded with prepare_model_for_compiled_hotswap and 2nd adapter was larger
483-
old_val.data = new_val.data
483+
old_val.data.copy_(new_val.data)
484484
else:
485485
# if 2nd adapter was smaller, ensure to fill up to adapter dimension and set the rest to zeros
486486
if old_val.dim() not in (2, 4):
@@ -492,10 +492,10 @@ def hotswap_adapter_from_state_dict(
492492
# Linear or Conv2d: the check for dim 0 or 1 works for both of these layer types
493493
if old_val.shape[0] > new_val.shape[0]:
494494
old_val.data.fill_(0)
495-
old_val.data[: new_val.shape[0]] = new_val.data
495+
old_val.data[: new_val.shape[0]].copy_(new_val.data)
496496
elif old_val.shape[1] > new_val.shape[1]:
497497
old_val.data.fill_(0)
498-
old_val.data[:, : new_val.shape[1]] = new_val.data
498+
old_val.data[:, : new_val.shape[1]].copy_(new_val.data)
499499
else:
500500
raise ValueError(
501501
f"Incompatible shapes found for LoRA weights {key}: {old_val.shape} vs {new_val.shape}. Please "

tests/test_gpu_examples.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4683,10 +4683,22 @@ def check_hotswap(self, do_hotswap, ranks, alpha_scalings):
46834683
output_after1 = model(inputs).logits
46844684
assert torch.allclose(output1, output_after1, atol=tol, rtol=tol)
46854685

4686+
# we need to call forward third time since cudagraphs are not recorded in first call.
4687+
if do_hotswap:
4688+
hotswap_adapter(model, os.path.join(tmp_dirname, "adapter0"), adapter_name="default")
4689+
output_after2 = model(inputs).logits
4690+
assert torch.allclose(output0, output_after2, atol=tol, rtol=tol)
4691+
46864692
# it is important to check hotswapping small to large ranks and large to small ranks
46874693
@pytest.mark.parametrize("ranks", [(11, 11), (7, 13), (13, 7)])
46884694
def test_hotswapping_compiled_model_does_not_trigger_recompilation(self, ranks):
4689-
with torch._dynamo.config.patch(error_on_recompile=True): # raise an error on recompilation
4695+
# here we set three configs to ensure no recompilation or cudagraph re-record occurs:
4696+
# 1. error_on_recompile: raise an error on recompilation
4697+
# 2. inline_inbuilt_nn_modules: needed to raise an error on static input address changes instead of re-recording
4698+
# 3. triton.cudagraph_support_input_mutation: same as above
4699+
dynamo_config_ctx = torch._dynamo.config.patch(error_on_recompile=True, inline_inbuilt_nn_modules=False)
4700+
inductor_config_ctx = torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
4701+
with dynamo_config_ctx, inductor_config_ctx:
46904702
self.check_hotswap(do_hotswap=True, ranks=ranks, alpha_scalings=ranks)
46914703

46924704
def test_no_hotswapping_compiled_model_triggers_recompilation(self):

0 commit comments

Comments
 (0)