Skip to content

[Draft] Add Llasa TTS family of models #39760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

ebezzam
Copy link
Contributor

@ebezzam ebezzam commented Jul 29, 2025

What does this PR do?

This PR adds the Llasa TTS family of models:

TODO

  • Tests (e.g. against tokenizer and speech tokens of original models)
  • Creating docstrings and overwriting those from Llama with corresponding examples with Llasa
  • Make make fixup happy with auto docs, unused attributes, etc
  • Conversion configurations for 3B and 8B
  • Creating public model cards, atm 1B can be found here. Changing tags to audio generation, TTS, etc
  • Integrate with XCodec2 (Transformer version) when [WiP] Add xcodec2 model #37868 merged

Example usage

Below is example usage with my Hub checkpoint (compared to that of original authors)

"""
pip install torchao xcodec2==0.1.3
"""

import torch
from transformers import LlasaTokenizer, LlasaForCausalLM, LlasaProcessor
import soundfile as sf
from xcodec2.modeling_xcodec2 import XCodec2Model

model_repo = "bezzam/Llasa-1B"
# model_repo = "bezzam/Llasa-3B"
# model_repo = "bezzam/Llasa-8B"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

# load processor (tokenizer + audio codec)
processor = LlasaProcessor(
    LlasaTokenizer.from_pretrained(model_repo),
    XCodec2Model.from_pretrained("HKUSTAudio/xcodec2").eval().to(torch_device)
)

# load model
model = LlasaForCausalLM.from_pretrained(model_repo)
model.eval().to(torch_device)

# TTS, some text inputs don't work which shows limitations of this approach
input_text = "How much wood would a woodchuck chuck if a woodchuck could chuck speech tokens?"
with torch.no_grad():

    # Tokenize the text
    encoded_text = processor(input_text).to(torch_device)

    # Generate the speech autoregressively
    outputs = model.generate(
        encoded_text["input_ids"],
        do_sample=False,
        max_length=600,    # generates up to ~10s. Max allowed length is 2048, as Llasa was trained with max length 2048
        top_p=1,           # Adjusts the diversity of generated content
        temperature=0.8,   # Controls randomness in output
    )

# decode to audio
gen_wav = processor.decode(outputs, input_offset=encoded_text["input_offset"])
fn = f"gen_{model_repo.split('/')[-1]}.wav"
sf.write(fn, gen_wav.cpu().numpy(), model.config.sampling_rate)
print(f"Generated speech saved to {fn}")

@ebezzam ebezzam marked this pull request as draft July 29, 2025 14:42
model_config.max_length = config.original_model.model_max_length
model = LlasaForCausalLM(model_config)
if config.remote_repo.dtype == "bfloat16":
model.to(torch.bfloat16)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this fine? Original models are trained in bf16 (config) and their Hub checkpoints are also in bf16 (e.g., 1B)

