Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 245 additions & 0 deletions vllm_ascend/torchair/ops/shared_weight_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
from dataclasses import dataclass
from typing import Callable, Optional

import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase


def dispose_tensor(x: torch.Tensor):
x.set_(torch.empty([], device=x.device, dtype=x.dtype))


@dataclass
class LayerMetadata:
"""Metadata for a layer.
"""
layer: Optional[LinearBase] # The layer object.
post_method: Callable[[
torch.nn.Module
], None] # The `process_weights_after_loading` method from the quant method.
weight: torch.Tensor # The weight tensor.
window_idx: int # The index of the window.


@dataclass
class SharedWindowMetadata:
"""Metadata for a shared window.
"""
weight: torch.Tensor # The weight tensor to be shared by layers.
data_layer_idx: int # The index of the layer this window's weight is equal to.
work: Optional[torch.distributed.Work] # The asynchronous broadcast work.


@dataclass
class SeriesMetadata:
"""Metadata for a weight shared series.
"""
group: GroupCoordinator
start_layer: int
end_layer: int
num_layers: int
prefetch_step: int
dummy_weight: torch.Tensor # Dummy weight to replace the loaded weight matrix. All the layers in the series share the same dummy weight tensor.
layers: list[LayerMetadata]
shared_windows: list[
SharedWindowMetadata] # Shared windows for prefetching. The window size is (`prefetch_step` + 1), as only the weights for the next (`prefetch_step` + 1) layers need to be stored.
window_offset: int # The index of the window for the next coming layer.

def is_source(self, layer_idx) -> bool:
return layer_idx % self.group.world_size == self.group.rank_in_group

def post_process_after_loading(self):
# This method only needs to be called once per series.
if self.shared_windows:
return
for layer_idx in range(self.start_layer, self.end_layer):
layer = self.layers[layer_idx - self.start_layer]
is_source = self.is_source(layer_idx)
# If the weight uses dummy weight, make a copy temporary such that the post method call won't affect other layers which also uses dummy weight.
if not is_source:
layer.weight.set_(torch.empty_like(self.dummy_weight))
# Broadcast to get the true weight.
dist.broadcast(layer.weight,
src=self.group.ranks[layer_idx %
self.group.world_size],
group=self.group.device_group)
assert layer.layer is not None
# Call `process_weights_after_loading` from the quant method.
layer.post_method(layer.layer)
step = layer_idx - self.start_layer
if step < self.prefetch_step:
# Build the windows for the first `prefetch_step` layers. The weights can be used for the first `prefetch_step` layers in `forward()`, so also clone the weights.
self.shared_windows.append(
SharedWindowMetadata(
weight=layer.weight.clone().detach(),
data_layer_idx=layer_idx,
work=None,
))
layer.window_idx = step
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
if not is_source:
layer.weight.set_(self.shared_windows[-1].weight)
else:
# Build one more window for prefetch. The weight is useless, so just keep the shape.
if step == self.prefetch_step:
self.shared_windows.append(
SharedWindowMetadata(
weight=torch.empty_like(layer.weight),
data_layer_idx=-1,
work=None,
))
# When the layer not intended to be stored in this device, dispose the tensor.
if not is_source:
dispose_tensor(layer.weight)

dispose_tensor(self.dummy_weight)

def reach_layer(self, layer_idx: int):
# The index of the layer to be prefetched.
next_layer_idx = (layer_idx + self.prefetch_step
) % self.num_layers + self.start_layer
next_layer = self.layers[next_layer_idx - self.start_layer]
# The index of the window to store the weight for the coming layer.
next_layer.window_idx = self.window_offset
window = self.shared_windows[next_layer.window_idx]
# When the layer not intended to be stored in this device, link to the corresponding window's tensor.
if not self.is_source(next_layer_idx):
next_layer.weight.set_(window.weight)
# Update `window_offset` by rolling one step.
self.window_offset = (self.window_offset + 1) % (self.prefetch_step +
1)
assert window.data_layer_idx != next_layer_idx
window.data_layer_idx = next_layer_idx
# Start asynchronous broadcast work.
window.work = dist.broadcast(
next_layer.weight,
src=self.group.ranks[next_layer_idx % self.group.world_size],
group=self.group.device_group,
async_op=True)

def wait_weight(self, layer_idx: int):
# Find the asynchronous broadcast work and wait for it.
assert self.shared_windows
window = self.shared_windows[self.layers[layer_idx -
self.start_layer].window_idx]
# Make sure the data in the corresponding shared window is for the current layer.
assert window.data_layer_idx == layer_idx
if window.work is not None:
window.work.wait()
window.work = None


@dataclass
class LayerExternalMetadata:
"""External metadata for a layer.
"""
series: SeriesMetadata
layer_idx: int


_series_dict: dict[str, SeriesMetadata] = {}

_layer_external_dict: dict[int, LayerExternalMetadata] = {}


def _create_forward_wrapper(forward: Callable, series: SeriesMetadata,
layer_idx: int) -> Callable:

def wrapped_forward(*args, **kwargs):
# Wait for the weight.
series.wait_weight(layer_idx)
return forward(*args, **kwargs)

return wrapped_forward


"""
Register linear layers into a shared storage series.

In a parallel group, each device stores a distinct, non-overlapping subset of layers from the series. All layers in a series must have the same structure (are isomorphic). The weight matrix for the i-th layer is stored on device (i % n), where n is the number of devices.

After loading the model, you must call `post_process_after_loading_for_shared_weight_series(layer)` on any layer of this series to complete the initialization.

During execution, each time a new layer is reached, you must call `reach_layer_for_shared_weight_series(layer)` for that layer to prefetch the weights. The argument `prefetch_step` is a non-negative integer k that manages asynchronous weight prefetching. Each call to `reach_layer_for_shared_weight_series(current_layer)` method will trigger an asynchronous prefetch for the weights of the k-th subsequent layer after `current_layer` within the series.

Note: The layers are managed as a circular buffer. The index of the layer to prefetch is determined by the formula:
- total_layers = end_layer - start_layer
- prefetch_layer_idx = (layer_idx + prefetch_step) % total_layers + start_layer

To hold the weights for the current layer and the k prefetched layers, a pool of (k + 1) shared tensor buffers will be created for this series.

Arguments:
series_name: This name identifies which series this layer belongs to.
group: The group coordinator for handling asynchronous communications. It is recommended to create a new group coordinator for each new series.
start_layer: The index of the first layer in the series (inclusive).
end_layer: The index of the last layer in the series (exclusive). Thus, the series includes all layers with indices in the range [start_layer, end_layer).
layer_idx: The index of the current layer.
layer: The linear layer object to register.
prefetch_step: An integer that manages asynchronous weight prefetching. Setting it to 0 or 1 can cover most cases.
"""


def register_layer_to_shared_weight_series(
series_name: str,
group: GroupCoordinator,
start_layer: int,
end_layer: int,
layer_idx: int,
layer: LinearBase,
prefetch_step: int = 1,
):
global _series_dict
if series_name not in _series_dict:
num_layers = end_layer - start_layer
assert num_layers > 0
assert prefetch_step >= 0 and prefetch_step <= num_layers - 2
_series_dict[series_name] = SeriesMetadata(
group=group,
start_layer=start_layer,
end_layer=end_layer,
num_layers=num_layers,
prefetch_step=prefetch_step,
dummy_weight=torch.empty_like(layer.weight),
layers=[
LayerMetadata(
layer=None,
post_method=lambda layer: None,
weight=torch.empty([]),
window_idx=-1,
) for _ in range(num_layers)
],
shared_windows=[],
window_offset=prefetch_step,
)
series = _series_dict[series_name]
assert layer.quant_method is not None
series.layers[layer_idx - start_layer] = LayerMetadata(
layer=layer,
post_method=layer.quant_method.process_weights_after_loading,
weight=layer.weight,
window_idx=-1,
)
# Discard the original `process_weights_after_loading` method such that it won't be called by others.
layer.quant_method.process_weights_after_loading = lambda layer: None
# When the layer not intended to be stored in this device, dispose the tensor and skip weight loading.
if not series.is_source(layer_idx):
dispose_tensor(layer.weight)
layer.weight.weight_loader = lambda *args, **kwargs: None
layer.forward = _create_forward_wrapper(layer.forward, series, layer_idx)
global _layer_external_dict
_layer_external_dict[id(layer)] = LayerExternalMetadata(
series=series,
layer_idx=layer_idx,
)


def post_process_after_loading_for_shared_weight_series(layer: LinearBase):
ext = _layer_external_dict[id(layer)]
ext.series.post_process_after_loading()


def reach_layer_for_shared_weight_series(layer: LinearBase):
ext = _layer_external_dict[id(layer)]
ext.series.reach_layer(ext.layer_idx)
Loading