Skip to content

Commit d337f8e

Browse files
committed
Update docstring for model loading code
1 parent 96c75bf commit d337f8e

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/cube/inference/utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,20 @@ def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
3737
return scfg
3838

3939

40-
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> torch.nn.Module:
40+
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
4141
"""
4242
Load a safetensors checkpoint into a PyTorch model.
43+
The model is updated in place.
4344
4445
Args:
4546
model: PyTorch model to load weights into
4647
ckpt_path: Path to the safetensors checkpoint file
4748
4849
Returns:
49-
The model with loaded weights
50+
None
5051
"""
51-
assert ckpt_path.endswith(".safetensors"), (
52-
f"Checkpoint path '{ckpt_path}' is not a safetensors file"
53-
)
52+
assert ckpt_path.endswith(
53+
".safetensors"
54+
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
5455

5556
load_model(model, ckpt_path)
56-
57-
return model

0 commit comments

Comments
 (0)