diff --git a/demo/gemini_audio_video/app.py b/demo/gemini_audio_video/app.py index f8e01a7a..5678035c 100644 --- a/demo/gemini_audio_video/app.py +++ b/demo/gemini_audio_video/app.py @@ -1,50 +1,136 @@ +# https://huggingface.co/spaces/freddyaboulton/gemini-audio-video-chat +# related demos: https://github.com/freddyaboulton/gradio-webrtc + import asyncio import base64 import os import time -from io import BytesIO +import logging +import traceback +import cv2 import gradio as gr import numpy as np -from dotenv import load_dotenv +from google import genai from fastrtc import ( AsyncAudioVideoStreamHandler, - Stream, WebRTC, + async_aggregate_bytes_to_16bit, + VideoEmitType, + AudioEmitType, get_twilio_turn_credentials, ) -from google import genai -from gradio.utils import get_space -from PIL import Image +import requests # Use requests for synchronous Twilio check + +# --- Setup Logging --- +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) -load_dotenv() +# --- Global State --- +twilio_available = None # Will be set *before* Gradio initialization +gemini_connected = False +# --- Helper Functions --- def encode_audio(data: np.ndarray) -> dict: - """Encode Audio data to send to the server""" - return { - "mime_type": "audio/pcm", - "data": base64.b64encode(data.tobytes()).decode("UTF-8"), - } - - -def encode_image(data: np.ndarray) -> dict: - with BytesIO() as output_bytes: - pil_image = Image.fromarray(data) - pil_image.save(output_bytes, "JPEG") - bytes_data = output_bytes.getvalue() - base64_str = str(base64.b64encode(bytes_data), "utf-8") + if not isinstance(data, np.ndarray): + raise TypeError("encode_audio expected a numpy.ndarray") + try: + return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")} + except Exception as e: + logger.error(f"Error encoding audio: {e}") + raise + +def encode_image(data: np.ndarray, quality: int = 85) -> dict: + """ + Encodes a NumPy array (image) to a JPEG, Base64-encoded UTF-8 string using OpenCV. + Handles various input data types. + + Args: + data: A NumPy array of shape (n, n, 3). + quality: JPEG quality (0-100). + + Returns: + A dictionary with keys "mime_type" and "data". + + Raises: + TypeError: If input is not a NumPy array. + ValueError: If input shape is incorrect or contains NaN/Inf. + Exception: If JPEG encoding fails. + """ + + # Input Validation (shape and dimensions) + if not isinstance(data, np.ndarray): + raise TypeError("Input must be a NumPy array.") + if data.ndim != 3 or data.shape[2] != 3: + raise ValueError("Input array must have shape (n, n, 3).") + if 0 in data.shape: + raise ValueError("Input array cannot have a dimension of size 0.") + + # Handle NaN/Inf (regardless of data type) + if np.any(np.isnan(data)) or np.any(np.isinf(data)): + raise ValueError("Input array contains NaN or Inf") + + # Normalize and convert to uint8 + if np.issubdtype(data.dtype, np.floating) or np.issubdtype(data.dtype, np.integer): + scaled_data = cv2.normalize(data, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) + else: + raise TypeError("Input array must have a floating-point or integer data type.") + + # JPEG Encoding (with quality control and error handling) + try: + retval, buf = cv2.imencode(".jpg", scaled_data, [int(cv2.IMWRITE_JPEG_QUALITY), quality]) + if not retval: + raise Exception("cv2.imencode failed") + except Exception as e: + raise Exception(f"JPEG encoding failed: {e}") + + # Base64 Encoding + jpeg_bytes = np.array(buf).tobytes() + base64_str = base64.b64encode(jpeg_bytes).decode('utf-8') + return {"mime_type": "image/jpeg", "data": base64_str} +def check_twilio_availability_sync() -> bool: + """Checks Twilio TURN server availability (synchronous version).""" + global twilio_available + retries = 3 + delay = 2 + + for attempt in range(retries): + try: + logger.info(f"Attempting to get Twilio credentials (attempt {attempt + 1})...") + credentials = get_twilio_turn_credentials() + logger.info(f"Twilio credentials response: {credentials}") + if credentials: + twilio_available = True + logger.info("Twilio TURN server available.") + return True + except requests.exceptions.RequestException as e: + logger.warning(f"Attempt {attempt + 1}: {e}") + logger.warning(traceback.format_exc()) + if attempt < retries - 1: + time.sleep(delay) + except Exception as e: + logger.exception(f"Unexpected error checking Twilio: {e}") + twilio_available = False + return False + + twilio_available = False + logger.warning("Twilio TURN server unavailable.") + return False + + +# --- Gemini Handler Class --- class GeminiHandler(AsyncAudioVideoStreamHandler): def __init__( - self, + self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480 ) -> None: super().__init__( - "mono", - output_sample_rate=24000, - output_frame_size=480, + expected_layout, + output_sample_rate, + output_frame_size, input_sample_rate=16000, ) self.audio_queue = asyncio.Queue() @@ -52,84 +138,149 @@ def __init__( self.quit = asyncio.Event() self.session = None self.last_frame_time = 0 - self.quit = asyncio.Event() def copy(self) -> "GeminiHandler": - return GeminiHandler() - - async def start_up(self): - client = genai.Client( - api_key=os.getenv("GEMINI_API_KEY"), http_options={"api_version": "v1alpha"} + return GeminiHandler( + expected_layout=self.expected_layout, + output_sample_rate=self.output_sample_rate, + output_frame_size=self.output_frame_size, ) - config = {"response_modalities": ["AUDIO"]} - async with client.aio.live.connect( - model="gemini-2.0-flash-exp", config=config - ) as session: - self.session = session - print("set session") - while not self.quit.is_set(): - turn = self.session.receive() - async for response in turn: - if data := response.data: - audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1) - self.audio_queue.put_nowait(audio) async def video_receive(self, frame: np.ndarray): if self.session: - # send image every 1 second - print(time.time() - self.last_frame_time) - if time.time() - self.last_frame_time > 1: - self.last_frame_time = time.time() - await self.session.send(input=encode_image(frame)) - if self.latest_args[1] is not None: - await self.session.send(input=encode_image(self.latest_args[1])) - + try: + if time.time() - self.last_frame_time > 1: + self.last_frame_time = time.time() + await self.session.send(encode_image(frame)) + if self.latest_args[2] is not None: + await self.session.send(encode_image(self.latest_args[2])) + except Exception as e: + logger.error(f"Error sending video frame: {e}") + gr.Warning("Error sending video to Gemini.") self.video_queue.put_nowait(frame) - async def video_emit(self): - return await self.video_queue.get() + async def video_emit(self) -> VideoEmitType: + try: + return await self.video_queue.get() + except asyncio.CancelledError: + logger.info("Video emit cancelled.") + return None + except Exception as e: + logger.exception(f"Error in video_emit: {e}") + return None + + async def connect(self, api_key: str): + global gemini_connected + if self.session is None: + try: + client = genai.Client(api_key=api_key, http_options={"api_version": "v1alpha"}) + config = {"response_modalities": ["AUDIO"]} + async with client.aio.live.connect( + model="gemini-2.0-flash-exp", config=config + ) as session: + self.session = session + gemini_connected = True + asyncio.create_task(self.receive_audio()) + await self.quit.wait() + except Exception as e: + logger.error(f"Error connecting to Gemini: {e}") + gemini_connected = False + self.shutdown() + gr.Warning(f"Failed to connect to Gemini: {e}") + finally: + update_gemini_status_sync() + + async def generator(self): + if not self.session: + logger.warning("Gemini session is not initialized.") + return + + while not self.quit.is_set(): + try: + await asyncio.sleep(0) # Yield to the event loop + if self.quit.is_set(): + break + turn = self.session.receive() + async for response in turn: + if self.quit.is_set(): + break # Exit inner loop if quit is set. + if data := response.data: + yield data + except Exception as e: + logger.error(f"Error receiving from Gemini: {e}") + self.quit.set() # set quit if we error. + break + + async def receive_audio(self): + try: + async for audio_response in async_aggregate_bytes_to_16bit(self.generator()): + self.audio_queue.put_nowait(audio_response) + except Exception as e: + logger.exception(f"Error in receive_audio: {e}") async def receive(self, frame: tuple[int, np.ndarray]) -> None: _, array = frame array = array.squeeze() - audio_message = encode_audio(array) - if self.session: - await self.session.send(input=audio_message) + try: + audio_message = encode_audio(array) + if self.session: + await self.session.send(audio_message) + except Exception as e: + logger.error(f"Error sending audio: {e}") + gr.Warning("Error sending audio to Gemini.") + + async def emit(self) -> AudioEmitType: + if not self.args_set.is_set(): + await self.wait_for_args() + if self.session is None: + asyncio.create_task(self.connect(self.latest_args[1])) - async def emit(self): - array = await self.audio_queue.get() - return (self.output_sample_rate, array) + try: + array = await self.audio_queue.get() + return (self.output_sample_rate, array) + except asyncio.CancelledError: + logger.info("Audio emit cancelled.") + return (self.output_sample_rate, np.array([])) + except Exception as e: + logger.exception(f"Error in emit: {e}") + return (self.output_sample_rate, np.array([])) - async def shutdown(self) -> None: + def shutdown(self) -> None: + global gemini_connected + gemini_connected = False + logger.info("Shutting down GeminiHandler.") if self.session: - self.quit.set() - await self.session._websocket.close() - self.quit.clear() - - -stream = Stream( - handler=GeminiHandler(), - modality="audio-video", - mode="send-receive", - rtc_configuration=get_twilio_turn_credentials() - if get_space() == "spaces" - else None, - time_limit=90 if get_space() else None, - additional_inputs=[ - gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"]) - ], - ui_args={ - "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", - "pulse_color": "rgb(255, 255, 255)", - "icon_button_color": "rgb(255, 255, 255)", - "title": "Gemini Audio Video Chat", - }, -) + try: + # await self.session.close() # There is no async close + pass + except Exception: + pass + self.quit.set() # Set quit *after* attempting to close the session + self.connection = None + self.args_set.clear() + + self.quit.clear() + update_gemini_status_sync() + +def update_gemini_status_sync(): + """Updates the Gemini status message (synchronous version).""" + status = "✅ Gemini: Connected" if gemini_connected else "❌ Gemini: Disconnected" + if 'demo' in locals() and demo.running: + gr.update(value=status) + + + +# --- Gradio UI --- css = """ #video-source {max-width: 600px !important; max-height: 600 !important;} """ +# Perform Twilio check *before* Gradio UI definition (synchronously) +if __name__ == "__main__": + check_twilio_availability_sync() + + with gr.Blocks(css=css) as demo: gr.HTML( """ @@ -146,40 +297,83 @@ async def shutdown(self) -> None: """ ) - with gr.Row() as row: + twilio_status_message = gr.Markdown("❓ Twilio: Checking...") + gemini_status_message = gr.Markdown("❓ Gemini: Checking...") + + with gr.Row() as api_key_row: + api_key = gr.Textbox( + label="API Key", + type="password", + placeholder="Enter your API Key", + value=os.getenv("GOOGLE_API_KEY"), + ) + with gr.Row(visible=False) as row: with gr.Column(): + # Set rtc_configuration based on the *pre-checked* twilio_available + rtc_config = get_twilio_turn_credentials() if twilio_available else None + # Explicitly specify codecs (example - you might need to adjust) + if rtc_config: + rtc_config['codecs'] = ['VP8', 'H264'] # Prefer VP8, then H.264 webrtc = WebRTC( label="Video Chat", modality="audio-video", mode="send-receive", elem_id="video-source", - rtc_configuration=get_twilio_turn_credentials() - if get_space() == "spaces" - else None, + rtc_configuration=rtc_config, icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", - pulse_color="rgb(255, 255, 255)", - icon_button_color="rgb(255, 255, 255)", + pulse_color="rgb(35, 157, 225)", + icon_button_color="rgb(35, 157, 225)", ) with gr.Column(): - image_input = gr.Image( - label="Image", type="numpy", sources=["upload", "clipboard"] - ) + image_input = gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"]) + + + def update_twilio_status_ui(): + if twilio_available: + message = "✅ Twilio: Available" + else: + message = "❌ Twilio: Unavailable (connection may be less reliable)" + return gr.update(value=message) + + demo.load(update_twilio_status_ui, [], [twilio_status_message]) + + handler = GeminiHandler() + webrtc.stream( + handler, + inputs=[webrtc, api_key, image_input], + outputs=[webrtc], + time_limit=90, + concurrency_limit=None, + ) + - webrtc.stream( - GeminiHandler(), - inputs=[webrtc, image_input], - outputs=[webrtc], - time_limit=60 if get_space() else None, - concurrency_limit=2 if get_space() else None, + def check_api_key(api_key_str): + if not api_key_str: + return ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(value="Please enter a valid API key"), + gr.update(value="❓ Gemini: Checking..."), + ) + return ( + gr.update(visible=False), + gr.update(visible=True), + gr.update(value=""), + gr.update(value="❓ Gemini: Checking..."), ) -stream.ui = demo + api_key.submit( + check_api_key, + [api_key], + [api_key_row, row, twilio_status_message, gemini_status_message], + ) + # If API key is already set via environment variables, hide the API key row and show content + if os.getenv("GOOGLE_API_KEY"): + demo.load( + lambda: (gr.update(visible=False), gr.update(visible=True)), + None, + [api_key_row, row], + ) -if __name__ == "__main__": - if (mode := os.getenv("MODE")) == "UI": - stream.ui.launch(server_port=7860) - elif mode == "PHONE": - raise ValueError("Phone mode not supported for this demo") - else: - stream.ui.launch(server_port=7860) +demo.launch()