-
Notifications
You must be signed in to change notification settings - Fork 458
[Feat] Add a new kind of linear operation: LayerShardLinear #2931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
clrs97
wants to merge
4
commits into
vllm-project:main
Choose a base branch
from
clrs97:layer-shard-linear
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+245
−0
Open
Changes from 1 commit
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,313 @@ | ||
from dataclasses import dataclass | ||
from typing import Callable, Optional, Union | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch.nn.parameter import Parameter | ||
from vllm.distributed.parallel_state import GroupCoordinator | ||
from vllm.model_executor.custom_op import CustomOp | ||
from vllm.model_executor.layers.linear import LinearBase | ||
from vllm.model_executor.layers.quantization import QuantizationConfig | ||
from vllm.model_executor.utils import set_weight_attrs | ||
|
||
from vllm_ascend.utils import dispose_tensor | ||
|
||
|
||
@dataclass | ||
class LayerMetadata: | ||
"""Metadata for a layer. | ||
""" | ||
layer: Optional[LinearBase] # The layer object. | ||
post_method: Callable[[ | ||
torch.nn.Module | ||
], None] # The `process_weights_after_loading` method from the quant method. | ||
weight: torch.Tensor # The weight tensor. | ||
window_idx: int # The index of the window. | ||
|
||
|
||
@dataclass | ||
class SharedWindowMetadata: | ||
"""Metadata for a shared window. | ||
""" | ||
weight: torch.Tensor # The weight tensor to be shared by layers. | ||
data_layer_idx: int # The index of the layer this window's weight is equal to. | ||
work: Optional[torch.distributed.Work] # The asynchronous broadcast work. | ||
|
||
|
||
@dataclass | ||
class ClusterMetadata: | ||
"""Metadata for a cluster. | ||
""" | ||
group: GroupCoordinator | ||
start_layer: int | ||
end_layer: int | ||
num_layers: int | ||
prefetch_step: int | ||
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the cluster share the same dummy weight tensor. | ||
layers: list[LayerMetadata] | ||
shared_windows: list[ | ||
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored. | ||
window_offset: int # The index of the window for the next coming layer. | ||
|
||
def is_source(self, layer_idx) -> bool: | ||
return layer_idx % self.group.world_size == self.group.rank_in_group | ||
|
||
def post_process_after_loading(self): | ||
# This method only needs to be called once per cluster. | ||
if self.shared_windows: | ||
return | ||
for layer_idx in range(self.start_layer, self.end_layer): | ||
layer = self.layers[layer_idx - self.start_layer] | ||
is_source = self.is_source(layer_idx) | ||
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight. | ||
if not is_source: | ||
layer.weight.set_(torch.empty_like(self.dummy_weight)) | ||
# Broadcast to get the true weight. | ||
dist.broadcast(layer.weight, | ||
src=self.group.ranks[layer_idx % | ||
self.group.world_size], | ||
group=self.group.device_group) | ||
assert layer.layer is not None | ||
# Call `process_weights_after_loading` from the quant method. | ||
layer.post_method(layer.layer) | ||
step = layer_idx - self.start_layer | ||
if step < self.prefetch_step: | ||
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights. | ||
self.shared_windows.append( | ||
SharedWindowMetadata( | ||
weight=layer.weight.clone().detach(), | ||
data_layer_idx=layer_idx, | ||
work=None, | ||
)) | ||
layer.window_idx = step | ||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor. | ||
if not is_source: | ||
layer.weight.set_(self.shared_windows[-1].weight) | ||
else: | ||
# Build one more window for prefetch. The weight is useless, so just keep the shape. | ||
if step == self.prefetch_step: | ||
self.shared_windows.append( | ||
SharedWindowMetadata( | ||
weight=torch.empty_like(layer.weight), | ||
data_layer_idx=-1, | ||
work=None, | ||
)) | ||
# When the layer not intended to be stored in this device, dispose the tensor. | ||
if not is_source: | ||
dispose_tensor(layer.weight) | ||
|
||
dispose_tensor(self.dummy_weight) | ||
|
||
def get_shared_window(self, layer_idx: int): | ||
assert self.shared_windows | ||
return self.shared_windows[self.layers[layer_idx - | ||
self.start_layer].window_idx] | ||
|
||
def reach_layer(self, layer_idx: int): | ||
clrs97 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# The index of the layer to be prefetched. | ||
next_layer_idx = (layer_idx + self.prefetch_step | ||
) % self.num_layers + self.start_layer | ||
next_layer = self.layers[next_layer_idx - self.start_layer] | ||
# The index of the window to store the weight for the coming layer. | ||
next_layer.window_idx = self.window_offset | ||
window = self.shared_windows[next_layer.window_idx] | ||
# When the layer not intended to be stored in this device, link to the corresponding window's tensor. | ||
if not self.is_source(next_layer_idx): | ||
next_layer.weight.set_(window.weight) | ||
# Update `window_offset` by rolling one step. | ||
self.window_offset = (self.window_offset + 1) % (self.prefetch_step + | ||
1) | ||
assert window.data_layer_idx != next_layer_idx | ||
window.data_layer_idx = next_layer_idx | ||
# Start asynchronous broadcast work. | ||
window.work = dist.broadcast( | ||
next_layer.weight, | ||
src=self.group.ranks[next_layer_idx % self.group.world_size], | ||
group=self.group.device_group, | ||
async_op=True) | ||
|
||
|
||
_cluster_dict: dict[str, ClusterMetadata] = {} | ||
|
||
|
||
def register_layer_to_cluster( | ||
name: str, | ||
group: GroupCoordinator, | ||
start_layer: int, | ||
end_layer: int, | ||
prefetch_step: int, | ||
layer_idx: int, | ||
layer: LinearBase, | ||
) -> ClusterMetadata: | ||
global _cluster_dict | ||
if name not in _cluster_dict: | ||
num_layers = end_layer - start_layer | ||
assert num_layers > 0 | ||
assert prefetch_step >= 0 and prefetch_step <= num_layers - 2 | ||
clrs97 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
_cluster_dict[name] = ClusterMetadata( | ||
group=group, | ||
start_layer=start_layer, | ||
end_layer=end_layer, | ||
num_layers=num_layers, | ||
prefetch_step=prefetch_step, | ||
dummy_weight=torch.empty_like(layer.weight), | ||
layers=[ | ||
LayerMetadata( | ||
layer=None, | ||
post_method=lambda layer: None, | ||
weight=torch.empty([]), | ||
window_idx=-1, | ||
) for _ in range(num_layers) | ||
], | ||
shared_windows=[], | ||
window_offset=prefetch_step, | ||
) | ||
cluster = _cluster_dict[name] | ||
assert layer.quant_method is not None | ||
cluster.layers[layer_idx - start_layer] = LayerMetadata( | ||
layer=layer, | ||
post_method=layer.quant_method.process_weights_after_loading, | ||
weight=layer.weight, | ||
window_idx=-1, | ||
) | ||
# Discard the original `process_weights_after_loading` method such that it won't be called by others. | ||
layer.quant_method.process_weights_after_loading = lambda layer: None | ||
clrs97 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# When the layer not intended to be stored in this device, dispose the tensor. | ||
if not cluster.is_source(layer_idx): | ||
dispose_tensor(layer.weight) | ||
return cluster | ||
|
||
|
||
@CustomOp.register("layer_shard_linear") | ||
class LayerShardLinear(LinearBase): | ||
"""Linear layer with sharding storage. | ||
|
||
Each device in the parallel group evenly stores a set of disjoint layers. All layers must have the same structure. Assuming there are n devices, the weight matrix of the i-th layer will be stored on the (i % n)-th device. | ||
|
||
After loading the model, you must call `post_process_after_loading_for_cluster()` to complete the initialization. | ||
|
||
Each time a new layer is reached, you must call `reach_layer()` to prefetch the weights. | ||
|
||
|
||
Arguments: | ||
input_size: first dimension of matrix. | ||
output_size: second dimension of matrix. | ||
bias: If true, add bias. | ||
skip_bias_add: This was added to enable performance optimization where | ||
bias can be fused with other element-wise operations. | ||
We skip adding bias but instead return it. | ||
params_dtype: Data type for the parameters. | ||
quant_config: Quantization configure. | ||
prefix: The name of the layer in the state dict, including all parents | ||
(e.g. model.layers.0.self_attn.o_proj) | ||
return_bias: If true, return bias together with outputs in forward pass. | ||
cluster_name: A set of isomorphic layers is defined as a "cluster". This name identifies which cluster this class belongs to. | ||
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new cluster. | ||
start_layer: The index of the first layer in the cluster (inclusive). | ||
end_layer: The index of the last layer in the cluster (exclusive). Thus, the cluster includes all layers with indices in the range [start_layer, end_layer). | ||
layer_idx: The index of the current layer. | ||
prefetch_step: If set to 0, no weights will be prefetched. If set to k, it will prefetch the weights for the next k layers. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
output_size: int, | ||
bias: bool = True, | ||
skip_bias_add: bool = False, | ||
params_dtype: Optional[torch.dtype] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
prefix: str = "", | ||
*, | ||
return_bias: bool = True, | ||
cluster_name: str, | ||
group: GroupCoordinator, | ||
start_layer: int, | ||
|
||
end_layer: int, | ||
layer_idx: int, | ||
prefetch_step: int = 0, | ||
clrs97 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
): | ||
self.input_size = input_size | ||
self.output_size = output_size | ||
self.output_partition_sizes = [output_size] | ||
super().__init__(input_size, | ||
output_size, | ||
skip_bias_add, | ||
params_dtype, | ||
quant_config, | ||
prefix, | ||
return_bias=return_bias) | ||
assert self.quant_method is not None | ||
self.quant_method.create_weights( | ||
layer=self, | ||
input_size_per_partition=self.input_size, | ||
output_partition_sizes=[self.output_size], | ||
input_size=self.input_size, | ||
output_size=self.output_size, | ||
params_dtype=self.params_dtype, | ||
weight_loader=self.weight_loader) | ||
if bias: | ||
self.bias = Parameter( | ||
torch.empty(self.output_size, dtype=params_dtype)) | ||
set_weight_attrs(self.bias, { | ||
"output_dim": 0, | ||
"weight_loader": self.weight_loader, | ||
}) | ||
else: | ||
self.register_parameter("bias", None) | ||
|
||
self.layer_idx = layer_idx | ||
self.cluster = register_layer_to_cluster( | ||
name=cluster_name, | ||
group=group, | ||
start_layer=start_layer, | ||
end_layer=end_layer, | ||
prefetch_step=prefetch_step, | ||
layer_idx=layer_idx, | ||
layer=self, | ||
) | ||
|
||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): | ||
# Skip loading matrix weight when not intended to be stored on this device. | ||
if param is self.weight and not self.cluster.is_source(self.layer_idx): | ||
return | ||
assert not getattr(param, "is_gguf_weight", False) | ||
assert not getattr(param, "is_gguf_weight_type", False) | ||
# If the weight on disk does not have a shape, give it one | ||
# (such scales for AutoFp8). | ||
if len(loaded_weight.shape) == 0: | ||
loaded_weight = loaded_weight.reshape(1) | ||
assert param.size() == loaded_weight.size(), ( | ||
f"Tried to load weights of size {loaded_weight.size()}" | ||
f"to a parameter of size {param.size()}") | ||
param.data.copy_(loaded_weight) | ||
|
||
def post_process_after_loading_for_cluster(self): | ||
self.cluster.post_process_after_loading() | ||
|
||
def reach_layer(self): | ||
self.cluster.reach_layer(self.layer_idx) | ||
|
||
def forward( | ||
clrs97 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
self, | ||
input, | ||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: | ||
# Find the async broadcast work and wait for it. | ||
window = self.cluster.get_shared_window(self.layer_idx) | ||
# Make sure the data in the corresponding shared window is for the current layer. | ||
assert window.data_layer_idx == self.layer_idx | ||
if window.work is not None: | ||
window.work.wait() | ||
window.work = None | ||
# Matrix multiply. | ||
bias_ = None if self.skip_bias_add else self.bias | ||
output = self.quant_method.apply(self, input, bias=bias_) | ||
output_bias = self.bias if self.skip_bias_add else None | ||
if not self.return_bias: | ||
return output | ||
return output, output_bias | ||
|
||
def extra_repr(self) -> str: | ||
s = f"input_features={self.input_size}" | ||
s += f", output_features={self.output_size}" | ||
s += f", bias={self.bias is not None}" | ||
return s |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.