Skip to content

Commit b5a4889

Browse files
Merge pull request #221 from cloudera/liu/refactor
Move model get logic to model provider classes
2 parents 3fa0201 + 0a162bb commit b5a4889

File tree

9 files changed

+196
-135
lines changed

9 files changed

+196
-135
lines changed

llm-service/app/services/models/embedding.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@
3939

4040
from fastapi import HTTPException
4141
from llama_index.core.base.embeddings.base import BaseEmbedding
42-
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
43-
from llama_index.embeddings.bedrock import BedrockEmbedding
44-
from llama_index.embeddings.openai import OpenAIEmbedding
4542

4643
from . import _model_type, _noop
4744
from .providers import (
@@ -50,9 +47,7 @@
5047
CAIIModelProvider,
5148
)
5249
from .providers.openai import OpenAiModelProvider
53-
from ..caii.caii import get_embedding_model as caii_embedding
5450
from ..caii.types import ModelResponse
55-
from ...config import settings
5651

5752

5853
class Embedding(_model_type.ModelType[BaseEmbedding]):
@@ -62,24 +57,12 @@ def get(cls, model_name: Optional[str] = None) -> BaseEmbedding:
6257
model_name = cls.list_available()[0].model_id
6358

6459
if AzureModelProvider.is_enabled():
65-
return AzureOpenAIEmbedding(
66-
model_name=model_name,
67-
deployment_name=model_name,
68-
# must be passed manually otherwise AzureOpenAIEmbedding checks OPENAI_API_KEY
69-
api_key=settings.azure_openai_api_key,
70-
)
71-
60+
return AzureModelProvider.get_embedding_model(model_name)
7261
if CAIIModelProvider.is_enabled():
73-
return caii_embedding(model_name=model_name)
74-
62+
return CAIIModelProvider.get_embedding_model(model_name)
7563
if OpenAiModelProvider.is_enabled():
76-
return OpenAIEmbedding(
77-
model_name=model_name,
78-
api_key=settings.openai_api_key,
79-
api_base=settings.openai_api_base,
80-
)
81-
82-
return BedrockEmbedding(model_name=model_name)
64+
return OpenAiModelProvider.get_embedding_model(model_name)
65+
return BedrockModelProvider.get_embedding_model(model_name)
8366

8467
@staticmethod
8568
def get_noop() -> BaseEmbedding:
@@ -88,15 +71,12 @@ def get_noop() -> BaseEmbedding:
8871
@staticmethod
8972
def list_available() -> list[ModelResponse]:
9073
if AzureModelProvider.is_enabled():
91-
return AzureModelProvider.get_embedding_models()
92-
74+
return AzureModelProvider.list_embedding_models()
9375
if CAIIModelProvider.is_enabled():
94-
return CAIIModelProvider.get_embedding_models()
95-
76+
return CAIIModelProvider.list_embedding_models()
9677
if OpenAiModelProvider.is_enabled():
97-
return OpenAiModelProvider.get_embedding_models()
98-
99-
return BedrockModelProvider.get_embedding_models()
78+
return OpenAiModelProvider.list_embedding_models()
79+
return BedrockModelProvider.list_embedding_models()
10080

10181
@classmethod
10282
def test(cls, model_name: str) -> str:

llm-service/app/services/models/llm.py

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@
4040
from fastapi import HTTPException
4141
from llama_index.core import llms
4242
from llama_index.core.base.llms.types import ChatMessage, MessageRole
43-
from llama_index.llms.azure_openai import AzureOpenAI
44-
from llama_index.llms.bedrock_converse import BedrockConverse
45-
from llama_index.llms.openai import OpenAI
4643

4744
from . import _model_type, _noop
4845
from .providers import (
@@ -51,10 +48,7 @@
5148
CAIIModelProvider,
5249
)
5350
from .providers.openai import OpenAiModelProvider
54-
from ..caii.caii import get_llm as caii_llm
5551
from ..caii.types import ModelResponse
56-
from ..llama_utils import completion_to_prompt, messages_to_prompt
57-
from ...config import settings
5852

5953

6054
class LLM(_model_type.ModelType[llms.LLM]):
@@ -64,37 +58,12 @@ def get(cls, model_name: Optional[str] = None) -> llms.LLM:
6458
model_name = cls.list_available()[0].model_id
6559

6660
if AzureModelProvider.is_enabled():
67-
return AzureOpenAI(
68-
model=model_name,
69-
engine=model_name,
70-
messages_to_prompt=messages_to_prompt,
71-
completion_to_prompt=completion_to_prompt,
72-
max_tokens=2048,
73-
)
74-
75-
if OpenAiModelProvider.is_enabled():
76-
return OpenAI(
77-
model=model_name,
78-
messages_to_prompt=messages_to_prompt,
79-
completion_to_prompt=completion_to_prompt,
80-
max_tokens=2048,
81-
api_base=settings.openai_api_base,
82-
api_key=settings.openai_api_key,
83-
)
84-
61+
return AzureModelProvider.get_llm_model(model_name)
8562
if CAIIModelProvider.is_enabled():
86-
return caii_llm(
87-
endpoint_name=model_name,
88-
messages_to_prompt=messages_to_prompt,
89-
completion_to_prompt=completion_to_prompt,
90-
)
91-
92-
return BedrockConverse(
93-
model=model_name,
94-
messages_to_prompt=messages_to_prompt,
95-
completion_to_prompt=completion_to_prompt,
96-
max_tokens=2048,
97-
)
63+
return CAIIModelProvider.get_llm_model(model_name)
64+
if OpenAiModelProvider.is_enabled():
65+
return OpenAiModelProvider.get_llm_model(model_name)
66+
return BedrockModelProvider.get_llm_model(model_name)
9867

