-
Notifications
You must be signed in to change notification settings - Fork 1.9k
refactor(model): move api-based embeddings/reranking calls out of model server #5216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor(model): move api-based embeddings/reranking calls out of model server #5216
Conversation
…server, added/modified unit tests
The latest updates on your projects. Learn more about Vercel for GitHub.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Summary
This PR implements a significant architectural refactoring (DAN-2326) that moves API-based embedding and reranking calls out of the model server and routes them directly through the Onyx NLP layer. The change creates a cleaner separation of concerns where the model server handles only local models that require GPU resources, while API-based providers (OpenAI, Cohere, Voyage, Azure, Google Vertex) are called directly from the NLP layer.
The refactor introduces a new CloudEmbedding
class in onyx/natural_language_processing/search_nlp_models.py
that handles direct API calls to various cloud providers with async context management. The routing logic in EmbeddingModel
and RerankerModel
now checks the provider_type
: when it's None
, requests go to the model server for local model processing; otherwise, they use direct API calls via the new cloud embedding implementation.
Key changes include:
- Moving the
CloudEmbedding
class frommodel_server/encoders.py
to the NLP layer with comprehensive provider-specific methods - Simplifying the model server's
embed_text
function to handle only local models, removing 8 parameters related to API providers - Adding provider constants and text type mappings to
onyx/natural_language_processing/constants.py
- Creating utility functions in
onyx/utils/search_nlp_models_utils.py
for AWS key parsing - Adding new dependencies (
voyageai==0.2.3
andcohere==5.6.1
) to support direct API client calls - Updating tests to reflect the new architecture, removing cloud provider tests from model server and adding comprehensive tests for the new
CloudEmbedding
class
The refactoring eliminates unnecessary network hops through the model server for API-based operations, reducing latency and simplifying the debugging path. This architectural change makes the system more maintainable by consolidating cloud provider integration logic in one place while keeping the model server focused on its core responsibility of local model inference.
Confidence score: 4/5
- This PR involves significant architectural changes but appears well-structured with proper separation of concerns and comprehensive test coverage
- Score reflects the complexity of moving critical embedding logic across modules, though the changes follow good design principles
- Pay close attention to
onyx/natural_language_processing/search_nlp_models.py
for the new routing logic and cloud provider implementations
9 files reviewed, 8 comments
async def mock_http_client() -> AsyncGenerator[AsyncMock, None]: | ||
with patch("httpx.AsyncClient") as mock: | ||
client = AsyncMock(spec=AsyncClient) | ||
mock.return_value = client | ||
client.post = AsyncMock() | ||
async with client as c: | ||
yield c |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The mock_http_client fixture is defined but never used in any tests - consider removing it or adding tests that actually use HTTP clients
async def mock_http_client() -> AsyncGenerator[AsyncMock, None]: | |
with patch("httpx.AsyncClient") as mock: | |
client = AsyncMock(spec=AsyncClient) | |
mock.return_value = client | |
client.post = AsyncMock() | |
async with client as c: | |
yield c | |
@pytest.fixture | |
def sample_embeddings() -> List[List[float]]: | |
return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] |
credentials = service_account.Credentials.from_service_account_info( | ||
json.loads(self.api_key) | ||
) | ||
project_id = json.loads(self.api_key)["project_id"] | ||
vertexai.init(project=project_id, credentials=credentials) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: JSON parsing API key twice could be optimized by parsing once and storing the result.
credentials = service_account.Credentials.from_service_account_info( | |
json.loads(self.api_key) | |
) | |
project_id = json.loads(self.api_key)["project_id"] | |
vertexai.init(project=project_id, credentials=credentials) | |
service_account_info = json.loads(self.api_key) | |
credentials = service_account.Credentials.from_service_account_info( | |
service_account_info | |
) | |
project_id = service_account_info["project_id"] | |
vertexai.init(project=project_id, credentials=credentials) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with greptile, might as well clean this up since we're moving it
response = asyncio.run( | ||
self._make_direct_api_call( | ||
embed_request, tenant_id=tenant_id, request_id=request_id | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Using asyncio.run() in a potentially multithreaded context could cause issues. Consider using asyncio.new_event_loop() or ensuring thread safety.
# Route between direct API calls and model server calls | ||
if self.provider_type is not None: | ||
# For API providers, make direct API call | ||
return asyncio.run(self._make_direct_rerank_call(query, passages)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Same asyncio.run() concern applies here for the reranking predict method.
backend/model_server/encoders.py
Outdated
if rerank_request.provider_type is None: | ||
sim_scores = await local_rerank( | ||
query=rerank_request.query, | ||
docs=rerank_request.documents, | ||
model_name=rerank_request.model_name, | ||
region_name=aws_region, | ||
aws_access_key_id=aws_access_key_id, | ||
aws_secret_access_key=aws_secret_access_key, | ||
) | ||
return RerankResponse(scores=sim_scores) | ||
else: | ||
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}") | ||
raise ValueError("Neither model name nor provider specified for reranking") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Logic error: Lines 262-270 are unreachable because line 254 already checks if rerank_request.provider_type is not None
and raises an exception. The else block on line 269 will never execute.
except Exception as e: | ||
raise ValueError(f"Failed to parse AWS key components: {str(e)}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Generic Exception catch is too broad - the try block only contains tuple unpacking which has specific failure modes
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 issues found across 10 files
React with 👍 or 👎 to teach cubic. You can also tag @cubic-dev-ai
to give feedback, ask questions, or re-run the review.
# Route between direct API calls and model server calls | ||
if self.provider_type is not None: | ||
# For API providers, make direct API call | ||
return asyncio.run(self._make_direct_rerank_call(query, passages)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not call asyncio.run() inside predict; this may fail under an active event loop and is unsafe in multi-threaded contexts. Use a per-thread event loop or schedule the coroutine on the appropriate loop
Prompt for AI agents
Address the following comment on backend/onyx/natural_language_processing/search_nlp_models.py at line 922:
<comment>Do not call asyncio.run() inside predict; this may fail under an active event loop and is unsafe in multi-threaded contexts. Use a per-thread event loop or schedule the coroutine on the appropriate loop</comment>
<file context>
@@ -360,29 +863,85 @@ def __init__(
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
- model_server_url = build_model_server_url(model_server_host, model_server_port)
- self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
self.model_name = model_name
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url
</file context>
# Route between direct API calls and model server calls | ||
if self.provider_type is not None: | ||
# For API providers, make direct API call | ||
response = asyncio.run( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid calling asyncio.run() here; it can raise at runtime when invoked from a running event loop and can conflict with threaded execution. Create and manage a dedicated event loop per thread or refactor to run the coroutine on an existing loop safely
Prompt for AI agents
Address the following comment on backend/onyx/natural_language_processing/search_nlp_models.py at line 718:
<comment>Avoid calling asyncio.run() here; it can raise at runtime when invoked from a running event loop and can conflict with threaded execution. Create and manage a dedicated event loop per thread or refactor to run the coroutine on an existing loop safely</comment>
<file context>
@@ -219,11 +710,23 @@ def process_batch(
reduced_dimension=self.reduced_dimension,
)
- start_time = time.time()
- response = self._make_model_server_request(
- embed_request, tenant_id=tenant_id, request_id=request_id
- )
- end_time = time.time()
+ start_time = time.monotonic()
</file context>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the bot comments; there are some examples in the repo that make a new event loop I believe
@staticmethod | ||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str: | ||
"""Get provider-specific text type string.""" | ||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_type method lacks error handling for unsupported provider/text_type combinations and could raise KeyError exceptions
Prompt for AI agents
Address the following comment on backend/onyx/natural_language_processing/constants.py at line 40:
<comment>get_type method lacks error handling for unsupported provider/text_type combinations and could raise KeyError exceptions</comment>
<file context>
@@ -0,0 +1,40 @@
+"""
+Constants for natural language processing, including embedding and reranking models.
+
+This file contains constants moved from model_server to support the gradual migration
+of API-based calls to bypass the model server.
+"""
+
+from shared_configs.enums import EmbeddingProvider
+from shared_configs.enums import EmbedTextType
</file context>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good after addressing comments!
error_message += "\n".join(texts) | ||
logger.error(error_message) | ||
raise ValueError(error_message) | ||
# Only local models should call this function now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would like to have a code enforcement of the new flow: i.e. if model provider is not None raise a ValueError
_RERANK_MODEL: Optional["CrossEncoder"] = None | ||
|
||
# If we are not only indexing, dont want retry very long | ||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can these be removed from here now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think those are still being used by local model calls in encoder/bi-encoder-embed and encoder/cross-encoder-scores
credentials = service_account.Credentials.from_service_account_info( | ||
json.loads(self.api_key) | ||
) | ||
project_id = json.loads(self.api_key)["project_id"] | ||
vertexai.init(project=project_id, credentials=credentials) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with greptile, might as well clean this up since we're moving it
# Route between direct API calls and model server calls | ||
if self.provider_type is not None: | ||
# For API providers, make direct API call | ||
response = asyncio.run( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with the bot comments; there are some examples in the repo that make a new event loop I believe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tested and lgtm!
…el server (onyx-dot-app#5216) * move api-based embeddings/reranking calls to api server out of model server, added/modified unit tests * ran pre-commit * fix mypy errors * mypy and precommit * move utils to right place and add requirements * precommit check * removed extra constants, changed error msg * Update backend/onyx/utils/search_nlp_models_utils.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * greptile * addressed comments * added code enforcement to throw error --------- Co-authored-by: Jessica Singh <jessicasingh@Mac.attlocal.net> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
DAN-2326
Description
[Provide a brief description of the changes in this PR]
How Has This Been Tested?
[Describe the tests you ran to verify your changes]
Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
Summary by cubic
Route API-based embedding and reranking calls directly through the Onyx NLP layer, leaving the model server for local models only. This simplifies the path, reduces latency, and fulfills DAN-2326.
Refactors
Migration