Skip to content

[Feature] Deferred Weight Initialization #241

@chhzh123

Description

@chhzh123

Checklist

Motivation

The current NN module implementation only annotates the parameters with kernel_axes metadata attached to them. This requires a full initialization of the memory buffer, which is not scalable for large models like Grok, because (1) it incurs a long initialization time for creating random weights at the beginning; (2) it may cause OOM since it always needs to create an entire buffer first.

An example of the LinearBase initialization method is shown below:

        self.weight = nnx.Param(
            nnx.with_partitioning(nnx.initializers.normal(), kernel_axes)(
                rngs.params(), (input_size, output_size), params_dtype
            )
        )

To resolve this issue, we need to use something similar to meta_device in PyTorch that defers the initialization. In JAX, this is jax.ShapeDtypeStruct.

Upstream SGL does not have this issue as the model implementation is already sharded with Column/RowParallelLinear.

This is a must-have and high-priority feature. I'll modify the code and create a PR.

Related resources

No response

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions