Skip to content

Commit 2e8ebec

Browse files
committed
stability fixes
1 parent 661c646 commit 2e8ebec

File tree

3 files changed

+107
-75
lines changed

3 files changed

+107
-75
lines changed

custom_components/openai_tts/manifest.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@
1010
"iot_class": "cloud_polling",
1111
"issue_tracker": "https://github.com/sfortis/openai_tts/issues",
1212
"requirements": [],
13-
"version": "0.3.0b0"
13+
"version": "0.3.1b0"
1414
}

custom_components/openai_tts/openaitts_engine.py

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""
22
TTS Engine for OpenAI TTS.
33
"""
4-
import asyncio
5-
import threading
4+
import json
65
import logging
7-
import aiohttp
6+
import time
7+
from urllib.request import Request, urlopen
8+
from urllib.error import HTTPError, URLError
9+
from asyncio import CancelledError
10+
11+
from homeassistant.exceptions import HomeAssistantError
812

913
_LOGGER = logging.getLogger(__name__)
1014

@@ -21,64 +25,66 @@ def __init__(self, api_key: str, voice: str, model: str, speed: float, url: str)
2125
self._speed = speed
2226
self._url = url
2327

24-
# Create a dedicated event loop running in a background thread.
25-
self._loop = asyncio.new_event_loop()
26-
self._session = None
27-
self._thread = threading.Thread(target=self._start_loop, daemon=True)
28-
self._thread.start()
29-
# Initialize the aiohttp session in the background event loop.
30-
asyncio.run_coroutine_threadsafe(self._init_session(), self._loop).result()
31-
32-
def _start_loop(self):
33-
asyncio.set_event_loop(self._loop)
34-
self._loop.run_forever()
28+
def get_tts(self, text: str, speed: float = None, voice: str = None) -> AudioResponse:
29+
"""Synchronous TTS request using urllib.request
30+
If the API call fails, waits for 1 second and retries once.
31+
"""
32+
if speed is None:
33+
speed = self._speed
34+
if voice is None:
35+
voice = self._voice
3536

36-
async def _init_session(self):
37-
# Create a persistent aiohttp session for reuse.
38-
self._session = aiohttp.ClientSession()
37+
headers = {"Content-Type": "application/json"}
38+
if self._api_key:
39+
headers["Authorization"] = f"Bearer {self._api_key}"
3940

40-
async def _async_get_tts(self, text: str, speed: float, voice: str) -> AudioResponse:
41-
headers = {"Authorization": f"Bearer {self._api_key}"} if self._api_key else {}
4241
data = {
4342
"model": self._model,
4443
"input": text,
4544
"voice": voice,
4645
"response_format": "wav",
47-
"speed": speed,
48-
"stream": True
46+
"speed": speed
4947
}
50-
# Use separate timeouts for connecting and reading.
51-
timeout = aiohttp.ClientTimeout(total=None, sock_connect=5, sock_read=25)
52-
async with self._session.post(self._url, headers=headers, json=data, timeout=timeout) as resp:
53-
resp.raise_for_status()
54-
audio_chunks = []
55-
# Optimize the chunk size to 4096 bytes.
56-
async for chunk in resp.content.iter_chunked(4096):
57-
if chunk:
58-
audio_chunks.append(chunk)
59-
audio_data = b"".join(audio_chunks)
60-
return AudioResponse(audio_data)
6148

62-
def get_tts(self, text: str, speed: float = None, voice: str = None) -> AudioResponse:
63-
"""Synchronous wrapper that runs the asynchronous TTS request on a dedicated event loop.
64-
If 'speed' or 'voice' are provided, they override the stored values.
65-
"""
66-
try:
67-
if speed is None:
68-
speed = self._speed
69-
if voice is None:
70-
voice = self._voice
71-
future = asyncio.run_coroutine_threadsafe(self._async_get_tts(text, speed, voice), self._loop)
72-
return future.result()
73-
except Exception as e:
74-
_LOGGER.error("Error in asynchronous get_tts: %s", e)
75-
raise e
49+
max_retries = 1
50+
attempt = 0
51+
while True:
52+
try:
53+
req = Request(
54+
self._url,
55+
data=json.dumps(data).encode("utf-8"),
56+
headers=headers,
57+
method="POST"
58+
)
59+
# Set a timeout of 30 seconds for the entire request.
60+
with urlopen(req, timeout=30) as response:
61+
content = response.read()
62+
return AudioResponse(content)
63+
except CancelledError as ce:
64+
_LOGGER.exception("TTS request cancelled")
65+
raise # Propagate cancellation.
66+
except (HTTPError, URLError) as net_err:
67+
_LOGGER.exception("Network error in synchronous get_tts on attempt %d", attempt + 1)
68+
if attempt < max_retries:
69+
attempt += 1
70+
time.sleep(1) # Wait for 1 second before retrying.
71+
_LOGGER.debug("Retrying HTTP call (attempt %d)", attempt + 1)
72+
continue
73+
else:
74+
raise HomeAssistantError("Network error occurred while fetching TTS audio") from net_err
75+
except Exception as exc:
76+
_LOGGER.exception("Unknown error in synchronous get_tts on attempt %d", attempt + 1)
77+
if attempt < max_retries:
78+
attempt += 1
79+
time.sleep(1)
80+
_LOGGER.debug("Retrying HTTP call (attempt %d)", attempt + 1)
81+
continue
82+
else:
83+
raise HomeAssistantError("An unknown error occurred while fetching TTS audio") from exc
7684

7785
def close(self):
78-
"""Clean up the aiohttp session and event loop on shutdown."""
79-
if self._session:
80-
asyncio.run_coroutine_threadsafe(self._session.close(), self._loop).result()
81-
self._loop.call_soon_threadsafe(self._loop.stop())
86+
"""Nothing to close in the synchronous version."""
87+
pass
8288

8389
@staticmethod
8490
def get_supported_langs() -> list:

custom_components/openai_tts/tts.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from __future__ import annotations
55
import io
66
import math
7-
import re
87
import struct
98
import wave
109
import logging
10+
from asyncio import CancelledError
1111

1212
from homeassistant.components.tts import TextToSpeechEntity
1313
from homeassistant.config_entries import ConfigEntry
@@ -20,16 +20,22 @@
2020

2121
_LOGGER = logging.getLogger(__name__)
2222

23-
# --- Helper Functions - Chime & silence synthesis --
23+
# --- Helper Functions - Chime & Silence Synthesis ---
2424

2525
def synthesize_chime(sample_rate: int = 44100, channels: int = 1, sampwidth: int = 2, duration: float = 1.0) -> bytes:
26-
_LOGGER.debug("Synthesizing chime: sample_rate=%d, channels=%d, sampwidth=%d, duration=%.2f", sample_rate, channels, sampwidth, duration)
27-
frequency1 = 440.0 # Note A
26+
_LOGGER.debug(
27+
"Synthesizing chime: sample_rate=%d, channels=%d, sampwidth=%d, duration=%.2f",
28+
sample_rate,
29+
channels,
30+
sampwidth,
31+
duration,
32+
)
33+
frequency1 = 440.0 # Note A
2834
frequency2 = 587.33 # Note D
2935
amplitude = 0.8
3036
num_samples = int(sample_rate * duration)
3137
output = io.BytesIO()
32-
with wave.open(output, 'wb') as wf:
38+
with wave.open(output, "wb") as wf:
3339
wf.setnchannels(channels)
3440
wf.setsampwidth(sampwidth)
3541
wf.setframerate(sample_rate)
@@ -40,33 +46,43 @@ def synthesize_chime(sample_rate: int = 44100, channels: int = 1, sampwidth: int
4046
sample2 = math.sin(2 * math.pi * frequency2 * t)
4147
sample = amplitude * fade * ((sample1 + sample2) / 2)
4248
int_sample = int(sample * 32767)
43-
wf.writeframes(struct.pack('<h', int_sample))
49+
wf.writeframes(struct.pack("<h", int_sample))
4450
chime_data = output.getvalue()
4551
_LOGGER.debug("Chime synthesized, length: %d bytes", len(chime_data))
4652
return chime_data
4753

4854
def synthesize_silence(sample_rate: int, channels: int, sampwidth: int, duration: float = 0.3) -> bytes:
49-
_LOGGER.debug("Synthesizing silence: sample_rate=%d, channels=%d, sampwidth=%d, duration=%.2f", sample_rate, channels, sampwidth, duration)
55+
_LOGGER.debug(
56+
"Synthesizing silence: sample_rate=%d, channels=%d, sampwidth=%d, duration=%.2f",
57+
sample_rate,
58+
channels,
59+
sampwidth,
60+
duration,
61+
)
5062
num_samples = int(sample_rate * duration)
5163
output = io.BytesIO()
52-
with wave.open(output, 'wb') as wf:
64+
with wave.open(output, "wb") as wf:
5365
wf.setnchannels(channels)
5466
wf.setsampwidth(sampwidth)
5567
wf.setframerate(sample_rate)
5668
for _ in range(num_samples):
57-
wf.writeframes(struct.pack('<h', 0))
69+
wf.writeframes(struct.pack("<h", 0))
5870
silence_data = output.getvalue()
5971
_LOGGER.debug("Silence synthesized, length: %d bytes", len(silence_data))
6072
return silence_data
6173

6274
def combine_wav_files(chime_bytes: bytes, pause_bytes: bytes, tts_bytes: bytes) -> bytes:
63-
_LOGGER.debug("Combining WAV files: chime (%d bytes), pause (%d bytes), TTS (%d bytes)",
64-
len(chime_bytes), len(pause_bytes), len(tts_bytes))
75+
_LOGGER.debug(
76+
"Combining WAV files: chime (%d bytes), pause (%d bytes), TTS (%d bytes)",
77+
len(chime_bytes),
78+
len(pause_bytes),
79+
len(tts_bytes),
80+
)
6581
chime_io = io.BytesIO(chime_bytes)
6682
pause_io = io.BytesIO(pause_bytes)
6783
tts_io = io.BytesIO(tts_bytes)
68-
69-
with wave.open(chime_io, 'rb') as w1, wave.open(pause_io, 'rb') as w2, wave.open(tts_io, 'rb') as w3:
84+
85+
with wave.open(chime_io, "rb") as w1, wave.open(pause_io, "rb") as w2, wave.open(tts_io, "rb") as w3:
7086
params1 = w1.getparams()
7187
params2 = w2.getparams()
7288
params3 = w3.getparams()
@@ -75,9 +91,9 @@ def combine_wav_files(chime_bytes: bytes, pause_bytes: bytes, tts_bytes: bytes)
7591
frames_chime = w1.readframes(w1.getnframes())
7692
frames_pause = w2.readframes(w2.getnframes())
7793
frames_tts = w3.readframes(w3.getnframes())
78-
94+
7995
output = io.BytesIO()
80-
with wave.open(output, 'wb') as wout:
96+
with wave.open(output, "wb") as wout:
8197
wout.setparams(params1)
8298
wout.writeframes(frames_chime)
8399
wout.writeframes(frames_pause)
@@ -110,7 +126,7 @@ async def async_setup_entry(
110126
config_entry.data[CONF_VOICE],
111127
config_entry.data[CONF_MODEL],
112128
config_entry.data.get(CONF_SPEED, 1.0),
113-
config_entry.data[CONF_URL]
129+
config_entry.data[CONF_URL],
114130
)
115131
async_add_entities([OpenAITTSEntity(hass, config_entry, engine)])
116132

@@ -142,14 +158,16 @@ def device_info(self) -> dict:
142158
return {
143159
"identifiers": {(DOMAIN, self._attr_unique_id)},
144160
"model": self._config.data.get(CONF_MODEL),
145-
"manufacturer": "OpenAI"
161+
"manufacturer": "OpenAI",
146162
}
147163

148164
@property
149165
def name(self) -> str:
150166
return _map_model(self._config.data.get(CONF_MODEL, "")).upper()
151167

152-
def get_tts_audio(self, message: str, language: str, options: dict | None = None) -> tuple[str, bytes] | tuple[None, None]:
168+
def get_tts_audio(
169+
self, message: str, language: str, options: dict | None = None
170+
) -> tuple[str, bytes] | tuple[None, None]:
153171
try:
154172
if len(message) > 4096:
155173
raise MaxLengthExceeded("Message exceeds maximum allowed length")
@@ -167,27 +185,35 @@ def get_tts_audio(self, message: str, language: str, options: dict | None = None
167185
if chime_enabled:
168186
_LOGGER.debug("Chime option enabled; synthesizing chime and pause.")
169187
tts_io = io.BytesIO(audio_content)
170-
with wave.open(tts_io, 'rb') as tts_wave:
188+
with wave.open(tts_io, "rb") as tts_wave:
171189
sample_rate = tts_wave.getframerate()
172190
channels = tts_wave.getnchannels()
173191
sampwidth = tts_wave.getsampwidth()
174192
tts_frames = tts_wave.getnframes()
175-
_LOGGER.debug("TTS parameters: sample_rate=%d, channels=%d, sampwidth=%d, frames=%d",
176-
sample_rate, channels, sampwidth, tts_frames)
193+
_LOGGER.debug(
194+
"TTS parameters: sample_rate=%d, channels=%d, sampwidth=%d, frames=%d",
195+
sample_rate,
196+
channels,
197+
sampwidth,
198+
tts_frames,
199+
)
177200
chime_audio = synthesize_chime(sample_rate=sample_rate, channels=channels, sampwidth=sampwidth, duration=1.0)
178201
pause_audio = synthesize_silence(sample_rate=sample_rate, channels=channels, sampwidth=sampwidth, duration=0.3)
179202
try:
180203
combined_audio = combine_wav_files(chime_audio, pause_audio, audio_content)
181204
_LOGGER.debug("Combined audio generated (chime -> pause -> TTS).")
182205
return "wav", combined_audio
183206
except Exception as ce:
184-
_LOGGER.error("Error combining audio: %s", ce)
207+
_LOGGER.exception("Error combining audio")
185208
return "wav", audio_content
186209
else:
187210
_LOGGER.debug("Chime option disabled; returning TTS audio only.")
188211
return "wav", audio_content
212+
except CancelledError as ce:
213+
_LOGGER.exception("TTS task cancelled")
214+
return None, None
189215
except MaxLengthExceeded as mle:
190-
_LOGGER.error("Maximum message length exceeded: %s", mle)
216+
_LOGGER.exception("Maximum message length exceeded")
191217
except Exception as e:
192-
_LOGGER.error("Unknown error in get_tts_audio: %s", e)
218+
_LOGGER.exception("Unknown error in get_tts_audio")
193219
return None, None

0 commit comments

Comments
 (0)