Skip to content

Commit ba2207f

Browse files
author
pathfinder-fp
committed
use server parameters
1 parent f130bfe commit ba2207f

File tree

5 files changed

+13
-17
lines changed

5 files changed

+13
-17
lines changed

python/sgl_jax/srt/entrypoints/engine.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
get_zmq_socket,
5050
kill_process_tree,
5151
launch_dummy_health_check_server,
52-
pathways_available,
5352
prepare_model_and_tokenizer,
5453
set_ulimit,
5554
)
@@ -248,7 +247,7 @@ def rerank(
248247
def shutdown(self):
249248
"""Shutdown the engine"""
250249
kill_process_tree(os.getpid(), include_parent=False)
251-
if pathways_available():
250+
if self.server_args.enable_single_process:
252251
self.send_to_rpc.close()
253252

254253
def __enter__(self):
@@ -410,7 +409,7 @@ def get_default_sampling_params(self) -> SamplingParams:
410409
return SamplingParams()
411410

412411

413-
def _set_envs_and_config():
412+
def _set_envs_and_config(server_args):
414413
# Set ulimit
415414
set_ulimit()
416415

@@ -434,7 +433,7 @@ def sigquit_handler(signum, frame):
434433
kill_process_tree(os.getpid())
435434

436435
signal.signal(signal.SIGQUIT, sigquit_handler)
437-
if not pathways_available():
436+
if not server_args.enable_single_process:
438437
# Set mp start method
439438
mp.set_start_method("spawn", force=True)
440439
else:
@@ -450,7 +449,7 @@ def _launch_subprocesses(
450449
# Configure global environment
451450
configure_logger(server_args)
452451
server_args.check_server_args()
453-
_set_envs_and_config()
452+
_set_envs_and_config(server_args)
454453

455454
# Allocate ports for inter-process communications
456455
if port_args is None:
@@ -632,7 +631,7 @@ def _launch_threads(
632631
def _launch_subprocesses_or_threads(
633632
server_args, port_args: Optional[PortArgs] = None
634633
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
635-
if pathways_available():
634+
if server_args.enable_single_process:
636635
return _launch_threads(server_args, port_args)
637636
else:
638637
return _launch_subprocesses(server_args, port_args)

python/sgl_jax/srt/managers/tp_worker_overlap_thread.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from sgl_jax.srt.managers.utils import resolve_future_token_ids, set_future_token_ids
1919
from sgl_jax.srt.sampling.sampling_batch_info import SamplingMetadata
2020
from sgl_jax.srt.server_args import ServerArgs
21-
from sgl_jax.srt.utils import pathways_available
2221
from sgl_jax.utils import get_exception_traceback
2322

2423
logger = logging.getLogger(__name__)
@@ -52,7 +51,7 @@ def __init__(
5251
# JAX handles device execution automatically, no need for explicit streams
5352
self.forward_thread = threading.Thread(
5453
target=self.forward_thread_func,
55-
daemon=True if pathways_available() else False,
54+
daemon=True if server_args.enable_single_process else False,
5655
)
5756
self.forward_thread.start()
5857
self.parent_process = psutil.Process().parent()

python/sgl_jax/srt/server_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class ServerArgs:
132132

133133
# For deterministic sampling
134134
enable_deterministic_sampling: bool = False
135+
enable_single_process: bool = False
135136

136137
def __post_init__(self):
137138
# Set missing default values
@@ -771,6 +772,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
771772
help="Enable deterministic sampling",
772773
)
773774

775+
parser.add_argument(
776+
"--enable-single-process",
777+
action = "store_true",
778+
help = "Enable run the engine with single process.",
779+
)
780+
774781
@classmethod
775782
def from_cli_args(cls, args: argparse.Namespace):
776783
args.tp_size = args.tensor_parallel_size

python/sgl_jax/srt/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,3 @@
1313
set_ulimit,
1414
set_uvicorn_logging_configs,
1515
)
16-
from .tunix_utils import pathways_available

python/sgl_jax/srt/utils/tunix_utils.py

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)