9968
@staticmethod
10069
def get_noop() -> llms.LLM:
@@ -103,15 +72,12 @@ def get_noop() -> llms.LLM:
10372
@staticmethod
10473
def list_available() -> list[ModelResponse]:
10574
if AzureModelProvider.is_enabled():
106-
return AzureModelProvider.get_llm_models()
107-
75+
return AzureModelProvider.list_llm_models()
10876
if CAIIModelProvider.is_enabled():
109-
return CAIIModelProvider.get_llm_models()
110-
77+
return CAIIModelProvider.list_llm_models()
11178
if OpenAiModelProvider.is_enabled():
112-
return OpenAiModelProvider.get_llm_models()
113-
114-
return BedrockModelProvider.get_llm_models()
79+
return OpenAiModelProvider.list_llm_models()
80+
return BedrockModelProvider.list_llm_models()
11581

11682
@classmethod
11783
def test(cls, model_name: str) -> Literal["ok"]:

llm-service/app/services/models/providers/_model_provider.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737
#
3838
import abc
3939
import os
40-
from typing import List
40+
41+
from llama_index.core.base.embeddings.base import BaseEmbedding
42+
from llama_index.core.llms import LLM
43+
from llama_index.core.postprocessor.types import BaseNodePostprocessor
4144

4245
from ...caii.types import ModelResponse
4346

@@ -56,18 +59,36 @@ def get_env_var_names() -> set[str]:
5659

5760
@staticmethod
5861
@abc.abstractmethod
59-
def get_llm_models() -> List[ModelResponse]:
60-
"""Return available LLM models."""
62+
def list_llm_models() -> list[ModelResponse]:
63+
"""Return names and IDs of available LLM models."""
64+
raise NotImplementedError
65+
66+
@staticmethod
67+
@abc.abstractmethod
68+
def list_embedding_models() -> list[ModelResponse]:
69+
"""Return names and IDs of available embedding models."""
70+
raise NotImplementedError
71+
72+
@staticmethod
73+
@abc.abstractmethod
74+
def list_reranking_models() -> list[ModelResponse]:
75+
"""Return names and IDs of available reranking models."""
76+
raise NotImplementedError
77+
78+
@staticmethod
79+
@abc.abstractmethod
80+
def get_llm_model(name: str) -> LLM:
81+
"""Return LLM model with `name`."""
6182
raise NotImplementedError
6283

6384
@staticmethod
6485
@abc.abstractmethod
65-
def get_embedding_models() -> List[ModelResponse]:
66-
"""Return available embedding models."""
86+
def get_embedding_model(name: str) -> BaseEmbedding:
87+
"""Return embedding model with `name`."""
6788
raise NotImplementedError
6889

6990
@staticmethod
7091
@abc.abstractmethod
71-
def get_reranking_models() -> List[ModelResponse]:
72-
"""Return available reranking models."""
92+
def get_reranking_model(name: str, top_n: int) -> BaseNodePostprocessor:
93+
"""Return reranking model with `name`."""
7394
raise NotImplementedError

llm-service/app/services/models/providers/azure.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,14 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38+
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
39+
from llama_index.llms.azure_openai import AzureOpenAI
3840

39-
from typing import List
40-
41-
from ...caii.types import ModelResponse
4241
from ._model_provider import ModelProvider
42+
from ...caii.types import ModelResponse
43+
from ...llama_utils import completion_to_prompt, messages_to_prompt
44+
from ...query.simple_reranker import SimpleReranker
45+
from ....config import settings
4346

4447

