Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 22 additions & 38 deletions python/sgl_jax/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
)
from sgl_jax.srt.managers.scheduler import run_scheduler_process, run_scheduler_thread
from sgl_jax.srt.managers.scheduler import (
run_scheduler_loop_thread_after_create,
run_scheduler_process,
)
from sgl_jax.srt.managers.template_manager import TemplateManager
from sgl_jax.srt.managers.tokenizer_manager import TokenizerManager
from sgl_jax.srt.sampling.sampling_params import SamplingParams
Expand All @@ -46,7 +49,6 @@
get_zmq_socket,
kill_process_tree,
launch_dummy_health_check_server,
pathways_available,
prepare_model_and_tokenizer,
set_ulimit,
)
Expand Down Expand Up @@ -139,7 +141,11 @@ def generate(
token_ids_logprob=token_ids_logprob,
stream=stream,
)
loop = asyncio.get_event_loop()
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
generator = self.tokenizer_manager.generate_request(obj, None)

if stream:
Expand Down Expand Up @@ -241,7 +247,7 @@ def rerank(
def shutdown(self):
"""Shutdown the engine"""
kill_process_tree(os.getpid(), include_parent=False)
if pathways_available():
if self.server_args.enable_single_process:
self.send_to_rpc.close()

def __enter__(self):
Expand Down Expand Up @@ -394,6 +400,7 @@ def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params.get(p) is not None
}
self.default_sampling_params = diff_sampling_param

else:
self.default_sampling_params = {}

Expand All @@ -402,7 +409,7 @@ def get_default_sampling_params(self) -> SamplingParams:
return SamplingParams()


def _set_envs_and_config():
def _set_envs_and_config(server_args):
# Set ulimit
set_ulimit()

Expand All @@ -426,7 +433,7 @@ def sigquit_handler(signum, frame):
kill_process_tree(os.getpid())

