1
1
import asyncio
2
+ import importlib .util
2
3
import re
3
4
from collections .abc import AsyncGenerator , Generator
4
5
from dataclasses import dataclass
9
10
from huggingface_hub import hf_hub_download
10
11
from numpy .typing import NDArray
11
12
13
+ from fastrtc .utils import async_aggregate_bytes_to_16bit
14
+
12
15
13
16
class TTSOptions :
14
17
pass
@@ -20,15 +23,15 @@ class TTSOptions:
20
23
class TTSModel (Protocol [T ]):
21
24
def tts (
22
25
self , text : str , options : T | None = None
23
- ) -> tuple [int , NDArray [np .float32 ]]: ...
26
+ ) -> tuple [int , NDArray [np .float32 ] | NDArray [ np . int16 ] ]: ...
24
27
25
28
def stream_tts (
26
29
self , text : str , options : T | None = None
27
- ) -> AsyncGenerator [tuple [int , NDArray [np .float32 ]], None ]: ...
30
+ ) -> AsyncGenerator [tuple [int , NDArray [np .float32 ] | NDArray [ np . int16 ] ], None ]: ...
28
31
29
32
def stream_tts_sync (
30
33
self , text : str , options : T | None = None
31
- ) -> Generator [tuple [int , NDArray [np .float32 ]], None , None ]: ...
34
+ ) -> Generator [tuple [int , NDArray [np .float32 ] | NDArray [ np . int16 ] ], None , None ]: ...
32
35
33
36
34
37
@dataclass
@@ -39,10 +42,19 @@ class KokoroTTSOptions(TTSOptions):
39
42
40
43
41
44
@lru_cache
42
- def get_tts_model (model : Literal ["kokoro" ] = "kokoro" ) -> TTSModel :
43
- m = KokoroTTSModel ()
44
- m .tts ("Hello, world!" )
45
- return m
45
+ def get_tts_model (
46
+ model : Literal ["kokoro" , "cartesia" ] = "kokoro" , ** kwargs
47
+ ) -> TTSModel :
48
+ if model == "kokoro" :
49
+ m = KokoroTTSModel ()
50
+ m .tts ("Hello, world!" )
51
+ return m
52
+ elif model == "cartesia" :
53
+ m = CartesiaTTSModel (api_key = kwargs .get ("cartesia_api_key" , "" ))
54
+ m .tts ("Hello, world!" )
55
+ return m
56
+ else :
57
+ raise ValueError (f"Invalid model: { model } " )
46
58
47
59
48
60
class KokoroFixedBatchSize :
@@ -139,3 +151,77 @@ def stream_tts_sync(
139
151
yield loop .run_until_complete (iterator .__anext__ ())
140
152
except StopAsyncIteration :
141
153
break
154
+
155
+
156
+ class CartesiaTTSOptions (TTSOptions ):
157
+ voice : str = "71a7ad14-091c-4e8e-a314-022ece01c121"
158
+ language : str = "en"
159
+ emotion : list [str ] = []
160
+ cartesia_version : str = "2024-06-10"
161
+ model : str = "sonic-2"
162
+ sample_rate : int = 22_050
163
+
164
+
165
+ class CartesiaTTSModel (TTSModel ):
166
+ def __init__ (self , api_key : str ):
167
+ if importlib .util .find_spec ("cartesia" ) is None :
168
+ raise RuntimeError (
169
+ "cartesia is not installed. Please install it using 'pip install cartesia'."
170
+ )
171
+ from cartesia import AsyncCartesia
172
+
173
+ self .client = AsyncCartesia (api_key = api_key )
174
+
175
+ async def stream_tts (
176
+ self , text : str , options : CartesiaTTSOptions | None = None
177
+ ) -> AsyncGenerator [tuple [int , NDArray [np .int16 ]], None ]:
178
+ options = options or CartesiaTTSOptions ()
179
+
180
+ sentences = re .split (r"(?<=[.!?])\s+" , text .strip ())
181
+
182
+ for sentence in sentences :
183
+ if not sentence .strip ():
184
+ continue
185
+ async for output in async_aggregate_bytes_to_16bit (
186
+ self .client .tts .bytes (
187
+ model_id = "sonic-2" ,
188
+ transcript = sentence ,
189
+ voice = {"id" : options .voice }, # type: ignore
190
+ language = "en" ,
191
+ output_format = {
192
+ "container" : "raw" ,
193
+ "sample_rate" : options .sample_rate ,
194
+ "encoding" : "pcm_s16le" ,
195
+ },
196
+ )
197
+ ):
198
+ yield options .sample_rate , np .frombuffer (output , dtype = np .int16 )
199
+
200
+ def stream_tts_sync (
201
+ self , text : str , options : CartesiaTTSOptions | None = None
202
+ ) -> Generator [tuple [int , NDArray [np .int16 ]], None , None ]:
203
+ loop = asyncio .new_event_loop ()
204
+
205
+ iterator = self .stream_tts (text , options ).__aiter__ ()
206
+ while True :
207
+ try :
208
+ yield loop .run_until_complete (iterator .__anext__ ())
209
+ except StopAsyncIteration :
210
+ break
211
+
212
+ def tts (
213
+ self , text : str , options : CartesiaTTSOptions | None = None
214
+ ) -> tuple [int , NDArray [np .int16 ]]:
215
+ loop = asyncio .new_event_loop ()
216
+ buffer = np .array ([], dtype = np .int16 )
217
+
218
+ options = options or CartesiaTTSOptions ()
219
+
220
+ iterator = self .stream_tts (text , options ).__aiter__ ()
221
+ while True :
222
+ try :
223
+ _ , chunk = loop .run_until_complete (iterator .__anext__ ())
224
+ buffer = np .concatenate ([buffer , chunk ])
225
+ except StopAsyncIteration :
226
+ break
227
+ return options .sample_rate , buffer
0 commit comments