From a6218ff8a313b9085eb0a0853f2ea28aaeaf6af2 Mon Sep 17 00:00:00 2001 From: Nika Kudukhashvili Date: Thu, 24 Apr 2025 23:25:33 +0400 Subject: [PATCH 1/2] AvaForCausalLM --- .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 3 +- src/transformers/models/ava/__init__.py | 28 + .../models/ava/configuration_ava.py | 271 ++++++++ src/transformers/models/ava/modeling_ava.py | 602 ++++++++++++++++++ 5 files changed, 905 insertions(+), 1 deletion(-) create mode 100644 src/transformers/models/ava/__init__.py create mode 100644 src/transformers/models/ava/configuration_ava.py create mode 100644 src/transformers/models/ava/modeling_ava.py diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7d13bc788d46..f9591d39c9b2 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -39,6 +39,7 @@ ("aria_text", "AriaTextConfig"), ("audio-spectrogram-transformer", "ASTConfig"), ("autoformer", "AutoformerConfig"), + ("ava", "AvaConfig"), ("aya_vision", "AyaVisionConfig"), ("bamba", "BambaConfig"), ("bark", "BarkConfig"), @@ -383,6 +384,7 @@ ("aria_text", "AriaText"), ("audio-spectrogram-transformer", "Audio Spectrogram Transformer"), ("autoformer", "Autoformer"), + ("ava", "AVA"), ("aya_vision", "AyaVision"), ("bamba", "Bamba"), ("bark", "Bark"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index a7271d04f607..ae671c58fe73 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -26,7 +26,6 @@ ) from .configuration_auto import CONFIG_MAPPING_NAMES - logger = logging.get_logger(__name__) MODEL_MAPPING_NAMES = OrderedDict( @@ -39,6 +38,7 @@ ("aria_text", "AriaTextModel"), ("audio-spectrogram-transformer", "ASTModel"), ("autoformer", "AutoformerModel"), + ("ava", "AvaModel"), ("bamba", "BambaModel"), ("bark", "BarkModel"), ("bart", "BartModel"), @@ -508,6 +508,7 @@ [ # Model for Causal LM mapping ("aria_text", "AriaTextForCausalLM"), + ("ava", "AvaForCausalLM"), ("bamba", "BambaForCausalLM"), ("bart", "BartForCausalLM"), ("bert", "BertLMHeadModel"), diff --git a/src/transformers/models/ava/__init__.py b/src/transformers/models/ava/__init__.py new file mode 100644 index 000000000000..1a98303e00c8 --- /dev/null +++ b/src/transformers/models/ava/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 Nika Kudukhashvili . All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_ava import * + from .modeling_ava import * +else: + import sys + + _file = globals()['__file__'] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/ava/configuration_ava.py b/src/transformers/models/ava/configuration_ava.py new file mode 100644 index 000000000000..d27fcb6211e4 --- /dev/null +++ b/src/transformers/models/ava/configuration_ava.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Nika Kudukhashvili . All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + +logger = logging.get_logger(__name__) + +class AvaConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`AvaModel`]. It is used to instantiate an AVA model + according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the AVA model. Defines the number of different tokens that can be represented. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 16): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer. + rms_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the RMS normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the beginning-of-sequence token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the end-of-sequence token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + kv_heads (`int`, *optional*): + Number of key/value heads (for Grouped Query Attention). Defaults to num_attention_heads. + head_dim (`int`, *optional*): + The dimension of each attention head. Defaults to hidden_size // num_attention_heads. + """ + + model_type = "ava" + PREDEFINED_MODELS = { + # Tiny models (Edge devices, IoT, offline agents, chatbots) + '100m': { + 'hidden_size': 768, + 'intermediate_size': 3072, + 'num_hidden_layers': 6, + 'num_attention_heads': 12, + 'max_position_embeddings': 2048, + 'head_dim': 64, + 'kv_heads': 4 + }, + '500m': { + 'hidden_size': 1024, + 'intermediate_size': 4096, + 'num_hidden_layers': 8, + 'num_attention_heads': 16, + 'max_position_embeddings': 2048, + 'head_dim': 64, + 'kv_heads': 4 + }, + # Small models (Mobile apps, personal assistants, summarization) + '1b': { + 'hidden_size': 1280, + 'intermediate_size': 5120, + 'num_hidden_layers': 12, + 'num_attention_heads': 16, + 'max_position_embeddings': 4096, + 'head_dim': 80, + 'kv_heads': 8 + }, + '3b': { + 'hidden_size': 1600, + 'intermediate_size': 6400, + 'num_hidden_layers': 24, + 'num_attention_heads': 16, + 'max_position_embeddings': 4096, + 'head_dim': 100, + 'kv_heads': 8 + }, + # Medium models (Coding, reasoning, multi-turn chat, translation) + '7b': { + 'hidden_size': 4096, + 'intermediate_size': 11008, + 'num_hidden_layers': 32, + 'num_attention_heads': 32, + 'max_position_embeddings': 8192, + 'head_dim': 128, + 'kv_heads': 8 + }, + '13b': { + 'hidden_size': 5120, + 'intermediate_size': 13824, + 'num_hidden_layers': 40, + 'num_attention_heads': 40, + 'max_position_embeddings': 8192, + 'head_dim': 128, + 'kv_heads': 8 + }, + # Large models (Research, enterprise-level applications) + '30b': { + 'hidden_size': 6656, + 'intermediate_size': 17920, + 'num_hidden_layers': 60, + 'num_attention_heads': 52, + 'max_position_embeddings': 8192, + 'head_dim': 128, + 'kv_heads': 8 + }, + '65b': { + 'hidden_size': 8192, + 'intermediate_size': 22016, + 'num_hidden_layers': 80, + 'num_attention_heads': 64, + 'max_position_embeddings': 8192, + 'head_dim': 128, + 'kv_heads': 8 + }, + # Massive models (AGI research, cutting-edge LLMs) + '100b': { + 'hidden_size': 12288, + 'intermediate_size': 33024, + 'num_hidden_layers': 96, + 'num_attention_heads': 96, + 'max_position_embeddings': 16384, + 'head_dim': 128, + 'kv_heads': 8 + } + } + + def __init__( + self, + vocab_size=32000, + hidden_size=2048, + intermediate_size=8192, + num_hidden_layers=16, + num_attention_heads=16, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + attention_dropout=0.0, + kv_heads=None, + head_dim=None, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.kv_heads = kv_heads if kv_heads is not None else num_attention_heads + self.head_dim = head_dim if head_dim is not None else hidden_size // num_attention_heads + + # Validate parameters + self._validate_config() + + def _validate_config(self): + """Validate the configuration parameters""" + if self.vocab_size <= 0: + raise ValueError(f"vocab_size must be positive, got {self.vocab_size}") + + if self.hidden_size <= 0: + raise ValueError(f"hidden_size must be positive, got {self.hidden_size}") + + if self.num_attention_heads <= 0: + raise ValueError(f"num_attention_heads must be positive, got {self.num_attention_heads}") + + if self.head_dim <= 0: + raise ValueError(f"head_dim must be positive, got {self.head_dim}") + + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_attention_heads, got {self.hidden_size} and {self.num_attention_heads}" + ) + + if self.kv_heads <= 0: + raise ValueError(f"kv_heads must be positive, got {self.kv_heads}") + + if self.num_attention_heads % self.kv_heads != 0: + raise ValueError( + f"num_attention_heads must be divisible by kv_heads, got {self.num_attention_heads} and {self.kv_heads}" + ) + + @classmethod + def from_predefined(cls, model_size="7b", **kwargs): + """ + Instantiate a config from a predefined model architecture. + + Args: + model_size (`str`): + One of the predefined model sizes (e.g., '100m', '500m', '1b', '3b', '7b', '13b', '30b', '65b', '100b') + **kwargs: + Additional arguments to override the predefined config + + Returns: + AvaConfig: The configuration object + """ + if model_size not in cls.PREDEFINED_MODELS: + raise ValueError( + f"Unknown model size '{model_size}'. Available sizes: {list(cls.PREDEFINED_MODELS.keys())}" + ) + + config_dict = cls.PREDEFINED_MODELS[model_size].copy() + config_dict.update(kwargs) + + return cls(**config_dict) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """ + Instantiate a config from a pretrained model name or path. + + This method handles both: + - Actual pretrained model paths (files/directories) + - Predefined model shortcuts (e.g., "ava/7b") + """ + if isinstance(pretrained_model_name_or_path, str): + if pretrained_model_name_or_path.startswith("ava/"): + model_size = pretrained_model_name_or_path.split("/")[-1] + return cls.from_predefined(model_size, **kwargs) + + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) \ No newline at end of file diff --git a/src/transformers/models/ava/modeling_ava.py b/src/transformers/models/ava/modeling_ava.py new file mode 100644 index 000000000000..5fa0d167e945 --- /dev/null +++ b/src/transformers/models/ava/modeling_ava.py @@ -0,0 +1,602 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Nika Kudukhashvili . All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from .configuration_ava import AvaConfig + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + + return torch.cat([-x2, x1], dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin): + cos = cos[:, :, :q.shape[2], :] + sin = sin[:, :, :q.shape[2], :] + + q_embed = q * cos + rotate_half(q) * sin + k_embed = k * cos + rotate_half(k) * sin + + return q_embed, k_embed + + +class AvaAttention(nn.Module): + def __init__(self, config: AvaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.kv_heads = getattr(config, "kv_heads", self.num_heads) + self.kv_dim = self.head_dim * self.kv_heads + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.kv_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.kv_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + self.dropout = nn.Dropout(config.attention_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + rotary_emb=None, + position_ids=None + ): + B, T, _ = hidden_states.shape + + query = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs) + key = self.k_proj(hidden_states).view(B, T, self.kv_heads, self.head_dim).transpose(1, 2) # (B, kvh, T, hs) + value = self.v_proj(hidden_states).view(B, T, self.kv_heads, self.head_dim).transpose(1, 2) # (B, kvh, T, hs) + + if past_key_value is not None: + key = torch.cat([past_key_value[0], key], dim=2) + value = torch.cat([past_key_value[1], value], dim=2) + + past_key_value = (key, value) if use_cache else None + + if rotary_emb is not None: + cos, sin = rotary_emb(query, seq_len=query.shape[2]) + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + if self.kv_heads != self.num_heads: + repeat_factor = self.num_heads // self.kv_heads + key = key.repeat_interleave(repeat_factor, dim=1) + value = value.repeat_interleave(repeat_factor, dim=1) + + attn_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_scores += attention_mask + + attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype) + attn_probs = self.dropout(attn_probs) + + context = torch.matmul(attn_probs, value) + context = context.transpose(1, 2).contiguous().view(B, T, self.hidden_size) + + output = self.o_proj(context) + + if output_attentions: + return output, past_key_value, attn_probs + + return output, past_key_value + +class AvaRMSNorm(nn.Module): + def __init__( + self, + hidden_size: int, + epsilon: float = 1e-5 + ): + + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.epsilon = epsilon + + def forward(self, hidden_states): + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) + + return self.weight * hidden_states.to(hidden_states.dtype) + +class AvaRotaryEmbedding(nn.Module): + """Rotary Position Embeddings""" + + def __init__(self, + dim, + max_position_embeddings = 2048, + base = 10000.0 + ): + + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = max_position_embeddings + + t = torch.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False) + self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False) + + def forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + + t = torch.arange( + self.max_seq_len_cached, + device = x.device, + dtype = self.inv_freq.dtype + ) + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False) + self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), + ) + +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class AvaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + self.gate_proj = nn.Linear( + self.hidden_size, + self.intermediate_size, + bias = False + ) + + self.up_proj = nn.Linear( + self.hidden_size, + self.intermediate_size, + bias = False + ) + + self.down_proj = nn.Linear( + self.intermediate_size, + self.hidden_size, + bias = False + ) + + self.act_fn = SiLU() + + def forward(self, x): + return self.down_proj( + self.act_fn(self.gate_proj(x)) * self.up_proj(x) + ) + +class AvaDecoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.self_attn = AvaAttention(config) + + self.mlp = AvaMLP(config) + self.input_layernorm = AvaRMSNorm( + config.hidden_size, + epsilon = config.rms_norm_eps + ) + + self.post_attention_layernorm = AvaRMSNorm( + config.hidden_size, + epsilon = config.rms_norm_eps + ) + + def forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + output_attentions=False, + use_cache=False, + rotary_emb=None, + ): + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + attn_outputs = self.self_attn( + hidden_states = hidden_states, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_value = past_key_value, + output_attentions = output_attentions, + use_cache = use_cache, + rotary_emb = rotary_emb, + ) + + hidden_states = attn_outputs[0] + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (attn_outputs[1],) + + if output_attentions: + outputs += (attn_outputs[2],) + + return outputs + +class AvaModel(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([AvaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = AvaRMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + self.rotary_emb = AvaRotaryEmbedding( + config.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + + + self.apply(self._init_weights) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + + def forward( + self, + input_ids, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + batch_size, seq_length = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2] + + if position_ids is None: + position_ids = torch.arange( + seq_length, + dtype = torch.long, + device = input_ids.device + ).unsqueeze(0) + + past_length = 0 + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + position_ids = position_ids[:, past_length:] + + + if attention_mask is not None: + causal_mask = torch.triu( + torch.full((seq_length, seq_length), -float('inf'), device=attention_mask.device), + diagonal=1, + ) + + expanded_attn_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * torch.finfo(torch.float32).min + expanded_attn_mask = expanded_attn_mask + causal_mask.unsqueeze(0) + else: + causal_mask = torch.triu( + torch.full( + (seq_length, seq_length), + -float('inf'), + device = input_ids.device + ), + + diagonal=1, + ) + + expanded_attn_mask = causal_mask.unsqueeze(0).expand(batch_size, -1, -1) + + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_past_key_values = () if use_cache else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer( + hidden_states, + attention_mask=expanded_attn_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + rotary_emb=self.rotary_emb, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_past_key_values += (layer_outputs[1],) + + if output_attentions: + all_self_attns += (layer_outputs[2],) + + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + return { + 'last_hidden_state': hidden_states, + 'past_key_values': next_past_key_values, + 'hidden_states': all_hidden_states, + 'attentions': all_self_attns, + } + +class AvaForCausalLM(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.model = AvaModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.apply(self._init_weights) + + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output = self.model( + input_ids = input_ids, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + use_cache = use_cache, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + return_dict = return_dict, + ) + + hidden_states = output['last_hidden_state'] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + + + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + return { + 'loss': loss, + 'logits': logits, + 'past_key_values': output.get('past_key_values', None), + 'hidden_states': output.get('hidden_states', None), + 'attentions': output.get('attentions', None), + } + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): + input_shape = input_ids.shape + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values is not None: + position_ids = position_ids[:, -1].unsqueeze(-1) + + return { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache', True), + } + + @torch.no_grad() + def generate( + self, + input_ids, + attention_mask=None, + max_length=None, + temperature=1.0, + top_k=50, + top_p=0.9, + repetition_penalty=1.0, + do_sample=True, + num_return_sequences=1, + pad_token_id=None, + eos_token_id=None, + use_cache=True, + streamer=None, + early_stopping=True, + ): + """ + Improved generate method with streamer support and better caching + """ + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id + max_length = max_length if max_length is not None else self.config.max_position_embeddings + batch_size = input_ids.shape[0] + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + input_ids_seq_length = input_ids.shape[-1] + generated_tokens = input_ids.clone() + cached_position_ids = torch.arange(input_ids_seq_length, device=input_ids.device).unsqueeze(0).expand(batch_size, -1) + past_key_values = None + + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + self.eval() + + for current_length in range(input_ids_seq_length, max_length): + if past_key_values is not None: + inputs = generated_tokens[:, -1].unsqueeze(-1) + else: + inputs = generated_tokens + + if past_key_values is not None: + position_ids = cached_position_ids[:, -1].unsqueeze(-1) + 1 + cached_position_ids = torch.cat([cached_position_ids, position_ids], dim=-1) + else: + position_ids = cached_position_ids + + if attention_mask is not None and past_key_values is not None: + attention_mask = torch.cat([ + attention_mask, + unfinished_sequences.unsqueeze(-1) + ], dim = -1) + + model_inputs = self.prepare_inputs_for_generation( + inputs, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + ) + + outputs = self.forward(**model_inputs) + next_token_logits = outputs['logits'][:, -1, :] + past_key_values = outputs['past_key_values'] + next_token_logits = next_token_logits / temperature + + if repetition_penalty != 1.0: + for i in range(batch_size): + for previous_token in generated_tokens[i]: + if previous_token in [pad_token_id, eos_token_id]: + continue + + next_token_logits[i, previous_token] /= repetition_penalty + + if top_k > 0: + top_k_values, top_k_indices = torch.topk(next_token_logits, top_k) + next_token_logits = torch.full_like(next_token_logits, float('-inf')) + next_token_logits.scatter_(1, top_k_indices, top_k_values) + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + for i in range(batch_size): + indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] + next_token_logits[i, indices_to_remove] = float('-inf') + + if do_sample: + probs = F.softmax(next_token_logits, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + + else: + next_tokens = torch.argmax(next_token_logits, dim=-1) + + if eos_token_id is not None: + next_tokens = next_tokens * unfinished_sequences + eos_token_id * (1 - unfinished_sequences) + + generated_tokens = torch.cat([generated_tokens, next_tokens.unsqueeze(-1)], dim=-1) + + if eos_token_id is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) + + if streamer is not None: + streamer.put(next_tokens.unsqueeze(-1)) + + if unfinished_sequences.max() == 0 or (early_stopping and current_length > input_ids_seq_length + 50): + break + + if streamer is not None: + streamer.end() + + return generated_tokens From 069bc442347acbfbddd603f3b8d9e91292359547 Mon Sep 17 00:00:00 2001 From: Nika Kudukhashvili Date: Thu, 24 Apr 2025 23:58:24 +0400 Subject: [PATCH 2/2] Fixed 'Blank line contains whitespace' --- src/transformers/models/ava/__init__.py | 1 - .../models/ava/configuration_ava.py | 27 +- src/transformers/models/ava/modeling_ava.py | 233 +++++++++--------- 3 files changed, 128 insertions(+), 133 deletions(-) diff --git a/src/transformers/models/ava/__init__.py b/src/transformers/models/ava/__init__.py index 1a98303e00c8..aba96fd17ca9 100644 --- a/src/transformers/models/ava/__init__.py +++ b/src/transformers/models/ava/__init__.py @@ -17,7 +17,6 @@ from ...utils import _LazyModule from ...utils.import_utils import define_import_structure - if TYPE_CHECKING: from .configuration_ava import * from .modeling_ava import * diff --git a/src/transformers/models/ava/configuration_ava.py b/src/transformers/models/ava/configuration_ava.py index d27fcb6211e4..c28c9f2d4b29 100644 --- a/src/transformers/models/ava/configuration_ava.py +++ b/src/transformers/models/ava/configuration_ava.py @@ -14,9 +14,6 @@ # limitations under the License. from ...configuration_utils import PretrainedConfig -from ...utils import logging - -logger = logging.get_logger(__name__) class AvaConfig(PretrainedConfig): """ @@ -207,24 +204,24 @@ def _validate_config(self): """Validate the configuration parameters""" if self.vocab_size <= 0: raise ValueError(f"vocab_size must be positive, got {self.vocab_size}") - + if self.hidden_size <= 0: raise ValueError(f"hidden_size must be positive, got {self.hidden_size}") - + if self.num_attention_heads <= 0: raise ValueError(f"num_attention_heads must be positive, got {self.num_attention_heads}") - + if self.head_dim <= 0: raise ValueError(f"head_dim must be positive, got {self.head_dim}") - + if self.hidden_size % self.num_attention_heads != 0: raise ValueError( f"hidden_size must be divisible by num_attention_heads, got {self.hidden_size} and {self.num_attention_heads}" ) - + if self.kv_heads <= 0: raise ValueError(f"kv_heads must be positive, got {self.kv_heads}") - + if self.num_attention_heads % self.kv_heads != 0: raise ValueError( f"num_attention_heads must be divisible by kv_heads, got {self.num_attention_heads} and {self.kv_heads}" @@ -234,13 +231,13 @@ def _validate_config(self): def from_predefined(cls, model_size="7b", **kwargs): """ Instantiate a config from a predefined model architecture. - + Args: model_size (`str`): One of the predefined model sizes (e.g., '100m', '500m', '1b', '3b', '7b', '13b', '30b', '65b', '100b') **kwargs: Additional arguments to override the predefined config - + Returns: AvaConfig: The configuration object """ @@ -248,17 +245,17 @@ def from_predefined(cls, model_size="7b", **kwargs): raise ValueError( f"Unknown model size '{model_size}'. Available sizes: {list(cls.PREDEFINED_MODELS.keys())}" ) - + config_dict = cls.PREDEFINED_MODELS[model_size].copy() config_dict.update(kwargs) - + return cls(**config_dict) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ Instantiate a config from a pretrained model name or path. - + This method handles both: - Actual pretrained model paths (files/directories) - Predefined model shortcuts (e.g., "ava/7b") @@ -267,5 +264,5 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): if pretrained_model_name_or_path.startswith("ava/"): model_size = pretrained_model_name_or_path.split("/")[-1] return cls.from_predefined(model_size, **kwargs) - + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) \ No newline at end of file diff --git a/src/transformers/models/ava/modeling_ava.py b/src/transformers/models/ava/modeling_ava.py index 5fa0d167e945..6dbe3e5ecfc3 100644 --- a/src/transformers/models/ava/modeling_ava.py +++ b/src/transformers/models/ava/modeling_ava.py @@ -23,9 +23,8 @@ def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] - - return torch.cat([-x2, x1], dim=-1) + return torch.cat([-x2, x1], dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): cos = cos[:, :, :q.shape[2], :] @@ -36,7 +35,6 @@ def apply_rotary_pos_emb(q, k, cos, sin): return q_embed, k_embed - class AvaAttention(nn.Module): def __init__(self, config: AvaConfig): super().__init__() @@ -84,7 +82,7 @@ def forward( key = key.repeat_interleave(repeat_factor, dim=1) value = value.repeat_interleave(repeat_factor, dim=1) - attn_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_scores += attention_mask @@ -92,23 +90,23 @@ def forward( attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype) attn_probs = self.dropout(attn_probs) - context = torch.matmul(attn_probs, value) - context = context.transpose(1, 2).contiguous().view(B, T, self.hidden_size) + context = torch.matmul(attn_probs, value) + context = context.transpose(1, 2).contiguous().view(B, T, self.hidden_size) output = self.o_proj(context) if output_attentions: return output, past_key_value, attn_probs - + return output, past_key_value - + class AvaRMSNorm(nn.Module): def __init__( - self, - hidden_size: int, + self, + hidden_size: int, epsilon: float = 1e-5 ): - + super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.epsilon = epsilon @@ -119,22 +117,22 @@ def forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) return self.weight * hidden_states.to(hidden_states.dtype) - + class AvaRotaryEmbedding(nn.Module): """Rotary Position Embeddings""" - def __init__(self, - dim, - max_position_embeddings = 2048, + def __init__(self, + dim, + max_position_embeddings = 2048, base = 10000.0 ): super().__init__() - + self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) self.max_seq_len_cached = max_position_embeddings @@ -151,22 +149,22 @@ def forward(self, x, seq_len=None): self.max_seq_len_cached = seq_len t = torch.arange( - self.max_seq_len_cached, - device = x.device, + self.max_seq_len_cached, + device = x.device, dtype = self.inv_freq.dtype ) - + freqs = torch.einsum('i,j->ij', t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.register_buffer('cos_cached', emb.cos()[None, None, :, :], persistent=False) self.register_buffer('sin_cached', emb.sin()[None, None, :, :], persistent=False) - + return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype, device=x.device), ) - + class SiLU(nn.Module): def forward(self, x): return x * torch.sigmoid(x) @@ -178,32 +176,32 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - + self.gate_proj = nn.Linear( - self.hidden_size, - self.intermediate_size, + self.hidden_size, + self.intermediate_size, bias = False ) - + self.up_proj = nn.Linear( - self.hidden_size, - self.intermediate_size, + self.hidden_size, + self.intermediate_size, bias = False ) - + self.down_proj = nn.Linear( - self.intermediate_size, - self.hidden_size, + self.intermediate_size, + self.hidden_size, bias = False ) - + self.act_fn = SiLU() - + def forward(self, x): return self.down_proj( self.act_fn(self.gate_proj(x)) * self.up_proj(x) ) - + class AvaDecoderLayer(nn.Module): def __init__(self, config): super().__init__() @@ -213,15 +211,15 @@ def __init__(self, config): self.mlp = AvaMLP(config) self.input_layernorm = AvaRMSNorm( - config.hidden_size, + config.hidden_size, epsilon = config.rms_norm_eps ) - + self.post_attention_layernorm = AvaRMSNorm( - config.hidden_size, + config.hidden_size, epsilon = config.rms_norm_eps ) - + def forward( self, hidden_states, @@ -232,10 +230,10 @@ def forward( use_cache=False, rotary_emb=None, ): - + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - + attn_outputs = self.self_attn( hidden_states = hidden_states, attention_mask = attention_mask, @@ -245,23 +243,23 @@ def forward( use_cache = use_cache, rotary_emb = rotary_emb, ) - + hidden_states = attn_outputs[0] hidden_states = residual + hidden_states - + residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - + outputs = (hidden_states,) - + if use_cache: - outputs += (attn_outputs[1],) - + outputs += (attn_outputs[1],) + if output_attentions: - outputs += (attn_outputs[2],) - + outputs += (attn_outputs[2],) + return outputs class AvaModel(nn.Module): @@ -269,7 +267,7 @@ def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size - + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([AvaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = AvaRMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) @@ -278,21 +276,21 @@ def __init__(self, config): max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) - - + + self.apply(self._init_weights) - + def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) - + if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) - + def forward( self, input_ids, @@ -306,56 +304,56 @@ def forward( return_dict=None, ): batch_size, seq_length = input_ids.shape if input_ids is not None else inputs_embeds.shape[:2] - + if position_ids is None: position_ids = torch.arange( - seq_length, - dtype = torch.long, + seq_length, + dtype = torch.long, device = input_ids.device ).unsqueeze(0) - + past_length = 0 if past_key_values is not None: past_length = past_key_values[0][0].shape[2] position_ids = position_ids[:, past_length:] - - + + if attention_mask is not None: causal_mask = torch.triu( torch.full((seq_length, seq_length), -float('inf'), device=attention_mask.device), diagonal=1, ) - + expanded_attn_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * torch.finfo(torch.float32).min expanded_attn_mask = expanded_attn_mask + causal_mask.unsqueeze(0) else: causal_mask = torch.triu( torch.full( - (seq_length, seq_length), - -float('inf'), + (seq_length, seq_length), + -float('inf'), device = input_ids.device ), - + diagonal=1, ) expanded_attn_mask = causal_mask.unsqueeze(0).expand(batch_size, -1, -1) - - + + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - + hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_past_key_values = () if use_cache else None - + for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - + past_key_value = past_key_values[i] if past_key_values is not None else None - + layer_outputs = layer( hidden_states, attention_mask=expanded_attn_mask, @@ -365,28 +363,28 @@ def forward( use_cache=use_cache, rotary_emb=self.rotary_emb, ) - + hidden_states = layer_outputs[0] - + if use_cache: next_past_key_values += (layer_outputs[1],) - + if output_attentions: all_self_attns += (layer_outputs[2],) - - + + hidden_states = self.norm(hidden_states) - + if output_hidden_states: all_hidden_states += (hidden_states,) - + return { 'last_hidden_state': hidden_states, 'past_key_values': next_past_key_values, 'hidden_states': all_hidden_states, 'attentions': all_self_attns, } - + class AvaForCausalLM(nn.Module): def __init__(self, config): super().__init__() @@ -395,21 +393,21 @@ def __init__(self, config): self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.apply(self._init_weights) - + if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight - + def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) - + if module.bias is not None: module.bias.data.zero_() - + elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) - + def forward( self, input_ids=None, @@ -434,24 +432,24 @@ def forward( output_hidden_states = output_hidden_states, return_dict = return_dict, ) - + hidden_states = output['last_hidden_state'] logits = self.lm_head(hidden_states) - + loss = None if labels is not None: - + shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() - + loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id) shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) - - + + shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) - + return { 'loss': loss, 'logits': logits, @@ -459,21 +457,21 @@ def forward( 'hidden_states': output.get('hidden_states', None), 'attentions': output.get('attentions', None), } - + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): input_shape = input_ids.shape if past_key_values is not None: input_ids = input_ids[:, -1].unsqueeze(-1) - - + + position_ids = kwargs.get('position_ids', None) if attention_mask is not None and position_ids is None: - + position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values is not None: position_ids = position_ids[:, -1].unsqueeze(-1) - + return { 'input_ids': input_ids, 'attention_mask': attention_mask, @@ -481,7 +479,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attenti 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), } - + @torch.no_grad() def generate( self, @@ -503,40 +501,41 @@ def generate( """ Improved generate method with streamer support and better caching """ + pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id max_length = max_length if max_length is not None else self.config.max_position_embeddings batch_size = input_ids.shape[0] - + if attention_mask is None: attention_mask = torch.ones_like(input_ids) - + input_ids_seq_length = input_ids.shape[-1] generated_tokens = input_ids.clone() cached_position_ids = torch.arange(input_ids_seq_length, device=input_ids.device).unsqueeze(0).expand(batch_size, -1) past_key_values = None - + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) self.eval() - + for current_length in range(input_ids_seq_length, max_length): if past_key_values is not None: inputs = generated_tokens[:, -1].unsqueeze(-1) else: inputs = generated_tokens - + if past_key_values is not None: position_ids = cached_position_ids[:, -1].unsqueeze(-1) + 1 cached_position_ids = torch.cat([cached_position_ids, position_ids], dim=-1) else: position_ids = cached_position_ids - + if attention_mask is not None and past_key_values is not None: attention_mask = torch.cat([ - attention_mask, + attention_mask, unfinished_sequences.unsqueeze(-1) ], dim = -1) - + model_inputs = self.prepare_inputs_for_generation( inputs, past_key_values=past_key_values, @@ -544,12 +543,12 @@ def generate( position_ids=position_ids, use_cache=use_cache, ) - + outputs = self.forward(**model_inputs) next_token_logits = outputs['logits'][:, -1, :] past_key_values = outputs['past_key_values'] next_token_logits = next_token_logits / temperature - + if repetition_penalty != 1.0: for i in range(batch_size): for previous_token in generated_tokens[i]: @@ -557,12 +556,12 @@ def generate( continue next_token_logits[i, previous_token] /= repetition_penalty - + if top_k > 0: top_k_values, top_k_indices = torch.topk(next_token_logits, top_k) next_token_logits = torch.full_like(next_token_logits, float('-inf')) next_token_logits.scatter_(1, top_k_indices, top_k_values) - + if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) @@ -570,33 +569,33 @@ def generate( sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 - + for i in range(batch_size): indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] next_token_logits[i, indices_to_remove] = float('-inf') - + if do_sample: probs = F.softmax(next_token_logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_logits, dim=-1) - + if eos_token_id is not None: next_tokens = next_tokens * unfinished_sequences + eos_token_id * (1 - unfinished_sequences) - + generated_tokens = torch.cat([generated_tokens, next_tokens.unsqueeze(-1)], dim=-1) - + if eos_token_id is not None: unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long()) - + if streamer is not None: streamer.put(next_tokens.unsqueeze(-1)) - + if unfinished_sequences.max() == 0 or (early_stopping and current_length > input_ids_seq_length + 50): break - + if streamer is not None: streamer.end() - + return generated_tokens