4548
class AzureModelProvider(ModelProvider):
@@ -48,7 +51,7 @@ def get_env_var_names() -> set[str]:
4851
return {"AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "OPENAI_API_VERSION"}
4952

5053
@staticmethod
51-
def get_llm_models() -> List[ModelResponse]:
54+
def list_llm_models() -> list[ModelResponse]:
5255
return [
5356
ModelResponse(
5457
model_id="gpt-4o",
@@ -61,7 +64,7 @@ def get_llm_models() -> List[ModelResponse]:
6164
]
6265

6366
@staticmethod
64-
def get_embedding_models() -> List[ModelResponse]:
67+
def list_embedding_models() -> list[ModelResponse]:
6568
return [
6669
ModelResponse(
6770
model_id="text-embedding-ada-002",
@@ -74,9 +77,32 @@ def get_embedding_models() -> List[ModelResponse]:
7477
]
7578

7679
@staticmethod
77-
def get_reranking_models() -> List[ModelResponse]:
80+
def list_reranking_models() -> list[ModelResponse]:
7881
return []
7982

83+
@staticmethod
84+
def get_llm_model(name: str) -> AzureOpenAI:
85+
return AzureOpenAI(
86+
model=name,
87+
engine=name,
88+
messages_to_prompt=messages_to_prompt,
89+
completion_to_prompt=completion_to_prompt,
90+
max_tokens=2048,
91+
)
92+
93+
@staticmethod
94+
def get_embedding_model(name: str) -> AzureOpenAIEmbedding:
95+
return AzureOpenAIEmbedding(
96+
model_name=name,
97+
deployment_name=name,
98+
# must be passed manually otherwise AzureOpenAIEmbedding checks OPENAI_API_KEY
99+
api_key=settings.azure_openai_api_key,
100+
)
101+
102+
@staticmethod
103+
def get_reranking_model(name: str, top_n: int) -> SimpleReranker:
104+
return SimpleReranker(top_n=top_n)
105+
80106

81107
# ensure interface is implemented
82108
_ = AzureModelProvider()

llm-service/app/services/models/providers/bedrock.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,17 @@
3535
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
3636
# DATA.
3737
#
38-
from typing import List, Optional, cast
38+
from typing import Optional, cast
3939

4040
import boto3
41+
from llama_index.embeddings.bedrock import BedrockEmbedding
42+
from llama_index.llms.bedrock_converse import BedrockConverse
43+
from llama_index.postprocessor.bedrock_rerank import AWSBedrockRerank
4144

4245
from app.config import settings
43-
from ...caii.types import ModelResponse
4446
from ._model_provider import ModelProvider
47+
from ...caii.types import ModelResponse
48+
from ...llama_utils import completion_to_prompt, messages_to_prompt
4549

4650
DEFAULT_BEDROCK_LLM_MODEL = "meta.llama3-1-8b-instruct-v1:0"
4751
DEFAULT_BEDROCK_RERANK_MODEL = "cohere.rerank-v3-5:0"
@@ -53,7 +57,7 @@ def get_env_var_names() -> set[str]:
5357
return {"AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION"}
5458

5559
@staticmethod
56-
def get_llm_models() -> List[ModelResponse]:
60+
def list_llm_models() -> list[ModelResponse]:
5761
models = [
5862
ModelResponse(
5963
model_id=DEFAULT_BEDROCK_LLM_MODEL, name="Llama3.1 8B Instruct v1"
@@ -91,7 +95,8 @@ def get_llm_models() -> List[ModelResponse]:
9195

9296
@staticmethod
9397
def _get_model_arn_by_profiles(
94-
suffix: str, profiles: List[dict[str, str]]
98+
suffix: str,
99+
profiles: list[dict[str, str]],
95100
) -> Optional[ModelResponse]:
96101
for profile in profiles:
97102
if profile["inferenceProfileId"].endswith(suffix):
@@ -102,13 +107,16 @@ def _get_model_arn_by_profiles(
102107
return None
103108

104109
@staticmethod
105-
def _get_model_arns() -> List[dict[str, str]]:
106-
bedrock_client = boto3.client("bedrock", region_name=settings.aws_default_region)
110+
def _get_model_arns() -> list[dict[str, str]]:
111+
bedrock_client = boto3.client(
112+
"bedrock",
113+
region_name=settings.aws_default_region,
114+
)
107115
profiles = bedrock_client.list_inference_profiles()["inferenceProfileSummaries"]
108-
return cast(List[dict[str, str]], profiles)
116+
return cast(list[dict[str, str]], profiles)
109117

110118
@staticmethod
111-
def get_embedding_models() -> List[ModelResponse]:
119+
def list_embedding_models() -> list[ModelResponse]:
112120
return [
113121
ModelResponse(
114122
model_id="cohere.embed-english-v3",
@@ -121,7 +129,7 @@ def get_embedding_models() -> List[ModelResponse]:
121129
]
122130

123131
@staticmethod
124-
def get_reranking_models() -> List[ModelResponse]:
132+
def list_reranking_models() -> list[ModelResponse]:
125133
return [
126134
ModelResponse(
127135
model_id=DEFAULT_BEDROCK_RERANK_MODEL,
@@ -133,6 +141,23 @@ def get_reranking_models() -> List[ModelResponse]:
133141
),
134142
]
135143

144+
@staticmethod
145+
def get_llm_model(name: str) -> BedrockConverse:
146+
return BedrockConverse(
147+
model=name,
148+
messages_to_prompt=messages_to_prompt,
149+
completion_to_prompt=completion_to_prompt,
150+
max_tokens=2048,
151+
)
152+
153+
@staticmethod
154+
def get_embedding_model(name: str) -> BedrockEmbedding:
155+
return BedrockEmbedding(model_name=name)
156+
157+
@staticmethod
158+
def get_reranking_model(name: str, top_n: int) -> AWSBedrockRerank:
159+
return AWSBedrockRerank(rerank_model_name=name, top_n=top_n)
160+
136161

137162
# ensure interface is implemented
138163
_ = BedrockModelProvider()

0 commit comments

Comments
 (0)