signal.signal(signal.SIGQUIT, sigquit_handler)
if not pathways_available():
if not server_args.enable_single_process:
# Set mp start method
mp.set_start_method("spawn", force=True)
else:
Expand All @@ -442,7 +449,7 @@ def _launch_subprocesses(
# Configure global environment
configure_logger(server_args)
server_args.check_server_args()
_set_envs_and_config()
_set_envs_and_config(server_args)

# Allocate ports for inter-process communications
if port_args is None:
Expand Down Expand Up @@ -545,7 +552,7 @@ def _launch_threads(
# Configure global environment
configure_logger(server_args)
server_args.check_server_args()
_set_envs_and_config()
_set_envs_and_config(server_args)

# Allocate ports for inter-process communications
if port_args is None:
Expand All @@ -558,23 +565,11 @@ def _launch_threads(
)

scheduler_threads = []
scheduler_infos = []
if server_args.dp_size == 1:
scheduler_pipe_readers = []
reader, writer = mp.Pipe(duplex=False)
thread = threading.Thread(
target=run_scheduler_thread,
args=(
server_args,
port_args,
None,
writer,
),
daemon=True,
)
# with memory_saver_adapter.configure_subprocess():
thread.start()
scheduler_threads.append(thread)
scheduler_pipe_readers.append(reader)
scheduler_info = run_scheduler_loop_thread_after_create(server_args, port_args)
scheduler_infos.append(scheduler_info)
else:
pass

Expand Down Expand Up @@ -620,25 +615,14 @@ def _launch_threads(
)

# Wait for the model to finish loading
scheduler_infos = []
for i in range(len(scheduler_pipe_readers)):
try:
data = scheduler_pipe_readers[i].recv()
except EOFError:
logger.error(
f"Node {i} jax_scheduler is dead. Please check if there are relevant logs."
)
scheduler_threads[i].join()
logger.error(f"{scheduler_threads[i].name} eof")
raise

if data["status"] != "ready":
for i in range(len(scheduler_infos)):
if scheduler_infos[i]["status"] != "ready":
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
scheduler_infos.append(data)

# Assume all schedulers have the same scheduler_info
assert len(scheduler_infos) > 0, "scheduler_infos is empty"
scheduler_info = scheduler_infos[0]
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
return tokenizer_manager, template_manager, scheduler_info
Expand All @@ -647,7 +631,7 @@ def _launch_threads(
def _launch_subprocesses_or_threads(
server_args, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
if pathways_available():
if server_args.enable_single_process:
return _launch_threads(server_args, port_args)
else:
return _launch_subprocesses(server_args, port_args)
3 changes: 3 additions & 0 deletions python/sgl_jax/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,9 @@ def launch_server(
tokenizer_manager, template_manager, scheduler_info = (
_launch_subprocesses_or_threads(server_args=server_args, port_args=None)
)
## don't expose scheduler in server mode
if "scheduler" in scheduler_info:
del scheduler_info["scheduler"]
set_global_state(
_GlobalState(
tokenizer_manager=tokenizer_manager,
Expand Down
39 changes: 24 additions & 15 deletions python/sgl_jax/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,12 +1098,33 @@ def run_scheduler_process(
parent_process.send_signal(signal.SIGQUIT)


def run_scheduler_thread(
def run_scheduler_loop_thread_after_create(
server_args: ServerArgs,
port_args: PortArgs,
dp_rank: Optional[int],
pipe_writer,
):
current_process = psutil.Process()
# Create a scheduler and run the event loop
try:
scheduler = Scheduler(server_args, port_args)
scheduler_thread = threading.Thread(
target=scheduler_loop_after_create,
args=(server_args, scheduler),
daemon=True,
)
scheduler_thread.start()
return {
"status": "ready",
"max_total_num_tokens": scheduler.max_total_num_tokens,
"max_req_input_len": scheduler.max_req_input_len,
"scheduler": scheduler,
}
except Exception:
traceback = get_exception_traceback()
logger.error(f"Scheduler hit an exception: {traceback}")
current_process.send_signal(signal.SIGQUIT)


def scheduler_loop_after_create(server_args, scheduler):
# Generate the prefix
prefix = ""
if server_args.nnodes > 1:
Expand All @@ -1117,23 +1138,11 @@ def run_scheduler_thread(

# Configure the logger
configure_logger(server_args, prefix=prefix)

# Create a scheduler and run the event loop
try:
scheduler = Scheduler(server_args, port_args)
pipe_writer.send(
{
"status": "ready",
"max_total_num_tokens": scheduler.max_total_num_tokens,
"max_req_input_len": scheduler.max_req_input_len,
}
)

if scheduler.enable_overlap:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()

except Exception:
traceback = get_exception_traceback()
logger.error(f"Scheduler hit an exception: {traceback}")
Expand Down
1 change: 1 addition & 0 deletions python/sgl_jax/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
# JAX handles device execution automatically, no need for explicit streams
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
daemon=True if server_args.enable_single_process else False,
)
self.forward_thread.start()
self.parent_process = psutil.Process().parent()
Expand Down
9 changes: 5 additions & 4 deletions python/sgl_jax/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,12 @@ def _initialize_model(self, model_config: ModelConfig) -> Any:
return model_class

def _get_model(self, model_class: Any, model_config: ModelConfig) -> nnx.Module:
model = nnx.eval_shape(
lambda: model_class(
model_config.hf_config, model_config.dtype, self.rng, self.mesh
with self.mesh:
model = nnx.eval_shape(
lambda: model_class(
model_config.hf_config, model_config.dtype, self.rng, self.mesh
)
)
)

model.load_weights(model_config, self.rng.default.key.value)
return model
Expand Down
19 changes: 17 additions & 2 deletions python/sgl_jax/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def from_model_worker_batch_for_precompile(
padded_top_ks = np.concat(
[
batch.sampling_info.top_ks,
np.array([-1] * pad_size, dtype=batch.sampling_info.top_ks.dtype),
np.array([1] * pad_size, dtype=batch.sampling_info.top_ks.dtype),
]
)
padded_min_ps = np.concat(
Expand All @@ -259,6 +259,21 @@ def from_model_worker_batch_for_precompile(
np.array([0.0] * pad_size, dtype=batch.sampling_info.min_ps.dtype),
]
)
if batch.sampling_info.sampling_seeds is not None:
padded_sampling_seeds = np.concat(
[
batch.sampling_info.sampling_seeds,
np.array(
[DEFAULT_SAMPLING_SEED] * pad_size,
dtype=batch.sampling_info.sampling_seeds.dtype,
),
]
)
sampling_seeds_device = device_array(
padded_sampling_seeds, sharding=sharding
)
else:
sampling_seeds_device = None

(temperatures_device, top_ps_device, top_ks_device, min_ps_device) = (
device_array(
Expand Down Expand Up @@ -304,10 +319,10 @@ def from_model_worker_batch_for_precompile(
top_ps=top_ps_device,
top_ks=top_ks_device,
min_ps=min_ps_device,
sampling_seeds=sampling_seeds_device,
is_all_greedy=batch.sampling_info.is_all_greedy,
need_min_p_sampling=batch.sampling_info.need_min_p_sampling,
linear_penalty=linear_penalty_device,
sampling_seeds=sampling_seeds_device,
do_penalties=True,
)

Expand Down
7 changes: 7 additions & 0 deletions python/sgl_jax/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class ServerArgs:

# For deterministic sampling
enable_deterministic_sampling: bool = False
enable_single_process: bool = False

def __post_init__(self):
# Set missing default values
Expand Down Expand Up @@ -771,6 +772,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Enable deterministic sampling",
)

parser.add_argument(
"--enable-single-process",
action="store_true",
help="Enable run the engine with single process.",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
Expand Down
1 change: 0 additions & 1 deletion python/sgl_jax/srt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@
set_ulimit,
set_uvicorn_logging_configs,
)
from .tunix_utils import pathways_available
8 changes: 0 additions & 8 deletions python/sgl_jax/srt/utils/tunix_utils.py

This file was deleted.

28 changes: 28 additions & 0 deletions test/srt/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,34 @@ def test_abort_all(self):
future.result()["meta_info"]["finish_reason"]["type"], "abort"
)

def test_cache_miss_prefill(self):
args = SimpleNamespace(
base_url=self.base_url,
text="the capital of France is",
temperature=0,
max_new_tokens=1,
)

resp = run_curl(args)

if "cache_miss_count" not in resp["meta_info"]:
raise "[prefill] cache_miss_count is missed in response"
self.assertEqual(resp["meta_info"]["cache_miss_count"], 0)

def test_cache_miss_decode(self):
args = SimpleNamespace(
base_url=self.base_url,
text="the capital of France is",
temperature=0,
max_new_tokens=2,
)

resp = run_curl(args)

if "cache_miss_count" not in resp["meta_info"]:
raise "[prefill] cache_miss_count is missed in response"
self.assertEqual(resp["meta_info"]["cache_miss_count"], 0)

def test_logprobs(self):
# Note: add test_logprobs until accuracy score is relatively high, we will update the following expected logits.
# Now every accuracy improvement may result in tiny differences in value, so skip it now and support it in the future.
Expand Down