Skip to content

Commit e46ce79

Browse files
committed
Update
[ghstack-poisoned]
1 parent 4428a6e commit e46ce79

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

sota-implementations/grpo/grpo_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ def get_train_model(
5151
max_memory = {}
5252
for i in range(torch.cuda.device_count()):
5353
if i in train_devices:
54-
max_memory[f"cuda:{i}"] = "24GiB" # Allow max memory for devices we want to use
54+
max_memory[i] = "24GiB" # Allow max memory for devices we want to use
5555
else:
56-
max_memory[f"cuda:{i}"] = "0GiB" # No memory for other devices
56+
max_memory[i] = "0GiB" # No memory for other devices
5757
max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback
5858

5959
# Let HF handle distribution with max_memory

0 commit comments

Comments
 (0)