Skip to content

Commit 91a5e51

Browse files
committed
validate batch size before collecting
1 parent dac4cf0 commit 91a5e51

1 file changed

Lines changed: 20 additions & 0 deletions

File tree

bergson/collection.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def callback(name: str, g: torch.Tensor, indices: list[int]):
8484
attention_cfgs=attention_cfgs,
8585
)
8686

87+
validate_batch_size(model, token_batch_size, collector)
88+
8789
# Allocate space ahead of time for the gradients
8890
grad_sizes = {name: math.prod(s) for name, s in collector.shapes().items()}
8991

@@ -247,3 +249,21 @@ def process_preconditioners(
247249
preconditioners_eigen[name] = (eigval, eigvec)
248250
if rank == 0:
249251
processor.preconditioners_eigen = preconditioners_eigen
252+
253+
254+
def validate_batch_size(
255+
model: PreTrainedModel,
256+
token_batch_size: int | None,
257+
collector: GradientCollector,
258+
):
259+
"""Validate that the specified token batch size fits on device."""
260+
if token_batch_size is None:
261+
return
262+
263+
random_tokens = torch.randint(
264+
0, 10, (1, token_batch_size), device=model.device, dtype=torch.long
265+
)
266+
with collector:
267+
loss = model(random_tokens).logits[0, 0, 0].float()
268+
loss.backward()
269+
model.zero_grad()

0 commit comments

Comments
 (0)