diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 1bf919e12..cc8aadc65 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -244,10 +244,11 @@ } # fmt: skip SUNBIRD_SUPPORTED_LANGUAGES = { - "eng", "swa", "ach", "lgg", "lug", "nyn", - "teo", "xog", "ttj", "kin", "myx", + "ach": "<|su|>", "eng": "<|en|>", "kin": "<|as|>", "lgg": "<|jw|>", "lug": "<|ba|>", "myx": "<|mg|>", + "nyn": "<|ha|>", "swa": "<|sw|>", "teo": "<|ln|>", "ttj": "<|tt|>", "xog": "<|haw|>" } # fmt: skip + # https://translation.ghananlp.org/api-details#api=ghananlp-translation-webservice-api GHANA_NLP_SUPPORTED = {'en': 'English', 'tw': 'Twi', 'gaa': 'Ga', 'ee': 'Ewe', 'fat': 'Fante', 'dag': 'Dagbani', 'gur': 'Gurene', 'yo': 'Yoruba', 'ki': 'Kikuyu', 'luo': 'Luo', 'mer': 'Kimeru'} # fmt: skip @@ -1301,6 +1302,8 @@ def run_asr( # don't pass language or task kwargs.pop("task", None) kwargs["max_length"] = 448 + elif selected_model == AsrModels.whisper_sunbird_large_v3: + kwargs["language"] = SUNBIRD_SUPPORTED_LANGUAGES[language.strip()] elif "whisper" in selected_model.name: forced_lang = forced_asr_languages.get(selected_model) if forced_lang: