From c02d610734bc0ef21ef8ce28ffe388ffc3210295 Mon Sep 17 00:00:00 2001 From: Filipe Assuncao Date: Fri, 31 Oct 2025 16:48:03 +0000 Subject: [PATCH 1/8] bump --- setup.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index d2f7fa514..217eaf649 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "timeout-decorator", "torch", "torchvision", - "transformers~=4.52.4", + "transformers~=4.55.2", ] @@ -71,12 +71,7 @@ # packaging: "packaging" # # some of the values are versioned whereas others aren't. -deps = { - b: a - for a, b in ( - re.findall(r"^(([^!=<>~ ]+)(?:[!=<>~ ].*)?$)", x)[0] for x in _deps - ) -} +deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ ]+)(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)} def deps_list(*pkgs): @@ -114,9 +109,7 @@ def deps_list(*pkgs): "torchvision", ) -extras["quality"] = deps_list( - "black", "datasets", "isort", "flake8", "GitPython" -) +extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython") extras["docs"] = deps_list( "docutils", From 3df88aeef5fc59430d546c80ee901da366fdd3d3 Mon Sep 17 00:00:00 2001 From: Filipe Assuncao Date: Sat, 1 Nov 2025 07:07:46 +0000 Subject: [PATCH 2/8] update branch --- .gitmodules | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitmodules b/.gitmodules index 2c1a30f22..49eb8dbbe 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,4 @@ [submodule "hf_transformers"] path = hf_transformers url = https://github.com/huggingface/transformers.git + branch = v4.55.2 From bf3f9fec347162c9bf00cb1608d56a95ca245e4e Mon Sep 17 00:00:00 2001 From: Filipe Assuncao Date: Sat, 1 Nov 2025 07:15:29 +0000 Subject: [PATCH 3/8] update --- hf_transformers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hf_transformers b/hf_transformers index 51f94ea06..a0bf5a82e 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 51f94ea06d19a6308c61bbb4dc97c40aabd12bad +Subproject commit a0bf5a82eebf88ee9f52145be427f6f1541329f6 From 4bf687e4ad55ba9bc9f9be3461289a07d00a50b3 Mon Sep 17 00:00:00 2001 From: Filipe Assuncao Date: Sat, 1 Nov 2025 07:18:51 +0000 Subject: [PATCH 4/8] revert --- .gitmodules | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 49eb8dbbe..2c1a30f22 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,3 @@ [submodule "hf_transformers"] path = hf_transformers url = https://github.com/huggingface/transformers.git - branch = v4.55.2 From 637b2d4420c8a70f7bbff53e8a877695b0a3b1aa Mon Sep 17 00:00:00 2001 From: Filipe Assuncao Date: Sat, 1 Nov 2025 08:09:19 +0000 Subject: [PATCH 5/8] bump version --- hf_transformers | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hf_transformers b/hf_transformers index a0bf5a82e..0f3f8a1e0 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit a0bf5a82eebf88ee9f52145be427f6f1541329f6 +Subproject commit 0f3f8a1e076198b9ab627f2f85fedfff428925ec diff --git a/setup.py b/setup.py index 217eaf649..8ef68a200 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "timeout-decorator", "torch", "torchvision", - "transformers~=4.55.2", + "transformers~=4.57.1", ] From b4e01c5c4aa3a7836572e7fd2ae98a679216c68c Mon Sep 17 00:00:00 2001 From: Filipe Assuncao Date: Sun, 2 Nov 2025 07:58:17 +0000 Subject: [PATCH 6/8] mv to 4.56.0 --- hf_transformers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hf_transformers b/hf_transformers index 0f3f8a1e0..e7d351ceb 160000 --- a/hf_transformers +++ b/hf_transformers @@ -1 +1 @@ -Subproject commit 0f3f8a1e076198b9ab627f2f85fedfff428925ec +Subproject commit e7d351cebad5f6dcdd169b0c034fdee0a000e6a9 From e6db4a4758555b2289dc0588942ec5b132d0a355 Mon Sep 17 00:00:00 2001 From: Filipe Assuncao Date: Sun, 2 Nov 2025 07:58:30 +0000 Subject: [PATCH 7/8] bump minor --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8ef68a200..be07e44df 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "timeout-decorator", "torch", "torchvision", - "transformers~=4.57.1", + "transformers~=4.56.0", ] From a6cb1a361382ddf0abf23ef736ca4bc1d7915fbe Mon Sep 17 00:00:00 2001 From: Filipe Assuncao Date: Mon, 3 Nov 2025 08:00:14 +0000 Subject: [PATCH 8/8] partially fix roberta --- .../xlm_roberta/modeling_xlm_roberta.py | 184 +++++++++--------- 1 file changed, 95 insertions(+), 89 deletions(-) diff --git a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py index af981456f..3534ed64e 100644 --- a/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/adapters/models/xlm_roberta/modeling_xlm_roberta.py @@ -28,6 +28,8 @@ XLMRobertaSelfAttention, XLMRobertaSelfOutput, ) +from transformers.cache_utils import Cache, EncoderDecoderCache +from transformers.utils.deprecation import deprecate_kwarg from transformers.utils import logging from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel @@ -39,6 +41,7 @@ # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->XLMRoberta +@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") class XLMRobertaSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, XLMRobertaSelfAttention): def forward( self, @@ -46,54 +49,58 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - attention_mask = prefix_attention_mask(attention_mask) # type: ignore - - mixed_query_layer = self.query(hidden_states) - - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. + cache_position: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor]: + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_values.is_updated[self.layer_idx] = True - query_layer = self.transpose_for_scores(mixed_query_layer) # >>> START AH Changes <<< query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) # >>> END AH Changes <<< - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) - + # TODO - what to do with this? # >>> START AH Changes <<< key_layer, value_layer, attention_mask = self.prefix_tuning( key_layer, value_layer, hidden_states, attention_mask @@ -106,7 +113,7 @@ def forward( if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -148,23 +155,20 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class XLMRobertaSdpaSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, XLMRobertaSdpaSelfAttention): + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: # >>> START AH Changes <<< attention_mask = prefix_attention_mask(attention_mask, [2, 3]) # type: ignore @@ -184,46 +188,67 @@ def forward( attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) - # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention - # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_values.is_updated[self.layer_idx] = True # >>> START AH Changes <<< query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer) (attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask) # >>> END AH Changes <<< - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 # >>> START AH Changes <<< key_layer, value_layer, attention_mask = self.prefix_tuning( @@ -233,22 +258,6 @@ def forward( bsz = query_layer.size(0) # >>> END AH Changes <<< - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: - query_layer = query_layer.contiguous() - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create - # a causal mask in case tgt_len == 1. - is_causal = ( - True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False - ) - attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, @@ -261,10 +270,7 @@ def forward( attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->XLMRoberta