Skip to content

Commit 745701c

Browse files
Add first-class support for Cartesia text-to-speech (#298)
* Demo * patient intake * cartesia * Add cartesia * Fix * lint * Move test * Fix * Fix * Fix * Fix
1 parent 24349de commit 745701c

File tree

4 files changed

+99
-14
lines changed

4 files changed

+99
-14
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ jobs:
3434
- name: Run tests
3535
run: |
3636
python -m pip install -U pip
37-
pip install .[dev]
37+
pip install '.[dev, tts]'
3838
python -m pytest --capture=no
3939
shell: bash

backend/fastrtc/text_to_speech/tts.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import importlib.util
23
import re
34
from collections.abc import AsyncGenerator, Generator
45
from dataclasses import dataclass
@@ -9,6 +10,8 @@
910
from huggingface_hub import hf_hub_download
1011
from numpy.typing import NDArray
1112

13+
from fastrtc.utils import async_aggregate_bytes_to_16bit
14+
1215

1316
class TTSOptions:
1417
pass
@@ -20,15 +23,15 @@ class TTSOptions:
2023
class TTSModel(Protocol[T]):
2124
def tts(
2225
self, text: str, options: T | None = None
23-
) -> tuple[int, NDArray[np.float32]]: ...
26+
) -> tuple[int, NDArray[np.float32] | NDArray[np.int16]]: ...
2427

2528
def stream_tts(
2629
self, text: str, options: T | None = None
27-
) -> AsyncGenerator[tuple[int, NDArray[np.float32]], None]: ...
30+
) -> AsyncGenerator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None]: ...
2831

2932
def stream_tts_sync(
3033
self, text: str, options: T | None = None
31-
) -> Generator[tuple[int, NDArray[np.float32]], None, None]: ...
34+
) -> Generator[tuple[int, NDArray[np.float32] | NDArray[np.int16]], None, None]: ...
3235

3336

3437
@dataclass
@@ -39,10 +42,19 @@ class KokoroTTSOptions(TTSOptions):
3942

4043

4144
@lru_cache
42-
def get_tts_model(model: Literal["kokoro"] = "kokoro") -> TTSModel:
43-
m = KokoroTTSModel()
44-
m.tts("Hello, world!")
45-
return m
45+
def get_tts_model(
46+
model: Literal["kokoro", "cartesia"] = "kokoro", **kwargs
47+
) -> TTSModel:
48+
if model == "kokoro":
49+
m = KokoroTTSModel()
50+
m.tts("Hello, world!")
51+
return m
52+
elif model == "cartesia":
53+
m = CartesiaTTSModel(api_key=kwargs.get("cartesia_api_key", ""))
54+
m.tts("Hello, world!")
55+
return m
56+
else:
57+
raise ValueError(f"Invalid model: {model}")
4658

4759

4860
class KokoroFixedBatchSize:
@@ -139,3 +151,77 @@ def stream_tts_sync(
139151
yield loop.run_until_complete(iterator.__anext__())
140152
except StopAsyncIteration:
141153
break
154+
155+
156+
class CartesiaTTSOptions(TTSOptions):
157+
voice: str = "71a7ad14-091c-4e8e-a314-022ece01c121"
158+
language: str = "en"
159+
emotion: list[str] = []
160+
cartesia_version: str = "2024-06-10"
161+
model: str = "sonic-2"
162+
sample_rate: int = 22_050
163+
164+
165+
class CartesiaTTSModel(TTSModel):
166+
def __init__(self, api_key: str):
167+
if importlib.util.find_spec("cartesia") is None:
168+
raise RuntimeError(
169+
"cartesia is not installed. Please install it using 'pip install cartesia'."
170+
)
171+
from cartesia import AsyncCartesia
172+
173+
self.client = AsyncCartesia(api_key=api_key)
174+
175+
async def stream_tts(
176+
self, text: str, options: CartesiaTTSOptions | None = None
177+
) -> AsyncGenerator[tuple[int, NDArray[np.int16]], None]:
178+
options = options or CartesiaTTSOptions()
179+
180+
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
181+
182+
for sentence in sentences:
183+
if not sentence.strip():
184+
continue
185+
async for output in async_aggregate_bytes_to_16bit(
186+
self.client.tts.bytes(
187+
model_id="sonic-2",
188+
transcript=sentence,
189+
voice={"id": options.voice}, # type: ignore
190+
language="en",
191+
output_format={
192+
"container": "raw",
193+
"sample_rate": options.sample_rate,
194+
"encoding": "pcm_s16le",
195+
},
196+
)
197+
):
198+
yield options.sample_rate, np.frombuffer(output, dtype=np.int16)
199+
200+
def stream_tts_sync(
201+
self, text: str, options: CartesiaTTSOptions | None = None
202+
) -> Generator[tuple[int, NDArray[np.int16]], None, None]:
203+
loop = asyncio.new_event_loop()
204+
205+
iterator = self.stream_tts(text, options).__aiter__()
206+
while True:
207+
try:
208+
yield loop.run_until_complete(iterator.__anext__())
209+
except StopAsyncIteration:
210+
break
211+
212+
def tts(
213+
self, text: str, options: CartesiaTTSOptions | None = None
214+
) -> tuple[int, NDArray[np.int16]]:
215+
loop = asyncio.new_event_loop()
216+
buffer = np.array([], dtype=np.int16)
217+
218+
options = options or CartesiaTTSOptions()
219+
220+
iterator = self.stream_tts(text, options).__aiter__()
221+
while True:
222+
try:
223+
_, chunk = loop.run_until_complete(iterator.__anext__())
224+
buffer = np.concatenate([buffer, chunk])
225+
except StopAsyncIteration:
226+
break
227+
return options.sample_rate, buffer
Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1+
import pytest
12
from fastrtc.text_to_speech.tts import get_tts_model
23

34

4-
def test_tts_long_prompt():
5-
model = get_tts_model()
5+
@pytest.mark.parametrize("model", ["kokoro"])
6+
def test_tts_long_prompt(model):
7+
model = get_tts_model(model=model)
68
prompt = "It may be that this communication will be considered as a madman's freak but at any rate it must be admitted that in its clearness and frankness it left nothing to be desired The serious part of it was that the Federal Government had undertaken to treat a sale by auction as a valid concession of these undiscovered territories Opinions on the matter were many Some readers saw in it only one of those prodigious outbursts of American humbug which would exceed the limits of puffism if the depths of human credulity were not unfathomable"
79

810
for i, chunk in enumerate(model.stream_tts_sync(prompt)):
911
print(f"Chunk {i}: {chunk[1].shape}")
10-
11-
12-
if __name__ == "__main__":
13-
test_tts_long_prompt()

test/test_webrtc_connection_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
)
4040
self.time_limit = time_limit
4141
self.allow_extra_tracks = allow_extra_tracks
42+
self.server_rtc_configuration = None
4243

4344
def mount(self, app: FastAPI, path: str = ""):
4445
from fastapi import APIRouter

0 commit comments

Comments
 (0)