-
Notifications
You must be signed in to change notification settings - Fork 17
Description
Checklist
- 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/sgl-project/sgl-jax/discussions/new/choose Otherwise, it will be closed.
- 2. Please use English, otherwise it will be closed.
Motivation
The current NN module implementation only annotates the parameters with kernel_axes
metadata attached to them. This requires a full initialization of the memory buffer, which is not scalable for large models like Grok, because (1) it incurs a long initialization time for creating random weights at the beginning; (2) it may cause OOM since it always needs to create an entire buffer first.
An example of the LinearBase
initialization method is shown below:
self.weight = nnx.Param(
nnx.with_partitioning(nnx.initializers.normal(), kernel_axes)(
rngs.params(), (input_size, output_size), params_dtype
)
)
To resolve this issue, we need to use something similar to meta_device
in PyTorch that defers the initialization. In JAX, this is jax.ShapeDtypeStruct
.
Upstream SGL does not have this issue as the model implementation is already sharded with Column/RowParallelLinear
.
This is a must-have and high-priority feature. I'll modify the code and create a PR.
Related resources
No response