@@ -442,6 +442,197 @@ async def voice_search(
442
442
)
443
443
444
444
445
+ @router .post (
446
+ "/voice-chat" ,
447
+ response_model = QueryAudioResponse ,
448
+ responses = {
449
+ status .HTTP_400_BAD_REQUEST : {
450
+ "model" : QueryResponseError ,
451
+ "description" : "Bad Request" ,
452
+ },
453
+ status .HTTP_500_INTERNAL_SERVER_ERROR : {
454
+ "model" : QueryResponseError ,
455
+ "description" : "Internal Server Error" ,
456
+ },
457
+ },
458
+ )
459
+ @observe ()
460
+ async def voice_chat (
461
+ file_url : str ,
462
+ request : Request ,
463
+ session_id : int | None = None ,
464
+ reset_chat_history : bool = False ,
465
+ asession : AsyncSession = Depends (get_async_session ),
466
+ workspace_db : WorkspaceDB = Depends (authenticate_key ),
467
+ ) -> QueryAudioResponse | JSONResponse :
468
+ """Endpoint to transcribe audio from a provided URL, generate an LLM response, by
469
+ default `generate_tts` is set to `True`, and return a public random URL of an audio
470
+ file containing the spoken version of the generated response.
471
+
472
+ Parameters
473
+ ----------
474
+ file_url
475
+ The URL of the audio file.
476
+ request
477
+ The FastAPI request object.
478
+ session_id
479
+ The session ID for the chat, in case of a continuation of a previous chat.
480
+ reset_chat_history
481
+ Specifies whether to reset the chat history.
482
+ asession
483
+ The SQLAlchemy async session to use for all database connections.
484
+ workspace_db
485
+ The authenticated workspace object.
486
+
487
+ Returns
488
+ -------
489
+ QueryAudioResponse | JSONResponse
490
+ The query audio response object or an appropriate JSON response.
491
+ """
492
+
493
+ workspace_id = workspace_db .workspace_id
494
+
495
+ try :
496
+ file_stream , content_type , file_extension = await download_file_from_url (
497
+ file_url = file_url
498
+ )
499
+ assert isinstance (file_stream , BytesIO )
500
+
501
+ unique_filename = generate_random_filename (extension = file_extension )
502
+ destination_blob_name = f"stt-voice-notes/{ unique_filename } "
503
+
504
+ await upload_file_to_gcs (
505
+ bucket_name = GCS_SPEECH_BUCKET ,
506
+ content_type = content_type ,
507
+ destination_blob_name = destination_blob_name ,
508
+ file_stream = file_stream ,
509
+ )
510
+
511
+ file_path = f"temp/{ unique_filename } "
512
+ with open (file_path , "wb" ) as f :
513
+ file_stream .seek (0 )
514
+ f .write (file_stream .read ())
515
+ file_stream .seek (0 )
516
+
517
+ if CUSTOM_STT_ENDPOINT is not None :
518
+ transcription = await post_to_speech_stt (
519
+ file_path = file_path , endpoint_url = CUSTOM_STT_ENDPOINT
520
+ )
521
+ transcription_result = transcription ["text" ]
522
+ else :
523
+ transcription_result = await transcribe_audio (audio_filename = file_path )
524
+
525
+ user_query = QueryBase (
526
+ generate_llm_response = True ,
527
+ query_metadata = {},
528
+ query_text = transcription_result ,
529
+ session_id = session_id ,
530
+ )
531
+
532
+ # 1.
533
+ user_query = await init_user_query_and_chat_histories (
534
+ redis_client = request .app .state .redis ,
535
+ reset_chat_history = reset_chat_history ,
536
+ user_query = user_query ,
537
+ )
538
+
539
+ # 2.
540
+ (
541
+ user_query_db ,
542
+ user_query_refined_template ,
543
+ response_template ,
544
+ ) = await get_user_query_and_response (
545
+ asession = asession ,
546
+ generate_tts = True ,
547
+ user_query = user_query ,
548
+ workspace_id = workspace_id ,
549
+ )
550
+ assert isinstance (user_query_db , QueryDB )
551
+
552
+ response = await get_search_response (
553
+ asession = asession ,
554
+ exclude_archived = True ,
555
+ n_similar = int (N_TOP_CONTENT ),
556
+ n_to_crossencoder = int (N_TOP_CONTENT_TO_CROSSENCODER ),
557
+ query_refined = user_query_refined_template ,
558
+ request = request ,
559
+ response = response_template ,
560
+ workspace_id = workspace_id ,
561
+ )
562
+
563
+ if user_query .generate_llm_response : # Should be always true in this case
564
+ response = await get_generation_response (
565
+ query_refined = user_query_refined_template , response = response
566
+ )
567
+
568
+ langfuse_context .update_current_trace (
569
+ name = "voice-chat" ,
570
+ session_id = user_query_refined_template .session_id ,
571
+ metadata = {"query_id" : response .query_id , "workspace_id" : workspace_id },
572
+ )
573
+
574
+ await save_query_response_to_db (
575
+ asession = asession ,
576
+ response = response ,
577
+ user_query_db = user_query_db ,
578
+ workspace_id = workspace_id ,
579
+ )
580
+ await increment_query_count (
581
+ asession = asession ,
582
+ contents = response .search_results ,
583
+ workspace_id = workspace_id ,
584
+ )
585
+ await save_content_for_query_to_db (
586
+ asession = asession ,
587
+ contents = response .search_results ,
588
+ query_id = response .query_id ,
589
+ session_id = user_query .session_id ,
590
+ workspace_id = workspace_id ,
591
+ )
592
+
593
+ if os .path .exists (file_path ):
594
+ os .remove (file_path )
595
+ file_stream .close ()
596
+
597
+ if isinstance (response , QueryResponseError ):
598
+ return JSONResponse (
599
+ status_code = status .HTTP_400_BAD_REQUEST , content = response .model_dump ()
600
+ )
601
+
602
+ if isinstance (response , QueryAudioResponse ):
603
+ return response
604
+
605
+ return JSONResponse (
606
+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
607
+ content = {"error" : "Internal server error" },
608
+ )
609
+
610
+ # Unsure where to place this
611
+ except LLMCallException :
612
+ return JSONResponse (
613
+ status_code = status .HTTP_502_BAD_GATEWAY ,
614
+ content = {
615
+ "error_message" : (
616
+ "LLM call returned an error: Please check LLM configuration"
617
+ )
618
+ },
619
+ )
620
+
621
+ except ValueError as ve :
622
+ logger .error (f"ValueError: { str (ve )} " )
623
+ return JSONResponse (
624
+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
625
+ content = {"error" : f"Value error: { str (ve )} " },
626
+ )
627
+
628
+ except Exception as e : # pylint: disable=W0718
629
+ logger .error (f"Unexpected error: { str (e )} " )
630
+ return JSONResponse (
631
+ status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
632
+ content = {"error" : "Internal server error" },
633
+ )
634
+
635
+
445
636
@identify_language__before
446
637
@classify_safety__before
447
638
@translate_question__before
0 commit comments