Skip to content

Commit 8283de0

Browse files
enhancement: extent websocket handler for general usage
1 parent cfcca94 commit 8283de0

File tree

4 files changed

+116
-32
lines changed

4 files changed

+116
-32
lines changed

transports/bifrost-http/handlers/server.go

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ type BifrostHTTPServer struct {
5858
Client *bifrost.Bifrost
5959
Config *lib.Config
6060

61-
Server *fasthttp.Server
62-
Router *router.Router
61+
Server *fasthttp.Server
62+
Router *router.Router
63+
WebSocketHandler *WebSocketHandler
6364
}
6465

6566
// NewBifrostHTTPServer creates a new instance of BifrostHTTPServer.
@@ -395,14 +396,15 @@ func (s *BifrostHTTPServer) RegisterRoutes(ctx context.Context, middlewares ...l
395396
cacheHandler = NewCacheHandler(semanticCachePlugin, logger)
396397
}
397398
// Websocket handler needs to go below UI handler
398-
var wsHandler *WebSocketHandler
399+
logger.Debug("initializing websocket server")
399400
if loggerPlugin != nil {
400-
logger.Debug("initializing websocket server")
401-
wsHandler = NewWebSocketHandler(ctx, loggerPlugin.GetPluginLogManager(), logger, s.Config.ClientConfig.AllowedOrigins)
402-
loggerPlugin.SetLogCallback(wsHandler.BroadcastLogUpdate)
403-
// Start WebSocket heartbeat
404-
wsHandler.StartHeartbeat()
401+
s.WebSocketHandler = NewWebSocketHandler(ctx, loggerPlugin.GetPluginLogManager(), logger, s.Config.ClientConfig.AllowedOrigins)
402+
loggerPlugin.SetLogCallback(s.WebSocketHandler.BroadcastLogUpdate)
403+
} else {
404+
s.WebSocketHandler = NewWebSocketHandler(ctx, nil, logger, s.Config.ClientConfig.AllowedOrigins)
405405
}
406+
// Start WebSocket heartbeat
407+
s.WebSocketHandler.StartHeartbeat()
406408
middlewaresWithTelemetry := append(middlewares, telemetry.PrometheusMiddleware)
407409
// Chaining all middlewares
408410
// lib.ChainMiddlewares chains multiple middlewares together
@@ -429,8 +431,8 @@ func (s *BifrostHTTPServer) RegisterRoutes(ctx context.Context, middlewares ...l
429431
if loggingHandler != nil {
430432
loggingHandler.RegisterRoutes(s.Router, middlewares...)
431433
}
432-
if wsHandler != nil {
433-
wsHandler.RegisterRoutes(s.Router, middlewares...)
434+
if s.WebSocketHandler != nil {
435+
s.WebSocketHandler.RegisterRoutes(s.Router, middlewares...)
434436
}
435437
//
436438
// Add Prometheus /metrics endpoint

transports/bifrost-http/handlers/websocket.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func NewWebSocketHandler(ctx context.Context, logManager logging.LogManager, log
5151

5252
// RegisterRoutes registers all WebSocket-related routes
5353
func (h *WebSocketHandler) RegisterRoutes(r *router.Router, middlewares ...lib.BifrostHTTPMiddleware) {
54-
r.GET("/ws/logs", lib.ChainMiddlewares(h.connectLogStream, middlewares...))
54+
r.GET("/ws", lib.ChainMiddlewares(h.connectStream, middlewares...))
5555
}
5656

5757
// getUpgrader returns a WebSocket upgrader configured with the current allowed origins
@@ -86,8 +86,8 @@ func isLocalhost(host string) bool {
8686
host == ""
8787
}
8888

89-
// connectLogStream handles WebSocket connections for real-time log streaming
90-
func (h *WebSocketHandler) connectLogStream(ctx *fasthttp.RequestCtx) {
89+
// connectStream handles WebSocket connections for real-time streaming
90+
func (h *WebSocketHandler) connectStream(ctx *fasthttp.RequestCtx) {
9191
upgrader := h.getUpgrader()
9292
err := upgrader.Upgrade(ctx, func(ws *websocket.Conn) {
9393
// Read safety & liveness
@@ -161,6 +161,7 @@ func (h *WebSocketHandler) sendMessageSafely(client *WebSocketClient, messageTyp
161161
client.conn.Close()
162162
}()
163163
}
164+
164165
return err
165166
}
166167

@@ -195,6 +196,11 @@ func (h *WebSocketHandler) BroadcastLogUpdate(logEntry *logstore.Log) {
195196
return
196197
}
197198

199+
h.BroadcastMarshaledMessage(data)
200+
}
201+
202+
// BroadcastMarshaledMessage sends an adaptive routing update to all connected WebSocket clients
203+
func (h *WebSocketHandler) BroadcastMarshaledMessage(data []byte) {
198204
// Get a snapshot of clients to avoid holding the lock during writes
199205
h.mu.RLock()
200206
clients := make([]*WebSocketClient, 0, len(h.clients))

ui/app/logs/page.tsx

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,17 @@ export default function LogsPage() {
185185
});
186186
}, []);
187187

188-
const { isConnected: isSocketConnected, setMessageHandler } = useWebSocket();
188+
const { isConnected: isSocketConnected, subscribe } = useWebSocket();
189189

190-
// Set up the message handler when the component mounts
190+
// Subscribe to log messages
191191
useEffect(() => {
192-
setMessageHandler(handleLogMessage);
193-
}, [handleLogMessage, setMessageHandler]);
192+
const unsubscribe = subscribe("log", (data) => {
193+
const { payload, operation } = data;
194+
handleLogMessage(payload, operation);
195+
});
196+
197+
return unsubscribe;
198+
}, [handleLogMessage, subscribe]);
194199

