Skip to content

定义MessageSender,将websocket.send进行统一管理 #1463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions main/xiaozhi-server/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import subprocess
import websockets
from core.handle.mcpHandle import call_mcp_tool
from core.message.sender.message_sender_factory import MessageSenderFactory
from core.utils.util import (
extract_json_from_string,
check_vad_update,
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
self.read_config_from_api = self.config.get("read_config_from_api", False)

self.websocket = None
self.message_sender = None
self.headers = None
self.device_id = None
self.client_ip = None
Expand Down Expand Up @@ -186,12 +188,18 @@ async def handle_connection(self, ws):
self.websocket = ws
self.device_id = self.headers.get("device-id", None)

# 初始化message发送器
message_sender = MessageSenderFactory.create_sender(
self.config, self.websocket
)
self.message_sender = message_sender

# 启动超时检查任务
self.timeout_task = asyncio.create_task(self._check_timeout())

self.welcome_msg = self.config["xiaozhi"]
self.welcome_msg["session_id"] = self.session_id
await self.websocket.send(json.dumps(self.welcome_msg))
await message_sender.send(json.dumps(self.welcome_msg))

# 获取差异化配置
self._initialize_private_config()
Expand Down Expand Up @@ -263,7 +271,7 @@ async def handle_restart(self, message):
self.logger.bind(tag=TAG).info("收到服务器重启指令,准备执行...")

# 发送确认响应
await self.websocket.send(
await self.message_sender.send(
json.dumps(
{
"type": "server",
Expand Down Expand Up @@ -293,7 +301,7 @@ def restart_server():

except Exception as e:
self.logger.bind(tag=TAG).error(f"重启失败: {str(e)}")
await self.websocket.send(
await self.message_sender.send(
json.dumps(
{
"type": "server",
Expand Down
2 changes: 1 addition & 1 deletion main/xiaozhi-server/core/handle/abortHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ async def handleAbortMessage(conn):
conn.client_abort = True
conn.clear_queues()
# 打断客户端说话状态
await conn.websocket.send(
await conn.message_sender.send(
json.dumps({"type": "tts", "state": "stop", "session_id": conn.session_id})
)
conn.clearSpeakStatus()
Expand Down
2 changes: 1 addition & 1 deletion main/xiaozhi-server/core/handle/helloHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def handleHelloMessage(conn, msg_json):
# 发送mcp消息,获取tools列表
asyncio.create_task(send_mcp_tools_list_request(conn))

await conn.websocket.send(json.dumps(conn.welcome_msg))
await conn.message_sender.send(json.dumps(conn.welcome_msg))


async def checkWakeupWords(conn, text):
Expand Down
2 changes: 1 addition & 1 deletion main/xiaozhi-server/core/handle/iotHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ async def send_iot_conn(conn, name, method_name, parameters):
if parameters:
command["parameters"] = parameters
send_message = json.dumps({"type": "iot", "commands": [command]})
await conn.websocket.send(send_message)
await conn.message_sender.send(send_message)
conn.logger.bind(tag=TAG).info(f"发送物联网指令: {send_message}")
return
conn.logger.bind(tag=TAG).error(f"未找到方法{method_name}")
2 changes: 1 addition & 1 deletion main/xiaozhi-server/core/handle/mcpHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def send_mcp_message(conn, payload: dict):
message = json.dumps({"type": "mcp", "payload": payload})

try:
await conn.websocket.send(message)
await conn.message_sender.send(message)
conn.logger.bind(tag=TAG).info(f"成功发送MCP消息: {message}")
except Exception as e:
conn.logger.bind(tag=TAG).error(f"发送MCP消息失败: {e}")
Expand Down
10 changes: 5 additions & 5 deletions main/xiaozhi-server/core/handle/sendAudioHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def sendAudioMessage(conn, sentenceType, audios, text):
if text is not None:
emotion = analyze_emotion(text)
emoji = emoji_map.get(emotion, "🙂") # 默认使用笑脸
await conn.websocket.send(
await conn.message_sender.send(
json.dumps(
{
"type": "llm",
Expand Down Expand Up @@ -81,7 +81,7 @@ async def sendAudio(conn, audios, pre_buffer=True):
if pre_buffer:
pre_buffer_frames = min(3, len(audios))
for i in range(pre_buffer_frames):
await conn.websocket.send(audios[i])
await conn.message_sender.send(audios[i])
remaining_audios = audios[pre_buffer_frames:]
else:
remaining_audios = audios
Expand All @@ -104,7 +104,7 @@ async def sendAudio(conn, audios, pre_buffer=True):
if delay > 0:
await asyncio.sleep(delay)

await conn.websocket.send(opus_packet)
await conn.message_sender.send(opus_packet)

play_position += frame_duration

Expand All @@ -129,7 +129,7 @@ async def send_tts_message(conn, state, text=None):
conn.clearSpeakStatus()

# 发送消息到客户端
await conn.websocket.send(json.dumps(message))
await conn.message_sender.send(json.dumps(message))


async def send_stt_message(conn, text):
Expand All @@ -140,7 +140,7 @@ async def send_stt_message(conn, text):

"""发送 STT 状态消息"""
stt_text = get_string_no_punctuation_or_emoji(text)
await conn.websocket.send(
await conn.message_sender.send(
json.dumps({"type": "stt", "text": stt_text, "session_id": conn.session_id})
)
conn.client_is_speaking = True
Expand Down
14 changes: 7 additions & 7 deletions main/xiaozhi-server/core/handle/textHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def handleTextMessage(conn, message):
msg_json = json.loads(message)
if isinstance(msg_json, int):
conn.logger.bind(tag=TAG).info(f"收到文本消息:{message}")
await conn.websocket.send(message)
await conn.message_sender.send(message)
return
if msg_json["type"] == "hello":
conn.logger.bind(tag=TAG).info(f"收到hello消息:{message}")
Expand Down Expand Up @@ -92,7 +92,7 @@ async def handleTextMessage(conn, message):
secret = conn.config["manager-api"].get("secret", "")
# 如果secret不匹配,则返回
if post_secret != secret:
await conn.websocket.send(
await conn.message_sender.send(
json.dumps(
{
"type": "server",
Expand All @@ -107,7 +107,7 @@ async def handleTextMessage(conn, message):
try:
# 更新WebSocketServer的配置
if not conn.server:
await conn.websocket.send(
await conn.message_sender.send(
json.dumps(
{
"type": "server",
Expand All @@ -120,7 +120,7 @@ async def handleTextMessage(conn, message):
return

if not await conn.server.update_config():
await conn.websocket.send(
await conn.message_sender.send(
json.dumps(
{
"type": "server",
Expand All @@ -133,7 +133,7 @@ async def handleTextMessage(conn, message):
return

# 发送成功响应
await conn.websocket.send(
await conn.message_sender.send(
json.dumps(
{
"type": "server",
Expand All @@ -145,7 +145,7 @@ async def handleTextMessage(conn, message):
)
except Exception as e:
conn.logger.bind(tag=TAG).error(f"更新配置失败: {str(e)}")
await conn.websocket.send(
await conn.message_sender.send(
json.dumps(
{
"type": "server",
Expand All @@ -161,4 +161,4 @@ async def handleTextMessage(conn, message):
else:
conn.logger.bind(tag=TAG).error(f"收到未知类型消息:{message}")
except json.JSONDecodeError:
await conn.websocket.send(message)
await conn.message_sender.send(message)
17 changes: 17 additions & 0 deletions main/xiaozhi-server/core/message/sender/message_sender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from abc import ABC, abstractmethod

class MessageSender(ABC):
"""
消息发送的抽象基类。
定义了所有消息发送器都应该实现的发送接口。
"""

@abstractmethod
def send(self, message: any):
"""
抽象方法:发送消息。

Args:
message (any): 要发送的内容。
"""
pass
31 changes: 31 additions & 0 deletions main/xiaozhi-server/core/message/sender/message_sender_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Dict, Any

from core.message.sender.message_sender import MessageSender
from core.message.sender.websocket_sender import WebSocketSender

class MessageSenderFactory:
"""
消息发送器的工厂类,根据配置和已存在的连接创建具体的 MessageSender 实例。
"""

@staticmethod
def create_sender(config: Dict[str, Any], connection_instance: Any) -> MessageSender:
"""
根据配置字典和已存在的连接实例创建并返回一个 MessageSender 实例。

Args:
config (Dict[str, Any]): 配置字典,至少包含 'message_sender_type' 键,指示使用哪种发送器。
connection_instance (Any): 已创建并活跃的 WebSocket 连接对象或 MQTT 客户端对象。

Returns:
MessageSender: 具体的 MessageSender 实例。

Raises:
ValueError: 如果配置中的 'message_sender_type' 不支持或 connection_instance 类型不匹配。
"""
message_sender_type = config.get("message_sender_type", "websocket") # 默认使用 websocket 发送器

if message_sender_type == "websocket":
return WebSocketSender(connection_instance)
else:
raise ValueError(f"不支持的消息发送器类型: {message_sender_type}")
27 changes: 27 additions & 0 deletions main/xiaozhi-server/core/message/sender/websocket_sender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from config.logger import setup_logging
from core.message.sender.message_sender import MessageSender

TAG = __name__
logger = setup_logging()

class WebSocketSender(MessageSender):
"""
通过一个已存在的 WebSocket 连接发送消息的具体实现。
"""
def __init__(self, websocket_connection):
"""
初始化 WebSocketSender。

Args:
websocket_connection: 已建立并活跃的 WebSocket 连接实例。
"""
self.websocket_connection = websocket_connection
logger.bind(tag=TAG).debug("WebSocketSender: 初始化,使用已存在的 WebSocket 连接")
if not self.websocket_connection:
logger.bind(tag=TAG).error("提供的 WebSocket 连接为空或无效。")

async def send(self, message: str):
if self.websocket_connection:
await self.websocket_connection.send(message)
else:
logger.bind(tag=TAG).error(f"WebSocket 连接已关闭或无效,无法发送文字消息。")