Skip to content

Support SpeechT5 text-to-speech pipeline by OpenVINO #1230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

rkazants
Copy link
Collaborator

@rkazants rkazants commented Apr 9, 2025

What does this PR do?

This PR introduces support of SpeechT5 text-to-speech pipeline using OpenVINO. Here is a demo code:

import soundfile as sf
import torch
from datasets import load_dataset
from optimum.intel import OVModelForTextToSpeechSeq2Seq
from transformers import SpeechT5Processor

model_id = "microsoft/speecht5_tts"
vocoder_id = "microsoft/speecht5_hifigan"

ov_pipe = OVModelForTextToSpeechSeq2Seq.from_pretrained(model_id, export=True, vocoder=vocoder_id)
ov_pipe.save_pretrained("speecht5_tts")
ov_pipe = OVModelForTextToSpeechSeq2Seq.from_pretrained("speecht5_tts")

processor = SpeechT5Processor.from_pretrained(model_id)

inputs = processor(text="Hello, this PR introduces support of SpeechT5 text-to-speech pipeline using OpenVINO.",
                   return_tensors="pt")

# load vector containing speaker's voice characteristics from a dataset
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)

speech = ov_pipe.generate(input_ids=inputs["input_ids"],
                          speaker_embeddings=speaker_embeddings)

sf.write("speech.wav", speech.numpy()[0], samplerate=16000)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
@eaidova
Copy link
Collaborator

eaidova commented Apr 14, 2025

@rkazants could you please provide tests?

@eaidova
Copy link
Collaborator

eaidova commented Apr 14, 2025

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
@rkazants rkazants requested a review from eaidova April 18, 2025 14:35
@rkazants
Copy link
Collaborator Author

@rkazants could you please provide tests?

Done

@rkazants
Copy link
Collaborator Author

@eaidova, @IlyasMoutawwakil, @echarlaix, could you please review PR?

Thanks,
Roman

@rkazants
Copy link
Collaborator Author

rkazants commented Apr 18, 2025

CI failures does not relate to my changes in PR. For example, I see issue with whisper model. Please correct me if I am wrong.

@eaidova
Copy link
Collaborator

eaidova commented Apr 21, 2025

@IlyasMoutawwakil could you please rerun ci, thanks

@IlyasMoutawwakil
Copy link
Member

@eaidova done, was this issue fixed ?

FAILED tests/openvino/test_modeling.py::OVModelForTextToSpeechSeq2SeqIntegrationTest::test_compare_to_transformers_0_speecht5 - RuntimeError: The size of tensor a (512) must match the size of tensor b (32) at non-singleton dimension 1

@eaidova
Copy link
Collaborator

eaidova commented Apr 21, 2025

@eaidova done, was this issue fixed ?

FAILED tests/openvino/test_modeling.py::OVModelForTextToSpeechSeq2SeqIntegrationTest::test_compare_to_transformers_0_speecht5 - RuntimeError: The size of tensor a (512) must match the size of tensor b (32) at non-singleton dimension 1

@IlyasMoutawwakil thanks, yes @rkazants is working on the fix

@echarlaix echarlaix added the openvino-test Trigger OpenVINO slow tests label Apr 24, 2025
Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great addition !
I think there's still some redundancy / room to make the implementation leaner, for example with Whisper which has a custom generate method, we only make sure our class is compliant with its behavior and use the method directly from transformers.

@IlyasMoutawwakil IlyasMoutawwakil removed the openvino-test Trigger OpenVINO slow tests label Apr 25, 2025
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
@rkazants
Copy link
Collaborator Author

LGTM, great addition ! I think there's still some redundancy / room to make the implementation leaner, for example with Whisper which has a custom generate method, we only make sure our class is compliant with its behavior and use the method directly from transformers.

Responded here #1230 (comment)
Thanks

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
Copy link
Collaborator

@echarlaix echarlaix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left couple of minor comments, good to merge once resolved

rkazants and others added 3 commits April 29, 2025 11:30
Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
Co-authored-by: Ella Charlaix <80481427+echarlaix@users.noreply.github.com>
Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
@nikita-savelyevv nikita-savelyevv merged commit 1949522 into huggingface:main Apr 29, 2025
16 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants