Skip to content

Commit 5669b6e

Browse files
authored
[Feat] support for "O" series models of OpenAI (#160)
* Added condition to check whether model is of o1 type * Update unstract-sdk's version in __init__.py Signed-off-by: Praveen Kumar <praveen@zipstack.com> * Added docs to init method of openai adapter * Updated comment in open_ai.py * added pdm.lock --------- Signed-off-by: Praveen Kumar <praveen@zipstack.com>
1 parent 75091c8 commit 5669b6e

File tree

5 files changed

+33
-23
lines changed

5 files changed

+33
-23
lines changed

pdm.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"llama-index-vector-stores-weaviate==1.3.1",
3535
"llama-index-vector-stores-pinecone==0.4.2",
3636
"llama-index-vector-stores-qdrant==0.4.2",
37-
"llama-index-llms-openai==0.3.12",
37+
"llama-index-llms-openai==0.3.17",
3838
"llama-index-llms-palm==0.3.0",
3939
"llama-index-llms-mistralai==0.3.1",
4040
"mistralai==1.2.5",

src/unstract/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.57.0rc2"
1+
__version__ = "0.57.0rc3"
22

33

44
def get_sdk_version():

src/unstract/sdk/adapters/llm/llm_adapter.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional
55

66
from llama_index.core.llms import LLM, MockLLM
7+
from llama_index.llms.openai.utils import O1_MODELS
78

89
from unstract.sdk.adapters.base import Adapter
910
from unstract.sdk.adapters.enums import AdapterTypes
@@ -72,9 +73,14 @@ def _test_llm_instance(llm: Optional[LLM]) -> bool:
7273
message="Unable to connect to LLM, please recheck the configuration",
7374
status_code=400,
7475
)
76+
# Get completion kwargs based on model capabilities
77+
completion_kwargs = {}
78+
if hasattr(llm, 'model') and getattr(llm, 'model') not in O1_MODELS:
79+
completion_kwargs['temperature'] = 0.003
80+
7581
response = llm.complete(
7682
"The capital of Tamilnadu is ",
77-
temperature=0.003,
83+
**completion_kwargs
7884
)
7985
response_lower_case: str = response.text.lower()
8086
find_match = re.search("chennai", response_lower_case)

src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
from llama_index.core.llms import LLM
55
from llama_index.llms.openai import OpenAI
6+
from llama_index.llms.openai.utils import O1_MODELS
67
from openai import APIError as OpenAIAPIError
78

89
from unstract.sdk.adapters.exceptions import AdapterError
910
from unstract.sdk.adapters.llm.constants import LLMKeys
1011
from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter
1112
from unstract.sdk.exceptions import LLMError
1213

13-
1414
class Constants:
1515
MODEL = "model"
1616
API_KEY = "api_key"
@@ -23,6 +23,7 @@ class Constants:
2323

2424

2525
class OpenAILLM(LLMAdapter):
26+
2627
def __init__(self, settings: dict[str, Any]):
2728
super().__init__("OpenAI")
2829
self.config = settings
@@ -53,21 +54,24 @@ def get_llm_instance(self) -> LLM:
5354
try:
5455
max_tokens = self.config.get(Constants.MAX_TOKENS)
5556
max_tokens = int(max_tokens) if max_tokens else None
56-
llm: LLM = OpenAI(
57-
model=str(self.config.get(Constants.MODEL)),
58-
api_key=str(self.config.get(Constants.API_KEY)),
59-
api_base=str(self.config.get(Constants.API_BASE)),
60-
api_version=str(self.config.get(Constants.API_VERSION)),
61-
max_retries=int(
62-
self.config.get(Constants.MAX_RETRIES, LLMKeys.DEFAULT_MAX_RETRIES)
63-
),
64-
api_type="openai",
65-
temperature=0,
66-
timeout=float(
67-
self.config.get(Constants.TIMEOUT, LLMKeys.DEFAULT_TIMEOUT)
68-
),
69-
max_tokens=max_tokens,
70-
)
57+
model = str(self.config.get(Constants.MODEL))
58+
59+
llm_kwargs = {
60+
"model": model,
61+
"api_key": str(self.config.get(Constants.API_KEY)),
62+
"api_base": str(self.config.get(Constants.API_BASE)),
63+
"api_version": str(self.config.get(Constants.API_VERSION)),
64+
"max_retries": int(self.config.get(Constants.MAX_RETRIES, LLMKeys.DEFAULT_MAX_RETRIES)),
65+
"api_type": "openai",
66+
"timeout": float(self.config.get(Constants.TIMEOUT, LLMKeys.DEFAULT_TIMEOUT)),
67+
"max_tokens": max_tokens,
68+
}
69+
70+
# O-series models default to temperature=1, ignoring passed values, so it's not set explicitly.
71+
if model not in O1_MODELS:
72+
llm_kwargs["temperature"] = 0
73+
74+
llm = OpenAI(**llm_kwargs)
7175
return llm
7276
except Exception as e:
7377
raise AdapterError(str(e))

0 commit comments

Comments
 (0)