diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 5b3ddfcf..a6f064a2 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -225,7 +225,7 @@ def tie_weights(self): def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, attention_mask=None, labels=None): """ "position_ids" is just to be compatible with Transformer generation. We don't use it. num_last_tokens: if > 0, only return the logits for the last n tokens @@ -234,8 +234,25 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_ if num_last_tokens > 0: hidden_states = hidden_states[:, -num_last_tokens:] lm_logits = self.lm_head(hidden_states) - CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) - return CausalLMOutput(logits=lm_logits) + + # Source: https://github.com/huggingface/transformers/blob/80377eb018c077dba434bc8e7912bcaed3a64d09/src/transformers/models/llama/modeling_llama.py#L1196 + if labels is not None: + logits = lm_logits + vocab_size = logits.shape[-1] + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = torch.nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + return (loss,) + else: + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) @classmethod def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):