Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion python/sgl_jax/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import os
import signal
import threading
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union, Any

import zmq
import zmq.asyncio
Expand Down Expand Up @@ -43,6 +43,7 @@
set_ulimit,
)
from sgl_jax.version import __version__
from sgl_jax.srt.sampling import SamplingParams

logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
Expand Down Expand Up @@ -99,6 +100,8 @@ def __init__(self, **kwargs):
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
)

self.default_sampling_params: Union[dict[str, Any], None] = None

def generate(
self,
prompt: Optional[Union[List[str], str]] = None,
Expand Down Expand Up @@ -343,6 +346,15 @@ async def async_score(
request=None,
)

def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None:
self.default_sampling_params = (
self.llm_engine.model_config.get_diff_sampling_param())
if self.default_sampling_params:
return SamplingParams.from_optional(**self.default_sampling_params)
return SamplingParams()
Comment on lines +349 to +355

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This new method has several issues that will cause it to fail at runtime:

  1. self.llm_engine is not an attribute of the Engine class. The Engine class acts as a client and does not hold a reference to the LLMEngine instance, which runs in a separate process.
  2. The ModelConfig class does not have a method named get_diff_sampling_param. This call will result in an AttributeError.
  3. The SamplingParams class does not have a class method named from_optional. This will also raise an AttributeError.

It seems the intention is to fetch default sampling parameters from the model's configuration. This likely requires fetching this information from the scheduler process, for example by introducing a new RPC call.

Additionally, the logic if self.default_sampling_params: might be incorrect if an empty dictionary is a valid return value for sampling parameters, as it would evaluate to False. You might want to check for is not None instead.

To instantiate SamplingParams from a dictionary, you can use its constructor directly, e.g., SamplingParams(**params_dict).




def _set_envs_and_config():
# Set ulimit
Expand Down
Loading