195200
// Cleanup timeouts on unmount
196201
useEffect(() => {

ui/hooks/useWebSocket.tsx

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,61 @@
11
"use client";
22

3-
import React, { createContext, useContext, useEffect, useRef, useState, type ReactNode } from "react";
4-
import type { LogEntry, WebSocketLogMessage } from "../lib/types/logs";
3+
import React, { createContext, useContext, useEffect, useRef, useState, useCallback, type ReactNode } from "react";
54
import { getWebSocketUrl } from "@/lib/utils/port";
65

6+
type MessageHandler = (data: any) => void;
7+
78
interface WebSocketContextType {
89
isConnected: boolean;
910
ws: React.RefObject<WebSocket | null>;
10-
setMessageHandler: (handler: (log: LogEntry, operation: "create" | "update") => void) => void;
11+
subscribe: (channel: string, handler: MessageHandler) => () => void;
12+
send: (data: any) => void;
1113
}
1214

1315
const WebSocketContext = createContext<WebSocketContextType | null>(null);
1416

1517
interface WebSocketProviderProps {
1618
children: ReactNode;
19+
path?: string;
1720
}
1821

1922
// Global reference to maintain state across component remounts
2023
let globalWsRef: WebSocket | null = null;
21-
let globalMessageHandler: ((log: LogEntry, operation: "create" | "update") => void) | null = null;
24+
const messageHandlers = new Map<string, Set<MessageHandler>>();
2225

23-
export function WebSocketProvider({ children }: WebSocketProviderProps) {
26+
export function WebSocketProvider({ children, path = "/ws" }: WebSocketProviderProps) {
2427
const wsRef = useRef<WebSocket | null>(globalWsRef);
2528
const reconnectTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
29+
const pingTimerRef = useRef<ReturnType<typeof setInterval> | null>(null);
30+
const retryCountRef = useRef(0);
2631
const [isConnected, setIsConnected] = useState(false);
2732

28-
const setMessageHandler = (handler: (log: LogEntry, operation: "create" | "update") => void) => {
29-
globalMessageHandler = handler;
33+
const subscribe = useCallback<(channel: string, handler: MessageHandler) => () => void>((channel, handler) => {
34+
if (!messageHandlers.has(channel)) {
35+
messageHandlers.set(channel, new Set());
36+
}
37+
messageHandlers.get(channel)!.add(handler);
38+
39+
// Return unsubscribe function
40+
return () => {
41+
const handlers = messageHandlers.get(channel);
42+
if (handlers) {
43+
handlers.delete(handler);
44+
if (handlers.size === 0) {
45+
messageHandlers.delete(channel);
46+
}
47+
}
48+
};
49+
}, []);
50+
51+
const send = (data: any) => {
52+
if (wsRef.current?.readyState === WebSocket.OPEN) {
53+
try {
54+
wsRef.current.send(typeof data === "string" ? data : JSON.stringify(data));
55+
} catch (error) {
56+
console.error("Failed to send WebSocket message:", error);
57+
}
58+
}
3059
};
3160

3261
useEffect(() => {
@@ -35,7 +64,7 @@ export function WebSocketProvider({ children }: WebSocketProviderProps) {
3564
return;
3665
}
3766

38-
const wsUrl = getWebSocketUrl("/ws/logs");
67+
const wsUrl = getWebSocketUrl(path);
3968

4069
const ws = new WebSocket(wsUrl);
4170
wsRef.current = ws;
@@ -44,18 +73,44 @@ export function WebSocketProvider({ children }: WebSocketProviderProps) {
4473
ws.onopen = () => {
4574
console.log("WebSocket connected");
4675
setIsConnected(true);
76+
retryCountRef.current = 0; // Reset retry count on successful connection
77+
4778
// Clear any pending reconnection attempts
4879
if (reconnectTimeoutRef.current) {
4980
clearTimeout(reconnectTimeoutRef.current);
5081
reconnectTimeoutRef.current = null;
5182
}
83+
84+
// Start heartbeat/ping to keep connection alive
85+
if (pingTimerRef.current) {
86+
clearInterval(pingTimerRef.current);
87+
}
88+
pingTimerRef.current = setInterval(() => {
89+
if (ws.readyState === WebSocket.OPEN) {
90+
try {
91+
ws.send("ping");
92+
} catch (error) {
93+
console.error("Ping failed:", error);
94+
}
95+
}
96+
}, 25000); // Ping every 25 seconds
5297
};
5398

5499
ws.onmessage = (event) => {
55100
try {
56-
const data: WebSocketLogMessage = JSON.parse(event.data);
57-
if (data.type === "log" && globalMessageHandler) {
58-
globalMessageHandler(data.payload, data.operation);
101+
const data = JSON.parse(event.data);
102+
const messageType = data.type || "default";
103+
104+
// Notify all subscribers for this message type
105+
const handlers = messageHandlers.get(messageType);
106+
if (handlers) {
107+
handlers.forEach((handler) => handler(data));
108+
}
109+
110+
// Also notify wildcard subscribers
111+
const wildcardHandlers = messageHandlers.get("*");
112+
if (wildcardHandlers) {
113+
wildcardHandlers.forEach((handler) => handler(data));
59114
}
60115
} catch (error) {
61116
console.error("Failed to parse WebSocket message:", error);
@@ -65,11 +120,23 @@ export function WebSocketProvider({ children }: WebSocketProviderProps) {
65120
ws.onclose = () => {
66121
console.log("WebSocket disconnected, attempting to reconnect...");
67122
setIsConnected(false);
68-
// Attempt to reconnect after 5 seconds
69-
reconnectTimeoutRef.current = setTimeout(connect, 5000);
123+
124+
// Clear ping timer
125+
if (pingTimerRef.current) {
126+
clearInterval(pingTimerRef.current);
127+
pingTimerRef.current = null;
128+
}
129+
130+
// Exponential backoff: 0.5s, 1s, 2s, 4s, 8s, 16s, 32s (max)
131+
retryCountRef.current = Math.min(retryCountRef.current + 1, 6);
132+
const delay = Math.pow(2, retryCountRef.current) * 500;
133+
console.log(`Reconnecting in ${delay}ms...`);
134+
135+
reconnectTimeoutRef.current = setTimeout(connect, delay);
70136
};
71137

72138
ws.onerror = (error) => {
139+
console.error("WebSocket error:", error);
73140
setIsConnected(false);
74141
ws.close();
75142
};
@@ -84,10 +151,14 @@ export function WebSocketProvider({ children }: WebSocketProviderProps) {
84151
clearTimeout(reconnectTimeoutRef.current);
85152
reconnectTimeoutRef.current = null;
86153
}
154+
if (pingTimerRef.current) {
155+
clearInterval(pingTimerRef.current);
156+
pingTimerRef.current = null;
157+
}
87158
};
88-
}, []);
159+
}, [path]);
89160

90-
return <WebSocketContext.Provider value={{ isConnected, ws: wsRef, setMessageHandler }}>{children}</WebSocketContext.Provider>;
161+
return <WebSocketContext.Provider value={{ isConnected, ws: wsRef, subscribe, send }}>{children}</WebSocketContext.Provider>;
91162
}
92163

93164
export function useWebSocket() {

0 commit comments

Comments
 (0)