diff --git a/controller/attribute/llm_response_tmpl.py b/controller/attribute/llm_response_tmpl.py index 311d183c..118d906e 100644 --- a/controller/attribute/llm_response_tmpl.py +++ b/controller/attribute/llm_response_tmpl.py @@ -61,6 +61,13 @@ class LLMProvider_A2VYBG(Enum): "presence_penalty": float("@@PRESENCE_PENALTY@@"), } +IS_O_SERIES_A2VYBG = bool("@@IS_O_SERIES@@") + +if IS_O_SERIES_A2VYBG: + del LLM_KWARGS_A2VYBG["temperature"] + LLM_KWARGS_A2VYBG["max_completion_tokens"] = LLM_KWARGS_A2VYBG.pop("max_tokens") + + SYSTEM_PROMPT_A2VYBG = ( """@@SYSTEM_PROMPT@@ You must only output valid JSON. """ "If there is not yet a schema defined for the JSON output, " @@ -291,16 +298,29 @@ async def get_llm_response(record: dict, cached_records: dict): if curr_running_id in cached_records: return cached_records[curr_running_id] - messages = [ - { - "role": "system", - "content": SYSTEM_PROMPT_A2VYBG, - }, - { - "role": "user", - "content": USER_PROMPT_A2VYBG, - }, - ] + if IS_O_SERIES_A2VYBG: + # doesn't have a system prompt + messages = [ + { + "role": "user", + "content": f"""Instructions: +{SYSTEM_PROMPT_A2VYBG} +Further information: +{USER_PROMPT_A2VYBG} +""", + }, + ] + else: + messages = [ + { + "role": "system", + "content": SYSTEM_PROMPT_A2VYBG, + }, + { + "role": "user", + "content": USER_PROMPT_A2VYBG, + }, + ] exception = None for _ in range(int(MAX_RETRIES_A2VYBG)): try: diff --git a/controller/attribute/util.py b/controller/attribute/util.py index 3f60386c..b106855d 100644 --- a/controller/attribute/util.py +++ b/controller/attribute/util.py @@ -80,19 +80,22 @@ def prepare_sample_records_doc_bin( return prefixed_doc_bin -def test_openai_llm_connection(api_key: str, model: str): +def test_openai_llm_connection(api_key: str, model: str, is_o_series: bool = False): # more here: https://platform.openai.com/docs/api-reference/making-requests headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } - + if is_o_series: + add_payload = {"max_completion_tokens": 5} + else: + add_payload = {"max_tokens": 5} payload = { "model": model, "messages": [ {"role": "user", "content": [{"type": "text", "text": "only say 'hello'"}]}, ], - "max_tokens": 5, + **add_payload, } response = requests.post( @@ -124,7 +127,11 @@ def test_azure_foundry_llm_connection(api_key: str, base_endpoint: str): def test_azure_llm_connection( - api_key: str, base_endpoint: str, api_version: str, model: str + api_key: str, + base_endpoint: str, + api_version: str, + model: str, + is_o_series: bool = False, ): # more here: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference-preview base_endpoint = base_endpoint.rstrip("/") @@ -146,11 +153,15 @@ def test_azure_llm_connection( "api-key": api_key, } + if is_o_series: + add_payload = {"max_completion_tokens": 5} + else: + add_payload = {"max_tokens": 5} payload = { "messages": [ {"role": "user", "content": [{"type": "text", "text": "only say 'hello'"}]}, ], - "max_tokens": 5, + **add_payload, } response = requests.post(final_endpoint, headers=headers, json=payload) @@ -190,6 +201,7 @@ def validate_llm_config(llm_config: Dict[str, Any]): test_openai_llm_connection( api_key=llm_config["apiKey"], model=llm_config["model"], + is_o_series=llm_config.get("openAioSeries", False), ) elif llm_config["llmIdentifier"] == enums.LLMProvider.AZURE.value: test_azure_llm_connection( @@ -197,6 +209,7 @@ def validate_llm_config(llm_config: Dict[str, Any]): model=llm_config["model"], base_endpoint=llm_config["apiBase"], api_version=llm_config["apiVersion"], + is_o_series=llm_config.get("openAioSeries", False), ) elif llm_config["llmIdentifier"] == enums.LLMProvider.AZURE_FOUNDRY.value: test_azure_foundry_llm_connection( @@ -291,6 +304,8 @@ async def ac(record): "@@CACHE_FILE_UPLOAD_LINK@@": llm_config.get( "llmAcCacheFileUploadLink", "" ), + # string quotes are replaced since bool("False") == True + '"@@IS_O_SERIES@@"': str(llm_config.get("openAioSeries", False)), } except KeyError: raise LlmResponseError( diff --git a/submodules/model b/submodules/model index 6a570fef..b2d3a38b 160000 --- a/submodules/model +++ b/submodules/model @@ -1 +1 @@ -Subproject commit 6a570fef7d7fef3b77e9d5155fd1eb05db9d8a83 +Subproject commit b2d3a38b971eb9fc29afab8ddf867b53c8ca5579