From 1fd4882db53d3a5f20a1e9a5d873cc527fb894b9 Mon Sep 17 00:00:00 2001 From: ilanaizelman Date: Thu, 18 Jul 2024 15:14:02 +0300 Subject: [PATCH 1/2] Implemented decoding N tokens, instead of one token after another. --- mamba_ssm/modules/mamba_simple.py | 155 ++++++++++++++++++++++++------ 1 file changed, 127 insertions(+), 28 deletions(-) diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 4c8a3882..2c9a6f0a 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -126,7 +126,7 @@ def forward(self, hidden_states, inference_params=None): conv_state, ssm_state = None, None if inference_params is not None: conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: + if inference_params.seqlen_offset > 0 and seqlen == 1: # The states are updated inplace out, _, _ = self.step(hidden_states, conv_state, ssm_state) return out @@ -161,20 +161,27 @@ def forward(self, hidden_states, inference_params=None): else: x, z = xz.chunk(2, dim=1) # Compute short convolution - if conv_state is not None: - # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) - if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - x = causal_conv1d_fn( - x=x, - weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), - bias=self.conv1d.bias, - activation=self.activation, - ) + # if conv_state is not None: + # # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv + # # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. + # conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W) + + # if conv_state is None: + # conv_state = torch.zeros(b, d * self.expand, self.d_conv, device = x.device) + x = torch.cat((conv_state, x), dim=2) + conv_state.copy_(x[:, :, -self.d_conv:]) + x = self.act(self.conv1d(x)[:, :, self.d_conv:self.d_conv + seqlen]) + + # if causal_conv1d_fn is None: + # x = self.act(self.conv1d(x)[..., :seqlen]) + # else: + # assert self.activation in ["silu", "swish"] + # x = causal_conv1d_fn( + # x=x, + # weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), + # bias=self.conv1d.bias, + # activation=self.activation, + # ) # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension @@ -186,24 +193,116 @@ def forward(self, hidden_states, inference_params=None): B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() assert self.activation in ["silu", "swish"] - y = selective_scan_fn( - x, - dt, - A, - B, - C, - self.D.float(), - z=z, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=ssm_state is not None, - ) + # y = selective_scan_fn( + # x, + # dt, + # A, + # B, + # C, + # self.D.float(), + # z=z, + # delta_bias=self.dt_proj.bias.float(), + # delta_softplus=True, + # return_last_state=ssm_state is not None, + # ) + y = self.selective_scan(x.float(), dt.float(), A.float(), B.float(), C.float(), self.D.float(), z=z.float(), delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, ssm_state=ssm_state.float()) if ssm_state is not None: y, last_state = y ssm_state.copy_(last_state) y = rearrange(y, "b d l -> b l d") - out = self.out_proj(y) + out = self.out_proj(y.to(torch.bfloat16)) return out + + def selective_scan(self, x, delta, A, B, C, D, z=None, delta_bias=None, delta_softplus=False, ssm_state=None): + x = rearrange(x, 'b d l -> b l d') + B = rearrange(B, 'b d l -> b l d') + C = rearrange(C, 'b d l -> b l d') + + (b, l, hidden_dim) = x.shape + n = A.shape[1] + + if delta_bias is not None: + delta = delta + delta_bias[..., None] + if delta_softplus: + delta = F.softplus(delta) + + # Discretize continuous parameters (A, B) + # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1]) + # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors: + # "A is the more important term and the performance doesn't change much with the simplification on B" + # Perform selective scan (see scan_SSM() in The Annotated S4 [2]) + # Note that the below is sequential, while the official implementation does a much faster parallel scan that + # is additionally hardware-aware (like FlashAttention). + deltaA = torch.exp(torch.einsum('bdl,dn->bldn', delta, A)) + deltaBX = torch.einsum('bdl,bln,bld->bldn', delta, B, x) + y = torch.empty((b, l, hidden_dim), device=deltaA.device) + # -- Build h -- + if ssm_state is not None: + h = ssm_state + else: + h = torch.zeros((b, hidden_dim, n), device=deltaA.device, dtype=deltaA.dtype) + # -- TODO, how to fast it in parallel? -- + for i in range(l): + h = deltaA[:, i] * h + deltaBX[:, i] + + y[:,i] = torch.einsum('bdn,bn->bd', h, C[:, i]) # BUG? the h is not h(t), it is already set to h(t+1) in prev line + + # -- Save h -- + # if state is not None: + # ssm_state.copy_(h) + out = y + x * D + out = rearrange(out, 'b l d -> b d l') + if z is not None: + out = out * F.silu(z) + return out, h + + def step(self, hidden_states, conv_state, ssm_state): + dtype = hidden_states.dtype + assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" + xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) + x, z = xz.chunk(2, dim=-1) # (B D) + + # Conv step + if causal_conv1d_update is None: + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + x = x + self.conv1d.bias + x = self.act(x).to(dtype=dtype) + else: + x = causal_conv1d_update( + x, + conv_state, + rearrange(self.conv1d.weight, "d 1 w -> d w"), + self.conv1d.bias, + self.activation, + ) + + x_db = self.x_proj(x) # (B dt_rank+2*d_state) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + # Don't add dt_bias here + dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + + # SSM step + if selective_state_update is None: + # Discretize A and B + dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + dB = torch.einsum("bd,bn->bdn", dt, B) + ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) + y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) + y = y + self.D.to(dtype) * x + y = y * self.act(z) # (B D) + else: + y = selective_state_update( + ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + + out = self.out_proj(y) + return out.unsqueeze(1), conv_state, ssm_state def step(self, hidden_states, conv_state, ssm_state): dtype = hidden_states.dtype From 678ce14e3eca5a67d92c363f25c6d292739663b3 Mon Sep 17 00:00:00 2001 From: ilanaizelman Date: Thu, 18 Jul 2024 15:16:36 +0300 Subject: [PATCH 2/2] added code which shows that chunked decoding works, and also latency test that it's much faster. --- test_mamba_ssm_state.py | 132 ++++++++++++++++++++++++++++++++++++++++ time_measure.py | 64 +++++++++++++++++++ 2 files changed, 196 insertions(+) create mode 100644 test_mamba_ssm_state.py create mode 100644 time_measure.py diff --git a/test_mamba_ssm_state.py b/test_mamba_ssm_state.py new file mode 100644 index 00000000..4ea1cb23 --- /dev/null +++ b/test_mamba_ssm_state.py @@ -0,0 +1,132 @@ +from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.utils.generation import InferenceParams + +import torch +import copy + + +def test_state_seq(): + """Check that Mamba([x1.x2.x3.x4]) == Mamba([x1,x2])|>step(x3)|>step(x4)""" + device = "cuda:0" + dim_model = 8 + + # Generate a model with random weights + m = Mamba(dim_model, layer_idx=7).to(device=device) + m.requires_grad_(False) # allows deepcopy of tensrors + + # Generate the whole sequence + x_all = torch.rand(1, 5, dim_model, device=device) + y_all = m(x_all) + + # Introducing empty inference parameters should not add new data + inference_all = InferenceParams(max_seqlen=16, max_batch_size=3) + y_with_inference = m(x_all, inference_params=inference_all) + + assert len(inference_all.key_value_memory_dict) + assert torch.allclose(y_with_inference, y_all) + + # Inference by parts + # X0,X1 + inference_part = InferenceParams( + max_seqlen=inference_all.max_seqlen, max_batch_size=inference_all.max_batch_size) + y01 = m(x_all[:, 0:2], inference_params=inference_part) + assert torch.allclose(y_with_inference[:, :2], y01) + + # (past state up to X1), X2, X3 + inference_part.seqlen_offset = 2 + inference_part_b = copy.deepcopy(inference_part) + y2 = m(x_all[:, 2:4], inference_params=inference_part) + + # (past state up to X3), X4 + inference_part.seqlen_offset = 4 + y3 = m(x_all[:, 4:5], inference_params=inference_part) + + # (past state up to X1), X2 again + inference_part_b.seqlen_offset = 2 + y2_b = m(x_all[:, 2:3], inference_params=inference_part_b) + # (past state up to X2), X3 again + inference_part_b.seqlen_offset = 3 + y3_b = m(x_all[:, 3:4], inference_params=inference_part_b) + + # Values should match result we got from inferencin over the all sequence + assert torch.allclose(y_all[:, 0:2], y01) + assert torch.allclose(y_all[:, 2:4], y2) #Decode chunk - Finally works. + assert torch.allclose(y_all[:, 4:5], y3) + assert torch.allclose(y_all[:, 2:3], y2_b) + assert torch.allclose(y_all[:, 3:4], y3_b) + + # Sanity check + assert not torch.allclose(y_all[:, 3:4], y2) + + +def test_state_batch_drop_empty_infer(): + """Check that you can drop a batch when inference parms are empty""" + device = "cuda" + dim_model = 8 + + # Generate a model with random weights + m = Mamba(dim_model, layer_idx=7).to(device=device) + m.requires_grad_(False) # allows deepcopy of tensrors + + x_all = torch.rand(3, 4, dim_model, device=device) + y_all = m(x_all) + + # Introducing empty inference parameters should not add new data + inference_all = InferenceParams(max_seqlen=16, max_batch_size=3) + y_all = m(x_all, inference_params=inference_all) + kv = inference_all.key_value_memory_dict[7] + + # Drop batch in the middle + x_02 = x_all[(0, 2), ...] + kv = tuple(batched[(0, 2), ...] for batched in kv) + inference_all.key_value_memory_dict[7] = kv + + inference_02 = InferenceParams(max_seqlen=16, max_batch_size=3) + y_02 = m(x_02, inference_params=inference_02) + y_02_a = y_all[(0, 2), ...] + assert torch.allclose(y_02, y_02_a) + + +def test_state_batch_drop_step(): + """Check that you can drop a batch when inference parms are filled""" + + device = "cuda" + dim_model = 8 + + # Generate a model with random weights + m = Mamba(dim_model, layer_idx=7).to(device=device) + m.requires_grad_(False) # allows deepcopy of tensrors + + x_prefix = torch.rand(3, 4, dim_model, device=device) + + # Rewind model forward so inference parms has data + inference_parms = InferenceParams(max_seqlen=16, max_batch_size=3) + _ = m(x_prefix, inference_params=inference_parms) + + x_next = torch.rand(3, 1, dim_model, device=device) + inference_parms.seqlen_offset = x_prefix.shape[1] + inference_parms_bak = copy.deepcopy(inference_parms) + + # Y with all 3 batches + y_next = m(x_next, inference_params=inference_parms) + + # Remove middle batch from cache + kv = inference_parms_bak.key_value_memory_dict[7] + kv = tuple(batched[(0, 2), ...] for batched in kv) + inference_parms_bak.key_value_memory_dict[7] = kv + + # Calculate batches without middle batch + x_02 = x_next[(0, 2), ...] + y_next_parmed = m(x_02, inference_params=inference_parms_bak) + + # Check that batch was removed + y_next_a = y_next[(0, 2), ...] + assert torch.allclose(y_next_a, y_next_parmed) + + # Sanity check + assert not torch.allclose(y_next[(0, 1), ...], y_next_parmed) + + +test_state_seq() +# test_state_batch_drop_empty_infer() +# test_state_batch_drop_step() \ No newline at end of file diff --git a/time_measure.py b/time_measure.py new file mode 100644 index 00000000..3894282f --- /dev/null +++ b/time_measure.py @@ -0,0 +1,64 @@ +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from mamba_ssm.utils.generation import InferenceParams +from mamba_ssm import Mamba +import time +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel + + +device = "cuda:5" +x = 'hello world' +x2 = 'hello world2' +chunk = 'hello world ' * 300 +question = 'What is the meaning of life?' +dtype=torch.bfloat16 +repeats = 5 +x3 = torch.rand(2, 64, 16, device=device) +torch.random.manual_seed(0) +model_path = 'state-spaces/mamba-2.8b' +model = MambaLMHeadModel.from_pretrained(model_path, device=device, dtype=dtype) +tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", device=device, dtype=dtype) +model.requires_grad_(False) +model.eval() +input = tokenizer(chunk, truncation=False, return_tensors="pt").to(device) +input2 = tokenizer(question, truncation=False, return_tensors="pt").to(device) +input_ids2 = torch.randint(high=32000, size=(20, 20), dtype=torch.long, device=device) + +times = 0 +max_length = input2['input_ids'].shape[1] + 100 +with torch.inference_mode(): + + for repeat in range(repeats): + #x = torch.randint(high=32000, size=(1, 10000), device=device) + inference_all = InferenceParams(max_seqlen=5000, max_batch_size=20) + inference_all.seqlen_offset = input['input_ids'].shape[1] + # y_all = model.generate(input_ids=input2['input_ids'], max_length=max_length, cg=True, + # return_dict_in_generate=True, output_scores=True) + y_all = model(input_ids=input['input_ids']) + t1 = time.time() + for i in range(input_ids2.shape[1]): #for i in range(input2['input_ids'].shape[1]): + inference_all.seqlen_offset = input_ids2.shape[1] + i + 1 #input['input_ids'].shape[1] + i + 1 + y_with_inference = model(input_ids=input_ids2[:, i:i+1], inference_params=inference_all) #model(input_ids=input2['input_ids'][:, i:i+1], inference_params=inference_all) + t2 = time.time() + inference_time = t2 - t1 + times += inference_time + print(f'{model_path}, Forward: inference time: {inference_time}') + times /= repeats + print(f"1: Average time is : {times}") + + times = 0 + for repeat in range(repeats): + #x = torch.randint(high=32000, size=(1, 10000), device=device) + inference_all = InferenceParams(max_seqlen=5000, max_batch_size=20) + inference_all.seqlen_offset = input['input_ids'].shape[1] + y_all = model(input_ids=input_ids2) + t1 = time.time() + y_with_inference = model(input_ids=input_ids2, inference_params=inference_all) #model(input_ids=input2['input_ids'], inference_params=inference_all) + t2 = time.time() + inference_time = t2 - t1 + times += inference_time + print(f'{model_path}, Forward: inference time: {inference_time}') + times /= repeats + print(f"2: Average time is : {times}")