print(f"Pushing a {model.__class__.__name__} to Hugging Face Hub: {config.remote_repo.id}")
model.push_to_hub(config.remote_repo.id, private=config.remote_repo.private, use_temp_dir=True)
print(f"Pushing a {tokenizer.__class__.__name__} to Hugging Face Hub: {config.remote_repo.id}")
tokenizer.push_to_hub(config.remote_repo.id, private=True, use_temp_dir=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should processing config also be pushed to Hub? Seems like it was done for Dia:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although seems like preprocessor_config.json was manually added (see commit). Is there a way to upload/add with conversion script?

Comment on lines +68 to +74
def from_pretrained_llm(cls, *args, **kwargs):
"""
Load the tokenizer from a pre-trained LLM model, and add relevant speech and Llasa tokens.
"""
tokenizer = super().from_pretrained(*args, **kwargs)
tokenizer.add_tokens(list(tokenizer.llasa_token.values()) + tokenizer.speech_tokens)
return tokenizer
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is something like this fine? (also for LlasaConfig)

The difference with conventional from_pretrained is that this one increases the vocab size according to the (speech and llasa tokens). These methods are useful for the conversion script to copy the tokenizer and config from Llama (an LLM).

But when using Llasa, from_pretrained will be used as usual, loading from actual Llasa tokeniers and configs that don't need explicit adding of tokens.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 168 to 194
# TODO: how to overwrite generate method?
# Not necessary but could be nice to check max_length < 2048 (what model was trained on)
# I get the following error (I think because `generate` isn't method of LlamaForCausalLM but its parent):
# ```
# File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 355, in replace_super_calls
# original_modeling_method_body = self.original_modeling_methods[func_name].body.body
# KeyError: 'generate'
# ```
# """
# @torch.no_grad()
# def generate(
# inputs,
# max_length=2048,
# **kwargs,
# ):
# """
# Set specific parameters from Llasa processor output
# """
# if max_length > 2048:
# raise ValueError("Max length should be less than or equal to 2048.")

# # Call the parent class's generate method
# return super().generate(
# inputs,
# max_length=inputs["max_length"],
# **kwargs
# )
Copy link
Contributor Author

@ebezzam ebezzam Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to overwrite the generate method but getting the error (below) when running the modular script.

Maybe the issue is that generate is not a method of LlamaForCausalLM but its parent class GenerationMixin?

In any case, it is isn't absolutely necessary to overwrite generate, but could be nice for adding a check that uses don't request outputs larger than model.generation_config.max_length (=2048), which is the max length the models were trained on. But maybe there's another way to restrict output sizes that users request?

# command: python utils/modular_model_converter.py --files-to-parse src/transformers/models/llasa/modular_llasa.py
Traceback (most recent call last):
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 1779, in <module>
    converted_files = convert_modular_file(file_name)
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 1693, in convert_modular_file
    for file, module in create_modules(cst_transformers).items():
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 1634, in create_modules
    nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files)
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 1577, in get_class_node_and_dependencies
    updated_node = replace_class_node(mapper, node, renamed_super_class, super_class)
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 1064, in replace_class_node
    new_replacement_class = new_module.visit(
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/metadata/wrapper.py", line 204, in visit
    return self.module.visit(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/module.py", line 89, in visit
    result = super(Module, self).visit(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/base.py", line 228, in visit
    _CSTNodeSelfT, self._visit_and_replace_children(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/module.py", line 74, in _visit_and_replace_children
    body=visit_body_sequence(self, "body", self.body, visitor),
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/internal.py", line 227, in visit_body_sequence
    return tuple(visit_body_iterable(parent, fieldname, children, visitor))
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/internal.py", line 193, in visit_body_iterable
    new_child = child.visit(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/base.py", line 228, in visit
    _CSTNodeSelfT, self._visit_and_replace_children(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/statement.py", line 1989, in _visit_and_replace_children
    body=visit_required(self, "body", self.body, visitor),
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/internal.py", line 81, in visit_required
    result = node.visit(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/base.py", line 228, in visit
    _CSTNodeSelfT, self._visit_and_replace_children(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/statement.py", line 704, in _visit_and_replace_children
    body=visit_body_sequence(self, "body", self.body, visitor),
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/internal.py", line 227, in visit_body_sequence
    return tuple(visit_body_iterable(parent, fieldname, children, visitor))
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/internal.py", line 193, in visit_body_iterable
    new_child = child.visit(visitor)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_nodes/base.py", line 237, in visit
    leave_result = visitor.on_leave(self, with_updated_children)
  File "/home/eric_bezzam/transformers/py310/lib/python3.10/site-packages/libcst/_visitors.py", line 71, in on_leave
    updated_node = leave_func(original_node, updated_node)
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 369, in leave_FunctionDef
    new_body = self.replace_super_calls(updated_node.body, name)
  File "/home/eric_bezzam/transformers/utils/modular_model_converter.py", line 355, in replace_super_calls
    original_modeling_method_body = self.original_modeling_methods[func_name].body.body
KeyError: 'generate'

@ebezzam ebezzam requested a review from eustlb July 29, 2025 15:06
@Rocketknight1
Copy link
Member

cc @eustlb for TTS

Copy link
Contributor

github-actions bot commented Aug 4, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, llasa

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants