From 2109fbcdeb9d2dd4a50eee844db97fd31c3503ac Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 9 Oct 2024 11:48:07 +0300 Subject: [PATCH 01/24] Do not read or write to padding --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 12 +- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 11 +- tests/kernels/test_causal_conv1d.py | 119 ++++++++++++++++++ tests/kernels/test_mamba_ssm.py | 138 +++++++++++++++++++++ 4 files changed, 270 insertions(+), 10 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 30831efdfa1..542578b61fe 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -334,16 +334,17 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { + channel_id * params.out_c_stride; float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); - bool has_initial_state = params.has_initial_state_ptr == nullptr ? false - : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; - int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - - input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr + + // cache_index == -1 is defined as padding and cache should'nt been written/read + input_t *conv_states = params.conv_states_ptr == nullptr || cache_index == -1 ? nullptr : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; + bool has_initial_state = params.has_initial_state_ptr == nullptr || conv_states == nullptr ? false + : reinterpret_cast(params.has_initial_state_ptr)[batch_id] ; + // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. if (tidx == 0) { input_t initial_state[kNElts] = {0}; @@ -528,6 +529,7 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr ? batch_id : params.conv_state_indices_ptr[batch_id]; + if (conv_state_batch_coord == -1) return; input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 6b225b41d29..c3d49546cf8 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -109,12 +109,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { sequence_start_index = query_start_loc[batch_id]; seqlen = query_start_loc[batch_id + 1] - sequence_start_index; } - const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false - : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; - const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + const bool wr_cache = cache_index != -1; + const bool has_initial_state = params.has_initial_state_ptr == nullptr && wr_cache ? false + : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; + + input_t *u = reinterpret_cast(params.u_ptr) + sequence_start_index * params.u_batch_stride + dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + sequence_start_index * params.delta_batch_stride @@ -250,7 +252,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. if (threadIdx.x == 0) { smem_running_prefix[state_idx] = prefix_op.running_prefix; - if (chunk == n_chunks - 1) { + if (chunk == n_chunks - 1 && wr_cache) { ssm_states[state_idx] = input_t(prefix_op.running_prefix.y); } } @@ -626,7 +628,6 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, TORCH_CHECK(ssm_states.scalar_type() == input_type); TORCH_CHECK(ssm_states.is_cuda()); TORCH_CHECK(ssm_states.stride(-1) == 1); - CHECK_SHAPE(ssm_states, batch_size, dim, dstate); SSMParamsBase params; set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 069020a536d..2c51e8cb0e1 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -6,6 +6,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.utils import seed_everything @@ -271,6 +272,64 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, )) +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [True]) +@pytest.mark.parametrize("has_bias", [True]) +@pytest.mark.parametrize("seqlen", [1]) +@pytest.mark.parametrize("width", [4]) +@pytest.mark.parametrize("dim", [2048]) +def test_causal_conv1d_update_with_batch_gather_padding_unchanged( + dim, + width, + seqlen, + has_bias, + silu_activation, + itype +): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + + # set )seed + seed_everything(0) + batch = 64 + + x = torch.randn(batch, dim, 1, device=device, dtype=itype) + + total_entries = 10 * batch + conv_state = torch.randn(total_entries, + dim, + width - 1, + device=device, + dtype=itype) + conv_state_before = conv_state.clone() + conv_state_indices = torch.as_tensor( + [PAD_SLOT_ID] * batch, + dtype=torch.int32, + device=device + ) + + weight = torch.randn(dim, + width, + device=device, + dtype=itype, + requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) + else: + bias = None + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=conv_state_indices) + + assert torch.equal(conv_state, conv_state_before) + + @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @@ -422,3 +481,63 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), cache_indices, has_initial_states, final_states, activation) + + +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [True]) +@pytest.mark.parametrize("has_bias", [True]) +@pytest.mark.parametrize("width", [4]) +@pytest.mark.parametrize('seqlen', + [256]) +@pytest.mark.parametrize('dim', [64]) +def test_causal_conv1d_varlen_check_padding_unchanged( + dim, + seqlen, + width, + has_bias, + silu_activation, + itype +): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + seed_everything(0) + batch = 1 + seqlens = [] + nsplits = 3 + eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values + seqlens.append( + torch.diff( + torch.cat( + [torch.tensor([-1]), eos_pos, + torch.tensor([seqlen - 1])])).tolist()) + assert sum(seqlens[-1]) == seqlen + assert all(s > 0 for s in seqlens[-1]) + + cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], + dim=0) + x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, + dtype=itype)[:, 4096:4096 + dim, :] + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + activation = None if not silu_activation else "silu" + final_states = torch.randn(nsplits + 1, + dim, + width - 1, + device=x.device, + dtype=x.dtype) + final_states_before = final_states.clone() + has_initial_states = torch.randint(0, + 2, (cumsum.shape[0] - 1, ), + dtype=torch.bool, + device=x.device) + cache_indices = torch.as_tensor([-1] * (cumsum.shape[0] - 1), + dtype=torch.int32, + device=x.device) + causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), + cache_indices, has_initial_states, final_states, + activation) + assert torch.equal(final_states, final_states_before) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 8fa55e75f6c..1180417cc95 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -515,6 +515,144 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, has_initial_state, prev_state) + +@pytest.mark.parametrize('wtype', [torch.float32]) +@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seqlen', [256]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [True]) +@pytest.mark.parametrize('delta_softplus', [True]) +@pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("varBC_groups", [1]) +@pytest.mark.parametrize("is_variable_C", [True]) +@pytest.mark.parametrize("is_variable_B", [True]) +def test_selective_scan_varlen_padding_unchanged(is_variable_B, is_variable_C, + varBC_groups, has_D, has_z, + has_delta_bias, + delta_softplus, + return_last_state, seqlen, itype, wtype): + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + seqlens = [] + nsplits = 3 + if seqlen < 10: + nsplits = 0 + eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values + seqlens.append( + torch.diff( + torch.cat( + [torch.tensor([-1]), eos_pos, + torch.tensor([seqlen - 1])])).tolist()) + assert sum(seqlens[-1]) == seqlen + assert all(s > 0 for s in seqlens[-1]) + + cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], + dim=0).cuda() + + dim = 4 + dstate = 8 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + B_shape = [varBC_groups, dstate, seqlen] + B = torch.randn(B_shape, + device=device, + dtype=wtype if not is_variable_B else itype) + C_shape = [varBC_groups, dstate, seqlen] + C = torch.randn(C_shape, + device=device, + dtype=wtype if not is_variable_C else itype) + D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None + z = torch.randn(dim, seqlen, device=device, dtype=itype) + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) + ) if has_delta_bias else None + u = torch.randn(dim, seqlen, device=device, dtype=itype) + delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype)) + prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1])) + prev_state = torch.randn(prev_state_shape, + device=u.device, + dtype=itype, + requires_grad=False) + prev_state_ref = prev_state.clone() + cache_indices = torch.as_tensor([-1] * (cumsum.shape[0] - 1), + dtype=torch.int32, + device=u.device) + + has_initial_state = torch.randint(0, + 2, (cumsum.shape[0] - 1, ), + dtype=torch.bool, + device=u.device) + out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, + delta_softplus, cumsum, cache_indices, + has_initial_state) + assert torch.equal(prev_state, prev_state_ref) + + + +@pytest.mark.parametrize("itype", + [torch.bfloat16]) +@pytest.mark.parametrize("has_z", [True]) +@pytest.mark.parametrize("dstate", [16]) +@pytest.mark.parametrize("dim", [2048]) +def test_selective_state_update_with_batch_indices_padding_unchanged( + dim, + dstate, + has_z, + itype +): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 7e-2, 7e-2 + if torch.version.hip: + atol *= 2 + # set seed + torch.random.manual_seed(0) + batch_size = 3 + + total_entries = 10 * batch_size + state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) + state_indices = torch.as_tensor( + [-1] * batch_size, + dtype=torch.int32, + device=device + ) + + x = torch.randn(batch_size, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state.clone() + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices) + assert torch.equal(state_ref, state) + + + + @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [True]) From 8545ec40b439ee3c818a42bc11af0def520c77f4 Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 9 Oct 2024 11:48:51 +0300 Subject: [PATCH 02/24] Padding support to mamba_ssm --- .../layers/mamba/ops/mamba_ssm.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 08b016c20c4..f1c43928431 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -144,9 +144,14 @@ def _selective_scan_update_kernel( z_ptrs = z_ptr + offs_m * stride_z_dim out_ptrs = out_ptr + offs_m * stride_out_dim - state = tl.load(state_ptrs, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), - other=0.0) + if HAS_STATE_BATCH_INDICES: + state = tl.load(state_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) & (state_batch_idx != -1), + other=0.0) + else: + state = tl.load(state_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) @@ -177,9 +182,14 @@ def _selective_scan_update_kernel( dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt state = state * dA + dB * x[:, None] - tl.store(state_ptrs, - state, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + if HAS_STATE_BATCH_INDICES: + tl.store(state_ptrs, + state, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) & (state_batch_idx != -1)) + else: + tl.store(state_ptrs, + state, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D From 24758ed5f96811b4c75bdfe6bff9d880f9a210c8 Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 9 Oct 2024 11:49:18 +0300 Subject: [PATCH 03/24] continuous batching jamba --- vllm/model_executor/models/jamba.py | 254 +++++++++++----------------- 1 file changed, 102 insertions(+), 152 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 06ec324b3e1..6b12aa02f07 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -9,6 +9,7 @@ from transformers import JambaConfig from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.layer import Attention from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import (get_tensor_model_parallel_rank, @@ -43,9 +44,15 @@ @dataclass class MambaCacheParams: - is_prompt: bool = False conv_state: torch.Tensor = torch.Tensor() ssm_state: torch.Tensor = torch.Tensor() + state_indices_tensor: torch.Tensor = torch.Tensor() + + def at_layer_idx(self,layer_idx): + return MambaCacheParams(self.conv_state[layer_idx], + self.ssm_state[layer_idx], + self.state_indices_tensor) + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @@ -82,8 +89,8 @@ def __init__(self, config: JambaConfig, layer_idx): # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = MergedColumnParallelLinear(self.hidden_size, - [self.intermediate_size] * 2, + self.in_proj = ColumnParallelLinear(self.hidden_size, + self.intermediate_size * 2, bias=self.use_bias) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( @@ -137,8 +144,8 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor): eps=config.rms_norm_eps) def forward(self, hidden_states: torch.Tensor, - attn_metadata: AttentionMetadata, conv_state: torch.Tensor, - ssm_state: torch.Tensor): + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -161,16 +168,18 @@ def forward(self, hidden_states: torch.Tensor, conv_weights, self.conv1d.bias, activation=self.activation, - conv_states=conv_state, + conv_states=mamba_cache_params.conv_state, has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, query_start_loc=attn_metadata.query_start_loc) else: hidden_states = causal_conv1d_update( hidden_states.transpose(0, 1), - conv_state, + mamba_cache_params.conv_state, conv_weights, self.conv1d.bias, self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor ) hidden_states = hidden_states.transpose(0, 1) @@ -196,7 +205,7 @@ def forward(self, hidden_states: torch.Tensor, and attn_metadata.context_lens_tensor is not None: scan_outputs = selective_scan_fn( hidden_states, - ssm_state, + mamba_cache_params.ssm_state, discrete_time_step, self.A, B.transpose(-2, -1), @@ -205,11 +214,12 @@ def forward(self, hidden_states: torch.Tensor, gate, time_proj_bias, delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, has_initial_state=attn_metadata.context_lens_tensor > 0, query_start_loc=attn_metadata.query_start_loc) else: scan_outputs = selective_state_update( - ssm_state, + mamba_cache_params.ssm_state, hidden_states.transpose(0, 1), discrete_time_step.transpose(0, 1), self.A, @@ -219,6 +229,7 @@ def forward(self, hidden_states: torch.Tensor, gate.transpose(0, 1), time_proj_bias, dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor ) scan_outputs = scan_outputs.transpose(0, 1) @@ -315,8 +326,7 @@ def forward( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + mamba_cache_params: MambaCacheParams, **kwargs, ): if residual is None: @@ -326,8 +336,11 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, - ssm_state) + hidden_states = self.mamba( + hidden_states, + attn_metadata, + mamba_cache_params + ) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -484,17 +497,14 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + mamba_cache_params: MambaCacheParams, ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None - for i in range(len(self.layers)): layer = self.layers[i] kv_cache = None - current_ssm_state = None - current_conv_state = None + layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): kv_cache = kv_caches[(i - self.config.attn_layer_offset) // self.config.attn_layer_period] @@ -502,8 +512,9 @@ def forward( current_state_layer = i - (1 + (i - self.config.attn_layer_offset) // self.config.attn_layer_period) - current_ssm_state = ssm_state[current_state_layer] - current_conv_state = conv_state[current_state_layer] + layer_mamba_cache_params = mamba_cache_params.at_layer_idx( + current_state_layer + ) hidden_states, residual = layer( positions=positions, @@ -511,8 +522,7 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, residual=residual, - conv_state=current_conv_state, - ssm_state=current_ssm_state, + mamba_cache_params=layer_mamba_cache_params ) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -574,7 +584,8 @@ def __init__( self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() # Maps between the request id and a dict that maps between the seq_id # and its index inside the self.mamba_cache - self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.cache_indices_mapping: Dict[str, Dict[int, int]] = {} + self.free_indices = list(range(self.scheduler_config.max_num_seqs)) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() @@ -593,154 +604,86 @@ def forward(self, # We get here only on Prefill/Eager mode runs request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] finished_requests_ids = kwargs["finished_requests_ids"] - mamba_cache = self._release_finished_and_prepare_mamba_cache( + state_indices = self._release_finished_and_prepare_mamba_cache( finished_requests_ids, request_ids_to_seq_ids) + state_indices_tensor = torch.as_tensor( + state_indices, + dtype=torch.int32, + device="cuda" + ) + mamba_cache = self.mamba_cache else: # CUDA graph capturing runs - mamba_cache = kwargs["seqlen_agnostic_capture_inputs"] + ( mamba_cache, state_indices_tensor + ) = kwargs["seqlen_agnostic_capture_inputs"] + mamba_cache_params = MambaCacheParams( + mamba_cache[0], + mamba_cache[1], + state_indices_tensor + ) hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, mamba_cache[0], - mamba_cache[1]) + attn_metadata, mamba_cache_params) return hidden_states - def _swap_mamba_cache(self, from_index: int, to_index: int): - assert len(self.mamba_cache) > 0 - for cache_t in self.mamba_cache: - cache_t[:, [to_index,from_index]] = \ - cache_t[:, [from_index,to_index]] - def _copy_mamba_cache(self, from_index: int, to_index: int): assert len(self.mamba_cache) > 0 for cache_t in self.mamba_cache: cache_t[:, to_index].copy_(cache_t[:, from_index], non_blocking=True) - def _move_out_if_already_occupied(self, index: int, - all_occupied_indices: List[int]): - if index in all_occupied_indices: - first_free_index = self._first_free_index_in_mamba_cache() - # In case occupied, move the occupied to a new empty block - self._move_cache_index_and_mappings(from_index=index, - to_index=first_free_index) - - def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, - seq_id: int, - destination_index: int): + def _assign_seq_id_to_cache_index( + self, + cur_rid: str, + seq_id: int, + finished_requests_ids + ) -> int: """ Assign (req_id,seq_id) pair to a `destination_index` index, if already occupied, move the occupying index to a free index. """ - all_occupied_indices = self._get_all_occupied_indices() - if cur_rid not in self.mamba_cache_indices_mapping: - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) - self.mamba_cache_indices_mapping[cur_rid] = { - seq_id: destination_index - } + if cur_rid in finished_requests_ids: + # set as pad, do not allocate destination index + return PAD_SLOT_ID + elif cur_rid not in self.cache_indices_mapping: + destination_index = self.free_indices.pop() + self.cache_indices_mapping[cur_rid] = {seq_id : destination_index} + return destination_index elif seq_id not in (seq_ids2indices := - self.mamba_cache_indices_mapping[cur_rid]): + self.cache_indices_mapping[cur_rid]): # parallel sampling , where n > 1, assume prefill have # already happened now we only need to copy the already # existing cache into the siblings seq_ids caches - self._move_out_if_already_occupied( - index=destination_index, - all_occupied_indices=all_occupied_indices) - index_exists = list(seq_ids2indices.values())[0] + index_exists = next(iter(seq_ids2indices.values())) # case of decoding n>1, copy prefill cache to decoding indices + destination_index = self.free_indices.pop() self._copy_mamba_cache(from_index=index_exists, to_index=destination_index) - self.mamba_cache_indices_mapping[cur_rid][ + self.cache_indices_mapping[cur_rid][ seq_id] = destination_index + return destination_index else: # already exists - cache_index_already_exists = self.mamba_cache_indices_mapping[ + return self.cache_indices_mapping[ cur_rid][seq_id] - if cache_index_already_exists != destination_index: - # In case the seq id already exists but not in - # the right destination, swap it with what's occupying it - self._swap_pair_indices_and_mappings( - from_index=cache_index_already_exists, - to_index=destination_index) def _prepare_current_run_mamba_cache( self, request_ids_to_seq_ids: Dict[str, list[int]], finished_requests_ids: List[str] - ) -> Tuple[torch.Tensor, torch.Tensor]: - running_indices = [] - request_ids_to_seq_ids_flatten = [ - (req_id, seq_id) + ) -> List[int]: + return [ + self._assign_seq_id_to_cache_index( + req_id, + seq_id, + finished_requests_ids + ) for req_id, seq_ids in request_ids_to_seq_ids.items() for seq_id in seq_ids ] - batch_size = len(request_ids_to_seq_ids_flatten) - for dest_index, (request_id, - seq_id) in enumerate(request_ids_to_seq_ids_flatten): - if request_id in finished_requests_ids: - # Do not allocate cache index for requests that run - # and finish right after - continue - self._assign_seq_id_to_mamba_cache_in_specific_dest( - request_id, seq_id, dest_index) - running_indices.append(dest_index) - - self._clean_up_first_bs_blocks(batch_size, running_indices) - conv_state = self.mamba_cache[0][:, :batch_size] - temporal_state = self.mamba_cache[1][:, :batch_size] - - return (conv_state, temporal_state) - - def _get_all_occupied_indices(self): - return [ - cache_idx - for seq_ids2indices in self.mamba_cache_indices_mapping.values() - for cache_idx in seq_ids2indices.values() - ] - - def _clean_up_first_bs_blocks(self, batch_size: int, - indices_for_current_run: List[int]): - # move out all of the occupied but currently not running blocks - # outside of the first n blocks - destination_indices = range(batch_size) - max_possible_batch_size = self.mamba_cache[0].shape[1] - for destination_index in destination_indices: - if destination_index in self._get_all_occupied_indices() and \ - destination_index not in indices_for_current_run: - # move not running indices outside of the batch - all_other_indices = list( - range(batch_size, max_possible_batch_size)) - first_avail_index = self._first_free_index_in_mamba_cache( - all_other_indices) - self._swap_indices(from_index=destination_index, - to_index=first_avail_index) - - def _move_cache_index_and_mappings(self, from_index: int, to_index: int): - self._copy_mamba_cache(from_index=from_index, to_index=to_index) - self._update_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): - self._swap_mamba_cache(from_index=from_index, to_index=to_index) - self._swap_mapping_index(from_index=from_index, to_index=to_index) - - def _swap_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - elif to_index == index: - seq_ids2index.update({seq_id: from_index}) - - def _update_mapping_index(self, from_index: int, to_index: int): - for seq_ids2index in self.mamba_cache_indices_mapping.values(): - for seq_id, index in seq_ids2index.items(): - if from_index == index: - seq_ids2index.update({seq_id: to_index}) - return def _release_finished_and_prepare_mamba_cache( self, finished_requests_ids, - request_ids_to_seq_ids) -> Tuple[torch.Tensor, torch.Tensor]: + request_ids_to_seq_ids) -> List[int]: self._release_mamba_cache(finished_requests_ids) return self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, finished_requests_ids) @@ -751,34 +694,41 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): that was provided during the capture runs (JambaForCausalLM.mamba_gc_cache_buffer). """ - self._release_finished_and_prepare_mamba_cache( + _, gc_state_indices_t = input_buffers["seqlen_agnostic_capture_inputs"] + state_indices = self._release_finished_and_prepare_mamba_cache( kwargs["finished_requests_ids"], kwargs["request_ids_to_seq_ids"]) + cuda_graph_pad_len = gc_state_indices_t.shape[0] - len(state_indices) + state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) + gc_state_indices_t.copy_(torch.as_tensor( + state_indices, + dtype=torch.int32, + device="cuda" + )) + + - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + def get_seqlen_agnostic_capture_inputs(self, batch_size): """ - Provide the CUDA graph capture runs with a buffer in adjusted size. + Provide the CUDA graph capture runs with a buffer. The buffer is used to maintain the Mamba Cache during the CUDA graph replay runs. """ - return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) + state_indices_tensor = torch.as_tensor( + [-1] * batch_size, + dtype=torch.int32, + device="cuda" + ) + return (self.mamba_cache, state_indices_tensor) def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: - if req_id in self.mamba_cache_indices_mapping: - self.mamba_cache_indices_mapping.pop(req_id) - - def _first_free_index_in_mamba_cache( - self, indices_range: Optional[List[int]] = None) -> int: - assert self.mamba_cache is not None - if indices_range is None: - max_possible_batch_size = self.mamba_cache[0].shape[1] - indices_range = list(range(max_possible_batch_size)) - all_occupied_indices = self._get_all_occupied_indices() - for i in indices_range: - if i not in all_occupied_indices: - return i - raise Exception("Couldn't find a free spot in the mamba cache! This" - "should never happen") + if req_id in self.cache_indices_mapping: + for seq_id in self.cache_indices_mapping[req_id]: + self.free_indices.append( + self.cache_indices_mapping[req_id][seq_id] + ) + self.cache_indices_mapping.pop(req_id) + def _get_mamba_cache_shape( self From 385c257c5fe628552a777b9ef139f0456bbdd2f1 Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 9 Oct 2024 13:47:43 +0300 Subject: [PATCH 04/24] format --- tests/kernels/test_causal_conv1d.py | 61 ++++------ tests/kernels/test_mamba_ssm.py | 74 +++++------- .../layers/mamba/ops/mamba_ssm.py | 9 +- vllm/model_executor/models/jamba.py | 105 ++++++------------ 4 files changed, 85 insertions(+), 164 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 2c51e8cb0e1..9d04ee4b60d 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -279,19 +279,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("dim", [2048]) def test_causal_conv1d_update_with_batch_gather_padding_unchanged( - dim, - width, - seqlen, - has_bias, - silu_activation, - itype -): + dim, width, seqlen, has_bias, silu_activation, itype): device = "cuda" - rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 1e-2, 5e-2 - - # set )seed + # set seed seed_everything(0) batch = 64 @@ -304,11 +294,9 @@ def test_causal_conv1d_update_with_batch_gather_padding_unchanged( device=device, dtype=itype) conv_state_before = conv_state.clone() - conv_state_indices = torch.as_tensor( - [PAD_SLOT_ID] * batch, - dtype=torch.int32, - device=device - ) + conv_state_indices = torch.as_tensor([PAD_SLOT_ID] * batch, + dtype=torch.int32, + device=device) weight = torch.randn(dim, width, @@ -320,12 +308,12 @@ def test_causal_conv1d_update_with_batch_gather_padding_unchanged( else: bias = None activation = None if not silu_activation else "silu" - out = causal_conv1d_update(x, - conv_state, - weight, - bias, - activation=activation, - conv_state_indices=conv_state_indices) + causal_conv1d_update(x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=conv_state_indices) assert torch.equal(conv_state, conv_state_before) @@ -487,21 +475,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize('seqlen', - [256]) +@pytest.mark.parametrize('seqlen', [256]) @pytest.mark.parametrize('dim', [64]) -def test_causal_conv1d_varlen_check_padding_unchanged( - dim, - seqlen, - width, - has_bias, - silu_activation, - itype -): +def test_causal_conv1d_varlen_check_padding_state_unchanged( + dim, seqlen, width, has_bias, silu_activation, itype): device = "cuda" - rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 1e-2, 5e-2 # set seed seed_everything(0) batch = 1 @@ -534,10 +512,9 @@ def test_causal_conv1d_varlen_check_padding_unchanged( 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, device=x.device) - cache_indices = torch.as_tensor([-1] * (cumsum.shape[0] - 1), - dtype=torch.int32, - device=x.device) - causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - cache_indices, has_initial_states, final_states, - activation) + cache_indices = torch.as_tensor([PAD_SLOT_ID] * (cumsum.shape[0] - 1), + dtype=torch.int32, + device=x.device) + causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), cache_indices, + has_initial_states, final_states, activation) assert torch.equal(final_states, final_states_before) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 1180417cc95..1fa736e8332 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -5,6 +5,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.utils import seed_everything @@ -515,7 +516,6 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, has_initial_state, prev_state) - @pytest.mark.parametrize('wtype', [torch.float32]) @pytest.mark.parametrize('itype', [torch.float32]) @pytest.mark.parametrize('seqlen', [256]) @@ -527,21 +527,13 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, @pytest.mark.parametrize("varBC_groups", [1]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) -def test_selective_scan_varlen_padding_unchanged(is_variable_B, is_variable_C, - varBC_groups, has_D, has_z, - has_delta_bias, - delta_softplus, - return_last_state, seqlen, itype, wtype): +def test_selective_scan_varlen_padding_state_unchanged( + is_variable_B, is_variable_C, varBC_groups, has_D, has_z, + has_delta_bias, delta_softplus, return_last_state, seqlen, itype, + wtype): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' - rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 3e-2, 5e-2 - rtolw, atolw = (1e-3, 1e-3) - if has_z: # If we have z, the errors on the weights seem higher - rtolw = max(rtolw, rtol) - atolw = max(atolw, atol) # set seed torch.random.manual_seed(0) seqlens = [] @@ -585,48 +577,34 @@ def test_selective_scan_varlen_padding_unchanged(is_variable_B, is_variable_C, requires_grad=False) prev_state_ref = prev_state.clone() cache_indices = torch.as_tensor([-1] * (cumsum.shape[0] - 1), - dtype=torch.int32, - device=u.device) + dtype=torch.int32, + device=u.device) has_initial_state = torch.randint(0, 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, device=u.device) - out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, cache_indices, - has_initial_state) + selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, + delta_softplus, cumsum, cache_indices, has_initial_state) assert torch.equal(prev_state, prev_state_ref) - -@pytest.mark.parametrize("itype", - [torch.bfloat16]) +@pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("dstate", [16]) @pytest.mark.parametrize("dim", [2048]) def test_selective_state_update_with_batch_indices_padding_unchanged( - dim, - dstate, - has_z, - itype -): + dim, dstate, has_z, itype): device = "cuda" - rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) - if itype == torch.bfloat16: - rtol, atol = 7e-2, 7e-2 - if torch.version.hip: - atol *= 2 # set seed torch.random.manual_seed(0) batch_size = 3 total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) - state_indices = torch.as_tensor( - [-1] * batch_size, - dtype=torch.int32, - device=device - ) + state_indices = torch.as_tensor([PAD_SLOT_ID] * batch_size, + dtype=torch.int32, + device=device) x = torch.randn(batch_size, dim, device=device, dtype=itype) dt = torch.randn(batch_size, dim, device=device, dtype=itype) @@ -637,22 +615,20 @@ def test_selective_state_update_with_batch_indices_padding_unchanged( D = torch.randn(dim, device=device) z = torch.randn_like(x) if has_z else None state_ref = state.clone() - out = selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - state_batch_indices=state_indices) + selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices) assert torch.equal(state_ref, state) - - @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [True]) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index f1c43928431..9d1b94cfc6a 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -146,11 +146,13 @@ def _selective_scan_update_kernel( if HAS_STATE_BATCH_INDICES: state = tl.load(state_ptrs, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) & (state_batch_idx != -1), + mask=(offs_m[:, None] < dim) & + (offs_n[None, :] < dstate) & (state_batch_idx != -1), other=0.0) else: state = tl.load(state_ptrs, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + mask=(offs_m[:, None] < dim) & + (offs_n[None, :] < dstate), other=0.0) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: @@ -185,7 +187,8 @@ def _selective_scan_update_kernel( if HAS_STATE_BATCH_INDICES: tl.store(state_ptrs, state, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) & (state_batch_idx != -1)) + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) & + (state_batch_idx != -1)) else: tl.store(state_ptrs, state, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 6b12aa02f07..511e2573cd5 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -17,7 +17,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -47,14 +46,13 @@ class MambaCacheParams: conv_state: torch.Tensor = torch.Tensor() ssm_state: torch.Tensor = torch.Tensor() state_indices_tensor: torch.Tensor = torch.Tensor() - - def at_layer_idx(self,layer_idx): + + def at_layer_idx(self, layer_idx): return MambaCacheParams(self.conv_state[layer_idx], self.ssm_state[layer_idx], self.state_indices_tensor) - # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class JambaMambaMixer(nn.Module): """ @@ -67,10 +65,9 @@ class JambaMambaMixer(nn.Module): **selective** state spaces) """ - def __init__(self, config: JambaConfig, layer_idx): + def __init__(self, config: JambaConfig): super().__init__() self.config = config - self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.ssm_state_size = config.mamba_d_state self.conv_kernel_size = config.mamba_d_conv @@ -90,8 +87,8 @@ def __init__(self, config: JambaConfig, layer_idx): self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.in_proj = ColumnParallelLinear(self.hidden_size, - self.intermediate_size * 2, - bias=self.use_bias) + self.intermediate_size * 2, + bias=self.use_bias) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( self.intermediate_size, @@ -179,8 +176,7 @@ def forward(self, hidden_states: torch.Tensor, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=mamba_cache_params.state_indices_tensor - ) + conv_state_indices=mamba_cache_params.state_indices_tensor) hidden_states = hidden_states.transpose(0, 1) # 3. State Space Model sequence transformation @@ -229,8 +225,7 @@ def forward(self, hidden_states: torch.Tensor, gate.transpose(0, 1), time_proj_bias, dt_softplus=True, - state_batch_indices=mamba_cache_params.state_indices_tensor - ) + state_batch_indices=mamba_cache_params.state_indices_tensor) scan_outputs = scan_outputs.transpose(0, 1) # 4. Final linear projection @@ -311,7 +306,7 @@ def __init__(self, super().__init__() self.layer_idx = layer_idx self.config = config - self.mamba = JambaMambaMixer(config, layer_idx) + self.mamba = JambaMambaMixer(config) num_experts = config.layers_num_experts[layer_idx] ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP @@ -336,11 +331,8 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba( - hidden_states, - attn_metadata, - mamba_cache_params - ) + hidden_states = self.mamba(hidden_states, attn_metadata, + mamba_cache_params) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -513,8 +505,7 @@ def forward( (i - self.config.attn_layer_offset) // self.config.attn_layer_period) layer_mamba_cache_params = mamba_cache_params.at_layer_idx( - current_state_layer - ) + current_state_layer) hidden_states, residual = layer( positions=positions, @@ -522,8 +513,7 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, residual=residual, - mamba_cache_params=layer_mamba_cache_params - ) + mamba_cache_params=layer_mamba_cache_params) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states @@ -606,22 +596,17 @@ def forward(self, finished_requests_ids = kwargs["finished_requests_ids"] state_indices = self._release_finished_and_prepare_mamba_cache( finished_requests_ids, request_ids_to_seq_ids) - state_indices_tensor = torch.as_tensor( - state_indices, - dtype=torch.int32, - device="cuda" - ) + state_indices_tensor = torch.as_tensor(state_indices, + dtype=torch.int32, + device="cuda") mamba_cache = self.mamba_cache else: # CUDA graph capturing runs - ( mamba_cache, state_indices_tensor - ) = kwargs["seqlen_agnostic_capture_inputs"] + (mamba_cache, + state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"] - mamba_cache_params = MambaCacheParams( - mamba_cache[0], - mamba_cache[1], - state_indices_tensor - ) + mamba_cache_params = MambaCacheParams(mamba_cache[0], mamba_cache[1], + state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_params) return hidden_states @@ -632,12 +617,8 @@ def _copy_mamba_cache(self, from_index: int, to_index: int): cache_t[:, to_index].copy_(cache_t[:, from_index], non_blocking=True) - def _assign_seq_id_to_cache_index( - self, - cur_rid: str, - seq_id: int, - finished_requests_ids - ) -> int: + def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, + finished_requests_ids) -> int: """ Assign (req_id,seq_id) pair to a `destination_index` index, if already occupied, move the occupying index to a free index. @@ -647,7 +628,7 @@ def _assign_seq_id_to_cache_index( return PAD_SLOT_ID elif cur_rid not in self.cache_indices_mapping: destination_index = self.free_indices.pop() - self.cache_indices_mapping[cur_rid] = {seq_id : destination_index} + self.cache_indices_mapping[cur_rid] = {seq_id: destination_index} return destination_index elif seq_id not in (seq_ids2indices := self.cache_indices_mapping[cur_rid]): @@ -659,31 +640,24 @@ def _assign_seq_id_to_cache_index( destination_index = self.free_indices.pop() self._copy_mamba_cache(from_index=index_exists, to_index=destination_index) - self.cache_indices_mapping[cur_rid][ - seq_id] = destination_index + self.cache_indices_mapping[cur_rid][seq_id] = destination_index return destination_index else: # already exists - return self.cache_indices_mapping[ - cur_rid][seq_id] + return self.cache_indices_mapping[cur_rid][seq_id] def _prepare_current_run_mamba_cache( self, request_ids_to_seq_ids: Dict[str, list[int]], - finished_requests_ids: List[str] - ) -> List[int]: + finished_requests_ids: List[str]) -> List[int]: return [ - self._assign_seq_id_to_cache_index( - req_id, - seq_id, - finished_requests_ids - ) + self._assign_seq_id_to_cache_index(req_id, seq_id, + finished_requests_ids) for req_id, seq_ids in request_ids_to_seq_ids.items() for seq_id in seq_ids ] def _release_finished_and_prepare_mamba_cache( - self, finished_requests_ids, - request_ids_to_seq_ids) -> List[int]: + self, finished_requests_ids, request_ids_to_seq_ids) -> List[int]: self._release_mamba_cache(finished_requests_ids) return self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, finished_requests_ids) @@ -698,14 +672,9 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): state_indices = self._release_finished_and_prepare_mamba_cache( kwargs["finished_requests_ids"], kwargs["request_ids_to_seq_ids"]) cuda_graph_pad_len = gc_state_indices_t.shape[0] - len(state_indices) - state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) - gc_state_indices_t.copy_(torch.as_tensor( - state_indices, - dtype=torch.int32, - device="cuda" - )) - - + state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) + gc_state_indices_t.copy_( + torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) def get_seqlen_agnostic_capture_inputs(self, batch_size): """ @@ -713,11 +682,9 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size): The buffer is used to maintain the Mamba Cache during the CUDA graph replay runs. """ - state_indices_tensor = torch.as_tensor( - [-1] * batch_size, - dtype=torch.int32, - device="cuda" - ) + state_indices_tensor = torch.as_tensor([-1] * batch_size, + dtype=torch.int32, + device="cuda") return (self.mamba_cache, state_indices_tensor) def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): @@ -725,11 +692,9 @@ def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): if req_id in self.cache_indices_mapping: for seq_id in self.cache_indices_mapping[req_id]: self.free_indices.append( - self.cache_indices_mapping[req_id][seq_id] - ) + self.cache_indices_mapping[req_id][seq_id]) self.cache_indices_mapping.pop(req_id) - def _get_mamba_cache_shape( self ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]: From 927be2cba8b8984a3da753877e07c8c910b1988d Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 9 Oct 2024 13:56:01 +0300 Subject: [PATCH 05/24] format --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 542578b61fe..a6538403eec 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -338,7 +338,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { : reinterpret_cast(params.cache_indices_ptr); int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - // cache_index == -1 is defined as padding and cache should'nt been written/read + // cache_index == -1 is defined as padding and cache shouldn't been written/read input_t *conv_states = params.conv_states_ptr == nullptr || cache_index == -1 ? nullptr : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; From 5bc07fb0110281ebd5221f28fea8a8b06db31b43 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 17:36:46 +0300 Subject: [PATCH 06/24] Add pad_slot_id to kernels params --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 40 +++++++++---- csrc/mamba/causal_conv1d/causal_conv1d.h | 1 + csrc/mamba/mamba_ssm/selective_scan.h | 1 + csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 33 ++++++----- csrc/ops.h | 6 +- csrc/torch_bindings.cpp | 9 ++- vllm/_custom_ops.py | 58 +++++++++++-------- .../layers/mamba/ops/causal_conv1d.py | 42 +++++++++----- .../layers/mamba/ops/mamba_ssm.py | 33 ++++++++--- 9 files changed, 143 insertions(+), 80 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index a6538403eec..0235f4248bc 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -55,10 +55,11 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, const at::Tensor out, const c10::optional& bias, bool silu_activation, + int64_t pad_slot_id, const c10::optional& query_start_loc = std::nullopt, const c10::optional& cache_indices = std::nullopt, - const c10::optional& has_initial_state = std::nullopt) { - + const c10::optional& has_initial_state = std::nullopt + ) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, params.dim = dim; params.seqlen = seqlen; params.width = width; + params.pad_slot_id = pad_slot_id; params.silu_activation = silu_activation; @@ -97,7 +99,10 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, const c10::optional &query_start_loc, const c10::optional &cache_indices, const c10::optional &has_initial_state, - bool silu_activation) { + bool silu_activation, + // used to identify padding entries if cache_indices provided + // incase of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -159,6 +164,7 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, silu_activation, + pad_slot_id, query_start_loc, cache_indices, has_initial_state @@ -194,7 +200,10 @@ causal_conv1d_update(const at::Tensor &x, const c10::optional &bias_, bool silu_activation, const c10::optional &cache_seqlens_, - const c10::optional &conv_state_indices_) { + const c10::optional &conv_state_indices_, + // used to identify padding entries if cache_indices provided + // incase of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -232,7 +241,9 @@ causal_conv1d_update(const at::Tensor &x, ConvParamsBase params; set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, - silu_activation); + silu_activation, + pad_slot_id + ); params.conv_state_ptr = conv_state.data_ptr(); params.conv_state_len = conv_state_len; // All stride are in elements, not bytes. @@ -334,17 +345,19 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { + channel_id * params.out_c_stride; float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + bool has_initial_state = params.has_initial_state_ptr == nullptr ? false + : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; + int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - - // cache_index == -1 is defined as padding and cache shouldn't been written/read - input_t *conv_states = params.conv_states_ptr == nullptr || cache_index == -1 ? nullptr + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } + input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; - bool has_initial_state = params.has_initial_state_ptr == nullptr || conv_states == nullptr ? false - : reinterpret_cast(params.has_initial_state_ptr)[batch_id] ; - // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. if (tidx == 0) { input_t initial_state[kNElts] = {0}; @@ -529,7 +542,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr ? batch_id : params.conv_state_indices_ptr[batch_id]; - if (conv_state_batch_coord == -1) return; + // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early + if (conv_state_batch_coord == params.pad_slot_id){ + return; + } input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index 49e37ee4528..e26684a2b98 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -13,6 +13,7 @@ struct ConvParamsBase { using index_t = uint32_t; int batch, dim, seqlen, width; + int64_t pad_slot_id; bool silu_activation; index_t x_batch_stride; diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 580d0b2e17e..563d2fe4ef6 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -21,6 +21,7 @@ struct SSMParamsBase { int dim_ngroups_ratio; bool is_variable_B; bool is_variable_C; + int64_t pad_slot_id; bool delta_softplus; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index c3d49546cf8..f5857829025 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -109,14 +109,16 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { sequence_start_index = query_start_loc[batch_id]; seqlen = query_start_loc[batch_id + 1] - sequence_start_index; } + const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false + : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; + const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - const bool wr_cache = cache_index != -1; - const bool has_initial_state = params.has_initial_state_ptr == nullptr && wr_cache ? false - : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; - - + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } input_t *u = reinterpret_cast(params.u_ptr) + sequence_start_index * params.u_batch_stride + dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + sequence_start_index * params.delta_batch_stride @@ -252,7 +254,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. if (threadIdx.x == 0) { smem_running_prefix[state_idx] = prefix_op.running_prefix; - if (chunk == n_chunks - 1 && wr_cache) { + if (chunk == n_chunks - 1) { ssm_states[state_idx] = input_t(prefix_op.running_prefix.y); } } @@ -389,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const size_t seqlen, const size_t dstate, const size_t n_groups, - const size_t n_chunks, const bool is_variable_B, const bool is_variable_C, // device pointers @@ -409,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, - bool varlen) { + bool varlen, + int64_t pad_slot_id) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -419,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.seqlen = seqlen; params.dstate = dstate; params.n_groups = n_groups; - params.n_chunks = n_chunks; params.dim_ngroups_ratio = dim / n_groups; + params.pad_slot_id = pad_slot_id; params.delta_softplus = delta_softplus; @@ -509,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const c10::optional &query_start_loc, const c10::optional &cache_indices, const c10::optional &has_initial_state, - const torch::Tensor &ssm_states) { + const torch::Tensor &ssm_states, + // used to identify padding entries if cache_indices provided + // incase of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -620,9 +625,6 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, out_z = z; - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - // at::Tensor out = torch::empty_like(u); // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = delta; TORCH_CHECK(ssm_states.scalar_type() == input_type); @@ -630,7 +632,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, TORCH_CHECK(ssm_states.stride(-1) == 1); SSMParamsBase params; - set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z, D_, delta_bias_, @@ -640,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, query_start_loc, cache_indices, has_initial_state, - varlen + varlen, + pad_slot_id ); diff --git a/csrc/ops.h b/csrc/ops.h index fce545f95a7..8bb1252a313 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -157,13 +157,13 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, - const torch::Tensor& ssm_states); + const torch::Tensor& ssm_states, int64_t pad_slot_id); at::Tensor causal_conv1d_update( const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, const c10::optional& bias_, bool silu_activation, const c10::optional& cache_seqlens_, - const c10::optional& conv_state_indices_); + const c10::optional& conv_state_indices_, int64_t pad_slot_id); at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& bias_, @@ -171,7 +171,7 @@ at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, - bool silu_activation); + bool silu_activation, int64_t pad_slot_id); #ifndef USE_ROCM using fptr_t = int64_t; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a0100b4a85e..d2360d7d228 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -278,7 +278,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? query_start_loc," "Tensor? cache_indices," "Tensor? has_initial_state," - "Tensor! ssm_states) -> ()"); + "Tensor! ssm_states," + "int pad_slot_id) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.def( @@ -288,7 +289,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? bias_," "bool silu_activation," "Tensor? cache_seqlens_," - "Tensor? conv_state_indices) -> Tensor"); + "Tensor? conv_state_indices," + "int pad_slot_id) -> Tensor"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( @@ -298,7 +300,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? query_start_loc," "Tensor? cache_indices," "Tensor? has_initial_state," - "bool silu_activation) -> Tensor"); + "bool silu_activation," + "int pad_slot_id) -> Tensor"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 24e008dc380..3326752a9de 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -453,15 +453,18 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, cu_seq_len: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: + silu_activation: bool, + pad_slot_id: int) -> torch.Tensor: return torch.empty_like(x) @torch.library.register_fake("_C::causal_conv1d_update") - def causal_conv1d_update_fake( - x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: + def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, + bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int) -> torch.Tensor: return torch.empty_like(x) @torch.library.register_fake("_C::selective_scan_fwd") @@ -474,7 +477,8 @@ def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, cu_seq_len: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - ssm_states: Optional[torch.Tensor]) -> None: + ssm_states: Optional[torch.Tensor], + pad_slot_id: int) -> None: return None @@ -789,33 +793,37 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, query_start_loc: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: + silu_activation: bool, pad_slot_id: int) -> torch.Tensor: return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, query_start_loc, cache_indices, - has_initial_state, silu_activation) + has_initial_state, silu_activation, + pad_slot_id) -def causal_conv1d_update( - x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int) -> torch.Tensor: return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, silu_activation, cache_seqlens, - conv_state_indices) - - -def selective_scan_fwd( - u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, - C: torch.Tensor, D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor): + conv_state_indices, pad_slot_id) + + +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: torch.Tensor, pad_slot_id: int): torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, delta_softplus, query_start_loc, cache_indices, has_initial_state, - ssm_states) + ssm_states, pad_slot_id) # moe diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index ed7241af6cd..fffd3704432 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -6,18 +6,18 @@ import torch from vllm import _custom_ops as ops +from vllm.attention.backends.utils import PAD_SLOT_ID -def causal_conv1d_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): +def causal_conv1d_fn(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID): """ x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen sequences are concatenated from left to right for varlen @@ -37,6 +37,13 @@ def causal_conv1d_fn( conv_states: (...,dim,width - 1) itype updated inplace if provided activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim, seqlen) """ @@ -48,7 +55,7 @@ def causal_conv1d_fn( out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, cache_indices, has_initial_state, activation - in ["silu", "swish"]) + in ["silu", "swish"], pad_slot_id) return out @@ -58,7 +65,8 @@ def causal_conv1d_update(x: torch.Tensor, bias: Optional[torch.Tensor] = None, activation: Optional[str] = None, cache_seqlens: Optional[torch.Tensor] = None, - conv_state_indices: Optional[torch.Tensor] = None): + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID): """ x: (batch, dim) or (batch, dim, seqlen) conv_state: (batch, dim, state_len), where state_len >= width - 1 @@ -73,7 +81,12 @@ def causal_conv1d_update(x: torch.Tensor, If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. - + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) """ if activation not in [None, "silu", "swish"]: @@ -83,7 +96,8 @@ def causal_conv1d_update(x: torch.Tensor, if unsqueeze: x = x.unsqueeze(-1) out = ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, - cache_seqlens, conv_state_indices) + cache_seqlens, conv_state_indices, + pad_slot_id) if unsqueeze: out = out.squeeze(-1) return out diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 9d1b94cfc6a..148125e0045 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -9,6 +9,7 @@ from packaging import version from vllm import _custom_ops as ops +from vllm.attention.backends.utils import PAD_SLOT_ID TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") @@ -50,6 +51,7 @@ def _selective_scan_update_kernel( z_ptr, out_ptr, state_batch_indices_ptr, + pad_slot_id, # Matrix dimensions batch, nheads, @@ -147,7 +149,8 @@ def _selective_scan_update_kernel( if HAS_STATE_BATCH_INDICES: state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & - (offs_n[None, :] < dstate) & (state_batch_idx != -1), + (offs_n[None, :] < dstate) & + (state_batch_idx != pad_slot_id), other=0.0) else: state = tl.load(state_ptrs, @@ -188,7 +191,7 @@ def _selective_scan_update_kernel( tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) & - (state_batch_idx != -1)) + (state_batch_idx != pad_slot_id)) else: tl.store(state_ptrs, state, @@ -211,7 +214,8 @@ def selective_state_update(state, z=None, dt_bias=None, dt_softplus=False, - state_batch_indices=None): + state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -223,6 +227,12 @@ def selective_state_update(state, D: (dim,) or (nheads, dim) z: (batch, dim) or (batch, nheads, dim) dt_bias: (dim,) or (nheads, dim) + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 Return: out: (batch, dim) or (batch, nheads, dim) """ @@ -289,6 +299,7 @@ def selective_state_update(state, z, out, state_batch_indices, + pad_slot_id, batch, nheads, dim, @@ -345,9 +356,13 @@ def selective_scan_fn( delta_softplus=False, query_start_loc=None, cache_indices=None, - has_initial_state=None) -> Tuple[torch.Tensor, torch.Tensor]: + has_initial_state=None, + pad_slot_id=PAD_SLOT_ID) -> torch.Tensor: """ u: (dim, total_length) for varlen or (batch, dim, seqlen) + applies changes in place. + ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate) + applies changes in place. delta: (dim, total_length) for varlen or (batch, dim, seqlen) A: (dim, dstate) B: (ngroups, dstate, total_length) for varlen or @@ -370,12 +385,14 @@ def selective_scan_fn( indicate if the ssm_state at the corresponding index should be used as initial state. Not providing argument assumes there's no initial state - + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padding entries + that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at indices 0 and 3 returns output: (dim, total_length) for varlen or (batch, dim, seqlen) supports inplace replacement - last_state has shape (batch, dim, dstate). - supports inplace replacement if ssm_state was provided """ if u.stride(-1) != 1: u = u.contiguous() @@ -400,7 +417,7 @@ def selective_scan_fn( ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, query_start_loc, cache_indices, has_initial_state, - ssm_states) + ssm_states, pad_slot_id) if z is None: return delta # output written inplace to delta From 6c9b043ac2edeb217957176e8eda7c8c17085858 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 17:36:59 +0300 Subject: [PATCH 07/24] revert merged column parallel --- vllm/model_executor/models/jamba.py | 46 ++++++++++++++++------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 511e2573cd5..e0a37f26de6 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -86,9 +87,9 @@ def __init__(self, config: JambaConfig): # doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - self.in_proj = ColumnParallelLinear(self.hidden_size, - self.intermediate_size * 2, - bias=self.use_bias) + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) # selective projection used to make dt, B and C input dependent self.x_proj = RowParallelLinear( self.intermediate_size, @@ -575,7 +576,7 @@ def __init__( # Maps between the request id and a dict that maps between the seq_id # and its index inside the self.mamba_cache self.cache_indices_mapping: Dict[str, Dict[int, int]] = {} - self.free_indices = list(range(self.scheduler_config.max_num_seqs)) + self.free_cache_indices = list(range(self._get_max_batch_size())) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() @@ -627,17 +628,17 @@ def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, # set as pad, do not allocate destination index return PAD_SLOT_ID elif cur_rid not in self.cache_indices_mapping: - destination_index = self.free_indices.pop() + destination_index = self.free_cache_indices.pop() self.cache_indices_mapping[cur_rid] = {seq_id: destination_index} return destination_index elif seq_id not in (seq_ids2indices := self.cache_indices_mapping[cur_rid]): # parallel sampling , where n > 1, assume prefill have - # already happened now we only need to copy the already + # already happened, so we copy the # existing cache into the siblings seq_ids caches index_exists = next(iter(seq_ids2indices.values())) # case of decoding n>1, copy prefill cache to decoding indices - destination_index = self.free_indices.pop() + destination_index = self.free_cache_indices.pop() self._copy_mamba_cache(from_index=index_exists, to_index=destination_index) self.cache_indices_mapping[cur_rid][seq_id] = destination_index @@ -664,25 +665,25 @@ def _release_finished_and_prepare_mamba_cache( def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): """ - Copy the relevant Mamba cache into the CUDA graph input buffer - that was provided during the capture runs - (JambaForCausalLM.mamba_gc_cache_buffer). + Copy the relevant state_indices into the CUDA graph input buffer """ - _, gc_state_indices_t = input_buffers["seqlen_agnostic_capture_inputs"] + _, input_state_indices_buffer = input_buffers[ + "seqlen_agnostic_capture_inputs"] state_indices = self._release_finished_and_prepare_mamba_cache( kwargs["finished_requests_ids"], kwargs["request_ids_to_seq_ids"]) - cuda_graph_pad_len = gc_state_indices_t.shape[0] - len(state_indices) + cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( + state_indices) state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) - gc_state_indices_t.copy_( + + input_state_indices_buffer.copy_( torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) def get_seqlen_agnostic_capture_inputs(self, batch_size): """ - Provide the CUDA graph capture runs with a buffer. - The buffer is used to maintain the Mamba Cache during the CUDA graph - replay runs. + Provide the CUDA graph capture runs with a state_indices buffer. + will be used during the CUDA graph decode runs. """ - state_indices_tensor = torch.as_tensor([-1] * batch_size, + state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, dtype=torch.int32, device="cuda") return (self.mamba_cache, state_indices_tensor) @@ -691,7 +692,7 @@ def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.cache_indices_mapping: for seq_id in self.cache_indices_mapping[req_id]: - self.free_indices.append( + self.free_cache_indices.append( self.cache_indices_mapping[req_id][seq_id]) self.cache_indices_mapping.pop(req_id) @@ -710,14 +711,17 @@ def _get_mamba_cache_shape( ) return conv_state_shape, temporal_state_shape + def _get_max_batch_size(self): + return (_get_graph_batch_size( + self.scheduler_config.max_num_seqs + ) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) + def _prepare_mamba_cache(self): dtype = self.lm_head.weight.dtype layers_type = self.config.layers_block_type mamba_layers = sum( [layer_type == "mamba" for layer_type in layers_type]) - max_batch_size = (_get_graph_batch_size( - self.scheduler_config.max_num_seqs) if self.scheduler_config else - max(_BATCH_SIZES_TO_CAPTURE) + 2) + max_batch_size = self._get_max_batch_size() conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() assert conv_state_shape is not None and temporal_state_shape is not None From 3a4d02b5e16c47a228bab5f0a4ac35eaa6063918 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 17:37:14 +0300 Subject: [PATCH 08/24] Add TP=2 test --- .../decoder_only/language/test_jamba.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 408d12cd5ff..91000677b36 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,5 +1,6 @@ import pytest +from tests.utils import multi_gpu_test from vllm.sampling_params import SamplingParams from vllm.worker.model_runner import _get_graph_batch_size @@ -270,6 +271,34 @@ def test_state_cleanup( "could be related to finished_requests_ids") +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_jamba_distributed_produces_identical_generation( + vllm_runner, + model: str, + dtype: str, + max_tokens: int, + example_prompts +) -> None: + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model: + vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, + max_tokens) + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model: + vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_tp_1, + outputs_1_lst=vllm_outputs_tp_2, + name_0="vllm_tp_1", + name_1="vllm_tp_2", + ) + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) def test_model_print( From c934c30ec9008bfe0379f4958d5483c57b049eea Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 17:37:35 +0300 Subject: [PATCH 09/24] Add with padding params to mamba tests --- tests/kernels/test_mamba_ssm.py | 260 +++++++++++++------------------- 1 file changed, 107 insertions(+), 153 deletions(-) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 1fa736e8332..5e8fd155e21 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -175,7 +175,8 @@ def selective_scan_opcheck_fn(u, cu_seq_len=None, cache_indices=None, has_initial_state=None, - ssm_states=None): + ssm_states=None, + pad_slot_id=PAD_SLOT_ID): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). """ @@ -204,7 +205,7 @@ def selective_scan_opcheck_fn(u, # a bogus error. opcheck(torch.ops._C.selective_scan_fwd, (u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len, - cache_indices, has_initial_state, ssm_states), + cache_indices, has_initial_state, ssm_states, pad_slot_id), test_utils=["test_schema", "test_faketensor"]) @@ -405,7 +406,8 @@ def test_selective_state_update(dim, dstate, has_z, itype): @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) -def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, +@pytest.mark.parametrize("with_padding", [False, True]) +def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, delta_softplus, return_last_state, seqlen, itype, wtype): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): @@ -421,21 +423,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, # set seed torch.random.manual_seed(0) seqlens = [] - nsplits = 3 + batch_size = 4 if seqlen < 10: - nsplits = 0 + if with_padding: + pytest.skip() + batch_size = 1 + padding = 3 if with_padding else 0 + padded_batch_size = batch_size + padding + + nsplits = padded_batch_size - 1 eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( torch.cat( [torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist()) + + print(seqlens) assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) + total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) - cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], - dim=0).cuda() + cumsum = torch.concat([ + torch.tensor([0], dtype=torch.int32), + cumsum + ], dim=0).cuda() + print(cumsum) dim = 4 dstate = 8 @@ -463,22 +477,39 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, delta_ref = delta.clone() out = None out_ref = None - prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1])) + + + prev_state_shape = (total_entries, u.shape[0], int(A.shape[1])) prev_state = torch.randn(prev_state_shape, device=u.device, dtype=itype, requires_grad=False) prev_state_ref = prev_state.clone() - cache_indices = torch.randperm(cumsum.shape[0] - 1, + prev_state_for_padding_test = prev_state.clone() + state_indices = torch.randperm(total_entries, dtype=torch.int32, - device=u.device) - + device=u.device)[:batch_size] + unused_states_bool = torch.ones( + total_entries, + dtype=torch.bool, + device=device + ) + unused_states_bool[state_indices] = False + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, + dtype=torch.int32, + device=device + ), + ],dim=-1) + has_initial_state = torch.randint(0, 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, device=u.device) out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, cache_indices, + delta_softplus, cumsum, padded_state_indices, has_initial_state) outs_ref = [] splits = [ @@ -487,147 +518,50 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, ] for i in range(len(seqlens[0])): u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits] - out_ref_s, _ = selective_scan_ref( - u_s, - delta_s, - A_ref, - B_s, - C_s, - D_ref, - z=z_s, - delta_bias=delta_bias, - delta_softplus=delta_softplus, - return_last_state=return_last_state, - prev_state=prev_state_ref[cache_indices[i]].unsqueeze(0) - if has_initial_state[i] else None, - final_state_out=prev_state_ref[cache_indices[i]].unsqueeze(0)) + if padded_state_indices[i] == PAD_SLOT_ID: + out_ref_s = z_s if has_z else delta_s + else: + out_ref_s, _ = selective_scan_ref( + u_s, + delta_s, + A_ref, + B_s, + C_s, + D_ref, + z=z_s, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state, + prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0) + if has_initial_state[i] else None, + final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(0)) outs_ref.append(out_ref_s) out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0] + # test "real" entries are correct print("Output diff max", (out - out_ref[0]).max()) print("Output diff mean", (out - out_ref[0]).mean()) print("Output state diff max", (prev_state - prev_state_ref).max()) print("Output state diff mean", (prev_state - prev_state_ref).mean()) assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol) + # test padded entries are correct + if with_padding: + assert torch.equal( + prev_state[unused_states_bool], + prev_state_for_padding_test[unused_states_bool] + ) + if has_z: + assert torch.equal(out[batch_size + 1:], z[batch_size + 1:]) + else: + assert torch.equal(out[batch_size + 1:], delta[batch_size + 1:]) + selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, cache_indices, + delta_softplus, cumsum, padded_state_indices, has_initial_state, prev_state) -@pytest.mark.parametrize('wtype', [torch.float32]) -@pytest.mark.parametrize('itype', [torch.float32]) -@pytest.mark.parametrize('seqlen', [256]) -@pytest.mark.parametrize("return_last_state", [True]) -@pytest.mark.parametrize('has_delta_bias', [True]) -@pytest.mark.parametrize('delta_softplus', [True]) -@pytest.mark.parametrize('has_z', [True]) -@pytest.mark.parametrize('has_D', [True]) -@pytest.mark.parametrize("varBC_groups", [1]) -@pytest.mark.parametrize("is_variable_C", [True]) -@pytest.mark.parametrize("is_variable_B", [True]) -def test_selective_scan_varlen_padding_state_unchanged( - is_variable_B, is_variable_C, varBC_groups, has_D, has_z, - has_delta_bias, delta_softplus, return_last_state, seqlen, itype, - wtype): - if varBC_groups > 1 and (not is_variable_B or not is_variable_C): - pytest.skip() # This config is not applicable - device = 'cuda' - # set seed - torch.random.manual_seed(0) - seqlens = [] - nsplits = 3 - if seqlen < 10: - nsplits = 0 - eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values - seqlens.append( - torch.diff( - torch.cat( - [torch.tensor([-1]), eos_pos, - torch.tensor([seqlen - 1])])).tolist()) - assert sum(seqlens[-1]) == seqlen - assert all(s > 0 for s in seqlens[-1]) - - cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) - cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], - dim=0).cuda() - - dim = 4 - dstate = 8 - A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) - B_shape = [varBC_groups, dstate, seqlen] - B = torch.randn(B_shape, - device=device, - dtype=wtype if not is_variable_B else itype) - C_shape = [varBC_groups, dstate, seqlen] - C = torch.randn(C_shape, - device=device, - dtype=wtype if not is_variable_C else itype) - D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None - z = torch.randn(dim, seqlen, device=device, dtype=itype) - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) - ) if has_delta_bias else None - u = torch.randn(dim, seqlen, device=device, dtype=itype) - delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype)) - prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1])) - prev_state = torch.randn(prev_state_shape, - device=u.device, - dtype=itype, - requires_grad=False) - prev_state_ref = prev_state.clone() - cache_indices = torch.as_tensor([-1] * (cumsum.shape[0] - 1), - dtype=torch.int32, - device=u.device) - - has_initial_state = torch.randint(0, - 2, (cumsum.shape[0] - 1, ), - dtype=torch.bool, - device=u.device) - selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, cache_indices, has_initial_state) - assert torch.equal(prev_state, prev_state_ref) - - -@pytest.mark.parametrize("itype", [torch.bfloat16]) -@pytest.mark.parametrize("has_z", [True]) -@pytest.mark.parametrize("dstate", [16]) -@pytest.mark.parametrize("dim", [2048]) -def test_selective_state_update_with_batch_indices_padding_unchanged( - dim, dstate, has_z, itype): - device = "cuda" - # set seed - torch.random.manual_seed(0) - batch_size = 3 - - total_entries = 10 * batch_size - state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) - state_indices = torch.as_tensor([PAD_SLOT_ID] * batch_size, - dtype=torch.int32, - device=device) - - x = torch.randn(batch_size, dim, device=device, dtype=itype) - dt = torch.randn(batch_size, dim, device=device, dtype=itype) - dt_bias = torch.rand(dim, device=device) - 4.0 - A = -torch.rand(dim, dstate, device=device) - 1.0 - B = torch.randn(batch_size, dstate, device=device) - C = torch.randn(batch_size, dstate, device=device) - D = torch.randn(dim, device=device) - z = torch.randn_like(x) if has_z else None - state_ref = state.clone() - selective_state_update(state, - x, - dt, - A, - B, - C, - D=D, - z=z, - dt_bias=dt_bias, - dt_softplus=True, - state_batch_indices=state_indices) - assert torch.equal(state_ref, state) - @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @@ -644,21 +578,30 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): # set seed torch.random.manual_seed(0) batch_size = 3 - + padding = 5 + padded_batch_size = batch_size + padding total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state_indices = torch.randperm(total_entries)[:batch_size].to( dtype=torch.int32, device=device) - - x = torch.randn(batch_size, dim, device=device, dtype=itype) - dt = torch.randn(batch_size, dim, device=device, dtype=itype) + unused_states_bool= torch.ones(total_entries,dtype=torch.bool,device=device) + unused_states_bool[state_indices] = False + padded_state_indices = torch.concat([state_indices,torch.as_tensor( + [PAD_SLOT_ID] * padding, + dtype=torch.int32, + device=device + )],dim=0) + + x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) + dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) dt_bias = torch.rand(dim, device=device) - 4.0 A = -torch.rand(dim, dstate, device=device) - 1.0 - B = torch.randn(batch_size, dstate, device=device) - C = torch.randn(batch_size, dstate, device=device) + B = torch.randn(padded_batch_size, dstate, device=device) + C = torch.randn(padded_batch_size, dstate, device=device) D = torch.randn(dim, device=device) z = torch.randn_like(x) if has_z else None - state_ref = state[state_indices, :].detach().clone() + state_ref = state[state_indices, :].clone() + state_before = state.clone() out = selective_state_update(state, x, dt, @@ -669,15 +612,15 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): z=z, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices) + state_batch_indices=padded_state_indices) out_ref = selective_state_update_ref(state_ref, - x, - dt, + x[:batch_size], + dt[:batch_size], A, - B, - C, + B[:batch_size], + C[:batch_size], D=D, - z=z, + z=z[:batch_size], dt_bias=dt_bias, dt_softplus=True) @@ -686,11 +629,22 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): print("Output state diff max", (state[state_indices, :] - state_ref).max()) print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) + # test padded entries stay the same + assert torch.equal( + state_before[unused_states_bool], + state[unused_states_bool] + ) + assert torch.equal(x[batch_size + 1:],x[batch_size + 1:]) + assert torch.equal(dt[batch_size + 1:],dt[batch_size + 1:]) + assert torch.equal(B[batch_size + 1:],B[batch_size + 1:]) + assert torch.equal(C[batch_size + 1:],C[batch_size + 1:]) + + # test "real" entries assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize("itype", From 8a6626c4d41e83f1ada791dcecd4567dbf55c519 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 17:39:58 +0300 Subject: [PATCH 10/24] Format --- tests/kernels/test_causal_conv1d.py | 56 +++++---------- tests/kernels/test_mamba_ssm.py | 72 +++++++++---------- .../decoder_only/language/test_jamba.py | 8 +-- .../layers/mamba/ops/mamba_ssm.py | 31 ++++---- vllm/model_executor/models/jamba.py | 5 +- 5 files changed, 68 insertions(+), 104 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 9d04ee4b60d..65ff1d4bac2 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -115,16 +115,15 @@ def causal_conv1d_update_ref(x, @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) -def causal_conv1d_opcheck_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - cu_seq_len: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): +def causal_conv1d_opcheck_fn(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + cu_seq_len: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID): """ x: (batch, dim, seqlen) weight: (dim, width) @@ -142,16 +141,9 @@ def causal_conv1d_opcheck_fn( x = x.contiguous() bias = bias.contiguous() if bias is not None else None - opcheck(torch.ops._C.causal_conv1d_fwd, ( - x, - weight, - bias, - conv_states, - cu_seq_len, - cache_indices, - has_initial_state, - activation in ["silu", "swish"], - )) + opcheck(torch.ops._C.causal_conv1d_fwd, + (x, weight, bias, conv_states, cu_seq_len, cache_indices, + has_initial_state, activation in ["silu", "swish"], pad_slot_id)) @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @@ -261,15 +253,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - opcheck(torch.ops._C.causal_conv1d_update, ( - x, - conv_state, - weight, - bias, - activation in ["silu", "swish"], - None, - None, - )) + opcheck(torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation + in ["silu", "swish"], None, None, PAD_SLOT_ID)) @pytest.mark.parametrize("itype", [torch.bfloat16]) @@ -373,15 +359,9 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - opcheck(torch.ops._C.causal_conv1d_update, ( - x, - conv_state, - weight, - bias, - activation in ["silu", "swish"], - None, - conv_state_indices, - )) + opcheck(torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation + in ["silu", "swish"], None, conv_state_indices, PAD_SLOT_ID)) @pytest.mark.parametrize("itype", [torch.bfloat16]) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 5e8fd155e21..b153142487f 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -407,9 +407,10 @@ def test_selective_state_update(dim, dstate, has_z, itype): @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) @pytest.mark.parametrize("with_padding", [False, True]) -def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, varBC_groups, - has_D, has_z, has_delta_bias, delta_softplus, - return_last_state, seqlen, itype, wtype): +def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, + varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, + itype, wtype): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -445,10 +446,8 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, varBC total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) - cumsum = torch.concat([ - torch.tensor([0], dtype=torch.int32), - cumsum - ], dim=0).cuda() + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], + dim=0).cuda() print(cumsum) dim = 4 @@ -477,7 +476,6 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, varBC delta_ref = delta.clone() out = None out_ref = None - prev_state_shape = (total_entries, u.shape[0], int(A.shape[1])) prev_state = torch.randn(prev_state_shape, @@ -489,21 +487,17 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, varBC state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[:batch_size] - unused_states_bool = torch.ones( - total_entries, - dtype=torch.bool, - device=device - ) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) unused_states_bool[state_indices] = False padded_state_indices = torch.concat([ state_indices, torch.as_tensor( - [PAD_SLOT_ID] * padding, - dtype=torch.int32, - device=device - ), - ],dim=-1) - + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1) + has_initial_state = torch.randint(0, 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, @@ -534,7 +528,8 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, varBC return_last_state=return_last_state, prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0) if has_initial_state[i] else None, - final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze(0)) + final_state_out=prev_state_ref[ + padded_state_indices[i]].unsqueeze(0)) outs_ref.append(out_ref_s) out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0] @@ -547,22 +542,18 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, varBC assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol) # test padded entries are correct if with_padding: - assert torch.equal( - prev_state[unused_states_bool], - prev_state_for_padding_test[unused_states_bool] - ) + assert torch.equal(prev_state[unused_states_bool], + prev_state_for_padding_test[unused_states_bool]) if has_z: assert torch.equal(out[batch_size + 1:], z[batch_size + 1:]) else: assert torch.equal(out[batch_size + 1:], delta[batch_size + 1:]) - selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cumsum, padded_state_indices, has_initial_state, prev_state) - @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("has_z", [True]) @@ -584,13 +575,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state_indices = torch.randperm(total_entries)[:batch_size].to( dtype=torch.int32, device=device) - unused_states_bool= torch.ones(total_entries,dtype=torch.bool,device=device) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) unused_states_bool[state_indices] = False - padded_state_indices = torch.concat([state_indices,torch.as_tensor( - [PAD_SLOT_ID] * padding, - dtype=torch.int32, - device=device - )],dim=0) + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) + ], + dim=0) x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) @@ -630,14 +624,12 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) # test padded entries stay the same - assert torch.equal( - state_before[unused_states_bool], - state[unused_states_bool] - ) - assert torch.equal(x[batch_size + 1:],x[batch_size + 1:]) - assert torch.equal(dt[batch_size + 1:],dt[batch_size + 1:]) - assert torch.equal(B[batch_size + 1:],B[batch_size + 1:]) - assert torch.equal(C[batch_size + 1:],C[batch_size + 1:]) + assert torch.equal(state_before[unused_states_bool], + state[unused_states_bool]) + assert torch.equal(x[batch_size + 1:], x[batch_size + 1:]) + assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:]) + assert torch.equal(B[batch_size + 1:], B[batch_size + 1:]) + assert torch.equal(C[batch_size + 1:], C[batch_size + 1:]) # test "real" entries assert torch.allclose(state[state_indices, :], diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 91000677b36..384ec77e545 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -276,12 +276,8 @@ def test_state_cleanup( @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [64]) def test_jamba_distributed_produces_identical_generation( - vllm_runner, - model: str, - dtype: str, - max_tokens: int, - example_prompts -) -> None: + vllm_runner, model: str, dtype: str, max_tokens: int, + example_prompts) -> None: with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model: vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 148125e0045..ce15833682f 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,8 +1,6 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py -from typing import Tuple - import torch import triton import triton.language as tl @@ -343,21 +341,20 @@ def selective_state_update(state, return out -def selective_scan_fn( - u, - ssm_states, - delta, - A, - B, - C, - D=None, - z=None, - delta_bias=None, - delta_softplus=False, - query_start_loc=None, - cache_indices=None, - has_initial_state=None, - pad_slot_id=PAD_SLOT_ID) -> torch.Tensor: +def selective_scan_fn(u, + ssm_states, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + query_start_loc=None, + cache_indices=None, + has_initial_state=None, + pad_slot_id=PAD_SLOT_ID) -> torch.Tensor: """ u: (dim, total_length) for varlen or (batch, dim, seqlen) applies changes in place. diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index e0a37f26de6..72ecbe95b83 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -712,9 +712,8 @@ def _get_mamba_cache_shape( return conv_state_shape, temporal_state_shape def _get_max_batch_size(self): - return (_get_graph_batch_size( - self.scheduler_config.max_num_seqs - ) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) + return (_get_graph_batch_size(self.scheduler_config.max_num_seqs) + if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) def _prepare_mamba_cache(self): dtype = self.lm_head.weight.dtype From a6eab1be20daed6a2e393c51ccf959a60887204b Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 19:03:11 +0300 Subject: [PATCH 11/24] causal_conv1d outputs to x inplace --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 12 ++++------ csrc/ops.h | 28 ++++++++++++----------- csrc/torch_bindings.cpp | 4 ++-- vllm/_custom_ops.py | 25 ++++++++++---------- 4 files changed, 33 insertions(+), 36 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 0235f4248bc..167290a81f8 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -92,8 +92,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, } -at::Tensor -causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, +void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, const c10::optional &bias_, const c10::optional &conv_states, const c10::optional &query_start_loc, @@ -158,7 +157,7 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, CHECK_SHAPE(cache_indices_, batch_size); } - at::Tensor out = torch::empty_like(x); + at::Tensor out = x; ConvParamsBase params; set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, @@ -189,12 +188,10 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { causal_conv1d_fwd_cuda(params, stream); }); - return out; } -at::Tensor -causal_conv1d_update(const at::Tensor &x, +void causal_conv1d_update(const at::Tensor &x, const at::Tensor &conv_state, const at::Tensor &weight, const c10::optional &bias_, @@ -236,7 +233,7 @@ causal_conv1d_update(const at::Tensor &x, CHECK_SHAPE(bias, dim); } - at::Tensor out = torch::empty_like(x); + at::Tensor out = x; ConvParamsBase params; set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, @@ -285,7 +282,6 @@ causal_conv1d_update(const at::Tensor &x, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { causal_conv1d_update_cuda(params, stream); }); - return out; } template diff --git a/csrc/ops.h b/csrc/ops.h index 8bb1252a313..c10c34e0857 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -159,19 +159,21 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const c10::optional& has_initial_state, const torch::Tensor& ssm_states, int64_t pad_slot_id); -at::Tensor causal_conv1d_update( - const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, - const c10::optional& bias_, bool silu_activation, - const c10::optional& cache_seqlens_, - const c10::optional& conv_state_indices_, int64_t pad_slot_id); - -at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, - const c10::optional& bias_, - const c10::optional& conv_states, - const c10::optional& query_start_loc, - const c10::optional& cache_indices, - const c10::optional& has_initial_state, - bool silu_activation, int64_t pad_slot_id); +void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, + const at::Tensor& weight, + const c10::optional& bias_, + bool silu_activation, + const c10::optional& cache_seqlens_, + const c10::optional& conv_state_indices_, + int64_t pad_slot_id); + +void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const c10::optional& bias_, + const c10::optional& conv_states, + const c10::optional& query_start_loc, + const c10::optional& cache_indices, + const c10::optional& has_initial_state, + bool silu_activation, int64_t pad_slot_id); #ifndef USE_ROCM using fptr_t = int64_t; diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d2360d7d228..d69c4e5afb4 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -290,7 +290,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "bool silu_activation," "Tensor? cache_seqlens_," "Tensor? conv_state_indices," - "int pad_slot_id) -> Tensor"); + "int pad_slot_id) -> ()"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( @@ -301,7 +301,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? cache_indices," "Tensor? has_initial_state," "bool silu_activation," - "int pad_slot_id) -> Tensor"); + "int pad_slot_id) -> ()"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3326752a9de..fa02bd00af8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -453,9 +453,8 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, cu_seq_len: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - silu_activation: bool, - pad_slot_id: int) -> torch.Tensor: - return torch.empty_like(x) + silu_activation: bool, pad_slot_id: int): + return None @torch.library.register_fake("_C::causal_conv1d_update") def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, @@ -465,7 +464,7 @@ def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, cache_seqlens: Optional[torch.Tensor], conv_state_indices: Optional[torch.Tensor], pad_slot_id: int) -> torch.Tensor: - return torch.empty_like(x) + return None @torch.library.register_fake("_C::selective_scan_fwd") def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, @@ -793,11 +792,11 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, query_start_loc: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - silu_activation: bool, pad_slot_id: int) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, - query_start_loc, cache_indices, - has_initial_state, silu_activation, - pad_slot_id) + silu_activation: bool, pad_slot_id: int): + torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, + query_start_loc, cache_indices, + has_initial_state, silu_activation, + pad_slot_id) def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, @@ -805,10 +804,10 @@ def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, silu_activation: bool, cache_seqlens: Optional[torch.Tensor], conv_state_indices: Optional[torch.Tensor], - pad_slot_id: int) -> torch.Tensor: - return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation, cache_seqlens, - conv_state_indices, pad_slot_id) + pad_slot_id: int): + torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation, cache_seqlens, + conv_state_indices, pad_slot_id) def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, From f6d3a0581c6739bcc62f6ee3c52bfe9c9f1887f0 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 19:03:36 +0300 Subject: [PATCH 12/24] Fix tests and add with_padding test --- tests/kernels/test_causal_conv1d.py | 229 ++++++++++------------------ 1 file changed, 79 insertions(+), 150 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 65ff1d4bac2..ef8467e8de5 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -226,17 +226,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, seed_everything(0) batch = 2 x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) + x_ref = x.clone() conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype) - weight = torch.randn(dim, - width, - device=device, - dtype=itype, - requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) - else: - bias = None + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state.detach().clone() activation = None if not silu_activation else "silu" out = causal_conv1d_update(x, @@ -244,7 +238,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, weight, bias, activation=activation) - out_ref = causal_conv1d_update_ref(x, + out_ref = causal_conv1d_update_ref(x_ref, conv_state_ref, weight, bias, @@ -258,52 +252,6 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, in ["silu", "swish"], None, None, PAD_SLOT_ID)) -@pytest.mark.parametrize("itype", [torch.bfloat16]) -@pytest.mark.parametrize("silu_activation", [True]) -@pytest.mark.parametrize("has_bias", [True]) -@pytest.mark.parametrize("seqlen", [1]) -@pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize("dim", [2048]) -def test_causal_conv1d_update_with_batch_gather_padding_unchanged( - dim, width, seqlen, has_bias, silu_activation, itype): - device = "cuda" - # set seed - seed_everything(0) - batch = 64 - - x = torch.randn(batch, dim, 1, device=device, dtype=itype) - - total_entries = 10 * batch - conv_state = torch.randn(total_entries, - dim, - width - 1, - device=device, - dtype=itype) - conv_state_before = conv_state.clone() - conv_state_indices = torch.as_tensor([PAD_SLOT_ID] * batch, - dtype=torch.int32, - device=device) - - weight = torch.randn(dim, - width, - device=device, - dtype=itype, - requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) - else: - bias = None - activation = None if not silu_activation else "silu" - causal_conv1d_update(x, - conv_state, - weight, - bias, - activation=activation, - conv_state_indices=conv_state_indices) - - assert torch.equal(conv_state, conv_state_before) - - @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @@ -311,37 +259,47 @@ def test_causal_conv1d_update_with_batch_gather_padding_unchanged( @pytest.mark.parametrize("seqlen", [1, 4, 5]) @pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, +@pytest.mark.parametrize("with_padding", [True, False]) +def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, + seqlen, has_bias, silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 - # set )seed + # set seed seed_everything(0) - batch = 64 - x = torch.randn(batch, dim, 1, device=device, dtype=itype) + batch_size = 3 + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding + total_entries = 10 * batch_size + + x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype) + x_ref = x.clone() - total_entries = 10 * batch + conv_state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[conv_state_indices] = False + padded_state_indices = torch.concat([ + conv_state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) + ], + dim=0) conv_state = torch.randn(total_entries, dim, width - 1, device=device, dtype=itype) - conv_state_indices = torch.randperm(total_entries)[:batch].to( - dtype=torch.int32, device=device) + conv_state_for_padding_test = conv_state.clone() - weight = torch.randn(dim, - width, - device=device, - dtype=itype, - requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) - else: - bias = None + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state[conv_state_indices, :].detach().clone() activation = None if not silu_activation else "silu" out = causal_conv1d_update(x, @@ -349,19 +307,21 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, weight, bias, activation=activation, - conv_state_indices=conv_state_indices) - out_ref = causal_conv1d_update_ref(x, + conv_state_indices=padded_state_indices) + out_ref = causal_conv1d_update_ref(x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) + assert torch.equal(conv_state[unused_states_bool], + conv_state_for_padding_test[unused_states_bool]) opcheck(torch.ops._C.causal_conv1d_update, (x, conv_state, weight, bias, activation - in ["silu", "swish"], None, conv_state_indices, PAD_SLOT_ID)) + in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID)) @pytest.mark.parametrize("itype", [torch.bfloat16]) @@ -371,17 +331,25 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) @pytest.mark.parametrize('dim', [64, 4096]) -def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, - itype): +@pytest.mark.parametrize('with_padding', [True, False]) +def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, + silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed seed_everything(0) - batch = 1 seqlens = [] - nsplits = 3 + batch_size = 4 + if seqlen < 10: + if with_padding: + pytest.skip() + batch_size = 1 + padding = 3 if with_padding else 0 + padded_batch_size = batch_size + padding + nsplits = padded_batch_size - 1 + eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( @@ -391,10 +359,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) + total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0) - x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, + x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :] weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None @@ -402,7 +371,7 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None activation = None if not silu_activation else "silu" - final_states = torch.randn(nsplits + 1, + final_states = torch.randn(total_entries, dim, width - 1, device=x.device, @@ -412,18 +381,31 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, device=x.device) - cache_indices = torch.randperm(cumsum.shape[0] - 1, + state_indices = torch.randperm(total_entries, dtype=torch.int32, - device=x.device) + device=x.device)[:batch_size] + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[state_indices] = False + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1) + out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - cache_indices, has_initial_states, final_states, - activation) + padded_state_indices, has_initial_states, + final_states, activation) out_ref = [] out_ref_b = [] splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] for i in range(len(seqlens[0])): x_s = [v[i].unsqueeze(0) for v in splits][0] + if padded_state_indices[i] == PAD_SLOT_ID: + continue out_ref_b.append( causal_conv1d_ref( x_s, @@ -431,70 +413,17 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, bias_ref, activation=activation, return_final_states=True, - final_states_out=final_states_ref[cache_indices[i]].unsqueeze( - 0), - initial_states=final_states_ref[cache_indices[i]].unsqueeze(0) - if has_initial_states[i] else None)) + final_states_out=final_states_ref[ + padded_state_indices[i]].unsqueeze(0), + initial_states=final_states_ref[padded_state_indices[i]]. + unsqueeze(0) if has_initial_states[i] else None)) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) - out_ref = torch.cat(out_ref, dim=0) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print("Output state max diff" - f":{(final_states - final_states_ref).abs().max()}") - print("Output state mean diff" - f":{(final_states - final_states_ref).abs().mean()}") - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) - causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - cache_indices, has_initial_states, final_states, - activation) - + out_ref_tensor = torch.cat(out_ref, dim=0) -@pytest.mark.parametrize("itype", [torch.bfloat16]) -@pytest.mark.parametrize("silu_activation", [True]) -@pytest.mark.parametrize("has_bias", [True]) -@pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize('seqlen', [256]) -@pytest.mark.parametrize('dim', [64]) -def test_causal_conv1d_varlen_check_padding_state_unchanged( - dim, seqlen, width, has_bias, silu_activation, itype): - device = "cuda" - # set seed - seed_everything(0) - batch = 1 - seqlens = [] - nsplits = 3 - eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values - seqlens.append( - torch.diff( - torch.cat( - [torch.tensor([-1]), eos_pos, - torch.tensor([seqlen - 1])])).tolist()) - assert sum(seqlens[-1]) == seqlen - assert all(s > 0 for s in seqlens[-1]) + unpadded_out = out[:, :out_ref_tensor.shape[-1]] + assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) + assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) - cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) - cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], - dim=0) - x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, - dtype=itype)[:, 4096:4096 + dim, :] - weight = torch.randn(dim, width, device=device, dtype=itype) - bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None - activation = None if not silu_activation else "silu" - final_states = torch.randn(nsplits + 1, - dim, - width - 1, - device=x.device, - dtype=x.dtype) - final_states_before = final_states.clone() - has_initial_states = torch.randint(0, - 2, (cumsum.shape[0] - 1, ), - dtype=torch.bool, - device=x.device) - cache_indices = torch.as_tensor([PAD_SLOT_ID] * (cumsum.shape[0] - 1), - dtype=torch.int32, - device=x.device) - causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), cache_indices, - has_initial_states, final_states, activation) - assert torch.equal(final_states, final_states_before) + causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), + padded_state_indices, has_initial_states, + final_states, activation) From d96bd011904feadf7c03aa0f6c541a9e8aefe85a Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 19:04:01 +0300 Subject: [PATCH 13/24] Fix tests --- tests/kernels/test_mamba_ssm.py | 75 ++++++++++++++------------------- 1 file changed, 32 insertions(+), 43 deletions(-) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index b153142487f..b056bcb5fb3 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -440,7 +440,6 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, [torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist()) - print(seqlens) assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) @@ -448,7 +447,6 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda() - print(cumsum) dim = 4 dstate = 8 @@ -483,7 +481,6 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, dtype=itype, requires_grad=False) prev_state_ref = prev_state.clone() - prev_state_for_padding_test = prev_state.clone() state_indices = torch.randperm(total_entries, dtype=torch.int32, device=u.device)[:batch_size] @@ -513,42 +510,32 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, for i in range(len(seqlens[0])): u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits] if padded_state_indices[i] == PAD_SLOT_ID: - out_ref_s = z_s if has_z else delta_s - else: - out_ref_s, _ = selective_scan_ref( - u_s, - delta_s, - A_ref, - B_s, - C_s, - D_ref, - z=z_s, - delta_bias=delta_bias, - delta_softplus=delta_softplus, - return_last_state=return_last_state, - prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0) - if has_initial_state[i] else None, - final_state_out=prev_state_ref[ - padded_state_indices[i]].unsqueeze(0)) + continue + out_ref_s, _ = selective_scan_ref( + u_s, + delta_s, + A_ref, + B_s, + C_s, + D_ref, + z=z_s, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state, + prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0) + if has_initial_state[i] else None, + final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze( + 0)) outs_ref.append(out_ref_s) - out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0] + out_ref = torch.cat(outs_ref, dim=-1)[0] - # test "real" entries are correct - print("Output diff max", (out - out_ref[0]).max()) - print("Output diff mean", (out - out_ref[0]).mean()) + unpadded_out = out[:, :out_ref[0].shape[-1]] + print("Output diff max", (unpadded_out - out_ref).max()) + print("Output diff mean", (unpadded_out - out_ref).mean()) print("Output state diff max", (prev_state - prev_state_ref).max()) print("Output state diff mean", (prev_state - prev_state_ref).mean()) assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol) - assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol) - # test padded entries are correct - if with_padding: - assert torch.equal(prev_state[unused_states_bool], - prev_state_for_padding_test[unused_states_bool]) - if has_z: - assert torch.equal(out[batch_size + 1:], z[batch_size + 1:]) - else: - assert torch.equal(out[batch_size + 1:], delta[batch_size + 1:]) - + assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol) selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, delta_softplus, cumsum, padded_state_indices, has_initial_state, prev_state) @@ -559,7 +546,9 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, @pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): +@pytest.mark.parametrize("with_padding", [True, False]) +def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, + has_z, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: @@ -569,7 +558,7 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): # set seed torch.random.manual_seed(0) batch_size = 3 - padding = 5 + padding = 5 if with_padding else 0 padded_batch_size = batch_size + padding total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) @@ -585,7 +574,6 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) ], dim=0) - x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) dt_bias = torch.rand(dim, device=device) - 4.0 @@ -624,12 +612,13 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) # test padded entries stay the same - assert torch.equal(state_before[unused_states_bool], - state[unused_states_bool]) - assert torch.equal(x[batch_size + 1:], x[batch_size + 1:]) - assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:]) - assert torch.equal(B[batch_size + 1:], B[batch_size + 1:]) - assert torch.equal(C[batch_size + 1:], C[batch_size + 1:]) + if with_padding: + assert torch.equal(state_before[unused_states_bool], + state[unused_states_bool]) + assert torch.equal(x[batch_size + 1:], x[batch_size + 1:]) + assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:]) + assert torch.equal(B[batch_size + 1:], B[batch_size + 1:]) + assert torch.equal(C[batch_size + 1:], C[batch_size + 1:]) # test "real" entries assert torch.allclose(state[state_indices, :], From a663fa8b4c95f015487e4f7b1a788270522e6283 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 19:04:09 +0300 Subject: [PATCH 14/24] Return none --- .../layers/mamba/ops/causal_conv1d.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index fffd3704432..ecf35a14b41 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -53,10 +53,10 @@ def causal_conv1d_fn(x: torch.Tensor, x = x.contiguous() bias = bias.contiguous() if bias is not None else None - out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, - cache_indices, has_initial_state, activation - in ["silu", "swish"], pad_slot_id) - return out + ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, + cache_indices, has_initial_state, activation + in ["silu", "swish"], pad_slot_id) + return x def causal_conv1d_update(x: torch.Tensor, @@ -95,9 +95,8 @@ def causal_conv1d_update(x: torch.Tensor, unsqueeze = x.dim() == 2 if unsqueeze: x = x.unsqueeze(-1) - out = ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, - cache_seqlens, conv_state_indices, - pad_slot_id) + ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, + cache_seqlens, conv_state_indices, pad_slot_id) if unsqueeze: - out = out.squeeze(-1) - return out + x = x.squeeze(-1) + return x From 69ebab82ae930f5457c51970ad75d3de311d66b3 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 19:11:18 +0300 Subject: [PATCH 15/24] fix typo --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 4 ++-- csrc/mamba/mamba_ssm/selective_scan_fwd.cu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 167290a81f8..1317e0e30b9 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -100,7 +100,7 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, const c10::optional &has_initial_state, bool silu_activation, // used to identify padding entries if cache_indices provided - // incase of padding, the kernel will return early + // in case of padding, the kernel will return early int64_t pad_slot_id) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); @@ -199,7 +199,7 @@ void causal_conv1d_update(const at::Tensor &x, const c10::optional &cache_seqlens_, const c10::optional &conv_state_indices_, // used to identify padding entries if cache_indices provided - // incase of padding, the kernel will return early + // in case of padding, the kernel will return early int64_t pad_slot_id) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index f5857829025..71624696338 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -513,7 +513,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const c10::optional &has_initial_state, const torch::Tensor &ssm_states, // used to identify padding entries if cache_indices provided - // incase of padding, the kernel will return early + // in case of padding, the kernel will return early int64_t pad_slot_id) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); From 906379d1ffa10c894b21d5657ddba1898d515702 Mon Sep 17 00:00:00 2001 From: mzusman Date: Thu, 10 Oct 2024 19:17:10 +0300 Subject: [PATCH 16/24] remove diffs and use pad_slot_id as var in tests --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 4 ++-- tests/kernels/test_causal_conv1d.py | 5 +++-- tests/kernels/test_mamba_ssm.py | 6 ++++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 1317e0e30b9..a12230764f9 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -58,8 +58,8 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, int64_t pad_slot_id, const c10::optional& query_start_loc = std::nullopt, const c10::optional& cache_indices = std::nullopt, - const c10::optional& has_initial_state = std::nullopt - ) { + const c10::optional& has_initial_state = std::nullopt) { + // Reset the parameters memset(¶ms, 0, sizeof(params)); diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index ef8467e8de5..e099a28d2b4 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -307,7 +307,8 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, weight, bias, activation=activation, - conv_state_indices=padded_state_indices) + conv_state_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID) out_ref = causal_conv1d_update_ref(x_ref[:batch_size], conv_state_ref, weight, @@ -397,7 +398,7 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), padded_state_indices, has_initial_states, - final_states, activation) + final_states, activation, PAD_SLOT_ID) out_ref = [] out_ref_b = [] diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index b056bcb5fb3..ffe06625bf8 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -594,7 +594,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, z=z, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=padded_state_indices) + state_batch_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID) out_ref = selective_state_update_ref(state_ref, x[:batch_size], dt[:batch_size], @@ -694,7 +695,8 @@ def test_selective_state_update_with_heads_with_batch_indices( z=z, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices) + state_batch_indices=state_indices, + pad_slot_id=PAD_SLOT_ID) out_ref = selective_state_update_ref(state_ref, x, dt, From 94fe81909a53c2b3c7040d884c705b9460fffd01 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 13 Oct 2024 13:37:27 +0300 Subject: [PATCH 17/24] Fix tests --- vllm/model_executor/models/jamba.py | 12 ++++++------ vllm/model_executor/models/mamba.py | 22 ++++++++++++---------- vllm/model_executor/models/mamba_cache.py | 21 +++++++++++---------- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 90e8d01f0ad..400519468e0 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -41,6 +41,7 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer class JambaMambaMixer(nn.Module): """ @@ -577,13 +578,12 @@ def forward(self, self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, *self._get_mamba_cache_shape()) - mamba_cache_tensors, state_indices_tensor= self.mamba_cache.current_run_tensors(input_ids, attn_metadata, **kwargs) + mamba_cache_tensors, state_indices_tensor = self.mamba_cache.current_run_tensors( + input_ids, attn_metadata, **kwargs) - mamba_cache_params = MambaCacheParams( - mamba_cache_tensors[0], - mamba_cache_tensors[1], - state_indices_tensor - ) + mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], + mamba_cache_tensors[1], + state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_params) return hidden_states diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 2c5d83b9664..db61be87b33 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -261,6 +261,7 @@ def forward( attn_metadata: AttentionMetadata, residual: Optional[torch.Tensor], mamba_cache_params: MambaCacheParams, + **kwargs, ): if residual is None: residual = hidden_states @@ -268,11 +269,8 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer( - hidden_states, - attn_metadata, - mamba_cache_params - ) + hidden_states = self.mixer(hidden_states, attn_metadata, + mamba_cache_params) # Fully Connected hidden_states, residual = self.pre_ff_layernorm( hidden_states, residual) @@ -318,6 +316,7 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, ) -> torch.Tensor: @@ -332,8 +331,7 @@ def forward( hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, - mamba_cache_params=mamba_cache_params.at_layer_idx(i) - ) + mamba_cache_params=mamba_cache_params.at_layer_idx(i)) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states @@ -407,12 +405,16 @@ def forward(self, self.lm_head.weight.dtype, self.config.num_hidden_layers, max_batch_size, *self._get_mamba_cache_shape()) - mamba_cache_tensors = self.mamba_cache.current_run_tensors( + + mamba_cache_tensors, state_indices_tensor = self.mamba_cache.current_run_tensors( input_ids, attn_metadata, **kwargs) + mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], + mamba_cache_tensors[1], + state_indices_tensor) + hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_tensors[0], - mamba_cache_tensors[1]) + attn_metadata, mamba_cache_params) return hidden_states diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 8cc8fc39e51..cceaa841698 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -6,6 +6,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.utils import PAD_SLOT_ID + @dataclass class MambaCacheParams: conv_state: torch.Tensor = torch.Tensor() @@ -80,9 +81,7 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): self._release_finished_requests(finished_requests_ids) state_indices = self._prepare_current_run_mamba_cache( - request_ids_to_seq_ids, - finished_requests_ids - ) + request_ids_to_seq_ids, finished_requests_ids) cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len( state_indices) state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len) @@ -90,7 +89,6 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): input_state_indices_buffer.copy_( torch.as_tensor(state_indices, dtype=torch.int32, device="cuda")) - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): """ Provide the CUDA graph capture runs with a buffer in adjusted size. @@ -98,8 +96,8 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int): replay runs. """ state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size, - dtype=torch.int32, - device="cuda") + dtype=torch.int32, + device="cuda") return (self.mamba_cache, state_indices_tensor) def _copy_mamba_cache(self, from_index: int, to_index: int): @@ -119,7 +117,9 @@ def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, return PAD_SLOT_ID elif cur_rid not in self.mamba_cache_indices_mapping: destination_index = self.free_cache_indices.pop() - self.mamba_cache_indices_mapping[cur_rid] = {seq_id: destination_index} + self.mamba_cache_indices_mapping[cur_rid] = { + seq_id: destination_index + } return destination_index elif seq_id not in (seq_ids2indices := self.mamba_cache_indices_mapping[cur_rid]): @@ -131,7 +131,8 @@ def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int, destination_index = self.free_cache_indices.pop() self._copy_mamba_cache(from_index=index_exists, to_index=destination_index) - self.mamba_cache_indices_mapping[cur_rid][seq_id] = destination_index + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = destination_index return destination_index else: # already exists @@ -154,8 +155,8 @@ def _get_all_occupied_indices(self): for cache_idx in seq_ids2indices.values() ] - - def _release_finished_requests(self, finished_seq_groups_req_ids: List[str]): + def _release_finished_requests(self, + finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: if req_id in self.mamba_cache_indices_mapping: for seq_id in self.mamba_cache_indices_mapping[req_id]: From 15442317b05a61617db6582c22bad3f002859b73 Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 13 Oct 2024 13:41:54 +0300 Subject: [PATCH 18/24] format --- vllm/model_executor/models/jamba.py | 18 +++++++++++------- vllm/model_executor/models/mamba.py | 11 ++++++----- vllm/model_executor/models/mamba_cache.py | 9 +-------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 400519468e0..b6e8d5c36cd 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,6 +1,5 @@ # coding=utf-8 """Inference-only Jamba model.""" -from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple import torch @@ -8,7 +7,6 @@ from transformers import JambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.layer import Attention from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -30,7 +28,10 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.mamba_cache import MambaCacheManager, MambaCacheParams +from vllm.model_executor.models.mamba_cache import ( + MambaCacheManager, + MambaCacheParams, +) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors @@ -577,10 +578,13 @@ def forward(self, self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, *self._get_mamba_cache_shape()) - - mamba_cache_tensors, state_indices_tensor = self.mamba_cache.current_run_tensors( - input_ids, attn_metadata, **kwargs) - + (mamba_cache_tensors, + state_indices_tensor, + ) = self.mamba_cache.current_run_tensors( + input_ids, + attn_metadata, + **kwargs + ) mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], state_indices_tensor) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index db61be87b33..1979b0a14e3 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,6 +1,5 @@ # coding=utf-8 """PyTorch MAMBA model.""" -from dataclasses import dataclass from typing import Iterable, List, Optional, Tuple import torch @@ -405,16 +404,18 @@ def forward(self, self.lm_head.weight.dtype, self.config.num_hidden_layers, max_batch_size, *self._get_mamba_cache_shape()) - - mamba_cache_tensors, state_indices_tensor = self.mamba_cache.current_run_tensors( - input_ids, attn_metadata, **kwargs) + ( + mamba_cache_tensors, + state_indices_tensor, + ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, + **kwargs) mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], state_indices_tensor) hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params) + attn_metadata, mamba_cache_params) return hidden_states diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index cceaa841698..79393421f3a 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, List import torch @@ -148,13 +148,6 @@ def _prepare_current_run_mamba_cache( for seq_id in seq_ids ] - def _get_all_occupied_indices(self): - return [ - cache_idx - for seq_ids2indices in self.mamba_cache_indices_mapping.values() - for cache_idx in seq_ids2indices.values() - ] - def _release_finished_requests(self, finished_seq_groups_req_ids: List[str]): for req_id in finished_seq_groups_req_ids: From 158f22dff4391fc154432976e55e72da4d23889a Mon Sep 17 00:00:00 2001 From: mzusman Date: Sun, 13 Oct 2024 13:42:58 +0300 Subject: [PATCH 19/24] Format --- vllm/model_executor/models/jamba.py | 16 ++++++---------- vllm/model_executor/models/mamba.py | 6 ++---- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b6e8d5c36cd..fddd39fb8c8 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -28,10 +28,8 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( composed_weight_loader, default_weight_loader, sharded_weight_loader) -from vllm.model_executor.models.mamba_cache import ( - MambaCacheManager, - MambaCacheParams, -) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors @@ -578,13 +576,11 @@ def forward(self, self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, *self._get_mamba_cache_shape()) - (mamba_cache_tensors, + ( + mamba_cache_tensors, state_indices_tensor, - ) = self.mamba_cache.current_run_tensors( - input_ids, - attn_metadata, - **kwargs - ) + ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata, + **kwargs) mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0], mamba_cache_tensors[1], state_indices_tensor) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 1979b0a14e3..9c8082c53f2 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -28,10 +28,8 @@ composed_weight_loader, default_weight_loader, sharded_weight_loader) from vllm.model_executor.models.interfaces import (HasInnerState, IsAttentionFree) -from vllm.model_executor.models.mamba_cache import ( - MambaCacheManager, - MambaCacheParams, -) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors From 40d14eeed8b8ce39cee3eea6b4c4b3df7f34d1b8 Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 16 Oct 2024 08:01:23 +0800 Subject: [PATCH 20/24] Address review comments --- csrc/mamba/causal_conv1d/causal_conv1d.cu | 3 +- tests/kernels/test_causal_conv1d.py | 6 ++-- tests/kernels/test_mamba_ssm.py | 7 +++-- .../layers/mamba/ops/causal_conv1d.py | 2 +- .../layers/mamba/ops/mamba_ssm.py | 31 +++++++------------ 5 files changed, 21 insertions(+), 28 deletions(-) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index a12230764f9..3a464c5f327 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -239,8 +239,7 @@ void causal_conv1d_update(const at::Tensor &x, set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, silu_activation, - pad_slot_id - ); + pad_slot_id); params.conv_state_ptr = conv_state.data_ptr(); params.conv_state_len = conv_state_len; // All stride are in elements, not bytes. diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index e099a28d2b4..868b8fec29c 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -259,7 +259,8 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, @pytest.mark.parametrize("seqlen", [1, 4, 5]) @pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -@pytest.mark.parametrize("with_padding", [True, False]) +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, seqlen, has_bias, silu_activation, itype): @@ -332,6 +333,7 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) @pytest.mark.parametrize('dim', [64, 4096]) +# tests correctness in case subset of the sequences are padded @pytest.mark.parametrize('with_padding', [True, False]) def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, silu_activation, itype): @@ -344,8 +346,6 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, seqlens = [] batch_size = 4 if seqlen < 10: - if with_padding: - pytest.skip() batch_size = 1 padding = 3 if with_padding else 0 padded_batch_size = batch_size + padding diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index ffe06625bf8..e92d401368a 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -406,6 +406,7 @@ def test_selective_state_update(dim, dstate, has_z, itype): @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) +# tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [False, True]) def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, varBC_groups, has_D, has_z, has_delta_bias, @@ -426,12 +427,13 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, seqlens = [] batch_size = 4 if seqlen < 10: - if with_padding: - pytest.skip() batch_size = 1 padding = 3 if with_padding else 0 padded_batch_size = batch_size + padding + if with_padding and seqlen < padded_batch_size: + pytest.skip() + nsplits = padded_batch_size - 1 eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( @@ -546,6 +548,7 @@ def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, @pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +# tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, has_z, itype): diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index ecf35a14b41..be5639df985 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -40,7 +40,7 @@ def causal_conv1d_fn(x: torch.Tensor, pad_slot_id: int if cache_indices is passed, lets the kernel identify padded entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index ce15833682f..609cc2811c2 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -143,18 +143,11 @@ def _selective_scan_update_kernel( if HAS_Z: z_ptrs = z_ptr + offs_m * stride_z_dim out_ptrs = out_ptr + offs_m * stride_out_dim - + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: - state = tl.load(state_ptrs, - mask=(offs_m[:, None] < dim) & - (offs_n[None, :] < dstate) & - (state_batch_idx != pad_slot_id), - other=0.0) - else: - state = tl.load(state_ptrs, - mask=(offs_m[:, None] < dim) & - (offs_n[None, :] < dstate), - other=0.0) + mask &= (state_batch_idx != pad_slot_id) + state = tl.load(state_ptrs, mask= mask, other=0.0) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) @@ -185,15 +178,13 @@ def _selective_scan_update_kernel( dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt state = state * dA + dB * x[:, None] + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: - tl.store(state_ptrs, - state, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) & - (state_batch_idx != pad_slot_id)) - else: - tl.store(state_ptrs, - state, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + mask &= (state_batch_idx != pad_slot_id) + tl.store(state_ptrs, + state, + mask= mask) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D @@ -228,7 +219,7 @@ def selective_state_update(state, pad_slot_id: int if cache_indices is passed, lets the kernel identify padded entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 Return: From 73252549169ef6b1e5f714b0e89f4b3cc643dcfa Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 16 Oct 2024 08:15:40 +0800 Subject: [PATCH 21/24] Format --- tests/kernels/test_causal_conv1d.py | 2 +- vllm/model_executor/layers/mamba/ops/mamba_ssm.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 868b8fec29c..32d4a090781 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -260,7 +260,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, @pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) # tests correctness in case subset of the sequences are padded -@pytest.mark.parametrize("with_padding", [True, False]) +@pytest.mark.parametrize("with_padding", [True, False]) def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, seqlen, has_bias, silu_activation, itype): diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 609cc2811c2..1484b79815a 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -146,7 +146,7 @@ def _selective_scan_update_kernel( mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: mask &= (state_batch_idx != pad_slot_id) - state = tl.load(state_ptrs, mask= mask, other=0.0) + state = tl.load(state_ptrs, mask=mask, other=0.0) x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) if not TIE_HDIM: @@ -182,9 +182,7 @@ def _selective_scan_update_kernel( mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) if HAS_STATE_BATCH_INDICES: mask &= (state_batch_idx != pad_slot_id) - tl.store(state_ptrs, - state, - mask= mask) + tl.store(state_ptrs, state, mask=mask) out = tl.sum(state * C[None, :], axis=1) if HAS_D: out += x * D From f4c1198d79447d60c2753e02254583ad9be8774e Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 16 Oct 2024 08:33:08 +0800 Subject: [PATCH 22/24] Fix mamba --- vllm/model_executor/models/mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 217d9149286..7f2efb9895f 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -354,8 +354,8 @@ def forward(self, mamba_cache_tensors[1], state_indices_tensor) - hidden_states = self.backbone(input_ids, positions, kv_caches, - attn_metadata, mamba_cache_params) + hidden_states = self.backbone(input_ids, positions, attn_metadata, + mamba_cache_params) return hidden_states From d1c5c3203c6bc9bd73ad5229d6e6a004c4dd446f Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 16 Oct 2024 09:06:08 +0800 Subject: [PATCH 23/24] add cache empty for consistency --- tests/kernels/test_causal_conv1d.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 32d4a090781..db12c2a5cba 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -331,13 +331,14 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize('seqlen', - [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) + [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]) @pytest.mark.parametrize('dim', [64, 4096]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize('with_padding', [True, False]) def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, silu_activation, itype): device = "cuda" + torch.cuda.empty_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 @@ -385,16 +386,11 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[:batch_size] - unused_states_bool = torch.ones(total_entries, - dtype=torch.bool, - device=device) - unused_states_bool[state_indices] = False padded_state_indices = torch.concat([ state_indices, torch.as_tensor( [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), - ], - dim=-1) + ], dim=-1) out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), padded_state_indices, has_initial_states, From 9905319a7b956bdfabaa6288b7d9202dd58505a7 Mon Sep 17 00:00:00 2001 From: mzusman Date: Wed, 16 Oct 2024 09:40:37 +0800 Subject: [PATCH 24/24] Format --- tests/kernels/test_causal_conv1d.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index db12c2a5cba..277d7e4977d 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -330,8 +330,8 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize('seqlen', - [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]) +@pytest.mark.parametrize( + 'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]) @pytest.mark.parametrize('dim', [64, 4096]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize('with_padding', [True, False]) @@ -390,7 +390,8 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, state_indices, torch.as_tensor( [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), - ], dim=-1) + ], + dim=-1) out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), padded_state_indices, has_initial_states,