From 90c423ed7412d1a8c5590b4833513196ade0191f Mon Sep 17 00:00:00 2001 From: ethan Date: Tue, 1 Jul 2025 00:03:46 -0700 Subject: [PATCH 1/5] imporve the audio model accaurcy --- notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb | 2 +- notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb b/notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb index 63293b7e592..9781e72a48f 100644 --- a/notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb +++ b/notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb @@ -86,7 +86,7 @@ "\n", "\n", "pip_install(\n", - " \"transformers>=4.52.0\",\n", + " \"transformers==4.52.0\",\n", " \"torchvision\",\n", " \"accelerate\",\n", " \"qwen-omni-utils[decord]\",\n", diff --git a/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py b/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py index b5d04f93709..e49aae8e4fc 100644 --- a/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py +++ b/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py @@ -334,9 +334,9 @@ def forward_wrap_audio_state(self, each_audio_states): ov_model = ov.convert_model( audio, example_input={ - "padded_feature": torch.randn([1, 128, 9], dtype=torch.float32), - "padded_mask": torch.ones([1, 1, 9], dtype=torch.int32), - "padded_mask_after_cnn": torch.ones([1, 5], dtype=torch.bool), + "padded_feature": torch.randn([3, 128, 200], dtype=torch.float32), + "padded_mask": torch.ones([3, 1, 200], dtype=torch.int32), + "padded_mask_after_cnn": torch.ones([3, 100], dtype=torch.bool), }, ) ov.save_model(ov_model, thinker_audio_path) From 530c46427cf175a17ea4b390182ad0c8006935cd Mon Sep 17 00:00:00 2001 From: ethan Date: Tue, 1 Jul 2025 01:28:58 -0700 Subject: [PATCH 2/5] switch to SDPA implementation --- notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py b/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py index e49aae8e4fc..86a9f2ef9af 100644 --- a/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py +++ b/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py @@ -25,6 +25,7 @@ ALL_ATTENTION_FUNCTIONS, apply_rotary_pos_emb, ) +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import Qwen2_5OmniConfig from pathlib import Path import types from itertools import accumulate @@ -275,8 +276,14 @@ def convert_qwen2_5_omni_model(model_id, output_dir, quantization_config=None, u ckpt = model_id if not (Path(output_dir) / "spk_dict.pt").exists(): hf_hub_download(model_id, filename="spk_dict.pt", local_dir=output_dir) - - model = Qwen2_5OmniForConditionalGeneration.from_pretrained(ckpt, torch_dtype=torch.float16) + + + config = Qwen2_5OmniConfig.from_pretrained(ckpt) + config.thinker_config._attn_implementation_autoset = False + config.thinker_config._attn_implementation = "sdpa" + config.talker_config._attn_implementation_autoset = False + config.talker_config._attn_implementation = "sdpa" + model = Qwen2_5OmniForConditionalGeneration.from_pretrained(ckpt, config=config, torch_dtype=torch.float16) model.eval() processor = AutoProcessor.from_pretrained(ckpt) From 547f5dba26bd0300cd4646c1368296bf5f2ce2b3 Mon Sep 17 00:00:00 2001 From: ethan Date: Wed, 2 Jul 2025 18:45:52 -0700 Subject: [PATCH 3/5] reformat --- .../qwen2_5_omni_helper.py | 1101 +++++++++++++---- 1 file changed, 873 insertions(+), 228 deletions(-) diff --git a/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py b/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py index 86a9f2ef9af..847612565e4 100644 --- a/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py +++ b/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py @@ -25,7 +25,9 @@ ALL_ATTENTION_FUNCTIONS, apply_rotary_pos_emb, ) -from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import Qwen2_5OmniConfig +from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( + Qwen2_5OmniConfig, +) from pathlib import Path import types from itertools import accumulate @@ -55,7 +57,9 @@ def model_has_input_output_name(ov_model: ov.Model, name: str): Returns: True if input or output with requested name exists else False """ - return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], []) + return name in sum( + [list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [] + ) def fuse_cache_reorder( @@ -88,7 +92,9 @@ def fuse_cache_reorder( if model_has_input_output_name(ov_model, "beam_idx"): raise ValueError("Model already has fused cache") input_batch = ov_model.input("inputs_embeds").get_partial_shape()[0] - beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch])) + beam_idx = opset13.parameter( + name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]) + ) beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted? ov_model.add_parameters([beam_idx]) not_kv_inputs.append(ov_model.inputs[-1]) @@ -96,7 +102,9 @@ def fuse_cache_reorder( for input_name in key_value_input_names: parameter_output_port = ov_model.input(input_name) consumers = parameter_output_port.get_target_inputs() - gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim)) + gather = opset13.gather( + parameter_output_port, beam_idx, opset13.constant(gather_dim) + ) for consumer in consumers: consumer.replace_source_output(gather.output(0)) ov_model.validate_nodes_and_infer_types() @@ -122,9 +130,18 @@ def build_state_initializer(ov_model: ov.Model, batch_dim: int): if op.get_type_name() == "ReadValue": dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))] dims[batch_dim] = batch - dims = [(opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim) for dim in dims] + dims = [ + ( + opset13.constant(np.array([dim], dtype=np.int64)) + if isinstance(dim, int) + else dim + ) + for dim in dims + ] shape = opset13.concat(dims, axis=0) - broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape) + broadcast = opset13.broadcast( + opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape + ) op.set_arguments([broadcast]) ov_model.validate_nodes_and_infer_types() @@ -188,7 +205,11 @@ def make_stateful( def patch_stateful(ov_model, dim): key_value_input_names = [key.get_any_name() for key in ov_model.inputs[2:-1]] key_value_output_names = [key.get_any_name() for key in ov_model.outputs[dim:]] - not_kv_inputs = [input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())] + not_kv_inputs = [ + input + for input in ov_model.inputs + if not any(name in key_value_input_names for name in input.get_names()) + ] if not key_value_input_names or not key_value_output_names: return batch_dim = 0 @@ -233,7 +254,9 @@ def cleanup_torchscript_cache(): TOKEN2WAV_BIGVGAN_NAME = "openvino_token2wav_bigvgan_model.xml" -def convert_qwen2_5_omni_model(model_id, output_dir, quantization_config=None, use_local_dir=False): +def convert_qwen2_5_omni_model( + model_id, output_dir, quantization_config=None, use_local_dir=False +): thinker_output_dir = Path(output_dir) / "thinker" talker_output_dir = Path(output_dir) / "talker" @@ -262,7 +285,9 @@ def convert_qwen2_5_omni_model(model_id, output_dir, quantization_config=None, u token2wav_bigvgan_path.exists(), ] ): - print(f"✅ {model_id} model already converted. You can find results in {output_dir}") + print( + f"✅ {model_id} model already converted. You can find results in {output_dir}" + ) return print(f"⌛ {model_id} conversion started. Be patient, it may takes some time.") print("⌛ Load Original model") @@ -276,14 +301,15 @@ def convert_qwen2_5_omni_model(model_id, output_dir, quantization_config=None, u ckpt = model_id if not (Path(output_dir) / "spk_dict.pt").exists(): hf_hub_download(model_id, filename="spk_dict.pt", local_dir=output_dir) - - + config = Qwen2_5OmniConfig.from_pretrained(ckpt) config.thinker_config._attn_implementation_autoset = False config.thinker_config._attn_implementation = "sdpa" config.talker_config._attn_implementation_autoset = False config.talker_config._attn_implementation = "sdpa" - model = Qwen2_5OmniForConditionalGeneration.from_pretrained(ckpt, config=config, torch_dtype=torch.float16) + model = Qwen2_5OmniForConditionalGeneration.from_pretrained( + ckpt, config=config, torch_dtype=torch.float16 + ) model.eval() processor = AutoProcessor.from_pretrained(ckpt) @@ -308,7 +334,9 @@ def forward_wrap_audio(self, padded_feature, padded_mask, padded_mask_after_cnn) padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) - padded_embed = padded_embed + self.positional_embedding.positional_embedding[: padded_embed.shape[1], :].unsqueeze(0).to(padded_embed.dtype) + padded_embed = padded_embed + self.positional_embedding.positional_embedding[ + : padded_embed.shape[1], : + ].unsqueeze(0).to(padded_embed.dtype) hidden_states = padded_embed[padded_mask_after_cnn] cu_seqlens = torch.cat( ( @@ -327,7 +355,9 @@ def forward_wrap_audio(self, padded_feature, padded_mask, padded_mask_after_cnn) return hidden_states def forward_wrap_audio_state(self, each_audio_states): - each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) + each_audio_states = self.avg_pooler( + each_audio_states.transpose(0, 1) + ).transpose_(0, 1) each_audio_states = self.ln_post(each_audio_states) each_audio_states = self.proj(each_audio_states) return each_audio_states @@ -375,7 +405,10 @@ def forward_wrap_audio_state(self, each_audio_states): vision_embed_tokens = model.thinker.visual if not thinker_patcher_path.exists(): __make_16bit_traceable(vision_embed_tokens.patch_embed) - ov_model = ov.convert_model(vision_embed_tokens.patch_embed, example_input={"hidden_states": torch.randn([8, 1176])}) + ov_model = ov.convert_model( + vision_embed_tokens.patch_embed, + example_input={"hidden_states": torch.randn([8, 1176])}, + ) ov.save_model(ov_model, thinker_patcher_path) del ov_model cleanup_torchscript_cache() @@ -389,10 +422,14 @@ def image_embed_forward( rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: seq_len = hidden_states.shape[0] - hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 + ) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) for layer_num, blk in enumerate(self.blocks): @@ -400,7 +437,11 @@ def image_embed_forward( attention_mask_now = attention_mask else: attention_mask_now = window_attention_mask - hidden_states = blk(hidden_states, attention_mask=attention_mask_now, rotary_pos_emb=rotary_pos_emb) + hidden_states = blk( + hidden_states, + attention_mask=attention_mask_now, + rotary_pos_emb=rotary_pos_emb, + ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) @@ -408,8 +449,15 @@ def image_embed_forward( return hidden_states - def sdpa_attn_forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor = None) -> torch.Tensor: - from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import apply_rotary_pos_emb_vision + def sdpa_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( + apply_rotary_pos_emb_vision, + ) seq_length = hidden_states.shape[0] q = self.q(hidden_states).reshape(seq_length, self.num_heads, -1) @@ -421,19 +469,29 @@ def sdpa_attn_forward(self, hidden_states: torch.Tensor, attention_mask: torch.T q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) + attn_output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attention_mask, dropout_p=0.0 + ) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output - def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.Tensor: - hidden_states = hidden_states + self.attn(self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb) + def block_forward( + self, hidden_states, attention_mask, rotary_pos_emb + ) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states if not thinker_merger_path.exists(): - vision_embed_tokens.forward = types.MethodType(image_embed_forward, vision_embed_tokens) + vision_embed_tokens.forward = types.MethodType( + image_embed_forward, vision_embed_tokens + ) for block in vision_embed_tokens.blocks: block.forward = types.MethodType(block_forward, block) block.attn.forward = types.MethodType(sdpa_attn_forward, block.attn) @@ -475,9 +533,19 @@ def forward_wrap_thinker( """take care of image_encode, position_ids and (attention_mask = None is fine)""" if past_key_values is not None: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) outputs = self.model( input_ids=input_ids, @@ -492,7 +560,9 @@ def forward_wrap_thinker( cache_position=cache_position, ) if past_key_values is not None: - outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + outputs["past_key_values"] = outputs[ + "past_key_values" + ].to_legacy_cache() hidden_states = outputs[0] logits = self.lm_head(hidden_states) output = (logits,) + outputs[:] @@ -505,7 +575,12 @@ def forward_wrap_thinker( lang_model.forward = types.MethodType(forward_wrap_thinker, lang_model) num_pkv = lang_model.model.config.num_hidden_layers - pkv_shape = (2, lang_model.model.config.num_key_value_heads, 2, hidden_size // lang_model.model.config.num_attention_heads) + pkv_shape = ( + 2, + lang_model.model.config.num_key_value_heads, + 2, + hidden_size // lang_model.model.config.num_attention_heads, + ) # input_embeds = torch.randn((1, 1, hidden_size)) cache_position = torch.arange(2, 4) position_ids = cache_position.view(1, 1, -1).expand(3, 2, -1) @@ -518,21 +593,41 @@ def forward_wrap_thinker( for i in range(num_pkv): kv = [torch.randn(pkv_shape) for _ in range(2)] past_key_values.append(kv) - input_names.extend([f"past_key_values.{i}.key", f"past_key_values.{i}.value"]) + input_names.extend( + [f"past_key_values.{i}.key", f"past_key_values.{i}.value"] + ) output_names.extend([f"present.{i}.key", f"present.{i}.value"]) input_names.append("inputs_embeds") - example_input = {"inputs_embeds": input_embeds, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values} + example_input = { + "inputs_embeds": input_embeds, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + } input_shapes = [ ov.PartialShape([-1, -1]), ov.PartialShape([3, -1, -1]), ] input_shapes += ( - [ov.PartialShape([-1, lang_model.model.config.num_key_value_heads, -1, hidden_size // lang_model.model.config.num_attention_heads])] * 2 * num_pkv + [ + ov.PartialShape( + [ + -1, + lang_model.model.config.num_key_value_heads, + -1, + hidden_size // lang_model.model.config.num_attention_heads, + ] + ) + ] + * 2 + * num_pkv ) input_shapes += [ov.PartialShape([-1, -1, input_embeds.shape[-1]])] __make_16bit_traceable(lang_model) - ov_model = ov.convert_model(lang_model, example_input=example_input, input=input_shapes) + ov_model = ov.convert_model( + lang_model, example_input=example_input, input=input_shapes + ) for input, input_name in zip(ov_model.inputs, input_names): input.get_tensor().set_names({input_name}) @@ -542,7 +637,9 @@ def forward_wrap_thinker( print("✅ Thinker language model successfully converted") if quantization_config is not None: - print(f"⌛ Weights compression with {quantization_config['mode']} mode started") + print( + f"⌛ Weights compression with {quantization_config['mode']} mode started" + ) ov_model = nncf.compress_weights(ov_model, **quantization_config) print("✅ Weights compression finished") @@ -550,7 +647,9 @@ def forward_wrap_thinker( del ov_model cleanup_torchscript_cache() gc.collect() - print(f"✅ Thinker model conversion finished. You can find results in {output_dir}") + print( + f"✅ Thinker model conversion finished. You can find results in {output_dir}" + ) if not talker_embedding_path.exists(): print("⌛ Convert talker embedding model") @@ -582,9 +681,19 @@ def forward_wrap_talker( """take care of image_encode, position_ids and (attention_mask = None is fine)""" if past_key_values is not None: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) talker_lm_input = self.thinker_to_talker_proj(inputs_embeds) outputs = self.model( @@ -598,7 +707,9 @@ def forward_wrap_talker( return_dict=return_dict, ) if past_key_values is not None: - outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + outputs["past_key_values"] = outputs[ + "past_key_values" + ].to_legacy_cache() hidden_states = outputs[0] logits = self.codec_head(hidden_states) @@ -614,7 +725,12 @@ def forward_wrap_talker( lang_model.forward = types.MethodType(forward_wrap_talker, lang_model) num_pkv = lang_model.model.config.num_hidden_layers - pkv_shape = (2, lang_model.model.config.num_key_value_heads, 2, lang_model.model.config.head_dim) + pkv_shape = ( + 2, + lang_model.model.config.num_key_value_heads, + 2, + lang_model.model.config.head_dim, + ) # input_embeds = torch.randn((1, 1, hidden_size)) cache_position = torch.arange(2, 4) position_ids = cache_position.view(1, 1, -1).expand(3, 2, -1) @@ -627,20 +743,42 @@ def forward_wrap_talker( for i in range(num_pkv): kv = [torch.randn(pkv_shape) for _ in range(2)] past_key_values.append(kv) - input_names.extend([f"past_key_values.{i}.key", f"past_key_values.{i}.value"]) + input_names.extend( + [f"past_key_values.{i}.key", f"past_key_values.{i}.value"] + ) output_names.extend([f"present.{i}.key", f"present.{i}.value"]) input_names.append("inputs_embeds") - example_input = {"inputs_embeds": input_embeds, "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values} + example_input = { + "inputs_embeds": input_embeds, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + } input_shapes = [ ov.PartialShape([-1, -1]), ov.PartialShape([3, -1, -1]), ] - input_shapes += [ov.PartialShape([-1, lang_model.model.config.num_key_value_heads, -1, lang_model.model.config.head_dim])] * 2 * num_pkv + input_shapes += ( + [ + ov.PartialShape( + [ + -1, + lang_model.model.config.num_key_value_heads, + -1, + lang_model.model.config.head_dim, + ] + ) + ] + * 2 + * num_pkv + ) input_shapes += [ov.PartialShape([-1, -1, input_embeds.shape[-1]])] __make_16bit_traceable(lang_model) - ov_model = ov.convert_model(lang_model, example_input=example_input, input=input_shapes) + ov_model = ov.convert_model( + lang_model, example_input=example_input, input=input_shapes + ) for input, input_name in zip(ov_model.inputs, input_names): input.get_tensor().set_names({input_name}) @@ -650,7 +788,9 @@ def forward_wrap_talker( print("✅ Talker language model successfully converted") if quantization_config is not None: - print(f"⌛ Weights compression with {quantization_config['mode']} mode started") + print( + f"⌛ Weights compression with {quantization_config['mode']} mode started" + ) ov_model = nncf.compress_weights(ov_model, **quantization_config) print("✅ Weights compression finished") @@ -658,7 +798,9 @@ def forward_wrap_talker( del ov_model cleanup_torchscript_cache() gc.collect() - print(f"✅ Talker model conversion finished. You can find results in {output_dir}") + print( + f"✅ Talker model conversion finished. You can find results in {output_dir}" + ) if not token2wav_dit_path.exists(): print("⌛ Convert token2wav DIT model") @@ -686,7 +828,9 @@ def forward_wrap_dit_attention( # apply rotary position embedding # Due to training process, only first head is applied with RoPE, will be fixed at next release cos, sin = position_embeddings - query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) + query[:, :1], key[:, :1] = apply_rotary_pos_emb( + query[:, :1], key[:, :1], cos, sin + ) attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation] attn_mask = torch.zeros_like(attention_mask, dtype=torch.float32) @@ -701,7 +845,9 @@ def forward_wrap_dit_attention( ) # mask. e.g. inference got a batch with different target durations, mask out the padding - attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) + attention_weights = attention_weights.reshape( + batch_size, -1, self.heads * head_dim + ) attention_weights = attention_weights.to(query.dtype) # linear proj @@ -712,16 +858,28 @@ def forward_wrap_dit_attention( code2wav_dit = model.token2wav.code2wav_dit_model for block in code2wav_dit.transformer_blocks: - block.attn.forward = types.MethodType(forward_wrap_dit_attention, block.attn) + block.attn.forward = types.MethodType( + forward_wrap_dit_attention, block.attn + ) __make_16bit_traceable(code2wav_dit) ov_model = ov.convert_model( code2wav_dit, example_input={ - "hidden_states": torch.randn([1, 4, model.token2wav.code2wav_dit_model.config.mel_dim], dtype=torch.float32), + "hidden_states": torch.randn( + [1, 4, model.token2wav.code2wav_dit_model.config.mel_dim], + dtype=torch.float32, + ), "quantized_code": torch.ones([1, 2], dtype=torch.int64), - "speaker_embedding": torch.randn([1, 4, model.token2wav.code2wav_dit_model.config.enc_emb_dim], dtype=torch.float32), - "condition_vector": torch.full((1, 400, model.token2wav.code2wav_dit_model.config.mel_dim), fill_value=-11.5129, dtype=torch.float32), + "speaker_embedding": torch.randn( + [1, 4, model.token2wav.code2wav_dit_model.config.enc_emb_dim], + dtype=torch.float32, + ), + "condition_vector": torch.full( + (1, 400, model.token2wav.code2wav_dit_model.config.mel_dim), + fill_value=-11.5129, + dtype=torch.float32, + ), "time_step": torch.tensor(0.0051, dtype=torch.float32), }, ) @@ -742,7 +900,9 @@ def forward_wrap_bigvgan(self, mel_spectrogram): for layer_index in range(self.num_upsample_layers): hidden_representation = self.ups[layer_index][0](hidden_representation) residual_output = sum( - self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) + self.resblocks[ + layer_index * self.num_residual_blocks + block_index + ](hidden_representation) for block_index in range(self.num_residual_blocks) ) residual_output = residual_output / self.num_residual_blocks @@ -754,12 +914,16 @@ def forward_wrap_bigvgan(self, mel_spectrogram): return audio code2wav_bigvgan = model.token2wav.code2wav_bigvgan_model - code2wav_bigvgan.forward = types.MethodType(forward_wrap_bigvgan, code2wav_bigvgan) + code2wav_bigvgan.forward = types.MethodType( + forward_wrap_bigvgan, code2wav_bigvgan + ) __make_16bit_traceable(code2wav_bigvgan) ov_model = ov.convert_model( code2wav_bigvgan, example_input={ - "mel_spectrogram": torch.randn([1, code2wav_bigvgan.config.mel_dim, 2], dtype=torch.float32), + "mel_spectrogram": torch.randn( + [1, code2wav_bigvgan.config.mel_dim, 2], dtype=torch.float32 + ), }, ) ov.save_model(ov_model, token2wav_bigvgan_path) @@ -768,7 +932,9 @@ def forward_wrap_bigvgan(self, mel_spectrogram): del model gc.collect() print("✅ Token2wav BIGVGAN model successfully converted") - print(f"✅ {model_id} model conversion finished. You can find results in {output_dir}") + print( + f"✅ {model_id} model conversion finished. You can find results in {output_dir}" + ) def get_llm_pos_ids_for_vision( @@ -782,9 +948,25 @@ def get_llm_pos_ids_for_vision( llm_pos_ids_list = [] llm_grid_h = grid_hs[vision_idx] // spatial_merge_size llm_grid_w = grid_ws[vision_idx] // spatial_merge_size - h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() - w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() - t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long() + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(len(t_index), -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(len(t_index), llm_grid_h, -1) + .flatten() + ) + t_index = ( + torch.Tensor(t_index) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + .long() + ) _llm_pos_ids = torch.stack([t_index, h_index, w_index]) # + 1 ) # 12.09 by malinhan llm_pos_ids_list.append(_llm_pos_ids + start_idx) @@ -830,7 +1012,9 @@ def get_rope_index( seconds_per_chunk = config.seconds_per_chunk mrope_position_deltas = [] - if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) @@ -846,18 +1030,34 @@ def get_rope_index( for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] image_nums, video_nums, audio_nums = 0, 0, 0 - vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] audio_nums = torch.sum(input_ids == audio_start_token_id) image_nums = (vision_tokens == image_token_id).sum() - video_nums = (vision_tokens == audio_start_token_id).sum() if use_audio_in_video else (vision_tokens == video_token_id).sum() + video_nums = ( + (vision_tokens == audio_start_token_id).sum() + if use_audio_in_video + else (vision_tokens == video_token_id).sum() + ) input_tokens = input_ids.tolist() llm_pos_ids_list: list = [] st = 0 - remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums - multimodal_nums = image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums + remain_images, remain_videos, remain_audios = ( + image_nums, + video_nums, + audio_nums, + ) + multimodal_nums = ( + image_nums + audio_nums + if use_audio_in_video + else image_nums + video_nums + audio_nums + ) for _ in range(multimodal_nums): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) if image_token_id in input_tokens and remain_images > 0: ed_image = input_tokens.index(image_token_id, st) else: @@ -874,21 +1074,45 @@ def get_rope_index( if min_ed == ed_audio: text_len = min_ed - st - 1 if text_len != 0: - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) bos_len = 1 - llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 - llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + llm_pos_ids = ( + torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + ) llm_pos_ids_list.append(llm_pos_ids) - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) eos_len = 1 - llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx + ) st += text_len + bos_len + audio_len + eos_len audio_idx += 1 @@ -897,25 +1121,53 @@ def get_rope_index( elif min_ed == ed_image: text_len = min_ed - st - 1 if text_len != 0: - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) bos_len = 1 - llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) grid_t = image_grid_thw[image_idx][0] grid_hs = image_grid_thw[:, 1] grid_ws = image_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long() - llm_pos_ids = get_llm_pos_ids_for_vision(st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws) - image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) + t_index = ( + torch.arange(grid_t) * 1 * position_id_per_seconds + ).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + image_len = image_grid_thw[image_idx].prod() // ( + spatial_merge_size**2 + ) llm_pos_ids_list.append(llm_pos_ids) - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) eos_len = 1 - llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx + ) st += text_len + bos_len + image_len + eos_len image_idx += 1 @@ -924,25 +1176,55 @@ def get_rope_index( elif min_ed == ed_video and not use_audio_in_video: text_len = min_ed - st - 1 if text_len != 0: - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) bos_len = 1 - llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) grid_t = video_grid_thw[video_idx][0] grid_hs = video_grid_thw[:, 1] grid_ws = video_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds).long() - llm_pos_ids = get_llm_pos_ids_for_vision(st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws) - video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + t_index = ( + torch.arange(grid_t) + * second_per_grids[video_idx].cpu().float() + * position_id_per_seconds + ).long() + llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + video_len = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) llm_pos_ids_list.append(llm_pos_ids) - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) eos_len = 1 - llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx + ) st += text_len + bos_len + video_len + eos_len video_idx += 1 @@ -951,45 +1233,105 @@ def get_rope_index( elif min_ed == ed_video and use_audio_in_video: text_len = min_ed - st - 2 if text_len != 0: - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) bos_len = 1 - llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) - llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx + ) + llm_pos_ids_list.append( + torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 - audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + audio_llm_pos_ids = ( + torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx + ) grid_t = video_grid_thw[video_idx][0] grid_hs = video_grid_thw[:, 1] grid_ws = video_grid_thw[:, 2] - t_index = (torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds).long() - video_llm_pos_ids = get_llm_pos_ids_for_vision(st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws) - - t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk) - video_chunk_indexes = get_chunked_index(video_llm_pos_ids, t_ntoken_per_chunk, st_idx) - audio_chunk_indexes = get_chunked_index(audio_llm_pos_ids, t_ntoken_per_chunk, st_idx) + t_index = ( + torch.arange(grid_t) + * second_per_grids[video_idx].cpu().float() + * position_id_per_seconds + ).long() + video_llm_pos_ids = get_llm_pos_ids_for_vision( + st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws + ) + + t_ntoken_per_chunk = int( + position_id_per_seconds * seconds_per_chunk + ) + video_chunk_indexes = get_chunked_index( + video_llm_pos_ids, t_ntoken_per_chunk, st_idx + ) + audio_chunk_indexes = get_chunked_index( + audio_llm_pos_ids, t_ntoken_per_chunk, st_idx + ) sub_len = 0 - for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): - video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None - audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None + for j in range( + max(len(video_chunk_indexes), len(audio_chunk_indexes)) + ): + video_chunk_index = ( + video_chunk_indexes[j] + if j < len(video_chunk_indexes) + else None + ) + audio_chunk_index = ( + audio_chunk_indexes[j] + if j < len(audio_chunk_indexes) + else None + ) if video_chunk_index is not None: sub_len += video_chunk_index[1] - video_chunk_index[0] - llm_pos_ids_list.append(video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]]) + llm_pos_ids_list.append( + video_llm_pos_ids[ + :, video_chunk_index[0] : video_chunk_index[1] + ] + ) if audio_chunk_index is not None: sub_len += audio_chunk_index[1] - audio_chunk_index[0] - llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]]) - video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) - - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + audio_llm_pos_ids[ + :, audio_chunk_index[0] : audio_chunk_index[1] + ] + ) + video_len = video_grid_thw[video_idx].prod() // ( + spatial_merge_size**2 + ) + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) eos_len = 1 - llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) - llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx + ) + llm_pos_ids_list.append( + torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx + ) st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 @@ -999,23 +1341,37 @@ def get_rope_index( remain_audios -= 1 if st < len(input_tokens): - st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) text_len = len(input_tokens) - st - llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) return position_ids, mrope_position_deltas else: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] - mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + position_ids = ( + position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[ + 0 + ] + mrope_position_deltas = ( + max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) + ) return position_ids, mrope_position_deltas @@ -1024,12 +1380,22 @@ class OVQwen2_5OmniThinkerForConditionalGeneration(GenerationMixin): def __init__(self, model_dir, device, config): self.model = core.read_model(model_dir / THINKER_LANGUAGE_NAME) self.audio = core.compile_model(model_dir / THINKER_AUDIO_NAME, device) - self.audio_state = core.compile_model(model_dir / THINKER_AUDIO_STATE_NAME, device) - self.visual_patcher = core.compile_model(model_dir / THINKER_PATCHER_NAME, device) + self.audio_state = core.compile_model( + model_dir / THINKER_AUDIO_STATE_NAME, device + ) + self.visual_patcher = core.compile_model( + model_dir / THINKER_PATCHER_NAME, device + ) self.visual_merger = core.compile_model(model_dir / THINKER_MERGER_NAME, device) - self.embed_tokens = core.compile_model(model_dir / THINKER_EMBEDDING_NAME, device) - self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)} - self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} + self.embed_tokens = core.compile_model( + model_dir / THINKER_EMBEDDING_NAME, device + ) + self.input_names = { + key.get_any_name(): idx for idx, key in enumerate(self.model.inputs) + } + self.output_names = { + key.get_any_name(): idx for idx, key in enumerate(self.model.outputs) + } compiled_model = core.compile_model(self.model, device) self.request = compiled_model.create_infer_request() self.main_input_name = "input_ids" @@ -1044,7 +1410,10 @@ def __init__(self, model_dir, device, config): self.patch_size = self.config.vision_config.patch_size self.fullatt_block_indexes = self.config.vision_config.fullatt_block_indexes self.window_size = self.config.vision_config.window_size - self.spatial_merge_unit = self.config.vision_config.spatial_merge_size * self.config.vision_config.spatial_merge_size + self.spatial_merge_unit = ( + self.config.vision_config.spatial_merge_size + * self.config.vision_config.spatial_merge_size + ) self._skip_keys_device_placement = "past_key_values" self._supports_flash_attn_2 = True self._supports_sdpa = True @@ -1054,15 +1423,21 @@ def __init__(self, model_dir, device, config): class Qwen2_5_VisionRotaryEmbedding(torch.nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + inv_freq = 1.0 / ( + theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + seq = torch.arange( + seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(seq, self.inv_freq) return freqs - head_dim = self.config.vision_config.hidden_size // self.config.vision_config.num_heads + head_dim = ( + self.config.vision_config.hidden_size // self.config.vision_config.num_heads + ) self._rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) def can_generate(self): @@ -1149,19 +1524,25 @@ def get_window_index(self, grid_thw): window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 - vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + vit_merger_window_size = ( + self.window_size // self.spatial_merge_size // self.patch_size + ) for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size, ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w + ) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = torch.nn.functional.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = torch.nn.functional.pad( + index, (0, pad_w, 0, pad_h), "constant", -100 + ) index_padded = index_padded.reshape( grid_t, num_windows_h, @@ -1179,7 +1560,9 @@ def get_window_index(self, grid_thw): index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) - cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_seqlens_tmp = ( + seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + ) cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) @@ -1195,23 +1578,47 @@ def visual(self, pixel_values, grid_thw, **kwargs): dtype=torch.int32, ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) - attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool) + attention_mask = torch.zeros( + (1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool + ) causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32) for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + attention_mask[ + ..., + cu_seqlens[i - 1] : cu_seqlens[i], + cu_seqlens[i - 1] : cu_seqlens[i], + ] = True causal_mask.masked_fill_(torch.logical_not(attention_mask), float("-inf")) - window_attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool) + window_attention_mask = torch.zeros( + (1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool + ) window_causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32) for i in range(1, len(cu_window_seqlens)): - window_attention_mask[..., cu_window_seqlens[i - 1] : cu_window_seqlens[i], cu_window_seqlens[i - 1] : cu_window_seqlens[i]] = True - - window_causal_mask.masked_fill_(torch.logical_not(window_attention_mask), float("-inf")) + window_attention_mask[ + ..., + cu_window_seqlens[i - 1] : cu_window_seqlens[i], + cu_window_seqlens[i - 1] : cu_window_seqlens[i], + ] = True + + window_causal_mask.masked_fill_( + torch.logical_not(window_attention_mask), float("-inf") + ) - res = self.visual_merger([hidden_states, causal_mask, window_causal_mask, window_index, rotary_pos_emb])[0] + res = self.visual_merger( + [ + hidden_states, + causal_mask, + window_causal_mask, + window_index, + rotary_pos_emb, + ] + )[0] return torch.from_numpy(res) def __call__( @@ -1242,7 +1649,9 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): output_lengths = (input_lengths - 2) // 2 + 1 return input_lengths, output_lengths - def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): + def padded_and_mask_function( + self, tensor_list, tensor_len, padding_value=0, padding_side="right" + ): max_len = tensor_len.max() dim = tensor_list[0].shape[0] padded_tensor = torch.full( @@ -1302,11 +1711,17 @@ def forward( ) -> Union[tuple, BaseModelOutputWithPast]: if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) - input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) + input_features = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) else: audio_feature_lengths = None if attention_mask is not None and position_ids is None: - if cache_position is None or (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + if ( + cache_position is None + or (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + ): delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) position_ids, rope_deltas = get_rope_index( self.config, @@ -1322,7 +1737,11 @@ def forward( self.rope_deltas = rope_deltas else: batch_size, seq_length = input_ids.shape - delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + delta = ( + cache_position[0] + self.rope_deltas + if cache_position is not None + else 0 + ) position_ids = torch.arange(seq_length, device=input_ids.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) position_ids = position_ids.add(delta) @@ -1332,10 +1751,19 @@ def forward( inputs_embeds = torch.from_numpy(self.embed_tokens(input_ids)[0]) if input_ids is not None and input_ids.shape[1] != 1: # Prefill stage if input_features is not None: - audio_feat_lengths, audio_output_lengths = self._get_feat_extract_output_lengths( - audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) + ( + audio_feat_lengths, + audio_output_lengths, + ) = self._get_feat_extract_output_lengths( + audio_feature_lengths + if audio_feature_lengths is not None + else feature_attention_mask.sum(-1) + ) + feature_lens = ( + audio_feature_lengths + if audio_feature_lengths is not None + else feature_attention_mask.sum(-1) ) - feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() chunk_lengths = torch.tensor( @@ -1343,39 +1771,75 @@ def forward( dtype=torch.long, device=feature_lens.device, ) - tail_chunk_index = list(accumulate(chunk_num.tolist(), func=operator.add, initial=-1))[1:] + tail_chunk_index = list( + accumulate(chunk_num.tolist(), func=operator.add, initial=-1) + )[1:] chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) - chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) + chunk_lengths = torch.where( + chunk_lengths == 0, self.n_window * 2, chunk_lengths + ) chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) - padded_feature, padded_mask, padded_mask_after_cnn = self.padded_and_mask_function( + ( + padded_feature, + padded_mask, + padded_mask_after_cnn, + ) = self.padded_and_mask_function( chunk_list, chunk_lengths, padding_value=0, padding_side="right" ) - hidden_states = torch.from_numpy(self.audio([padded_feature, padded_mask, padded_mask_after_cnn])[0]) - hidden_states_list = hidden_states.split(audio_feat_lengths.tolist(), dim=0) + hidden_states = torch.from_numpy( + self.audio([padded_feature, padded_mask, padded_mask_after_cnn])[0] + ) + hidden_states_list = hidden_states.split( + audio_feat_lengths.tolist(), dim=0 + ) token_audio_list = [] for each_audio_states in hidden_states_list: - each_audio_states = torch.from_numpy(self.audio_state([each_audio_states])[0]) + each_audio_states = torch.from_numpy( + self.audio_state([each_audio_states])[0] + ) token_audio_list.append(each_audio_states) audio_features = torch.cat(token_audio_list, dim=0) if audio_features.shape[0] != sum(audio_output_lengths.tolist()): - raise ValueError("length of audio_features should match audio_output_lengths") - audio_mask = (input_ids == self.config.audio_token_index).unsqueeze(-1).expand_as(inputs_embeds) - audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + raise ValueError( + "length of audio_features should match audio_output_lengths" + ) + audio_mask = ( + (input_ids == self.config.audio_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + ) + audio_features = audio_features.to( + inputs_embeds.device, inputs_embeds.dtype + ) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) if pixel_values is not None: image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + video_mask = ( + (input_ids == self.config.video_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: @@ -1390,7 +1854,11 @@ def forward( inputs["attention_mask"] = attention_mask inputs["position_ids"] = position_ids if "beam_idx" in self.input_names: - inputs["beam_idx"] = self.next_beam_idx if self.next_beam_idx is not None else np.arange(inputs_embeds.shape[0], dtype=int) + inputs["beam_idx"] = ( + self.next_beam_idx + if self.next_beam_idx is not None + else np.arange(inputs_embeds.shape[0], dtype=int) + ) self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = self.request.get_tensor("logits").data @@ -1401,16 +1869,23 @@ def forward( embeds_to_talker = inputs_embeds.clone() hidden_states_output = hidden_states.clone() return Qwen2_5OmniThinkerCausalLMOutputWithPast( - logits=logits, past_key_values=past_key_values, rope_deltas=rope_deltas, hidden_states=(embeds_to_talker, hidden_states_output) + logits=logits, + past_key_values=past_key_values, + rope_deltas=rope_deltas, + hidden_states=(embeds_to_talker, hidden_states_output), ) - def _reorder_cache(self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor) -> tuple[tuple[torch.Tensor]]: + def _reorder_cache( + self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> tuple[tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ - self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration + self.next_beam_idx = np.array( + beam_idx + ) # save beam_idx to be used as an input in the next iteration return past_key_values def _get_past_length(self, past_key_values=None): @@ -1445,8 +1920,12 @@ def __init__(self, model_dir, device, config): self.model = core.read_model(model_dir / TALKER_LANGUAGE_NAME) self.embed_tokens = core.compile_model(model_dir / TALKER_EMBEDDING_NAME, "CPU") - self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)} - self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} + self.input_names = { + key.get_any_name(): idx for idx, key in enumerate(self.model.inputs) + } + self.output_names = { + key.get_any_name(): idx for idx, key in enumerate(self.model.outputs) + } compiled_model = core.compile_model(self.model, device) self.request = compiled_model.create_infer_request() self.config = config.talker_config @@ -1522,7 +2001,9 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPast]: if attention_mask is not None and position_ids is None: - if cache_position is None or (cache_position is not None and cache_position[0] == 0): + if cache_position is None or ( + cache_position is not None and cache_position[0] == 0 + ): position_ids, rope_deltas = get_rope_index( self.config, input_text_ids, @@ -1534,15 +2015,31 @@ def forward( video_second_per_grid, ) inputs_embeds[:, -1, :] += torch.from_numpy( - self.embed_tokens(torch.tensor([[self.codec_bos_token]], dtype=torch.long, device=inputs_embeds.device))[0][0] + self.embed_tokens( + torch.tensor( + [[self.codec_bos_token]], + dtype=torch.long, + device=inputs_embeds.device, + ) + )[0][0] ) inputs_embeds[:, -2, :] += torch.from_numpy( - self.embed_tokens(torch.tensor([[self.codec_pad_token]], dtype=torch.long, device=inputs_embeds.device))[0][0] + self.embed_tokens( + torch.tensor( + [[self.codec_pad_token]], + dtype=torch.long, + device=inputs_embeds.device, + ) + )[0][0] ) else: batch_size, seq_length = input_ids.shape - delta = cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 + delta = ( + cache_position[0] + rope_deltas + if cache_position is not None and rope_deltas is not None + else 0 + ) position_ids = torch.arange(seq_length, device=input_ids.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) position_ids = position_ids.add(delta) @@ -1563,7 +2060,11 @@ def forward( inputs["attention_mask"] = attention_mask inputs["position_ids"] = position_ids if "beam_idx" in self.input_names: - inputs["beam_idx"] = self.next_beam_idx if self.next_beam_idx is not None else np.arange(inputs_embeds.shape[0], dtype=int) + inputs["beam_idx"] = ( + self.next_beam_idx + if self.next_beam_idx is not None + else np.arange(inputs_embeds.shape[0], dtype=int) + ) self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = self.request.get_tensor("logits").data @@ -1580,7 +2081,9 @@ def forward( def _get_initial_cache_position(self, input_ids, device, model_kwargs): # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily inputs_embeds = model_kwargs.pop("inputs_embeds") - model_kwargs = super()._get_initial_cache_position(input_ids, device, model_kwargs) + model_kwargs = super()._get_initial_cache_position( + input_ids, device, model_kwargs + ) model_kwargs["inputs_embeds"] = inputs_embeds return model_kwargs @@ -1629,13 +2132,17 @@ def prepare_inputs_for_generation( return model_inputs - def _reorder_cache(self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor) -> tuple[tuple[torch.Tensor]]: + def _reorder_cache( + self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> tuple[tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ - self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration + self.next_beam_idx = np.array( + beam_idx + ) # save beam_idx to be used as an input in the next iteration return past_key_values def _get_past_length(self, past_key_values=None): @@ -1654,7 +2161,9 @@ def _update_model_kwargs_for_generation( if getattr(outputs, "attention_mask", None) is not None: model_kwargs["attention_mask"] = outputs.attention_mask - model_kwargs = super()._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder, num_new_tokens) + model_kwargs = super()._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder, num_new_tokens + ) if getattr(outputs, "rope_deltas", None) is not None: model_kwargs["rope_deltas"] = outputs.rope_deltas @@ -1673,18 +2182,48 @@ def __init__(self, function, initial_value): self._one_third = 1 / 3 self._two_thirds = 2 / 3 - def _rk4_step(self, function, time_start, time_step, time_end, value_start, function_value_start=None): - k1 = function_value_start if function_value_start is not None else function(time_start, value_start) - k2 = function(time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third) - k3 = function(time_start + time_step * self._two_thirds, value_start + time_step * (k2 - k1 * self._one_third)) + def _rk4_step( + self, + function, + time_start, + time_step, + time_end, + value_start, + function_value_start=None, + ): + k1 = ( + function_value_start + if function_value_start is not None + else function(time_start, value_start) + ) + k2 = function( + time_start + time_step * self._one_third, + value_start + time_step * k1 * self._one_third, + ) + k3 = function( + time_start + time_step * self._two_thirds, + value_start + time_step * (k2 - k1 * self._one_third), + ) k4 = function(time_end, value_start + time_step * (k1 - k2 + k3)) return (k1 + 3 * (k2 + k3) + k4) * time_step / 8 def _compute_step(self, function, time_start, time_step, time_end, value_start): function_value_start = function(time_start, value_start) - return self._rk4_step(function, time_start, time_step, time_end, value_start, function_value_start=function_value_start), function_value_start + return ( + self._rk4_step( + function, + time_start, + time_step, + time_end, + value_start, + function_value_start=function_value_start, + ), + function_value_start, + ) - def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point): + def _linear_interpolation( + self, time_start, time_end, value_start, value_end, time_point + ): if time_point == time_start: return value_start if time_point == time_end: @@ -1705,11 +2244,22 @@ def integrate(self, time_points): current_value = self.initial_value for time_start, time_end in zip(time_points[:-1], time_points[1:]): time_step = time_end - time_start - delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value) + delta_value, _ = self._compute_step( + self.function, time_start, time_step, time_end, current_value + ) next_value = current_value + delta_value - while current_index < len(time_points) and time_end >= time_points[current_index]: - solution[current_index] = self._linear_interpolation(time_start, time_end, current_value, next_value, time_points[current_index]) + while ( + current_index < len(time_points) + and time_end >= time_points[current_index] + ): + solution[current_index] = self._linear_interpolation( + time_start, + time_end, + current_value, + next_value, + time_points[current_index], + ) current_index += 1 current_value = next_value @@ -1723,7 +2273,9 @@ def __init__(self, model_dir, thinker_device, talker_device, token2wav_device): self.has_talker = self.config.enable_audio_output model_path = Path(model_dir) - self.thinker = OVQwen2_5OmniThinkerForConditionalGeneration(model_path / "thinker", thinker_device, self.config) + self.thinker = OVQwen2_5OmniThinkerForConditionalGeneration( + model_path / "thinker", thinker_device, self.config + ) self.speaker_map = {} if self.config.enable_audio_output: self.enable_talker(model_path, talker_device, token2wav_device) @@ -1734,9 +2286,15 @@ def __init__(self, model_dir, thinker_device, talker_device, token2wav_device): def enable_talker(self, model_path, device, token2wav_device=None): if token2wav_device is None: token2wav_device = device - self.talker = OVQwen2_5OmniTalkerForConditionalGeneration(model_path / "talker", device, self.config) - self.token2wav_dit = core.compile_model(model_path / TOKEN2WAV_DIT_NAME, token2wav_device) - self.token2wav_bigvgan = core.compile_model(model_path / TOKEN2WAV_BIGVGAN_NAME, token2wav_device) + self.talker = OVQwen2_5OmniTalkerForConditionalGeneration( + model_path / "talker", device, self.config + ) + self.token2wav_dit = core.compile_model( + model_path / TOKEN2WAV_DIT_NAME, token2wav_device + ) + self.token2wav_bigvgan = core.compile_model( + model_path / TOKEN2WAV_BIGVGAN_NAME, token2wav_device + ) self.has_talker = True def load_speakers(self, path): @@ -1796,13 +2354,19 @@ def generate( - **Audio waveform** (`torch.Tensor`): Generated audio waveform. """ if speaker not in self.speaker_map: - raise ValueError(f"{speaker} is not availible, availible speakers: {self.speaker_map.keys()}") + raise ValueError( + f"{speaker} is not availible, availible speakers: {self.speaker_map.keys()}" + ) if return_audio and not self.has_talker: - raise ValueError("Cannot use talker when talker module not initalized. Use `enable_talker` method or set enable_talker in config to enable talker.") + raise ValueError( + "Cannot use talker when talker module not initalized. Use `enable_talker` method or set enable_talker in config to enable talker." + ) if return_audio is None: return_audio = self.has_talker if input_ids.shape[0] != 1 and return_audio: - raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output") + raise NotImplementedError( + "Qwen2.5-Omni currently does not support batched inference with audio output" + ) shared_kwargs = {"use_audio_in_video": use_audio_in_video} thinker_kwargs = { @@ -1860,14 +2424,24 @@ def generate( return thinker_result # 2. Generate speech tokens from talker module - thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device) - thinker_token_embeds = [x[0].to(self.talker.device) for x in thinker_result.hidden_states] - thinker_hidden_states = [x[1].to(self.talker.device) for x in thinker_result.hidden_states] + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to( + self.talker.device + ) + thinker_token_embeds = [ + x[0].to(self.talker.device) for x in thinker_result.hidden_states + ] + thinker_hidden_states = [ + x[1].to(self.talker.device) for x in thinker_result.hidden_states + ] talker_text_bos_token = speaker_params["bos_token"] talker_input_text_ids = torch.cat( [ input_ids.to(self.talker.device), - torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.talker.device), + torch.tensor( + [[talker_text_bos_token]], + dtype=torch.long, + device=self.talker.device, + ), thinker_generate_ids[:, :1], ], dim=-1, @@ -1875,17 +2449,35 @@ def generate( talker_input_ids = torch.cat( [ - torch.full_like(input_ids, fill_value=self.talker.codec_mask_token, device=self.talker.device), - torch.tensor([[self.talker.codec_pad_token]], dtype=torch.long, device=self.talker.device), - torch.tensor([[self.talker.codec_bos_token]], dtype=torch.long, device=self.talker.device), + torch.full_like( + input_ids, + fill_value=self.talker.codec_mask_token, + device=self.talker.device, + ), + torch.tensor( + [[self.talker.codec_pad_token]], + dtype=torch.long, + device=self.talker.device, + ), + torch.tensor( + [[self.talker.codec_bos_token]], + dtype=torch.long, + device=self.talker.device, + ), ], dim=1, ) - thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat( + thinker_token_embeds[1:], dim=1 + ) talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] - talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device) - talker_text_bos_embed = torch.from_numpy(self.thinker.embed_tokens(talker_text_bos_token)[0]).to(self.talker.device) + talker_text_bos_token = torch.tensor( + [[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device + ) + talker_text_bos_embed = torch.from_numpy( + self.thinker.embed_tokens(talker_text_bos_token)[0] + ).to(self.talker.device) talker_inputs_embeds = torch.cat( [ @@ -1897,11 +2489,23 @@ def generate( ) eos_embedding = torch.from_numpy( - self.thinker.embed_tokens(torch.tensor([[self.talker.text_eos_token]], dtype=torch.long, device=self.thinker.device))[0] + self.thinker.embed_tokens( + torch.tensor( + [[self.talker.text_eos_token]], + dtype=torch.long, + device=self.thinker.device, + ) + )[0] ).to(self.talker.device) pad_embedding = torch.from_numpy( - self.thinker.embed_tokens(torch.tensor([[self.talker.text_pad_token]], dtype=torch.long, device=self.thinker.device))[0] + self.thinker.embed_tokens( + torch.tensor( + [[self.talker.text_pad_token]], + dtype=torch.long, + device=self.thinker.device, + ) + )[0] ).to(self.talker.device) thinker_reply_part = torch.cat( @@ -1915,7 +2519,10 @@ def generate( talker_attention_mask = None if "attention_mask" in kwargs: - talker_attention_mask = torch.cat([kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], dim=1).to(self.talker.device) + talker_attention_mask = torch.cat( + [kwargs["attention_mask"], kwargs["attention_mask"].new_ones((1, 2))], + dim=1, + ).to(self.talker.device) print("[===start talker===]") talker_result = self.talker.generate( input_ids=talker_input_ids, @@ -1924,19 +2531,34 @@ def generate( inputs_embeds=talker_inputs_embeds, attention_mask=talker_attention_mask, suppress_tokens=[self.talker.codec_bos_token], - **{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()}, + **{ + k: (v.to(self.talker.device) if torch.is_tensor(v) else v) + for k, v in talker_kwargs.items() + }, ) talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1] print("[===start token2wav===]") # 3. Generate wavs from code - reference_mel_spectrogram = speaker_params["ref_mel"].to(torch.device("cpu")).float() + reference_mel_spectrogram = ( + speaker_params["ref_mel"].to(torch.device("cpu")).float() + ) conditioning_vector = speaker_params["cond"].to(torch.device("cpu")).float() - noise_initialization = torch.randn([1, 30000, self.config.token2wav_config.dit_config.mel_dim], dtype=reference_mel_spectrogram.dtype) - maximum_duration = talker_generate_codes.shape[1] * self.config.token2wav_config.dit_config.repeats - initial_state = noise_initialization[:, :maximum_duration].to(talker_generate_codes.device) + noise_initialization = torch.randn( + [1, 30000, self.config.token2wav_config.dit_config.mel_dim], + dtype=reference_mel_spectrogram.dtype, + ) + maximum_duration = ( + talker_generate_codes.shape[1] + * self.config.token2wav_config.dit_config.repeats + ) + initial_state = noise_initialization[:, :maximum_duration].to( + talker_generate_codes.device + ) batch_size = reference_mel_spectrogram.shape[0] - conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1) + conditioning_vector = conditioning_vector.unsqueeze(1).repeat( + 1, maximum_duration, 1 + ) if batch_size != 1: raise ValueError("Only batch size = 1 is currently supported") guidance_scale = 0.5 @@ -1944,23 +2566,46 @@ def generate( def ode_function(time_step, hidden_states): model_output = torch.from_numpy( - self.token2wav_dit([hidden_states, reference_mel_spectrogram, conditioning_vector, talker_generate_codes, time_step])[0] + self.token2wav_dit( + [ + hidden_states, + reference_mel_spectrogram, + conditioning_vector, + talker_generate_codes, + time_step, + ] + )[0] ) guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) - return guided_prediction + (guided_prediction - null_prediction) * guidance_scale + return ( + guided_prediction + + (guided_prediction - null_prediction) * guidance_scale + ) initial_time = 0 - time_embedding = torch.linspace(initial_time, 1, 10, device=talker_generate_codes.device, dtype=conditioning_vector.dtype) + time_embedding = torch.linspace( + initial_time, + 1, + 10, + device=talker_generate_codes.device, + dtype=conditioning_vector.dtype, + ) if sway_coefficient is not None: - time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) + time_embedding += sway_coefficient * ( + torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding + ) - ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) + ode_solver = RungeKutta4ODESolver( + function=ode_function, initial_value=initial_state + ) solution_trajectory = ode_solver.integrate(time_embedding) generated_waveform = solution_trajectory[-1] generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) - waveform = torch.from_numpy(self.token2wav_bigvgan([generated_mel_spectrogram])[0]) + waveform = torch.from_numpy( + self.token2wav_bigvgan([generated_mel_spectrogram])[0] + ) waveform.squeeze().cpu() return thinker_result.sequences, waveform.float() From 79c88e780cfebbaad7bb5e6e6d9935631a6f604a Mon Sep 17 00:00:00 2001 From: ethan Date: Wed, 2 Jul 2025 18:47:29 -0700 Subject: [PATCH 4/5] reformat --- .../qwen2_5_omni_helper.py | 830 ++++-------------- 1 file changed, 189 insertions(+), 641 deletions(-) diff --git a/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py b/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py index 847612565e4..30c6f021355 100644 --- a/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py +++ b/notebooks/qwen2.5-omni-chatbot/qwen2_5_omni_helper.py @@ -57,9 +57,7 @@ def model_has_input_output_name(ov_model: ov.Model, name: str): Returns: True if input or output with requested name exists else False """ - return name in sum( - [list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], [] - ) + return name in sum([list(t.get_names()) for t in ov_model.inputs + ov_model.outputs], []) def fuse_cache_reorder( @@ -92,9 +90,7 @@ def fuse_cache_reorder( if model_has_input_output_name(ov_model, "beam_idx"): raise ValueError("Model already has fused cache") input_batch = ov_model.input("inputs_embeds").get_partial_shape()[0] - beam_idx = opset13.parameter( - name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch]) - ) + beam_idx = opset13.parameter(name="beam_idx", dtype=ov.Type.i32, shape=ov.PartialShape([input_batch])) beam_idx.output(0).get_tensor().add_names({"beam_idx"}) # why list is not accepted? ov_model.add_parameters([beam_idx]) not_kv_inputs.append(ov_model.inputs[-1]) @@ -102,9 +98,7 @@ def fuse_cache_reorder( for input_name in key_value_input_names: parameter_output_port = ov_model.input(input_name) consumers = parameter_output_port.get_target_inputs() - gather = opset13.gather( - parameter_output_port, beam_idx, opset13.constant(gather_dim) - ) + gather = opset13.gather(parameter_output_port, beam_idx, opset13.constant(gather_dim)) for consumer in consumers: consumer.replace_source_output(gather.output(0)) ov_model.validate_nodes_and_infer_types() @@ -130,18 +124,9 @@ def build_state_initializer(ov_model: ov.Model, batch_dim: int): if op.get_type_name() == "ReadValue": dims = [dim.min_length for dim in list(op.get_output_partial_shape(0))] dims[batch_dim] = batch - dims = [ - ( - opset13.constant(np.array([dim], dtype=np.int64)) - if isinstance(dim, int) - else dim - ) - for dim in dims - ] + dims = [(opset13.constant(np.array([dim], dtype=np.int64)) if isinstance(dim, int) else dim) for dim in dims] shape = opset13.concat(dims, axis=0) - broadcast = opset13.broadcast( - opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape - ) + broadcast = opset13.broadcast(opset13.constant(0.0, dtype=op.get_output_element_type(0)), shape) op.set_arguments([broadcast]) ov_model.validate_nodes_and_infer_types() @@ -205,11 +190,7 @@ def make_stateful( def patch_stateful(ov_model, dim): key_value_input_names = [key.get_any_name() for key in ov_model.inputs[2:-1]] key_value_output_names = [key.get_any_name() for key in ov_model.outputs[dim:]] - not_kv_inputs = [ - input - for input in ov_model.inputs - if not any(name in key_value_input_names for name in input.get_names()) - ] + not_kv_inputs = [input for input in ov_model.inputs if not any(name in key_value_input_names for name in input.get_names())] if not key_value_input_names or not key_value_output_names: return batch_dim = 0 @@ -254,9 +235,7 @@ def cleanup_torchscript_cache(): TOKEN2WAV_BIGVGAN_NAME = "openvino_token2wav_bigvgan_model.xml" -def convert_qwen2_5_omni_model( - model_id, output_dir, quantization_config=None, use_local_dir=False -): +def convert_qwen2_5_omni_model(model_id, output_dir, quantization_config=None, use_local_dir=False): thinker_output_dir = Path(output_dir) / "thinker" talker_output_dir = Path(output_dir) / "talker" @@ -285,9 +264,7 @@ def convert_qwen2_5_omni_model( token2wav_bigvgan_path.exists(), ] ): - print( - f"✅ {model_id} model already converted. You can find results in {output_dir}" - ) + print(f"✅ {model_id} model already converted. You can find results in {output_dir}") return print(f"⌛ {model_id} conversion started. Be patient, it may takes some time.") print("⌛ Load Original model") @@ -307,9 +284,7 @@ def convert_qwen2_5_omni_model( config.thinker_config._attn_implementation = "sdpa" config.talker_config._attn_implementation_autoset = False config.talker_config._attn_implementation = "sdpa" - model = Qwen2_5OmniForConditionalGeneration.from_pretrained( - ckpt, config=config, torch_dtype=torch.float16 - ) + model = Qwen2_5OmniForConditionalGeneration.from_pretrained(ckpt, config=config, torch_dtype=torch.float16) model.eval() processor = AutoProcessor.from_pretrained(ckpt) @@ -334,9 +309,7 @@ def forward_wrap_audio(self, padded_feature, padded_mask, padded_mask_after_cnn) padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) - padded_embed = padded_embed + self.positional_embedding.positional_embedding[ - : padded_embed.shape[1], : - ].unsqueeze(0).to(padded_embed.dtype) + padded_embed = padded_embed + self.positional_embedding.positional_embedding[: padded_embed.shape[1], :].unsqueeze(0).to(padded_embed.dtype) hidden_states = padded_embed[padded_mask_after_cnn] cu_seqlens = torch.cat( ( @@ -355,9 +328,7 @@ def forward_wrap_audio(self, padded_feature, padded_mask, padded_mask_after_cnn) return hidden_states def forward_wrap_audio_state(self, each_audio_states): - each_audio_states = self.avg_pooler( - each_audio_states.transpose(0, 1) - ).transpose_(0, 1) + each_audio_states = self.avg_pooler(each_audio_states.transpose(0, 1)).transpose_(0, 1) each_audio_states = self.ln_post(each_audio_states) each_audio_states = self.proj(each_audio_states) return each_audio_states @@ -422,14 +393,10 @@ def image_embed_forward( rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: seq_len = hidden_states.shape[0] - hidden_states = hidden_states.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 - ) + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) - rotary_pos_emb = rotary_pos_emb.reshape( - seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1 - ) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) rotary_pos_emb = rotary_pos_emb[window_index, :, :] rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) for layer_num, blk in enumerate(self.blocks): @@ -469,17 +436,13 @@ def sdpa_attn_forward( q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attention_mask, dropout_p=0.0 - ) + attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) return attn_output - def block_forward( - self, hidden_states, attention_mask, rotary_pos_emb - ) -> torch.Tensor: + def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), attention_mask=attention_mask, @@ -489,9 +452,7 @@ def block_forward( return hidden_states if not thinker_merger_path.exists(): - vision_embed_tokens.forward = types.MethodType( - image_embed_forward, vision_embed_tokens - ) + vision_embed_tokens.forward = types.MethodType(image_embed_forward, vision_embed_tokens) for block in vision_embed_tokens.blocks: block.forward = types.MethodType(block_forward, block) block.attn.forward = types.MethodType(sdpa_attn_forward, block.attn) @@ -533,19 +494,9 @@ def forward_wrap_thinker( """take care of image_encode, position_ids and (attention_mask = None is fine)""" if past_key_values is not None: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, @@ -560,9 +511,7 @@ def forward_wrap_thinker( cache_position=cache_position, ) if past_key_values is not None: - outputs["past_key_values"] = outputs[ - "past_key_values" - ].to_legacy_cache() + outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() hidden_states = outputs[0] logits = self.lm_head(hidden_states) output = (logits,) + outputs[:] @@ -593,9 +542,7 @@ def forward_wrap_thinker( for i in range(num_pkv): kv = [torch.randn(pkv_shape) for _ in range(2)] past_key_values.append(kv) - input_names.extend( - [f"past_key_values.{i}.key", f"past_key_values.{i}.value"] - ) + input_names.extend([f"past_key_values.{i}.key", f"past_key_values.{i}.value"]) output_names.extend([f"present.{i}.key", f"present.{i}.value"]) input_names.append("inputs_embeds") example_input = { @@ -625,9 +572,7 @@ def forward_wrap_thinker( ) input_shapes += [ov.PartialShape([-1, -1, input_embeds.shape[-1]])] __make_16bit_traceable(lang_model) - ov_model = ov.convert_model( - lang_model, example_input=example_input, input=input_shapes - ) + ov_model = ov.convert_model(lang_model, example_input=example_input, input=input_shapes) for input, input_name in zip(ov_model.inputs, input_names): input.get_tensor().set_names({input_name}) @@ -637,9 +582,7 @@ def forward_wrap_thinker( print("✅ Thinker language model successfully converted") if quantization_config is not None: - print( - f"⌛ Weights compression with {quantization_config['mode']} mode started" - ) + print(f"⌛ Weights compression with {quantization_config['mode']} mode started") ov_model = nncf.compress_weights(ov_model, **quantization_config) print("✅ Weights compression finished") @@ -647,9 +590,7 @@ def forward_wrap_thinker( del ov_model cleanup_torchscript_cache() gc.collect() - print( - f"✅ Thinker model conversion finished. You can find results in {output_dir}" - ) + print(f"✅ Thinker model conversion finished. You can find results in {output_dir}") if not talker_embedding_path.exists(): print("⌛ Convert talker embedding model") @@ -681,19 +622,9 @@ def forward_wrap_talker( """take care of image_encode, position_ids and (attention_mask = None is fine)""" if past_key_values is not None: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict talker_lm_input = self.thinker_to_talker_proj(inputs_embeds) outputs = self.model( @@ -707,9 +638,7 @@ def forward_wrap_talker( return_dict=return_dict, ) if past_key_values is not None: - outputs["past_key_values"] = outputs[ - "past_key_values" - ].to_legacy_cache() + outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() hidden_states = outputs[0] logits = self.codec_head(hidden_states) @@ -743,9 +672,7 @@ def forward_wrap_talker( for i in range(num_pkv): kv = [torch.randn(pkv_shape) for _ in range(2)] past_key_values.append(kv) - input_names.extend( - [f"past_key_values.{i}.key", f"past_key_values.{i}.value"] - ) + input_names.extend([f"past_key_values.{i}.key", f"past_key_values.{i}.value"]) output_names.extend([f"present.{i}.key", f"present.{i}.value"]) input_names.append("inputs_embeds") example_input = { @@ -776,9 +703,7 @@ def forward_wrap_talker( input_shapes += [ov.PartialShape([-1, -1, input_embeds.shape[-1]])] __make_16bit_traceable(lang_model) - ov_model = ov.convert_model( - lang_model, example_input=example_input, input=input_shapes - ) + ov_model = ov.convert_model(lang_model, example_input=example_input, input=input_shapes) for input, input_name in zip(ov_model.inputs, input_names): input.get_tensor().set_names({input_name}) @@ -788,9 +713,7 @@ def forward_wrap_talker( print("✅ Talker language model successfully converted") if quantization_config is not None: - print( - f"⌛ Weights compression with {quantization_config['mode']} mode started" - ) + print(f"⌛ Weights compression with {quantization_config['mode']} mode started") ov_model = nncf.compress_weights(ov_model, **quantization_config) print("✅ Weights compression finished") @@ -798,9 +721,7 @@ def forward_wrap_talker( del ov_model cleanup_torchscript_cache() gc.collect() - print( - f"✅ Talker model conversion finished. You can find results in {output_dir}" - ) + print(f"✅ Talker model conversion finished. You can find results in {output_dir}") if not token2wav_dit_path.exists(): print("⌛ Convert token2wav DIT model") @@ -828,9 +749,7 @@ def forward_wrap_dit_attention( # apply rotary position embedding # Due to training process, only first head is applied with RoPE, will be fixed at next release cos, sin = position_embeddings - query[:, :1], key[:, :1] = apply_rotary_pos_emb( - query[:, :1], key[:, :1], cos, sin - ) + query[:, :1], key[:, :1] = apply_rotary_pos_emb(query[:, :1], key[:, :1], cos, sin) attention_interface = ALL_ATTENTION_FUNCTIONS[self._attn_implementation] attn_mask = torch.zeros_like(attention_mask, dtype=torch.float32) @@ -845,9 +764,7 @@ def forward_wrap_dit_attention( ) # mask. e.g. inference got a batch with different target durations, mask out the padding - attention_weights = attention_weights.reshape( - batch_size, -1, self.heads * head_dim - ) + attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim) attention_weights = attention_weights.to(query.dtype) # linear proj @@ -858,9 +775,7 @@ def forward_wrap_dit_attention( code2wav_dit = model.token2wav.code2wav_dit_model for block in code2wav_dit.transformer_blocks: - block.attn.forward = types.MethodType( - forward_wrap_dit_attention, block.attn - ) + block.attn.forward = types.MethodType(forward_wrap_dit_attention, block.attn) __make_16bit_traceable(code2wav_dit) ov_model = ov.convert_model( @@ -900,9 +815,7 @@ def forward_wrap_bigvgan(self, mel_spectrogram): for layer_index in range(self.num_upsample_layers): hidden_representation = self.ups[layer_index][0](hidden_representation) residual_output = sum( - self.resblocks[ - layer_index * self.num_residual_blocks + block_index - ](hidden_representation) + self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation) for block_index in range(self.num_residual_blocks) ) residual_output = residual_output / self.num_residual_blocks @@ -914,16 +827,12 @@ def forward_wrap_bigvgan(self, mel_spectrogram): return audio code2wav_bigvgan = model.token2wav.code2wav_bigvgan_model - code2wav_bigvgan.forward = types.MethodType( - forward_wrap_bigvgan, code2wav_bigvgan - ) + code2wav_bigvgan.forward = types.MethodType(forward_wrap_bigvgan, code2wav_bigvgan) __make_16bit_traceable(code2wav_bigvgan) ov_model = ov.convert_model( code2wav_bigvgan, example_input={ - "mel_spectrogram": torch.randn( - [1, code2wav_bigvgan.config.mel_dim, 2], dtype=torch.float32 - ), + "mel_spectrogram": torch.randn([1, code2wav_bigvgan.config.mel_dim, 2], dtype=torch.float32), }, ) ov.save_model(ov_model, token2wav_bigvgan_path) @@ -932,9 +841,7 @@ def forward_wrap_bigvgan(self, mel_spectrogram): del model gc.collect() print("✅ Token2wav BIGVGAN model successfully converted") - print( - f"✅ {model_id} model conversion finished. You can find results in {output_dir}" - ) + print(f"✅ {model_id} model conversion finished. You can find results in {output_dir}") def get_llm_pos_ids_for_vision( @@ -948,25 +855,9 @@ def get_llm_pos_ids_for_vision( llm_pos_ids_list = [] llm_grid_h = grid_hs[vision_idx] // spatial_merge_size llm_grid_w = grid_ws[vision_idx] // spatial_merge_size - h_index = ( - torch.arange(llm_grid_h) - .view(1, -1, 1) - .expand(len(t_index), -1, llm_grid_w) - .flatten() - ) - w_index = ( - torch.arange(llm_grid_w) - .view(1, 1, -1) - .expand(len(t_index), llm_grid_h, -1) - .flatten() - ) - t_index = ( - torch.Tensor(t_index) - .view(-1, 1) - .expand(-1, llm_grid_h * llm_grid_w) - .flatten() - .long() - ) + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(len(t_index), -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(len(t_index), llm_grid_h, -1).flatten() + t_index = torch.Tensor(t_index).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten().long() _llm_pos_ids = torch.stack([t_index, h_index, w_index]) # + 1 ) # 12.09 by malinhan llm_pos_ids_list.append(_llm_pos_ids + start_idx) @@ -1012,9 +903,7 @@ def get_rope_index( seconds_per_chunk = config.seconds_per_chunk mrope_position_deltas = [] - if input_ids is not None and ( - image_grid_thw is not None or video_grid_thw is not None - ): + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids if attention_mask is None: attention_mask = torch.ones_like(total_input_ids) @@ -1030,17 +919,11 @@ def get_rope_index( for i, input_ids in enumerate(total_input_ids): input_ids = input_ids[attention_mask[i] == 1] image_nums, video_nums, audio_nums = 0, 0, 0 - vision_start_indices = torch.argwhere( - input_ids == vision_start_token_id - ).squeeze(1) + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] audio_nums = torch.sum(input_ids == audio_start_token_id) image_nums = (vision_tokens == image_token_id).sum() - video_nums = ( - (vision_tokens == audio_start_token_id).sum() - if use_audio_in_video - else (vision_tokens == video_token_id).sum() - ) + video_nums = (vision_tokens == audio_start_token_id).sum() if use_audio_in_video else (vision_tokens == video_token_id).sum() input_tokens = input_ids.tolist() llm_pos_ids_list: list = [] st = 0 @@ -1049,15 +932,9 @@ def get_rope_index( video_nums, audio_nums, ) - multimodal_nums = ( - image_nums + audio_nums - if use_audio_in_video - else image_nums + video_nums + audio_nums - ) + multimodal_nums = image_nums + audio_nums if use_audio_in_video else image_nums + video_nums + audio_nums for _ in range(multimodal_nums): - st_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 if image_token_id in input_tokens and remain_images > 0: ed_image = input_tokens.index(image_token_id, st) else: @@ -1074,45 +951,21 @@ def get_rope_index( if min_ed == ed_audio: text_len = min_ed - st - 1 if text_len != 0: - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 bos_len = 1 - llm_pos_ids_list.append( - torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx - ) - - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 - llm_pos_ids = ( - torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx - ) + llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx llm_pos_ids_list.append(llm_pos_ids) - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 eos_len = 1 - llm_pos_ids_list.append( - torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx - ) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) st += text_len + bos_len + audio_len + eos_len audio_idx += 1 @@ -1121,53 +974,25 @@ def get_rope_index( elif min_ed == ed_image: text_len = min_ed - st - 1 if text_len != 0: - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 bos_len = 1 - llm_pos_ids_list.append( - torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx - ) - - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 grid_t = image_grid_thw[image_idx][0] grid_hs = image_grid_thw[:, 1] grid_ws = image_grid_thw[:, 2] - t_index = ( - torch.arange(grid_t) * 1 * position_id_per_seconds - ).long() - llm_pos_ids = get_llm_pos_ids_for_vision( - st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - image_len = image_grid_thw[image_idx].prod() // ( - spatial_merge_size**2 - ) + t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).long() + llm_pos_ids = get_llm_pos_ids_for_vision(st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws) + image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2) llm_pos_ids_list.append(llm_pos_ids) - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 eos_len = 1 - llm_pos_ids_list.append( - torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx - ) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) st += text_len + bos_len + image_len + eos_len image_idx += 1 @@ -1176,55 +1001,25 @@ def get_rope_index( elif min_ed == ed_video and not use_audio_in_video: text_len = min_ed - st - 1 if text_len != 0: - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 bos_len = 1 - llm_pos_ids_list.append( - torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx - ) - - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 grid_t = video_grid_thw[video_idx][0] grid_hs = video_grid_thw[:, 1] grid_ws = video_grid_thw[:, 2] - t_index = ( - torch.arange(grid_t) - * second_per_grids[video_idx].cpu().float() - * position_id_per_seconds - ).long() - llm_pos_ids = get_llm_pos_ids_for_vision( - st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - video_len = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2 - ) + t_index = (torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds).long() + llm_pos_ids = get_llm_pos_ids_for_vision(st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) llm_pos_ids_list.append(llm_pos_ids) - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 eos_len = 1 - llm_pos_ids_list.append( - torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx - ) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) st += text_len + bos_len + video_len + eos_len video_idx += 1 @@ -1233,105 +1028,45 @@ def get_rope_index( elif min_ed == ed_video and use_audio_in_video: text_len = min_ed - st - 2 if text_len != 0: - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 bos_len = 1 - llm_pos_ids_list.append( - torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx - ) - llm_pos_ids_list.append( - torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx - ) - - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 audio_len = ((audio_seqlens[audio_idx] - 1) // 2 + 1 - 2) // 2 + 1 - audio_llm_pos_ids = ( - torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx - ) + audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx grid_t = video_grid_thw[video_idx][0] grid_hs = video_grid_thw[:, 1] grid_ws = video_grid_thw[:, 2] - t_index = ( - torch.arange(grid_t) - * second_per_grids[video_idx].cpu().float() - * position_id_per_seconds - ).long() - video_llm_pos_ids = get_llm_pos_ids_for_vision( - st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws - ) - - t_ntoken_per_chunk = int( - position_id_per_seconds * seconds_per_chunk - ) - video_chunk_indexes = get_chunked_index( - video_llm_pos_ids, t_ntoken_per_chunk, st_idx - ) - audio_chunk_indexes = get_chunked_index( - audio_llm_pos_ids, t_ntoken_per_chunk, st_idx - ) + t_index = (torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds).long() + video_llm_pos_ids = get_llm_pos_ids_for_vision(st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws) + + t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk) + video_chunk_indexes = get_chunked_index(video_llm_pos_ids, t_ntoken_per_chunk, st_idx) + audio_chunk_indexes = get_chunked_index(audio_llm_pos_ids, t_ntoken_per_chunk, st_idx) sub_len = 0 - for j in range( - max(len(video_chunk_indexes), len(audio_chunk_indexes)) - ): - video_chunk_index = ( - video_chunk_indexes[j] - if j < len(video_chunk_indexes) - else None - ) - audio_chunk_index = ( - audio_chunk_indexes[j] - if j < len(audio_chunk_indexes) - else None - ) + for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): + video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None + audio_chunk_index = audio_chunk_indexes[j] if j < len(audio_chunk_indexes) else None if video_chunk_index is not None: sub_len += video_chunk_index[1] - video_chunk_index[0] - llm_pos_ids_list.append( - video_llm_pos_ids[ - :, video_chunk_index[0] : video_chunk_index[1] - ] - ) + llm_pos_ids_list.append(video_llm_pos_ids[:, video_chunk_index[0] : video_chunk_index[1]]) if audio_chunk_index is not None: sub_len += audio_chunk_index[1] - audio_chunk_index[0] - llm_pos_ids_list.append( - audio_llm_pos_ids[ - :, audio_chunk_index[0] : audio_chunk_index[1] - ] - ) - video_len = video_grid_thw[video_idx].prod() // ( - spatial_merge_size**2 - ) - - st_idx = ( - llm_pos_ids_list[-1].max() + 1 - if len(llm_pos_ids_list) > 0 - else 0 - ) + llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_chunk_index[0] : audio_chunk_index[1]]) + video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2) + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 eos_len = 1 - llm_pos_ids_list.append( - torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx - ) - llm_pos_ids_list.append( - torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx - ) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) + llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx) st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 @@ -1341,37 +1076,23 @@ def get_rope_index( remain_audios -= 1 if st < len(input_tokens): - st_idx = ( - llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - llm_pos_ids_list.append( - torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx - ) + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( - position_ids.device - ) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) - mrope_position_deltas = torch.tensor( - mrope_position_deltas, device=input_ids.device - ).unsqueeze(1) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) return position_ids, mrope_position_deltas else: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = ( - position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) - ) - max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[ - 0 - ] - mrope_position_deltas = ( - max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) - ) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True) return position_ids, mrope_position_deltas @@ -1380,22 +1101,12 @@ class OVQwen2_5OmniThinkerForConditionalGeneration(GenerationMixin): def __init__(self, model_dir, device, config): self.model = core.read_model(model_dir / THINKER_LANGUAGE_NAME) self.audio = core.compile_model(model_dir / THINKER_AUDIO_NAME, device) - self.audio_state = core.compile_model( - model_dir / THINKER_AUDIO_STATE_NAME, device - ) - self.visual_patcher = core.compile_model( - model_dir / THINKER_PATCHER_NAME, device - ) + self.audio_state = core.compile_model(model_dir / THINKER_AUDIO_STATE_NAME, device) + self.visual_patcher = core.compile_model(model_dir / THINKER_PATCHER_NAME, device) self.visual_merger = core.compile_model(model_dir / THINKER_MERGER_NAME, device) - self.embed_tokens = core.compile_model( - model_dir / THINKER_EMBEDDING_NAME, device - ) - self.input_names = { - key.get_any_name(): idx for idx, key in enumerate(self.model.inputs) - } - self.output_names = { - key.get_any_name(): idx for idx, key in enumerate(self.model.outputs) - } + self.embed_tokens = core.compile_model(model_dir / THINKER_EMBEDDING_NAME, device) + self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)} + self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} compiled_model = core.compile_model(self.model, device) self.request = compiled_model.create_infer_request() self.main_input_name = "input_ids" @@ -1410,10 +1121,7 @@ def __init__(self, model_dir, device, config): self.patch_size = self.config.vision_config.patch_size self.fullatt_block_indexes = self.config.vision_config.fullatt_block_indexes self.window_size = self.config.vision_config.window_size - self.spatial_merge_unit = ( - self.config.vision_config.spatial_merge_size - * self.config.vision_config.spatial_merge_size - ) + self.spatial_merge_unit = self.config.vision_config.spatial_merge_size * self.config.vision_config.spatial_merge_size self._skip_keys_device_placement = "past_key_values" self._supports_flash_attn_2 = True self._supports_sdpa = True @@ -1423,21 +1131,15 @@ def __init__(self, model_dir, device, config): class Qwen2_5_VisionRotaryEmbedding(torch.nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - inv_freq = 1.0 / ( - theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim) - ) + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange( - seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype - ) + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.outer(seq, self.inv_freq) return freqs - head_dim = ( - self.config.vision_config.hidden_size // self.config.vision_config.num_heads - ) + head_dim = self.config.vision_config.hidden_size // self.config.vision_config.num_heads self._rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) def can_generate(self): @@ -1524,25 +1226,19 @@ def get_window_index(self, grid_thw): window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 - vit_merger_window_size = ( - self.window_size // self.spatial_merge_size // self.patch_size - ) + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size, ) - index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( - grid_t, llm_grid_h, llm_grid_w - ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size - index_padded = torch.nn.functional.pad( - index, (0, pad_w, 0, pad_h), "constant", -100 - ) + index_padded = torch.nn.functional.pad(index, (0, pad_w, 0, pad_h), "constant", -100) index_padded = index_padded.reshape( grid_t, num_windows_h, @@ -1560,9 +1256,7 @@ def get_window_index(self, grid_thw): index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) - cu_seqlens_tmp = ( - seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] - ) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) @@ -1578,13 +1272,9 @@ def visual(self, pixel_values, grid_thw, **kwargs): dtype=torch.int32, ) cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) - cu_seqlens = torch.repeat_interleave( - grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] - ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(dim=0, dtype=torch.int32) cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) - attention_mask = torch.zeros( - (1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool - ) + attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool) causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32) for i in range(1, len(cu_seqlens)): attention_mask[ @@ -1595,9 +1285,7 @@ def visual(self, pixel_values, grid_thw, **kwargs): causal_mask.masked_fill_(torch.logical_not(attention_mask), float("-inf")) - window_attention_mask = torch.zeros( - (1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool - ) + window_attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool) window_causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32) for i in range(1, len(cu_window_seqlens)): window_attention_mask[ @@ -1606,9 +1294,7 @@ def visual(self, pixel_values, grid_thw, **kwargs): cu_window_seqlens[i - 1] : cu_window_seqlens[i], ] = True - window_causal_mask.masked_fill_( - torch.logical_not(window_attention_mask), float("-inf") - ) + window_causal_mask.masked_fill_(torch.logical_not(window_attention_mask), float("-inf")) res = self.visual_merger( [ @@ -1649,9 +1335,7 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): output_lengths = (input_lengths - 2) // 2 + 1 return input_lengths, output_lengths - def padded_and_mask_function( - self, tensor_list, tensor_len, padding_value=0, padding_side="right" - ): + def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): max_len = tensor_len.max() dim = tensor_list[0].shape[0] padded_tensor = torch.full( @@ -1711,17 +1395,11 @@ def forward( ) -> Union[tuple, BaseModelOutputWithPast]: if feature_attention_mask is not None: audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) - input_features = input_features.permute(0, 2, 1)[ - feature_attention_mask.bool() - ].permute(1, 0) + input_features = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) else: audio_feature_lengths = None if attention_mask is not None and position_ids is None: - if ( - cache_position is None - or (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - ): + if cache_position is None or (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1) position_ids, rope_deltas = get_rope_index( self.config, @@ -1737,11 +1415,7 @@ def forward( self.rope_deltas = rope_deltas else: batch_size, seq_length = input_ids.shape - delta = ( - cache_position[0] + self.rope_deltas - if cache_position is not None - else 0 - ) + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 position_ids = torch.arange(seq_length, device=input_ids.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) position_ids = position_ids.add(delta) @@ -1754,16 +1428,8 @@ def forward( ( audio_feat_lengths, audio_output_lengths, - ) = self._get_feat_extract_output_lengths( - audio_feature_lengths - if audio_feature_lengths is not None - else feature_attention_mask.sum(-1) - ) - feature_lens = ( - audio_feature_lengths - if audio_feature_lengths is not None - else feature_attention_mask.sum(-1) - ) + ) = self._get_feat_extract_output_lengths(audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1)) + feature_lens = audio_feature_lengths if audio_feature_lengths is not None else feature_attention_mask.sum(-1) chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() chunk_lengths = torch.tensor( @@ -1771,13 +1437,9 @@ def forward( dtype=torch.long, device=feature_lens.device, ) - tail_chunk_index = list( - accumulate(chunk_num.tolist(), func=operator.add, initial=-1) - )[1:] + tail_chunk_index = list(accumulate(chunk_num.tolist(), func=operator.add, initial=-1))[1:] chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) - chunk_lengths = torch.where( - chunk_lengths == 0, self.n_window * 2, chunk_lengths - ) + chunk_lengths = torch.where(chunk_lengths == 0, self.n_window * 2, chunk_lengths) chunk_list = input_features.split(chunk_lengths.tolist(), dim=1) @@ -1785,61 +1447,31 @@ def forward( padded_feature, padded_mask, padded_mask_after_cnn, - ) = self.padded_and_mask_function( - chunk_list, chunk_lengths, padding_value=0, padding_side="right" - ) - hidden_states = torch.from_numpy( - self.audio([padded_feature, padded_mask, padded_mask_after_cnn])[0] - ) - hidden_states_list = hidden_states.split( - audio_feat_lengths.tolist(), dim=0 - ) + ) = self.padded_and_mask_function(chunk_list, chunk_lengths, padding_value=0, padding_side="right") + hidden_states = torch.from_numpy(self.audio([padded_feature, padded_mask, padded_mask_after_cnn])[0]) + hidden_states_list = hidden_states.split(audio_feat_lengths.tolist(), dim=0) token_audio_list = [] for each_audio_states in hidden_states_list: - each_audio_states = torch.from_numpy( - self.audio_state([each_audio_states])[0] - ) + each_audio_states = torch.from_numpy(self.audio_state([each_audio_states])[0]) token_audio_list.append(each_audio_states) audio_features = torch.cat(token_audio_list, dim=0) if audio_features.shape[0] != sum(audio_output_lengths.tolist()): - raise ValueError( - "length of audio_features should match audio_output_lengths" - ) - audio_mask = ( - (input_ids == self.config.audio_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - ) - audio_features = audio_features.to( - inputs_embeds.device, inputs_embeds.dtype - ) + raise ValueError("length of audio_features should match audio_output_lengths") + audio_mask = (input_ids == self.config.audio_token_index).unsqueeze(-1).expand_as(inputs_embeds) + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) if pixel_values is not None: image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) + image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - video_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_embeds = video_embeds.to( - inputs_embeds.device, inputs_embeds.dtype - ) + video_mask = (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: @@ -1854,11 +1486,7 @@ def forward( inputs["attention_mask"] = attention_mask inputs["position_ids"] = position_ids if "beam_idx" in self.input_names: - inputs["beam_idx"] = ( - self.next_beam_idx - if self.next_beam_idx is not None - else np.arange(inputs_embeds.shape[0], dtype=int) - ) + inputs["beam_idx"] = self.next_beam_idx if self.next_beam_idx is not None else np.arange(inputs_embeds.shape[0], dtype=int) self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = self.request.get_tensor("logits").data @@ -1875,17 +1503,13 @@ def forward( hidden_states=(embeds_to_talker, hidden_states_output), ) - def _reorder_cache( - self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: + def _reorder_cache(self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor) -> tuple[tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ - self.next_beam_idx = np.array( - beam_idx - ) # save beam_idx to be used as an input in the next iteration + self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration return past_key_values def _get_past_length(self, past_key_values=None): @@ -1920,12 +1544,8 @@ def __init__(self, model_dir, device, config): self.model = core.read_model(model_dir / TALKER_LANGUAGE_NAME) self.embed_tokens = core.compile_model(model_dir / TALKER_EMBEDDING_NAME, "CPU") - self.input_names = { - key.get_any_name(): idx for idx, key in enumerate(self.model.inputs) - } - self.output_names = { - key.get_any_name(): idx for idx, key in enumerate(self.model.outputs) - } + self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)} + self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} compiled_model = core.compile_model(self.model, device) self.request = compiled_model.create_infer_request() self.config = config.talker_config @@ -2001,9 +1621,7 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[tuple, BaseModelOutputWithPast]: if attention_mask is not None and position_ids is None: - if cache_position is None or ( - cache_position is not None and cache_position[0] == 0 - ): + if cache_position is None or (cache_position is not None and cache_position[0] == 0): position_ids, rope_deltas = get_rope_index( self.config, input_text_ids, @@ -2021,7 +1639,9 @@ def forward( dtype=torch.long, device=inputs_embeds.device, ) - )[0][0] + )[ + 0 + ][0] ) inputs_embeds[:, -2, :] += torch.from_numpy( self.embed_tokens( @@ -2030,16 +1650,14 @@ def forward( dtype=torch.long, device=inputs_embeds.device, ) - )[0][0] + )[ + 0 + ][0] ) else: batch_size, seq_length = input_ids.shape - delta = ( - cache_position[0] + rope_deltas - if cache_position is not None and rope_deltas is not None - else 0 - ) + delta = cache_position[0] + rope_deltas if cache_position is not None and rope_deltas is not None else 0 position_ids = torch.arange(seq_length, device=input_ids.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) position_ids = position_ids.add(delta) @@ -2060,11 +1678,7 @@ def forward( inputs["attention_mask"] = attention_mask inputs["position_ids"] = position_ids if "beam_idx" in self.input_names: - inputs["beam_idx"] = ( - self.next_beam_idx - if self.next_beam_idx is not None - else np.arange(inputs_embeds.shape[0], dtype=int) - ) + inputs["beam_idx"] = self.next_beam_idx if self.next_beam_idx is not None else np.arange(inputs_embeds.shape[0], dtype=int) self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = self.request.get_tensor("logits").data @@ -2081,9 +1695,7 @@ def forward( def _get_initial_cache_position(self, input_ids, device, model_kwargs): # Talker needs to calculate cache_position with input_ids, so pop inputs_embeds temporarily inputs_embeds = model_kwargs.pop("inputs_embeds") - model_kwargs = super()._get_initial_cache_position( - input_ids, device, model_kwargs - ) + model_kwargs = super()._get_initial_cache_position(input_ids, device, model_kwargs) model_kwargs["inputs_embeds"] = inputs_embeds return model_kwargs @@ -2132,17 +1744,13 @@ def prepare_inputs_for_generation( return model_inputs - def _reorder_cache( - self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> tuple[tuple[torch.Tensor]]: + def _reorder_cache(self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor) -> tuple[tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ - self.next_beam_idx = np.array( - beam_idx - ) # save beam_idx to be used as an input in the next iteration + self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration return past_key_values def _get_past_length(self, past_key_values=None): @@ -2161,9 +1769,7 @@ def _update_model_kwargs_for_generation( if getattr(outputs, "attention_mask", None) is not None: model_kwargs["attention_mask"] = outputs.attention_mask - model_kwargs = super()._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder, num_new_tokens - ) + model_kwargs = super()._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder, num_new_tokens) if getattr(outputs, "rope_deltas", None) is not None: model_kwargs["rope_deltas"] = outputs.rope_deltas @@ -2191,11 +1797,7 @@ def _rk4_step( value_start, function_value_start=None, ): - k1 = ( - function_value_start - if function_value_start is not None - else function(time_start, value_start) - ) + k1 = function_value_start if function_value_start is not None else function(time_start, value_start) k2 = function( time_start + time_step * self._one_third, value_start + time_step * k1 * self._one_third, @@ -2221,9 +1823,7 @@ def _compute_step(self, function, time_start, time_step, time_end, value_start): function_value_start, ) - def _linear_interpolation( - self, time_start, time_end, value_start, value_end, time_point - ): + def _linear_interpolation(self, time_start, time_end, value_start, value_end, time_point): if time_point == time_start: return value_start if time_point == time_end: @@ -2244,15 +1844,10 @@ def integrate(self, time_points): current_value = self.initial_value for time_start, time_end in zip(time_points[:-1], time_points[1:]): time_step = time_end - time_start - delta_value, _ = self._compute_step( - self.function, time_start, time_step, time_end, current_value - ) + delta_value, _ = self._compute_step(self.function, time_start, time_step, time_end, current_value) next_value = current_value + delta_value - while ( - current_index < len(time_points) - and time_end >= time_points[current_index] - ): + while current_index < len(time_points) and time_end >= time_points[current_index]: solution[current_index] = self._linear_interpolation( time_start, time_end, @@ -2273,9 +1868,7 @@ def __init__(self, model_dir, thinker_device, talker_device, token2wav_device): self.has_talker = self.config.enable_audio_output model_path = Path(model_dir) - self.thinker = OVQwen2_5OmniThinkerForConditionalGeneration( - model_path / "thinker", thinker_device, self.config - ) + self.thinker = OVQwen2_5OmniThinkerForConditionalGeneration(model_path / "thinker", thinker_device, self.config) self.speaker_map = {} if self.config.enable_audio_output: self.enable_talker(model_path, talker_device, token2wav_device) @@ -2286,15 +1879,9 @@ def __init__(self, model_dir, thinker_device, talker_device, token2wav_device): def enable_talker(self, model_path, device, token2wav_device=None): if token2wav_device is None: token2wav_device = device - self.talker = OVQwen2_5OmniTalkerForConditionalGeneration( - model_path / "talker", device, self.config - ) - self.token2wav_dit = core.compile_model( - model_path / TOKEN2WAV_DIT_NAME, token2wav_device - ) - self.token2wav_bigvgan = core.compile_model( - model_path / TOKEN2WAV_BIGVGAN_NAME, token2wav_device - ) + self.talker = OVQwen2_5OmniTalkerForConditionalGeneration(model_path / "talker", device, self.config) + self.token2wav_dit = core.compile_model(model_path / TOKEN2WAV_DIT_NAME, token2wav_device) + self.token2wav_bigvgan = core.compile_model(model_path / TOKEN2WAV_BIGVGAN_NAME, token2wav_device) self.has_talker = True def load_speakers(self, path): @@ -2354,19 +1941,13 @@ def generate( - **Audio waveform** (`torch.Tensor`): Generated audio waveform. """ if speaker not in self.speaker_map: - raise ValueError( - f"{speaker} is not availible, availible speakers: {self.speaker_map.keys()}" - ) + raise ValueError(f"{speaker} is not availible, availible speakers: {self.speaker_map.keys()}") if return_audio and not self.has_talker: - raise ValueError( - "Cannot use talker when talker module not initalized. Use `enable_talker` method or set enable_talker in config to enable talker." - ) + raise ValueError("Cannot use talker when talker module not initalized. Use `enable_talker` method or set enable_talker in config to enable talker.") if return_audio is None: return_audio = self.has_talker if input_ids.shape[0] != 1 and return_audio: - raise NotImplementedError( - "Qwen2.5-Omni currently does not support batched inference with audio output" - ) + raise NotImplementedError("Qwen2.5-Omni currently does not support batched inference with audio output") shared_kwargs = {"use_audio_in_video": use_audio_in_video} thinker_kwargs = { @@ -2424,15 +2005,9 @@ def generate( return thinker_result # 2. Generate speech tokens from talker module - thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to( - self.talker.device - ) - thinker_token_embeds = [ - x[0].to(self.talker.device) for x in thinker_result.hidden_states - ] - thinker_hidden_states = [ - x[1].to(self.talker.device) for x in thinker_result.hidden_states - ] + thinker_generate_ids = thinker_result.sequences[:, input_ids.size(1) :].to(self.talker.device) + thinker_token_embeds = [x[0].to(self.talker.device) for x in thinker_result.hidden_states] + thinker_hidden_states = [x[1].to(self.talker.device) for x in thinker_result.hidden_states] talker_text_bos_token = speaker_params["bos_token"] talker_input_text_ids = torch.cat( [ @@ -2468,16 +2043,10 @@ def generate( dim=1, ) - thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat( - thinker_token_embeds[1:], dim=1 - ) + thinker_reply_part = torch.cat(thinker_hidden_states[1:], dim=1) + torch.cat(thinker_token_embeds[1:], dim=1) talker_inputs_embeds = thinker_hidden_states[0] + thinker_token_embeds[0] - talker_text_bos_token = torch.tensor( - [[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device - ) - talker_text_bos_embed = torch.from_numpy( - self.thinker.embed_tokens(talker_text_bos_token)[0] - ).to(self.talker.device) + talker_text_bos_token = torch.tensor([[talker_text_bos_token]], dtype=torch.long, device=self.thinker.device) + talker_text_bos_embed = torch.from_numpy(self.thinker.embed_tokens(talker_text_bos_token)[0]).to(self.talker.device) talker_inputs_embeds = torch.cat( [ @@ -2531,34 +2100,22 @@ def generate( inputs_embeds=talker_inputs_embeds, attention_mask=talker_attention_mask, suppress_tokens=[self.talker.codec_bos_token], - **{ - k: (v.to(self.talker.device) if torch.is_tensor(v) else v) - for k, v in talker_kwargs.items() - }, + **{k: (v.to(self.talker.device) if torch.is_tensor(v) else v) for k, v in talker_kwargs.items()}, ) talker_generate_codes = talker_result[:, talker_input_ids.shape[1] : -1] print("[===start token2wav===]") # 3. Generate wavs from code - reference_mel_spectrogram = ( - speaker_params["ref_mel"].to(torch.device("cpu")).float() - ) + reference_mel_spectrogram = speaker_params["ref_mel"].to(torch.device("cpu")).float() conditioning_vector = speaker_params["cond"].to(torch.device("cpu")).float() noise_initialization = torch.randn( [1, 30000, self.config.token2wav_config.dit_config.mel_dim], dtype=reference_mel_spectrogram.dtype, ) - maximum_duration = ( - talker_generate_codes.shape[1] - * self.config.token2wav_config.dit_config.repeats - ) - initial_state = noise_initialization[:, :maximum_duration].to( - talker_generate_codes.device - ) + maximum_duration = talker_generate_codes.shape[1] * self.config.token2wav_config.dit_config.repeats + initial_state = noise_initialization[:, :maximum_duration].to(talker_generate_codes.device) batch_size = reference_mel_spectrogram.shape[0] - conditioning_vector = conditioning_vector.unsqueeze(1).repeat( - 1, maximum_duration, 1 - ) + conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1) if batch_size != 1: raise ValueError("Only batch size = 1 is currently supported") guidance_scale = 0.5 @@ -2578,10 +2135,7 @@ def ode_function(time_step, hidden_states): ) guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0) - return ( - guided_prediction - + (guided_prediction - null_prediction) * guidance_scale - ) + return guided_prediction + (guided_prediction - null_prediction) * guidance_scale initial_time = 0 time_embedding = torch.linspace( @@ -2593,19 +2147,13 @@ def ode_function(time_step, hidden_states): ) if sway_coefficient is not None: - time_embedding += sway_coefficient * ( - torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding - ) + time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding) - ode_solver = RungeKutta4ODESolver( - function=ode_function, initial_value=initial_state - ) + ode_solver = RungeKutta4ODESolver(function=ode_function, initial_value=initial_state) solution_trajectory = ode_solver.integrate(time_embedding) generated_waveform = solution_trajectory[-1] generated_mel_spectrogram = generated_waveform.permute(0, 2, 1) - waveform = torch.from_numpy( - self.token2wav_bigvgan([generated_mel_spectrogram])[0] - ) + waveform = torch.from_numpy(self.token2wav_bigvgan([generated_mel_spectrogram])[0]) waveform.squeeze().cpu() return thinker_result.sequences, waveform.float() From 837f9b78e524d93f4917c02e2252235a0f20496d Mon Sep 17 00:00:00 2001 From: ethan Date: Thu, 3 Jul 2025 08:57:30 -0700 Subject: [PATCH 5/5] change the quantization to reduce accuracy on audio inputs --- notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb b/notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb index 9781e72a48f..4689ed1aff6 100644 --- a/notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb +++ b/notebooks/qwen2.5-omni-chatbot/qwen2.5-omni-chatbot.ipynb @@ -201,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "6e938ae8-7e49-4c61-88b9-0b73c8fa8407", "metadata": {}, "outputs": [ @@ -218,9 +218,9 @@ "from qwen2_5_omni_helper import convert_qwen2_5_omni_model\n", "\n", "compression_configuration = {\n", - " \"mode\": nncf.CompressWeightsMode.INT4_ASYM,\n", + " \"mode\": nncf.CompressWeightsMode.INT4_SYM,\n", " \"group_size\": 128,\n", - " \"ratio\": 0.8,\n", + " \"ratio\": 0.6,\n", "}\n", "\n", "convert_qwen2_5_omni_model(model_id, model_dir, compression_configuration)"