Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 55 additions & 45 deletions python/helpers/kokoro_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,103 +11,113 @@
warnings.filterwarnings("ignore", category=FutureWarning)

_pipeline = None
_voice = "am_puck,am_onyx"
_speed = 1.1
is_updating_model = False

def _get_voice_settings():
"""Get voice configuration from settings"""
from python.helpers import settings
current_settings = settings.get_settings()

primary_voice = current_settings.get("kokoro_voice", "af_alloy")
secondary_voice = current_settings.get("kokoro_voice_blend", "")
voice_ratio = current_settings.get("kokoro_voice_ratio", 0.5)
speed = current_settings.get("kokoro_speed", 1.1)

# Validate primary voice exists and is not a placeholder
if (not primary_voice or not primary_voice.strip() or
primary_voice in ["", "No blending"]):
primary_voice = "af_alloy" # fallback to default

# Validate speed is within reasonable bounds
if speed < 0.1 or speed > 5.0:
speed = 1.1 # fallback to default

# Build voice string for Kokoro
if (secondary_voice and secondary_voice.strip() and
secondary_voice != primary_voice and
secondary_voice not in ["", "No blending"]):
# Voice blending: use ratio to determine blend
voice_string = f"{primary_voice},{secondary_voice}"
else:
# Single voice
voice_string = primary_voice

return voice_string, speed

async def preload():
try:
# return await runtime.call_development_function(_preload)
return await _preload()
return await runtime.call_development_function(_preload)
except Exception as e:
# if not runtime.is_development():
raise e
if not runtime.is_development():
raise e
# Fallback to direct execution if RFC fails in development
# PrintStyle.standard("RFC failed, falling back to direct execution...")
# return await _preload()

PrintStyle.standard("RFC failed, falling back to direct execution...")
return await _preload()

async def _preload():
global _pipeline, is_updating_model

while is_updating_model:
await asyncio.sleep(0.1)

try:
is_updating_model = True
if not _pipeline:
PrintStyle.standard("Loading Kokoro TTS model...")
from kokoro import KPipeline
_pipeline = KPipeline(lang_code="a")
_pipeline = KPipeline(lang_code='a')
finally:
is_updating_model = False


async def is_downloading():
try:
# return await runtime.call_development_function(_is_downloading)
return _is_downloading()
return await runtime.call_development_function(_is_downloading)
except Exception as e:
# if not runtime.is_development():
raise e
if not runtime.is_development():
raise e
# Fallback to direct execution if RFC fails in development
# return _is_downloading()

return _is_downloading()

def _is_downloading():
return is_updating_model

async def is_downloaded():
try:
# return await runtime.call_development_function(_is_downloaded)
return _is_downloaded()
except Exception as e:
# if not runtime.is_development():
raise e
# Fallback to direct execution if RFC fails in development
# return _is_downloaded()

def _is_downloaded():
return _pipeline is not None


async def synthesize_sentences(sentences: list[str]):
"""Generate audio for multiple sentences and return concatenated base64 audio"""
try:
# return await runtime.call_development_function(_synthesize_sentences, sentences)
return await _synthesize_sentences(sentences)
return await runtime.call_development_function(_synthesize_sentences, sentences)
except Exception as e:
# if not runtime.is_development():
raise e
if not runtime.is_development():
raise e
# Fallback to direct execution if RFC fails in development
# return await _synthesize_sentences(sentences)

return await _synthesize_sentences(sentences)

async def _synthesize_sentences(sentences: list[str]):
await _preload()

# Get current voice settings
voice_string, speed = _get_voice_settings()

combined_audio = []

try:
for sentence in sentences:
if sentence.strip():
segments = _pipeline(sentence.strip(), voice=_voice, speed=_speed) # type: ignore
segments = _pipeline(sentence.strip(), voice=voice_string, speed=speed)
segment_list = list(segments)

for segment in segment_list:
audio_tensor = segment.audio
audio_numpy = audio_tensor.detach().cpu().numpy() # type: ignore
audio_numpy = audio_tensor.detach().cpu().numpy()
combined_audio.extend(audio_numpy)

# Convert combined audio to bytes
buffer = io.BytesIO()
sf.write(buffer, combined_audio, 24000, format="WAV")
sf.write(buffer, combined_audio, 24000, format='WAV')
audio_bytes = buffer.getvalue()

# Return base64 encoded audio
return base64.b64encode(audio_bytes).decode("utf-8")
return base64.b64encode(audio_bytes).decode('utf-8')

except Exception as e:
PrintStyle.error(f"Error in Kokoro TTS synthesis: {e}")
raise
raise
Loading