From 5749e9d05163ec3b7856bc88fd16015f0b2ea034 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Wed, 15 Oct 2025 20:47:30 -0400 Subject: [PATCH 01/18] ruff all --- .pre-commit-config.yaml | 4 +- python/sgl_jax/bench_one_batch.py | 10 ++-- python/sgl_jax/bench_serving.py | 2 +- python/sgl_jax/check_env.py | 3 +- python/sgl_jax/profiler.py | 1 - python/sgl_jax/srt/entrypoints/EngineBase.py | 2 +- python/sgl_jax/srt/entrypoints/engine.py | 7 ++- python/sgl_jax/srt/entrypoints/http_server.py | 5 +- .../entrypoints/openai/serving_embedding.py | 4 +- .../srt/function_call/function_call_parser.py | 2 +- python/sgl_jax/srt/hf_transformers_utils.py | 2 +- python/sgl_jax/srt/jinja_template_utils.py | 1 - .../srt/layers/attention/base_attn_backend.py | 1 - .../flash_attn_kernel/flash_attention.py | 3 -- .../flash_attn_kernel/tuned_block_sizes.py | 2 - .../attention/flashattention_backend.py | 1 - .../srt/layers/attention/native_backend.py | 2 - .../srt/layers/gmm/megablox_gmm_kernel/gmm.py | 1 + python/sgl_jax/srt/layers/linear.py | 2 +- python/sgl_jax/srt/layers/logits_processor.py | 1 - python/sgl_jax/srt/layers/moe.py | 2 +- python/sgl_jax/srt/layers/radix_attention.py | 3 -- python/sgl_jax/srt/layers/sampler.py | 3 +- python/sgl_jax/srt/managers/schedule_batch.py | 7 +-- .../sgl_jax/srt/managers/schedule_policy.py | 2 - python/sgl_jax/srt/managers/scheduler.py | 7 ++- .../srt/managers/scheduler_metrics_mixin.py | 2 +- python/sgl_jax/srt/managers/tp_worker.py | 2 +- python/sgl_jax/srt/mem_cache/allocator.py | 1 - python/sgl_jax/srt/mem_cache/chunk_cache.py | 2 - python/sgl_jax/srt/mem_cache/radix_cache.py | 9 ++-- python/sgl_jax/srt/memory_profiler.py | 1 - .../srt/model_executor/forward_batch_info.py | 9 ++-- .../srt/model_executor/model_runner.py | 7 ++- python/sgl_jax/srt/models/llama.py | 4 +- python/sgl_jax/srt/models/qwen.py | 1 - python/sgl_jax/srt/models/qwen2.py | 1 - python/sgl_jax/srt/models/qwen3.py | 1 - python/sgl_jax/srt/models/qwen3_moe.py | 1 - python/sgl_jax/srt/precision_tracer.py | 6 +-- .../srt/sampling/sampling_batch_info.py | 1 - python/sgl_jax/srt/server_args.py | 2 +- python/sgl_jax/srt/utils/__init__.py | 1 + python/sgl_jax/srt/utils/common_utils.py | 6 +-- python/sgl_jax/srt/utils/jax_utils.py | 2 +- python/sgl_jax/srt/utils/tunix_utils.py | 9 ++++ python/sgl_jax/srt/utils/weight_utils.py | 3 +- .../sgl_jax/test/mem_cache/test_kv_cache.py | 5 -- .../test/mem_cache/test_radix_cache.py | 2 - .../test/model_executor/test_model_runner.py | 1 + python/sgl_jax/test/models/test_qwen_model.py | 17 +++--- python/sgl_jax/test/run_curl.py | 1 - python/sgl_jax/test/run_jax_loader_test.py | 9 ++-- python/sgl_jax/test/run_qwen3_moe_test.py | 17 +++--- python/sgl_jax/test/run_qwen_test.py | 17 +++--- python/sgl_jax/test/runners.py | 15 ------ python/sgl_jax/test/simple_eval_gpqa.py | 1 - python/sgl_jax/test/simple_eval_humaneval.py | 2 - python/sgl_jax/test/simple_eval_mgsm.py | 2 +- python/sgl_jax/test/test_flashattention.py | 17 +++--- python/sgl_jax/test/test_jax_model_loader.py | 4 +- python/sgl_jax/test/test_model_loader.py | 53 +++++++++---------- python/sgl_jax/test/test_sampler.py | 4 +- python/sgl_jax/test/test_utils.py | 2 - python/sgl_jax/tools/trace_diff.py | 4 +- python/sgl_jax/utils.py | 12 +---- 66 files changed, 130 insertions(+), 206 deletions(-) create mode 100644 python/sgl_jax/srt/utils/tunix_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d3a40d13a..be12f9ec1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,8 +28,8 @@ repos: rev: v0.11.7 hooks: - id: ruff - args: [--select=F401, --fixable=F401] - files: ^(benchmark/|docs/|examples/) + args: ["--select=F401,F821", "--fixable=F401,F821"] + files: ^(python/|benchmark/|docs/|examples/) exclude: \.ipynb$ - repo: https://github.com/psf/black rev: 24.10.0 diff --git a/python/sgl_jax/bench_one_batch.py b/python/sgl_jax/bench_one_batch.py index f0377adee..c8d08cb8b 100644 --- a/python/sgl_jax/bench_one_batch.py +++ b/python/sgl_jax/bench_one_batch.py @@ -66,7 +66,7 @@ from sgl_jax.srt.sampling.sampling_batch_info import SamplingMetadata from sgl_jax.srt.sampling.sampling_params import SamplingParams from sgl_jax.srt.server_args import PortArgs, ServerArgs -from sgl_jax.srt.utils import configure_logger, get_bool_env_var, kill_process_tree +from sgl_jax.srt.utils import configure_logger, kill_process_tree @dataclasses.dataclass @@ -128,7 +128,7 @@ def load_model(server_args, port_args, tp_rank): # TODO: pass in tp_size # server_args.tp_size = 1 rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + # moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) model_config = ModelConfig.from_server_args(server_args) @@ -404,7 +404,6 @@ def latency_test_run_once( tot_latency = 0 - profiler = None if profile: profile_dir = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.tb" os.makedirs(profile_dir, exist_ok=True) @@ -469,10 +468,7 @@ def latency_test( bench_args, tp_rank, ): - # Set CPU affinity - if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): - set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank) - + # TODO: Fix this function # Configure the logger configure_logger(server_args, prefix=f" TP{tp_rank}") rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None diff --git a/python/sgl_jax/bench_serving.py b/python/sgl_jax/bench_serving.py index bed9f9170..486261a4d 100644 --- a/python/sgl_jax/bench_serving.py +++ b/python/sgl_jax/bench_serving.py @@ -1054,7 +1054,7 @@ def sample_generated_shared_prefix_requests( random.shuffle(input_requests) # Print statistics - print(f"\nGenerated shared prefix dataset statistics:") + print("\nGenerated shared prefix dataset statistics:") print(f"Number of groups: {num_groups}") print(f"Prompts per group: {prompts_per_group}") print(f"Total prompts: {len(input_requests)}") diff --git a/python/sgl_jax/check_env.py b/python/sgl_jax/check_env.py index d31a98300..47127ef0b 100644 --- a/python/sgl_jax/check_env.py +++ b/python/sgl_jax/check_env.py @@ -2,9 +2,8 @@ import importlib.metadata import resource -import subprocess import sys -from collections import OrderedDict, defaultdict +from collections import OrderedDict import jax diff --git a/python/sgl_jax/profiler.py b/python/sgl_jax/profiler.py index 3503ae7fc..d872ca320 100644 --- a/python/sgl_jax/profiler.py +++ b/python/sgl_jax/profiler.py @@ -9,7 +9,6 @@ import json import os import time -import urllib.parse from argparse import ArgumentParser from pathlib import Path from typing import List, Optional diff --git a/python/sgl_jax/srt/entrypoints/EngineBase.py b/python/sgl_jax/srt/entrypoints/EngineBase.py index e0704b4f3..73c9abc6d 100644 --- a/python/sgl_jax/srt/entrypoints/EngineBase.py +++ b/python/sgl_jax/srt/entrypoints/EngineBase.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Union class EngineBase(ABC): diff --git a/python/sgl_jax/srt/entrypoints/engine.py b/python/sgl_jax/srt/entrypoints/engine.py index afa4c05d1..9778a5763 100644 --- a/python/sgl_jax/srt/entrypoints/engine.py +++ b/python/sgl_jax/srt/entrypoints/engine.py @@ -4,6 +4,8 @@ This file implements python APIs for the inference engine. """ +import json +import uvloop import asyncio import atexit import dataclasses @@ -17,13 +19,10 @@ import zmq import zmq.asyncio +# ruff: noqa: E402 # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) -import json - -import uvloop - from sgl_jax.srt.entrypoints.EngineBase import EngineBase from sgl_jax.srt.hf_transformers_utils import get_generation_config from sgl_jax.srt.managers.detokenizer_manager import ( diff --git a/python/sgl_jax/srt/entrypoints/http_server.py b/python/sgl_jax/srt/entrypoints/http_server.py index b3935b08d..641bd8a18 100644 --- a/python/sgl_jax/srt/entrypoints/http_server.py +++ b/python/sgl_jax/srt/entrypoints/http_server.py @@ -16,6 +16,7 @@ from http import HTTPStatus from typing import Any, AsyncIterator, Callable, Dict, List, Optional +# ruff: noqa: E402 # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) @@ -430,7 +431,7 @@ async def start_trace_async(obj: Optional[StartTraceReqInput] = None): ) precision_tracer.start_trace(req_num=obj.req_num, output_file=output_file) - logger.info(f"[HTTP] Sending trace state to scheduler...") + logger.info("[HTTP] Sending trace state to scheduler...") trace_state = { "precision_tracer": { "trace_active": True, @@ -479,7 +480,7 @@ async def stop_trace_async(obj: Optional[StopTraceReqInput] = None): """Stop precision tracing.""" try: output_file = precision_tracer.stop_trace() - print(f"[HTTP] Sending stop trace state to scheduler...") + print("[HTTP] Sending stop trace state to scheduler...") trace_state = { "precision_tracer": { "trace_active": False, diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_embedding.py b/python/sgl_jax/srt/entrypoints/openai/serving_embedding.py index c11ff9cc3..256927834 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_embedding.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_embedding.py @@ -54,14 +54,14 @@ def _validate_request(self, request: EmbeddingRequest) -> Optional[str]: # List of strings for i, item in enumerate(input): if not isinstance(item, str): - return f"All items in input list must be strings" + return "All items in input list must be strings" if not item.strip(): return f"Input at index {i} cannot be empty or whitespace only" elif isinstance(first_item, int): # List of integers (token IDs) for i, item in enumerate(input): if not isinstance(item, int): - return f"All items in input list must be integers" + return "All items in input list must be integers" if item < 0: return f"Token ID at index {i} must be non-negative" return None diff --git a/python/sgl_jax/srt/function_call/function_call_parser.py b/python/sgl_jax/srt/function_call/function_call_parser.py index 3700b48a4..6744c7b56 100644 --- a/python/sgl_jax/srt/function_call/function_call_parser.py +++ b/python/sgl_jax/srt/function_call/function_call_parser.py @@ -5,7 +5,7 @@ import json import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) diff --git a/python/sgl_jax/srt/hf_transformers_utils.py b/python/sgl_jax/srt/hf_transformers_utils.py index 5b5955a92..f169a200e 100644 --- a/python/sgl_jax/srt/hf_transformers_utils.py +++ b/python/sgl_jax/srt/hf_transformers_utils.py @@ -123,7 +123,7 @@ def get_generation_config( return GenerationConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) - except OSError as e: + except OSError: return None diff --git a/python/sgl_jax/srt/jinja_template_utils.py b/python/sgl_jax/srt/jinja_template_utils.py index bba6849e9..daa89301c 100644 --- a/python/sgl_jax/srt/jinja_template_utils.py +++ b/python/sgl_jax/srt/jinja_template_utils.py @@ -4,7 +4,6 @@ """ import logging -from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) diff --git a/python/sgl_jax/srt/layers/attention/base_attn_backend.py b/python/sgl_jax/srt/layers/attention/base_attn_backend.py index 0c2a87469..587c69cc7 100644 --- a/python/sgl_jax/srt/layers/attention/base_attn_backend.py +++ b/python/sgl_jax/srt/layers/attention/base_attn_backend.py @@ -5,7 +5,6 @@ import jax from flax import nnx -from jax.sharding import Mesh if TYPE_CHECKING: from sgl_jax.srt.layers.radix_attention import RadixAttention diff --git a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py index eed38d669..bb7fce216 100644 --- a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py +++ b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py @@ -280,9 +280,6 @@ def _ragged_paged_attention_kernel( kv_packing, _, ) = kv_cache_fused_hbm_ref.shape - max_num_seqs = kv_lens_ref.shape[0] - num_page_indices = page_indices_ref.shape[0] - pages_per_seq = num_page_indices // max_num_seqs num_q_heads_per_kv_head = num_q_heads_per_kv_head_per_packing * q_packing q_dtype = q_hbm_ref.dtype kv_dtype = kv_cache_fused_hbm_ref.dtype diff --git a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py index a2b3c354c..a99fb8169 100644 --- a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py +++ b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py @@ -3,9 +3,7 @@ import jax.numpy as jnp from sgl_jax.srt.layers.attention.flash_attn_kernel.util import ( - align_to, get_device_name, - get_dtype_packing, get_tpu_version, next_power_of_2, ) diff --git a/python/sgl_jax/srt/layers/attention/flashattention_backend.py b/python/sgl_jax/srt/layers/attention/flashattention_backend.py index d55701b63..608cbfffb 100644 --- a/python/sgl_jax/srt/layers/attention/flashattention_backend.py +++ b/python/sgl_jax/srt/layers/attention/flashattention_backend.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Tuple import jax import jax.numpy as jnp diff --git a/python/sgl_jax/srt/layers/attention/native_backend.py b/python/sgl_jax/srt/layers/attention/native_backend.py index 2adcdcec9..9eeee07ff 100644 --- a/python/sgl_jax/srt/layers/attention/native_backend.py +++ b/python/sgl_jax/srt/layers/attention/native_backend.py @@ -2,7 +2,6 @@ import jax import jax.numpy as jnp -from jax.sharding import Mesh from jax.tree_util import register_pytree_node_class from sgl_jax.srt.layers.attention.base_attn_backend import AttentionBackend @@ -245,7 +244,6 @@ def _apply_extend_mask( Applies a block-diagonal and optionally a causal mask in a unified, efficient way, correctly handling padding. """ - batch_size = seq_lengths.shape[0] _, query_len, key_len = attn_weights.shape # --- Create validity masks to handle padding --- diff --git a/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py b/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py index 42bcf7144..a33544582 100644 --- a/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py +++ b/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py @@ -425,6 +425,7 @@ def _accum(is_last_k_tile): mask_k_rem_lhs = partial(mask_k_rem, dim=1) mask_k_rem_rhs = partial(mask_k_rem, dim=int(transpose_rhs)) else: + # ruff: noqa: E731 mask_k_rem_lhs = lambda x: x mask_k_rem_rhs = lambda x: x diff --git a/python/sgl_jax/srt/layers/linear.py b/python/sgl_jax/srt/layers/linear.py index 6acc788ae..b847d93ae 100644 --- a/python/sgl_jax/srt/layers/linear.py +++ b/python/sgl_jax/srt/layers/linear.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union +from typing import Iterable, Optional, Sequence, Tuple import jax from flax import nnx diff --git a/python/sgl_jax/srt/layers/logits_processor.py b/python/sgl_jax/srt/layers/logits_processor.py index c46255731..e00daab12 100644 --- a/python/sgl_jax/srt/layers/logits_processor.py +++ b/python/sgl_jax/srt/layers/logits_processor.py @@ -1,5 +1,4 @@ import dataclasses -from functools import partial from typing import List, Optional import jax diff --git a/python/sgl_jax/srt/layers/moe.py b/python/sgl_jax/srt/layers/moe.py index 7e14fa15f..c4edba147 100644 --- a/python/sgl_jax/srt/layers/moe.py +++ b/python/sgl_jax/srt/layers/moe.py @@ -170,7 +170,7 @@ def _detect_device_capabilities(self): primary_device = device_types[0] if device_types else "unknown" return can_use_ragged, primary_device - except Exception as e: + except Exception as _: return False, "cpu" def __call__(self, inputs, router_logits=None): diff --git a/python/sgl_jax/srt/layers/radix_attention.py b/python/sgl_jax/srt/layers/radix_attention.py index 77f35f057..22d055684 100644 --- a/python/sgl_jax/srt/layers/radix_attention.py +++ b/python/sgl_jax/srt/layers/radix_attention.py @@ -3,10 +3,7 @@ from enum import Enum import jax -import jax.numpy as jnp from flax import nnx -from jax.sharding import NamedSharding -from jax.sharding import PartitionSpec as P from sgl_jax.srt.mem_cache.memory_pool import KVCache from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch diff --git a/python/sgl_jax/srt/layers/sampler.py b/python/sgl_jax/srt/layers/sampler.py index f64479c9d..28de0e167 100644 --- a/python/sgl_jax/srt/layers/sampler.py +++ b/python/sgl_jax/srt/layers/sampler.py @@ -1,5 +1,4 @@ -from functools import partial -from typing import List, Optional +from typing import List import jax import numpy as np diff --git a/python/sgl_jax/srt/managers/schedule_batch.py b/python/sgl_jax/srt/managers/schedule_batch.py index 2a79dfca3..50754f359 100644 --- a/python/sgl_jax/srt/managers/schedule_batch.py +++ b/python/sgl_jax/srt/managers/schedule_batch.py @@ -1,3 +1,4 @@ +# ruff: noqa: E402 from __future__ import annotations """ @@ -1338,9 +1339,9 @@ def _available_and_evictable_str(self) -> str: return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n" -def align_to_size(l: list, size: int, value: int = 0) -> list: - align_len = (len(l) + size - 1) // size * size - return l[:] + [value] * (align_len - len(l)) +def align_to_size(lst: list, size: int, value: int = 0) -> list: + align_len = (len(lst) + size - 1) // size * size + return lst[:] + [value] * (align_len - len(lst)) @dataclasses.dataclass diff --git a/python/sgl_jax/srt/managers/schedule_policy.py b/python/sgl_jax/srt/managers/schedule_policy.py index 230defa0e..822d67e4f 100644 --- a/python/sgl_jax/srt/managers/schedule_policy.py +++ b/python/sgl_jax/srt/managers/schedule_policy.py @@ -1,7 +1,5 @@ from __future__ import annotations -"""Request scheduler policy""" - import os import random from collections import defaultdict diff --git a/python/sgl_jax/srt/managers/scheduler.py b/python/sgl_jax/srt/managers/scheduler.py index daaa47b5b..eba78e6c1 100644 --- a/python/sgl_jax/srt/managers/scheduler.py +++ b/python/sgl_jax/srt/managers/scheduler.py @@ -54,7 +54,6 @@ from sgl_jax.srt.mem_cache.radix_cache import RadixCache from sgl_jax.srt.model_executor.forward_batch_info import ForwardMode from sgl_jax.srt.precision_tracer import precision_tracer -from sgl_jax.srt.sampling.sampling_batch_info import SamplingMetadata from sgl_jax.srt.server_args import PortArgs, ServerArgs from sgl_jax.srt.utils.common_utils import ( configure_logger, @@ -307,9 +306,9 @@ def __init__( ) if not server_args.disable_jax_precompile: - logger.info(f"[Scheduler] Begins to run worker precompile.") + logger.info("[Scheduler] Begins to run worker precompile.") self.tp_worker.run_precompile() - logger.info(f"[Scheduler] Completes worker precompile.") + logger.info("[Scheduler] Completes worker precompile.") def sync_pub(self): logger.info( @@ -635,7 +634,7 @@ def set_internal_state(self, recv_req: SetInternalStateReq): precision_tracer._completed_requests_count = 0 precision_tracer._request_traces = {} logger.info( - f"[SCHEDULER] Reset request_counter, completed_count and traces" + "[SCHEDULER] Reset request_counter, completed_count and traces" ) if "max_requests" in tracer_config: diff --git a/python/sgl_jax/srt/managers/scheduler_metrics_mixin.py b/python/sgl_jax/srt/managers/scheduler_metrics_mixin.py index 238e25015..161091512 100644 --- a/python/sgl_jax/srt/managers/scheduler_metrics_mixin.py +++ b/python/sgl_jax/srt/managers/scheduler_metrics_mixin.py @@ -1,7 +1,7 @@ import logging import time from collections import defaultdict -from typing import List, Optional +from typing import List from sgl_jax.srt.managers.schedule_policy import PrefillAdder from sgl_jax.srt.managers.scheduler import Req, ScheduleBatch diff --git a/python/sgl_jax/srt/managers/tp_worker.py b/python/sgl_jax/srt/managers/tp_worker.py index 7474fb18d..ada6e9ea6 100644 --- a/python/sgl_jax/srt/managers/tp_worker.py +++ b/python/sgl_jax/srt/managers/tp_worker.py @@ -106,7 +106,7 @@ def __init__( constraints = [server_limit, pool_limit, attn_backend_limit] self.max_running_requests = min(constraints) # Log each constraint for debugging - logger.info(f"Max running requests constraints:") + logger.info("Max running requests constraints:") logger.info( f" - Server limit: {server_limit} {'(max_total_tokens//2)' if server_args.max_running_requests is None else '(configured)'}" ) diff --git a/python/sgl_jax/srt/mem_cache/allocator.py b/python/sgl_jax/srt/mem_cache/allocator.py index 3cd36a22b..0c217c767 100644 --- a/python/sgl_jax/srt/mem_cache/allocator.py +++ b/python/sgl_jax/srt/mem_cache/allocator.py @@ -205,7 +205,6 @@ def alloc_extend( page_idx = 0 for seq_idx in range(batch_size): - seq_len = seq_lens_np[seq_idx] pre_len = prefix_lens_np[seq_idx] last_loc = last_loc_np[seq_idx] extend_len = extend_lens[seq_idx] diff --git a/python/sgl_jax/srt/mem_cache/chunk_cache.py b/python/sgl_jax/srt/mem_cache/chunk_cache.py index dce6d69d1..201c553ce 100644 --- a/python/sgl_jax/srt/mem_cache/chunk_cache.py +++ b/python/sgl_jax/srt/mem_cache/chunk_cache.py @@ -2,9 +2,7 @@ from typing import TYPE_CHECKING, Any, Optional -import jax import numpy as np -from jax import numpy as jnp from sgl_jax.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sgl_jax.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult diff --git a/python/sgl_jax/srt/mem_cache/radix_cache.py b/python/sgl_jax/srt/mem_cache/radix_cache.py index f6afe90a9..87c51c0e7 100644 --- a/python/sgl_jax/srt/mem_cache/radix_cache.py +++ b/python/sgl_jax/srt/mem_cache/radix_cache.py @@ -15,7 +15,7 @@ from sgl_jax.srt.mem_cache.memory_pool import ReqToTokenPool if TYPE_CHECKING: - from sgl_jax.srt.managers.schedule_batch import Req + pass class TreeNode: @@ -278,7 +278,6 @@ def get_cached_kv(self, token_ids: List[int]) -> Tuple[jnp.ndarray, int]: match_result = self.match_prefix(token_ids) matched_tokens = match_result.device_indices - last_node = match_result.last_device_node matched_len = len(matched_tokens) if matched_len == 0: @@ -325,9 +324,9 @@ def get_cached_kv(self, token_ids: List[int]) -> Tuple[jnp.ndarray, int]: k_data = jnp.stack( k_data_list, axis=0 ) # (layer_num, matched_len, head_num, head_dim) - v_data = jnp.stack( - v_data_list, axis=0 - ) # (layer_num, matched_len, head_num, head_dim) + # v_data = jnp.stack( + # v_data_list, axis=0 + # ) # (layer_num, matched_len, head_num, head_dim) # For this implementation, we return K data (could also return concatenated K,V) kv_data = k_data diff --git a/python/sgl_jax/srt/memory_profiler.py b/python/sgl_jax/srt/memory_profiler.py index 67d98dcb5..a8d56f079 100644 --- a/python/sgl_jax/srt/memory_profiler.py +++ b/python/sgl_jax/srt/memory_profiler.py @@ -7,7 +7,6 @@ from typing import Callable, Dict, List, Optional, Union import jax -import jax.numpy as jnp import jax.profiler logger = logging.getLogger(__name__) diff --git a/python/sgl_jax/srt/model_executor/forward_batch_info.py b/python/sgl_jax/srt/model_executor/forward_batch_info.py index c911b0543..cde2ed3a6 100644 --- a/python/sgl_jax/srt/model_executor/forward_batch_info.py +++ b/python/sgl_jax/srt/model_executor/forward_batch_info.py @@ -23,20 +23,17 @@ from typing import TYPE_CHECKING, List, Optional import jax - -logger = logging.getLogger(__name__) - from jax.sharding import NamedSharding, PartitionSpec - +from jax.tree_util import register_pytree_node_class from sgl_jax.srt.utils.jax_utils import device_array +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from sgl_jax.srt.layers.attention.base_attn_backend import AttentionBackend from sgl_jax.srt.managers.schedule_batch import ModelWorkerBatch from sgl_jax.srt.model_executor.model_runner import ModelRunner -from jax.tree_util import register_pytree_node_class - class ForwardMode(IntEnum): # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt). diff --git a/python/sgl_jax/srt/model_executor/model_runner.py b/python/sgl_jax/srt/model_executor/model_runner.py index 3b9ee3947..54137edf2 100644 --- a/python/sgl_jax/srt/model_executor/model_runner.py +++ b/python/sgl_jax/srt/model_executor/model_runner.py @@ -13,13 +13,12 @@ from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P -from sgl_jax.srt.configs.load_config import LoadConfig, LoadFormat +from sgl_jax.srt.configs.load_config import LoadConfig from sgl_jax.srt.configs.model_config import AttentionArch, MockModelConfig, ModelConfig from sgl_jax.srt.layers.logits_processor import LogitsMetadata, LogitsProcessorOutput from sgl_jax.srt.layers.sampler import Sampler from sgl_jax.srt.managers.schedule_batch import ( GLOBAL_SERVER_ARGS_KEYS, - ModelWorkerBatch, global_server_args_dict, ) from sgl_jax.srt.mem_cache.allocator import ( @@ -28,10 +27,10 @@ TokenToKVPoolAllocator, ) from sgl_jax.srt.mem_cache.memory_pool import MHATokenToKVPool, ReqToTokenPool -from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch from sgl_jax.srt.model_loader.loader import get_model_loader from sgl_jax.srt.precision_tracer import precision_tracer -from sgl_jax.srt.sampling.sampling_batch_info import SamplingBatchInfo, SamplingMetadata +from sgl_jax.srt.sampling.sampling_batch_info import SamplingMetadata from sgl_jax.srt.server_args import ServerArgs from sgl_jax.srt.utils.common_utils import get_bool_env_var from sgl_jax.srt.utils.jax_utils import get_available_device_memory diff --git a/python/sgl_jax/srt/models/llama.py b/python/sgl_jax/srt/models/llama.py index b93330f9e..9b1f76961 100644 --- a/python/sgl_jax/srt/models/llama.py +++ b/python/sgl_jax/srt/models/llama.py @@ -17,7 +17,7 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple import jax import jax.numpy as jnp @@ -28,7 +28,6 @@ from sgl_jax.srt.layers.embeddings import ( Embed, ParallelLMHead, - RotaryEmbedding, get_rope, ) from sgl_jax.srt.layers.layernorm import RMSNorm @@ -36,7 +35,6 @@ from sgl_jax.srt.layers.logits_processor import ( LogitsMetadata, LogitsProcessor, - LogitsProcessorOutput, ) from sgl_jax.srt.layers.radix_attention import RadixAttention from sgl_jax.srt.mem_cache.memory_pool import KVCache diff --git a/python/sgl_jax/srt/models/qwen.py b/python/sgl_jax/srt/models/qwen.py index 4de8f531c..15ffcb2bc 100644 --- a/python/sgl_jax/srt/models/qwen.py +++ b/python/sgl_jax/srt/models/qwen.py @@ -4,7 +4,6 @@ import jax import jax.numpy as jnp from flax import nnx -from jax.sharding import PartitionSpec from transformers import PretrainedConfig from sgl_jax.srt.configs.model_config import ModelConfig diff --git a/python/sgl_jax/srt/models/qwen2.py b/python/sgl_jax/srt/models/qwen2.py index ddce08b5a..7a821214a 100644 --- a/python/sgl_jax/srt/models/qwen2.py +++ b/python/sgl_jax/srt/models/qwen2.py @@ -4,7 +4,6 @@ import jax import jax.numpy as jnp from flax import nnx -from jax import numpy as jnp from transformers import PretrainedConfig from sgl_jax.srt.configs.model_config import ModelConfig diff --git a/python/sgl_jax/srt/models/qwen3.py b/python/sgl_jax/srt/models/qwen3.py index a1a782939..16aa6130b 100644 --- a/python/sgl_jax/srt/models/qwen3.py +++ b/python/sgl_jax/srt/models/qwen3.py @@ -4,7 +4,6 @@ import jax import jax.numpy as jnp from flax import nnx -from jax import numpy as jnp from transformers import PretrainedConfig from sgl_jax.srt.configs.model_config import ModelConfig diff --git a/python/sgl_jax/srt/models/qwen3_moe.py b/python/sgl_jax/srt/models/qwen3_moe.py index 0efa5c71b..faafedb57 100644 --- a/python/sgl_jax/srt/models/qwen3_moe.py +++ b/python/sgl_jax/srt/models/qwen3_moe.py @@ -4,7 +4,6 @@ from flax import nnx from jax import jax from jax import numpy as jnp -from jax.sharding import get_abstract_mesh from transformers import PretrainedConfig from sgl_jax.srt.configs.model_config import ModelConfig diff --git a/python/sgl_jax/srt/precision_tracer.py b/python/sgl_jax/srt/precision_tracer.py index fb1f70e15..9c699cbec 100644 --- a/python/sgl_jax/srt/precision_tracer.py +++ b/python/sgl_jax/srt/precision_tracer.py @@ -143,10 +143,6 @@ def get_max_requests(self): with self.lock: return self._max_requests - def get_completed_requests_count(self): - with self.lock: - return self._completed_requests_count - def get_request_counter(self): with self.lock: return self._request_counter @@ -222,7 +218,7 @@ def start_trace( os.makedirs(os.path.dirname(self._trace_output_file), exist_ok=True) - with open(self._trace_output_file, "w") as f: + with open(self._trace_output_file, "w") as _: pass logger.info(f"Request tracing started. Output: {self._trace_output_file}") diff --git a/python/sgl_jax/srt/sampling/sampling_batch_info.py b/python/sgl_jax/srt/sampling/sampling_batch_info.py index 2c95dd276..1f41b4350 100644 --- a/python/sgl_jax/srt/sampling/sampling_batch_info.py +++ b/python/sgl_jax/srt/sampling/sampling_batch_info.py @@ -20,7 +20,6 @@ import jax import jax.numpy as jnp import numpy as np -from jax._src import mesh as mesh_lib logger = logging.getLogger(__name__) diff --git a/python/sgl_jax/srt/server_args.py b/python/sgl_jax/srt/server_args.py index 52cec1584..c91e604f3 100644 --- a/python/sgl_jax/srt/server_args.py +++ b/python/sgl_jax/srt/server_args.py @@ -6,7 +6,7 @@ import logging import os import tempfile -from typing import List, Optional, Union +from typing import List, Optional import jax diff --git a/python/sgl_jax/srt/utils/__init__.py b/python/sgl_jax/srt/utils/__init__.py index 02312d5ff..e33ba6b06 100644 --- a/python/sgl_jax/srt/utils/__init__.py +++ b/python/sgl_jax/srt/utils/__init__.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401 from .common_utils import ( add_api_key_middleware, cdiv, diff --git a/python/sgl_jax/srt/utils/common_utils.py b/python/sgl_jax/srt/utils/common_utils.py index 1693f2021..d09fc4d2d 100644 --- a/python/sgl_jax/srt/utils/common_utils.py +++ b/python/sgl_jax/srt/utils/common_utils.py @@ -22,7 +22,7 @@ import traceback from collections import OrderedDict from pathlib import Path -from typing import Any, Callable, Optional, Set, Union +from typing import Any, Callable, Optional, Set, Union, Sequence import numpy as np import psutil @@ -378,10 +378,10 @@ def retry( return fn() except Exception as e: if try_index >= max_retry: - raise Exception(f"retry() exceed maximum number of retries.") + raise Exception("retry() exceed maximum number of retries.") if not should_retry(e): - raise Exception(f"retry() observe errors that should not be retried.") + raise Exception("retry() observe errors that should not be retried.") delay = min(initial_delay * (2**try_index), max_delay) * ( 0.75 + 0.25 * random.random() diff --git a/python/sgl_jax/srt/utils/jax_utils.py b/python/sgl_jax/srt/utils/jax_utils.py index abce88a4b..2da16202b 100644 --- a/python/sgl_jax/srt/utils/jax_utils.py +++ b/python/sgl_jax/srt/utils/jax_utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import jax import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec +from jax.sharding import PartitionSpec def get_num_kv_heads_by_tp(total_num_kv_heads: int, tp_size: int) -> int: diff --git a/python/sgl_jax/srt/utils/tunix_utils.py b/python/sgl_jax/srt/utils/tunix_utils.py new file mode 100644 index 000000000..e837373a1 --- /dev/null +++ b/python/sgl_jax/srt/utils/tunix_utils.py @@ -0,0 +1,9 @@ +# refer to tunix +def pathways_available() -> bool: + try: + # ruff: noqa: F401 + import pathwaysutils + + return True + except ImportError: + return False diff --git a/python/sgl_jax/srt/utils/weight_utils.py b/python/sgl_jax/srt/utils/weight_utils.py index f6ff8c7d8..c7d2647f8 100644 --- a/python/sgl_jax/srt/utils/weight_utils.py +++ b/python/sgl_jax/srt/utils/weight_utils.py @@ -1,10 +1,9 @@ -import functools import glob import logging import math import os from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import jax import jax.numpy as jnp diff --git a/python/sgl_jax/test/mem_cache/test_kv_cache.py b/python/sgl_jax/test/mem_cache/test_kv_cache.py index d280dac47..c43830d2a 100644 --- a/python/sgl_jax/test/mem_cache/test_kv_cache.py +++ b/python/sgl_jax/test/mem_cache/test_kv_cache.py @@ -1,4 +1,3 @@ -import random import unittest import jax @@ -142,10 +141,6 @@ def test_kv_cache_update_page_size_1_with_padding(self): # Verify that padding tokens didn't affect the cache padding_mask = loc == -1 if jnp.any(padding_mask): - # Check that original cache values at padding positions are unchanged - original_k_cache = jnp.zeros_like(k_cache) # Original was all zeros - original_v_cache = jnp.zeros_like(v_cache) - # For positions that should be ignored (padding), cache should remain unchanged for i in range(total_tokens): if loc[i] == -1: diff --git a/python/sgl_jax/test/mem_cache/test_radix_cache.py b/python/sgl_jax/test/mem_cache/test_radix_cache.py index 0ede312d2..82370e850 100644 --- a/python/sgl_jax/test/mem_cache/test_radix_cache.py +++ b/python/sgl_jax/test/mem_cache/test_radix_cache.py @@ -10,7 +10,6 @@ os.environ["JAX_PLATFORMS"] = "cpu" import unittest -from unittest.mock import Mock import jax import jax.numpy as jnp @@ -740,7 +739,6 @@ def print_sharding(obj, name, prefix=""): def test_cache_finished_req_disabled(self): """test cache finished request disabled""" # create disabled cache - mesh = Mesh([self.devices[0]], axis_names=("tensor",)) disabled_cache = RadixCache( req_to_token_pool=self.req_pool, token_to_kv_pool_allocator=self.allocator, diff --git a/python/sgl_jax/test/model_executor/test_model_runner.py b/python/sgl_jax/test/model_executor/test_model_runner.py index 9ea4f886e..6b65082f3 100644 --- a/python/sgl_jax/test/model_executor/test_model_runner.py +++ b/python/sgl_jax/test/model_executor/test_model_runner.py @@ -1,5 +1,6 @@ import os +# ruff: noqa: E402 TP_SIZE = int(os.environ.get("TP_SIZE", 1)) os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={TP_SIZE}" import unittest diff --git a/python/sgl_jax/test/models/test_qwen_model.py b/python/sgl_jax/test/models/test_qwen_model.py index 9b91a9392..d2bcaa84f 100644 --- a/python/sgl_jax/test/models/test_qwen_model.py +++ b/python/sgl_jax/test/models/test_qwen_model.py @@ -1,7 +1,6 @@ import os import unittest from pathlib import Path -from unittest.mock import patch import jax.numpy as jnp from flax import nnx @@ -395,11 +394,11 @@ def test_qwen_model_forward(self, batch_size: int = None): input_texts = self._generate_random_questions(batch_size) print(f"\nGenerated {batch_size} random questions for batch testing") - print(f"\nBatch Configuration:") + print("\nBatch Configuration:") print(f" Total requests: {len(input_texts)}") print(f" Batch size: {len(input_texts)}") - print(f"\nSample questions:") + print("\nSample questions:") for i, text in enumerate(input_texts[: min(5, len(input_texts))]): print(f" {i+1}: '{text}'") if len(input_texts) > 5: @@ -414,7 +413,7 @@ def test_qwen_model_forward(self, batch_size: int = None): self._create_batch_from_texts(model.config, input_texts, tokenizer) ) - print(f"\n Batch Processing Info:") + print("\n Batch Processing Info:") print(f" Input tokens shape: {input_ids_array.shape}") print( f" Actual sequence lengths: {actual_seq_lens[:10]}{'...' if len(actual_seq_lens) > 10 else ''}" @@ -517,8 +516,8 @@ def test_qwen_model_forward(self, batch_size: int = None): len(r["output"]) for r in final_results.values() ) / len(final_results) - print(f"\n === Generation Results Summary ===") - print(f"Performance Metrics:") + print("\n === Generation Results Summary ===") + print("Performance Metrics:") print(f" Total time: {total_time:.2f} seconds") print(f" Requests processed: {len(input_texts)}") print(f" Requests finished: {finished_count}/{len(input_texts)}") @@ -528,7 +527,7 @@ def test_qwen_model_forward(self, batch_size: int = None): # Print detailed results for small batches if len(input_texts) <= 10: - print(f"\nDetailed Results:") + print("\nDetailed Results:") for i in range(len(input_texts)): result = final_results[i] status = ( @@ -541,7 +540,7 @@ def test_qwen_model_forward(self, batch_size: int = None): print(f" Output: '{result['output']}'") else: # Show only a few examples - print(f"\n Sample Results (first 3):") + print("\n Sample Results (first 3):") for i in range(min(3, len(input_texts))): result = final_results[i] status = " Finished" if result["finished"] else "Max iterations" @@ -559,7 +558,7 @@ def test_qwen_model_forward(self, batch_size: int = None): self.assertIsNotNone(result["output"]) self.assertTrue(len(result["output"]) >= len(result["input"])) - print(f"\n Batch test completed successfully!") + print("\n Batch test completed successfully!") return { "total_time": total_time, "throughput": len(input_texts) / total_time, diff --git a/python/sgl_jax/test/run_curl.py b/python/sgl_jax/test/run_curl.py index c0d86484e..ce59b621e 100644 --- a/python/sgl_jax/test/run_curl.py +++ b/python/sgl_jax/test/run_curl.py @@ -5,7 +5,6 @@ import argparse import json -import time import requests diff --git a/python/sgl_jax/test/run_jax_loader_test.py b/python/sgl_jax/test/run_jax_loader_test.py index 26c231cde..6b3b226d2 100644 --- a/python/sgl_jax/test/run_jax_loader_test.py +++ b/python/sgl_jax/test/run_jax_loader_test.py @@ -25,7 +25,7 @@ import os import subprocess import sys -import unittest +import importlib from pathlib import Path @@ -55,8 +55,8 @@ def check_jax_dependencies(): def check_sglang_dependencies(): """Check if SGLang dependencies are available""" try: - from sgl_jax.srt.configs.load_config import LoadFormat - from sgl_jax.srt.model_loader.loader import JAXModelLoader + importlib.util.find_spec("sgl_jax.srt.configs.load_config.LoadFormat") + importlib.util.find_spec("sgl_jax.srt.model_loader.loader.JAXModelLoader") print("✓ SGLang JAXModelLoader available") return True @@ -96,7 +96,6 @@ def run_tests(test_name=None, model_path=None, verbose=False): def create_sample_jax_model(output_dir): """Create a sample JAX model directory for testing""" import json - import tempfile model_dir = Path(output_dir) / "sample_jax_model" model_dir.mkdir(parents=True, exist_ok=True) @@ -175,7 +174,7 @@ def main(): if args.create_sample: try: model_path = create_sample_jax_model(args.create_sample) - print(f"\nYou can now run tests with:") + print("\nYou can now run tests with:") print(f"python {__file__} --model-path {model_path}") return 0 except Exception as e: diff --git a/python/sgl_jax/test/run_qwen3_moe_test.py b/python/sgl_jax/test/run_qwen3_moe_test.py index 98989e7b8..0ecaf3d0e 100644 --- a/python/sgl_jax/test/run_qwen3_moe_test.py +++ b/python/sgl_jax/test/run_qwen3_moe_test.py @@ -28,7 +28,7 @@ import os import subprocess import sys -import unittest +import importlib from pathlib import Path os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" @@ -39,13 +39,13 @@ def check_jax_dependencies(): try: import jax import jax.numpy as jnp - from flax import nnx + importlib.util.find_spec("flax.nnx") print(f"✓ JAX version: {jax.__version__}") print(f"✓ JAX backend: {jax.default_backend()}") print(f"✓ Available devices: {len(jax.devices())} devices") print(f"✓ Device types: {[d.platform for d in jax.devices()]}") - print(f"✓ Flax NNX available") + print("✓ Flax NNX available") # Test basic JAX operations x = jnp.array([1, 2, 3]) @@ -65,9 +65,9 @@ def check_jax_dependencies(): def check_sglang_dependencies(): """Check if SGLang dependencies are available""" try: - from sgl_jax.srt.configs.load_config import LoadFormat - from sgl_jax.srt.model_loader.loader import JAXModelLoader - from sgl_jax.srt.models.qwen3_moe import Qwen3MoeForCausalLMJaxModel + importlib.util.find_spec("sgl_jax.srt.configs.load_config.LoadFormat") + importlib.util.find_spec("sgl_jax.srt.model_loader.loader.JAXModelLoader") + importlib.util.find_spec("sgl_jax.srt.models.qwen3_moe.Qwen3MoeForCausalLMJaxModel") print("✓ SGLang JAXModelLoader available") print("✓ Qwen3MoeForCausalLMJaxModel available") @@ -82,7 +82,6 @@ def check_transformers_dependencies(): """Check if Transformers dependencies are available""" try: import transformers - from transformers import PretrainedConfig print(f"✓ Transformers version: {transformers.__version__}") return True @@ -271,7 +270,7 @@ def create_sample_qwen3_moe_model(output_dir): msgpack.pack(mock_weights, f) print(f"Created sample Qwen3 MoE JAX model at: {model_dir}") - print(f" - config.json: Model configuration with MoE settings") + print(" - config.json: Model configuration with MoE settings") print(f" - model.msgpack: Mock weights ({msgpack_file.stat().st_size} bytes)") print( f" - MoE config: {config['num_experts']} experts, {config['num_experts_per_tok']} experts per token" @@ -369,7 +368,7 @@ def main(): if args.create_sample: try: model_path = create_sample_qwen3_moe_model(args.create_sample) - print(f"\nYou can now run tests with:") + print("\nYou can now run tests with:") print(f"python {__file__} --model-path {model_path}") return 0 except Exception as e: diff --git a/python/sgl_jax/test/run_qwen_test.py b/python/sgl_jax/test/run_qwen_test.py index 45ee5ca39..d74ee587a 100644 --- a/python/sgl_jax/test/run_qwen_test.py +++ b/python/sgl_jax/test/run_qwen_test.py @@ -2,7 +2,7 @@ import os import subprocess import sys -import unittest +import importlib from pathlib import Path @@ -11,11 +11,11 @@ def check_jax_dependencies(): try: import jax import jax.numpy as jnp - from flax import nnx + importlib.util.find_spec("flax.nnx") print(f"✓ JAX version: {jax.__version__}") print(f"✓ JAX backend: {jax.default_backend()}") - print(f"✓ Flax NNX available") + print("✓ Flax NNX available") # Test basic JAX operations x = jnp.array([1, 2, 3]) @@ -35,9 +35,9 @@ def check_jax_dependencies(): def check_sglang_dependencies(): """Check if SGLang dependencies are available""" try: - from sgl_jax.srt.configs.load_config import LoadFormat - from sgl_jax.srt.model_loader.loader import JAXModelLoader - from sgl_jax.srt.models.qwen import QWenLMHeadJaxModel + importlib.util.find_spec("sgl_jax.srt.configs.load_config.LoadFormat") + importlib.util.find_spec("sgl_jax.srt.model_loader.loader.JAXModelLoader") + importlib.util.find_spec("sgl_jax.srt.models.qwen.QWenLMHeadJaxModel") print("✓ SGLang JAXModelLoader available") print("✓ QWenLMHeadJaxModel available") @@ -52,7 +52,6 @@ def check_transformers_dependencies(): """Check if Transformers dependencies are available""" try: import transformers - from transformers import PretrainedConfig print(f"✓ Transformers version: {transformers.__version__}") return True @@ -188,7 +187,7 @@ def create_sample_qwen_model(output_dir): msgpack.pack(mock_weights, f) print(f"Created sample QWen JAX model at: {model_dir}") - print(f" - config.json: Model configuration") + print(" - config.json: Model configuration") print(f" - model.msgpack: Mock weights ({msgpack_file.stat().st_size} bytes)") except ImportError: @@ -290,7 +289,7 @@ def main(): if args.create_sample: try: model_path = create_sample_qwen_model(args.create_sample) - print(f"\nYou can now run tests with:") + print("\nYou can now run tests with:") print(f"python {__file__} --model-path {model_path}") return 0 except Exception as e: diff --git a/python/sgl_jax/test/runners.py b/python/sgl_jax/test/runners.py index ae5122b24..37d25e4c4 100644 --- a/python/sgl_jax/test/runners.py +++ b/python/sgl_jax/test/runners.py @@ -12,24 +12,9 @@ # limitations under the License. # ============================================================================== -import multiprocessing as mp import os -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union -import transformers -from transformers import ( - AutoConfig, - AutoModel, - AutoModelForCausalLM, - AutoModelForVision2Seq, - AutoProcessor, - GenerationConfig, -) -from sgl_jax.srt.entrypoints.engine import Engine -from sgl_jax.srt.hf_transformers_utils import get_tokenizer -from sgl_jax.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l DEFAULT_PROMPTS = [ "Apple is red. Banana is Yellow. " * 800 + "Apple is", diff --git a/python/sgl_jax/test/simple_eval_gpqa.py b/python/sgl_jax/test/simple_eval_gpqa.py index 3ad6cb39f..0a834a595 100644 --- a/python/sgl_jax/test/simple_eval_gpqa.py +++ b/python/sgl_jax/test/simple_eval_gpqa.py @@ -18,7 +18,6 @@ HTML_JINJA, Eval, EvalResult, - MessageList, SamplerBase, SingleEvalResult, format_multichoice_question, diff --git a/python/sgl_jax/test/simple_eval_humaneval.py b/python/sgl_jax/test/simple_eval_humaneval.py index e15f152b4..1113bf4f7 100644 --- a/python/sgl_jax/test/simple_eval_humaneval.py +++ b/python/sgl_jax/test/simple_eval_humaneval.py @@ -11,7 +11,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Optional -import tqdm try: from human_eval.data import read_problems @@ -41,7 +40,6 @@ def evaluate_functional_correctness( Evaluates the functional correctness of generated samples, and writes results to f"{sample_file}_results.jsonl.gz" """ - import copy # Check the generated samples against test suites. with ThreadPoolExecutor(max_workers=n_workers) as executor: diff --git a/python/sgl_jax/test/simple_eval_mgsm.py b/python/sgl_jax/test/simple_eval_mgsm.py index d88445c54..7dc6d32cb 100644 --- a/python/sgl_jax/test/simple_eval_mgsm.py +++ b/python/sgl_jax/test/simple_eval_mgsm.py @@ -175,7 +175,7 @@ def fn(example: dict[str, str]): ] try: response_text = sampler(prompt_messages) - except Exception as e: + except Exception: response_text = "" answer_prefix = LANG_TO_ANSWER_PREFIX[language] diff --git a/python/sgl_jax/test/test_flashattention.py b/python/sgl_jax/test/test_flashattention.py index 47c8f2fd5..8a5958eaa 100644 --- a/python/sgl_jax/test/test_flashattention.py +++ b/python/sgl_jax/test/test_flashattention.py @@ -13,7 +13,6 @@ from sgl_jax.srt.managers.schedule_batch import ModelWorkerBatch from sgl_jax.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sgl_jax.srt.model_executor.model_runner import ModelRunner from sgl_jax.srt.utils.mesh_utils import create_device_mesh from sgl_jax.test.test_utils import CustomTestCase @@ -42,7 +41,7 @@ def create_qkv_cache( page_size=1, ): batched_q_len = sum([q_len for q_len, _ in lens]) - batched_kv_len = sum([kv_len for _, kv_len in lens]) + # batched_kv_len = sum([kv_len for _, kv_len in lens]) # Calculate aligned batched_kv_len seq_lens = jnp.array([kv_len for _, kv_len in lens], dtype=jnp.int32) @@ -165,9 +164,9 @@ def create_test_data( q, k, v = create_qkv_cache(lens, num_heads, head_dim, num_kv_heads, page_size) # cache loc - match schedule_batch.py logic with align_to_size - def align_to_size(l, size, value=0): - align_len = (len(l) + size - 1) // size * size - return l + [value] * (align_len - len(l)) + def align_to_size(lst, size, value=0): + align_len = (len(lst) + size - 1) // size * size + return lst + [value] * (align_len - len(lst)) cache_loc_flat = [] current_aligned_pos = 0 # Track aligned position in k/v cache @@ -312,7 +311,7 @@ def run_test(self, mode, lens, mode_args): ) # Debug cache mapping - print(f"=== Cache Mapping Debug ===") + print("=== Cache Mapping Debug ===") print(f"lens: {lens}") print(f"seq_lens: {forward_batch.seq_lens}") print(f"cu_q_lens: {forward_batch.attn_backend.forward_metadata.cu_q_lens}") @@ -395,13 +394,13 @@ def jit_attn(q, k, v, forward_batch, token_to_kv_pool: KVCache): diff = np.abs(jax_flat - expected_flat) max_diff = np.max(diff) - print(f"=== Detailed Analysis ===") + print("=== Detailed Analysis ===") print(f"JAX output shape: {jax_flat.shape}") print(f"Expected shape: {expected_flat.shape}") print(f"Max difference: {max_diff}") # Analyze by token dimension (rows) - show only first 5 tokens - print(f"\n=== Token-wise Analysis (first 20 tokens) ===") + print("\n=== Token-wise Analysis (first 20 tokens) ===") num_tokens = jax_flat.shape[0] for i in range(min(num_tokens, 20)): jax_row = np.asarray(jax_flat[i]) @@ -418,7 +417,7 @@ def jit_attn(q, k, v, forward_batch, token_to_kv_pool: KVCache): print() # Overall statistics - print(f"=== Overall Statistics ===") + print("=== Overall Statistics ===") print( f"JAX output: mean={float(np.mean(jax_flat)):.6f}, std={float(np.std(jax_flat)):.6f}" ) diff --git a/python/sgl_jax/test/test_jax_model_loader.py b/python/sgl_jax/test/test_jax_model_loader.py index c134c0d1a..b5f9369e6 100644 --- a/python/sgl_jax/test/test_jax_model_loader.py +++ b/python/sgl_jax/test/test_jax_model_loader.py @@ -87,7 +87,7 @@ def _print_pytree_structure(self, pytree, prefix="", max_depth=10, current_depth memory_str = f"{memory_bytes / 1024:.2f} KB" else: memory_str = f"{memory_bytes} bytes" - except: + except Exception: memory_str = "unknown" stats_str = "" @@ -100,7 +100,7 @@ def _print_pytree_structure(self, pytree, prefix="", max_depth=10, current_depth stats_str = f", range=[{min_val:.6f}, {max_val:.6f}], mean={mean_val:.6f}" else: stats_str = f", range=[{min_val:.6f}, {max_val:.6f}]" - except: + except Exception: pass # Print tensor data preview diff --git a/python/sgl_jax/test/test_model_loader.py b/python/sgl_jax/test/test_model_loader.py index 9af0505ff..b273a11f9 100644 --- a/python/sgl_jax/test/test_model_loader.py +++ b/python/sgl_jax/test/test_model_loader.py @@ -2,7 +2,6 @@ import os import tempfile import unittest -from pathlib import Path # Set up multi-device simulation for tensor parallelism if os.environ.get("USE_DEVICE_TYPE") == "cpu": @@ -84,7 +83,7 @@ def test_multi_device_environment_setup(self): """Test that multi-device environment is properly configured.""" devices = jax.devices() - print(f" Environment validation:") + print(" Environment validation:") print(f" XLA_FLAGS: {os.environ.get('XLA_FLAGS', 'Not set')}") print(f" JAX_PLATFORMS: {os.environ.get('JAX_PLATFORMS', 'Not set')}") print(f" Detected devices: {len(devices)}") @@ -94,12 +93,12 @@ def test_multi_device_environment_setup(self): if "--xla_force_host_platform_device_count=8" in os.environ.get( "XLA_FLAGS", "" ): - print(f"PASS: Multi-device simulation properly configured") + print("PASS: Multi-device simulation properly configured") self.assertGreaterEqual( len(devices), 2, "Should have at least 2 simulated devices" ) else: - print(f"WARNING: Multi-device simulation not configured") + print("WARNING: Multi-device simulation not configured") # Test mesh creation with available devices if len(devices) >= 4: @@ -121,7 +120,7 @@ def test_multi_device_environment_setup(self): device.platform, "cpu", f"Expected CPU device, got {device.platform}" ) - print(f"PASS: Multi-device environment validation completed!") + print("PASS: Multi-device environment validation completed!") def test_sharding_configuration(self): """Test that sharding configuration works correctly.""" @@ -160,7 +159,7 @@ def test_sharding_configuration(self): "Array should be distributed across multiple devices", ) - print(f"PASS: Sharding configuration test completed!") + print("PASS: Sharding configuration test completed!") def tearDown(self): """Clean up test fixtures.""" @@ -259,7 +258,7 @@ def test_model_config_creation(self): self.assertGreater(model_config.hidden_size, 0) self.assertGreater(model_config.num_attention_heads, 0) - print(f"PASS: Model config created successfully:") + print("PASS: Model config created successfully:") print(f" Model path: {model_config.model_path}") print(f" Architecture: {model_config.hf_config.architectures}") print(f" Hidden size: {model_config.hidden_size}") @@ -286,7 +285,7 @@ def test_qwen_model_instantiation(self): self.assertEqual(model.mesh, self.mesh) self.assertTrue(hasattr(model, "load_weights")) - print(f"PASS: QWen model instantiated successfully") + print("PASS: QWen model instantiated successfully") except Exception as e: self.fail(f"Failed to instantiate QWen model: {e}") @@ -321,7 +320,7 @@ def test_weight_loading_process(self): # Print the actual parameter structure of the model try: params = nnx.state(model) - print(f" Model parameter structure:") + print(" Model parameter structure:") self._print_param_structure(params, "", max_depth=3) except Exception as e: print(f"WARNING: Could not extract model parameters: {e}") @@ -332,7 +331,7 @@ def test_weight_loading_process(self): # Attempt to load weights model.load_weights(jax.random.PRNGKey(42)) - print(f"PASS: Weight loading completed successfully!") + print("PASS: Weight loading completed successfully!") except Exception as e: # Print detailed error for debugging @@ -388,7 +387,7 @@ def test_model_actual_structure_debug(self): model_path=self.model_path, trust_remote_code=True, dtype="bfloat16" ) - print(f" Model Config Details:") + print(" Model Config Details:") print(f" Architecture: {model_config.hf_config.architectures}") print(f" Hidden size: {model_config.hidden_size}") print(f" Num layers: {model_config.hf_config.num_hidden_layers}") @@ -400,14 +399,14 @@ def test_model_actual_structure_debug(self): ) # Check what attributes the model actually has - print(f" Model Attributes:") + print(" Model Attributes:") for attr in dir(model): if not attr.startswith("_"): try: value = getattr(model, attr) if not callable(value): print(f" {attr}: {type(value)}") - except: + except Exception: pass except Exception as e: @@ -427,7 +426,7 @@ def test_full_model_loading_pipeline(self): # Create loader loader = get_model_loader(self.load_config, self.rng, self.mesh) - print(f"🔄 Testing full loading pipeline...") + print("🔄 Testing full loading pipeline...") # Test download_model (should be no-op for local path) loader.download_model(model_config) @@ -436,7 +435,7 @@ def test_full_model_loading_pipeline(self): model = loader.load_model(model_config=model_config) self.assertIsNotNone(model) - print(f"PASS: Full model loading pipeline completed!") + print("PASS: Full model loading pipeline completed!") state = nnx.state(model) def print_sharding(params, prefix=""): @@ -500,13 +499,13 @@ def test_model_parameter_structure_validation(self): model_config, model_config.dtype, self.rng, self.mesh ) - print(f" Validating Model Parameter Structure:") + print(" Validating Model Parameter Structure:") print(f" Model type: {type(model).__name__}") # Get model state try: params = nnx.state(model) - print(f" Full Model Parameter Structure:") + print(" Full Model Parameter Structure:") self._print_param_structure_detailed(params, "", max_depth=5) # Check key mappings we're expecting @@ -524,7 +523,7 @@ def test_model_parameter_structure_validation(self): "transformer.h.0.mlp.c_proj.weight", ] - print(f"\n Checking Expected Parameter Paths:") + print("\n Checking Expected Parameter Paths:") for path in expected_paths: try: param = self._get_param_by_path(params, path) @@ -625,7 +624,7 @@ def test_multi_device_tensor_parallelism(self): # Check if weights are properly sharded state = nnx.state(model) - print(f" Checking weight sharding across devices:") + print(" Checking weight sharding across devices:") # Check a few key parameters for sharding key_params = [ @@ -664,7 +663,7 @@ def test_multi_device_tensor_parallelism(self): except Exception as e: print(f" ERROR: Could not check {param_path}: {e}") - print(f"PASS: Multi-device tensor parallelism test completed!") + print("PASS: Multi-device tensor parallelism test completed!") except Exception as e: print(f"ERROR: Multi-device tensor parallelism test failed: {e}") @@ -696,13 +695,13 @@ def test_tensor_parallel_computation(self): model_path=self.model_path, trust_remote_code=True, dtype="bfloat16" ) - print(f"🔄 Testing tensor parallel computation...") + print("🔄 Testing tensor parallel computation...") # Create loader and load model loader = get_model_loader(self.load_config, self.rng, self.mesh) model = loader.load_model(model_config=model_config) - print(f"PASS: Model loaded for computation test") + print("PASS: Model loaded for computation test") # Create a simple test input # Note: We're not actually running forward pass here since we'd need @@ -727,7 +726,7 @@ def test_tensor_parallel_computation(self): ) print(f" 📊 Sum result sharding: {weight_sum.sharding}") - print(f"PASS: Tensor parallel computation test completed!") + print("PASS: Tensor parallel computation test completed!") except Exception as e: print(f"ERROR: Tensor parallel computation test failed: {e}") @@ -775,20 +774,20 @@ def tearDown(self): def test_nonexistent_model_path(self): """Test handling of nonexistent model path.""" - load_config = LoadConfig(load_format=LoadFormat.JAX) + LoadConfig(load_format=LoadFormat.JAX) with self.assertRaises(Exception): - model_config = ModelConfig( + ModelConfig( model_path="/nonexistent/path", trust_remote_code=True ) def test_empty_directory(self): """Test handling of empty model directory.""" - load_config = LoadConfig(load_format=LoadFormat.JAX) + LoadConfig(load_format=LoadFormat.JAX) # This should fail when trying to create ModelConfig with self.assertRaises(Exception): - model_config = ModelConfig( + ModelConfig( model_path=self.temp_dir, trust_remote_code=True # Empty directory ) diff --git a/python/sgl_jax/test/test_sampler.py b/python/sgl_jax/test/test_sampler.py index 57ba42ef3..fa9403f5b 100644 --- a/python/sgl_jax/test/test_sampler.py +++ b/python/sgl_jax/test/test_sampler.py @@ -12,8 +12,8 @@ class TestMultinomialWithSeed(unittest.TestCase): def test_deterministic_sampling_with_same_seed(self): """Test that same (inputs, seed) pair always yields the same sample.""" # Setup test data - batch_size = 4 - vocab_size = 10 + # batch_size = 4 + # vocab_size = 10 # Create logits that simulate different temperature scenarios flatter_distribution = jnp.array( diff --git a/python/sgl_jax/test/test_utils.py b/python/sgl_jax/test/test_utils.py index 7a53a1ef1..2c6646cc5 100644 --- a/python/sgl_jax/test/test_utils.py +++ b/python/sgl_jax/test/test_utils.py @@ -3,7 +3,6 @@ import logging import os import re -import shutil import signal import subprocess import sys @@ -15,7 +14,6 @@ from typing import Awaitable, Callable, Optional, Sequence import jax -import jax.numpy as jnp import numpy as np import psutil import requests diff --git a/python/sgl_jax/tools/trace_diff.py b/python/sgl_jax/tools/trace_diff.py index fc0d3e736..94d9ebc19 100755 --- a/python/sgl_jax/tools/trace_diff.py +++ b/python/sgl_jax/tools/trace_diff.py @@ -1,7 +1,7 @@ import argparse import json import sys -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Tuple class Colors: @@ -640,7 +640,6 @@ def compare_trace_files( for content_hash in sorted(only_in_1): traces = groups1[content_hash] trace = traces[0] - records = trace.get("precision_records", []) print( f" {Colors.YELLOW}-{Colors.RESET} {content_hash} (Request ID: {trace.get('request_id', 'N/A')})" ) @@ -650,7 +649,6 @@ def compare_trace_files( for content_hash in sorted(only_in_2): traces = groups2[content_hash] trace = traces[0] - records = trace.get("precision_records", []) print( f" {Colors.YELLOW}+{Colors.RESET} {content_hash} (Request ID: {trace.get('request_id', 'N/A')})" ) diff --git a/python/sgl_jax/utils.py b/python/sgl_jax/utils.py index 83e44c246..4cc9defb4 100644 --- a/python/sgl_jax/utils.py +++ b/python/sgl_jax/utils.py @@ -1,15 +1,7 @@ import logging -import os -import subprocess import traceback -from io import BytesIO -from typing import Any, Callable, List, Tuple, Type, Union - -import psutil -import pybase64 -import requests -import zmq -from PIL import Image +from typing import Any, Callable, List, Tuple, Type + logger = logging.getLogger(__name__) From 92005b7930c3e565c8bf81a4f22e134ebe126887 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Wed, 15 Oct 2025 21:06:28 -0400 Subject: [PATCH 02/18] Update precommit script --- script.sh | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 script.sh diff --git a/script.sh b/script.sh new file mode 100644 index 000000000..e98723f6b --- /dev/null +++ b/script.sh @@ -0,0 +1,9 @@ +sky launch config.yaml -y --use-spot --infra=gcp -i 5 --down +ssh sky-466b-chhzh123 +conda create -n sglang python=3.12 +conda activate sglang +git clone https://github.com/sgl-project/sglang-jax.git +cd sglang-jax +git checkout feat/grok +pip install --upgrade pip setuptools packaging +pip install -e "python[all]" \ No newline at end of file From 90b68a77d8e18d8ccfd7a1e1765fa6279cd723a7 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Wed, 15 Oct 2025 21:07:41 -0400 Subject: [PATCH 03/18] Fix --- config.yaml | 17 +++++++++++++++++ script.sh | 9 --------- 2 files changed, 17 insertions(+), 9 deletions(-) create mode 100644 config.yaml delete mode 100644 script.sh diff --git a/config.yaml b/config.yaml new file mode 100644 index 000000000..14ff8032e --- /dev/null +++ b/config.yaml @@ -0,0 +1,17 @@ +resources: + accelerators: tpu-v6e-4 # + accelerator_args: + tpu_vm: True + runtime_version: v2-alpha-tpuv6e # optional +file_mounts: + ~/.ssh/id_rsa: ~/.ssh/id_rsa +setup: | + chmod 600 ~/.ssh/id_rsa + rm ~/.ssh/config + # GIT_SSH_COMMAND="ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no" git clone git@github.com:Furion-cn/sgl-jax.git + # cd /home/gcpuser/sky_workdir/sgl-jax && uv pip install -e python/ + +# run: | +# SGL_JAX_USE_JIT=1 JAX_COMPILATION_CACHE_DIR=/tmp/jax_compilation_cache +# uv run python -u -m sgl_jax.launch_server --model-path Qwen/Qwen-7B --engine- +# type jax --trust-remote-code --skip-server-warmup --dist-init-addr=0.0.0.0:10011 --nnodes=1 --tp-size=4 --device=tpu --random-seed=3 --node-rank=0 --mem-fraction-static=0.1 --max-prefill-tokens=4096 --download-dir=/tmp/ --kv-cache-dtype=bf16 \ No newline at end of file diff --git a/script.sh b/script.sh deleted file mode 100644 index e98723f6b..000000000 --- a/script.sh +++ /dev/null @@ -1,9 +0,0 @@ -sky launch config.yaml -y --use-spot --infra=gcp -i 5 --down -ssh sky-466b-chhzh123 -conda create -n sglang python=3.12 -conda activate sglang -git clone https://github.com/sgl-project/sglang-jax.git -cd sglang-jax -git checkout feat/grok -pip install --upgrade pip setuptools packaging -pip install -e "python[all]" \ No newline at end of file From eaec004ecb1a77524c3a92b3b962ae938f3535a0 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Thu, 16 Oct 2025 04:33:11 +0000 Subject: [PATCH 04/18] Update ruff --- .pre-commit-config.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index be12f9ec1..21a62d62e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,12 +25,13 @@ repos: args: ["--profile=black"] exclude: ^python/sgl_jax/test/run_eval\.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.7 + rev: v0.13.3 hooks: - - id: ruff - args: ["--select=F401,F821", "--fixable=F401,F821"] + - id: ruff-check + args: [--output-format, github, --fix] files: ^(python/|benchmark/|docs/|examples/) exclude: \.ipynb$ + - id: ruff-format - repo: https://github.com/psf/black rev: 24.10.0 hooks: From dcb443f79bcc9bb671cb9a2c172c740dd11308ec Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Thu, 16 Oct 2025 04:35:12 +0000 Subject: [PATCH 05/18] Fix --- config.yaml | 17 ----------------- python/sgl_jax/test/test_model_loader.py | 7 +++---- 2 files changed, 3 insertions(+), 21 deletions(-) delete mode 100644 config.yaml diff --git a/config.yaml b/config.yaml deleted file mode 100644 index 14ff8032e..000000000 --- a/config.yaml +++ /dev/null @@ -1,17 +0,0 @@ -resources: - accelerators: tpu-v6e-4 # - accelerator_args: - tpu_vm: True - runtime_version: v2-alpha-tpuv6e # optional -file_mounts: - ~/.ssh/id_rsa: ~/.ssh/id_rsa -setup: | - chmod 600 ~/.ssh/id_rsa - rm ~/.ssh/config - # GIT_SSH_COMMAND="ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no" git clone git@github.com:Furion-cn/sgl-jax.git - # cd /home/gcpuser/sky_workdir/sgl-jax && uv pip install -e python/ - -# run: | -# SGL_JAX_USE_JIT=1 JAX_COMPILATION_CACHE_DIR=/tmp/jax_compilation_cache -# uv run python -u -m sgl_jax.launch_server --model-path Qwen/Qwen-7B --engine- -# type jax --trust-remote-code --skip-server-warmup --dist-init-addr=0.0.0.0:10011 --nnodes=1 --tp-size=4 --device=tpu --random-seed=3 --node-rank=0 --mem-fraction-static=0.1 --max-prefill-tokens=4096 --download-dir=/tmp/ --kv-cache-dtype=bf16 \ No newline at end of file diff --git a/python/sgl_jax/test/test_model_loader.py b/python/sgl_jax/test/test_model_loader.py index b273a11f9..ac53bf44f 100644 --- a/python/sgl_jax/test/test_model_loader.py +++ b/python/sgl_jax/test/test_model_loader.py @@ -777,9 +777,7 @@ def test_nonexistent_model_path(self): LoadConfig(load_format=LoadFormat.JAX) with self.assertRaises(Exception): - ModelConfig( - model_path="/nonexistent/path", trust_remote_code=True - ) + ModelConfig(model_path="/nonexistent/path", trust_remote_code=True) def test_empty_directory(self): """Test handling of empty model directory.""" @@ -788,7 +786,8 @@ def test_empty_directory(self): # This should fail when trying to create ModelConfig with self.assertRaises(Exception): ModelConfig( - model_path=self.temp_dir, trust_remote_code=True # Empty directory + model_path=self.temp_dir, + trust_remote_code=True, # Empty directory ) From fbb2cf124604957a2ec0eaf3df19a12409f23ed2 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 17 Oct 2025 03:48:28 +0000 Subject: [PATCH 06/18] Format --- .pre-commit-config.yaml | 1 - .../flash_attention/bench_flashattention.py | 2 +- .../flash_attention/get_block_spec_config.py | 4 +- .../megablox_gmm/bench_megablox_gmm.py | 2 +- .../update_kv_cache/bench_update_kv_cache.py | 4 +- python/pyproject.toml | 39 ++ python/sgl_jax/bench_offline_throughput.py | 21 +- python/sgl_jax/bench_one_batch.py | 22 +- python/sgl_jax/bench_one_batch_server.py | 9 +- python/sgl_jax/bench_serving.py | 129 ++++--- python/sgl_jax/check_env.py | 4 +- python/sgl_jax/profiler.py | 17 +- python/sgl_jax/srt/configs/load_config.py | 11 +- python/sgl_jax/srt/configs/model_config.py | 68 ++-- python/sgl_jax/srt/conversation.py | 17 +- python/sgl_jax/srt/entrypoints/EngineBase.py | 36 +- python/sgl_jax/srt/entrypoints/engine.py | 108 +++--- python/sgl_jax/srt/entrypoints/http_server.py | 84 +++-- .../srt/entrypoints/openai/protocol.py | 343 +++++++++--------- .../srt/entrypoints/openai/serving_base.py | 18 +- .../srt/entrypoints/openai/serving_chat.py | 56 +-- .../entrypoints/openai/serving_completions.py | 24 +- .../entrypoints/openai/serving_embedding.py | 8 +- .../srt/entrypoints/openai/serving_rerank.py | 15 +- .../srt/entrypoints/openai/serving_score.py | 3 +- .../srt/entrypoints/openai/usage_processor.py | 9 +- .../sgl_jax/srt/entrypoints/openai/utils.py | 11 +- .../srt/function_call/function_call_parser.py | 12 +- python/sgl_jax/srt/hf_transformers_utils.py | 36 +- .../flash_attn_kernel/flash_attention.py | 18 +- .../flash_attn_kernel/tuned_block_sizes.py | 9 +- .../attention/flashattention_backend.py | 7 +- .../srt/layers/attention/native_backend.py | 17 +- python/sgl_jax/srt/layers/embeddings.py | 23 +- .../srt/layers/gmm/megablox_gmm_kernel/gmm.py | 14 +- python/sgl_jax/srt/layers/linear.py | 10 +- python/sgl_jax/srt/layers/logits_processor.py | 41 +-- python/sgl_jax/srt/layers/moe.py | 12 +- python/sgl_jax/srt/layers/sampler.py | 6 +- .../srt/managers/detokenizer_manager.py | 21 +- python/sgl_jax/srt/managers/io_struct.py | 168 ++++----- python/sgl_jax/srt/managers/schedule_batch.py | 119 +++--- .../sgl_jax/srt/managers/schedule_policy.py | 36 +- python/sgl_jax/srt/managers/scheduler.py | 121 +++--- .../srt/managers/scheduler_metrics_mixin.py | 5 +- .../scheduler_output_processor_mixin.py | 34 +- .../srt/managers/scheduler_profiler_mixing.py | 25 +- .../sgl_jax/srt/managers/template_manager.py | 15 +- .../sgl_jax/srt/managers/tokenizer_manager.py | 168 +++++---- python/sgl_jax/srt/managers/tp_worker.py | 53 +-- .../srt/managers/tp_worker_overlap_thread.py | 7 +- python/sgl_jax/srt/managers/utils.py | 9 +- python/sgl_jax/srt/mem_cache/allocator.py | 21 +- .../srt/mem_cache/base_prefix_cache.py | 12 +- python/sgl_jax/srt/mem_cache/chunk_cache.py | 4 +- python/sgl_jax/srt/mem_cache/memory_pool.py | 139 ++----- python/sgl_jax/srt/mem_cache/radix_cache.py | 28 +- python/sgl_jax/srt/memory_profiler.py | 48 +-- .../srt/model_executor/forward_batch_info.py | 11 +- .../srt/model_executor/model_runner.py | 63 ++-- python/sgl_jax/srt/model_loader/arch.py | 7 +- python/sgl_jax/srt/model_loader/loader.py | 10 +- python/sgl_jax/srt/models/llama.py | 19 +- python/sgl_jax/srt/models/qwen.py | 6 +- python/sgl_jax/srt/models/qwen2.py | 13 +- python/sgl_jax/srt/models/qwen3.py | 13 +- python/sgl_jax/srt/models/qwen3_moe.py | 12 +- python/sgl_jax/srt/models/registry.py | 21 +- python/sgl_jax/srt/precision_tracer.py | 91 ++--- python/sgl_jax/srt/reasoning_parser.py | 11 +- .../srt/sampling/penaltylib/orchestrator.py | 12 +- .../srt/sampling/sampling_batch_info.py | 18 +- .../sgl_jax/srt/sampling/sampling_params.py | 38 +- python/sgl_jax/srt/server_args.py | 82 ++--- python/sgl_jax/srt/utils/common_utils.py | 69 ++-- python/sgl_jax/srt/utils/jax_utils.py | 1 - python/sgl_jax/srt/utils/mesh_utils.py | 2 +- python/sgl_jax/srt/utils/weight_utils.py | 103 +++--- .../sgl_jax/test/mem_cache/test_kv_cache.py | 2 +- .../test/model_executor/test_model_runner.py | 2 +- python/sgl_jax/test/models/test_qwen_model.py | 10 +- python/sgl_jax/test/run_jax_loader_test.py | 2 +- python/sgl_jax/test/run_qwen3_moe_test.py | 7 +- python/sgl_jax/test/run_qwen_test.py | 3 +- python/sgl_jax/test/runners.py | 4 +- python/sgl_jax/test/simple_eval_common.py | 57 +-- python/sgl_jax/test/simple_eval_gpqa.py | 3 +- python/sgl_jax/test/simple_eval_humaneval.py | 12 +- python/sgl_jax/test/simple_eval_math.py | 3 +- python/sgl_jax/test/simple_eval_mgsm.py | 3 +- python/sgl_jax/test/simple_eval_mmlu.py | 3 +- python/sgl_jax/test/test_flashattention.py | 11 +- python/sgl_jax/test/test_model_loader.py | 19 +- .../test/test_multi_process_radix_cache.py | 10 +- python/sgl_jax/test/test_sampler.py | 1 - python/sgl_jax/test/test_utils.py | 44 +-- python/sgl_jax/tools/trace_diff.py | 53 ++- python/sgl_jax/utils.py | 6 +- .../openai_server/basic/test_serving_chat.py | 9 +- .../test_openai_server_params_validation.py | 1 - test/srt/test_bench_one_batch.py | 1 - test/srt/test_eval_accuracy_large.py | 6 +- test/srt/test_features.py | 1 - 103 files changed, 1566 insertions(+), 1612 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21a62d62e..26c4173fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,6 @@ repos: args: [--output-format, github, --fix] files: ^(python/|benchmark/|docs/|examples/) exclude: \.ipynb$ - - id: ruff-format - repo: https://github.com/psf/black rev: 24.10.0 hooks: diff --git a/benchmark/kernels/flash_attention/bench_flashattention.py b/benchmark/kernels/flash_attention/bench_flashattention.py index 1683e1bfb..45d28e098 100644 --- a/benchmark/kernels/flash_attention/bench_flashattention.py +++ b/benchmark/kernels/flash_attention/bench_flashattention.py @@ -226,7 +226,7 @@ def main(): except Exception as e: raise ValueError(f"run failed: {e=}") - print(f"cost: {flash_time*1000}ms") + print(f"cost: {flash_time * 1000}ms") if __name__ == "__main__": diff --git a/benchmark/kernels/flash_attention/get_block_spec_config.py b/benchmark/kernels/flash_attention/get_block_spec_config.py index 9bdc1c090..ce7fa546f 100644 --- a/benchmark/kernels/flash_attention/get_block_spec_config.py +++ b/benchmark/kernels/flash_attention/get_block_spec_config.py @@ -202,7 +202,7 @@ def main(): block_spec_configs.append((num_kv_pages_per_blk, num_queries_per_block)) print( - f"(q_dtype, kv_dtype, num_q_heads_per_blk, num_kv_heads_per_blk, head_dim, page_size, max_num_batched_tokens): (num_kv_pages_per_block, num_queries_per_block)" + "(q_dtype, kv_dtype, num_q_heads_per_blk, num_kv_heads_per_blk, head_dim, page_size, max_num_batched_tokens): (num_kv_pages_per_block, num_queries_per_block)" ) for i, ( @@ -237,7 +237,7 @@ def main(): if flash_time < best_output: best_output = flash_time best_config = (num_kv_pages_per_blk, num_queries_per_block) - except Exception as e: + except Exception: pass print( diff --git a/benchmark/kernels/megablox_gmm/bench_megablox_gmm.py b/benchmark/kernels/megablox_gmm/bench_megablox_gmm.py index 8f7212633..e2e04e665 100644 --- a/benchmark/kernels/megablox_gmm/bench_megablox_gmm.py +++ b/benchmark/kernels/megablox_gmm/bench_megablox_gmm.py @@ -138,7 +138,7 @@ def main(): ) print( - f"Config {valid_config_count}: m={adjusted_m}, k={k}, n={n}, groups={num_groups}, group_size={adjusted_m//num_groups}" + f"Config {valid_config_count}: m={adjusted_m}, k={k}, n={n}, groups={num_groups}, group_size={adjusted_m // num_groups}" ) try: diff --git a/benchmark/kernels/update_kv_cache/bench_update_kv_cache.py b/benchmark/kernels/update_kv_cache/bench_update_kv_cache.py index a60d84844..e26651e49 100644 --- a/benchmark/kernels/update_kv_cache/bench_update_kv_cache.py +++ b/benchmark/kernels/update_kv_cache/bench_update_kv_cache.py @@ -161,11 +161,11 @@ def main(): min_cost = cost fastest_num_slices_per_block = num_slices_per_block print( - f"[num_slices_per_block={num_slices_per_block}] avg cost: {cost*1000} ms" + f"[num_slices_per_block={num_slices_per_block}] avg cost: {cost * 1000} ms" ) print( - f"Fastest [num_slices_per_block={fastest_num_slices_per_block}] costs: {min_cost*1000} ms" + f"Fastest [num_slices_per_block={fastest_num_slices_per_block}] costs: {min_cost * 1000} ms" ) diff --git a/python/pyproject.toml b/python/pyproject.toml index 1938187ae..a490e9485 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -66,3 +66,42 @@ exclude = [ "scripts*", "tests*", ] + +[tool.black] +line-length = 100 + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort (handled by pre-commit isort hook) + # "I", + # flake8-logging-format + "G", +] +ignore = [ + # star imports + "F405", "F403", + # line too long + "E501", + # lambda expression assignment + "E731", + # zip without `strict=` + "B905", + # Loop control variable not used within loop body + "B007", +] + +[tool.ruff.format] +docstring-code-format = true diff --git a/python/sgl_jax/bench_offline_throughput.py b/python/sgl_jax/bench_offline_throughput.py index 5087c877b..08ca400ec 100644 --- a/python/sgl_jax/bench_offline_throughput.py +++ b/python/sgl_jax/bench_offline_throughput.py @@ -19,7 +19,6 @@ import os import random import time -from typing import Dict, List, Optional import numpy as np @@ -41,8 +40,8 @@ class BenchArgs: dataset_name: str = "sharegpt" dataset_path: str = "" num_prompts: int = 1000 - sharegpt_output_len: Optional[int] = None - sharegpt_context_len: Optional[int] = None + sharegpt_output_len: int | None = None + sharegpt_context_len: int | None = None random_input_len: int = 1024 random_output_len: int = 1024 random_range_ratio: float = 0.0 @@ -53,7 +52,7 @@ class BenchArgs: gsp_output_len: int = 256 seed: int = 1 disable_ignore_eos: bool = False - extra_request_body: Optional[str] = None + extra_request_body: str | None = None apply_chat_template: bool = False profile: bool = False skip_warmup: bool = False @@ -110,15 +109,13 @@ def add_cli_args(parser: argparse.ArgumentParser): "--random-range-ratio", type=float, default=BenchArgs.random_range_ratio, - help="Range of sampled ratio of input/output length, " - "used only for random dataset.", + help="Range of sampled ratio of input/output length, used only for random dataset.", ) parser.add_argument( "--gsp-num-groups", type=int, default=BenchArgs.gsp_num_groups, - help="Number of groups with shared prefix, used" - "only for generate-shared-prefix", + help="Number of groups with shared prefix, usedonly for generate-shared-prefix", ) parser.add_argument( "--gsp-prompts-per-group", @@ -131,13 +128,13 @@ def add_cli_args(parser: argparse.ArgumentParser): "--gsp-system-prompt-len", type=int, default=BenchArgs.gsp_system_prompt_len, - help="System prompt length, used" "only for generate-shared-prefix", + help="System prompt length, usedonly for generate-shared-prefix", ) parser.add_argument( "--gsp-question-len", type=int, default=BenchArgs.gsp_question_len, - help="Question length, used" "only for generate-shared-prefix", + help="Question length, usedonly for generate-shared-prefix", ) parser.add_argument( "--gsp-output-len", @@ -196,9 +193,9 @@ def from_cli_args(cls, args: argparse.Namespace): def throughput_test_once( backend_name: str, backend, - reqs: List[DatasetRow], + reqs: list[DatasetRow], ignore_eos: bool, - extra_request_body: Dict, + extra_request_body: dict, profile: bool, ): measurement_results = { diff --git a/python/sgl_jax/bench_one_batch.py b/python/sgl_jax/bench_one_batch.py index c8d08cb8b..0260ccd19 100644 --- a/python/sgl_jax/bench_one_batch.py +++ b/python/sgl_jax/bench_one_batch.py @@ -49,7 +49,6 @@ import logging import os import time -from typing import Tuple import jax import numpy as np @@ -72,9 +71,9 @@ @dataclasses.dataclass class BenchArgs: run_name: str = "default" - batch_size: Tuple[int] = (1,) - input_len: Tuple[int] = (1024,) - output_len: Tuple[int] = (16,) + batch_size: tuple[int] = (1,) + input_len: tuple[int] = (1024,) + output_len: tuple[int] = (16,) result_filename: str = "result.jsonl" correctness_test: bool = False # This is only used for correctness test @@ -156,9 +155,9 @@ def load_model(server_args, port_args, tp_rank): if tp > 1: try: jax_mh.sync_global_devices("load_model") - except Exception as e: + except Exception as err: logging.info( - f"Could not sync global devices (expected in single-host): {e}" + "Could not sync global devices (expected in single-host): %s", err ) return model_runner, tokenizer @@ -560,7 +559,11 @@ def main(server_args, bench_args): tokens_needed = (tokens_needed // page) * page server_args.max_total_tokens = max(tokens_needed, page) logging.info( - f"Setting max_total_tokens={server_args.max_total_tokens} (bs={bs_max}, in={in_max}, out={out_max}) to limit static KV memory on single TPU" + "Setting max_total_tokens=%s (bs=%s, in=%s, out=%s) to limit static KV memory on single TPU", + server_args.max_total_tokens, + bs_max, + in_max, + out_max, ) # Prefer native attention on single-TPU runs to avoid large FA compile-time temps @@ -577,10 +580,7 @@ def main(server_args, bench_args): _set_envs_and_config() if server_args.model_path: - if bench_args.correctness_test: - work_func = correctness_test - else: - work_func = latency_test + work_func = correctness_test if bench_args.correctness_test else latency_test else: raise ValueError( "Provide --model-path for running the tests or " diff --git a/python/sgl_jax/bench_one_batch_server.py b/python/sgl_jax/bench_one_batch_server.py index 1680c229a..a4a5da29f 100644 --- a/python/sgl_jax/bench_one_batch_server.py +++ b/python/sgl_jax/bench_one_batch_server.py @@ -18,7 +18,6 @@ import multiprocessing import os import time -from typing import Tuple import requests @@ -33,9 +32,9 @@ @dataclasses.dataclass class BenchArgs: run_name: str = "default" - batch_size: Tuple[int] = (1,) - input_len: Tuple[int] = (1024,) - output_len: Tuple[int] = (16,) + batch_size: tuple[int] = (1,) + input_len: tuple[int] = (1024,) + output_len: tuple[int] = (16,) temperature: float = 0.0 return_logprob: bool = False client_stream_interval: int = 1 @@ -372,7 +371,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): f"{input_throughput:.2f} | " f"{output_throughput:.2f} | " f"{accept_length} | " - f"{1 / (output_throughput/batch_size) * 1000:.2f} | " + f"{1 / (output_throughput / batch_size) * 1000:.2f} | " f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | " f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |" ) diff --git a/python/sgl_jax/bench_serving.py b/python/sgl_jax/bench_serving.py index 486261a4d..7e5cbea99 100644 --- a/python/sgl_jax/bench_serving.py +++ b/python/sgl_jax/bench_serving.py @@ -22,11 +22,12 @@ import traceback import warnings from argparse import ArgumentParser +from collections.abc import AsyncGenerator from dataclasses import dataclass, field from datetime import datetime from json import JSONDecodeError from pathlib import Path -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union +from typing import Any import aiohttp import numpy as np @@ -72,7 +73,7 @@ class RequestFuncInput: model: str lora_name: str image_data: str - extra_request_body: Dict[str, Any] + extra_request_body: dict[str, Any] @dataclass @@ -82,7 +83,7 @@ class RequestFuncOutput: latency: float = 0.0 ttft: float = 0.0 # Time to first token # List of inter-token latencies - itl: List[float] = field(default_factory=list) + itl: list[float] = field(default_factory=list) prompt_len: int = 0 error: str = "" output_len: int = 0 @@ -102,7 +103,7 @@ def remove_suffix(text: str, suffix: str) -> str: return text[: -len(suffix)] if text.endswith(suffix) else text -def get_auth_headers() -> Dict[str, str]: +def get_auth_headers() -> dict[str, str]: api_key = os.environ.get("OPENAI_API_KEY") if api_key: return {"Authorization": f"Bearer {api_key}"} @@ -114,7 +115,7 @@ def get_auth_headers() -> Dict[str, str]: # https://github.com/triton-inference-server/tensorrtllm_backend/issues/505 async def async_request_trt_llm( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith("generate_stream") @@ -183,7 +184,7 @@ async def async_request_trt_llm( # set ignore_eos True by default async def async_request_openai_completions( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( @@ -268,7 +269,7 @@ async def async_request_openai_completions( async def async_request_truss( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url @@ -346,7 +347,7 @@ async def async_request_truss( async def async_request_sglang_generate( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url prompt = request_func_input.prompt @@ -445,7 +446,7 @@ async def async_request_sglang_generate( async def async_request_gserver( request_func_input: RequestFuncInput, - pbar: Optional[tqdm] = None, + pbar: tqdm | None = None, ) -> RequestFuncOutput: raise NotImplementedError() @@ -485,7 +486,7 @@ def get_model(pretrained_model_name_or_path: str) -> str: def get_tokenizer( pretrained_model_name_or_path: str, -) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: +) -> PreTrainedTokenizer | PreTrainedTokenizerFast: assert ( pretrained_model_name_or_path is not None and pretrained_model_name_or_path != "" @@ -603,7 +604,7 @@ class BenchmarkMetrics: SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" -def download_and_cache_file(url: str, filename: Optional[str] = None): +def download_and_cache_file(url: str, filename: str | None = None): """Read and cache a file from a url.""" if filename is None: filename = os.path.join("/tmp", url.split("/")[-1]) @@ -623,13 +624,16 @@ def download_and_cache_file(url: str, filename: Optional[str] = None): chunk_size = 1024 # Download in chunks of 1KB # Use tqdm to display the progress bar - with open(filename, "wb") as f, tqdm( - desc=filename, - total=total_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as bar: + with ( + open(filename, "wb") as f, + tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar, + ): for chunk in response.iter_content(chunk_size=chunk_size): f.write(chunk) bar.update(len(chunk)) @@ -657,15 +661,15 @@ class DatasetRow: prompt: str prompt_len: int output_len: int - image_data: Optional[str] = None + image_data: str | None = None def sample_mmmu_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int] = None, + fixed_output_len: int | None = None, random_sample: bool = True, -) -> List[DatasetRow]: +) -> list[DatasetRow]: """ Sample requests from the MMMU dataset using HuggingFace datasets. @@ -683,8 +687,8 @@ def sample_mmmu_requests( import io from datasets import load_dataset - except ImportError: - raise ImportError("Please install datasets: pip install datasets") + except ImportError as err: + raise ImportError("Please install datasets: pip install datasets") from err print("Loading MMMU dataset from HuggingFace...") @@ -694,9 +698,9 @@ def sample_mmmu_requests( print( f"Successfully loaded MMMU Math dataset from HuggingFace with {len(mmmu_dataset)} examples" ) - except Exception as e: - print(f"Failed to load MMMU Math dataset: {e}") - raise ValueError(f"Failed to load MMMU dataset: {e}") + except Exception as err: + print(f"Failed to load MMMU Math dataset: {err}") + raise ValueError(f"Failed to load MMMU dataset: {err}") from err # Sample from the dataset if len(mmmu_dataset) > num_requests: @@ -791,11 +795,11 @@ def sample_sharegpt_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, - fixed_output_len: Optional[int] = None, - context_len: Optional[int] = None, - prompt_suffix: Optional[str] = "", + fixed_output_len: int | None = None, + context_len: int | None = None, + prompt_suffix: str | None = "", apply_chat_template=False, -) -> List[DatasetRow]: +) -> list[DatasetRow]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -826,7 +830,7 @@ def sample_sharegpt_requests( random.shuffle(dataset) # Filter out sequences that are too long or too short - filtered_dataset: List[DatasetRow] = [] + filtered_dataset: list[DatasetRow] = [] for i in range(len(dataset)): if len(filtered_dataset) == num_requests: break @@ -882,7 +886,7 @@ def sample_random_requests( dataset_path: str, random_sample: bool = True, return_text: bool = True, -) -> List[DatasetRow]: +) -> list[DatasetRow]: input_lens = np.random.randint( max(int(input_len * range_ratio), 1), input_len + 1, @@ -922,7 +926,7 @@ def sample_random_requests( random.shuffle(dataset) # Filter out sequences that are too long or too short - input_requests: List[DatasetRow] = [] + input_requests: list[DatasetRow] = [] for data in dataset: i = len(input_requests) if i == num_prompts: @@ -1004,7 +1008,7 @@ def sample_generated_shared_prefix_requests( output_len: int, tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace, -) -> List[DatasetRow]: +) -> list[DatasetRow]: """Generate benchmark requests with shared system prompts using random tokens and caching.""" cache_path = get_gen_prefix_cache_path(args, tokenizer) @@ -1077,7 +1081,7 @@ def sample_generated_shared_prefix_requests( async def get_request( - input_requests: List[DatasetRow], + input_requests: list[DatasetRow], request_rate: float, ) -> AsyncGenerator[DatasetRow, None]: input_requests = iter(input_requests) @@ -1095,20 +1099,20 @@ async def get_request( def calculate_metrics( - input_requests: List[DatasetRow], - outputs: List[RequestFuncOutput], + input_requests: list[DatasetRow], + outputs: list[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, backend: str, -) -> Tuple[BenchmarkMetrics, List[int]]: - output_lens: List[int] = [] - retokenized_output_lens: List[int] = [] +) -> tuple[BenchmarkMetrics, list[int]]: + output_lens: list[int] = [] + retokenized_output_lens: list[int] = [] total_input = 0 completed = 0 - itls: List[float] = [] - tpots: List[float] = [] - ttfts: List[float] = [] - e2e_latencies: List[float] = [] + itls: list[float] = [] + tpots: list[float] = [] + ttfts: list[float] = [] + e2e_latencies: list[float] = [] for i in range(len(outputs)): if outputs[i].success: output_len = outputs[i].output_len @@ -1179,12 +1183,12 @@ async def benchmark( base_url: str, model_id: str, tokenizer: PreTrainedTokenizerBase, - input_requests: List[DatasetRow], + input_requests: list[DatasetRow], request_rate: float, - max_concurrency: Optional[int], + max_concurrency: int | None, disable_tqdm: bool, - lora_names: List[str], - extra_request_body: Dict[str, Any], + lora_names: list[str], + extra_request_body: dict[str, Any], profile: bool, pd_separated: bool = False, flush_cache: bool = False, @@ -1211,10 +1215,9 @@ async def limited_request_func(request_func_input, pbar): # Use the first request for all warmup iterations test_request = input_requests[0] - if lora_names is not None and len(lora_names) != 0: - lora_name = lora_names[0] - else: - lora_name = None + lora_name = ( + lora_names[0] if lora_names is not None and len(lora_names) != 0 else None + ) # Create the test input once test_input = RequestFuncInput( @@ -1267,13 +1270,13 @@ async def limited_request_func(request_func_input, pbar): # Run all requests benchmark_start_time = time.perf_counter() - tasks: List[asyncio.Task] = [] + tasks: list[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): - if lora_names is not None and len(lora_names) != 0: - idx = random.randint(0, len(lora_names) - 1) - lora_name = lora_names[idx] - else: - lora_name = None + lora_name = ( + lora_names[random.randint(0, len(lora_names) - 1)] + if lora_names is not None and len(lora_names) != 0 + else None + ) request_func_input = RequestFuncInput( model=model_id, @@ -1291,7 +1294,7 @@ async def limited_request_func(request_func_input, pbar): limited_request_func(request_func_input=request_func_input, pbar=pbar) ) ) - outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) # Stop profiler if profile: @@ -1458,10 +1461,7 @@ async def limited_request_func(request_func_input, pbar): # Append results to a JSONL file with open(output_file_name, "a") as file: - if args.output_details: - result_for_dump = result | result_details - else: - result_for_dump = result + result_for_dump = result | result_details if args.output_details else result file.write(json.dumps(result_for_dump) + "\n") return result | result_details @@ -1726,8 +1726,7 @@ def __call__(self, parser, namespace, values, option_string=None): "--random-range-ratio", type=float, default=0.0, - help="Range of sampled ratio of input/output length, " - "used only for random dataset.", + help="Range of sampled ratio of input/output length, used only for random dataset.", ) parser.add_argument( "--request-rate", diff --git a/python/sgl_jax/check_env.py b/python/sgl_jax/check_env.py index 47127ef0b..62f43e55c 100644 --- a/python/sgl_jax/check_env.py +++ b/python/sgl_jax/check_env.py @@ -56,9 +56,7 @@ def get_device_info(): if len(device_list) == 0: return {"Device": "no device found"} for i, device in enumerate(device_list): - if "TPU" in device.device_kind: - device_info[f"[{device.device_kind}-{i}]"] = f"{device}" - elif "cpu" in device.device_kind: + if "TPU" in device.device_kind or "cpu" in device.device_kind: device_info[f"[{device.device_kind}-{i}]"] = f"{device}" else: raise ValueError(f"invalid device kind: {device.device_kind}") diff --git a/python/sgl_jax/profiler.py b/python/sgl_jax/profiler.py index d872ca320..c7fd18dd0 100644 --- a/python/sgl_jax/profiler.py +++ b/python/sgl_jax/profiler.py @@ -11,7 +11,6 @@ import time from argparse import ArgumentParser from pathlib import Path -from typing import List, Optional import requests @@ -19,11 +18,11 @@ def _run_profile( - url: Optional[str], + url: str | None, num_steps: int, - activities: List[str], - output_dir: Optional[str] = None, - profile_name: Optional[str] = None, + activities: list[str], + output_dir: str | None = None, + profile_name: str | None = None, profile_by_stage: bool = False, ) -> str: if output_dir is None: @@ -70,11 +69,11 @@ def _run_profile( def run_profile( - url: Optional[str], + url: str | None, num_steps: int, - activities: List[str], - output_dir: Optional[str] = None, - profile_name: Optional[str] = None, + activities: list[str], + output_dir: str | None = None, + profile_name: str | None = None, profile_by_stage: bool = False, ): # step based profile will self terminate on num_steps constraints diff --git a/python/sgl_jax/srt/configs/load_config.py b/python/sgl_jax/srt/configs/load_config.py index 1d99ac867..75d730c0a 100644 --- a/python/sgl_jax/srt/configs/load_config.py +++ b/python/sgl_jax/srt/configs/load_config.py @@ -2,7 +2,6 @@ import json import logging from dataclasses import dataclass, field -from typing import List, Optional, Union logger = logging.getLogger(__name__) @@ -45,11 +44,11 @@ class LoadConfig: from this file (after PBKDF2). """ - load_format: Union[str, LoadFormat] = LoadFormat.AUTO - download_dir: Optional[str] = None - model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) - ignore_patterns: Optional[Union[List[str], str]] = None - decryption_key_file: Optional[str] = None + load_format: str | LoadFormat = LoadFormat.AUTO + download_dir: str | None = None + model_loader_extra_config: str | dict | None = field(default_factory=dict) + ignore_patterns: list[str] | str | None = None + decryption_key_file: str | None = None def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} diff --git a/python/sgl_jax/srt/configs/model_config.py b/python/sgl_jax/srt/configs/model_config.py index 60c9028e7..f0d90dd6d 100644 --- a/python/sgl_jax/srt/configs/model_config.py +++ b/python/sgl_jax/srt/configs/model_config.py @@ -2,7 +2,6 @@ import logging import os from enum import Enum, IntEnum, auto -from typing import List, Optional, Set, Union import jax.numpy as jnp from transformers import PretrainedConfig @@ -35,18 +34,17 @@ def __init__( self, model_path: str, trust_remote_code: bool = True, - revision: Optional[str] = None, - context_length: Optional[int] = None, + revision: str | None = None, + context_length: int | None = None, model_override_args: str = "{}", - is_embedding: Optional[bool] = None, + is_embedding: bool | None = None, dtype: str = "auto", - override_config_file: Optional[str] = None, + override_config_file: str | None = None, is_draft_model: bool = False, - model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, - quantization: Optional[str] = None, - model_layer_nums: Optional[int] = None, + model_impl: str | ModelImpl = ModelImpl.AUTO, + quantization: str | None = None, + model_layer_nums: int | None = None, ) -> None: - self.model_path = model_path self.revision = revision self.model_impl = model_impl @@ -102,15 +100,14 @@ def __init__( "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True" ): logger.warning( - f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " - f"This may lead to incorrect model outputs or CUDA errors." + "Warning: User-specified context_length (%s) is greater than the derived context_length (%s). This may lead to incorrect model outputs or CUDA errors.", + context_length, + derived_context_len, ) self.context_len = context_length else: raise ValueError( - f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " - f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. " - f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1" + f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1" ) else: self.context_len = context_length @@ -150,8 +147,9 @@ def __init__( ) if model_layer_nums > self.num_hidden_layers: logger.warning( - f"model_layer_nums ({model_layer_nums}) is greater than the original " - f"num_hidden_layers ({self.num_hidden_layers}). Using original value." + "model_layer_nums (%s) is greater than the original num_hidden_layers (%s). Using original value.", + model_layer_nums, + self.num_hidden_layers, ) elif model_layer_nums != self.num_hidden_layers: self.num_hidden_layers = model_layer_nums @@ -299,7 +297,7 @@ def configure_for_tensor_parallel(self, tensor_parallel_size: int): self.hf_text_config.num_key_value_heads = total_kv_heads else: # For MHA models, dynamically add the attribute - setattr(self.hf_text_config, "num_key_value_heads", total_kv_heads) + self.hf_text_config.num_key_value_heads = total_kv_heads def get_original_kv_head_id(self, tp_rank: int, tensor_parallel_size: int) -> int: """Determine which original KV head this device should use.""" @@ -335,18 +333,22 @@ def log_kv_heads_info(self, tensor_parallel_size: int): if needs_replication: num_replicas = self.get_num_kv_head_replicas(tensor_parallel_size) logger.info( - f"KV heads replication enabled for {model_type} model: " - f"original_kv_heads={original_kv_heads}, tp_size={tensor_parallel_size}, " - f"each device gets {kv_heads_per_device} head(s), " - f"each original head replicated {num_replicas} times, " - f"padding_strategy={padding_strategy}" + "KV heads replication enabled for %s model: original_kv_heads=%s, tp_size=%s, each device gets %s head(s), each original head replicated %s times, padding_strategy=%s", + model_type, + original_kv_heads, + tensor_parallel_size, + kv_heads_per_device, + num_replicas, + padding_strategy, ) else: logger.info( - f"KV heads distribution for {model_type} model: " - f"original_kv_heads={original_kv_heads}, tp_size={tensor_parallel_size}, " - f"each device gets {kv_heads_per_device} head(s), no replication needed, " - f"padding_strategy={padding_strategy}" + "KV heads distribution for %s model: original_kv_heads=%s, tp_size=%s, each device gets %s head(s), no replication needed, padding_strategy=%s", + model_type, + original_kv_heads, + tensor_parallel_size, + kv_heads_per_device, + padding_strategy, ) def validate_tensor_parallel_config(self, tensor_parallel_size: int): @@ -390,7 +392,7 @@ def _parse_quant_hf_config(self): quant_cfg = modelopt_quant_config return quant_cfg - def get_hf_eos_token_id(self) -> Optional[Set[int]]: + def get_hf_eos_token_id(self) -> set[int] | None: eos_ids = getattr(self.hf_config, "eos_token_id", None) if eos_ids: # it can be either int or list of int @@ -439,11 +441,11 @@ def maybe_pull_model_tokenizer_from_remote(self) -> None: def _get_and_verify_dtype( config: PretrainedConfig, - dtype: Union[str, jnp.dtype], + dtype: str | jnp.dtype, ) -> jnp.dtype: config_dtype = getattr(config, "torch_dtype", None) if isinstance(config_dtype, str): - config_dtype = _STR_DTYPE_TO_JAX_DTYPE.get(config_dtype, None) + config_dtype = _STR_DTYPE_TO_JAX_DTYPE.get(config_dtype) elif config_dtype is not None: config_dtype = _STR_DTYPE_TO_JAX_DTYPE.get( str(config_dtype).replace("torch.", ""), None @@ -458,8 +460,8 @@ def _get_and_verify_dtype( jax_dtype = config_dtype if config_dtype != jnp.bfloat16: logger.warning( - f"Model dtype is {config_dtype}. " - "On TPU, using non-bfloat16 models may reduce performance." + "Model dtype is %s. On TPU, using non-bfloat16 models may reduce performance.", + config_dtype, ) else: if dtype not in _STR_DTYPE_TO_JAX_DTYPE: @@ -483,10 +485,10 @@ def _get_and_verify_dtype( else: # Casting between float16 and bfloat16 is allowed with a warning. logger.warning("Casting %s to %s.", config_dtype, jax_dtype) - return jax_dtype + return jax_dtype -def is_generation_model(model_architectures: List[str], is_embedding: bool = False): +def is_generation_model(model_architectures: list[str], is_embedding: bool = False): # We have two ways to determine whether a model is a generative model. # 1. Check the model architecture # 2. check the `is_embedding` server args diff --git a/python/sgl_jax/srt/conversation.py b/python/sgl_jax/srt/conversation.py index 98e9525f9..15b7ba9c4 100644 --- a/python/sgl_jax/srt/conversation.py +++ b/python/sgl_jax/srt/conversation.py @@ -4,15 +4,16 @@ """ import logging -from typing import Any, Callable, Dict, List, Optional +from collections.abc import Callable +from typing import Any logger = logging.getLogger(__name__) def generate_chat_conv( - messages: List[Dict[str, Any]], + messages: list[dict[str, Any]], tokenizer: Any = None, - chat_template: Optional[str] = None, + chat_template: str | None = None, ) -> str: """ Generate a conversation from chat messages. @@ -25,7 +26,7 @@ def generate_chat_conv( Returns: Formatted conversation string """ - logger.info(f"Generating chat conversation from {len(messages)} messages") + logger.info("Generating chat conversation from %s messages", len(messages)) # Basic implementation - just concatenate messages conversation_parts = [] @@ -35,7 +36,7 @@ def generate_chat_conv( conversation_parts.append(f"{role}: {content}") conversation = "\n".join(conversation_parts) - logger.debug(f"Generated conversation: {conversation[:100]}...") + logger.debug("Generated conversation: %s...", conversation[:100]) return conversation @@ -43,7 +44,7 @@ def generate_chat_conv( class Conversation: """Basic conversation class for handling chat interactions.""" - def __init__(self, messages: List[Dict[str, Any]]): + def __init__(self, messages: list[dict[str, Any]]): self.messages = messages self.history = [] @@ -63,8 +64,8 @@ def clear(self): # A global registry for all conversation templates -chat_templates: Dict[str, Conversation] = {} -matching_function_registry: List[Callable] = [] +chat_templates: dict[str, Conversation] = {} +matching_function_registry: list[Callable] = [] def register_conv_template(template: Conversation, override: bool = False): diff --git a/python/sgl_jax/srt/entrypoints/EngineBase.py b/python/sgl_jax/srt/entrypoints/EngineBase.py index 73c9abc6d..f1e45e067 100644 --- a/python/sgl_jax/srt/entrypoints/EngineBase.py +++ b/python/sgl_jax/srt/entrypoints/EngineBase.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, Iterator, List, Optional, Union +from collections.abc import Iterator class EngineBase(ABC): @@ -11,23 +11,23 @@ class EngineBase(ABC): @abstractmethod def generate( self, - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - image_data: Optional[Union[List[str], str]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, - lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None, - custom_logit_processor: Optional[Union[List[str], str]] = None, - return_hidden_states: Optional[bool] = None, - stream: Optional[bool] = None, - bootstrap_host: Optional[Union[List[str], str]] = None, - bootstrap_port: Optional[Union[List[int], int]] = None, - bootstrap_room: Optional[Union[List[int], int]] = None, - data_parallel_rank: Optional[int] = None, - ) -> Union[Dict, Iterator[Dict]]: + prompt: list[str] | str | None = None, + sampling_params: list[dict] | dict | None = None, + input_ids: list[list[int]] | list[int] | None = None, + image_data: list[str] | str | None = None, + return_logprob: list[bool] | bool | None = False, + logprob_start_len: list[int] | int | None = None, + top_logprobs_num: list[int] | int | None = None, + token_ids_logprob: list[list[int]] | list[int] | None = None, + lora_path: list[str | None] | str | None | None = None, + custom_logit_processor: list[str] | str | None = None, + return_hidden_states: bool | None = None, + stream: bool | None = None, + bootstrap_host: list[str] | str | None = None, + bootstrap_port: list[int] | int | None = None, + bootstrap_room: list[int] | int | None = None, + data_parallel_rank: int | None = None, + ) -> dict | Iterator[dict]: """Generate outputs based on given inputs.""" pass diff --git a/python/sgl_jax/srt/entrypoints/engine.py b/python/sgl_jax/srt/entrypoints/engine.py index 9778a5763..f8de4c8d8 100644 --- a/python/sgl_jax/srt/entrypoints/engine.py +++ b/python/sgl_jax/srt/entrypoints/engine.py @@ -4,24 +4,25 @@ This file implements python APIs for the inference engine. """ -import json -import uvloop import asyncio import atexit import dataclasses +import json import logging import multiprocessing as mp import os import signal import threading -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union +from collections.abc import AsyncIterator, Iterator +from typing import Any +import uvloop import zmq import zmq.asyncio # ruff: noqa: E402 # Fix a bug of Python threading -setattr(threading, "_register_atexit", lambda *args, **kwargs: None) +threading._register_atexit = lambda *args, **kwargs: None from sgl_jax.srt.entrypoints.EngineBase import EngineBase from sgl_jax.srt.hf_transformers_utils import get_generation_config @@ -91,7 +92,7 @@ def __init__(self, **kwargs): # Allocate ports for inter-process communications self.port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") + logger.info("server_args=%s", server_args) # Launch subprocesses or threads tokenizer_manager, template_manager, scheduler_info = ( @@ -104,7 +105,7 @@ def __init__(self, **kwargs): self.tokenizer_manager = tokenizer_manager self.template_manager = template_manager self.scheduler_info = scheduler_info - self.default_sampling_params: Union[dict[str, Any], None] = None + self.default_sampling_params: dict[str, Any] | None = None context = zmq.Context(2) self.send_to_rpc = get_zmq_socket( context, zmq.DEALER, self.port_args.rpc_ipc_name, True @@ -112,16 +113,16 @@ def __init__(self, **kwargs): def generate( self, - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, + prompt: list[str] | str | None = None, + sampling_params: list[dict] | dict | None = None, # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, + input_ids: list[list[int]] | list[int] | None = None, + return_logprob: list[bool] | bool | None = False, + logprob_start_len: list[int] | int | None = None, + top_logprobs_num: list[int] | int | None = None, + token_ids_logprob: list[list[int]] | list[int] | None = None, stream: bool = False, - ) -> Union[Dict, Iterator[Dict]]: + ) -> dict | Iterator[dict]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. Please refer to `GenerateReqInput` for the documentation. @@ -164,15 +165,15 @@ def generator_wrapper(): async def async_generate( self, - sampling_params: Optional[Union[List[Dict], Dict]] = None, + sampling_params: list[dict] | dict | None = None, # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None, + input_ids: list[list[int]] | list[int] | None = None, + return_logprob: list[bool] | bool | None = False, + logprob_start_len: list[int] | int | None = None, + top_logprobs_num: list[int] | int | None = None, + token_ids_logprob: list[list[int]] | list[int] | None = None, stream: bool = False, - ) -> Union[Dict, AsyncIterator[Dict]]: + ) -> dict | AsyncIterator[dict]: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. Please refer to `GenerateReqInput` for the documentation. @@ -199,8 +200,8 @@ async def async_generate( def encode( self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - ) -> Dict: + prompt: str | list[str] | list[dict] | list[list[dict]], + ) -> dict: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. Please refer to `EmbeddingReqInput` for the documentation. @@ -215,8 +216,8 @@ def encode( async def async_encode( self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - ) -> Dict: + prompt: str | list[str] | list[dict] | list[list[dict]], + ) -> dict: """ Asynchronous version of encode method. @@ -231,8 +232,8 @@ async def async_encode( def rerank( self, - prompt: Union[List[List[str]]], - ) -> Dict: + prompt: list[list[str]], + ) -> dict: """ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. Please refer to `EmbeddingReqInput` for the documentation. @@ -280,14 +281,14 @@ def get_server_info(self): "version": __version__, } - def release_memory_occupation(self, tags: Optional[List[str]] = None): + def release_memory_occupation(self, tags: list[str] | None = None): obj = ReleaseMemoryOccupationReqInput(tags=tags) loop = asyncio.get_event_loop() return loop.run_until_complete( self.tokenizer_manager.release_memory_occupation(obj, None) ) - def resume_memory_occupation(self, tags: Optional[List[str]] = None): + def resume_memory_occupation(self, tags: list[str] | None = None): obj = ResumeMemoryOccupationReqInput(tags=tags) loop = asyncio.get_event_loop() return loop.run_until_complete( @@ -296,12 +297,12 @@ def resume_memory_occupation(self, tags: Optional[List[str]] = None): def score( self, - query: Optional[Union[str, List[int]]] = None, - items: Optional[Union[str, List[str], List[List[int]]]] = None, - label_token_ids: Optional[List[int]] = None, + query: str | list[int] | None = None, + items: str | list[str] | list[list[int]] | None = None, + label_token_ids: list[int] | None = None, apply_softmax: bool = False, item_first: bool = False, - ) -> List[List[float]]: + ) -> list[list[float]]: """ Score the probability of specified token IDs appearing after the given (query + item) pair. For example: query = "<|user|>Is the following city the capital of France? " @@ -347,12 +348,12 @@ def score( async def async_score( self, - query: Optional[Union[str, List[int]]] = None, - items: Optional[Union[str, List[str], List[List[int]]]] = None, - label_token_ids: Optional[List[int]] = None, + query: str | list[int] | None = None, + items: str | list[str] | list[list[int]] | None = None, + label_token_ids: list[int] | None = None, apply_softmax: bool = False, item_first: bool = False, - ) -> List[List[float]]: + ) -> list[list[float]]: """ Asynchronous version of score method. @@ -416,9 +417,11 @@ def sigchld_handler(signum, frame): pid, exitcode = os.waitpid(0, os.WNOHANG) if exitcode != 0: logger.warning( - f"Child process unexpectedly failed with {exitcode=}. {pid=}" + "Child process unexpectedly failed with exitcode=%s. pid=%s", + exitcode, + pid, ) - logger.warning(f"Child process {pid=} frame={frame}") + logger.warning("Child process pid=%s frame=%s", pid, frame) signal.signal(signal.SIGCHLD, sigchld_handler) @@ -443,8 +446,8 @@ def sigquit_handler(signum, frame): def _launch_subprocesses( - server_args, port_args: Optional[PortArgs] = None -) -> Tuple[TokenizerManager, TemplateManager, Dict]: + server_args, port_args: PortArgs | None = None +) -> tuple[TokenizerManager, TemplateManager, dict]: # Configure global environment configure_logger(server_args) server_args.check_server_args() @@ -453,7 +456,7 @@ def _launch_subprocesses( # Allocate ports for inter-process communications if port_args is None: port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") + logger.info("server_args=%s", server_args) # If using model from www.modelscope.cn, first download the model. server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( @@ -497,7 +500,9 @@ def _launch_subprocesses( for proc in scheduler_procs: proc.join() logger.error( - f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + "Scheduler or DataParallelController %s terminated with %s", + proc.pid, + proc.exitcode, ) return None, None, None @@ -527,10 +532,11 @@ def _launch_subprocesses( data = scheduler_pipe_readers[i].recv() except EOFError: logger.error( - f"Node {i} jax_scheduler is dead. Please check if there are relevant logs." + "Node %s jax_scheduler is dead. Please check if there are relevant logs.", + i, ) scheduler_procs[i].join() - logger.error(f"Exit code: {scheduler_procs[i].exitcode}") + logger.error("Exit code: %s", scheduler_procs[i].exitcode) raise if data["status"] != "ready": @@ -546,8 +552,8 @@ def _launch_subprocesses( def _launch_threads( - server_args, port_args: Optional[PortArgs] = None -) -> Tuple[TokenizerManager, TemplateManager, Dict]: + server_args, port_args: PortArgs | None = None +) -> tuple[TokenizerManager, TemplateManager, dict]: # Configure global environment configure_logger(server_args) server_args.check_server_args() @@ -556,7 +562,7 @@ def _launch_threads( # Allocate ports for inter-process communications if port_args is None: port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") + logger.info("server_args=%s", server_args) # If using model from www.modelscope.cn, first download the model. server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( @@ -589,7 +595,7 @@ def _launch_threads( for thread in scheduler_threads: thread.join() logger.error( - f"Scheduler or DataParallelController {thread.name} terminated" + "Scheduler or DataParallelController %s terminated", thread.name ) return None, None, None @@ -628,8 +634,8 @@ def _launch_threads( def _launch_subprocesses_or_threads( - server_args, port_args: Optional[PortArgs] = None -) -> Tuple[TokenizerManager, TemplateManager, Dict]: + server_args, port_args: PortArgs | None = None +) -> tuple[TokenizerManager, TemplateManager, dict]: if server_args.enable_single_process: return _launch_threads(server_args, port_args) else: diff --git a/python/sgl_jax/srt/entrypoints/http_server.py b/python/sgl_jax/srt/entrypoints/http_server.py index 641bd8a18..d18a01fb5 100644 --- a/python/sgl_jax/srt/entrypoints/http_server.py +++ b/python/sgl_jax/srt/entrypoints/http_server.py @@ -13,12 +13,13 @@ import random import threading import time +from collections.abc import AsyncIterator, Callable from http import HTTPStatus -from typing import Any, AsyncIterator, Callable, Dict, List, Optional +from typing import Any # ruff: noqa: E402 # Fix a bug of Python threading -setattr(threading, "_register_atexit", lambda *args, **kwargs: None) +threading._register_atexit = lambda *args, **kwargs: None from contextlib import asynccontextmanager @@ -88,10 +89,10 @@ class _GlobalState: tokenizer_manager: TokenizerManager template_manager: TemplateManager - scheduler_info: Dict + scheduler_info: dict -_global_state: Optional[_GlobalState] = None +_global_state: _GlobalState | None = None def set_global_state(global_state: _GlobalState): @@ -150,10 +151,9 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE exc_str = str(exc) errors_str = str(exc.errors()) - if errors_str and errors_str != exc_str: - message = f"{exc_str} {errors_str}" - else: - message = exc_str + message = ( + f"{exc_str} {errors_str}" if errors_str and errors_str != exc_str else exc_str + ) err = ErrorResponse( message=message, @@ -231,9 +231,10 @@ async def gen(): "%H:%M:%S", time.localtime(_global_state.tokenizer_manager.last_receive_tstamp) ) logger.error( - f"Health check failed. Server couldn't get a response from detokenizer for last " - f"{HEALTH_CHECK_TIMEOUT} seconds. tic start time: {tic_time}. " - f"last_heartbeat time: {last_receive_time}" + "Health check failed. Server couldn't get a response from detokenizer for last %s seconds. tic start time: %s. last_heartbeat time: %s", + HEALTH_CHECK_TIMEOUT, + tic_time, + last_receive_time, ) _global_state.tokenizer_manager.rid_to_state.pop(rid, None) _global_state.tokenizer_manager.health_check_failed = True @@ -255,7 +256,7 @@ async def get_model_info(): @app.get("/get_server_info") async def get_server_info(): # Returns interna states per DP. - internal_states: List[Dict[Any, Any]] = ( + internal_states: list[dict[Any, Any]] = ( await _global_state.tokenizer_manager.get_internal_state() ) return { @@ -290,15 +291,19 @@ async def stream_results() -> AsyncIterator[bytes]: async for out in _global_state.tokenizer_manager.generate_request( obj, request ): - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" + yield ( + b"data: " + + orjson.dumps(out, option=orjson.OPT_NON_STR_KEYS) + + b"\n\n" + ) except ValueError as e: out = {"error": {"message": str(e)}} - logger.error(f"[http_server] Error: {e}") - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" + logger.error("[http_server] Error: %s", e) + yield ( + b"data: " + + orjson.dumps(out, option=orjson.OPT_NON_STR_KEYS) + + b"\n\n" + ) yield b"data: [DONE]\n\n" return StreamingResponse( @@ -313,7 +318,7 @@ async def stream_results() -> AsyncIterator[bytes]: ).__anext__() return ret except ValueError as e: - logger.error(f"[http_server] Error: {e}") + logger.error("[http_server] Error: %s", e) return _create_error_response(e) @@ -337,7 +342,7 @@ async def generate_from_file_request(file: UploadFile, request: Request): ).__anext__() return ret except ValueError as e: - logger.error(f"Error: {e}") + logger.error("Error: %s", e) return _create_error_response(e) @@ -377,7 +382,7 @@ async def flush_cache(): @app.api_route("/start_profile", methods=["GET", "POST"]) -async def start_profile_async(obj: Optional[ProfileReqInput] = None): +async def start_profile_async(obj: ProfileReqInput | None = None): """Start profiling.""" if obj is None: obj = ProfileReqInput() @@ -406,7 +411,7 @@ async def stop_profile_async(): @app.api_route("/start_trace", methods=["GET", "POST"]) -async def start_trace_async(obj: Optional[StartTraceReqInput] = None): +async def start_trace_async(obj: StartTraceReqInput | None = None): """Start precision tracing.""" if obj is None: obj = StartTraceReqInput() @@ -444,9 +449,9 @@ async def start_trace_async(obj: Optional[StartTraceReqInput] = None): result = await _global_state.tokenizer_manager.set_internal_state( SetInternalStateReq(request_id="trace_state", state_data=trace_state) ) - logger.info(f"[HTTP] Set internal state result: {result}") + logger.info("[HTTP] Set internal state result: %s", result) except Exception as e: - logger.info(f"[HTTP] Error setting internal state: {e}") + logger.info("[HTTP] Error setting internal state: %s", e) precision_tracer.stop_trace() return ORJSONResponse( content={ @@ -476,7 +481,7 @@ async def start_trace_async(obj: Optional[StartTraceReqInput] = None): @app.api_route("/stop_trace", methods=["GET", "POST"]) -async def stop_trace_async(obj: Optional[StopTraceReqInput] = None): +async def stop_trace_async(obj: StopTraceReqInput | None = None): """Stop precision tracing.""" try: output_file = precision_tracer.stop_trace() @@ -521,7 +526,7 @@ async def stop_trace_async(obj: Optional[StopTraceReqInput] = None): @app.api_route("/trace_status", methods=["GET", "POST"]) -async def trace_status_async(obj: Optional[TraceStatusReqInput] = None): +async def trace_status_async(obj: TraceStatusReqInput | None = None): """Get precision tracing status.""" try: return ORJSONResponse( @@ -789,8 +794,8 @@ def _create_error_response(e): def launch_server( server_args: ServerArgs, - pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, - launch_callback: Optional[Callable[[], None]] = None, + pipe_finish_writer: multiprocessing.connection.Connection | None = None, + launch_callback: Callable[[], None] | None = None, ): """ Launch SRT (SGLang Runtime) Server. @@ -859,7 +864,7 @@ def launch_server( def _execute_server_warmup( server_args: ServerArgs, - pipe_finish_writer: Optional[multiprocessing.connection.Connection], + pipe_finish_writer: multiprocessing.connection.Connection | None, ): headers = {} url = server_args.url() @@ -882,7 +887,7 @@ def _execute_server_warmup( if not success: if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") + logger.error("Initialization failed. warmup error: %s", last_traceback) kill_process_tree(os.getpid()) return success @@ -919,7 +924,7 @@ def _execute_server_warmup( last_traceback = get_exception_traceback() if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") + logger.error("Initialization failed. warmup error: %s", last_traceback) kill_process_tree(os.getpid()) return False @@ -930,15 +935,14 @@ def _execute_server_warmup( def _wait_and_warmup( server_args: ServerArgs, - pipe_finish_writer: Optional[multiprocessing.connection.Connection], - launch_callback: Optional[Callable[[], None]] = None, + pipe_finish_writer: multiprocessing.connection.Connection | None, + launch_callback: Callable[[], None] | None = None, ): - if not server_args.skip_server_warmup: - if not _execute_server_warmup( - server_args, - pipe_finish_writer, - ): - return + if not server_args.skip_server_warmup and not _execute_server_warmup( + server_args, + pipe_finish_writer, + ): + return logger.info("The server is fired up and ready to roll!") diff --git a/python/sgl_jax/srt/entrypoints/openai/protocol.py b/python/sgl_jax/srt/entrypoints/openai/protocol.py index 5d9fcb5f8..10b594915 100644 --- a/python/sgl_jax/srt/entrypoints/openai/protocol.py +++ b/python/sgl_jax/srt/entrypoints/openai/protocol.py @@ -2,10 +2,9 @@ import time from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Literal from pydantic import BaseModel, Field, root_validator, validator -from typing_extensions import Literal class ModelCard(BaseModel): @@ -15,68 +14,68 @@ class ModelCard(BaseModel): object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "sglang" - root: Optional[str] = None - max_model_len: Optional[int] = None + root: str | None = None + max_model_len: int | None = None class ModelList(BaseModel): """Model list consists of model cards.""" object: str = "list" - data: List[ModelCard] = Field(default_factory=list) + data: list[ModelCard] = Field(default_factory=list) class ErrorResponse(BaseModel): object: str = "error" message: str type: str - param: Optional[str] = None + param: str | None = None code: int class LogProbs(BaseModel): - text_offset: List[int] = Field(default_factory=list) - token_logprobs: List[Optional[float]] = Field(default_factory=list) - tokens: List[str] = Field(default_factory=list) - top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + text_offset: list[int] = Field(default_factory=list) + token_logprobs: list[float | None] = Field(default_factory=list) + tokens: list[str] = Field(default_factory=list) + top_logprobs: list[dict[str, float] | None] = Field(default_factory=list) class TopLogprob(BaseModel): token: str - bytes: List[int] + bytes: list[int] logprob: float class ChatCompletionTokenLogprob(BaseModel): token: str - bytes: List[int] + bytes: list[int] logprob: float - top_logprobs: List[TopLogprob] + top_logprobs: list[TopLogprob] class ChoiceLogprobs(BaseModel): # build for v1/chat/completions response - content: List[ChatCompletionTokenLogprob] + content: list[ChatCompletionTokenLogprob] class UsageInfo(BaseModel): prompt_tokens: int = 0 total_tokens: int = 0 - completion_tokens: Optional[int] = 0 + completion_tokens: int | None = 0 # only used to return cached tokens when --enable-cache-report is set - prompt_tokens_details: Optional[Dict[str, int]] = None + prompt_tokens_details: dict[str, int] | None = None class StreamOptions(BaseModel): - include_usage: Optional[bool] = False + include_usage: bool | None = False class JsonSchemaResponseFormat(BaseModel): name: str - description: Optional[str] = None + description: str | None = None # use alias to workaround pydantic conflict - schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) - strict: Optional[bool] = False + schema_: dict[str, object] | None = Field(alias="schema", default=None) + strict: bool | None = False class FileRequest(BaseModel): @@ -108,77 +107,77 @@ class BatchRequest(BaseModel): ) endpoint: str # The endpoint to be used for all requests in the batch completion_window: str # The time frame within which the batch should be processed - metadata: Optional[dict] = None # Optional custom metadata for the batch + metadata: dict | None = None # Optional custom metadata for the batch class BatchResponse(BaseModel): id: str object: str = "batch" endpoint: str - errors: Optional[dict] = None + errors: dict | None = None input_file_id: str completion_window: str status: str = "validating" - output_file_id: Optional[str] = None - error_file_id: Optional[str] = None + output_file_id: str | None = None + error_file_id: str | None = None created_at: int - in_progress_at: Optional[int] = None - expires_at: Optional[int] = None - finalizing_at: Optional[int] = None - completed_at: Optional[int] = None - failed_at: Optional[int] = None - expired_at: Optional[int] = None - cancelling_at: Optional[int] = None - cancelled_at: Optional[int] = None - request_counts: Optional[dict] = None - metadata: Optional[dict] = None + in_progress_at: int | None = None + expires_at: int | None = None + finalizing_at: int | None = None + completed_at: int | None = None + failed_at: int | None = None + expired_at: int | None = None + cancelling_at: int | None = None + cancelled_at: int | None = None + request_counts: dict | None = None + metadata: dict | None = None class CompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create model: str - prompt: Union[List[int], List[List[int]], str, List[str]] - best_of: Optional[int] = None + prompt: list[int] | list[list[int]] | str | list[str] + best_of: int | None = None echo: bool = False frequency_penalty: float = 0.0 - logit_bias: Optional[Dict[str, float]] = None - logprobs: Optional[int] = None + logit_bias: dict[str, float] | None = None + logprobs: int | None = None max_tokens: int = 16 n: int = 1 presence_penalty: float = 0.0 - seed: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None + seed: int | None = None + stop: str | list[str] | None = None stream: bool = False - stream_options: Optional[StreamOptions] = None - suffix: Optional[str] = None + stream_options: StreamOptions | None = None + suffix: str | None = None temperature: float = 1.0 top_p: float = 1.0 - user: Optional[str] = None + user: str | None = None return_hidden_states: bool = False # Extra parameters for SRT backend only and will be ignored by OpenAI models. top_k: int = -1 min_p: float = 0.0 min_tokens: int = 0 - json_schema: Optional[str] = None - regex: Optional[str] = None - ebnf: Optional[str] = None + json_schema: str | None = None + regex: str | None = None + ebnf: str | None = None repetition_penalty: float = 1.0 - stop_token_ids: Optional[List[int]] = None + stop_token_ids: list[int] | None = None no_stop_trim: bool = False ignore_eos: bool = False skip_special_tokens: bool = True - lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None - session_params: Optional[Dict] = None + lora_path: list[str | None] | str | None | None = None + session_params: dict | None = None # For PD disaggregation - bootstrap_host: Optional[str] = None - bootstrap_port: Optional[int] = None - bootstrap_room: Optional[int] = None + bootstrap_host: str | None = None + bootstrap_port: int | None = None + bootstrap_room: int | None = None # For request id - rid: Optional[Union[List[str], str]] = None + rid: list[str] | str | None = None @validator("max_tokens") @classmethod @@ -191,10 +190,10 @@ def validate_max_tokens_positive(cls, v): class CompletionResponseChoice(BaseModel): index: int text: str - logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None - matched_stop: Union[None, int, str] = None - hidden_states: Optional[object] = None + logprobs: LogProbs | None = None + finish_reason: Literal["stop", "length", "content_filter", "abort"] | None = None + matched_stop: None | int | str = None + hidden_states: object | None = None # @model_serializer(mode="wrap") # Not available in pydantic v1 # def _serialize(self, handler): @@ -209,17 +208,17 @@ class CompletionResponse(BaseModel): object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[CompletionResponseChoice] + choices: list[CompletionResponseChoice] usage: UsageInfo class CompletionResponseStreamChoice(BaseModel): index: int text: str - logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None - matched_stop: Union[None, int, str] = None - hidden_states: Optional[object] = None + logprobs: LogProbs | None = None + finish_reason: Literal["stop", "length", "content_filter", "abort"] | None = None + matched_stop: None | int | str = None + hidden_states: object | None = None # @model_serializer(mode="wrap") # Not available in pydantic v1 # def _serialize(self, handler): @@ -234,8 +233,8 @@ class CompletionStreamResponse(BaseModel): object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[CompletionResponseStreamChoice] - usage: Optional[UsageInfo] = None + choices: list[CompletionResponseStreamChoice] + usage: UsageInfo | None = None class ChatCompletionMessageContentTextPart(BaseModel): @@ -245,7 +244,7 @@ class ChatCompletionMessageContentTextPart(BaseModel): class ChatCompletionMessageContentImageURL(BaseModel): url: str - detail: Optional[Literal["auto", "low", "high"]] = "auto" + detail: Literal["auto", "low", "high"] | None = "auto" class ChatCompletionMessageContentVideoURL(BaseModel): @@ -259,7 +258,7 @@ class ChatCompletionMessageContentAudioURL(BaseModel): class ChatCompletionMessageContentImagePart(BaseModel): type: Literal["image_url"] image_url: ChatCompletionMessageContentImageURL - modalities: Optional[Literal["image", "multi-images", "video"]] = "image" + modalities: Literal["image", "multi-images", "video"] | None = "image" class ChatCompletionMessageContentVideoPart(BaseModel): @@ -272,37 +271,37 @@ class ChatCompletionMessageContentAudioPart(BaseModel): audio_url: ChatCompletionMessageContentAudioURL -ChatCompletionMessageContentPart = Union[ - ChatCompletionMessageContentTextPart, - ChatCompletionMessageContentImagePart, - ChatCompletionMessageContentVideoPart, - ChatCompletionMessageContentAudioPart, -] +ChatCompletionMessageContentPart = ( + ChatCompletionMessageContentTextPart + | ChatCompletionMessageContentImagePart + | ChatCompletionMessageContentVideoPart + | ChatCompletionMessageContentAudioPart +) class FunctionResponse(BaseModel): """Function response.""" - name: Optional[str] = None - arguments: Optional[str] = None + name: str | None = None + arguments: str | None = None class ToolCall(BaseModel): """Tool call response.""" - id: Optional[str] = None - index: Optional[int] = None + id: str | None = None + index: int | None = None type: Literal["function"] = "function" function: FunctionResponse class ChatCompletionMessageGenericParam(BaseModel): role: Literal["system", "assistant", "tool"] - content: Union[str, List[ChatCompletionMessageContentTextPart], None] - tool_call_id: Optional[str] = None - name: Optional[str] = None - reasoning_content: Optional[str] = None - tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + content: str | list[ChatCompletionMessageContentTextPart] | None + tool_call_id: str | None = None + name: str | None = None + reasoning_content: str | None = None + tool_calls: list[ToolCall] | None = Field(default=None, examples=[None]) @validator("role", pre=True) @classmethod @@ -319,37 +318,37 @@ def _normalize_role(cls, v): class ChatCompletionMessageUserParam(BaseModel): role: Literal["user"] - content: Union[str, List[ChatCompletionMessageContentPart]] + content: str | list[ChatCompletionMessageContentPart] -ChatCompletionMessageParam = Union[ - ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam -] +ChatCompletionMessageParam = ( + ChatCompletionMessageGenericParam | ChatCompletionMessageUserParam +) class ResponseFormat(BaseModel): type: Literal["text", "json_object", "json_schema"] - json_schema: Optional[JsonSchemaResponseFormat] = None + json_schema: JsonSchemaResponseFormat | None = None class StructuresResponseFormat(BaseModel): begin: str - schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) + schema_: dict[str, object] | None = Field(alias="schema", default=None) end: str class StructuralTagResponseFormat(BaseModel): type: Literal["structural_tag"] - structures: List[StructuresResponseFormat] - triggers: List[str] + structures: list[StructuresResponseFormat] + triggers: list[str] class Function(BaseModel): """Function descriptions.""" - description: Optional[str] = Field(default=None, examples=[None]) - name: Optional[str] = None - parameters: Optional[object] = None + description: str | None = Field(default=None, examples=[None]) + name: str | None = None + parameters: object | None = None strict: bool = False @@ -363,7 +362,7 @@ class Tool(BaseModel): class ToolChoiceFuncName(BaseModel): """The name of tool choice function.""" - name: Optional[str] = None + name: str | None = None class ToolChoice(BaseModel): @@ -376,34 +375,34 @@ class ToolChoice(BaseModel): class ChatCompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create - messages: List[ChatCompletionMessageParam] + messages: list[ChatCompletionMessageParam] model: str frequency_penalty: float = 0.0 - logit_bias: Optional[Dict[str, float]] = None + logit_bias: dict[str, float] | None = None logprobs: bool = False - top_logprobs: Optional[int] = None - max_tokens: Optional[int] = Field( + top_logprobs: int | None = None + max_tokens: int | None = Field( default=None, deprecated="max_tokens is deprecated in favor of the max_completion_tokens field", description="The maximum number of tokens that can be generated in the chat completion. ", ) - max_completion_tokens: Optional[int] = Field( + max_completion_tokens: int | None = Field( default=None, description="The maximum number of completion tokens for a chat completion request, " "including visible output tokens and reasoning tokens. Input tokens are not included. ", ) n: int = 1 presence_penalty: float = 0.0 - response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None - seed: Optional[int] = None - stop: Optional[Union[str, List[str]]] = None + response_format: ResponseFormat | StructuralTagResponseFormat | None = None + seed: int | None = None + stop: str | list[str] | None = None stream: bool = False - stream_options: Optional[StreamOptions] = None + stream_options: StreamOptions | None = None temperature: float = 0.7 top_p: float = 1.0 - user: Optional[str] = None - tools: Optional[List[Tool]] = Field(default=None, examples=[None]) - tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( + user: str | None = None + tools: list[Tool] | None = Field(default=None, examples=[None]) + tool_choice: ToolChoice | Literal["auto", "required", "none"] = Field( default="auto", examples=["none"] ) # noqa return_hidden_states: bool = False @@ -422,47 +421,48 @@ def set_tool_choice_default(cls, values): top_k: int = -1 min_p: float = 0.0 min_tokens: int = 0 - regex: Optional[str] = None - ebnf: Optional[str] = None + regex: str | None = None + ebnf: str | None = None repetition_penalty: float = 1.0 - stop_token_ids: Optional[List[int]] = None + stop_token_ids: list[int] | None = None no_stop_trim: bool = False ignore_eos: bool = False continue_final_message: bool = False skip_special_tokens: bool = True - lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None - session_params: Optional[Dict] = None + lora_path: list[str | None] | str | None | None = None + session_params: dict | None = None separate_reasoning: bool = True stream_reasoning: bool = True - chat_template_kwargs: Optional[Dict] = None + chat_template_kwargs: dict | None = None # For request id - rid: Optional[Union[List[str], str]] = None + rid: list[str] | str | None = None # For PD disaggregation - bootstrap_host: Optional[str] = None - bootstrap_port: Optional[int] = None - bootstrap_room: Optional[int] = None + bootstrap_host: str | None = None + bootstrap_port: int | None = None + bootstrap_room: int | None = None class ChatMessage(BaseModel): - role: Optional[str] = None - content: Optional[str] = None - reasoning_content: Optional[str] = None - tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + role: str | None = None + content: str | None = None + reasoning_content: str | None = None + tool_calls: list[ToolCall] | None = Field(default=None, examples=[None]) class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage - logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None - finish_reason: Optional[ + logprobs: LogProbs | ChoiceLogprobs | None = None + finish_reason: ( Literal[ "stop", "length", "tool_calls", "content_filter", "function_call", "abort" ] - ] = None - matched_stop: Union[None, int, str] = None - hidden_states: Optional[object] = None + | None + ) = None + matched_stop: None | int | str = None + hidden_states: object | None = None # @model_serializer(mode="wrap") # Not available in pydantic v1 # def _serialize(self, handler): @@ -477,16 +477,16 @@ class ChatCompletionResponse(BaseModel): object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[ChatCompletionResponseChoice] + choices: list[ChatCompletionResponseChoice] usage: UsageInfo class DeltaMessage(BaseModel): - role: Optional[str] = None - content: Optional[str] = None - reasoning_content: Optional[str] = None - tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) - hidden_states: Optional[object] = None + role: str | None = None + content: str | None = None + reasoning_content: str | None = None + tool_calls: list[ToolCall] | None = Field(default=None, examples=[None]) + hidden_states: object | None = None # @model_serializer(mode="wrap") # Not available in pydantic v1 # def _serialize(self, handler): @@ -499,13 +499,14 @@ class DeltaMessage(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage - logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None - finish_reason: Optional[ + logprobs: LogProbs | ChoiceLogprobs | None = None + finish_reason: ( Literal[ "stop", "length", "tool_calls", "content_filter", "function_call", "abort" ] - ] = None - matched_stop: Union[None, int, str] = None + | None + ) = None + matched_stop: None | int | str = None class ChatCompletionStreamResponse(BaseModel): @@ -513,18 +514,18 @@ class ChatCompletionStreamResponse(BaseModel): object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) model: str - choices: List[ChatCompletionResponseStreamChoice] - usage: Optional[UsageInfo] = None + choices: list[ChatCompletionResponseStreamChoice] + usage: UsageInfo | None = None class MultimodalEmbeddingInput(BaseModel): - text: Optional[str] = None - image: Optional[str] = None + text: str | None = None + image: str | None = None -EmbeddingInput = Union[ - List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput] -] +EmbeddingInput = ( + list[int] | list[list[int]] | str | list[str] | list[MultimodalEmbeddingInput] +) class EmbeddingRequest(BaseModel): @@ -533,69 +534,65 @@ class EmbeddingRequest(BaseModel): input: EmbeddingInput model: str encoding_format: str = "float" - dimensions: Optional[int] = None - user: Optional[str] = None + dimensions: int | None = None + user: str | None = None # The request id. - rid: Optional[Union[List[str], str]] = None + rid: list[str] | str | None = None class EmbeddingObject(BaseModel): - embedding: List[float] + embedding: list[float] index: int object: str = "embedding" class EmbeddingResponse(BaseModel): - data: List[EmbeddingObject] + data: list[EmbeddingObject] model: str object: str = "list" - usage: Optional[UsageInfo] = None + usage: UsageInfo | None = None class ScoringRequest(BaseModel): - query: Optional[Union[str, List[int]]] = ( - None # Query text or pre-tokenized token IDs - ) - items: Optional[Union[str, List[str], List[List[int]]]] = ( + query: str | list[int] | None = None # Query text or pre-tokenized token IDs + items: str | list[str] | list[list[int]] | None = ( None # Item text(s) or pre-tokenized token IDs ) - label_token_ids: Optional[List[int]] = ( - None # Token IDs to compute probabilities for - ) + label_token_ids: list[int] | None = None # Token IDs to compute probabilities for apply_softmax: bool = False item_first: bool = False model: str class ScoringResponse(BaseModel): - scores: List[ - List[float] + scores: list[ + list[float] ] # List of lists of probabilities, each in the order of label_token_ids model: str - usage: Optional[UsageInfo] = None + usage: UsageInfo | None = None object: str = "scoring" class V1RerankReqInput(BaseModel): query: str - documents: List[str] + documents: list[str] class RerankResponse(BaseModel): score: float document: str index: int - meta_info: Optional[dict] = None + meta_info: dict | None = None -OpenAIServingRequest = Union[ - ChatCompletionRequest, - CompletionRequest, - EmbeddingRequest, - ScoringRequest, - V1RerankReqInput, -] +OpenAIServingRequest = ( + ChatCompletionRequest + | CompletionRequest + | EmbeddingRequest + | ScoringRequest + | V1RerankReqInput +) @dataclass @@ -617,10 +614,10 @@ class MessageProcessingResult: """ prompt: str - prompt_ids: Union[str, List[int]] - image_data: Optional[Any] - audio_data: Optional[Any] - video_data: Optional[Any] - modalities: List[str] - stop: List[str] - tool_call_constraint: Optional[Any] = None + prompt_ids: str | list[int] + image_data: Any | None + audio_data: Any | None + video_data: Any | None + modalities: list[str] + stop: list[str] + tool_call_constraint: Any | None = None diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_base.py b/python/sgl_jax/srt/entrypoints/openai/serving_base.py index 44100f7e5..773815735 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_base.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_base.py @@ -2,7 +2,7 @@ import logging import uuid from abc import ABC, abstractmethod -from typing import Any, Optional, Union +from typing import Any from fastapi import Request from fastapi.responses import ORJSONResponse, StreamingResponse @@ -23,7 +23,7 @@ def __init__(self, tokenizer_manager: TokenizerManager): async def handle_request( self, request: OpenAIServingRequest, raw_request: Request - ) -> Union[Any, StreamingResponse, ErrorResponse]: + ) -> Any | StreamingResponse | ErrorResponse: """Handle the specific request type with common pattern""" try: # Validate request @@ -47,7 +47,7 @@ async def handle_request( ) except Exception as e: - logger.exception(f"Error in request: {e}") + logger.exception("Error in request: %s", e) return self.create_error_response( message=f"Internal server error: {str(e)}", err_type="InternalServerError", @@ -59,7 +59,7 @@ def _request_id_prefix(self) -> str: """Generate request ID based on request type""" pass - def _generate_request_id_base(self, request: OpenAIServingRequest) -> Optional[str]: + def _generate_request_id_base(self, request: OpenAIServingRequest) -> str | None: """Generate request ID based on request type""" return None @@ -82,7 +82,7 @@ async def _handle_streaming_request( adapted_request: GenerateReqInput, request: OpenAIServingRequest, raw_request: Request, - ) -> Union[StreamingResponse, ErrorResponse, ORJSONResponse]: + ) -> StreamingResponse | ErrorResponse | ORJSONResponse: """Handle streaming request Override this method in child classes that support streaming requests. @@ -98,7 +98,7 @@ async def _handle_non_streaming_request( adapted_request: GenerateReqInput, request: OpenAIServingRequest, raw_request: Request, - ) -> Union[Any, ErrorResponse, ORJSONResponse]: + ) -> Any | ErrorResponse | ORJSONResponse: """Handle non-streaming request Override this method in child classes that support non-streaming requests. @@ -109,16 +109,16 @@ async def _handle_non_streaming_request( status_code=501, ) - def _validate_request(self, _: OpenAIServingRequest) -> Optional[str]: + def _validate_request(self, _: OpenAIServingRequest) -> str | None: """Validate request""" - pass + return None def create_error_response( self, message: str, err_type: str = "BadRequestError", status_code: int = 400, - param: Optional[str] = None, + param: str | None = None, ) -> ORJSONResponse: """Create an error response""" error = ErrorResponse( diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_chat.py b/python/sgl_jax/srt/entrypoints/openai/serving_chat.py index c20d79104..9fa64a0ed 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_chat.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_chat.py @@ -3,7 +3,8 @@ import logging import time import uuid -from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from collections.abc import AsyncGenerator +from typing import Any from fastapi import Request from fastapi.responses import ORJSONResponse, StreamingResponse @@ -122,7 +123,7 @@ def _process_messages( def _apply_jinja_template( self, request: ChatCompletionRequest, - tools: Optional[List[Dict]], + tools: list[dict] | None, is_multimodal: bool, ) -> MessageProcessingResult: """Apply Jinja chat template""" @@ -165,21 +166,20 @@ def _apply_jinja_template( except json.JSONDecodeError as e: # Log a warning or error if JSON parsing fails for arguments logger.warning( - f"Failed to parse tool call arguments as JSON: {e}" + "Failed to parse tool call arguments as JSON: %s", e ) # Decide whether to continue or raise the exception based on desired behavior continue # Or raise e if strict parsing is required openai_compatible_messages.append(processed_msg) # Handle assistant prefix for continue_final_message - assistant_prefix = None if ( openai_compatible_messages and openai_compatible_messages[-1]["role"] == "assistant" + and request.continue_final_message ): - if request.continue_final_message: - assistant_prefix = openai_compatible_messages[-1]["content"] - openai_compatible_messages = openai_compatible_messages[:-1] + assistant_prefix = openai_compatible_messages[-1]["content"] + openai_compatible_messages = openai_compatible_messages[:-1] try: # Check if tokenizer has a chat template, if not, provide a default one @@ -310,9 +310,9 @@ def _apply_conversation_template( def _build_sampling_params( self, request: ChatCompletionRequest, - stop: List[str], - tool_call_constraint: Optional[Any], - ) -> Dict[str, Any]: + stop: list[str], + tool_call_constraint: Any | None, + ) -> dict[str, Any]: """Build sampling parameters for the request""" sampling_params = { @@ -599,7 +599,7 @@ async def _handle_non_streaming_request( adapted_request: GenerateReqInput, request: ChatCompletionRequest, raw_request: Request, - ) -> Union[ChatCompletionResponse, ErrorResponse, ORJSONResponse]: + ) -> ChatCompletionResponse | ErrorResponse | ORJSONResponse: """Handle non-streaming chat completion request""" try: ret = await self.tokenizer_manager.generate_request( @@ -622,9 +622,9 @@ async def _handle_non_streaming_request( def _build_chat_response( self, request: ChatCompletionRequest, - ret: List[Dict[str, Any]], + ret: list[dict[str, Any]], created: int, - ) -> Union[ChatCompletionResponse, ORJSONResponse]: + ) -> ChatCompletionResponse | ORJSONResponse: """Build chat completion response from generation results""" choices = [] @@ -650,7 +650,7 @@ def _build_chat_response( ) reasoning_text, text = parser.parse_non_stream(text) except Exception as e: - logger.error(f"Reasoning parsing error: {e}") + logger.error("Reasoning parsing error: %s", e) return self.create_error_response( "Failed to parse reasoning content", err_type="InternalServerError", @@ -701,7 +701,7 @@ def _build_chat_response( def _process_logprobs_tokens( self, logprobs: LogProbs, use_token_index: bool = False - ) -> List[ChatCompletionTokenLogprob]: + ) -> list[ChatCompletionTokenLogprob]: """Common helper to process logprobs tokens for both streaming and non-streaming Args: @@ -741,7 +741,7 @@ def _process_logprobs_tokens( return token_logprobs - def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs: + def _process_response_logprobs(self, ret_item: dict[str, Any]) -> ChoiceLogprobs: """Process logprobs for non-streaming response""" logprobs = to_openai_style_logprobs( output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], @@ -754,10 +754,10 @@ def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs def _process_tool_calls( self, text: str, - tools: List[Any], - tool_call_parser: Optional[str], - finish_reason: Dict[str, Any], - ) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]: + tools: list[Any], + tool_call_parser: str | None, + finish_reason: dict[str, Any], + ) -> tuple[list[ToolCall] | None, str, dict[str, Any]]: """Process tool calls in the response""" parser = FunctionCallParser(tools, tool_call_parser) if parser.has_tool_call(text): @@ -777,14 +777,14 @@ def _process_tool_calls( ] return tool_calls, text, finish_reason except Exception as e: - logger.error(f"Tool call parsing error: {e}") + logger.error("Tool call parsing error: %s", e) # Return error but don't fail the whole request return None, text, finish_reason return None, text, finish_reason def _process_streaming_logprobs( - self, content: Dict[str, Any], n_prev_token: int + self, content: dict[str, Any], n_prev_token: int ) -> ChoiceLogprobs: """Process logprobs for streaming response""" logprobs = to_openai_style_logprobs( @@ -803,10 +803,10 @@ def _process_reasoning_stream( self, index: int, delta: str, - reasoning_parser_dict: Dict[int, ReasoningParser], - content: Dict[str, Any], + reasoning_parser_dict: dict[int, ReasoningParser], + content: dict[str, Any], request: ChatCompletionRequest, - ) -> tuple[Optional[str], str]: + ) -> tuple[str | None, str]: """Process reasoning content in streaming response""" if index not in reasoning_parser_dict: reasoning_parser_dict[index] = ReasoningParser( @@ -839,10 +839,10 @@ async def _process_tool_call_stream( self, index: int, delta: str, - parser_dict: Dict[int, FunctionCallParser], - content: Dict[str, Any], + parser_dict: dict[int, FunctionCallParser], + content: dict[str, Any], request: ChatCompletionRequest, - finish_reason_type: Optional[str], + finish_reason_type: str | None, ): """Process tool calls in streaming response""" if index not in parser_dict: diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_completions.py b/python/sgl_jax/srt/entrypoints/openai/serving_completions.py index ade9017c7..aa3c2e606 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_completions.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_completions.py @@ -1,6 +1,7 @@ import logging import time -from typing import Any, AsyncGenerator, Dict, List, Union +from collections.abc import AsyncGenerator +from typing import Any from fastapi import Request from fastapi.responses import ORJSONResponse, StreamingResponse @@ -58,10 +59,7 @@ def _convert_to_internal_request( pass # Set logprob start length based on echo and logprobs - if request.echo and request.logprobs: - logprob_start_len = 0 - else: - logprob_start_len = -1 + logprob_start_len = 0 if request.echo and request.logprobs else -1 # Build sampling parameters sampling_params = self._build_sampling_params(request) @@ -92,7 +90,7 @@ def _convert_to_internal_request( return adapted_request, request - def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]: + def _build_sampling_params(self, request: CompletionRequest) -> dict[str, Any]: """Build sampling parameters for the request""" # Start with common parameters sampling_params = { @@ -164,11 +162,9 @@ async def _generate_completion_stream( hidden_states[index] = content["meta_info"].get("hidden_states", None) stream_buffer = stream_buffers.get(index, "") - # Handle echo for first chunk - if not stream_buffer: # The first chunk - if request.echo: - echo_text = self._get_echo_text(request, index) - text = echo_text + text + if not stream_buffer and request.echo: # The first chunk + echo_text = self._get_echo_text(request, index) + text = echo_text + text # Handle logprobs logprobs = None @@ -278,7 +274,7 @@ async def _handle_non_streaming_request( adapted_request: GenerateReqInput, request: CompletionRequest, raw_request: Request, - ) -> Union[CompletionResponse, ErrorResponse, ORJSONResponse]: + ) -> CompletionResponse | ErrorResponse | ORJSONResponse: """Handle non-streaming completion request""" try: generator = self.tokenizer_manager.generate_request( @@ -302,7 +298,7 @@ async def _handle_non_streaming_request( def _build_completion_response( self, request: CompletionRequest, - ret: List[Dict[str, Any]], + ret: list[dict[str, Any]], created: int, ) -> CompletionResponse: """Build completion response from generation results""" @@ -399,7 +395,7 @@ def _get_echo_text(self, request: CompletionRequest, index: int) -> str: ) return "" - def _prepare_echo_prompts(self, request: CompletionRequest) -> List[str]: + def _prepare_echo_prompts(self, request: CompletionRequest) -> list[str]: """Prepare echo prompts for non-streaming response""" if isinstance(request.prompt, list) and isinstance(request.prompt[0], str): # for the case of multiple str prompts diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_embedding.py b/python/sgl_jax/srt/entrypoints/openai/serving_embedding.py index 256927834..a6db0ebc4 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_embedding.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_embedding.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any from fastapi import Request from fastapi.responses import ORJSONResponse @@ -31,7 +31,7 @@ def __init__( def _request_id_prefix(self) -> str: return "embd-" - def _validate_request(self, request: EmbeddingRequest) -> Optional[str]: + def _validate_request(self, request: EmbeddingRequest) -> str | None: """Validate that the input is not empty or whitespace only.""" if not (input := request.input): return "Input cannot be empty" @@ -130,7 +130,7 @@ async def _handle_non_streaming_request( adapted_request: EmbeddingReqInput, request: EmbeddingRequest, raw_request: Request, - ) -> Union[EmbeddingResponse, ErrorResponse, ORJSONResponse]: + ) -> EmbeddingResponse | ErrorResponse | ORJSONResponse: """Handle the embedding request""" try: ret = await self.tokenizer_manager.generate_request( @@ -145,7 +145,7 @@ async def _handle_non_streaming_request( response = self._build_embedding_response(ret) return response - def _build_embedding_response(self, ret: List[Dict[str, Any]]) -> EmbeddingResponse: + def _build_embedding_response(self, ret: list[dict[str, Any]]) -> EmbeddingResponse: """Build the embedding response""" embedding_objects = [] prompt_tokens = 0 diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_rerank.py b/python/sgl_jax/srt/entrypoints/openai/serving_rerank.py index 070b8d560..f8cf86cd3 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_rerank.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_rerank.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any from fastapi import Request from fastapi.responses import ORJSONResponse @@ -24,14 +24,13 @@ class OpenAIServingRerank(OpenAIServingBase): def _request_id_prefix(self) -> str: return "rerank-" - def _validate_request(self, request: V1RerankReqInput) -> Optional[str]: + def _validate_request(self, request: V1RerankReqInput) -> str | None: """Validate rerank request format and content""" if not request.query: return "Query cannot be empty" - if isinstance(request.query, str): - if not request.query.strip(): - return "Query cannot be empty or whitespace only" + if isinstance(request.query, str) and not request.query.strip(): + return "Query cannot be empty or whitespace only" if not request.documents: return "Documents cannot be empty" @@ -65,7 +64,7 @@ async def _handle_non_streaming_request( adapted_request: EmbeddingReqInput, request: V1RerankReqInput, raw_request: Request, - ) -> Union[List[RerankResponse], ErrorResponse, ORJSONResponse]: + ) -> list[RerankResponse] | ErrorResponse | ORJSONResponse: """Handle the rerank request""" try: ret = await self.tokenizer_manager.generate_request( @@ -82,8 +81,8 @@ async def _handle_non_streaming_request( return responses def _build_rerank_response( - self, ret: List[Dict[str, Any]], request: V1RerankReqInput - ) -> List[RerankResponse]: + self, ret: list[dict[str, Any]], request: V1RerankReqInput + ) -> list[RerankResponse]: """Build the rerank response from generation results""" responses = [] for idx, ret_item in enumerate(ret): diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_score.py b/python/sgl_jax/srt/entrypoints/openai/serving_score.py index c3e926260..e046b4b04 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_score.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_score.py @@ -1,5 +1,4 @@ import logging -from typing import Union from fastapi import Request @@ -37,7 +36,7 @@ async def _handle_non_streaming_request( adapted_request: ScoringRequest, request: ScoringRequest, raw_request: Request, - ) -> Union[ScoringResponse, ErrorResponse]: + ) -> ScoringResponse | ErrorResponse: """Handle the scoring request""" try: # Use tokenizer_manager's score_request method directly diff --git a/python/sgl_jax/srt/entrypoints/openai/usage_processor.py b/python/sgl_jax/srt/entrypoints/openai/usage_processor.py index 65152cf09..dfd2d7957 100644 --- a/python/sgl_jax/srt/entrypoints/openai/usage_processor.py +++ b/python/sgl_jax/srt/entrypoints/openai/usage_processor.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Mapping, Optional, final +from collections.abc import Mapping +from typing import Any, final from sgl_jax.srt.entrypoints.openai.protocol import UsageInfo @@ -10,13 +11,13 @@ class UsageProcessor: """Stateless helpers that turn raw token counts into a UsageInfo.""" @staticmethod - def _details_if_cached(count: int) -> Optional[Dict[str, int]]: + def _details_if_cached(count: int) -> dict[str, int] | None: """Return {"cached_tokens": N} only when N > 0 (keeps JSON slim).""" return {"cached_tokens": count} if count > 0 else None @staticmethod def calculate_response_usage( - responses: List[Dict[str, Any]], + responses: list[dict[str, Any]], n_choices: int = 1, enable_cache_report: bool = False, ) -> UsageInfo: @@ -70,7 +71,7 @@ def calculate_streaming_usage( def calculate_token_usage( prompt_tokens: int, completion_tokens: int, - cached_tokens: Optional[Dict[str, int]] = None, + cached_tokens: dict[str, int] | None = None, ) -> UsageInfo: """Calculate token usage information""" return UsageInfo( diff --git a/python/sgl_jax/srt/entrypoints/openai/utils.py b/python/sgl_jax/srt/entrypoints/openai/utils.py index 7fb4a665f..478fa2640 100644 --- a/python/sgl_jax/srt/entrypoints/openai/utils.py +++ b/python/sgl_jax/srt/entrypoints/openai/utils.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any from sgl_jax.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, @@ -48,12 +48,9 @@ def append_top_logprobs(top_logprobs): def process_hidden_states_from_ret( - ret_item: Dict[str, Any], - request: Union[ - ChatCompletionRequest, - CompletionRequest, - ], -) -> Optional[List]: + ret_item: dict[str, Any], + request: ChatCompletionRequest | CompletionRequest, +) -> list | None: """Process hidden states from a ret item in non-streaming response. Args: diff --git a/python/sgl_jax/srt/function_call/function_call_parser.py b/python/sgl_jax/srt/function_call/function_call_parser.py index 6744c7b56..9aa2bfbf8 100644 --- a/python/sgl_jax/srt/function_call/function_call_parser.py +++ b/python/sgl_jax/srt/function_call/function_call_parser.py @@ -5,7 +5,7 @@ import json import logging -from typing import Any, Dict, List, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -13,11 +13,11 @@ class FunctionCallParser: """Parser for handling function calls and tool calls.""" - def __init__(self, parser_type: Optional[str] = None): + def __init__(self, parser_type: str | None = None): self.parser_type = parser_type or "default" - logger.info(f"Initialized FunctionCallParser with type: {self.parser_type}") + logger.info("Initialized FunctionCallParser with type: %s", self.parser_type) - def parse_function_call(self, text: str) -> Optional[Dict[str, Any]]: + def parse_function_call(self, text: str) -> dict[str, Any] | None: """ Parse function call from text. @@ -38,10 +38,10 @@ def parse_function_call(self, text: str) -> Optional[Dict[str, Any]]: return json.loads(json_str) return None except (json.JSONDecodeError, ValueError) as e: - logger.debug(f"Failed to parse function call: {e}") + logger.debug("Failed to parse function call: %s", e) return None - def extract_tool_calls(self, text: str) -> List[Dict[str, Any]]: + def extract_tool_calls(self, text: str) -> list[dict[str, Any]]: """ Extract tool calls from text. diff --git a/python/sgl_jax/srt/hf_transformers_utils.py b/python/sgl_jax/srt/hf_transformers_utils.py index f169a200e..fbc9a653d 100644 --- a/python/sgl_jax/srt/hf_transformers_utils.py +++ b/python/sgl_jax/srt/hf_transformers_utils.py @@ -4,7 +4,6 @@ import os import warnings from pathlib import Path -from typing import Dict, Optional, Type, Union from huggingface_hub import snapshot_download from transformers import ( @@ -21,7 +20,7 @@ from sgl_jax.srt.utils.common_utils import is_remote_url, lru_cache_frozenset -_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {} +_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {} for name, cls in _CONFIG_REGISTRY.items(): with contextlib.suppress(ValueError): @@ -51,10 +50,8 @@ def get_hf_text_config(config: PretrainedConfig): # qwen2.5 omni thinker_config = config.thinker_config if hasattr(thinker_config, "text_config"): - setattr( - thinker_config.text_config, - "torch_dtype", - getattr(thinker_config, "torch_dtype", None), + thinker_config.text_config.torch_dtype = getattr( + thinker_config, "torch_dtype", None ) return thinker_config.text_config return thinker_config @@ -66,8 +63,8 @@ def get_hf_text_config(config: PretrainedConfig): def get_config( model: str, trust_remote_code: bool, - revision: Optional[str] = None, - model_override_args: Optional[dict] = None, + revision: str | None = None, + model_override_args: dict | None = None, **kwargs, ): is_gguf = check_gguf_file(model) @@ -89,7 +86,7 @@ def get_config( config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`. - setattr(config, "_name_or_path", model) + config._name_or_path = model if isinstance(model, str) and config.model_type == "internvl_chat": for key, val in config.llm_config.__dict__.items(): @@ -116,7 +113,7 @@ def get_config( def get_generation_config( model: str, trust_remote_code: bool, - revision: Optional[str] = None, + revision: str | None = None, **kwargs, ): try: @@ -169,9 +166,9 @@ def get_tokenizer( *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, - tokenizer_revision: Optional[str] = None, + tokenizer_revision: str | None = None, **kwargs, -) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: +) -> PreTrainedTokenizer | PreTrainedTokenizerFast: """Gets a tokenizer for the given model name via Huggingface.""" if tokenizer_mode == "slow": if kwargs.get("use_fast", False): @@ -228,8 +225,8 @@ def get_tokenizer( if not isinstance(tokenizer, PreTrainedTokenizerFast): warnings.warn( - "Using a slow tokenizer. This might cause a significant " - "slowdown. Consider using a fast tokenizer instead." + "Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.", + stacklevel=2, ) attach_additional_stop_token_ids(tokenizer) @@ -248,8 +245,8 @@ def get_processor( *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, - tokenizer_revision: Optional[str] = None, - use_fast: Optional[bool] = True, + tokenizer_revision: str | None = None, + use_fast: bool | None = True, **kwargs, ): # pop 'revision' from kwargs if present. @@ -263,9 +260,8 @@ def get_processor( ) # fix: for Qwen2-VL model, inject default 'size' if not provided. - if config.model_type in {"qwen2_vl"}: - if "size" not in kwargs: - kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520} + if config.model_type in {"qwen2_vl"} and "size" not in kwargs: + kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520} if config.model_type not in {"llava", "clip"}: kwargs["use_fast"] = use_fast @@ -294,7 +290,7 @@ def attach_additional_stop_token_ids(tokenizer): tokenizer.additional_stop_token_ids = None -def check_gguf_file(model: Union[str, os.PathLike]) -> bool: +def check_gguf_file(model: str | os.PathLike) -> bool: """Check if the file is a GGUF model.""" model = Path(model) if not model.is_file(): diff --git a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py index bb7fce216..39fb6cc58 100644 --- a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py +++ b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py @@ -460,7 +460,7 @@ def wait_send_bo(bo_sem_idx): old_seq_idx = bo_ids_ref[bo_sem_idx] old_bo_idx = bo_ids_ref[bo_sem_idx + 2] - @pl.when(jnp.logical_and(0 <= old_seq_idx, old_seq_idx <= seq_idx)) + @pl.when(jnp.logical_and(old_seq_idx >= 0, old_seq_idx <= seq_idx)) def _(): _send_bo(old_seq_idx, old_bo_idx, bo_sem_idx, wait=True) @@ -989,8 +989,7 @@ def static_validate_inputs_fused( if actual_num_q_heads % actual_num_kv_heads != 0: raise ValueError( - f"Expected {actual_num_q_heads=} to be divisible by" - f" {actual_num_kv_heads=}." + f"Expected {actual_num_q_heads=} to be divisible by {actual_num_kv_heads=}." ) # Validate fused KV cache @@ -1039,8 +1038,7 @@ def static_validate_inputs_fused( if not (len(kv_lens.shape) == len(page_indices.shape) == len(cu_q_lens.shape) == 1): raise ValueError( - f"Expected 1D array for {kv_lens.shape=}, {page_indices.shape=}," - f" {cu_q_lens.shape=}" + f"Expected 1D array for {kv_lens.shape=}, {page_indices.shape=}, {cu_q_lens.shape=}" ) max_num_seqs = kv_lens.shape[0] @@ -1057,12 +1055,10 @@ def static_validate_inputs_fused( raise ValueError(f"{soft_cap=} must not be 0.0.") if chunk_prefill_size is not None and chunk_prefill_size <= 0: raise ValueError(f"{chunk_prefill_size=} must be positive.") - if num_kv_pages_per_block is not None: - if num_kv_pages_per_block <= 0: - raise ValueError(f"{num_kv_pages_per_block=} must be positive.") - if num_queries_per_block is not None: - if num_queries_per_block <= 0: - raise ValueError(f"{num_queries_per_block=} must be positive.") + if num_kv_pages_per_block is not None and num_kv_pages_per_block <= 0: + raise ValueError(f"{num_kv_pages_per_block=} must be positive.") + if num_queries_per_block is not None and num_queries_per_block <= 0: + raise ValueError(f"{num_queries_per_block=} must be positive.") if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: raise ValueError(f"{vmem_limit_bytes=} must be positive.") diff --git a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py index a99fb8169..1ff0b5ff7 100644 --- a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py +++ b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py @@ -142,9 +142,9 @@ ("bfloat16", "bfloat16", 4, 4, 128, 256, 64): (8, 16), ("bfloat16", "bfloat16", 4, 4, 128, 256, 128): (8, 4), ("bfloat16", "bfloat16", 4, 4, 128, 256, 256): (8, 1), - ("bfloat16", "bfloat16", 4, 4, 128, 256, 512): (4, 128), + ("bfloat16", "bfloat16", 4, 4, 128, 256, 512): (1, 128), ("bfloat16", "bfloat16", 4, 4, 128, 256, 1024): (4, 128), - ("bfloat16", "bfloat16", 4, 4, 128, 256, 2048): (1, 128), + ("bfloat16", "bfloat16", 4, 4, 128, 256, 2048): (4, 128), ("bfloat16", "bfloat16", 4, 4, 128, 256, 4096): (8, 128), ("bfloat16", "bfloat16", 4, 4, 128, 256, 8192): (8, 128), ("bfloat16", "bfloat16", 8, 2, 128, 64, 1): (1, 16), @@ -688,9 +688,8 @@ def get_tuned_block_sizes( # TPUv4 has much smaller VMEM size so we pick fixed block sizes. bkv_p, bq = (512 // page_size, 32) else: - if device_name in TUNED_BLOCK_SIZES: - if keys in TUNED_BLOCK_SIZES[device_name]: - bkv_p, bq = TUNED_BLOCK_SIZES[device_name][keys] + if device_name in TUNED_BLOCK_SIZES and keys in TUNED_BLOCK_SIZES[device_name]: + bkv_p, bq = TUNED_BLOCK_SIZES[device_name][keys] return (min(pages_per_seq, bkv_p), min(max_num_tokens, bq)) diff --git a/python/sgl_jax/srt/layers/attention/flashattention_backend.py b/python/sgl_jax/srt/layers/attention/flashattention_backend.py index 608cbfffb..b3a1d2a9f 100644 --- a/python/sgl_jax/srt/layers/attention/flashattention_backend.py +++ b/python/sgl_jax/srt/layers/attention/flashattention_backend.py @@ -206,10 +206,9 @@ def __call__( forward_batch, token_to_kv_pool, layer.layer_id ) - if layer.scaling is None: - scale = 1.0 / jnp.sqrt(layer.head_dim) - else: - scale = layer.scaling + scale = ( + 1.0 / jnp.sqrt(layer.head_dim) if layer.scaling is None else layer.scaling + ) # Prepare fused KV cache for paged format: [num_pages, page_size, num_kv_heads * 2, head_dim] total_tokens = kv_cache_fused.shape[0] diff --git a/python/sgl_jax/srt/layers/attention/native_backend.py b/python/sgl_jax/srt/layers/attention/native_backend.py index 9eeee07ff..3193c2c64 100644 --- a/python/sgl_jax/srt/layers/attention/native_backend.py +++ b/python/sgl_jax/srt/layers/attention/native_backend.py @@ -1,5 +1,3 @@ -from typing import Tuple - import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class @@ -64,17 +62,14 @@ def __call__( k, v, forward_batch, token_to_kv_pool, layer.layer_id ) - if layer.scaling is None: - scale = 1.0 / jnp.sqrt(layer.head_dim) - else: - scale = layer.scaling + scale = ( + 1.0 / jnp.sqrt(layer.head_dim) if layer.scaling is None else layer.scaling + ) - is_causal = True - if ( + is_causal = not ( forward_batch.forward_mode == ForwardMode.DECODE or layer.attn_type == AttentionType.ENCODER_ONLY - ): - is_causal = False + ) attn_output = forward_attention( q, @@ -101,7 +96,7 @@ def _get_and_update_kv_cache( forward_batch: ForwardBatch, token_to_kv_pool: KVCache, layer_id: int, - ) -> Tuple[jax.Array, jax.Array, jax.Array]: + ) -> tuple[jax.Array, jax.Array, jax.Array]: """ Get the kv cache from the forward batch. """ diff --git a/python/sgl_jax/srt/layers/embeddings.py b/python/sgl_jax/srt/layers/embeddings.py index be4758541..c8de05835 100644 --- a/python/sgl_jax/srt/layers/embeddings.py +++ b/python/sgl_jax/srt/layers/embeddings.py @@ -15,7 +15,7 @@ """Embedding Layers.""" import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any import jax import jax.numpy as jnp @@ -40,7 +40,7 @@ def __init__( self, num_embeddings: int, features: int, - dtype: Optional[jnp.dtype] = None, + dtype: jnp.dtype | None = None, param_dtype: jnp.dtype = jnp.bfloat16, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: nnx.Rngs = None, @@ -125,7 +125,7 @@ def __init__( self, num_embeddings: int, features: int, - dtype: Optional[jnp.dtype] = None, + dtype: jnp.dtype | None = None, param_dtype: jnp.dtype = jnp.bfloat16, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, rngs: nnx.Rngs = None, @@ -204,7 +204,7 @@ def __call__( positions: jax.Array, query: jax.Array, key: jax.Array, - ) -> Tuple[jax.Array, jax.Array]: + ) -> tuple[jax.Array, jax.Array]: positions = positions.flatten() # [num_tokens] inv_freq = jnp.asarray(self._inv_freq_np, dtype=self.dtype) @@ -232,7 +232,7 @@ def __call__( return query, key - def _compute_inv_freq(self, base: Union[int, float]) -> jax.Array: + def _compute_inv_freq(self, base: int | float) -> jax.Array: """Compute the inverse frequency.""" inv_freq = 1.0 / ( base @@ -251,7 +251,6 @@ def _compute_cos_sin_cache(self) -> jax.Array: class Llama3RotaryEmbedding(RotaryEmbedding): - def __init__( self, head_size: int, @@ -273,7 +272,7 @@ def __init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) - def _compute_inv_freq(self, base: Union[int, float]) -> jax.Array: + def _compute_inv_freq(self, base: int | float) -> jax.Array: inv_freqs = super()._compute_inv_freq(base) low_freq_wavelen = self.orig_max_position / self.low_freq_factor high_freq_wavelen = self.orig_max_position / self.high_freq_factor @@ -306,7 +305,7 @@ def rotary_embedding_forward( rotary_dim: int, head_size: int, is_neox_style: bool, -) -> Tuple[jax.Array, jax.Array]: +) -> tuple[jax.Array, jax.Array]: """Rotary Position Embedding.""" positions = positions.flatten() num_tokens = positions.shape[0] @@ -360,7 +359,7 @@ def _apply_rotary_emb( return stacked.reshape(*stacked.shape[:-2], -1) -_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} +_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} def get_rope( @@ -369,10 +368,10 @@ def get_rope( max_position: int, base: int, is_neox_style: bool = True, - rope_scaling: Optional[Dict[str, Any]] = None, - dtype: Optional[jnp.dtype] = jnp.bfloat16, + rope_scaling: dict[str, Any] | None = None, + dtype: jnp.dtype | None = jnp.bfloat16, partial_rotary_factor: float = 1.0, - dual_chunk_attention_config: Optional[Dict[str, Any]] = None, + dual_chunk_attention_config: dict[str, Any] | None = None, ) -> RotaryEmbedding: if rope_scaling is not None: # Transforms every value that is a list into a tuple for caching calls diff --git a/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py b/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py index a33544582..030e305db 100644 --- a/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py +++ b/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py @@ -1,7 +1,7 @@ import functools from collections.abc import Callable from functools import partial -from typing import Any, Optional +from typing import Any import jax import jax.numpy as jnp @@ -28,8 +28,7 @@ def _validate_args( # Validate 'rhs'. if rhs.ndim != expected_rhs_dims: raise ValueError( - f"Expected {expected_rhs_dims}-tensor for 'rhs' but got" - f" {rhs.ndim}-tensor." + f"Expected {expected_rhs_dims}-tensor for 'rhs' but got {rhs.ndim}-tensor." ) common.assert_is_supported_dtype(rhs.dtype) @@ -275,7 +274,7 @@ def _zero_uninitialized_memory( return jnp.where(valid_mask[:, None], out, 0) -LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] +LutFn = Callable[[int, int, int], tuple[int, int, int] | None] @functools.partial( @@ -363,10 +362,9 @@ def gmm( visit_empty_groups=False, ) - if transpose_rhs: - dot_general_dims = (((1,), (1,)), ((), ())) - else: - dot_general_dims = (((1,), (0,)), ((), ())) + dot_general_dims = ( + (((1,), (1,)), ((), ())) if transpose_rhs else (((1,), (0,)), ((), ())) + ) def kernel( group_metadata, diff --git a/python/sgl_jax/srt/layers/linear.py b/python/sgl_jax/srt/layers/linear.py index b847d93ae..5e1a9f4b0 100644 --- a/python/sgl_jax/srt/layers/linear.py +++ b/python/sgl_jax/srt/layers/linear.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Sequence, Tuple +from collections.abc import Iterable, Sequence import jax from flax import nnx @@ -12,7 +12,7 @@ def _canonicalize_tuple(x): return (x,) -def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int, ...]: +def _normalize_axes(axes: Iterable[int], ndim: int) -> tuple[int, ...]: return tuple(ax if ax >= 0 else ndim + ax for ax in axes) @@ -35,8 +35,8 @@ def __init__( output_size: int, use_bias: bool = True, skip_bias_add: bool = False, - params_dtype: Optional[jnp.dtype] = jnp.bfloat16, - kernel_axes: Optional[Sequence[str]] = None, + params_dtype: jnp.dtype | None = jnp.bfloat16, + kernel_axes: Sequence[str] | None = None, rngs: nnx.Rngs = None, ): """Initialize parameters and quantization method.""" @@ -55,7 +55,7 @@ def __init__( else: self.bias = None - def __call__(self, x: jax.Array) -> Tuple[jax.Array, Optional[jax.Array]]: + def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array | None]: """Forward pass of the linear layer.""" bias = self.bias if not self.skip_bias_add else None # Access the underlying JAX array using .value property diff --git a/python/sgl_jax/srt/layers/logits_processor.py b/python/sgl_jax/srt/layers/logits_processor.py index e00daab12..521fb1212 100644 --- a/python/sgl_jax/srt/layers/logits_processor.py +++ b/python/sgl_jax/srt/layers/logits_processor.py @@ -1,5 +1,4 @@ import dataclasses -from typing import List, Optional import jax import jax.nn as nn @@ -24,27 +23,27 @@ class LogitsProcessorOutput: next_token_logits: jax.Array # Used by speculative decoding (EAGLE) # The last hidden layers - hidden_states: Optional[jax.Array] = None + hidden_states: jax.Array | None = None ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler # The logprobs of the next tokens. shape: [#seq] - next_token_logprobs: Optional[jax.Array] = None + next_token_logprobs: jax.Array | None = None # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] - next_token_top_logprobs_val: Optional[List] = None - next_token_top_logprobs_idx: Optional[List] = None + next_token_top_logprobs_val: list | None = None + next_token_top_logprobs_idx: list | None = None # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids) - next_token_token_ids_logprobs_val: Optional[List] = None - next_token_token_ids_logprobs_idx: Optional[List] = None + next_token_token_ids_logprobs_val: list | None = None + next_token_token_ids_logprobs_idx: list | None = None ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor # The logprobs of input tokens. shape: [#token] - input_token_logprobs: Optional[jax.Array] = None + input_token_logprobs: jax.Array | None = None # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] - input_top_logprobs_val: List = None - input_top_logprobs_idx: List = None + input_top_logprobs_val: list = None + input_top_logprobs_idx: list = None # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids) - input_token_ids_logprobs_val: Optional[List] = None - input_token_ids_logprobs_idx: Optional[List] = None + input_token_ids_logprobs_val: list | None = None + input_token_ids_logprobs_idx: list | None = None def tree_flatten(self): children = ( @@ -107,13 +106,13 @@ class LogitsMetadata: extend_return_logprob: bool = False extend_return_top_logprob: bool = False extend_token_ids_logprob: bool = False - extend_seq_lens: Optional[jax.Array] = None - extend_seq_lens_cpu: Optional[List[int]] = None - extend_logprob_start_lens_cpu: Optional[List[int]] = None - extend_logprob_pruned_lens_cpu: Optional[List[int]] = None - top_logprobs_nums: Optional[List[int]] = None - extend_input_logprob_token_ids_device: Optional[jax.Array] = None - token_ids_logprobs: Optional[List[List[int]]] = None + extend_seq_lens: jax.Array | None = None + extend_seq_lens_cpu: list[int] | None = None + extend_logprob_start_lens_cpu: list[int] | None = None + extend_logprob_pruned_lens_cpu: list[int] | None = None + top_logprobs_nums: list[int] | None = None + extend_input_logprob_token_ids_device: jax.Array | None = None + token_ids_logprobs: list[list[int]] | None = None # logits and logprobs post processing temp_scaled_logprobs: bool = False @@ -293,7 +292,7 @@ def __call__( logits[sample_indices] if sample_indices is not None else logits ) - hidden_states_to_store: Optional[jax.Array] = None + hidden_states_to_store: jax.Array | None = None if logits_metadata.capture_hidden_mode.need_capture(): if logits_metadata.capture_hidden_mode.is_full(): hidden_states_to_store = hidden_states @@ -306,7 +305,7 @@ def __call__( else pruned_states ) else: - assert False, "Should never reach" + raise AssertionError() if not logits_metadata.extend_return_logprob: # Decode mode or extend mode without return_logprob. diff --git a/python/sgl_jax/srt/layers/moe.py b/python/sgl_jax/srt/layers/moe.py index c4edba147..e220cf657 100644 --- a/python/sgl_jax/srt/layers/moe.py +++ b/python/sgl_jax/srt/layers/moe.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Sequence, Tuple, Union +from collections.abc import Iterable, Sequence import jax from flax import nnx @@ -32,19 +32,18 @@ class GateLogit(nnx.Module): def __init__( self, input_size: int, - features: Union[Iterable[int], int], + features: Iterable[int] | int, model_name: str, - axis: Union[Iterable[int], int] = -1, + axis: Iterable[int] | int = -1, weight_dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32, - kernel_axes: Optional[Sequence[str]] = None, + kernel_axes: Sequence[str] | None = None, use_bias: bool = False, score_func: str = "", matmul_precision: str = "default", layer_id: int = 0, rngs: nnx.Rngs = None, ): - self.features = linear._canonicalize_tuple(features) self.axis = linear._canonicalize_tuple(axis) self.model_name = model_name @@ -78,7 +77,7 @@ def __init__( else: self.bias = None - def __call__(self, inputs: jax.Array) -> Tuple[jax.Array, Optional[jax.Array]]: + def __call__(self, inputs: jax.Array) -> tuple[jax.Array, jax.Array | None]: inputs = jnp.asarray(inputs, self.dtype) kernel = jnp.asarray(self.kernel.value, self.dtype) @@ -113,7 +112,6 @@ def __init__( layer_id: int = 0, rngs: nnx.Rngs = None, ): - self.config = config self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok diff --git a/python/sgl_jax/srt/layers/sampler.py b/python/sgl_jax/srt/layers/sampler.py index 28de0e167..5d5ba642d 100644 --- a/python/sgl_jax/srt/layers/sampler.py +++ b/python/sgl_jax/srt/layers/sampler.py @@ -1,5 +1,3 @@ -from typing import List - import jax import numpy as np from flax import nnx @@ -209,7 +207,7 @@ def __call__( return batch_next_token_ids -def get_top_logprobs(logprobs: jax.Array, top_logprobs_nums: List[int]): +def get_top_logprobs(logprobs: jax.Array, top_logprobs_nums: list[int]): max_k = max(top_logprobs_nums) values, indices = jax.lax.top_k(logprobs, max_k) values = values.tolist() @@ -223,7 +221,7 @@ def get_top_logprobs(logprobs: jax.Array, top_logprobs_nums: List[int]): return output_top_logprobs_val, output_top_logprobs_idx -def get_token_ids_logprobs(logprobs: jax.Array, token_ids_logprobs: List[List[int]]): +def get_token_ids_logprobs(logprobs: jax.Array, token_ids_logprobs: list[list[int]]): output_token_ids_logprobs_val = [] output_token_ids_logprobs_idx = [] for i, token_ids in enumerate(token_ids_logprobs): diff --git a/python/sgl_jax/srt/managers/detokenizer_manager.py b/python/sgl_jax/srt/managers/detokenizer_manager.py index 84f32d38f..526aad7da 100644 --- a/python/sgl_jax/srt/managers/detokenizer_manager.py +++ b/python/sgl_jax/srt/managers/detokenizer_manager.py @@ -6,7 +6,6 @@ import signal import threading from collections import OrderedDict -from typing import Dict, List, Optional, Union import psutil import setproctitle @@ -40,7 +39,7 @@ class DecodeStatus: """Store the status of incremental decoding.""" decoded_text: str - decode_ids: List[int] + decode_ids: list[int] surr_offset: int read_offset: int # Offset that's sent to tokenizer for incremental update. @@ -93,12 +92,12 @@ def event_loop(self): self.send_to_tokenizer.send_pyobj(output) def trim_matched_stop( - self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool + self, output: str | list[int], finished_reason: dict, no_stop_trim: bool ): if no_stop_trim or not finished_reason: return output - matched = finished_reason.get("matched", None) + matched = finished_reason.get("matched") if not matched: return output @@ -228,7 +227,7 @@ def deep_flatten(lst): for i in range(bs): try: s = self.decode_status[recv_obj.rids[i]] - except KeyError: + except KeyError as e: raise RuntimeError( f"Decode status not found for request {recv_obj.rids[i]}. " "It may be due to the request being evicted from the decode status due to memory pressure. " @@ -236,7 +235,7 @@ def deep_flatten(lst): "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. " f"The current value is {DETOKENIZER_MAX_STATES}. " "For more details, see: https://github.com/sgl-project/sglang/issues/2812" - ) + ) from e new_text = read_texts[i][len(surr_texts[i]) :] new_token_ids = read_ids[i][len(surr_ids[i]) :] if recv_obj.finished_reasons[i] is None: @@ -293,9 +292,9 @@ def deep_flatten(lst): def process_special_tokens_spaces( - token_ids: Optional[List[int]] = None, - skip_special_tokens: Optional[bool] = None, - all_special_ids: Optional[List[int]] = None, + token_ids: list[int] | None = None, + skip_special_tokens: bool | None = None, + all_special_ids: list[int] | None = None, ): if all_special_ids is None or not skip_special_tokens or token_ids is None: return token_ids @@ -329,7 +328,7 @@ def run_detokenizer_process( manager.event_loop() except Exception: traceback = get_exception_traceback() - logger.error(f"DetokenizerManager hit an exception: {traceback}") + logger.error("DetokenizerManager hit an exception: %s", traceback) parent_process.send_signal(signal.SIGQUIT) @@ -347,5 +346,5 @@ def run_detokenizer_thread( manager.event_loop() except Exception: traceback = get_exception_traceback() - logger.error(f"DetokenizerManager hit an exception: {traceback}") + logger.error("DetokenizerManager hit an exception: %s", traceback) current_process.send_signal(signal.SIGQUIT) diff --git a/python/sgl_jax/srt/managers/io_struct.py b/python/sgl_jax/srt/managers/io_struct.py index 1e9f48aaf..dd63aff44 100644 --- a/python/sgl_jax/srt/managers/io_struct.py +++ b/python/sgl_jax/srt/managers/io_struct.py @@ -3,7 +3,7 @@ import uuid from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any from sgl_jax.srt.managers.schedule_batch import BaseFinishReason @@ -11,35 +11,35 @@ @dataclass class BatchStrOut: # The request id - rids: List[str] + rids: list[str] # The finish reason - finished_reasons: List[dict] + finished_reasons: list[dict] # The output decoded strings - output_strs: List[str] + output_strs: list[str] # The token ids - output_ids: Optional[List[int]] + output_ids: list[int] | None # Token counts - prompt_tokens: List[int] - completion_tokens: List[int] - cached_tokens: List[int] + prompt_tokens: list[int] + completion_tokens: list[int] + cached_tokens: list[int] # Logprobs - input_token_logprobs_val: List[float] - input_token_logprobs_idx: List[int] - output_token_logprobs_val: List[float] - output_token_logprobs_idx: List[int] - input_top_logprobs_val: List[List] - input_top_logprobs_idx: List[List] - output_top_logprobs_val: List[List] - output_top_logprobs_idx: List[List] - input_token_ids_logprobs_val: List[List] - input_token_ids_logprobs_idx: List[List] - output_token_ids_logprobs_val: List[List] - output_token_ids_logprobs_idx: List[List] + input_token_logprobs_val: list[float] + input_token_logprobs_idx: list[int] + output_token_logprobs_val: list[float] + output_token_logprobs_idx: list[int] + input_top_logprobs_val: list[list] + input_top_logprobs_idx: list[list] + output_top_logprobs_val: list[list] + output_top_logprobs_idx: list[list] + input_token_ids_logprobs_val: list[list] + input_token_ids_logprobs_idx: list[list] + output_token_ids_logprobs_val: list[list] + output_token_ids_logprobs_idx: list[list] # Hidden states - output_hidden_states: List[List[float]] + output_hidden_states: list[list[float]] # Cache miss count cache_miss_count: int = None @@ -48,41 +48,41 @@ class BatchStrOut: @dataclass class BatchTokenIDOut: # The request id - rids: List[str] + rids: list[str] # The finish reason - finished_reasons: List[BaseFinishReason] + finished_reasons: list[BaseFinishReason] # For incremental decoding - decoded_texts: List[str] - decode_ids: List[List[int]] - read_offsets: List[int] + decoded_texts: list[str] + decode_ids: list[list[int]] + read_offsets: list[int] # Only used when `--skip-tokenizer-init` is on - output_ids: Optional[List[int]] + output_ids: list[int] | None # Detokenization configs - skip_special_tokens: List[bool] - spaces_between_special_tokens: List[bool] - no_stop_trim: List[bool] + skip_special_tokens: list[bool] + spaces_between_special_tokens: list[bool] + no_stop_trim: list[bool] # Token counts - prompt_tokens: List[int] - completion_tokens: List[int] - cached_tokens: List[int] + prompt_tokens: list[int] + completion_tokens: list[int] + cached_tokens: list[int] # Logprobs - input_token_logprobs_val: List[float] - input_token_logprobs_idx: List[int] - output_token_logprobs_val: List[float] - output_token_logprobs_idx: List[int] - input_top_logprobs_val: List[List] - input_top_logprobs_idx: List[List] - output_top_logprobs_val: List[List] - output_top_logprobs_idx: List[List] - input_token_ids_logprobs_val: List[List] - input_token_ids_logprobs_idx: List[List] - output_token_ids_logprobs_val: List[List] - output_token_ids_logprobs_idx: List[List] + input_token_logprobs_val: list[float] + input_token_logprobs_idx: list[int] + output_token_logprobs_val: list[float] + output_token_logprobs_idx: list[int] + input_top_logprobs_val: list[list] + input_top_logprobs_idx: list[list] + output_top_logprobs_val: list[list] + output_top_logprobs_idx: list[list] + input_token_ids_logprobs_val: list[list] + input_token_ids_logprobs_idx: list[list] + output_token_ids_logprobs_val: list[list] + output_token_ids_logprobs_idx: list[list] # Hidden states - output_hidden_states: List[List[float]] + output_hidden_states: list[list[float]] # Cache miss count cache_miss_count: int = None @@ -91,22 +91,22 @@ class BatchTokenIDOut: @dataclass class TokenizedGenerateReqInput: # The request id. - rid: Optional[Union[List[str], str]] = None + rid: list[str] | str | None = None # The input prompt. It can be a single prompt or a batch of prompts. - text: Optional[Union[List[str], str]] = None + text: list[str] | str | None = None # The token ids for text; one can specify either text or input_ids - input_ids: Optional[Union[List[List[int]], List[int]]] = None + input_ids: list[list[int]] | list[int] | None = None # The sampling_params. See descriptions below. - sampling_params: Optional[Union[List[Dict], Dict]] = None + sampling_params: list[dict] | dict | None = None # Whether to return logprobs. - return_logprob: Optional[Union[List[bool], bool]] = None + return_logprob: list[bool] | bool | None = None # If return logprobs, the start location in the prompt for returning logprobs. # By default, this value is "-1", which means it will only return logprobs for output tokens. - logprob_start_len: Optional[Union[List[int], int]] = -1 + logprob_start_len: list[int] | int | None = -1 # If return logprobs, the number of top logprobs to return at each position. - top_logprobs_num: Optional[Union[List[int], int]] = None + top_logprobs_num: list[int] | int | None = None # If return logprobs, the token ids to return logprob for. - token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None + token_ids_logprob: list[list[int]] | list[int] | None = None # Whether to stream output stream: bool = False @@ -126,7 +126,7 @@ class EmbeddingReqInput: rid: str = None text: str = "" - input_ids: List[int] = None + input_ids: list[int] = None normalize: bool = True @@ -135,23 +135,23 @@ class GenerateReqInput: """Request input for text generation.""" batch_size: int = 1 - rid: Optional[Union[List[str], str]] = None - text: Optional[Union[List[str], str]] = None - input_ids: List[int] = None + rid: list[str] | str | None = None + text: list[str] | str | None = None + input_ids: list[int] = None # The embeddings for input_ids; one can specify either text or input_ids or input_embeds. - input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None - sampling_params: Optional[Any] = ( + input_embeds: list[list[list[float]]] | list[list[float]] | None = None + sampling_params: Any | None = ( None # Using Any for now to avoid SamplingParams serialization issues ) stream: bool = False is_single: bool = True - return_logprob: Optional[Union[List[bool], bool]] = None + return_logprob: list[bool] | bool | None = None # If return logprobs, the start location in the prompt for returning logprobs. - logprob_start_len: Optional[Union[List[int], int]] = None + logprob_start_len: list[int] | int | None = None # If return logprobs, the number of top logprobs to return at each position. - top_logprobs_num: Optional[Union[List[int], int]] = None + top_logprobs_num: list[int] | int | None = None # If return logprobs, the token ids to return logprob for. - token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None + token_ids_logprob: list[list[int]] | list[int] | None = None # Whether to detokenize tokens in text in the returned logprobs. return_text_in_logprobs: bool = False @@ -387,9 +387,9 @@ class ResumeMemoryOccupationReqInput(RpcReqInput): class BatchEmbeddingOut: """Batch embedding output.""" - rids: List[str] - embeddings: List[List[float]] - prompt_tokens: List[int] + rids: list[str] + embeddings: list[list[float]] + prompt_tokens: list[int] # Request input classes for sessions @@ -413,7 +413,7 @@ class TokenizedEmbeddingReqInput: rid: str text: str - input_ids: List[int] + input_ids: list[int] normalize: bool = True @@ -423,7 +423,7 @@ class ConfigureLoggingReq(RpcReqInput): """Request to configure logging.""" log_level: str - log_file: Optional[str] = None + log_file: str | None = None @dataclass @@ -441,14 +441,14 @@ class GetInternalStateReq: @dataclass class GetInternalStateReqOutput: - internal_state: Dict[Any, Any] + internal_state: dict[Any, Any] @dataclass class SetInternalStateReq(RpcReqInput): """Request to set internal state.""" - state_data: Dict[str, Any] + state_data: dict[str, Any] # Profile classes @@ -456,19 +456,19 @@ class SetInternalStateReq(RpcReqInput): @dataclass class ProfileReqInput: - output_dir: Optional[str] = None - start_step: Optional[int] = None - num_steps: Optional[int] = None + output_dir: str | None = None + start_step: int | None = None + num_steps: int | None = None # Sets the trace level for host-side activities. # 0: Disables host (CPU) tracing entirely. # 1: Enables tracing of only user-instrumented TraceMe events (this is the default). # 2: Includes level 1 traces plus high-level program execution details like expensive XLA operations. # 3: Includes level 2 traces plus more verbose, low-level program execution details such as cheap XLA operations. - host_tracer_level: Optional[int] = None + host_tracer_level: int | None = None # Controls whether Python tracing is enabled. # 0: Disables Python function call tracing. # 1: Enables Python tracing (this is the default). - python_tracer_level: Optional[int] = None + python_tracer_level: int | None = None class ProfileReqType(Enum): @@ -479,12 +479,12 @@ class ProfileReqType(Enum): @dataclass class ProfileReq: type: ProfileReqType - output_dir: Optional[str] = None - start_step: Optional[int] = None - num_steps: Optional[int] = None - host_tracer_level: Optional[int] = None - python_tracer_level: Optional[int] = None - profile_id: Optional[str] = None + output_dir: str | None = None + start_step: int | None = None + num_steps: int | None = None + host_tracer_level: int | None = None + python_tracer_level: int | None = None + profile_id: str | None = None @dataclass @@ -579,8 +579,8 @@ class VertexGenerateReqInput(GenerateReqInput): class StartTraceReqInput(RpcReqInput): """Request to start precision tracing.""" - req_num: Optional[int] = None # Maximum number of requests to trace - output_file: Optional[str] = None # Output file path + req_num: int | None = None # Maximum number of requests to trace + output_file: str | None = None # Output file path request_id: str = "" # Override base class field with default def __post_init__(self): diff --git a/python/sgl_jax/srt/managers/schedule_batch.py b/python/sgl_jax/srt/managers/schedule_batch.py index 50754f359..5846877a8 100644 --- a/python/sgl_jax/srt/managers/schedule_batch.py +++ b/python/sgl_jax/srt/managers/schedule_batch.py @@ -21,7 +21,7 @@ import logging import threading from http import HTTPStatus -from typing import Any, List, Optional, Set, Tuple, Union +from typing import Any import numpy as np from jax import numpy as jnp @@ -68,7 +68,7 @@ def to_json(self): class FINISH_MATCHED_TOKEN(BaseFinishReason): - def __init__(self, matched: Union[int, List[int]]): + def __init__(self, matched: int | list[int]): super().__init__() self.matched = matched @@ -126,15 +126,15 @@ def __init__( self, rid: str, origin_input_text: str, - origin_input_ids: List[int], + origin_input_ids: list[int], sampling_params: SamplingParams, return_logprob: bool = False, top_logprobs_num: int = 0, - token_ids_logprob: List[int] = None, + token_ids_logprob: list[int] = None, stream: bool = False, - origin_input_ids_unpadded: Optional[Tuple[int]] = None, - eos_token_ids: Optional[Set[int]] = None, - vocab_size: Optional[int] = None, + origin_input_ids_unpadded: tuple[int] | None = None, + eos_token_ids: set[int] | None = None, + vocab_size: int | None = None, ): # Input and output info self.rid = rid @@ -154,7 +154,7 @@ def __init__( self.sampling_params = sampling_params # Memory pool info - self.req_pool_idx: Optional[int] = None + self.req_pool_idx: int | None = None # Check finish self.tokenizer = None @@ -219,18 +219,18 @@ def __init__( # Logprobs (return values) # True means the input logprob has been already sent to detokenizer. self.input_logprob_sent: bool = False - self.input_token_logprobs_val: Optional[List[float]] = None - self.input_token_logprobs_idx: Optional[List[int]] = None - self.input_top_logprobs_val: Optional[List[float]] = None - self.input_top_logprobs_idx: Optional[List[int]] = None - self.input_token_ids_logprobs_val: Optional[List[float]] = None - self.input_token_ids_logprobs_idx: Optional[List[int]] = None + self.input_token_logprobs_val: list[float] | None = None + self.input_token_logprobs_idx: list[int] | None = None + self.input_top_logprobs_val: list[float] | None = None + self.input_top_logprobs_idx: list[int] | None = None + self.input_token_ids_logprobs_val: list[float] | None = None + self.input_token_ids_logprobs_idx: list[int] | None = None # Temporary holder to store input_token_logprobs. - self.input_token_logprobs: Optional[List[Tuple[int]]] = None - self.temp_input_top_logprobs_val: Optional[List[np.ndarray]] = None - self.temp_input_top_logprobs_idx: Optional[List[int]] = None - self.temp_input_token_ids_logprobs_val: Optional[List[float]] = None - self.temp_input_token_ids_logprobs_idx: Optional[List[int]] = None + self.input_token_logprobs: list[tuple[int]] | None = None + self.temp_input_top_logprobs_val: list[np.ndarray] | None = None + self.temp_input_top_logprobs_idx: list[int] | None = None + self.temp_input_token_ids_logprobs_val: list[float] | None = None + self.temp_input_token_ids_logprobs_idx: list[int] | None = None if return_logprob: # shape: (bs, 1) @@ -247,7 +247,7 @@ def __init__( ) = self.output_top_logprobs_idx = self.output_token_ids_logprobs_val = ( self.output_token_ids_logprobs_idx ) = None - self.hidden_states: List[List[float]] = [] + self.hidden_states: list[list[float]] = [] # The number of cached tokens that were already cached in the KV cache self.cached_tokens = 0 @@ -284,7 +284,7 @@ def finished(self) -> bool: def init_next_round_input( self, - tree_cache: Optional[BasePrefixCache] = None, + tree_cache: BasePrefixCache | None = None, ): self.fill_ids = self.origin_input_ids + self.output_ids if tree_cache is not None: @@ -435,7 +435,7 @@ class ScheduleBatch: """Store all information of a batch on the scheduler.""" # Request, memory pool, and cache - reqs: List[Req] + reqs: list[Req] req_to_token_pool: ReqToTokenPool = None token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None tree_cache: BasePrefixCache = None @@ -449,7 +449,7 @@ class ScheduleBatch: # This is an optimization to reduce the overhead of the prefill check. batch_is_full: bool = False - chunked_req: Optional[Req] = None + chunked_req: Req | None = None # Sampling info sampling_info: SamplingBatchInfo = None @@ -468,21 +468,21 @@ class ScheduleBatch: # For processing logprobs return_logprob: bool = False - top_logprobs_nums: Optional[List[int]] = None - token_ids_logprobs: Optional[List[List[int]]] = None + top_logprobs_nums: list[int] | None = None + token_ids_logprobs: list[list[int]] | None = None # For logits and logprob post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False # For extend and mixed chunekd prefill - prefix_lens: List[int] = None - extend_lens: List[int] = None - extend_num_tokens: Optional[int] = None - decoding_reqs: List[Req] = None - extend_logprob_start_lens: List[int] = None + prefix_lens: list[int] = None + extend_lens: list[int] = None + extend_num_tokens: int | None = None + decoding_reqs: list[Req] = None + extend_logprob_start_lens: list[int] = None # It comes empty list if logprob is not required. - extend_input_logprob_token_ids: Optional[np.ndarray] = None + extend_input_logprob_token_ids: np.ndarray | None = None # Stream has_stream: bool = False @@ -499,19 +499,19 @@ class ScheduleBatch: is_prefill_only: bool = False # Events - launch_done: Optional[threading.Event] = None + launch_done: threading.Event | None = None @classmethod def init_new( cls, - reqs: List[Req], + reqs: list[Req], req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, tree_cache: BasePrefixCache, model_config: ModelConfig, enable_overlap: bool, enable_custom_logit_processor: bool = False, - chunked_req: Optional[Req] = None, + chunked_req: Req | None = None, mesh: mesh_lib.Mesh = None, ): return_logprob = any(req.return_logprob for req in reqs) @@ -575,9 +575,9 @@ def alloc_token_slots(self, num_tokens: int, backup_state: bool = False): def alloc_paged_token_slots_extend( self, - prefix_lens: List[int], - seq_lens: List[int], - last_loc: List[int], + prefix_lens: list[int], + seq_lens: list[int], + last_loc: list[int], extend_num_tokens: int, backup_state: bool = False, ): @@ -609,8 +609,8 @@ def alloc_paged_token_slots_extend( def alloc_paged_token_slots_decode( self, - seq_lens: List[int], - last_loc: List[int], + seq_lens: list[int], + last_loc: list[int], backup_state: bool = False, ): num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size @@ -635,7 +635,7 @@ def alloc_paged_token_slots_decode( else: return out_cache_loc - def mix_with_running(self, running_batch: "ScheduleBatch"): + def mix_with_running(self, running_batch: ScheduleBatch): # Use EXTEND instead of MIXED for precompile cache hit self.forward_mode = ForwardMode.EXTEND running_bs = running_batch.batch_size() @@ -977,8 +977,8 @@ def prepare_for_decode(self): def filter_batch( self, - chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None, - keep_indices: Optional[List[int]] = None, + chunked_req_to_exclude: Req | list[Req] | None = None, + keep_indices: list[int] | None = None, ): if keep_indices is None: if isinstance(chunked_req_to_exclude, Req): @@ -1021,7 +1021,7 @@ def filter_batch( self.sampling_info.filter_batch(np.array(keep_indices)) - def merge_batch(self, other: "ScheduleBatch"): + def merge_batch(self, other: ScheduleBatch): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it # needs to be called with pre-merged Batch.reqs. @@ -1059,7 +1059,7 @@ def get_model_worker_batch( self, token_paddings: list, bs_paddings: list, - cache_loc_paddings: List, + cache_loc_paddings: list, page_size: int, ) -> ModelWorkerBatch: if self.forward_mode.is_decode_or_idle(): @@ -1282,9 +1282,8 @@ def get_model_worker_batch( launch_done=self.launch_done, ) - def _generate_trace_info(self, real_bs: int, bid: int) -> List[str]: + def _generate_trace_info(self, real_bs: int, bid: int) -> list[str]: for req in self.reqs[:real_bs]: - if precision_tracer.get_trace_active(): # for chunked prefill trace if req.fill_ids: @@ -1304,7 +1303,9 @@ def _generate_trace_info(self, real_bs: int, bid: int) -> List[str]: if self.forward_mode == ForwardMode.EXTEND: precision_tracer.add_request_counter() logger.info( - f"Starting trace for request {precision_tracer.get_request_counter()}: {req.rid}" + "Starting trace for request %d: %s", + precision_tracer.get_request_counter(), + req.rid, ) def copy(self): @@ -1326,9 +1327,11 @@ def _evict_tree_cache_if_needed( if isinstance(self.tree_cache, ChunkCache): return - if self.token_to_kv_pool_allocator.available_size() < num_tokens: - if self.tree_cache is not None: - self.tree_cache.evict(num_tokens) + if ( + self.token_to_kv_pool_allocator.available_size() < num_tokens + and self.tree_cache is not None + ): + self.tree_cache.evict(num_tokens) def _is_available_size_sufficient(self, num_tokens: int) -> bool: return self.token_to_kv_pool_allocator.available_size() >= num_tokens @@ -1371,15 +1374,15 @@ class ModelWorkerBatch: # For logprob return_logprob: bool - top_logprobs_nums: Optional[List[int]] - token_ids_logprobs: Optional[List[List[int]]] + top_logprobs_nums: list[int] | None + token_ids_logprobs: list[list[int]] | None # For extend # extend_num_tokens: Optional[int] - extend_seq_lens: Optional[np.ndarray] - extend_prefix_lens: Optional[np.ndarray] - extend_logprob_start_lens: Optional[List[int]] - extend_input_logprob_token_ids: Optional[np.ndarray] + extend_seq_lens: np.ndarray | None + extend_prefix_lens: np.ndarray | None + extend_logprob_start_lens: list[int] | None + extend_input_logprob_token_ids: np.ndarray | None # For padding real_bs: int @@ -1393,10 +1396,10 @@ class ModelWorkerBatch: top_p: np.ndarray = None # Events - launch_done: Optional[threading.Event] = None + launch_done: threading.Event | None = None # Pre-initialized ForwardBatch for overlap scheduling optimization - forward_batch: Optional[Any] = None + forward_batch: Any | None = None def get_last_loc( diff --git a/python/sgl_jax/srt/managers/schedule_policy.py b/python/sgl_jax/srt/managers/schedule_policy.py index 822d67e4f..dcc174eec 100644 --- a/python/sgl_jax/srt/managers/schedule_policy.py +++ b/python/sgl_jax/srt/managers/schedule_policy.py @@ -5,7 +5,7 @@ from collections import defaultdict from contextlib import contextmanager from enum import Enum, auto -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union +from typing import TYPE_CHECKING from jax import numpy as jnp @@ -59,7 +59,7 @@ class CacheAgnosticPolicy(Enum): class SchedulePolicy: - Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy] + Policy = CacheAwarePolicy | CacheAgnosticPolicy def __init__( self, @@ -77,7 +77,7 @@ def __init__( disable=False, ) - def calc_priority(self, waiting_queue: List[Req]) -> bool: + def calc_priority(self, waiting_queue: list[Req]) -> bool: if self.policy == CacheAgnosticPolicy.FCFS: # A shortcut for FCFS return False @@ -110,7 +110,7 @@ def calc_priority(self, waiting_queue: List[Req]) -> bool: return prefix_computed - def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy: + def _determine_active_policy(self, waiting_queue: list[Req]) -> Policy: if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128: # Turn off the expensive prefix matching and sorting when the #queue is large. return CacheAgnosticPolicy.FCFS @@ -131,17 +131,17 @@ def _validate_and_adjust_policy( except ValueError: try: return CacheAgnosticPolicy(policy) - except ValueError: - raise ValueError(f"Unknown schedule_policy: {policy=}") + except ValueError as inner_err: + raise ValueError(f"Unknown schedule_policy: {policy=}") from inner_err def _compute_prefix_matches( - self, waiting_queue: List[Req], policy: CacheAwarePolicy - ) -> Set[int]: + self, waiting_queue: list[Req], policy: CacheAwarePolicy + ) -> set[int]: """ Computes and caches the matching prefixes for requests in the waiting queue, and handles in-batch prefix caching logic. """ - temporary_deprioritized: Set[int] = set() + temporary_deprioritized: set[int] = set() self.waiting_queue_radix_tree.reset() for r in waiting_queue: @@ -179,7 +179,7 @@ def _compute_prefix_matches( @staticmethod def _sort_by_longest_prefix( - waiting_queue: List[Req], temporary_deprioritized: Set[int] + waiting_queue: list[Req], temporary_deprioritized: set[int] ) -> None: """Sorts the waiting queue based on the longest prefix match.""" waiting_queue.sort( @@ -192,7 +192,7 @@ def _sort_by_longest_prefix( @staticmethod def _sort_by_dfs_weight( - waiting_queue: List[Req], tree_cache: BasePrefixCache + waiting_queue: list[Req], tree_cache: BasePrefixCache ) -> None: """Sorts the waiting queue based on a depth-first search weighting.""" last_node_to_reqs = defaultdict(list) @@ -213,17 +213,17 @@ def _sort_by_dfs_weight( ) @staticmethod - def _sort_by_longest_output(waiting_queue: List[Req]) -> None: + def _sort_by_longest_output(waiting_queue: list[Req]) -> None: """Sorts the waiting queue based on the longest output (max_new_tokens).""" waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) @staticmethod - def _sort_randomly(waiting_queue: List[Req]) -> None: + def _sort_randomly(waiting_queue: list[Req]) -> None: """Shuffles the waiting queue randomly.""" random.shuffle(waiting_queue) @staticmethod - def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None: + def _calc_weight(cur_node: TreeNode, node_to_weight: dict[TreeNode, int]) -> None: for child in cur_node.children.values(): SchedulePolicy._calc_weight(child, node_to_weight) node_to_weight[cur_node] += node_to_weight[child] @@ -231,9 +231,9 @@ def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> Non @staticmethod def _get_dfs_priority( cur_node: TreeNode, - node_to_priority: Dict[TreeNode, int], - last_node_to_reqs: Dict[TreeNode, List[Req]], - q: List, + node_to_priority: dict[TreeNode, int], + last_node_to_reqs: dict[TreeNode, list[Req]], + q: list, ) -> None: childs = [child for child in cur_node.children.values()] childs.sort(key=lambda x: -node_to_priority[x]) @@ -259,7 +259,7 @@ def __init__( running_batch: ScheduleBatch, new_token_ratio: float, rem_input_tokens: int, - rem_chunk_tokens: Optional[int], + rem_chunk_tokens: int | None, mixed_with_decode_tokens: int = 0, ): self.page_size = page_size diff --git a/python/sgl_jax/srt/managers/scheduler.py b/python/sgl_jax/srt/managers/scheduler.py index eba78e6c1..5255987fc 100644 --- a/python/sgl_jax/srt/managers/scheduler.py +++ b/python/sgl_jax/srt/managers/scheduler.py @@ -11,7 +11,6 @@ from collections import deque from dataclasses import dataclass from types import SimpleNamespace -from typing import Dict, List, Optional, Union import jax import numpy as np @@ -88,10 +87,10 @@ class ReceiveDataError(Exception): @dataclass class GenerationBatchResult: - logits_output: Optional[LogitsProcessorOutput] - next_token_ids: Optional[List[int]] # on device - extend_input_len_per_req: List[int] - extend_logprob_start_len_per_req: List[int] + logits_output: LogitsProcessorOutput | None + next_token_ids: list[int] | None # on device + extend_input_len_per_req: list[int] + extend_logprob_start_len_per_req: list[int] bid: int cache_miss_count: int @@ -202,10 +201,7 @@ def __init__( ici_parallelism=[-1, self.tp_size, 1], dcn_parallelism=[1, 1, 1] ) - if self.enable_overlap: - TpWorkerClass = ModelWorkerClient - else: - TpWorkerClass = ModelWorker + TpWorkerClass = ModelWorkerClient if self.enable_overlap else ModelWorker self.tp_worker = TpWorkerClass( server_args=server_args, @@ -234,15 +230,15 @@ def __init__( self.init_memory_pool_and_cache() # Init running status - self.waiting_queue: List[Req] = [] + self.waiting_queue: list[Req] = [] # The aborted requests - self.aborted_reqs: Dict[str, Req] = {} + self.aborted_reqs: dict[str, Req] = {} # The running decoding batch for continuous batching self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False) # The current forward batch - self.cur_batch: Optional[ScheduleBatch] = None + self.cur_batch: ScheduleBatch | None = None # The last forward batch - self.last_batch: Optional[ScheduleBatch] = None + self.last_batch: ScheduleBatch | None = None self.forward_ct = 0 self.forward_ct_decode = 0 self.num_generated_tokens = 0 @@ -312,7 +308,9 @@ def __init__( def sync_pub(self): logger.info( - f"[Publisher {self.node_rank}] Begins to synchronize, wait {self.nnodes-1} Subscribers" + "[Publisher %s] Begins to synchronize, wait %s Subscribers", + self.node_rank, + self.nnodes - 1, ) ready_count = 0 try: @@ -321,47 +319,47 @@ def sync_pub(self): if message == "READY": ready_count += 1 logger.info( - f"[Publisher {self.node_rank}] receives {ready_count} READY signal" + "[Publisher %s] receives %s READY signal", + self.node_rank, + ready_count, ) self.publisher_sync.send_string("ACK") else: self.publisher_sync.send_string("NACK") except zmq.Again: logger.error( - f"[Publisher {self.node_rank}] Fails to synchronize due to timeout" + "[Publisher %s] Fails to synchronize due to timeout", self.node_rank ) return False except Exception as e: - logger.error(f"[Publisher {self.node_rank}] Encounters error: {e}") + logger.error("[Publisher %s] Encounters error: %s", self.node_rank, e) return False - logger.info(f"[Publisher {self.node_rank}] Succeeds to synchronize!") + logger.info("[Publisher %s] Succeeds to synchronize!", self.node_rank) return True def sync_sub(self): - logger.info(f"[Subscriber {self.node_rank}] Begins to synchronize") + logger.info("[Subscriber %s] Begins to synchronize", self.node_rank) try: self.subscriber_sync.send_string("READY") ack = self.subscriber_sync.recv_string() if ack == "ACK": - logger.info(f"[Subscriber {self.node_rank}] Succeeds to synchronizes!") + logger.info("[Subscriber %s] Succeeds to synchronizes!", self.node_rank) return True else: logger.error( - f"[Subscriber {self.node_rank}] Fails to synchroinze with ack: {ack}" + "[Subscriber %s] Fails to synchroinze with ack: %s", + self.node_rank, + ack, ) return False except Exception as e: logger.error( - f"[Subscriber {self.node_rank}] Fails to synchronize with error: {e}" + "[Subscriber %s] Fails to synchronize with error: %s", self.node_rank, e ) return False def sync_pub_sub(self): - success = False - if self.node_rank == 0: - success = self.sync_pub() - else: - success = self.sync_sub() + success = self.sync_pub() if self.node_rank == 0 else self.sync_sub() if not success: raise SyncError("Fail to synchronize between publisher and subscribers") @@ -479,7 +477,9 @@ def run_publisher(self, recv_reqs): return True except Exception as e: logger.error( - f"[Publisher {self.node_rank}] Fails to send data with error: {e}" + "[Publisher %s] Fails to send data with error: %s", + self.node_rank, + e, ) return False @@ -491,11 +491,14 @@ def run_subscriber(self): return pickle.loads(serialized_data) except zmq.Again: logger.error( - f"[Subscriber {self.node_rank}] Fails to receive data with timeout, and try again" + "[Subscriber %s] Fails to receive data with timeout, and try again", + self.node_rank, ) except Exception as e: logger.error( - f"[Subscriber {self.node_rank}] Fails to receive or deserialize with error: {e}, and try again" + "[Subscriber %s] Fails to receive or deserialize with error: %s, and try again", + self.node_rank, + e, ) return None @@ -511,7 +514,7 @@ def broadcast_pyobj(self, recv_reqs): ) return recv_reqs - def recv_requests(self) -> List[Req]: + def recv_requests(self) -> list[Req]: """Receive results at node_rank = 0 and broadcast it to all other Node ranks.""" if self.node_rank == 0: recv_reqs = [] @@ -536,7 +539,7 @@ def recv_requests(self) -> List[Req]: recv_reqs = self.broadcast_pyobj(recv_reqs) return recv_reqs - def process_input_requests(self, recv_reqs: List): + def process_input_requests(self, recv_reqs: list): for recv_req in recv_reqs: output = self._request_dispatcher(recv_req) if output is not None: @@ -621,11 +624,13 @@ def set_internal_state(self, recv_req: SetInternalStateReq): # Update precision_tracer state in this process if "trace_active" in tracer_config: logger.info( - f"[SCHEDULER] check trace_active: {precision_tracer.get_trace_active()}" + "[SCHEDULER] check trace_active: %s", + precision_tracer.get_trace_active(), ) precision_tracer._trace_active = tracer_config["trace_active"] logger.info( - f"[SCHEDULER] Updated trace_active to: {precision_tracer._trace_active}" + "[SCHEDULER] Updated trace_active to: %s", + precision_tracer._trace_active, ) # Reset counters when starting trace @@ -640,23 +645,25 @@ def set_internal_state(self, recv_req: SetInternalStateReq): if "max_requests" in tracer_config: precision_tracer._max_requests = tracer_config["max_requests"] logger.info( - f"[SCHEDULER] Updated max_requests to: {precision_tracer._max_requests}" + "[SCHEDULER] Updated max_requests to: %s", + precision_tracer._max_requests, ) if "output_file" in tracer_config: precision_tracer._trace_output_file = tracer_config["output_file"] logger.info( - f"[SCHEDULER] Updated output_file to: {precision_tracer._trace_output_file}" + "[SCHEDULER] Updated output_file to: %s", + precision_tracer._trace_output_file, ) logger.info( - f"[SCHEDULER] Precision tracer state updated: {tracer_config}" + "[SCHEDULER] Precision tracer state updated: %s", tracer_config ) except Exception as e: success = False error_msg = str(e) - logger.info(f"[SCHEDULER] Error updating internal state: {error_msg}") + logger.info("[SCHEDULER] Error updating internal state: %s", error_msg) return SetInternalStateReqOutput( request_id=recv_req.request_id, success=success, error_msg=error_msg @@ -666,7 +673,7 @@ def _add_request_to_queue(self, req: Req): req.queue_time_start = time.perf_counter() self.waiting_queue.append(req) - def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False): + def _extend_requests_to_queue(self, reqs: list[Req], is_retracted: bool = False): self.waiting_queue.extend(reqs) def check_memory(self): @@ -676,7 +683,7 @@ def check_memory(self): token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n" if memory_leak: - msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}" + msg = f"token_to_kv_pool_allocator memory leak detected! {token_msg}" raise ValueError(msg) req_total_size = self.req_to_token_pool.size @@ -699,7 +706,7 @@ def _get_token_info(self): token_usage = num_used / self.max_total_num_tokens return num_used, token_usage, available_size, evictable_size - def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: + def get_next_batch_to_run(self) -> ScheduleBatch | None: chunked_req_to_exclude = set() if self.chunked_req: # Move the chunked request out of the batch so that we can merge @@ -746,7 +753,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: return ret - def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: + def get_new_batch_prefill(self) -> ScheduleBatch | None: # Handle the cases where prefill is not allowed if ( self.running_batch.batch_is_full or len(self.waiting_queue) == 0 @@ -796,7 +803,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: break # Update waiting queue - can_run_list: List[Req] = adder.can_run_list + can_run_list: list[Req] = adder.can_run_list if len(can_run_list) == 0: return None @@ -848,7 +855,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: return new_batch - def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: + def update_running_batch(self, batch: ScheduleBatch) -> ScheduleBatch | None: """Update the current running decoding batch.""" initial_bs = batch.batch_size() @@ -868,9 +875,10 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: self.new_token_ratio = new_token_ratio logger.info( - "KV cache pool is full. Retract requests. " - f"#retracted_reqs: {num_retracted_reqs}, " - f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" + "KV cache pool is full. Retract requests. #retracted_reqs: %d, #new_token_ratio: %.4f -> %.4f", + num_retracted_reqs, + old_ratio, + self.new_token_ratio, ) self._extend_requests_to_queue(retracted_reqs, is_retracted=True) @@ -887,7 +895,7 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: batch.prepare_for_decode() return batch - def run_batch(self, batch: ScheduleBatch) -> Union[GenerationBatchResult]: + def run_batch(self, batch: ScheduleBatch) -> GenerationBatchResult: """Run a batch.""" self.forward_ct += 1 @@ -964,10 +972,9 @@ def run_batch(self, batch: ScheduleBatch) -> Union[GenerationBatchResult]: def process_batch_result( self, batch: ScheduleBatch, - result: Union[GenerationBatchResult], - launch_done: Optional[threading.Event] = None, + result: GenerationBatchResult, + launch_done: threading.Event | None = None, ): - if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) elif batch.forward_mode.is_extend(): @@ -1014,7 +1021,7 @@ def watchdog_thread(self): time.sleep(self.watchdog_timeout // 2) pyspy_dump_schedulers() - logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") + logger.error("Watchdog timeout (watchdog_timeout=%s)", self.watchdog_timeout) print(file=sys.stderr, flush=True) print(file=sys.stdout, flush=True) @@ -1036,7 +1043,7 @@ def abort_request(self, recv_req: AbortReq): # We still need to send something back to TokenizerManager to clean up the state. req = self.waiting_queue.pop(i) self.send_to_tokenizer.send_pyobj(AbortReq(req.rid)) - logger.debug(f"Abort queued request. {req.rid=}") + logger.debug("Abort queued request. rid=%s", req.rid) # Delete requests in the running batch if self.cur_batch is self.running_batch or self.cur_batch is None: @@ -1051,14 +1058,14 @@ def abort_request(self, recv_req: AbortReq): # Abort method 3: set `to_abort=True` # The request will still run one decode forward pass. # Then we reuse all existing code to clean up the KV cache allocation. - logger.debug(f"Abort running request. {req.rid=}") + logger.debug("Abort running request. rid=%s", req.rid) req.to_abort = True def run_scheduler_process( server_args: ServerArgs, port_args: PortArgs, - dp_rank: Optional[int], + dp_rank: int | None, pipe_writer, ): # Generate the prefix @@ -1093,7 +1100,7 @@ def run_scheduler_process( except Exception: traceback = get_exception_traceback() - logger.error(f"Scheduler hit an exception: {traceback}") + logger.error("Scheduler hit an exception: %s", traceback) parent_process.send_signal(signal.SIGQUIT) @@ -1144,5 +1151,5 @@ def scheduler_loop_after_create(server_args, scheduler): scheduler.event_loop_normal() except Exception: traceback = get_exception_traceback() - logger.error(f"Scheduler hit an exception: {traceback}") + logger.error("Scheduler hit an exception: %s", traceback) current_process.send_signal(signal.SIGQUIT) diff --git a/python/sgl_jax/srt/managers/scheduler_metrics_mixin.py b/python/sgl_jax/srt/managers/scheduler_metrics_mixin.py index 161091512..49dfea89e 100644 --- a/python/sgl_jax/srt/managers/scheduler_metrics_mixin.py +++ b/python/sgl_jax/srt/managers/scheduler_metrics_mixin.py @@ -1,7 +1,6 @@ import logging import time from collections import defaultdict -from typing import List from sgl_jax.srt.managers.schedule_policy import PrefillAdder from sgl_jax.srt.managers.scheduler import Req, ScheduleBatch @@ -26,7 +25,7 @@ def init_metrics(self): def log_prefill_stats( self, adder: PrefillAdder, - can_run_list: List[Req], + can_run_list: list[Req], running_bs: int, ): gap_latency = time.perf_counter() - self.last_prefill_stats_tic @@ -60,7 +59,7 @@ def log_decode_stats(self, running_batch: ScheduleBatch = None): self.num_generated_tokens = 0 num_running_reqs = len(batch.reqs) num_used, token_usage, _, _ = self._get_token_info() - token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, " + token_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, " if RECORD_STEP_TIME: self.step_time_dict[num_running_reqs].append( diff --git a/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py b/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py index 8979b3ef9..975cfb759 100644 --- a/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py @@ -2,7 +2,7 @@ import logging import threading -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING import jax @@ -32,8 +32,8 @@ class SchedulerOutputProcessorMixin: def process_batch_result_prefill( self: Scheduler, batch: ScheduleBatch, - result: Union[GenerationBatchResult], - launch_done: Optional[threading.Event] = None, + result: GenerationBatchResult, + launch_done: threading.Event | None = None, ): skip_stream_req = None @@ -89,7 +89,10 @@ def process_batch_result_prefill( precision_tracer.add_completed_requests_count() precision_tracer.set_end_time_and_duration(req.rid) logger.info( - f"Request trace completed ({precision_tracer.get_completed_requests_count()}/{precision_tracer.get_max_requests()}): {req.rid}" + "Request trace completed (%d/%d): %s", + precision_tracer.get_completed_requests_count(), + precision_tracer.get_max_requests(), + req.rid, ) if ( precision_tracer.get_completed_requests_count() @@ -145,7 +148,7 @@ def process_batch_result_prefill( if batch.cache_miss_count > 0: logger.info( - f"Prefill batch. #bid: {result.bid}, #cache_miss: {cache_miss_count}" + "Prefill batch. #bid: %s, #cache_miss: %s", result.bid, cache_miss_count ) self.set_next_batch_sampling_info_done(batch) @@ -158,7 +161,7 @@ def process_batch_result_decode( self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult, - launch_done: Optional[threading.Event] = None, + launch_done: threading.Event | None = None, ): logits_output, next_token_ids, cache_miss_count = ( result.logits_output, @@ -212,7 +215,10 @@ def process_batch_result_decode( precision_tracer.add_completed_requests_count() precision_tracer.set_end_time_and_duration(req.rid) logger.info( - f"Request trace completed ({precision_tracer.get_completed_requests_count()}/{precision_tracer.get_max_requests()}): {req.rid}" + "Request trace completed (%d/%d): %s", + precision_tracer.get_completed_requests_count(), + precision_tracer.get_max_requests(), + req.rid, ) if ( precision_tracer.get_completed_requests_count() @@ -297,7 +303,7 @@ def add_input_logprob_return_values( # Important for the performance. assert isinstance(output.input_token_logprobs, tuple) - input_token_logprobs: Tuple[int] = output.input_token_logprobs + input_token_logprobs: tuple[int] = output.input_token_logprobs input_token_logprobs = input_token_logprobs[ logprob_pt : logprob_pt + num_input_logprobs ] @@ -394,7 +400,7 @@ def add_logprob_return_values( i: int, req: Req, pt: int, - next_token_ids: List[int], + next_token_ids: list[int], num_input_logprobs: int, output: LogitsProcessorOutput, ): @@ -422,9 +428,9 @@ def add_logprob_return_values( def stream_output( self: Scheduler, - reqs: List[Req], + reqs: list[Req], return_logprob: bool, - skip_req: Optional[Req] = None, + skip_req: Req | None = None, cache_miss_count: int = None, ): """Stream the output to detokenizer.""" @@ -433,13 +439,13 @@ def stream_output( def stream_output_generation( self: Scheduler, - reqs: List[Req], + reqs: list[Req], return_logprob: bool, - skip_req: Optional[Req] = None, + skip_req: Req | None = None, cache_miss_count: int = None, ): rids = [] - finished_reasons: List[BaseFinishReason] = [] + finished_reasons: list[BaseFinishReason] = [] decoded_texts = [] decode_ids_list = [] diff --git a/python/sgl_jax/srt/managers/scheduler_profiler_mixing.py b/python/sgl_jax/srt/managers/scheduler_profiler_mixing.py index 907186146..5d7714983 100644 --- a/python/sgl_jax/srt/managers/scheduler_profiler_mixing.py +++ b/python/sgl_jax/srt/managers/scheduler_profiler_mixing.py @@ -1,7 +1,6 @@ import logging import os from pathlib import Path -from typing import Optional import jax @@ -12,20 +11,20 @@ class SchedulerProfilerMixin: def init_profier(self): - self.profiler_output_dir: Optional[str] = None - self.profile_id: Optional[str] = None - self.profiler_start_forward_ct: Optional[int] = None - self.profiler_target_forward_ct: Optional[int] = None - self.profile_steps: Optional[int] = None + self.profiler_output_dir: str | None = None + self.profile_id: str | None = None + self.profiler_start_forward_ct: int | None = None + self.profiler_target_forward_ct: int | None = None + self.profile_steps: int | None = None self.profile_in_progress: bool = False def start_profile( self, - output_dir: Optional[str], - start_step: Optional[int], - num_steps: Optional[int], - host_tracer_level: Optional[int], - python_tracer_level: Optional[int], + output_dir: str | None, + start_step: int | None, + num_steps: int | None, + host_tracer_level: int | None, + python_tracer_level: int | None, profile_id: str, ) -> ProfileReqOutput: if self.profile_in_progress: @@ -68,7 +67,9 @@ def start_profile( return ProfileReqOutput(success=True, message="Succeeded") logger.info( - f"Profiling starts. Traces will be saved to: {self.profiler_output_dir} (with profile id: {self.profile_id})", + "Profiling starts. Traces will be saved to: %s (with profile id: %s)", + self.profiler_output_dir, + self.profile_id, ) profiler_options = jax.profiler.ProfileOptions() diff --git a/python/sgl_jax/srt/managers/template_manager.py b/python/sgl_jax/srt/managers/template_manager.py index a64442933..252fba1ff 100644 --- a/python/sgl_jax/srt/managers/template_manager.py +++ b/python/sgl_jax/srt/managers/template_manager.py @@ -6,7 +6,6 @@ """ import logging -from typing import Optional from sgl_jax.srt.conversation import get_conv_template_by_model_path @@ -24,22 +23,22 @@ class TemplateManager: def __init__(self): pass - self._chat_template_name: Optional[str] = None - self._completion_template_name: Optional[str] = None - self._jinja_template_content_format: Optional[str] = None + self._chat_template_name: str | None = None + self._completion_template_name: str | None = None + self._jinja_template_content_format: str | None = None @property - def chat_template_name(self) -> Optional[str]: + def chat_template_name(self) -> str | None: """Get the current chat template name.""" return self._chat_template_name @property - def completion_template_name(self) -> Optional[str]: + def completion_template_name(self) -> str | None: """Get the current completion template name.""" return self._completion_template_name @property - def jinja_template_content_format(self) -> Optional[str]: + def jinja_template_content_format(self) -> str | None: """Get the detected template content format ('string' or 'openai' or None).""" return self._jinja_template_content_format @@ -52,7 +51,7 @@ def guess_chat_template_from_model_path(self, model_path: str) -> None: """ template_name = get_conv_template_by_model_path(model_path) if template_name is not None: - logger.info(f"Inferred chat template from model path: {template_name}") + logger.info("Inferred chat template from model path: %s", template_name) self._chat_template_name = template_name def initialize_templates( diff --git a/python/sgl_jax/srt/managers/tokenizer_manager.py b/python/sgl_jax/srt/managers/tokenizer_manager.py index 98729d24f..124087c6f 100644 --- a/python/sgl_jax/srt/managers/tokenizer_manager.py +++ b/python/sgl_jax/srt/managers/tokenizer_manager.py @@ -16,7 +16,7 @@ from collections import deque from datetime import datetime from http import HTTPStatus -from typing import Any, Deque, Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import Any, Generic, TypeVar import fastapi import jax @@ -75,10 +75,10 @@ class ReqState: """Store the state a request.""" - out_list: List[Dict[Any, Any]] + out_list: list[dict[Any, Any]] finished: bool event: asyncio.Event - obj: Union[GenerateReqInput, EmbeddingReqInput] + obj: GenerateReqInput | EmbeddingReqInput # For metrics created_time: float @@ -91,19 +91,19 @@ class ReqState: last_output_offset: int = 0 text: str = "" - output_ids: List[int] = dataclasses.field(default_factory=list) - input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) - input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) - output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list) - output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list) - input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) - input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) - output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list) - output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list) - input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) - input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) - output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list) - output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list) + output_ids: list[int] = dataclasses.field(default_factory=list) + input_token_logprobs_val: list[float] = dataclasses.field(default_factory=list) + input_token_logprobs_idx: list[int] = dataclasses.field(default_factory=list) + output_token_logprobs_val: list[float] = dataclasses.field(default_factory=list) + output_token_logprobs_idx: list[int] = dataclasses.field(default_factory=list) + input_top_logprobs_val: list[list[float]] = dataclasses.field(default_factory=list) + input_top_logprobs_idx: list[list[int]] = dataclasses.field(default_factory=list) + output_top_logprobs_val: list[list[float]] = dataclasses.field(default_factory=list) + output_top_logprobs_idx: list[list[int]] = dataclasses.field(default_factory=list) + input_token_ids_logprobs_val: list = dataclasses.field(default_factory=list) + input_token_ids_logprobs_idx: list = dataclasses.field(default_factory=list) + output_token_ids_logprobs_val: list = dataclasses.field(default_factory=list) + output_token_ids_logprobs_idx: list = dataclasses.field(default_factory=list) class TokenizerManager: @@ -159,14 +159,14 @@ def __init__( # Store states self.no_create_loop = False - self.rid_to_state: Dict[str, ReqState] = {} + self.rid_to_state: dict[str, ReqState] = {} self.health_check_failed = False self.gracefully_exit = False self.last_receive_tstamp = 0 self.dump_requests_folder = "" # By default do not dump self.dump_requests_threshold = 1000 - self.dump_request_list: List[Tuple] = [] - self.crash_dump_request_list: deque[Tuple] = deque() + self.dump_request_list: list[tuple] = [] + self.crash_dump_request_list: deque[tuple] = deque() self.log_request_metadata = self.get_log_request_metadata() self.session_futures = {} # session_id -> asyncio event self.max_req_input_len = None @@ -239,8 +239,8 @@ def __init__( async def generate_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], - request: Optional[fastapi.Request] = None, + obj: GenerateReqInput | EmbeddingReqInput, + request: fastapi.Request | None = None, ): created_time = time.time() async with self._cond: @@ -258,7 +258,8 @@ async def generate_request( if self.log_requests: max_length, skip_names, _ = self.log_request_metadata logger.info( - f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}" + "Receive: obj=%s", + dataclass_to_string_truncated(obj, max_length, skip_names=skip_names), ) if obj.is_single: @@ -274,7 +275,7 @@ async def generate_request( async def _tokenize_one_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: GenerateReqInput | EmbeddingReqInput, ): """Tokenize one request.""" @@ -292,7 +293,7 @@ async def _tokenize_one_request( return self._create_tokenized_object(obj, input_text, input_ids) def _validate_one_request( - self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int] + self, obj: GenerateReqInput | EmbeddingReqInput, input_ids: list[int] ) -> None: """Validates that the input token count and the requested token count doesn't exceed the model's context length.""" @@ -321,7 +322,7 @@ def _validate_one_request( raise ValueError(error_msg) def _validate_input_ids_in_vocab( - self, input_ids: List[int], vocab_size: int + self, input_ids: list[int], vocab_size: int ) -> None: if any(id >= vocab_size for id in input_ids): raise ValueError( @@ -332,7 +333,7 @@ def _create_tokenized_object( self, obj: GenerateReqInput, input_text: str, - input_ids: List[int], + input_ids: list[int], ) -> TokenizedGenerateReqInput: """Create a tokenized request object from common parameters.""" # Parse sampling parameters @@ -364,9 +365,9 @@ def _create_tokenized_object( async def _batch_tokenize_and_process( self, batch_size: int, obj: GenerateReqInput - ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]: + ) -> list[TokenizedGenerateReqInput | TokenizedEmbeddingReqInput]: """Handle batch tokenization for text inputs only.""" - logger.debug(f"Starting batch tokenization for {batch_size} text requests") + logger.debug("Starting batch tokenization for %s text requests", batch_size) # Collect requests and texts requests = [obj[i] for i in range(batch_size)] @@ -385,11 +386,11 @@ async def _batch_tokenize_and_process( req, req.text, input_ids_list[i], None, None ) ) - logger.debug(f"Completed batch processing for {batch_size} requests") + logger.debug("Completed batch processing for %s requests", batch_size) return tokenized_objs def _validate_batch_tokenization_constraints( - self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput] + self, batch_size: int, obj: GenerateReqInput | EmbeddingReqInput ) -> None: """Validate constraints for batch tokenization processing.""" for i in range(batch_size): @@ -408,9 +409,9 @@ def _validate_batch_tokenization_constraints( def _send_one_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], - tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], - created_time: Optional[float] = None, + obj: GenerateReqInput | EmbeddingReqInput, + tokenized_obj: TokenizedGenerateReqInput | TokenizedEmbeddingReqInput, + created_time: float | None = None, ): self.send_to_scheduler.send_pyobj(tokenized_obj) state = ReqState([], False, asyncio.Event(), obj, created_time=created_time) @@ -421,9 +422,9 @@ def _send_one_request( async def _wait_one_response( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], + obj: GenerateReqInput | EmbeddingReqInput, state: ReqState, - request: Optional[fastapi.Request] = None, + request: fastapi.Request | None = None, ): """Wait for the response of one request.""" while True: @@ -434,9 +435,14 @@ async def _wait_one_response( # Abort the request for disconnected requests (non-streaming, waiting queue) self.abort_request(obj.rid) # Use exception to kill the whole call stack and asyncio task - raise ValueError( - f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}" - ) + try: + raise ValueError( + f"Request is disconnected from the client side (type 1). Abort request rid={obj.rid}" + ) + except ValueError as e: + raise ValueError( + f"Request is disconnected from the client side (type 1). Abort request rid={obj.rid}" + ) from e continue out = state.out_list[-1] @@ -475,9 +481,9 @@ async def _wait_one_response( async def _handle_batch_request( self, - obj: Union[GenerateReqInput, EmbeddingReqInput], - request: Optional[fastapi.Request] = None, - created_time: Optional[float] = None, + obj: GenerateReqInput | EmbeddingReqInput, + request: fastapi.Request | None = None, + created_time: float | None = None, ): batch_size = obj.batch_size @@ -574,11 +580,11 @@ def abort_request(self, rid: str = "", abort_all: bool = False): async def start_profile( self, - output_dir: Optional[str] = None, - start_step: Optional[int] = None, - num_steps: Optional[int] = None, - host_tracer_level: Optional[int] = None, - python_tracer_level: Optional[int] = None, + output_dir: str | None = None, + start_step: int | None = None, + num_steps: int | None = None, + host_tracer_level: int | None = None, + python_tracer_level: int | None = None, ): self.auto_create_handle_loop() req = ProfileReq( @@ -616,7 +622,7 @@ async def continue_generation(self): async def release_memory_occupation( self, obj: ReleaseMemoryOccupationReqInput, - request: Optional[fastapi.Request] = None, + request: fastapi.Request | None = None, ): self.auto_create_handle_loop() await self.release_memory_occupation_communicator(obj) @@ -624,13 +630,13 @@ async def release_memory_occupation( async def resume_memory_occupation( self, obj: ResumeMemoryOccupationReqInput, - request: Optional[fastapi.Request] = None, + request: fastapi.Request | None = None, ): self.auto_create_handle_loop() await self.resume_memory_occupation_communicator(obj) async def open_session( - self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None + self, obj: OpenSessionReqInput, request: fastapi.Request | None = None ): self.auto_create_handle_loop() @@ -647,13 +653,13 @@ async def open_session( return session_id async def close_session( - self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None + self, obj: CloseSessionReqInput, request: fastapi.Request | None = None ): await self.send_to_scheduler.send_pyobj(obj) - async def get_internal_state(self) -> List[Dict[Any, Any]]: + async def get_internal_state(self) -> list[dict[Any, Any]]: req = GetInternalStateReq() - responses: List[GetInternalStateReqOutput] = ( + responses: list[GetInternalStateReqOutput] = ( await self.get_internal_state_communicator(req) ) # Many DP ranks @@ -670,7 +676,7 @@ async def set_internal_state( self, obj: SetInternalStateReq ) -> SetInternalStateReqOutput: self.auto_create_handle_loop() - responses: List[SetInternalStateReqOutput] = ( + responses: list[SetInternalStateReqOutput] = ( await self.set_internal_state_communicator(obj) ) return ( @@ -748,7 +754,7 @@ def configure_logging(self, obj: ConfigureLoggingReq): self.dump_requests_threshold = obj.dump_requests_threshold if obj.crash_dump_folder is not None: self.crash_dump_folder = obj.crash_dump_folder - logging.info(f"Config logging: {obj=}") + logging.info("Config logging: %s", obj) self.log_request_metadata = self.get_log_request_metadata() def create_abort_task(self, obj: GenerateReqInput): @@ -802,7 +808,10 @@ def dump_requests_before_crash(self): "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping." ) return - logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}") + logger.error( + "Dumping requests before crash. crash_dump_folder=%s", + self.crash_dump_folder, + ) self.crash_dump_performed = True if not self.crash_dump_folder: return @@ -827,7 +836,7 @@ def dump_requests_before_crash(self): filename = os.path.join( self.crash_dump_folder, os.getenv("HOSTNAME", None), - f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl', + f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl", ) os.makedirs(os.path.dirname(filename), exist_ok=True) @@ -839,7 +848,10 @@ def dump_requests_before_crash(self): with open(filename, "wb") as f: pickle.dump(data_to_dump_with_server_args, f) logger.error( - f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}" + "Dumped %d finished and %d unfinished requests before crash to %s", + len(self.crash_dump_request_list), + len(unfinished_requests), + filename, ) async def sigterm_watchdog(self): @@ -868,7 +880,8 @@ async def sigterm_watchdog(self): break logger.info( - f"Gracefully exiting... remaining number of requests {remain_num_req}" + "Gracefully exiting... remaining number of requests %d", + remain_num_req, ) if remain_num_req > 0: await asyncio.sleep(5) @@ -889,13 +902,14 @@ async def handle_loop(self): def _handle_batch_output( self, - recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut], + recv_obj: BatchStrOut | BatchEmbeddingOut | BatchTokenIDOut, ): for i, rid in enumerate(recv_obj.rids): state = self.rid_to_state.get(rid, None) if state is None: logger.error( - f"Received output for {rid=} but the state was deleted in TokenizerManager." + "Received output for rid=%s but the state was deleted in TokenizerManager.", + rid, ) continue @@ -980,7 +994,7 @@ def convert_logprob_style( meta_info: dict, state: ReqState, top_logprobs_num: int, - token_ids_logprob: List[int], + token_ids_logprob: list[int], return_text_in_logprobs: bool, recv_obj: BatchStrOut, recv_obj_index: int, @@ -1065,8 +1079,8 @@ def convert_logprob_style( def detokenize_logprob_tokens( self, - token_logprobs_val: List[float], - token_logprobs_idx: List[int], + token_logprobs_val: list[float], + token_logprobs_idx: list[int], decode_to_text: bool, ): if not decode_to_text: @@ -1081,8 +1095,8 @@ def detokenize_logprob_tokens( def detokenize_top_logprobs_tokens( self, - token_logprobs_val: List[float], - token_logprobs_idx: List[int], + token_logprobs_val: list[float], + token_logprobs_idx: list[int], decode_to_text: bool, ): # We should batch all top-k tokens in all positions. @@ -1108,7 +1122,7 @@ def dump_requests(self, state: ReqState, out_dict: dict): self.dump_requests_folder, datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl", ) - logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}") + logger.info("Dump %s requests to %s", len(self.dump_request_list), filename) to_dump = self.dump_request_list self.dump_request_list = [] @@ -1164,13 +1178,13 @@ def _handle_open_session_req_output(self, recv_obj): async def score_request( self, - query: Optional[Union[str, List[int]]] = None, - items: Optional[Union[str, List[str], List[List[int]]]] = None, - label_token_ids: Optional[List[int]] = None, + query: str | list[int] | None = None, + items: str | list[str] | list[list[int]] | None = None, + label_token_ids: list[int] | None = None, apply_softmax: bool = False, item_first: bool = False, - request: Optional[Any] = None, - ) -> List[List[float]]: + request: Any | None = None, + ) -> list[list[float]]: """ See Engine.score() for more details. """ @@ -1266,7 +1280,7 @@ async def print_exception_wrapper(func): await func() except Exception: traceback = get_exception_traceback() - logger.error(f"TokenizerManager hit an exception: {traceback}") + logger.error("TokenizerManager hit an exception: %s", traceback) if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager): func.__self__.dump_requests_before_crash() kill_process_tree(os.getpid(), include_parent=True) @@ -1279,7 +1293,9 @@ def __init__(self, tokenizer_manager: TokenizerManager): def sigterm_handler(self, signum=None, frame=None): logger.warning( - f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." + "SIGTERM received. signum=%s frame=%s. Draining requests and shutting down...", + signum, + frame, ) self.tokenizer_manager.gracefully_exit = True @@ -1300,9 +1316,9 @@ class _Communicator(Generic[T]): def __init__(self, sender, fan_out: int): self._sender = sender self._fan_out = fan_out - self._result_event: Optional[asyncio.Event] = None - self._result_values: Optional[List[T]] = None - self._ready_queue: Deque[asyncio.Future] = deque() + self._result_event: asyncio.Event | None = None + self._result_values: list[T] | None = None + self._ready_queue: deque[asyncio.Future] = deque() async def __call__(self, obj): ready_event = asyncio.Event() diff --git a/python/sgl_jax/srt/managers/tp_worker.py b/python/sgl_jax/srt/managers/tp_worker.py index ada6e9ea6..1be5ada87 100644 --- a/python/sgl_jax/srt/managers/tp_worker.py +++ b/python/sgl_jax/srt/managers/tp_worker.py @@ -4,7 +4,6 @@ import logging import threading import time -from typing import Optional, Tuple, Union import jax import jax.numpy as jnp @@ -46,7 +45,7 @@ def __init__( self, server_args: ServerArgs, mesh: jax.sharding.Mesh, - req_to_token_pool: Optional[ReqToTokenPool] = None, + req_to_token_pool: ReqToTokenPool | None = None, ): # Parse args self.tp_size = server_args.tp_size @@ -62,14 +61,11 @@ def __init__( # Sync random seed across TP workers # Each process may have different random_seed. After broadcast, all processes will have the same random_seed. - # self.random_seed = broadcast_one_to_all(server_args.random_seed).item() if server_args.random_seed is None: with jax.default_device(jax.local_devices()[0]): - if jax.process_index() == 0: - seed_to_broadcast = server_args.random_seed - else: - seed_to_broadcast = 0 - + seed_to_broadcast = ( + server_args.random_seed if jax.process_index() == 0 else 0 + ) self.random_seed = broadcast_one_to_all(seed_to_broadcast).item() else: self.random_seed = server_args.random_seed @@ -108,13 +104,22 @@ def __init__( # Log each constraint for debugging logger.info("Max running requests constraints:") logger.info( - f" - Server limit: {server_limit} {'(max_total_tokens//2)' if server_args.max_running_requests is None else '(configured)'}" + " - Server limit: %s %s", + server_limit, + ( + "(max_total_tokens//2)" + if server_args.max_running_requests is None + else "(configured)" + ), ) - logger.info(f" - Token pool size: {pool_limit}") + logger.info(" - Token pool size: %s", pool_limit) logger.info( - f" - Attention backend: {attn_backend_limit} (context_len={self.model_config.context_len}, page_size={self.page_size})" + " - Attention backend: %s (context_len=%s, page_size=%s)", + attn_backend_limit, + self.model_config.context_len, + self.page_size, ) - logger.info(f" → Final max_running_requests: {self.max_running_requests}") + logger.info(" → Final max_running_requests: %s", self.max_running_requests) assert self.max_running_requests > 0, "max_running_request is zero" self.max_req_len = min( @@ -198,7 +203,9 @@ def run_precompile(self, future_token_ids_map=None): def precompile_extend(self, future_token_ids_map=None): start_time = time.perf_counter() logger.info( - f"[EXTEND] Begin to precompile bs_paddings={self.precompile_bs_paddings[-1:]} token_paddings={self.precompile_token_paddings}" + "[EXTEND] Begin to precompile bs_paddings=%s token_paddings=%s", + self.precompile_bs_paddings[-1:], + self.precompile_token_paddings, ) bs, _ = self.get_max_padded_size() @@ -210,7 +217,9 @@ def precompile_extend(self, future_token_ids_map=None): bs, num_tokens = pair[0], pair[1] pbar.set_postfix(bs=bs, tokens=num_tokens) if bs > num_tokens: - logger.warning(f"{bs=} > {num_tokens=}, skip this pair") + logger.warning( + "bs=%s > num_tokens=%s, skip this pair", bs, num_tokens + ) continue model_worker_batch = self.generate_model_worker_batch( bs, @@ -241,7 +250,8 @@ def precompile_extend(self, future_token_ids_map=None): def precompile_decode(self, future_token_ids_map=None): start_time = time.perf_counter() logger.info( - f"[DECODE] Begin to precompile bs_paddings={self.precompile_bs_paddings}" + "[DECODE] Begin to precompile bs_paddings=%s", + self.precompile_bs_paddings, ) with tqdm( @@ -287,7 +297,8 @@ def precompile_penalties(self, future_token_ids_map=None): """Precompile penalty application for different batch sizes and penalty combinations.""" start_time = time.perf_counter() logger.info( - f"[PENALTIES] Begin to precompile penalty applications bs_paddings={self.precompile_bs_paddings}" + "[PENALTIES] Begin to precompile penalty applications bs_paddings=%s", + self.precompile_bs_paddings, ) with tqdm( @@ -460,11 +471,11 @@ def get_memory_pool(self): def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, - launch_done: Optional[threading.Event] = None, + launch_done: threading.Event | None = None, skip_sample: bool = False, sampling_metadata: SamplingMetadata = None, forward_metadata=None, - ) -> Tuple[Union[LogitsProcessorOutput, jax.Array, int], Optional[jax.Array]]: + ) -> tuple[LogitsProcessorOutput | jax.Array | int, jax.Array | None]: # Use pre-initialized ForwardBatch if available (for overlap scheduling optimization) if model_worker_batch.forward_batch is not None: forward_batch = model_worker_batch.forward_batch @@ -606,10 +617,10 @@ def get_memory_pool(self): def forward_batch_generation( self, _model_worker_batch: ModelWorkerBatch, - _launch_done: Optional[threading.Event] = None, + _launch_done: threading.Event | None = None, _skip_sample: bool = False, - _sampling_metadata: Optional[SamplingMetadata] = None, - ) -> Tuple[Union[LogitsProcessorOutput, jax.Array], Optional[jax.Array]]: + _sampling_metadata: SamplingMetadata | None = None, + ) -> tuple[LogitsProcessorOutput | jax.Array, jax.Array | None]: return ( LogitsProcessorOutput( next_token_logits=jnp.array([0, 1]), diff --git a/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py b/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py index 90d1e0821..9530d13af 100644 --- a/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py +++ b/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py @@ -5,7 +5,6 @@ import signal import threading from queue import Queue -from typing import Optional, Tuple import jax import jax.numpy as jnp @@ -85,7 +84,7 @@ def forward_thread_func(self): self.forward_thread_func_() except Exception: traceback = get_exception_traceback() - logger.error(f"ModelWorkerClient hit an exception: {traceback}") + logger.error("ModelWorkerClient hit an exception: %s", traceback) self.parent_process.send_signal(signal.SIGQUIT) def forward_thread_func_(self): @@ -131,7 +130,7 @@ def forward_thread_func_(self): (None, logits_output, next_token_ids, cache_miss_count) ) - def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None): + def resolve_last_batch_result(self, launch_done: threading.Event | None = None): """ This function is called to resolve the last batch result and wait for the current batch to be launched. Used in overlap mode. @@ -158,7 +157,7 @@ def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, sampling_metadata: SamplingMetadata = None, - ) -> Tuple[None, jax.Array, int]: + ) -> tuple[None, jax.Array, int]: # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. sampling_info = model_worker_batch.sampling_info sampling_info.update_penalties() diff --git a/python/sgl_jax/srt/managers/utils.py b/python/sgl_jax/srt/managers/utils.py index b14255b17..19c5d8e4f 100644 --- a/python/sgl_jax/srt/managers/utils.py +++ b/python/sgl_jax/srt/managers/utils.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import jax from jax import numpy as jnp @@ -11,7 +10,7 @@ def validate_input_length( req: Req, max_req_input_len: int, allow_auto_truncate: bool -) -> Optional[str]: +) -> str | None: """Validate and potentially truncate input length. Args: @@ -25,9 +24,9 @@ def validate_input_length( if len(req.origin_input_ids) >= max_req_input_len: if allow_auto_truncate: logger.warning( - "Request length is longer than the KV cache pool size or " - "the max context length. Truncated. " - f"{len(req.origin_input_ids)=}, {max_req_input_len=}." + "Request length is longer than the KV cache pool size or the max context length. Truncated. len(origin_input_ids)=%s, max_req_input_len=%s", + len(req.origin_input_ids), + max_req_input_len, ) req.origin_input_ids = req.origin_input_ids[:max_req_input_len] return None diff --git a/python/sgl_jax/srt/mem_cache/allocator.py b/python/sgl_jax/srt/mem_cache/allocator.py index 0c217c767..ec27689c2 100644 --- a/python/sgl_jax/srt/mem_cache/allocator.py +++ b/python/sgl_jax/srt/mem_cache/allocator.py @@ -1,6 +1,5 @@ import abc import logging -from typing import List, Optional import numpy as np @@ -77,7 +76,7 @@ def clear(self): raise NotImplementedError() @abc.abstractmethod - def alloc(self, need_size: int) -> Optional[np.ndarray]: + def alloc(self, need_size: int) -> np.ndarray | None: raise NotImplementedError() @abc.abstractmethod @@ -105,7 +104,7 @@ def available_size(self) -> int: # To avoid minor "len(free_slots) * 1" overhead return len(self.free_slots) - def alloc(self, need_size: int) -> Optional[np.ndarray]: + def alloc(self, need_size: int) -> np.ndarray | None: if need_size > self.available_size(): return None @@ -143,7 +142,7 @@ def __init__( self.debug_mode = debug_mode self.clear() - def alloc(self, need_size: int) -> Optional[np.ndarray]: + def alloc(self, need_size: int) -> np.ndarray | None: # page-aligned allocation, returning contiguous indices of pages assert ( need_size % self.page_size == 0 @@ -166,11 +165,11 @@ def alloc(self, need_size: int) -> Optional[np.ndarray]: def alloc_extend( self, - prefix_lens: List[int], - seq_lens: List[int], - last_loc: List[int], + prefix_lens: list[int], + seq_lens: list[int], + last_loc: list[int], extend_num_tokens: int, - ) -> Optional[np.ndarray]: + ) -> np.ndarray | None: # Convert to numpy for internal operations seq_lens_np = np.array(seq_lens) prefix_lens_np = np.array(prefix_lens) @@ -266,9 +265,9 @@ def alloc_extend( def alloc_decode( self, - seq_lens: List[int], - last_loc: List[int], - ) -> Optional[np.ndarray]: + seq_lens: list[int], + last_loc: list[int], + ) -> np.ndarray | None: # Convert inputs to numpy for calculations seq_lens_np = np.array(seq_lens) last_loc_np = np.array(last_loc) diff --git a/python/sgl_jax/srt/mem_cache/base_prefix_cache.py b/python/sgl_jax/srt/mem_cache/base_prefix_cache.py index 2de4f98ee..07a968862 100644 --- a/python/sgl_jax/srt/mem_cache/base_prefix_cache.py +++ b/python/sgl_jax/srt/mem_cache/base_prefix_cache.py @@ -1,5 +1,5 @@ import abc -from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple +from typing import TYPE_CHECKING, Any, NamedTuple import jax.numpy as jnp @@ -23,8 +23,8 @@ class MatchResult(NamedTuple): """ device_indices: jnp.ndarray - last_device_node: Optional[TreeNode] - last_host_node: Optional[TreeNode] + last_device_node: TreeNode | None + last_host_node: TreeNode | None host_hit_length: int = 0 @@ -36,7 +36,7 @@ def reset(self): pass @abc.abstractmethod - def match_prefix(self, key: List[int], **kwargs) -> MatchResult: + def match_prefix(self, key: list[int], **kwargs) -> MatchResult: pass @abc.abstractmethod @@ -56,7 +56,7 @@ def inc_lock_ref(self, node: Any): pass @abc.abstractmethod - def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None): + def dec_lock_ref(self, node: Any, swa_uuid_for_lock: str | None = None): pass def evictable_size(self): @@ -87,7 +87,7 @@ def init_load_back( self, last_host_node: Any, host_hit_length: int, - ) -> Tuple[jnp.ndarray, Any]: + ) -> tuple[jnp.ndarray, Any]: """ Preparing KV cache loading from host to device. """ diff --git a/python/sgl_jax/srt/mem_cache/chunk_cache.py b/python/sgl_jax/srt/mem_cache/chunk_cache.py index 201c553ce..6fee53785 100644 --- a/python/sgl_jax/srt/mem_cache/chunk_cache.py +++ b/python/sgl_jax/srt/mem_cache/chunk_cache.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import numpy as np @@ -51,7 +51,7 @@ def evict(self, num_tokens: int): def inc_lock_ref(self, node: Any): return 0 - def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None): + def dec_lock_ref(self, node: Any, swa_uuid_for_lock: str | None = None): return 0 def pretty_print(self): diff --git a/python/sgl_jax/srt/mem_cache/memory_pool.py b/python/sgl_jax/srt/mem_cache/memory_pool.py index 9db380107..60d4acbc3 100644 --- a/python/sgl_jax/srt/mem_cache/memory_pool.py +++ b/python/sgl_jax/srt/mem_cache/memory_pool.py @@ -2,7 +2,6 @@ import logging import time from functools import partial -from typing import List, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -95,7 +94,7 @@ def available_size(self) -> int: """Return number of available request slots""" return len(self.free_slots) - def alloc(self, need_size: int = 1) -> List[int]: + def alloc(self, need_size: int = 1) -> list[int]: """Allocate request slots""" if need_size > len(self.free_slots): return None @@ -104,7 +103,7 @@ def alloc(self, need_size: int = 1) -> List[int]: self.free_slots = self.free_slots[need_size:] return select_indices - def free(self, free_index: Union[int, List[int]]): + def free(self, free_index: int | list[int]): """Free request slots""" if isinstance(free_index, int): self.free_slots.append(free_index) @@ -126,8 +125,8 @@ def __init__( dtype: jnp.dtype, layer_num: int, mesh: Mesh, - start_layer: Optional[int] = None, - end_layer: Optional[int] = None, + start_layer: int | None = None, + end_layer: int | None = None, ): self.size = size self.page_size = page_size @@ -170,7 +169,7 @@ def get_fused_kv_buffer(self, layer_id: int) -> jnp.ndarray: raise NotImplementedError() @abc.abstractmethod - def get_kv_buffer(self, layer_id: int) -> Tuple[jnp.ndarray, jnp.ndarray]: + def get_kv_buffer(self, layer_id: int) -> tuple[jnp.ndarray, jnp.ndarray]: """Get separate K and V buffers for native attention. Returns: @@ -213,8 +212,8 @@ def __init__( head_dim: int, layer_num: int, mesh: Mesh, - start_layer: Optional[int] = None, - end_layer: Optional[int] = None, + start_layer: int | None = None, + end_layer: int | None = None, ): super().__init__( size, page_size, dtype, layer_num, mesh, start_layer, end_layer @@ -272,7 +271,7 @@ def _create_buffers(self): """Create sharded fused KV cache buffers with proper distributed allocation""" self.kv_sharding = NamedSharding(self.mesh, P(None, self.kv_partition_axis)) - logger.info(f"Creating fused KV buffers for {self.layer_num} layers") + logger.info("Creating fused KV buffers for %s layers", self.layer_num) start_time = time.time() fused_buffer_shape = ( @@ -287,7 +286,9 @@ def _create_buffers(self): * jnp.dtype(self.dtype).itemsize ) logger.info( - f"Total fused KV cache memory per layer: {total_memory_per_layer / 1024**3:.2f} GB, dtype: {self.dtype}" + "Total fused KV cache memory per layer: %.2f GB, dtype: %s", + total_memory_per_layer / 1024**3, + self.dtype, ) with self.mesh: self.kv_buffer = [] @@ -304,7 +305,9 @@ def _create_buffers(self): end_time = time.time() logger.info( - f"Total time to create {self.layer_num} buffers: {end_time - start_time:.2f} seconds" + "Total time to create %s buffers: %.2f seconds", + self.layer_num, + end_time - start_time, ) def _calculate_memory_usage(self): @@ -320,8 +323,9 @@ def _calculate_memory_usage(self): self.mem_usage = fused_kv_size / GB logger.info( - f"JAX Fused KV Cache allocated. #tokens: {self.size}, " - f"Fused KV size: {fused_kv_size / GB:.2f} GB" + "JAX Fused KV Cache allocated. #tokens: %s, Fused KV size: %.2f GB", + self.size, + fused_kv_size / GB, ) def get_kv_size_bytes(self): @@ -342,7 +346,7 @@ def get_kv_size_bytes(self): def get_fused_kv_buffer(self, layer_id: int) -> jnp.ndarray: return self.kv_buffer[layer_id - self.start_layer] - def get_kv_buffer(self, layer_id: int) -> Tuple[jnp.ndarray, jnp.ndarray]: + def get_kv_buffer(self, layer_id: int) -> tuple[jnp.ndarray, jnp.ndarray]: layer_idx = layer_id - self.start_layer fused_kv = self.kv_buffer[layer_idx] # [cache_size, num_kv_heads * 2, head_dim] @@ -388,7 +392,7 @@ def set_kv_buffer( def get_kv_data( self, layer_id: int, indices: jnp.ndarray - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> tuple[jnp.ndarray, jnp.ndarray]: """Get KV data at specified positions""" layer_idx = layer_id - self.start_layer fused_kv_data = self.kv_buffer[layer_idx][indices] @@ -837,10 +841,7 @@ def update_fused_kv_cache_vectorized( def get_best_num_slices_per_block(head_num, cache_len, new_kv_len, head_dim, page_size): # keep same to original implementation - if page_size == 1: - num_slices_per_block = 4 - else: - num_slices_per_block = page_size + num_slices_per_block = 4 if page_size == 1 else page_size return num_slices_per_block @@ -893,7 +894,7 @@ def find_value(lst, target_num) -> int: # @partial(jax.jit, static_argnames=["layer_id"]) def _get_kv_buffer( layer_id: int, k_cache: jax.Array, v_cache: jax.Array -) -> Tuple[jax.Array, jax.Array]: +) -> tuple[jax.Array, jax.Array]: return k_cache[layer_id], v_cache[layer_id] @@ -908,8 +909,8 @@ def __init__( layer_num: int, mesh: Mesh, kv_partition_axis: str = "data", # Note: ignored in MLA, no sharding applied - start_layer: Optional[int] = None, - end_layer: Optional[int] = None, + start_layer: int | None = None, + end_layer: int | None = None, ): super().__init__( size, page_size, dtype, layer_num, mesh, start_layer, end_layer @@ -952,8 +953,9 @@ def _calculate_memory_usage(self): self.mem_usage = kv_size / GB logger.info( - f"JAX MLA KV Cache allocated. #tokens: {self.size}, " - f"KV size: {kv_size / GB:.2f} GB" + "JAX MLA KV Cache allocated. #tokens: %s, KV size: %.2f GB", + self.size, + kv_size / GB, ) def get_kv_size_bytes(self): @@ -974,7 +976,7 @@ def get_fused_kv_buffer(self, layer_id: int) -> jnp.ndarray: """ return self.kv_buffer[layer_id - self.start_layer] - def get_kv_buffer(self, layer_id: int) -> Tuple[jnp.ndarray, jnp.ndarray]: + def get_kv_buffer(self, layer_id: int) -> tuple[jnp.ndarray, jnp.ndarray]: """Get separate K and V buffers for native attention from MLA KV cache. Note: MLA architecture differs from standard MHA. For native attention compatibility, @@ -1186,87 +1188,12 @@ def load_cpu_copy(self, kv_cache_host, indices): "hn_32_mcl_320000_nvl_16384_hd_128_ps_128": 4, "hn_32_mcl_640000_nvl_1024_hd_128_ps_128": 512, "hn_32_mcl_640000_nvl_2048_hd_128_ps_128": 4096, - "hn_32_mcl_640000_nvl_4096_hd_128_ps_128": 4, - "hn_32_mcl_640000_nvl_9182_hd_128_ps_128": 2, - "hn_32_mcl_640000_nvl_16384_hd_128_ps_128": 1024, + "hn_32_mcl_640000_nvl_4096_hd_128_ps_128": 4096, + "hn_32_mcl_640000_nvl_9182_hd_128_ps_128": 256, + "hn_32_mcl_640000_nvl_16384_hd_128_ps_128": 128, "hn_32_mcl_1280000_nvl_1024_hd_128_ps_128": 32, "hn_32_mcl_1280000_nvl_2048_hd_128_ps_128": 2048, - "hn_32_mcl_1280000_nvl_4096_hd_128_ps_128": 128, - "hn_32_mcl_1280000_nvl_9182_hd_128_ps_128": 1024, - "hn_32_mcl_1280000_nvl_16384_hd_128_ps_128": 1024, - "hn_8_mcl_80000_nvl_1024_hd_128_ps_256": 2, - "hn_8_mcl_80000_nvl_2048_hd_128_ps_256": 4, - "hn_8_mcl_80000_nvl_4096_hd_128_ps_256": 2, - "hn_8_mcl_80000_nvl_9182_hd_128_ps_256": 512, - "hn_8_mcl_80000_nvl_16384_hd_128_ps_256": 2048, - "hn_8_mcl_160000_nvl_1024_hd_128_ps_256": 4, - "hn_8_mcl_160000_nvl_2048_hd_128_ps_256": 256, - "hn_8_mcl_160000_nvl_4096_hd_128_ps_256": 128, - "hn_8_mcl_160000_nvl_9182_hd_128_ps_256": 2048, - "hn_8_mcl_160000_nvl_16384_hd_128_ps_256": 2048, - "hn_8_mcl_320000_nvl_1024_hd_128_ps_256": 4, - "hn_8_mcl_320000_nvl_2048_hd_128_ps_256": 1024, - "hn_8_mcl_320000_nvl_4096_hd_128_ps_256": 64, - "hn_8_mcl_320000_nvl_9182_hd_128_ps_256": 16, - "hn_8_mcl_320000_nvl_16384_hd_128_ps_256": 512, - "hn_8_mcl_640000_nvl_1024_hd_128_ps_256": 1024, - "hn_8_mcl_640000_nvl_2048_hd_128_ps_256": 8, - "hn_8_mcl_640000_nvl_4096_hd_128_ps_256": 16, - "hn_8_mcl_640000_nvl_9182_hd_128_ps_256": 16, - "hn_8_mcl_640000_nvl_16384_hd_128_ps_256": 4096, - "hn_8_mcl_1280000_nvl_1024_hd_128_ps_256": 64, - "hn_8_mcl_1280000_nvl_2048_hd_128_ps_256": 2, - "hn_8_mcl_1280000_nvl_4096_hd_128_ps_256": 2048, - "hn_8_mcl_1280000_nvl_9182_hd_128_ps_256": 1024, - "hn_8_mcl_1280000_nvl_16384_hd_128_ps_256": 128, - "hn_16_mcl_80000_nvl_1024_hd_128_ps_256": 2, - "hn_16_mcl_80000_nvl_2048_hd_128_ps_256": 16, - "hn_16_mcl_80000_nvl_4096_hd_128_ps_256": 64, - "hn_16_mcl_80000_nvl_9182_hd_128_ps_256": 256, - "hn_16_mcl_80000_nvl_16384_hd_128_ps_256": 16, - "hn_16_mcl_160000_nvl_1024_hd_128_ps_256": 4, - "hn_16_mcl_160000_nvl_2048_hd_128_ps_256": 2, - "hn_16_mcl_160000_nvl_4096_hd_128_ps_256": 128, - "hn_16_mcl_160000_nvl_9182_hd_128_ps_256": 16, - "hn_16_mcl_160000_nvl_16384_hd_128_ps_256": 8, - "hn_16_mcl_320000_nvl_1024_hd_128_ps_256": 16, - "hn_16_mcl_320000_nvl_2048_hd_128_ps_256": 8, - "hn_16_mcl_320000_nvl_4096_hd_128_ps_256": 4, - "hn_16_mcl_320000_nvl_9182_hd_128_ps_256": 8, - "hn_16_mcl_320000_nvl_16384_hd_128_ps_256": 8, - "hn_16_mcl_640000_nvl_1024_hd_128_ps_256": 512, - "hn_16_mcl_640000_nvl_2048_hd_128_ps_256": 1024, - "hn_16_mcl_640000_nvl_4096_hd_128_ps_256": 2048, - "hn_16_mcl_640000_nvl_9182_hd_128_ps_256": 4096, - "hn_16_mcl_640000_nvl_16384_hd_128_ps_256": 32, - "hn_16_mcl_1280000_nvl_1024_hd_128_ps_256": 4, - "hn_16_mcl_1280000_nvl_2048_hd_128_ps_256": 2, - "hn_16_mcl_1280000_nvl_4096_hd_128_ps_256": 1024, - "hn_16_mcl_1280000_nvl_9182_hd_128_ps_256": 2048, - "hn_16_mcl_1280000_nvl_16384_hd_128_ps_256": 16, - "hn_32_mcl_80000_nvl_1024_hd_128_ps_256": 4, - "hn_32_mcl_80000_nvl_2048_hd_128_ps_256": 256, - "hn_32_mcl_80000_nvl_4096_hd_128_ps_256": 4096, - "hn_32_mcl_80000_nvl_9182_hd_128_ps_256": 128, - "hn_32_mcl_80000_nvl_16384_hd_128_ps_256": 512, - "hn_32_mcl_160000_nvl_1024_hd_128_ps_256": 64, - "hn_32_mcl_160000_nvl_2048_hd_128_ps_256": 4096, - "hn_32_mcl_160000_nvl_4096_hd_128_ps_256": 4096, - "hn_32_mcl_160000_nvl_9182_hd_128_ps_256": 256, - "hn_32_mcl_160000_nvl_16384_hd_128_ps_256": 128, - "hn_32_mcl_320000_nvl_1024_hd_128_ps_256": 4, - "hn_32_mcl_320000_nvl_2048_hd_128_ps_256": 64, - "hn_32_mcl_320000_nvl_4096_hd_128_ps_256": 1024, - "hn_32_mcl_320000_nvl_9182_hd_128_ps_256": 256, - "hn_32_mcl_320000_nvl_16384_hd_128_ps_256": 32, - "hn_32_mcl_640000_nvl_1024_hd_128_ps_256": 256, - "hn_32_mcl_640000_nvl_2048_hd_128_ps_256": 8, - "hn_32_mcl_640000_nvl_4096_hd_128_ps_256": 64, - "hn_32_mcl_640000_nvl_9182_hd_128_ps_256": 32, - "hn_32_mcl_640000_nvl_16384_hd_128_ps_256": 32, - "hn_32_mcl_1280000_nvl_1024_hd_128_ps_256": 256, - "hn_32_mcl_1280000_nvl_2048_hd_128_ps_256": 2048, - "hn_32_mcl_1280000_nvl_4096_hd_128_ps_256": 8, - "hn_32_mcl_1280000_nvl_9182_hd_128_ps_256": 4, - "hn_32_mcl_1280000_nvl_16384_hd_128_ps_256": 2, + "hn_32_mcl_1280000_nvl_4096_hd_128_ps_128": 8, + "hn_32_mcl_1280000_nvl_9182_hd_128_ps_128": 4, + "hn_32_mcl_1280000_nvl_16384_hd_128_ps_128": 2, } diff --git a/python/sgl_jax/srt/mem_cache/radix_cache.py b/python/sgl_jax/srt/mem_cache/radix_cache.py index 87c51c0e7..e55c6c154 100644 --- a/python/sgl_jax/srt/mem_cache/radix_cache.py +++ b/python/sgl_jax/srt/mem_cache/radix_cache.py @@ -4,7 +4,7 @@ import time from collections import defaultdict from functools import partial -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING import jax import jax.numpy as jnp @@ -21,7 +21,7 @@ class TreeNode: counter = 0 - def __init__(self, id: Optional[int] = None): + def __init__(self, id: int | None = None): self.children = defaultdict(TreeNode) self.parent = None self.key = None @@ -46,11 +46,11 @@ def evicted(self): def backuped(self): return self.host_value is not None - def __lt__(self, other: "TreeNode"): + def __lt__(self, other: TreeNode): return self.last_access_time < other.last_access_time -def _key_match_page_size1(key0: List, key1: List): +def _key_match_page_size1(key0: list, key1: list): i = 0 for k0, k1 in zip(key0, key1): if k0 != k1: @@ -59,7 +59,7 @@ def _key_match_page_size1(key0: List, key1: List): return i -def _key_match_paged(key0: List, key1: List, page_size: int): +def _key_match_paged(key0: list, key1: list, page_size: int): min_len = min(len(key0), len(key1)) i = 0 @@ -116,7 +116,7 @@ def __init__( ) self.reset() - def _create_tokens_data(self, tokens: List[int]) -> np.ndarray: + def _create_tokens_data(self, tokens: list[int]) -> np.ndarray: if self.disable: return np.array(tokens, dtype=np.int32) @@ -130,7 +130,7 @@ def reset(self): self.evictable_size_ = 0 self.protected_size_ = 0 - def match_prefix(self, key: List[int], **kwargs) -> MatchResult: + def match_prefix(self, key: list[int], **kwargs) -> MatchResult: if self.disable or len(key) == 0: empty_array = np.empty((0,), dtype=np.int32) @@ -170,7 +170,7 @@ def match_prefix(self, key: List[int], **kwargs) -> MatchResult: host_hit_length=0, ) - def insert(self, key: List, value=None): + def insert(self, key: list, value=None): if self.disable: return 0 @@ -266,7 +266,7 @@ def cache_unfinished_req(self, req): req.last_node = new_last_node # note: get_cached_kv is only used by test, skip to replace jnp with np - def get_cached_kv(self, token_ids: List[int]) -> Tuple[jnp.ndarray, int]: + def get_cached_kv(self, token_ids: list[int]) -> tuple[jnp.ndarray, int]: if self.disable: with jax.default_device(self.cpu_device): empty_kv = jnp.zeros( @@ -381,7 +381,7 @@ def inc_lock_ref(self, node: TreeNode): node = node.parent return delta - def dec_lock_ref(self, node: TreeNode, swa_uuid_for_lock: Optional[str] = None): + def dec_lock_ref(self, node: TreeNode, swa_uuid_for_lock: str | None = None): if self.disable: return 0 @@ -411,13 +411,13 @@ def take_events(self): ##### Internal Helper Functions ##### - def _match_prefix_helper(self, node: TreeNode, key: List): + def _match_prefix_helper(self, node: TreeNode, key: list): node.last_access_time = time.monotonic() child_key = self.get_child_key_fn(key) token_sequences = [] - while len(key) > 0 and child_key in node.children.keys(): + while len(key) > 0 and child_key in node.children: child = node.children[child_key] child.last_access_time = time.monotonic() prefix_len = self.key_match_fn(child.key, key) @@ -451,7 +451,7 @@ def _split_node(self, key, child: TreeNode, split_len: int): return new_node - def _insert_helper(self, node: TreeNode, key: List, value): + def _insert_helper(self, node: TreeNode, key: list, value): if isinstance(value, jnp.ndarray): assert value.ndim == 1, "value must be a 1D array" @@ -462,7 +462,7 @@ def _insert_helper(self, node: TreeNode, key: List, value): child_key = self.get_child_key_fn(key) total_prefix_length = 0 - while len(key) > 0 and child_key in node.children.keys(): + while len(key) > 0 and child_key in node.children: node = node.children[child_key] node.last_access_time = time.monotonic() prefix_len = self.key_match_fn(node.key, key) diff --git a/python/sgl_jax/srt/memory_profiler.py b/python/sgl_jax/srt/memory_profiler.py index a8d56f079..4ec3086ff 100644 --- a/python/sgl_jax/srt/memory_profiler.py +++ b/python/sgl_jax/srt/memory_profiler.py @@ -3,8 +3,8 @@ import json import logging import os +from collections.abc import Callable from contextlib import contextmanager -from typing import Callable, Dict, List, Optional, Union import jax import jax.profiler @@ -49,7 +49,7 @@ def from_env(self): def configure_memory_profiler( enabled: bool = None, output_dir: str = None, - layer_filter: Union[List[int], Callable, None] = None, + layer_filter: list[int] | Callable | None = None, generate_prof: bool = None, generate_reports: bool = None, log_to_console: bool = None, @@ -71,7 +71,7 @@ def configure_memory_profiler( _config.log_to_console = log_to_console -def _should_profile_layer(layer_id: Optional[int]) -> bool: +def _should_profile_layer(layer_id: int | None) -> bool: if not _config.enabled or layer_id is None: return False @@ -111,10 +111,10 @@ def _save_memory_snapshot(filename: str, condition: bool = True): ) jax.profiler.save_device_memory_profile(output_path) except Exception as e: - logger.warning(f"Failed to save memory snapshot {filename}: {e}") + logger.warning("Failed to save memory snapshot %s: %s", filename, e) -def _log_tensor_memory(stage: str, layer_id: Optional[int] = None, **tensors): +def _log_tensor_memory(stage: str, layer_id: int | None = None, **tensors): if not _config.enabled or not _config.log_to_console: return @@ -122,7 +122,7 @@ def _log_tensor_memory(stage: str, layer_id: Optional[int] = None, **tensors): return layer_info = f"Layer {layer_id}" if layer_id is not None else "Global" - logger.info(f" [Memory] {layer_info} - {stage}:") + logger.info(" [Memory] %s - %s:", layer_info, stage) total_memory = 0 tensor_info = [] @@ -133,10 +133,14 @@ def _log_tensor_memory(stage: str, layer_id: Optional[int] = None, **tensors): total_memory += memory_mb tensor_info.append((name, memory_mb, tensor.shape, tensor.dtype)) logger.info( - f" {name:<20}: {memory_mb:>8.2f} MB - {tensor.shape} {tensor.dtype}" + " %-20s: %8.2f MB - %s %s", + name, + memory_mb, + tensor.shape, + tensor.dtype, ) - logger.info(f" Total Memory: {total_memory:.2f} MB") + logger.info(" Total Memory: %.2f MB", total_memory) logger.info("-" * 80) return tensor_info, total_memory @@ -144,8 +148,8 @@ def _log_tensor_memory(stage: str, layer_id: Optional[int] = None, **tensors): def _create_memory_report( stage: str, - tensor_dict: Dict[str, jax.Array], - layer_id: Optional[int] = None, + tensor_dict: dict[str, jax.Array], + layer_id: int | None = None, report_type: str = "general", ): if not _config.enabled or not _config.generate_reports: @@ -227,17 +231,19 @@ def _create_memory_report( with open(json_report_path, "w") as f: json.dump(json_report, f, indent=2) - logger.debug(f" Generated memory reports: {report_path}, {json_report_path}") + logger.debug( + " Generated memory reports: %s, %s", report_path, json_report_path + ) except Exception as e: - logger.warning(f"Failed to create memory report for {stage}: {e}") + logger.warning("Failed to create memory report for %s: %s", stage, e) class MemoryProfiler: def __init__( self, stage: str, - layer_id: Optional[int] = None, + layer_id: int | None = None, report_type: str = "general", auto_snapshot: bool = True, ): @@ -279,12 +285,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): def memory_profile( stage: str, - layer_id: Optional[int] = None, + layer_id: int | None = None, report_type: str = "general", include_args: bool = False, include_result: bool = True, ): - def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): @@ -331,7 +336,6 @@ def wrapper(*args, **kwargs): @contextmanager def profile_memory_scope(stage: str, **tensors): - if not _config.enabled: yield return @@ -366,7 +370,7 @@ def move_reports_to_output_dir(): return moved_files -def generate_summary_report(output_dir: Optional[str] = None): +def generate_summary_report(output_dir: str | None = None): if output_dir is None: output_dir = _config.output_dir @@ -385,7 +389,7 @@ def generate_summary_report(output_dir: Optional[str] = None): for json_file in json_files: try: - with open(json_file, "r") as f: + with open(json_file) as f: report = json.load(f) stage = report.get("stage", "unknown") @@ -411,21 +415,21 @@ def generate_summary_report(output_dir: Optional[str] = None): summary["layer_analysis"][layer_id][stage] = total_memory except Exception as e: - logger.warning(f"Failed to process {json_file}: {e}") + logger.warning("Failed to process %s: %s", json_file, e) summary_path = os.path.join(output_dir, "memory_summary.json") with open(summary_path, "w") as f: json.dump(summary, f, indent=2) - logger.info(f"Generated memory summary report: {summary_path}") + logger.info("Generated memory summary report: %s", summary_path) return summary -def profile_attention(stage: str, layer_id: Optional[int] = None): +def profile_attention(stage: str, layer_id: int | None = None): return memory_profile(stage, layer_id, report_type="attention", include_result=True) -def profile_mlp(stage: str, layer_id: Optional[int] = None): +def profile_mlp(stage: str, layer_id: int | None = None): return memory_profile(stage, layer_id, report_type="mlp", include_result=True) diff --git a/python/sgl_jax/srt/model_executor/forward_batch_info.py b/python/sgl_jax/srt/model_executor/forward_batch_info.py index cde2ed3a6..a6b300122 100644 --- a/python/sgl_jax/srt/model_executor/forward_batch_info.py +++ b/python/sgl_jax/srt/model_executor/forward_batch_info.py @@ -20,11 +20,12 @@ from dataclasses import dataclass from enum import IntEnum, auto from functools import total_ordering -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING import jax from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import register_pytree_node_class + from sgl_jax.srt.utils.jax_utils import device_array logger = logging.getLogger(__name__) @@ -153,11 +154,11 @@ class ForwardBatch: cache_loc: jax.Array = None # For extend - extend_prefix_lens: Optional[jax.Array] = None - extend_seq_lens: Optional[jax.Array] = None + extend_prefix_lens: jax.Array | None = None + extend_seq_lens: jax.Array | None = None - trace_request_ids: Optional[List[str]] = None - trace_request_objects: Optional[List] = None + trace_request_ids: list[str] | None = None + trace_request_objects: list | None = None def tree_flatten(self): children = ( diff --git a/python/sgl_jax/srt/model_executor/model_runner.py b/python/sgl_jax/srt/model_executor/model_runner.py index 54137edf2..3615226bf 100644 --- a/python/sgl_jax/srt/model_executor/model_runner.py +++ b/python/sgl_jax/srt/model_executor/model_runner.py @@ -3,7 +3,6 @@ import logging import os from functools import partial -from typing import Optional, Tuple, Union import jax import jax.numpy as jnp @@ -61,8 +60,8 @@ def __init__( tp_size: int, server_args: ServerArgs, mesh: jax.sharding.Mesh, - req_to_token_pool: Optional[ReqToTokenPool] = None, - token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None, + req_to_token_pool: ReqToTokenPool | None = None, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator | None = None, rngs: nnx.Rngs = None, ): # Parse args @@ -197,18 +196,18 @@ def get_available_device_memory(self): # Check memory for tensor parallelism local_device_memory = get_available_device_memory(self.device) - if self.tp_size > 1: - if min_available_device_memory < local_device_memory * 0.9: - if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"): - logger.warning( - "The memory capacity is unbalanced. " - f"{min_available_device_memory=}, {local_device_memory=}, {local_device_memory * 0.9=}" - ) - else: - raise ValueError( - "The memory capacity is unbalanced. " - f"{min_available_device_memory=}, {local_device_memory=}, {local_device_memory * 0.9=}" - ) + if self.tp_size > 1 and min_available_device_memory < local_device_memory * 0.9: + if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"): + logger.warning( + "The memory capacity is unbalanced. min_available_device_memory=%s, local_device_memory=%s, local_device_memory*0.9=%s", + min_available_device_memory, + local_device_memory, + local_device_memory * 0.9, + ) + else: + raise ValueError( + f"The memory capacity is unbalanced. min_available_device_memory={min_available_device_memory}, local_device_memory={local_device_memory}, local_device_memory*0.9={local_device_memory * 0.9}" + ) return min_available_device_memory @@ -257,20 +256,20 @@ def profile_max_num_token(self, total_device_memory: int): max_tokens = max(1, int(available_kv_cache_bytes // cell_size)) logger.info( - f"TPU Memory profiling: " - f"available_device_memory={available_device_memory / (1024**3):.1f}GB, " - f"available_kv_cache={available_kv_cache_bytes / (1024**3):.1f}GB, " - f"max_tokens={max_tokens}, " - f"cell_size={cell_size}bytes" + "TPU Memory profiling: available_device_memory=%.1fGB, available_kv_cache=%.1fGB, max_tokens=%d, cell_size=%dbytes", + available_device_memory / (1024**3), + available_kv_cache_bytes / (1024**3), + max_tokens, + cell_size, ) return max_tokens def init_memory_pool( self, - max_num_reqs: Optional[int] = None, - max_total_tokens: Optional[int] = None, - total_device_memory: Optional[int] = None, + max_num_reqs: int | None = None, + max_total_tokens: int | None = None, + total_device_memory: int | None = None, ): """Initialize memory pool for KV cache.""" # Set KV cache data type @@ -282,7 +281,7 @@ def init_memory_pool( raise ValueError( f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." ) - logger.info(f"ModelRunner kv_cache_dtype: {self.kv_cache_dtype}") + logger.info("ModelRunner kv_cache_dtype: %s", self.kv_cache_dtype) # Profile maximum number of tokens self.max_total_num_tokens = self.profile_max_num_token(total_device_memory) @@ -307,9 +306,9 @@ def init_memory_pool( if max_total_tokens is not None: if max_total_tokens > self.max_total_num_tokens: logger.warning( - f"max_total_tokens={max_total_tokens} is larger than the profiled value " - f"{self.max_total_num_tokens}. " - f"Use the profiled value instead." + "max_total_tokens=%s is larger than the profiled value %s. Use the profiled value instead.", + max_total_tokens, + self.max_total_num_tokens, ) self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) @@ -325,7 +324,7 @@ def init_memory_pool( "Not enough memory. Please try to increase --mem-fraction-static." ) - logger.info(f"ModelRunner max_total_num_tokens: {self.max_total_num_tokens}") + logger.info("ModelRunner max_total_num_tokens: %s", self.max_total_num_tokens) # Create request to token pool if not already created if self.req_to_token_pool is None: @@ -438,14 +437,14 @@ def forward_idle( self, forward_batch: ForwardBatch, logits_metadata: LogitsMetadata, - ) -> Tuple[LogitsProcessorOutput, int]: + ) -> tuple[LogitsProcessorOutput, int]: raise NotImplementedError("forward_idle is not implemented") def forward( self, forward_batch: ForwardBatch, logits_metadata: LogitsMetadata, - ) -> Tuple[LogitsProcessorOutput, int]: + ) -> tuple[LogitsProcessorOutput, int]: self.forward_pass_id += 1 precision_tracer.start_batch_trace(forward_batch.bid) precision_tracer.set_current_forward_pass_id(self.forward_pass_id) @@ -455,7 +454,7 @@ def _forward_raw( self, forward_batch: ForwardBatch, logits_metadata: LogitsMetadata, - ) -> Tuple[LogitsProcessorOutput, int]: + ) -> tuple[LogitsProcessorOutput, int]: # for compatibility, 0.6.3 need to use use_mesh. set_mesh is not have __entry__ attribute. # on jax >=0.7.1, we need to use set_mesh. try: @@ -504,7 +503,7 @@ def sample( class MockModelRunner(ModelRunner): def __init__( self, - model_config: Union[ModelConfig, MockModelConfig], + model_config: ModelConfig | MockModelConfig, rngs: nnx.Rngs = None, mesh: mesh_lib.Mesh = None, server_args: ServerArgs = None, diff --git a/python/sgl_jax/srt/model_loader/arch.py b/python/sgl_jax/srt/model_loader/arch.py index f90563411..09a335cbb 100644 --- a/python/sgl_jax/srt/model_loader/arch.py +++ b/python/sgl_jax/srt/model_loader/arch.py @@ -1,7 +1,7 @@ """Utilities for selecting and loading models.""" import logging -from typing import Any, Tuple +from typing import Any import transformers from transformers.dynamic_module_utils import get_class_from_dynamic_module @@ -46,8 +46,7 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str if model_config.model_impl == ModelImpl.TRANSFORMERS: if not model_module.is_backend_compatible(): raise ValueError( - f"The Transformers implementation of {arch} is not " - "compatible with SGLang." + f"The Transformers implementation of {arch} is not compatible with SGLang." ) architectures[i] = "TransformersForCausalLM" if model_config.model_impl == ModelImpl.AUTO: @@ -66,7 +65,7 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str return architectures -def get_model_architecture(model_config: ModelConfig) -> Tuple[Any, str]: +def get_model_architecture(model_config: ModelConfig) -> tuple[Any, str]: from sgl_jax.srt.models.registry import ModelRegistry architectures = getattr(model_config.hf_config, "architectures", []) diff --git a/python/sgl_jax/srt/model_loader/loader.py b/python/sgl_jax/srt/model_loader/loader.py index f70b45ee3..6b9033466 100644 --- a/python/sgl_jax/srt/model_loader/loader.py +++ b/python/sgl_jax/srt/model_loader/loader.py @@ -48,7 +48,7 @@ class JAXModelLoader(BaseModelLoader): @dataclasses.dataclass class JAXSource: model_or_path: str - revision: Optional[str] + revision: str | None @classmethod def init_new(cls, model_config: ModelConfig): @@ -108,8 +108,8 @@ def _get_model(self, model_class: Any, model_config: ModelConfig) -> nnx.Module: return model def _maybe_download_from_modelscope( - self, model: str, revision: Optional[str] - ) -> Optional[str]: + self, model: str, revision: str | None + ) -> str | None: if get_bool_env_var("SGLANG_USE_MODELSCOPE"): # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. @@ -129,8 +129,8 @@ def _maybe_download_from_modelscope( return None def _prepare_weights( - self, model_name_or_path: str, revision: Optional[str] - ) -> Tuple[str, List[str]]: + self, model_name_or_path: str, revision: str | None + ) -> tuple[str, list[str]]: model_path = self._maybe_download_from_modelscope(model_name_or_path, revision) if model_path is not None: model_name_or_path = model_path diff --git a/python/sgl_jax/srt/models/llama.py b/python/sgl_jax/srt/models/llama.py index 9b1f76961..9006664da 100644 --- a/python/sgl_jax/srt/models/llama.py +++ b/python/sgl_jax/srt/models/llama.py @@ -17,7 +17,7 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any import jax import jax.numpy as jnp @@ -32,10 +32,7 @@ ) from sgl_jax.srt.layers.layernorm import RMSNorm from sgl_jax.srt.layers.linear import LinearBase -from sgl_jax.srt.layers.logits_processor import ( - LogitsMetadata, - LogitsProcessor, -) +from sgl_jax.srt.layers.logits_processor import LogitsMetadata, LogitsProcessor from sgl_jax.srt.layers.radix_attention import RadixAttention from sgl_jax.srt.mem_cache.memory_pool import KVCache from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch @@ -102,9 +99,9 @@ def __init__( num_kv_heads: int, layer_id: int = 0, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - head_dim: Optional[int] = None, - partial_rotary_factor: Optional[int] = None, + rope_scaling: dict[str, Any] | None = None, + head_dim: int | None = None, + partial_rotary_factor: int | None = None, rope_is_neox_style: bool = True, max_position_embeddings: int = 8192, dtype: jnp.dtype = jnp.bfloat16, @@ -271,8 +268,8 @@ def __call__( hidden_states: jax.Array, forward_batch: ForwardBatch, token_to_kv_pool: KVCache, - residual: Optional[jax.Array], - ) -> Tuple[jax.Array, jax.Array]: + residual: jax.Array | None, + ) -> tuple[jax.Array, jax.Array]: layer_callback_flag = [] if residual is None: residual = hidden_states @@ -392,7 +389,7 @@ def __init__( self.mesh = mesh self.config = config self.dtype = dtype - logger.info(f"LlamaForCausalLM config dtype: {self.dtype}") + logger.info("LlamaForCausalLM config dtype: %s", self.dtype) self.transformer = LlamaModel(config, dtype=self.dtype, rngs=rngs) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, rngs=rngs) self.logits_processor = LogitsProcessor( diff --git a/python/sgl_jax/srt/models/qwen.py b/python/sgl_jax/srt/models/qwen.py index 15ffcb2bc..96291e90e 100644 --- a/python/sgl_jax/srt/models/qwen.py +++ b/python/sgl_jax/srt/models/qwen.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Optional +from typing import Any import jax import jax.numpy as jnp @@ -76,7 +76,7 @@ def __init__( num_heads: int, max_position_embeddings: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, + rope_scaling: dict[str, Any] | None = None, layer_id: int = 0, dtype: jnp.dtype = jnp.float16, rngs: nnx.Rngs = None, @@ -304,7 +304,7 @@ def __init__( self.mesh = mesh self.config = config self.dtype = dtype - logger.info(f"QWenLMHeadModel config dtype: {self.dtype}") + logger.info("QWenLMHeadModel config dtype: %s", self.dtype) self.transformer = QWenModel(config, dtype=self.dtype, rngs=rngs) vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(vocab_size, config.hidden_size, rngs=rngs) diff --git a/python/sgl_jax/srt/models/qwen2.py b/python/sgl_jax/srt/models/qwen2.py index 7a821214a..360182772 100644 --- a/python/sgl_jax/srt/models/qwen2.py +++ b/python/sgl_jax/srt/models/qwen2.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any import jax import jax.numpy as jnp @@ -77,8 +77,8 @@ def __init__( num_kv_heads: int, max_position_embeddings: int, rope_theta: float = 1000000, - rope_scaling: Optional[Dict[str, Any]] = None, - head_dim: Optional[int] = None, + rope_scaling: dict[str, Any] | None = None, + head_dim: int | None = None, layer_id: int = 0, dtype: jnp.dtype = jnp.bfloat16, rngs: nnx.Rngs = None, @@ -221,8 +221,8 @@ def __call__( hidden_states: jax.Array, forward_batch: ForwardBatch, token_to_kv_pool: KVCache, - residual: Optional[jax.Array] = None, - ) -> Tuple[jax.Array, jax.Array]: + residual: jax.Array | None = None, + ) -> tuple[jax.Array, jax.Array]: layer_callback_flag = [] if residual is None: residual = hidden_states @@ -254,7 +254,6 @@ def __init__( dtype: jnp.dtype = jnp.bfloat16, rngs: nnx.Rngs = None, ): - self.embed_tokens = Embed( num_embeddings=config.vocab_size, features=config.hidden_size, @@ -321,7 +320,7 @@ def __init__( self.mesh = mesh self.config = config self.dtype = dtype - logger.info(f"Qwen2ForCausalLM config dtype: {self.dtype}") + logger.info("Qwen2ForCausalLM config dtype: %s", self.dtype) self.transformer = Qwen2Model(config, dtype=self.dtype, rngs=rngs) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, rngs=rngs) self.logits_processor = LogitsProcessor( diff --git a/python/sgl_jax/srt/models/qwen3.py b/python/sgl_jax/srt/models/qwen3.py index 16aa6130b..38a3b4291 100644 --- a/python/sgl_jax/srt/models/qwen3.py +++ b/python/sgl_jax/srt/models/qwen3.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any import jax import jax.numpy as jnp @@ -30,8 +30,8 @@ def __init__( num_kv_heads: int, max_position_embeddings: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - head_dim: Optional[int] = None, + rope_scaling: dict[str, Any] | None = None, + head_dim: int | None = None, rms_norm_eps: float = None, layer_id: int = 0, attention_bias: bool = False, @@ -242,8 +242,8 @@ def __call__( hidden_states: jax.Array, forward_batch: ForwardBatch, token_to_kv_pool: KVCache, - residual: Optional[jax.Array] = None, - ) -> Tuple[jax.Array, jax.Array]: + residual: jax.Array | None = None, + ) -> tuple[jax.Array, jax.Array]: layer_callback_flag = [] if residual is None: residual = hidden_states @@ -289,7 +289,6 @@ def __init__( dtype: jnp.dtype = jnp.bfloat16, rngs: nnx.Rngs = None, ): - self.embed_tokens = Embed( num_embeddings=config.vocab_size, features=config.hidden_size, @@ -360,7 +359,7 @@ def __init__( self.mesh = mesh self.config = config self.dtype = dtype - logger.info(f"QWen3ForCausalLMModel config dtype: {self.dtype}") + logger.info("QWen3ForCausalLMModel config dtype: %s", self.dtype) self.transformer = QWen3Model(config, dtype=self.dtype, rngs=rngs) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, rngs=rngs) self.logits_processor = LogitsProcessor( diff --git a/python/sgl_jax/srt/models/qwen3_moe.py b/python/sgl_jax/srt/models/qwen3_moe.py index faafedb57..62aee4033 100644 --- a/python/sgl_jax/srt/models/qwen3_moe.py +++ b/python/sgl_jax/srt/models/qwen3_moe.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any from flax import nnx from jax import jax @@ -31,8 +31,8 @@ def __init__( num_kv_heads: int, max_position_embeddings: int, rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - head_dim: Optional[int] = None, + rope_scaling: dict[str, Any] | None = None, + head_dim: int | None = None, rms_norm_eps: float = None, layer_id: int = 0, attention_bias: bool = False, @@ -234,8 +234,8 @@ def __call__( hidden_states: jax.Array, forward_batch: ForwardBatch, token_to_kv_pool: KVCache, - residual: Optional[jax.Array] = None, - ) -> Tuple[jax.Array, jax.Array]: + residual: jax.Array | None = None, + ) -> tuple[jax.Array, jax.Array]: if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -344,7 +344,7 @@ def __init__( self.mesh = mesh self.config = config self.dtype = dtype - logger.info(f"QWen3MoeForCausalLMModel config dtype: {self.dtype}") + logger.info("QWen3MoeForCausalLMModel config dtype: %s", self.dtype) self.transformer = QWen3MoeModel(config, dtype=self.dtype, rngs=rngs, mesh=mesh) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, rngs=rngs) self.logits_processor = LogitsProcessor( diff --git a/python/sgl_jax/srt/models/registry.py b/python/sgl_jax/srt/models/registry.py index 22999affd..3a8c32a70 100644 --- a/python/sgl_jax/srt/models/registry.py +++ b/python/sgl_jax/srt/models/registry.py @@ -1,9 +1,10 @@ import importlib import logging import pkgutil +from collections.abc import Set as AbstractSet from dataclasses import dataclass, field from functools import lru_cache -from typing import AbstractSet, Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any logger = logging.getLogger(__name__) @@ -11,12 +12,12 @@ @dataclass class _ModelRegistry: # Keyed by model_arch - models: Dict[str, Union[Type[Any], str]] = field(default_factory=dict) + models: dict[str, type[Any] | str] = field(default_factory=dict) def get_supported_archs(self) -> AbstractSet[str]: return self.models.keys() - def _raise_for_unsupported(self, architectures: List[str]): + def _raise_for_unsupported(self, architectures: list[str]): all_supported_archs = self.get_supported_archs() if any(arch in all_supported_archs for arch in architectures): @@ -30,7 +31,7 @@ def _raise_for_unsupported(self, architectures: List[str]): f"Supported architectures: {all_supported_archs}" ) - def _try_load_model_cls(self, model_arch: str) -> Optional[Type[Any]]: + def _try_load_model_cls(self, model_arch: str) -> type[Any] | None: if model_arch not in self.models: return None @@ -38,8 +39,8 @@ def _try_load_model_cls(self, model_arch: str) -> Optional[Type[Any]]: def _normalize_archs( self, - architectures: Union[str, List[str]], - ) -> List[str]: + architectures: str | list[str], + ) -> list[str]: if isinstance(architectures, str): architectures = [architectures] if not architectures: @@ -57,8 +58,8 @@ def _normalize_archs( def resolve_model_cls( self, - architectures: Union[str, List[str]], - ) -> Tuple[Type[Any], str]: + architectures: str | list[str], + ) -> tuple[type[Any], str]: architectures = self._normalize_archs(architectures) for arch in architectures: @@ -69,7 +70,7 @@ def resolve_model_cls( return self._raise_for_unsupported(architectures) -@lru_cache() +@lru_cache def import_model_classes(): model_arch_name_to_cls = {} package_name = "sgl_jax.srt.models" @@ -79,7 +80,7 @@ def import_model_classes(): try: module = importlib.import_module(name) except Exception as e: - logger.warning(f"Ignore import error when loading {name}. " f"{e}") + logger.warning("Ignore import error when loading %s. %s", name, e) continue if hasattr(module, "EntryClass"): entry = module.EntryClass diff --git a/python/sgl_jax/srt/precision_tracer.py b/python/sgl_jax/srt/precision_tracer.py index 9c699cbec..bf05d6be8 100644 --- a/python/sgl_jax/srt/precision_tracer.py +++ b/python/sgl_jax/srt/precision_tracer.py @@ -1,10 +1,11 @@ +import contextlib import hashlib import json import logging import os import threading import time -from typing import Any, Dict, List, Optional +from typing import Any import jax import jax.numpy as jnp @@ -13,9 +14,7 @@ def _is_jax_array(obj): - if not hasattr(obj, "shape") or not hasattr(obj, "dtype"): - return False - return True + return hasattr(obj, "shape") and hasattr(obj, "dtype") class TensorJSONEncoder(json.JSONEncoder): @@ -122,18 +121,20 @@ def __init__(self): # metadata self._current_batch_id = None - self._records: Dict[str, PrecisionTracerRecord] = {} - self._batch_requests_mapping: Dict[ - int, List[PrecisionTracerRequestMetadata] + self._records: dict[str, PrecisionTracerRecord] = {} + self._batch_requests_mapping: dict[ + int, list[PrecisionTracerRequestMetadata] ] = {} - self._token_counters: Dict[str, int] = {} - self._last_forward_pass_id: Dict[str, int] = {} + self._token_counters: dict[str, int] = {} + self._last_forward_pass_id: dict[str, int] = {} self._current_forward_pass_id: int = -1 def set_enable_precision_tracer(self, enabled: bool): self._enable_precision_tracer = enabled - logger.info(f"Precision tracer globally {'enabled' if enabled else 'disabled'}") + logger.info( + "Precision tracer globally %s", "enabled" if enabled else "disabled" + ) def get_trace_active(self): with self.lock: @@ -181,8 +182,8 @@ def add_request_to_batch_requests_mapping( def start_trace( self, - req_num: Optional[int] = None, - output_file: Optional[str] = None, + req_num: int | None = None, + output_file: str | None = None, verbose_logging: bool = False, ): if not self._enable_precision_tracer: @@ -214,16 +215,16 @@ def start_trace( raise ValueError("output_file is required") self._trace_output_file = output_file - logger.info(f"Trace output file: {self._trace_output_file}") + logger.info("Trace output file: %s", self._trace_output_file) os.makedirs(os.path.dirname(self._trace_output_file), exist_ok=True) with open(self._trace_output_file, "w") as _: pass - logger.info(f"Request tracing started. Output: {self._trace_output_file}") + logger.info("Request tracing started. Output: %s", self._trace_output_file) if req_num: - logger.info(f"Will trace up to {req_num} requests") + logger.info("Will trace up to %s requests", req_num) if not verbose_logging: logger.info("Verbose console logging disabled during tracing") @@ -243,16 +244,18 @@ def stop_trace(self): json.dump(record_dict, f, cls=TensorJSONEncoder, ensure_ascii=False) f.write("\n") - logger.info(f"Saved {len(self._records)} request traces to: {output_file}") + logger.info( + "Saved %s request traces to: %s", len(self._records), output_file + ) except Exception as e: - logger.error(f"Error saving traces to {output_file}: {e}") + logger.error("Error saving traces to %s: %s", output_file, e) self._records.clear() self._batch_requests_mapping.clear() self._token_counters.clear() self._last_forward_pass_id.clear() - logger.info(f"Request tracing stopped. Traces saved to: {output_file}") + logger.info("Request tracing stopped. Traces saved to: %s", output_file) return output_file def start_batch_trace(self, batch_id: int): @@ -263,7 +266,7 @@ def start_batch_trace(self, batch_id: int): requests_in_batch = self._batch_requests_mapping.get(batch_id, []) if len(requests_in_batch) == 0: - logger.warning(f"Batch {batch_id} has no requests to trace") + logger.warning("Batch %s has no requests to trace", batch_id) return self._current_batch_id = batch_id @@ -293,7 +296,7 @@ def set_current_forward_pass_id(self, forward_pass_id: int): self._current_forward_pass_id = forward_pass_id def jit_pure_callback_record( - self, tensor: Any, name: str, stage: str, layer_id: Optional[int] = None + self, tensor: Any, name: str, stage: str, layer_id: int | None = None ) -> bool: if self._enable_precision_tracer: full_stage = ( @@ -304,7 +307,10 @@ def trace_callback(tensor): # Debug logging to check what stage is being passed if self._trace_active: logger.debug( - f"Recording tensor {name} with stage: {full_stage}, layer_id: {layer_id}" + "Recording tensor %s with stage: %s, layer_id: %s", + name, + full_stage, + layer_id, ) precision_tracer.record(tensor, name, full_stage) return True @@ -326,7 +332,7 @@ def record( return if tensor is None: - logger.info(f"[{stage}] {name}: None") + logger.info("[%s] %s: None", stage, name) return with self.lock: @@ -336,7 +342,7 @@ def record( current_batch_id = self._current_batch_id if len(request_in_batch) == 0: - logger.warning(f"Batch {current_batch_id} has no requests to trace") + logger.warning("Batch %s has no requests to trace", current_batch_id) return prisicion_infos = self._calculate_tensor_pricision_info( @@ -422,7 +428,7 @@ def record( current_token_group["records"].append(record_with_metadata) else: - logger.warning(f"Request {req_id} not found in records") + logger.warning("Request %s not found in records", req_id) continue for req_id, data in prisicion_infos.items(): @@ -436,8 +442,8 @@ def _calculate_tensor_pricision_info( tensor: Any, name: str, stage: str, - request_in_batch: List[PrecisionTracerRequestMetadata], - ) -> Dict[str, Any]: + request_in_batch: list[PrecisionTracerRequestMetadata], + ) -> dict[str, Any]: try: try: test_scalar = jnp.array(1.0) @@ -456,16 +462,16 @@ def _calculate_tensor_pricision_info( for idx, req_meta in enumerate(request_in_batch): req_id = req_meta.request_id - if req_meta.forward_mode == 1: - seq_len = req_meta.input_len - else: - seq_len = 1 + seq_len = req_meta.input_len if req_meta.forward_mode == 1 else 1 if current_idx + seq_len > total_batch_size: logger.error( - f"[TENSOR_DEBUG] ERROR: Request {req_id} requires {seq_len} tokens, " - f"but only {total_batch_size - current_idx} left in tensor of shape {tensor.shape}. " - f"Current position: {current_idx}" + "[TENSOR_DEBUG] ERROR: Request %s requires %s tokens, but only %s left in tensor of shape %s. Current position: %s", + req_id, + seq_len, + total_batch_size - current_idx, + tensor.shape, + current_idx, ) continue @@ -518,7 +524,7 @@ def _parse_layer_and_module(self, stage: str): module_type = "unknown" if self._trace_active: - logger.debug(f"Parsing stage: '{stage}'") + logger.debug("Parsing stage: '%s'", stage) if "_layer_id_" in stage: parts = stage.split("_layer_id_") @@ -528,12 +534,14 @@ def _parse_layer_and_module(self, stage: str): module_type = parts[0] if self._trace_active: logger.debug( - f"Parsed from _layer_id_ format: layer_id={layer_id}, module_type={module_type}" + "Parsed from _layer_id_ format: layer_id=%s, module_type=%s", + layer_id, + module_type, ) return layer_id, module_type except (ValueError, IndexError) as e: if self._trace_active: - logger.debug(f"Failed to parse _layer_id_ format: {e}") + logger.debug("Failed to parse _layer_id_ format: %s", e) pass if stage: @@ -560,10 +568,8 @@ def _parse_layer_and_module(self, stage: str): layer_match = re.search(r"L(\d+)", stage, re.IGNORECASE) if layer_match: - try: + with contextlib.suppress(ValueError): layer_id = int(layer_match.group(1)) - except ValueError: - pass return layer_id, module_type @@ -580,10 +586,7 @@ def _compute_stats( shape = tensor.shape dtype = str(tensor.dtype) - if tensor.size > 1: - std_val = float(jnp.std(tensor, ddof=0).item()) - else: - std_val = 0.0 + std_val = float(jnp.std(tensor, ddof=0).item()) if tensor.size > 1 else 0.0 stats = { "framework": "jax", @@ -669,7 +672,7 @@ def _traced_stats(self, name: str, stage: str, shape, dtype, layer_id, module_ty "tracing_context": True, } - def _verbose_logging_console(self, stats: Dict[str, Any]): + def _verbose_logging_console(self, stats: dict[str, Any]): if self._trace_active and not self._verbose_logging: return diff --git a/python/sgl_jax/srt/reasoning_parser.py b/python/sgl_jax/srt/reasoning_parser.py index 9e18554f1..3a8f14497 100644 --- a/python/sgl_jax/srt/reasoning_parser.py +++ b/python/sgl_jax/srt/reasoning_parser.py @@ -1,6 +1,3 @@ -from typing import Dict, Optional, Tuple, Type - - class StreamingParseResult: """Result of streaming incremental parsing.""" @@ -186,13 +183,13 @@ class ReasoningParser: If True, streams reasoning content as it arrives. """ - DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = { + DetectorMap: dict[str, type[BaseReasoningFormatDetector]] = { "deepseek-r1": DeepSeekR1Detector, "qwen3": Qwen3Detector, "kimi": KimiDetector, } - def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True): + def __init__(self, model_type: str | None = None, stream_reasoning: bool = True): if not model_type: raise ValueError("Model type must be specified") @@ -202,12 +199,12 @@ def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = Tr self.detector = detector_class(stream_reasoning=stream_reasoning) - def parse_non_stream(self, full_text: str) -> Tuple[str, str]: + def parse_non_stream(self, full_text: str) -> tuple[str, str]: """Non-streaming call: one-time parsing""" ret = self.detector.detect_and_parse(full_text) return ret.reasoning_text, ret.normal_text - def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, str]: + def parse_stream_chunk(self, chunk_text: str) -> tuple[str, str]: """Streaming call: incremental parsing""" ret = self.detector.parse_streaming_increment(chunk_text) return ret.reasoning_text, ret.normal_text diff --git a/python/sgl_jax/srt/sampling/penaltylib/orchestrator.py b/python/sgl_jax/srt/sampling/penaltylib/orchestrator.py index 46bf5efd0..ab2a04227 100644 --- a/python/sgl_jax/srt/sampling/penaltylib/orchestrator.py +++ b/python/sgl_jax/srt/sampling/penaltylib/orchestrator.py @@ -2,7 +2,7 @@ import abc import weakref -from typing import TYPE_CHECKING, Optional, Set, Type +from typing import TYPE_CHECKING import numpy as np @@ -15,7 +15,7 @@ def __init__( self, vocab_size: int, batch: ScheduleBatch, - penalizers: Set[Type["_BatchedPenalizer"]], + penalizers: set[type[_BatchedPenalizer]], ): self.vocab_size = vocab_size self._batch_ref = weakref.ref(batch) @@ -35,7 +35,7 @@ def batch(self) -> ScheduleBatch | None: return self._batch_ref() @batch.setter - def batch(self, value: Optional[ScheduleBatch]): + def batch(self, value: ScheduleBatch | None): if value is None: self._batch_ref = lambda: None else: @@ -132,7 +132,7 @@ def filter(self, keep_indices: np.ndarray): penalizer.teardown() self.is_required = is_required - def merge(self, their: "BatchedPenalizerOrchestrator"): + def merge(self, their: BatchedPenalizerOrchestrator): """ Merge the penalizers of another orchestrator into this one. @@ -189,7 +189,7 @@ def filter(self, keep_indices: np.ndarray): self._filter(keep_indices=keep_indices) - def merge(self, their: "_BatchedPenalizer"): + def merge(self, their: _BatchedPenalizer): if not self._is_prepared and not their._is_prepared: return @@ -228,7 +228,7 @@ def _filter(self, keep_indices: np.ndarray): pass @abc.abstractmethod - def _merge(self, their: "_BatchedPenalizer"): + def _merge(self, their: _BatchedPenalizer): """ Merge the penalizer with another penalizer. """ diff --git a/python/sgl_jax/srt/sampling/sampling_batch_info.py b/python/sgl_jax/srt/sampling/sampling_batch_info.py index 1f41b4350..f2846d48a 100644 --- a/python/sgl_jax/srt/sampling/sampling_batch_info.py +++ b/python/sgl_jax/srt/sampling/sampling_batch_info.py @@ -2,7 +2,7 @@ import dataclasses import logging -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.tree_util import register_pytree_node_class @@ -13,7 +13,7 @@ from sgl_jax.srt.utils.jax_utils import device_array if TYPE_CHECKING: - from sgl_jax.srt.managers.schedule_batch import ScheduleBatch, ModelWorkerBatch + from sgl_jax.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch import threading @@ -33,8 +33,8 @@ class SamplingMetadata: # logprob return_logprob: bool - top_logprobs_nums: Optional[List[int]] - token_ids_logprobs: Optional[List[List[int]]] + top_logprobs_nums: list[int] | None + token_ids_logprobs: list[list[int]] | None # sample temperatures: jax.Array @@ -47,7 +47,7 @@ class SamplingMetadata: # penalty do_penalties: bool = False - linear_penalty: Optional[jax.Array] = None + linear_penalty: jax.Array | None = None def tree_flatten(self): children = ( @@ -353,12 +353,12 @@ class SamplingBatchInfo: need_min_p_sampling: bool = False # An event used for overlap schedule - sampling_info_done: Optional[threading.Event] = None + sampling_info_done: threading.Event | None = None - sampling_seeds: Optional[np.ndarray] = None + sampling_seeds: np.ndarray | None = None # Penalizer - penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None + penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator | None = None linear_penalty: np.ndarray = None @classmethod @@ -524,7 +524,7 @@ def filter_batch(self, keep_indices: np.ndarray): if value is not None: setattr(self, item, value[keep_indices]) - def merge_batch(self, other: "SamplingBatchInfo"): + def merge_batch(self, other: SamplingBatchInfo): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) # Note: because the __len()__ operator is defined on the temperatures tensor, # please make sure any merge operation with len(self) or len(other) is done before diff --git a/python/sgl_jax/srt/sampling/sampling_params.py b/python/sgl_jax/srt/sampling/sampling_params.py index 247422829..eb60057da 100644 --- a/python/sgl_jax/srt/sampling/sampling_params.py +++ b/python/sgl_jax/srt/sampling/sampling_params.py @@ -1,7 +1,5 @@ """Sampling parameters for text generation.""" -from typing import Dict, List, Optional, Union - from sgl_jax.srt.utils import get_bool_env_var _SAMPLING_EPS = 1e-6 @@ -21,8 +19,8 @@ class SamplingParams: def __init__( self, max_new_tokens: int = 128, - stop: Optional[Union[str, List[str]]] = None, - stop_token_ids: Optional[List[int]] = None, + stop: str | list[str] | None = None, + stop_token_ids: list[int] | None = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -32,17 +30,17 @@ def __init__( repetition_penalty: float = 1.0, min_new_tokens: int = 0, n: int = 1, - json_schema: Optional[str] = None, - regex: Optional[str] = None, - ebnf: Optional[str] = None, - structural_tag: Optional[str] = None, + json_schema: str | None = None, + regex: str | None = None, + ebnf: str | None = None, + structural_tag: str | None = None, ignore_eos: bool = False, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, no_stop_trim: bool = False, - stream_interval: Optional[int] = None, - logit_bias: Optional[Dict[str, float]] = None, - sampling_seed: Optional[int] = None, + stream_interval: int | None = None, + logit_bias: dict[str, float] | None = None, + sampling_seed: int | None = None, ) -> None: self.max_new_tokens = max_new_tokens self.stop_strs = stop @@ -101,22 +99,19 @@ def verify(self, vocab_size): ) if not -2.0 <= self.frequency_penalty <= 2.0: raise ValueError( - "frequency_penalty must be in [-2, 2], got " - f"{self.frequency_penalty}." + f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}." ) if not -2.0 <= self.presence_penalty <= 2.0: raise ValueError( - "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." + f"presence_penalty must be in [-2, 2], got {self.presence_penalty}." ) if not 0.0 <= self.repetition_penalty <= 2.0: raise ValueError( - "repetition_penalty must be in [0, 2], got " - f"{self.repetition_penalty}." + f"repetition_penalty must be in [0, 2], got {self.repetition_penalty}." ) - if not 0 <= self.min_new_tokens: + if not self.min_new_tokens >= 0: raise ValueError( - f"min_new_tokens must be in [0, max_new_tokens], got " - f"{self.min_new_tokens}." + f"min_new_tokens must be in [0, max_new_tokens], got {self.min_new_tokens}." ) if self.max_new_tokens is not None: if self.max_new_tokens < 0: @@ -132,8 +127,7 @@ def verify(self, vocab_size): for token_id in self.logit_bias: if not 0 <= int(token_id) < vocab_size: raise ValueError( - f"logit_bias must has keys in [0, {vocab_size - 1}], got " - f"{token_id}." + f"logit_bias must has keys in [0, {vocab_size - 1}], got {token_id}." ) grammars = [ self.json_schema, @@ -161,7 +155,7 @@ def normalize(self, tokenizer): stop_str_max_len = max(stop_str_max_len, len(stop_str)) self.stop_str_max_len = stop_str_max_len - def convert_to_dict(self) -> Dict: + def convert_to_dict(self) -> dict: # Start with a copy of all instance attributes result = {} for key, value in self.__dict__.items(): diff --git a/python/sgl_jax/srt/server_args.py b/python/sgl_jax/srt/server_args.py index c91e604f3..33096a00a 100644 --- a/python/sgl_jax/srt/server_args.py +++ b/python/sgl_jax/srt/server_args.py @@ -6,7 +6,6 @@ import logging import os import tempfile -from typing import List, Optional import jax @@ -25,36 +24,36 @@ class ServerArgs: # Model and tokenizer model_path: str - tokenizer_path: Optional[str] = None + tokenizer_path: str | None = None tokenizer_mode: str = "auto" skip_tokenizer_init: bool = False load_format: str = "auto" model_loader_extra_config: str = "{}" trust_remote_code: bool = False - context_length: Optional[int] = None + context_length: int | None = None is_embedding: bool = False - revision: Optional[str] = None + revision: str | None = None model_impl: str = "auto" - model_layer_nums: Optional[int] = None + model_layer_nums: int | None = None # HTTP server host: str = "127.0.0.1" port: int = 30000 skip_server_warmup: bool = False - warmups: Optional[str] = None + warmups: str | None = None # Quantization and data type dtype: str = "auto" - quantization: Optional[str] = None - quantization_param_path: Optional[str] = None + quantization: str | None = None + quantization_param_path: str | None = None kv_cache_dtype: str = "auto" # Memory and scheduling - mem_fraction_static: Optional[float] = None - max_running_requests: Optional[int] = None - max_total_tokens: Optional[int] = None + mem_fraction_static: float | None = None + max_running_requests: int | None = None + max_total_tokens: int | None = None max_prefill_tokens: int = 16384 - chunked_prefill_size: Optional[int] = None + chunked_prefill_size: int | None = None enable_mixed_chunk: bool = False schedule_policy: str = "fcfs" schedule_conservativeness: float = 1.0 @@ -63,15 +62,15 @@ class ServerArgs: disable_hybrid_swa_memory: bool = False # Runtime options - device: Optional[str] = None + device: str | None = None tp_size: int = 1 stream_interval: int = 1 stream_output: bool = False - random_seed: Optional[int] = None - constrained_json_whitespace_pattern: Optional[str] = None + random_seed: int | None = None + constrained_json_whitespace_pattern: str | None = None watchdog_timeout: float = 300 - dist_timeout: Optional[int] = None # timeout for distributed initialization - download_dir: Optional[str] = None + dist_timeout: int | None = None # timeout for distributed initialization + download_dir: str | None = None sleep_on_idle: bool = False # Data parallel @@ -79,34 +78,34 @@ class ServerArgs: # Logging log_level: str = "info" - log_level_http: Optional[str] = None + log_level_http: str | None = None log_requests: bool = False log_requests_level: int = 0 - crash_dump_folder: Optional[str] = None + crash_dump_folder: str | None = None show_time_cost: bool = False - bucket_time_to_first_token: Optional[List[float]] = None - bucket_inter_token_latency: Optional[List[float]] = None - bucket_e2e_request_latency: Optional[List[float]] = None + bucket_time_to_first_token: list[float] | None = None + bucket_inter_token_latency: list[float] | None = None + bucket_e2e_request_latency: list[float] | None = None decode_log_interval: int = 40 enable_request_time_stats_logging: bool = False - kv_events_config: Optional[str] = None + kv_events_config: str | None = None # API related - api_key: Optional[str] = None - served_model_name: Optional[str] = None + api_key: str | None = None + served_model_name: str | None = None file_storage_path: str = "sglang_storage" enable_cache_report: bool = False - reasoning_parser: Optional[str] = None - tool_call_parser: Optional[str] = None + reasoning_parser: str | None = None + tool_call_parser: str | None = None # Multi-node distributed serving - dist_init_addr: Optional[str] = None + dist_init_addr: str | None = None nnodes: int = 1 node_rank: int = 0 # Model override args in JSON json_model_override_args: str = "{}" - preferred_sampling_params: Optional[str] = None + preferred_sampling_params: str | None = None # Optimization/debug options disable_radix_cache: bool = False @@ -121,12 +120,12 @@ class ServerArgs: xla_backend: str = "tpu" # Kernel backend - attention_backend: Optional[str] = "fa" + attention_backend: str | None = "fa" max_seq_len: int = 4096 - precompile_token_paddings: Optional[List[int]] = None - precompile_bs_paddings: Optional[List[int]] = None + precompile_token_paddings: list[int] | None = None + precompile_bs_paddings: list[int] | None = None disable_jax_precompile: bool = False @@ -178,14 +177,13 @@ def __post_init__(self): if is_remote_url(self.model_path): self.load_format = "remote" - if self.enable_precision_tracer: - if self.chunked_prefill_size is not None or self.chunked_prefill_size > 0: - logger.warning( - "Chunked prefill is enabled, but precision tracer is also enabled. " - "This may cause incorrect precision tracer results." - "Disabling chunked prefill." - ) - self.chunked_prefill_size = -1 + if self.enable_precision_tracer and ( + self.chunked_prefill_size is not None or self.chunked_prefill_size > 0 + ): + logger.warning( + "Chunked prefill is enabled, but precision tracer is also enabled. This may cause incorrect precision tracer results. Disabling chunked prefill." + ) + self.chunked_prefill_size = -1 os.environ["SGLANG_ENABLE_DETERMINISTIC_SAMPLING"] = ( "1" if self.enable_deterministic_sampling else "0" @@ -815,7 +813,7 @@ def check_server_args(self): ), "chunked_prefill_size must be divisible by page_size" -def prepare_server_args(argv: List[str]) -> ServerArgs: +def prepare_server_args(argv: list[str]) -> ServerArgs: """ Prepare the server arguments from the command line arguments. @@ -857,7 +855,7 @@ class PortArgs: metrics_ipc_name: str @staticmethod - def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": + def init_new(server_args, dp_rank: int | None = None) -> "PortArgs": if server_args.nnodes > 1: dist_init_addr = server_args.dist_init_addr.split(":") dist_init_host, dist_init_port = dist_init_addr diff --git a/python/sgl_jax/srt/utils/common_utils.py b/python/sgl_jax/srt/utils/common_utils.py index d09fc4d2d..1224c66b8 100644 --- a/python/sgl_jax/srt/utils/common_utils.py +++ b/python/sgl_jax/srt/utils/common_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib import ctypes import dataclasses import functools @@ -21,8 +22,9 @@ import time import traceback from collections import OrderedDict +from collections.abc import Callable, Sequence from pathlib import Path -from typing import Any, Callable, Optional, Set, Union, Sequence +from typing import Any import numpy as np import psutil @@ -47,7 +49,9 @@ def get_bool_env_var(name: str, default: str = "false") -> bool: if (value not in truthy_values) and (value not in falsy_values): if value not in _warned_bool_env_var_keys: logger.warning( - f"get_bool_env_var({name}) see non-understandable value={value} and treat as false" + "get_bool_env_var(%s) see non-understandable value=%s and treat as false", + name, + value, ) _warned_bool_env_var_keys.add(value) @@ -79,13 +83,11 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N for child in children: if child.pid == skip_pid: continue - try: + with contextlib.suppress(psutil.NoSuchProcess): child.kill() - except psutil.NoSuchProcess: - pass if include_parent: - try: + with contextlib.suppress(psutil.NoSuchProcess): if parent_pid == os.getpid(): itself.kill() sys.exit(0) @@ -95,8 +97,6 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), # so we send an additional signal to kill them. itself.send_signal(signal.SIGQUIT) - except psutil.NoSuchProcess: - pass def set_ulimit(target_soft_limit=65535): @@ -108,7 +108,7 @@ def set_ulimit(target_soft_limit=65535): try: resource.setrlimit(resource_type, (target_soft_limit, current_hard)) except ValueError as e: - logger.warning(f"Fail to set RLIMIT_NOFILE: {e}") + logger.warning("Fail to set RLIMIT_NOFILE: %s", e) # stack size resource_type = resource.RLIMIT_STACK @@ -120,7 +120,7 @@ def set_ulimit(target_soft_limit=65535): resource_type, (target_soft_limit_stack_size, current_hard) ) except ValueError as e: - logger.warning(f"Fail to set RLIMIT_STACK: {e}") + logger.warning("Fail to set RLIMIT_STACK: %s", e) def add_api_key_middleware(app, api_key: str): @@ -138,14 +138,13 @@ async def authentication(request, call_next): def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str): - if get_bool_env_var("SGLANG_USE_MODELSCOPE"): - if not os.path.exists(model_path): - from modelscope import snapshot_download + if get_bool_env_var("SGLANG_USE_MODELSCOPE") and not os.path.exists(model_path): + from modelscope import snapshot_download - model_path = snapshot_download(model_path) - tokenizer_path = snapshot_download( - tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"] - ) + model_path = snapshot_download(model_path) + tokenizer_path = snapshot_download( + tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"] + ) return model_path, tokenizer_path @@ -153,8 +152,7 @@ def configure_logger(server_args, prefix: str = ""): if SGLANG_LOGGING_CONFIG_PATH := os.getenv("SGLANG_LOGGING_CONFIG_PATH"): if not os.path.exists(SGLANG_LOGGING_CONFIG_PATH): raise Exception( - "Setting SGLANG_LOGGING_CONFIG_PATH from env with " - f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!" + f"Setting SGLANG_LOGGING_CONFIG_PATH from env with {SGLANG_LOGGING_CONFIG_PATH} does not exists" ) with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file: custom_config = json.loads(file.read()) @@ -176,10 +174,7 @@ def get_zmq_socket( mem = psutil.virtual_memory() total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 - if total_mem > 32 and available_mem > 16: - buf_size = int(0.5 * 1024**3) - else: - buf_size = -1 + buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1 socket = context.socket(socket_type) if endpoint.find("[") != -1: @@ -220,7 +215,7 @@ def delete_directory(dirpath): def dataclass_to_string_truncated( - data, max_length=2048, skip_names: Optional[Set[str]] = None + data, max_length=2048, skip_names: set[str] | None = None ): if skip_names is None: skip_names = set() @@ -276,9 +271,9 @@ def pyspy_dump_schedulers(): result = subprocess.run( cmd, shell=True, capture_output=True, text=True, check=True ) - logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}") + logger.error("Pyspy dump for PID %s:\n%s", pid, result.stdout) except subprocess.CalledProcessError as e: - logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}") + logger.error("Pyspy failed to dump PID %s. Error: %s", pid, e.stderr) def kill_itself_when_parent_died(): @@ -344,16 +339,16 @@ async def health_generate(): try: loop = asyncio.get_running_loop() logger.info( - f"Dummy health check server scheduled on existing loop at {host}:{port}" + "Dummy health check server scheduled on existing loop at %s:%s", host, port ) loop.create_task(server.serve()) except RuntimeError: - logger.info(f"Starting dummy health check server at {host}:{port}") + logger.info("Starting dummy health check server at %s:%s", host, port) server.run() -def is_remote_url(url: Union[str, Path]) -> bool: +def is_remote_url(url: str | Path) -> bool: """ Check if the URL is a remote URL of the format: ://:/ @@ -378,17 +373,23 @@ def retry( return fn() except Exception as e: if try_index >= max_retry: - raise Exception("retry() exceed maximum number of retries.") + raise Exception("retry() exceed maximum number of retries.") from e if not should_retry(e): - raise Exception("retry() observe errors that should not be retried.") + raise Exception( + "retry() observe errors that should not be retried." + ) from e delay = min(initial_delay * (2**try_index), max_delay) * ( 0.75 + 0.25 * random.random() ) logger.warning( - f"retry() failed once ({try_index}th try, maximum {max_retry} retries). Will delay {delay:.2f}s and retry. Error: {e}" + "retry() failed once (%sth try, maximum %s retries). Will delay %.2fs and retry. Error: %s", + try_index, + max_retry, + delay, + e, ) traceback.print_exc() @@ -413,7 +414,7 @@ def _to_hashable(o): ): return tuple(_to_hashable(v) for v in o) else: - raise TypeError(f"Cannot make hashable: {type(o)}") + raise TypeError(f"Cannot make hashable: {type(o)}") from None def decorator(func): cache = OrderedDict() @@ -441,5 +442,5 @@ def wrapper(*args, **kwargs): def cdiv(a, b): - assert b != 0, f"b is equal to 0, {b=}" + assert b != 0, f"b is equal to 0, b={b}" return (a + b - 1) // b diff --git a/python/sgl_jax/srt/utils/jax_utils.py b/python/sgl_jax/srt/utils/jax_utils.py index 2da16202b..79c145cf8 100644 --- a/python/sgl_jax/srt/utils/jax_utils.py +++ b/python/sgl_jax/srt/utils/jax_utils.py @@ -81,7 +81,6 @@ def get_available_device_memory(device, distributed=False, empty_cache=True): raise ValueError(f"Invalid device: {device}") if distributed: - # Use pmap to find the minimum available memory across all devices. mesh = jax.make_mesh((jax.process_count(), 4), ("node", "device")) diff --git a/python/sgl_jax/srt/utils/mesh_utils.py b/python/sgl_jax/srt/utils/mesh_utils.py index 87aa49c2e..bee8a5c7f 100644 --- a/python/sgl_jax/srt/utils/mesh_utils.py +++ b/python/sgl_jax/srt/utils/mesh_utils.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence import jax import numpy as np diff --git a/python/sgl_jax/srt/utils/weight_utils.py b/python/sgl_jax/srt/utils/weight_utils.py index c7d2647f8..2403ec93b 100644 --- a/python/sgl_jax/srt/utils/weight_utils.py +++ b/python/sgl_jax/srt/utils/weight_utils.py @@ -3,7 +3,6 @@ import math import os from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union import jax import jax.numpy as jnp @@ -20,10 +19,10 @@ @dataclass class WeightMapping: - target_path: Union[str, List[str]] - sharding: Optional[Tuple] = None + target_path: str | list[str] + sharding: tuple | None = None transpose: bool = False - reshape: Optional[Tuple] = None + reshape: tuple | None = None head_dim_padding: bool = False kv_head_padding: bool = False @@ -31,11 +30,12 @@ def __post_init__(self): if self.sharding is None: self.sharding = self._infer_default_sharding() - def _infer_default_sharding(self) -> Tuple: - if isinstance(self.target_path, list): - path = self.target_path[0] - else: - path = self.target_path + def _infer_default_sharding(self) -> tuple: + path = ( + self.target_path[0] + if isinstance(self.target_path, list) + else self.target_path + ) if any(pattern in path for pattern in ["embedding", "lm_head"]): return (None, None) @@ -91,7 +91,7 @@ def __init__( self.sharding_size = 1 def load_weights_from_safetensors( - self, weight_mappings: Dict[str, Union[str, List[str], WeightMapping]] + self, weight_mappings: dict[str, str | list[str] | WeightMapping] ): params = nnx.state(self.model) @@ -107,7 +107,8 @@ def load_weights_from_safetensors( expert_weights = {} logger.info( - f"WeightLoader: Will load layers 0 to {self.model_config.num_hidden_layers - 1}" + "WeightLoader: Will load layers 0 to %s", + self.model_config.num_hidden_layers - 1, ) for hf_key, hf_weight in self._iterate_weights(): @@ -119,12 +120,12 @@ def load_weights_from_safetensors( self._process_and_assign_weight(params, hf_key, hf_weight, mapping) elif "mlp.experts." in hf_key and hf_key.endswith(".weight"): if self._is_excluded_layer_weight(hf_key): - logger.debug(f"Skipping excluded MoE expert weight: {hf_key}") + logger.debug("Skipping excluded MoE expert weight: %s", hf_key) else: expert_weights[hf_key] = hf_weight.astype(self.dtype) else: if self._is_excluded_layer_weight(hf_key): - logger.debug(f"Skipping excluded layer weight: {hf_key}") + logger.debug("Skipping excluded layer weight: %s", hf_key) else: logger.warning(f"No mapping found for weight: {hf_key}") nnx.update(self.model, params) @@ -148,34 +149,43 @@ def _iterate_weights(self): filename = os.path.basename(st_file) pbar.set_postfix({"file": filename}) - with jax.default_device(jax.local_devices(backend="cpu")[0]): - with safe_open(st_file, framework="flax") as f: - needed_keys = [] - for name in f.keys(): - if not name.startswith("model.layers."): - needed_keys.append(name) - continue - - if not self._is_excluded_layer_weight(name): - needed_keys.append(name) - - if not needed_keys: - skipped_files += 1 - logger.debug( - f"Skipping {filename}: 0/{len(f.keys())} weights needed" - ) + with ( + jax.default_device(jax.local_devices(backend="cpu")[0]), + safe_open(st_file, framework="flax") as f, + ): + needed_keys = [] + for name in f: + if not name.startswith("model.layers."): + needed_keys.append(name) continue + if not self._is_excluded_layer_weight(name): + needed_keys.append(name) + + if not needed_keys: + skipped_files += 1 logger.debug( - f"Loading {filename}: {len(needed_keys)}/{len(f.keys())} weights needed" + "Skipping %s: 0/%s weights needed", + filename, + len(f.keys()), ) - for name in needed_keys: - weight_tensor = f.get_tensor(name) - yield name, weight_tensor + continue + + logger.debug( + "Loading %s: %s/%s weights needed", + filename, + len(needed_keys), + len(f.keys()), + ) + for name in needed_keys: + weight_tensor = f.get_tensor(name) + yield name, weight_tensor if skipped_files > 0: logger.info( - f"Memory optimization: Skipped {skipped_files}/{len(weights_files)} files with no needed weights" + "Memory optimization: Skipped %s/%s files with no needed weights", + skipped_files, + len(weights_files), ) def _process_and_assign_weight( @@ -217,11 +227,15 @@ def _handle_single_weight( try: model_param = self._get_param(params, jax_path) logger.debug( - f"Loading {hf_key} -> {jax_path}, shape: {processed_weight.shape}, transpose: {mapping.transpose}" + "Loading %s -> %s, shape: %s, transpose: %s", + hf_key, + jax_path, + processed_weight.shape, + mapping.transpose, ) model_param.value = sharded_weight except Exception as e: - logger.error(f"Failed to load {hf_key} -> {jax_path}: {str(e)}") + logger.error("Failed to load %s -> %s: %s", hf_key, jax_path, str(e)) raise def _handle_split_weight( @@ -264,7 +278,6 @@ def _split_qkv_weight( splits = [q_bias, k_bias, v_bias] else: - q_dim = self.num_heads * self.head_dim_original kv_dim = self.num_kv_heads * self.head_dim_original @@ -360,7 +373,7 @@ def _split_qkv_weight( model_param = self._get_param(params, jax_path) model_param.value = sharded_weight logger.debug( - f"Split {hf_key} -> {jax_path}, shape: {processed_weight.shape}" + "Split %s -> %s, shape: %s", hf_key, jax_path, processed_weight.shape ) def _shard_weight(self, weight: jax.Array, sharding: tuple) -> jax.Array: @@ -539,7 +552,9 @@ def _is_excluded_layer_weight(self, hf_key: str) -> bool: if is_excluded and not hasattr(self, "_debug_count"): logger.info( - f"DEBUG: Excluding layer {layer_num} >= {self.model_config.num_hidden_layers}" + "DEBUG: Excluding layer %s >= %s", + layer_num, + self.model_config.num_hidden_layers, ) self._debug_count = True @@ -548,8 +563,8 @@ def _is_excluded_layer_weight(self, hf_key: str) -> bool: def _process_moe_expert_weights( self, params: nnx.State, - moe_mappings: Dict[str, WeightMapping], - expert_weights: Dict[str, jax.Array], + moe_mappings: dict[str, WeightMapping], + expert_weights: dict[str, jax.Array], ): with tqdm( moe_mappings.items(), desc="[STACKING] MOE EXPERTS", unit="layer" @@ -562,7 +577,7 @@ def _process_moe_expert_weights( not isinstance(mapping.target_path, list) or len(mapping.target_path) < 2 ): - logger.warning(f"Invalid MoE mapping for {moe_key}") + logger.warning("Invalid MoE mapping for %s", moe_key) continue target_path = mapping.target_path[0] @@ -576,7 +591,7 @@ def _process_moe_expert_weights( weight = jnp.transpose(weight, (1, 0)) collected_weights.append(weight) else: - logger.warning(f"Missing expert weight: {expert_key}") + logger.warning("Missing expert weight: %s", expert_key) if len(collected_weights) == len(expert_keys): stacked_weight = jnp.stack(collected_weights, axis=0) @@ -590,5 +605,5 @@ def _process_moe_expert_weights( model_param.value = sharded_weight else: logger.error( - f"Could not collect all expert weights for {target_path}" + "Could not collect all expert weights for %s", target_path ) diff --git a/python/sgl_jax/test/mem_cache/test_kv_cache.py b/python/sgl_jax/test/mem_cache/test_kv_cache.py index c43830d2a..c3bcb6238 100644 --- a/python/sgl_jax/test/mem_cache/test_kv_cache.py +++ b/python/sgl_jax/test/mem_cache/test_kv_cache.py @@ -554,7 +554,7 @@ def test_optimize_contiguous_updates_corner_cases(self): ) for idx, cache_start, new_start, length in slices: print( - f" Slice at pos {idx}: cache[{cache_start}:{cache_start+length}] <- input[{new_start}:{new_start+length}]" + f" Slice at pos {idx}: cache[{cache_start}:{cache_start + length}] <- input[{new_start}:{new_start + length}]" ) # Verify slice doesn't exceed page_size diff --git a/python/sgl_jax/test/model_executor/test_model_runner.py b/python/sgl_jax/test/model_executor/test_model_runner.py index 6b65082f3..fc6e99fa4 100644 --- a/python/sgl_jax/test/model_executor/test_model_runner.py +++ b/python/sgl_jax/test/model_executor/test_model_runner.py @@ -131,7 +131,7 @@ def _get_tokenizer(self): print(f"Failed to load tokenizer from HuggingFace: {e}") raise RuntimeError( f"Could not load tokenizer from local path or HuggingFace: {e}" - ) + ) from e def _new_forward_batch(self, input_ids, positions): """Create a ForwardBatch for testing.""" diff --git a/python/sgl_jax/test/models/test_qwen_model.py b/python/sgl_jax/test/models/test_qwen_model.py index d2bcaa84f..2655fa07d 100644 --- a/python/sgl_jax/test/models/test_qwen_model.py +++ b/python/sgl_jax/test/models/test_qwen_model.py @@ -66,7 +66,7 @@ def _get_tokenizer(self): print(f"Failed to load tokenizer from HuggingFace: {e}") raise RuntimeError( f"Could not load tokenizer from local path or HuggingFace: {e}" - ) + ) from e def _setup_model(self): """Setup model using get_model_loader""" @@ -281,7 +281,7 @@ def _generate_random_questions(self, batch_size: int) -> list[str]: question = template.format(*fill_params) except (IndexError, ValueError): question = ( - f"Question {i+1}: Tell me about {random.choice(fill_words)}" + f"Question {i + 1}: Tell me about {random.choice(fill_words)}" ) else: question = template @@ -400,7 +400,7 @@ def test_qwen_model_forward(self, batch_size: int = None): print("\nSample questions:") for i, text in enumerate(input_texts[: min(5, len(input_texts))]): - print(f" {i+1}: '{text}'") + print(f" {i + 1}: '{text}'") if len(input_texts) > 5: print(f" ... and {len(input_texts) - 5} more questions") @@ -522,8 +522,8 @@ def test_qwen_model_forward(self, batch_size: int = None): print(f" Requests processed: {len(input_texts)}") print(f" Requests finished: {finished_count}/{len(input_texts)}") print(f" Average output length: {avg_output_length:.1f} characters") - print(f" Throughput: {len(input_texts)/total_time:.2f} requests/second") - print(f" Time per request: {total_time/len(input_texts)*1000:.2f} ms") + print(f" Throughput: {len(input_texts) / total_time:.2f} requests/second") + print(f" Time per request: {total_time / len(input_texts) * 1000:.2f} ms") # Print detailed results for small batches if len(input_texts) <= 10: diff --git a/python/sgl_jax/test/run_jax_loader_test.py b/python/sgl_jax/test/run_jax_loader_test.py index 6b3b226d2..6efce587e 100644 --- a/python/sgl_jax/test/run_jax_loader_test.py +++ b/python/sgl_jax/test/run_jax_loader_test.py @@ -22,10 +22,10 @@ """ import argparse +import importlib import os import subprocess import sys -import importlib from pathlib import Path diff --git a/python/sgl_jax/test/run_qwen3_moe_test.py b/python/sgl_jax/test/run_qwen3_moe_test.py index 0ecaf3d0e..4ac5523ee 100644 --- a/python/sgl_jax/test/run_qwen3_moe_test.py +++ b/python/sgl_jax/test/run_qwen3_moe_test.py @@ -25,10 +25,10 @@ """ import argparse +import importlib import os import subprocess import sys -import importlib from pathlib import Path os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" @@ -39,6 +39,7 @@ def check_jax_dependencies(): try: import jax import jax.numpy as jnp + importlib.util.find_spec("flax.nnx") print(f"✓ JAX version: {jax.__version__}") @@ -67,7 +68,9 @@ def check_sglang_dependencies(): try: importlib.util.find_spec("sgl_jax.srt.configs.load_config.LoadFormat") importlib.util.find_spec("sgl_jax.srt.model_loader.loader.JAXModelLoader") - importlib.util.find_spec("sgl_jax.srt.models.qwen3_moe.Qwen3MoeForCausalLMJaxModel") + importlib.util.find_spec( + "sgl_jax.srt.models.qwen3_moe.Qwen3MoeForCausalLMJaxModel" + ) print("✓ SGLang JAXModelLoader available") print("✓ Qwen3MoeForCausalLMJaxModel available") diff --git a/python/sgl_jax/test/run_qwen_test.py b/python/sgl_jax/test/run_qwen_test.py index d74ee587a..1ee29e2a5 100644 --- a/python/sgl_jax/test/run_qwen_test.py +++ b/python/sgl_jax/test/run_qwen_test.py @@ -1,8 +1,8 @@ import argparse +import importlib import os import subprocess import sys -import importlib from pathlib import Path @@ -11,6 +11,7 @@ def check_jax_dependencies(): try: import jax import jax.numpy as jnp + importlib.util.find_spec("flax.nnx") print(f"✓ JAX version: {jax.__version__}") diff --git a/python/sgl_jax/test/runners.py b/python/sgl_jax/test/runners.py index 37d25e4c4..43eeb81a3 100644 --- a/python/sgl_jax/test/runners.py +++ b/python/sgl_jax/test/runners.py @@ -14,8 +14,6 @@ import os - - DEFAULT_PROMPTS = [ "Apple is red. Banana is Yellow. " * 800 + "Apple is", "The capital of the United Kingdom is", @@ -41,7 +39,7 @@ ] dirpath = os.path.dirname(__file__) -with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f: +with open(os.path.join(dirpath, "long_prompt.txt")) as f: long_prompt = f.read() DEFAULT_PROMPTS.append(long_prompt) diff --git a/python/sgl_jax/test/simple_eval_common.py b/python/sgl_jax/test/simple_eval_common.py index 5a60221ed..21600a42d 100644 --- a/python/sgl_jax/test/simple_eval_common.py +++ b/python/sgl_jax/test/simple_eval_common.py @@ -6,7 +6,7 @@ from collections import defaultdict from dataclasses import dataclass, field from multiprocessing.pool import ThreadPool -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import httpx import jinja2 @@ -23,8 +23,8 @@ ) -Message = Dict[str, Any] # keys role, content -MessageList = List[Message] +Message = dict[str, Any] # keys role, content +MessageList = list[Message] class SamplerBase: @@ -43,10 +43,10 @@ class EvalResult: Result of running an evaluation (usually consisting of many samples) """ - score: Optional[float] # top-line metric - metrics: Optional[Dict[str, float]] # other metrics - htmls: List[str] # strings of valid HTML - convos: List[MessageList] # sampled conversations + score: float | None # top-line metric + metrics: dict[str, float] | None # other metrics + htmls: list[str] # strings of valid HTML + convos: list[MessageList] # sampled conversations @dataclass @@ -55,10 +55,10 @@ class SingleEvalResult: Result of evaluating a single sample """ - score: Optional[float] - metrics: Dict[str, float] = field(default_factory=dict) - html: Optional[str] = None - convo: Optional[MessageList] = None # sampled conversation + score: float | None + metrics: dict[str, float] = field(default_factory=dict) + html: str | None = None + convo: MessageList | None = None # sampled conversation class Eval: @@ -88,8 +88,8 @@ class ChatCompletionSampler(SamplerBase): def __init__( self, base_url: str = None, - model: Optional[str] = None, - system_message: Optional[str] = None, + model: str | None = None, + system_message: str | None = None, temperature: float = 0.0, max_tokens: int = 2048, ): @@ -269,9 +269,9 @@ def _compute_stat(values: list, stat: str): def aggregate_results( - single_eval_results: List[SingleEvalResult], - default_stats: Tuple[str] = ("mean", "std"), - name2stats: Optional[Dict[str, Tuple[str]]] = None, + single_eval_results: list[SingleEvalResult], + default_stats: tuple[str] = ("mean", "std"), + name2stats: dict[str, tuple[str]] | None = None, ) -> EvalResult: """ Aggregate results from multiple evaluations into a single EvalResult. @@ -301,11 +301,11 @@ def aggregate_results( ) -def map_with_progress(f: callable, xs: List[Any], num_threads: int): +def map_with_progress(f: callable, xs: list[Any], num_threads: int): """ Apply f to each element of xs, using a ThreadPool, and show progress. """ - if os.getenv("debug"): + if os.getenv("DEBUG"): return list(map(f, tqdm(xs, total=len(xs)))) else: with ThreadPool(min(num_threads, len(xs))) as pool: @@ -421,7 +421,7 @@ def make_report(eval_result: EvalResult) -> str: ) -def make_report_from_example_htmls(htmls: List[str]): +def make_report_from_example_htmls(htmls: list[str]): """ Create a standalone HTML report from a list of example htmls """ @@ -439,20 +439,23 @@ def download_dataset(path, url): total_size = int(response.headers.get("content-length", 0)) block_size = 8192 - with open(path, "wb") as f, tqdm( - desc="Downloading", - total=total_size, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as progress_bar: + with ( + open(path, "wb") as f, + tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar, + ): for data in response.iter_content(block_size): size = f.write(data) progress_bar.update(size) print(f"Dataset downloaded and saved to {path}") except requests.RequestException as e: - raise Exception(f"Failed to download dataset: {e}") + raise Exception(f"Failed to download dataset: {e}") from e def set_ulimit(target_soft_limit=65535): diff --git a/python/sgl_jax/test/simple_eval_gpqa.py b/python/sgl_jax/test/simple_eval_gpqa.py index 0a834a595..e93eeecf4 100644 --- a/python/sgl_jax/test/simple_eval_gpqa.py +++ b/python/sgl_jax/test/simple_eval_gpqa.py @@ -8,7 +8,6 @@ import random import re -from typing import Optional import pandas @@ -28,7 +27,7 @@ class GPQAEval(Eval): def __init__( self, filename: str, - num_examples: Optional[int], + num_examples: int | None, num_threads: int, n_repeats: int = 1, ): diff --git a/python/sgl_jax/test/simple_eval_humaneval.py b/python/sgl_jax/test/simple_eval_humaneval.py index 1113bf4f7..abc34a498 100644 --- a/python/sgl_jax/test/simple_eval_humaneval.py +++ b/python/sgl_jax/test/simple_eval_humaneval.py @@ -9,8 +9,6 @@ import random import re from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, List, Optional - try: from human_eval.data import read_problems @@ -31,8 +29,8 @@ def evaluate_functional_correctness( - sample: Dict[str, str], - completions: List[str], + sample: dict[str, str], + completions: list[str], n_workers: int = 4, timeout: float = 3.0, ): @@ -59,10 +57,10 @@ def evaluate_functional_correctness( class HumanEval(Eval): def __init__( self, - num_examples: Optional[int], + num_examples: int | None, num_threads: int, num_samples_per_task: int = 5, - ks_passes: List[int] = [1, 2, 5], + ks_passes: tuple[int, ...] = (1, 2, 5), timeout: int = 120, ): self.seed = 0 @@ -89,7 +87,7 @@ def find_code(completion): ] # remove signature return extracted_answer - def fn(sample: Dict[str, str]): + def fn(sample: dict[str, str]): prompt_messages = [ sampler._pack_message( role="user", content=instruction + sample["prompt"] diff --git a/python/sgl_jax/test/simple_eval_math.py b/python/sgl_jax/test/simple_eval_math.py index 154b53f8b..3e213fec3 100644 --- a/python/sgl_jax/test/simple_eval_math.py +++ b/python/sgl_jax/test/simple_eval_math.py @@ -8,7 +8,6 @@ import random import re -from typing import Optional import pandas @@ -37,7 +36,7 @@ def __init__( self, filename: str, equality_checker: SamplerBase, - num_examples: Optional[int], + num_examples: int | None, num_threads: int, ): df = pandas.read_csv(filename) diff --git a/python/sgl_jax/test/simple_eval_mgsm.py b/python/sgl_jax/test/simple_eval_mgsm.py index 7dc6d32cb..5e5eed488 100644 --- a/python/sgl_jax/test/simple_eval_mgsm.py +++ b/python/sgl_jax/test/simple_eval_mgsm.py @@ -9,7 +9,6 @@ import re import urllib -from typing import Optional from sgl_jax.test import simple_eval_common as common from sgl_jax.test.simple_eval_common import ( @@ -139,7 +138,7 @@ def __init__( self, num_examples_per_lang: int = 250, # restrict to a subset of the data for debugging num_threads: int = 64, - languages: Optional[list[str]] = ALL_LANGUAGES, + languages: list[str] | None = ALL_LANGUAGES, ): if languages is None: languages = ALL_LANGUAGES diff --git a/python/sgl_jax/test/simple_eval_mmlu.py b/python/sgl_jax/test/simple_eval_mmlu.py index ccff7c888..df0cde7d9 100644 --- a/python/sgl_jax/test/simple_eval_mmlu.py +++ b/python/sgl_jax/test/simple_eval_mmlu.py @@ -8,7 +8,6 @@ import random import re -from typing import Optional import pandas @@ -85,7 +84,7 @@ class MMLUEval(Eval): - def __init__(self, filename: str, num_examples: Optional[int], num_threads: int): + def __init__(self, filename: str, num_examples: int | None, num_threads: int): df = pandas.read_csv(filename) examples = [row.to_dict() for _, row in df.iterrows()] if num_examples: diff --git a/python/sgl_jax/test/test_flashattention.py b/python/sgl_jax/test/test_flashattention.py index 8a5958eaa..f52d323a3 100644 --- a/python/sgl_jax/test/test_flashattention.py +++ b/python/sgl_jax/test/test_flashattention.py @@ -56,9 +56,9 @@ def create_qkv_cache( v = jnp.zeros((batched_aligned_kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16) # Fill in the actual data for each sequence with proper alignment - actual_pos = 0 aligned_pos = 0 - for seq_len in [kv_len for _, kv_len in lens]: + for actual_pos, _ in enumerate(range(len(lens))): + seq_len = lens[actual_pos] aligned_len = ((seq_len + page_size - 1) // page_size) * page_size # Generate data for this sequence @@ -77,7 +77,6 @@ def create_qkv_cache( k = k.at[aligned_pos : aligned_pos + seq_len].set(seq_k) v = v.at[aligned_pos : aligned_pos + seq_len].set(seq_v) - actual_pos += 1 aligned_pos += aligned_len return q, k, v @@ -290,11 +289,7 @@ def run_test(self, mode, lens, mode_args): # Create mock forward_batch num_heads, head_dim, num_kv_heads, page_size, dtype = mode_args - if dtype == jnp.bfloat16: - is_bf16 = True - else: - is_bf16 = False - + is_bf16 = dtype == jnp.bfloat16 forward_batch, token_to_kv_pool, q, k, v = create_test_data( mode, lens, diff --git a/python/sgl_jax/test/test_model_loader.py b/python/sgl_jax/test/test_model_loader.py index ac53bf44f..abc788eac 100644 --- a/python/sgl_jax/test/test_model_loader.py +++ b/python/sgl_jax/test/test_model_loader.py @@ -203,12 +203,10 @@ def _find_test_model_path(cls): # Look for directories that contain safetensors files for item in os.listdir(path): item_path = os.path.join(path, item) - if os.path.isdir(item_path): - # Check if this directory has safetensors files - if any( - f.endswith(".safetensors") for f in os.listdir(item_path) - ): - return item_path + if os.path.isdir(item_path) and any( + f.endswith(".safetensors") for f in os.listdir(item_path) + ): + return item_path return None @@ -591,10 +589,7 @@ def _get_param_by_path(self, params, path): keys = path.split(".") current = params for key in keys: - if key.isdigit(): - current = current[int(key)] - else: - current = current[key] + current = current[int(key)] if key.isdigit() else current[key] return current def test_multi_device_tensor_parallelism(self): @@ -776,7 +771,7 @@ def test_nonexistent_model_path(self): """Test handling of nonexistent model path.""" LoadConfig(load_format=LoadFormat.JAX) - with self.assertRaises(Exception): + with self.assertRaises((OSError, FileNotFoundError)): ModelConfig(model_path="/nonexistent/path", trust_remote_code=True) def test_empty_directory(self): @@ -784,7 +779,7 @@ def test_empty_directory(self): LoadConfig(load_format=LoadFormat.JAX) # This should fail when trying to create ModelConfig - with self.assertRaises(Exception): + with self.assertRaises(OSError): ModelConfig( model_path=self.temp_dir, trust_remote_code=True, # Empty directory diff --git a/python/sgl_jax/test/test_multi_process_radix_cache.py b/python/sgl_jax/test/test_multi_process_radix_cache.py index 736a9a54c..2223b7962 100644 --- a/python/sgl_jax/test/test_multi_process_radix_cache.py +++ b/python/sgl_jax/test/test_multi_process_radix_cache.py @@ -19,7 +19,7 @@ def print_cache_sharding_info(cache, mesh, req_pool, allocator, process_id): """Print cache-related sharding information""" - print(f"\n{'='*80}") + print(f"\n{'=' * 80}") print(f"[PROCESS {process_id}] RadixCache multiprocesssharding info") print(f"[PROCESS {process_id}] Local device count: {len(jax.local_devices())}") print(f"[PROCESS {process_id}] Global device count: {len(jax.devices())}") @@ -75,7 +75,7 @@ def print_sharding(obj, name, prefix=""): if hasattr(kv_cache, "kv_buffer") and kv_cache.kv_buffer: print_sharding(kv_cache.kv_buffer[0], "kv_cache.kv_buffer[0]", "allocator") - print(f"{'='*80}") + print(f"{'=' * 80}") def create_multi_process_radix_cache(process_id, tp_size=8): @@ -175,7 +175,7 @@ def test_basic_radix_cache_operations(cache, process_id): test_keys = [[1, 2, 3, 4, 5], [1, 2, 3, 6, 7], [10, 11, 12, 13, 14]] for i, key in enumerate(test_keys): - print(f"[PROCESS {process_id}] Inserting key {i+1}: {key}") + print(f"[PROCESS {process_id}] Inserting key {i + 1}: {key}") prefix_len = cache.insert(key) print(f"[PROCESS {process_id}] Prefix match length: {prefix_len}") @@ -253,11 +253,11 @@ def test_cross_process_isolation(cache, process_id): print(f"[PROCESS {process_id}] Inserting process-specific data:") for i, key in enumerate(process_specific_keys): - print(f"[PROCESS {process_id}] Inserting key{i+1}: {key}") + print(f"[PROCESS {process_id}] Inserting key{i + 1}: {key}") cache.insert(key) match_result = cache.match_prefix(key) print( - f"[PROCESS {process_id}] Key{i+1} match result: {len(match_result.device_indices)}" + f"[PROCESS {process_id}] Key{i + 1} match result: {len(match_result.device_indices)}" ) # Test cache status diff --git a/python/sgl_jax/test/test_sampler.py b/python/sgl_jax/test/test_sampler.py index fa9403f5b..6861d1356 100644 --- a/python/sgl_jax/test/test_sampler.py +++ b/python/sgl_jax/test/test_sampler.py @@ -8,7 +8,6 @@ class TestMultinomialWithSeed(unittest.TestCase): - def test_deterministic_sampling_with_same_seed(self): """Test that same (inputs, seed) pair always yields the same sample.""" # Setup test data diff --git a/python/sgl_jax/test/test_utils.py b/python/sgl_jax/test/test_utils.py index 2c6646cc5..a3ac4ed28 100644 --- a/python/sgl_jax/test/test_utils.py +++ b/python/sgl_jax/test/test_utils.py @@ -9,9 +9,9 @@ import threading import time import unittest -from contextlib import nullcontext +from collections.abc import Awaitable, Callable, Sequence +from contextlib import nullcontext, suppress from types import SimpleNamespace -from typing import Awaitable, Callable, Optional, Sequence import jax import numpy as np @@ -42,10 +42,7 @@ def is_in_ci(): return get_bool_env_var("SGLANG_IS_IN_CI") -if is_in_ci(): - DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5000 + 100 -else: - DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 7000 + 100 +DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5000 + 100 if is_in_ci() else 7000 + 100 DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" mesh_axes = [ @@ -142,10 +139,10 @@ def popen_launch_server( model: str, base_url: str, timeout: float, - api_key: Optional[str] = None, - other_args: list[str] = [], - env: Optional[dict] = None, - return_stdout_stderr: Optional[tuple] = None, + api_key: str | None = None, + other_args: list[str] | None = None, + env: dict | None = None, + return_stdout_stderr: tuple | None = None, device: str = "tpu", pd_separated: bool = False, ): @@ -155,7 +152,7 @@ def popen_launch_server( device: Device type ("auto", "cuda", "rocm" or "cpu"). If "auto", will detect available platforms automatically. """ - other_args = list(other_args) + other_args = list(other_args) if other_args is not None else [] other_args += ["--device", str(device)] _, host, port = base_url.split(":") @@ -217,13 +214,11 @@ def popen_launch_server( start_time = time.perf_counter() with requests.Session() as session: while time.perf_counter() - start_time < timeout: - return_code = process.poll() if return_code is not None: # Server failed to start (non-zero exit code) or crashed raise Exception( - f"Server process exited with code {return_code}. " - "Check server logs for errors." + f"Server process exited with code {return_code}. Check server logs for errors." ) try: @@ -271,13 +266,11 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N for child in children: if child.pid == skip_pid: continue - try: + with suppress(psutil.NoSuchProcess): child.kill() - except psutil.NoSuchProcess: - pass if include_parent: - try: + with suppress(psutil.NoSuchProcess): if parent_pid == os.getpid(): itself.kill() sys.exit(0) @@ -287,8 +280,6 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), # so we send an additional signal to kill them. itself.send_signal(signal.SIGQUIT) - except psutil.NoSuchProcess: - pass def generate_server_args() -> ServerArgs: @@ -467,8 +458,8 @@ def run_bench_serving( need_warmup=False, seed: int = 0, device="auto", - background_task: Optional[Callable[[str, asyncio.Event], Awaitable[None]]] = None, - lora_name: Optional[str] = None, + background_task: Callable[[str, asyncio.Event], Awaitable[None]] | None = None, + lora_name: str | None = None, ): if device == "auto": device = "tpu" @@ -722,10 +713,11 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2): lcs_len = lcs(s1, s2) precision = lcs_len / len(s1) if len(s1) > 0 else 0 recall = lcs_len / len(s2) if len(s2) > 0 else 0 - if precision + recall > 0: - fmeasure = (2 * precision * recall) / (precision + recall) - else: - fmeasure = 0.0 + fmeasure = ( + (2 * precision * recall) / (precision + recall) + if precision + recall > 0 + else 0.0 + ) rouge_l_scores.append(fmeasure) return rouge_l_scores diff --git a/python/sgl_jax/tools/trace_diff.py b/python/sgl_jax/tools/trace_diff.py index 94d9ebc19..a6bcd8624 100755 --- a/python/sgl_jax/tools/trace_diff.py +++ b/python/sgl_jax/tools/trace_diff.py @@ -1,7 +1,6 @@ import argparse import json import sys -from typing import Dict, List, Optional, Tuple class Colors: @@ -20,16 +19,16 @@ class Colors: BG_YELLOW = "\033[48;5;178m" -def load_jsonl(file_path: str) -> List[Dict]: +def load_jsonl(file_path: str) -> list[dict]: try: - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: return [json.loads(line.strip()) for line in f if line.strip()] except Exception as e: print(f"Error loading {file_path}: {e}") return [] -def group_by_content_hash(traces: List[Dict]) -> Dict[str, List[Dict]]: +def group_by_content_hash(traces: list[dict]) -> dict[str, list[dict]]: groups = {} for trace in traces: content_hash = trace.get("content_hash") @@ -41,11 +40,11 @@ def group_by_content_hash(traces: List[Dict]) -> Dict[str, List[Dict]]: def compare_precision_records( - records1: Dict, - records2: Dict, + records1: dict, + records2: dict, tolerance: float = 1e-6, - max_decode_tokens: Optional[int] = None, -) -> Tuple[bool, List[str]]: + max_decode_tokens: int | None = None, +) -> tuple[bool, list[str]]: differences = [] all_match = True @@ -87,11 +86,11 @@ def category_sort_key(category): def compare_token_groups( category: str, - tokens1: List[Dict], - tokens2: List[Dict], + tokens1: list[dict], + tokens2: list[dict], tolerance: float = 1e-6, - max_decode_tokens: Optional[int] = None, -) -> Tuple[bool, List[str]]: + max_decode_tokens: int | None = None, +) -> tuple[bool, list[str]]: """Compare token groups within a category (prefill/decode)""" differences = [] all_match = True @@ -156,10 +155,10 @@ def compare_token_groups( def compare_token_records( category: str, token_idx: int, - records1: List[Dict], - records2: List[Dict], + records1: list[dict], + records2: list[dict], tolerance: float = 1e-6, -) -> Tuple[bool, List[str]]: +) -> tuple[bool, list[str]]: """Compare records within a single token group""" differences = [] all_match = True @@ -306,7 +305,7 @@ def group_records(records): return all_match, differences -def print_diff_header(content_hash: str, trace1: Dict, trace2: Dict): +def print_diff_header(content_hash: str, trace1: dict, trace2: dict): """Print a header for the diff section""" print(f"\n{Colors.BOLD}{Colors.CYAN}Content Hash: {content_hash}{Colors.RESET}") print( @@ -352,7 +351,7 @@ def format_comparison_result(message: str) -> str: return message -def print_tree_differences(differences: List[str]): +def print_tree_differences(differences: list[str]): """Print differences in a tree-like structure with color coding""" if not differences: print(f"{Colors.BG_GREEN}{Colors.WHITE} ALL MATCH {Colors.RESET}") @@ -406,14 +405,12 @@ def print_tree_differences(differences: List[str]): # Print tree structure - ensure prefill comes first def display_category_sort_key(category): - if category == "prefill": - return 0 - elif category == "decode": - return 1 - elif category == "root": - return 999 # root always last - else: - return 2 + order = { + "prefill": 0, + "decode": 1, + "root": 999, # root always last + } + return order.get(category, 2) for category in sorted(tree.keys(), key=display_category_sort_key): if category == "root": @@ -471,7 +468,7 @@ def parse_layer_module(layer_key): # Group by layer number first layers_by_num = {} - for layer_key in layer_groups.keys(): + for layer_key in layer_groups: layer_num, module_type = parse_layer_module(layer_key) if layer_num not in layers_by_num: layers_by_num[layer_num] = {} @@ -547,7 +544,7 @@ def parse_layer_module(layer_key): print(f" {formatted}") -def print_match_status(is_match: bool, differences: List[str]): +def print_match_status(is_match: bool, differences: list[str]): """Print match status with tree-like hierarchy""" print_tree_differences(differences) @@ -557,7 +554,7 @@ def compare_trace_files( file2: str, tolerance: float = 1e-6, show_matches: bool = False, - max_decode_tokens: Optional[int] = None, + max_decode_tokens: int | None = None, ) -> bool: """ Compare two JSONL trace files by content_hash with tree-structured output diff --git a/python/sgl_jax/utils.py b/python/sgl_jax/utils.py index 4cc9defb4..e8e9af438 100644 --- a/python/sgl_jax/utils.py +++ b/python/sgl_jax/utils.py @@ -1,13 +1,13 @@ import logging import traceback -from typing import Any, Callable, List, Tuple, Type - +from collections.abc import Callable +from typing import Any logger = logging.getLogger(__name__) class TypeBasedDispatcher: - def __init__(self, mapping: List[Tuple[Type, Callable]]): + def __init__(self, mapping: list[tuple[type, Callable]]): self._mapping = mapping def __call__(self, obj: Any): diff --git a/test/srt/openai_server/basic/test_serving_chat.py b/test/srt/openai_server/basic/test_serving_chat.py index 8f7a3ed55..0bf57fedc 100644 --- a/test/srt/openai_server/basic/test_serving_chat.py +++ b/test/srt/openai_server/basic/test_serving_chat.py @@ -97,9 +97,12 @@ def setUp(self): # ------------- conversion tests ------------- def test_convert_to_internal_request_single(self): - with patch( - "sgl_jax.srt.entrypoints.openai.serving_chat.generate_chat_conv" - ) as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock: + with ( + patch( + "sgl_jax.srt.entrypoints.openai.serving_chat.generate_chat_conv" + ) as conv_mock, + patch.object(self.chat, "_process_messages") as proc_mock, + ): conv_ins = Mock() conv_ins.get_prompt.return_value = "Test prompt" conv_ins.image_data = conv_ins.audio_data = None diff --git a/test/srt/openai_server/validation/test_openai_server_params_validation.py b/test/srt/openai_server/validation/test_openai_server_params_validation.py index fc8c31236..f420e0608 100644 --- a/test/srt/openai_server/validation/test_openai_server_params_validation.py +++ b/test/srt/openai_server/validation/test_openai_server_params_validation.py @@ -127,7 +127,6 @@ def test_invalid_top_p_parameter(self): self.assertTrue(cm.exception.code, 400) def test_input_length_longer_than_context_length(self): - client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") # Will tokenize to more than context length long_text = "hello" * 1200 diff --git a/test/srt/test_bench_one_batch.py b/test/srt/test_bench_one_batch.py index 1dca9bf89..b6c47bb3b 100644 --- a/test/srt/test_bench_one_batch.py +++ b/test/srt/test_bench_one_batch.py @@ -13,7 +13,6 @@ class TestBenchOneBatch(CustomTestCase): - def test_bs1_default(self): output_throughput = run_bench_offline_throughput( DEFAULT_MODEL_NAME_FOR_TEST, diff --git a/test/srt/test_eval_accuracy_large.py b/test/srt/test_eval_accuracy_large.py index 9a2d309e0..736dc6648 100644 --- a/test/srt/test_eval_accuracy_large.py +++ b/test/srt/test_eval_accuracy_large.py @@ -85,7 +85,7 @@ def test_mmlu(self): metrics = run_eval(args) if is_in_ci(): - write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n') + write_github_step_summary(f'### test_mmlu\n{metrics["score"]=:.4f}\n') print("mmlu metrics", metrics) self.assertGreater(metrics["score"], 0.43) @@ -120,9 +120,7 @@ def test_mgsm_en(self): metrics = run_eval(args) if is_in_ci(): - write_github_step_summary( - f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n' - ) + write_github_step_summary(f'### test_mgsm_en\n{metrics["score"]=:.4f}\n') print("mgsm en metrics", metrics) self.assertGreater(metrics["score"], 0.4) diff --git a/test/srt/test_features.py b/test/srt/test_features.py index 7144f8bf1..9e4bccdff 100644 --- a/test/srt/test_features.py +++ b/test/srt/test_features.py @@ -650,7 +650,6 @@ def test_combined_penalties(self): class TestNoOverlapSchedule(CustomTestCase): - @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST From de8573f303dad7233ad6d56453afc7fcb76f466c Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 17 Oct 2025 04:14:19 +0000 Subject: [PATCH 07/18] Fix --- .../sgl_jax/srt/layers/attention/native_backend.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/python/sgl_jax/srt/layers/attention/native_backend.py b/python/sgl_jax/srt/layers/attention/native_backend.py index 3193c2c64..bb7c5341a 100644 --- a/python/sgl_jax/srt/layers/attention/native_backend.py +++ b/python/sgl_jax/srt/layers/attention/native_backend.py @@ -5,7 +5,7 @@ from sgl_jax.srt.layers.attention.base_attn_backend import AttentionBackend from sgl_jax.srt.layers.radix_attention import AttentionType, RadixAttention from sgl_jax.srt.managers.schedule_batch import ModelWorkerBatch -from sgl_jax.srt.mem_cache.memory_pool import KVCache, merge_kv +from sgl_jax.srt.mem_cache.memory_pool import KVCache from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sgl_jax.srt.utils.jax_utils import is_tpu_runtime @@ -33,9 +33,7 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): - return cls( - num_attn_heads=aux_data["num_heads"], num_kv_heads=aux_data["num_kv_heads"] - ) + return cls(num_attn_heads=aux_data["num_heads"], num_kv_heads=aux_data["num_kv_heads"]) def get_forward_metadata(self, batch: ModelWorkerBatch): """Init the metadata for a forward pass and return it.""" @@ -62,9 +60,7 @@ def __call__( k, v, forward_batch, token_to_kv_pool, layer.layer_id ) - scale = ( - 1.0 / jnp.sqrt(layer.head_dim) if layer.scaling is None else layer.scaling - ) + scale = 1.0 / jnp.sqrt(layer.head_dim) if layer.scaling is None else layer.scaling is_causal = not ( forward_batch.forward_mode == ForwardMode.DECODE @@ -181,9 +177,7 @@ def forward_attention( else: # Already in multi-head format: [num_tokens, num_heads, head_dim] num_tokens, num_heads_input, head_dim = q.shape - assert ( - num_heads_input == num_heads - ), f"Expected {num_heads} heads, got {num_heads_input}" + assert num_heads_input == num_heads, f"Expected {num_heads} heads, got {num_heads_input}" hidden_size = num_heads * head_dim # Calculate hidden_size for proper reshaping q_heads = q From 93665bcdb57e8150c52ec7c3deb84b916ffbae94 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 17 Oct 2025 04:18:24 +0000 Subject: [PATCH 08/18] Fix --- python/sgl_jax/srt/utils/weight_utils.py | 97 ++++++------------------ 1 file changed, 24 insertions(+), 73 deletions(-) diff --git a/python/sgl_jax/srt/utils/weight_utils.py b/python/sgl_jax/srt/utils/weight_utils.py index 2403ec93b..2a5923198 100644 --- a/python/sgl_jax/srt/utils/weight_utils.py +++ b/python/sgl_jax/srt/utils/weight_utils.py @@ -31,11 +31,7 @@ def __post_init__(self): self.sharding = self._infer_default_sharding() def _infer_default_sharding(self) -> tuple: - path = ( - self.target_path[0] - if isinstance(self.target_path, list) - else self.target_path - ) + path = self.target_path[0] if isinstance(self.target_path, list) else self.target_path if any(pattern in path for pattern in ["embedding", "lm_head"]): return (None, None) @@ -154,7 +150,7 @@ def _iterate_weights(self): safe_open(st_file, framework="flax") as f, ): needed_keys = [] - for name in f: + for name in f.keys(): # noqa: SIM118 if not name.startswith("model.layers."): needed_keys.append(name) continue @@ -215,9 +211,7 @@ def _handle_single_weight( processed_weight = jnp.reshape(processed_weight, mapping.reshape) if mapping.head_dim_padding and self.head_dim_pad > 0: - processed_weight = self._apply_head_dim_padding( - processed_weight, hf_key, mapping - ) + processed_weight = self._apply_head_dim_padding(processed_weight, hf_key, mapping) if mapping.kv_head_padding: processed_weight = self._apply_kv_head_padding(processed_weight, hf_key) @@ -264,15 +258,11 @@ def _split_qkv_weight( q_bias = jnp.pad(q_bias, ((0, 0), (0, self.head_dim_pad))) q_bias = jnp.reshape(q_bias, (self.num_heads * self.head_dim,)) - k_bias = jnp.reshape( - k_bias, (self.num_kv_heads, self.head_dim_original) - ) + k_bias = jnp.reshape(k_bias, (self.num_kv_heads, self.head_dim_original)) k_bias = jnp.pad(k_bias, ((0, 0), (0, self.head_dim_pad))) k_bias = jnp.reshape(k_bias, (self.num_kv_heads * self.head_dim,)) - v_bias = jnp.reshape( - v_bias, (self.num_kv_heads, self.head_dim_original) - ) + v_bias = jnp.reshape(v_bias, (self.num_kv_heads, self.head_dim_original)) v_bias = jnp.pad(v_bias, ((0, 0), (0, self.head_dim_pad))) v_bias = jnp.reshape(v_bias, (self.num_kv_heads * self.head_dim,)) @@ -296,9 +286,7 @@ def _split_qkv_weight( q_weight, (self.hidden_size, self.num_heads, self.head_dim_original), ) - q_weight = jnp.pad( - q_weight, ((0, 0), (0, 0), (0, self.head_dim_pad)) - ) + q_weight = jnp.pad(q_weight, ((0, 0), (0, 0), (0, self.head_dim_pad))) q_weight = jnp.reshape( q_weight, (self.hidden_size, self.num_heads * self.head_dim) ) @@ -307,9 +295,7 @@ def _split_qkv_weight( k_weight, (self.hidden_size, self.num_kv_heads, self.head_dim_original), ) - k_weight = jnp.pad( - k_weight, ((0, 0), (0, 0), (0, self.head_dim_pad)) - ) + k_weight = jnp.pad(k_weight, ((0, 0), (0, 0), (0, self.head_dim_pad))) k_weight = jnp.reshape( k_weight, (self.hidden_size, self.num_kv_heads * self.head_dim) ) @@ -318,9 +304,7 @@ def _split_qkv_weight( v_weight, (self.hidden_size, self.num_kv_heads, self.head_dim_original), ) - v_weight = jnp.pad( - v_weight, ((0, 0), (0, 0), (0, self.head_dim_pad)) - ) + v_weight = jnp.pad(v_weight, ((0, 0), (0, 0), (0, self.head_dim_pad))) v_weight = jnp.reshape( v_weight, (self.hidden_size, self.num_kv_heads * self.head_dim) ) @@ -329,9 +313,7 @@ def _split_qkv_weight( q_weight, (self.num_heads, self.head_dim_original, self.hidden_size), ) - q_weight = jnp.pad( - q_weight, ((0, 0), (0, self.head_dim_pad), (0, 0)) - ) + q_weight = jnp.pad(q_weight, ((0, 0), (0, self.head_dim_pad), (0, 0))) q_weight = jnp.reshape( q_weight, (self.num_heads * self.head_dim, self.hidden_size) ) @@ -340,9 +322,7 @@ def _split_qkv_weight( k_weight, (self.num_kv_heads, self.head_dim_original, self.hidden_size), ) - k_weight = jnp.pad( - k_weight, ((0, 0), (0, self.head_dim_pad), (0, 0)) - ) + k_weight = jnp.pad(k_weight, ((0, 0), (0, self.head_dim_pad), (0, 0))) k_weight = jnp.reshape( k_weight, (self.num_kv_heads * self.head_dim, self.hidden_size) ) @@ -351,9 +331,7 @@ def _split_qkv_weight( v_weight, (self.num_kv_heads, self.head_dim_original, self.hidden_size), ) - v_weight = jnp.pad( - v_weight, ((0, 0), (0, self.head_dim_pad), (0, 0)) - ) + v_weight = jnp.pad(v_weight, ((0, 0), (0, self.head_dim_pad), (0, 0))) v_weight = jnp.reshape( v_weight, (self.num_kv_heads * self.head_dim, self.hidden_size) ) @@ -363,18 +341,14 @@ def _split_qkv_weight( for split_weight, jax_path in zip(splits, jax_paths): processed_weight = split_weight - if mapping.kv_head_padding and ( - "k_proj" in jax_path or "v_proj" in jax_path - ): + if mapping.kv_head_padding and ("k_proj" in jax_path or "v_proj" in jax_path): processed_weight = self._apply_kv_head_padding(processed_weight, hf_key) sharded_weight = self._shard_weight(processed_weight, mapping.sharding) model_param = self._get_param(params, jax_path) model_param.value = sharded_weight - logger.debug( - "Split %s -> %s, shape: %s", hf_key, jax_path, processed_weight.shape - ) + logger.debug("Split %s -> %s, shape: %s", hf_key, jax_path, processed_weight.shape) def _shard_weight(self, weight: jax.Array, sharding: tuple) -> jax.Array: if math.prod(self.mesh.axis_sizes) == 1: @@ -404,15 +378,11 @@ def _apply_head_dim_padding( if hf_key.endswith(".bias"): if any(proj in hf_key for proj in ["q_proj", "k_proj", "v_proj"]): if "q_proj" in hf_key: - reshaped = jnp.reshape( - weight, (self.num_heads, self.head_dim_original) - ) + reshaped = jnp.reshape(weight, (self.num_heads, self.head_dim_original)) padded = jnp.pad(reshaped, ((0, 0), (0, self.head_dim_pad))) return jnp.reshape(padded, (self.num_heads * self.head_dim,)) else: # k_proj or v_proj - reshaped = jnp.reshape( - weight, (self.num_kv_heads, self.head_dim_original) - ) + reshaped = jnp.reshape(weight, (self.num_kv_heads, self.head_dim_original)) padded = jnp.pad(reshaped, ((0, 0), (0, self.head_dim_pad))) return jnp.reshape(padded, (self.num_kv_heads * self.head_dim,)) else: @@ -429,9 +399,7 @@ def _apply_head_dim_padding( weight, (self.hidden_size, self.num_heads, self.head_dim_original), ) - padded = jnp.pad( - reshaped, ((0, 0), (0, 0), (0, self.head_dim_pad)) - ) + padded = jnp.pad(reshaped, ((0, 0), (0, 0), (0, self.head_dim_pad))) return jnp.reshape( padded, (self.hidden_size, self.num_heads * self.head_dim) ) @@ -444,9 +412,7 @@ def _apply_head_dim_padding( self.head_dim_original, ), ) - padded = jnp.pad( - reshaped, ((0, 0), (0, 0), (0, self.head_dim_pad)) - ) + padded = jnp.pad(reshaped, ((0, 0), (0, 0), (0, self.head_dim_pad))) return jnp.reshape( padded, (self.hidden_size, self.num_kv_heads * self.head_dim), @@ -460,9 +426,7 @@ def _apply_head_dim_padding( reshaped, (self.num_heads, self.head_dim_original, self.hidden_size), ) - padded = jnp.pad( - padded_reshaped, ((0, 0), (0, self.head_dim_pad), (0, 0)) - ) + padded = jnp.pad(padded_reshaped, ((0, 0), (0, self.head_dim_pad), (0, 0))) return jnp.reshape( padded, (self.num_heads * self.head_dim, self.hidden_size) ) @@ -476,9 +440,7 @@ def _apply_kv_head_padding(self, weight: jax.Array, hf_key: str) -> jax.Array: proj in hf_key for proj in ["k_proj", "v_proj"] ) and self.model_config.needs_kv_head_replication(self.sharding_size): total_kv_heads = self.model_config.get_total_num_kv_heads() - num_replicas = self.model_config.get_num_kv_head_replicas( - self.sharding_size - ) + num_replicas = self.model_config.get_num_kv_head_replicas(self.sharding_size) padding_strategy = self.model_config.get_kv_padding_strategy() if padding_strategy == "replicate": @@ -531,9 +493,7 @@ def _apply_kv_head_padding(self, weight: jax.Array, hf_key: str) -> jax.Array: padding_size = target_size - current_size if padding_size > 0: - padding = jnp.zeros( - (weight.shape[0], padding_size), dtype=weight.dtype - ) + padding = jnp.zeros((weight.shape[0], padding_size), dtype=weight.dtype) weight = jnp.concatenate([weight, padding], axis=1) return weight @@ -566,17 +526,12 @@ def _process_moe_expert_weights( moe_mappings: dict[str, WeightMapping], expert_weights: dict[str, jax.Array], ): - with tqdm( - moe_mappings.items(), desc="[STACKING] MOE EXPERTS", unit="layer" - ) as pbar: + with tqdm(moe_mappings.items(), desc="[STACKING] MOE EXPERTS", unit="layer") as pbar: for moe_key, mapping in pbar: layer_name = moe_key.replace("__MOE_EXPERTS__", "") pbar.set_postfix({"layer": layer_name}) - if ( - not isinstance(mapping.target_path, list) - or len(mapping.target_path) < 2 - ): + if not isinstance(mapping.target_path, list) or len(mapping.target_path) < 2: logger.warning("Invalid MoE mapping for %s", moe_key) continue @@ -598,12 +553,8 @@ def _process_moe_expert_weights( device_experts = stacked_weight - sharded_weight = self._shard_weight( - device_experts, mapping.sharding - ) + sharded_weight = self._shard_weight(device_experts, mapping.sharding) model_param = self._get_param(params, target_path) model_param.value = sharded_weight else: - logger.error( - "Could not collect all expert weights for %s", target_path - ) + logger.error("Could not collect all expert weights for %s", target_path) From e322063bc49444085160e172a5982a0c39002f40 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 17 Oct 2025 04:27:10 +0000 Subject: [PATCH 09/18] Format --- .pre-commit-config.yaml | 4 + .../flash_attention/bench_flashattention.py | 5 +- .../flash_attention/get_block_spec_config.py | 9 +- benchmark/kernels/flash_attention/utils.py | 20 +-- .../megablox_gmm/bench_megablox_gmm.py | 12 +- .../update_kv_cache/bench_update_kv_cache.py | 12 +- benchmark/kernels/update_kv_cache/utils.py | 11 +- python/sgl_jax/bench_offline_throughput.py | 57 ++------ python/sgl_jax/bench_one_batch.py | 71 +++------ python/sgl_jax/bench_one_batch_server.py | 28 +--- python/sgl_jax/bench_serving.py | 121 ++++----------- python/sgl_jax/profiler.py | 8 +- python/sgl_jax/srt/configs/model_config.py | 62 ++------ python/sgl_jax/srt/conversation.py | 4 +- python/sgl_jax/srt/entrypoints/engine.py | 36 ++--- python/sgl_jax/srt/entrypoints/http_server.py | 104 ++++--------- .../srt/entrypoints/openai/protocol.py | 32 +--- .../srt/entrypoints/openai/serving_base.py | 4 +- .../srt/entrypoints/openai/serving_chat.py | 57 ++------ .../entrypoints/openai/serving_completions.py | 42 ++---- .../srt/entrypoints/openai/usage_processor.py | 11 +- .../sgl_jax/srt/entrypoints/openai/utils.py | 4 +- python/sgl_jax/srt/hf_transformers_utils.py | 8 +- python/sgl_jax/srt/jinja_template_utils.py | 4 +- .../flash_attn_kernel/flash_attention.py | 69 +++------ .../attention/flashattention_backend.py | 24 +-- python/sgl_jax/srt/layers/embeddings.py | 15 +- .../layers/gmm/megablox_gmm_kernel/common.py | 6 +- .../srt/layers/gmm/megablox_gmm_kernel/gmm.py | 38 ++--- python/sgl_jax/srt/layers/logits_processor.py | 52 ++----- python/sgl_jax/srt/layers/moe.py | 70 +++------ python/sgl_jax/srt/layers/sampler.py | 8 +- .../srt/managers/detokenizer_manager.py | 13 +- python/sgl_jax/srt/managers/io_struct.py | 32 +--- python/sgl_jax/srt/managers/schedule_batch.py | 130 ++++------------- .../sgl_jax/srt/managers/schedule_policy.py | 52 ++----- python/sgl_jax/srt/managers/scheduler.py | 83 +++-------- .../scheduler_output_processor_mixin.py | 121 +++++---------- .../srt/managers/scheduler_profiler_mixing.py | 18 +-- .../sgl_jax/srt/managers/tokenizer_manager.py | 138 +++++------------- python/sgl_jax/srt/managers/tp_worker.py | 110 ++++---------- .../srt/managers/tp_worker_overlap_thread.py | 24 +-- python/sgl_jax/srt/managers/utils.py | 4 +- python/sgl_jax/srt/mem_cache/allocator.py | 26 ++-- python/sgl_jax/srt/mem_cache/memory_pool.py | 68 ++------- python/sgl_jax/srt/mem_cache/radix_cache.py | 20 +-- python/sgl_jax/srt/memory_profiler.py | 18 +-- .../srt/model_executor/model_runner.py | 73 +++------ python/sgl_jax/srt/model_loader/arch.py | 4 +- python/sgl_jax/srt/model_loader/loader.py | 20 +-- python/sgl_jax/srt/models/llama.py | 12 +- python/sgl_jax/srt/models/qwen.py | 4 +- python/sgl_jax/srt/models/qwen2.py | 4 +- python/sgl_jax/srt/models/qwen3.py | 4 +- python/sgl_jax/srt/models/qwen3_moe.py | 15 +- python/sgl_jax/srt/models/registry.py | 8 +- python/sgl_jax/srt/precision_tracer.py | 65 ++------- python/sgl_jax/srt/reasoning_parser.py | 4 +- .../sampling/penaltylib/frequency_penalty.py | 5 +- .../srt/sampling/penaltylib/min_new_tokens.py | 24 +-- .../srt/sampling/penaltylib/orchestrator.py | 9 +- .../sampling/penaltylib/presence_penalty.py | 9 +- .../srt/sampling/sampling_batch_info.py | 56 ++----- .../sgl_jax/srt/sampling/sampling_params.py | 25 +--- python/sgl_jax/srt/server_args.py | 18 +-- python/sgl_jax/srt/utils/common_utils.py | 40 ++--- python/sgl_jax/srt/utils/jax_utils.py | 8 +- python/sgl_jax/srt/utils/mesh_utils.py | 4 +- .../sgl_jax/test/mem_cache/test_kv_cache.py | 137 +++++------------ .../test/mem_cache/test_radix_cache.py | 24 +-- .../test/model_executor/test_model_runner.py | 40 ++--- python/sgl_jax/test/models/test_qwen_model.py | 64 +++----- python/sgl_jax/test/run_curl.py | 4 +- python/sgl_jax/test/run_eval.py | 20 +-- python/sgl_jax/test/run_jax_loader_test.py | 16 +- python/sgl_jax/test/run_qwen3_moe_test.py | 40 ++--- python/sgl_jax/test/run_qwen_test.py | 38 ++--- python/sgl_jax/test/simple_eval_common.py | 8 +- python/sgl_jax/test/simple_eval_gpqa.py | 4 +- python/sgl_jax/test/simple_eval_humaneval.py | 11 +- python/sgl_jax/test/simple_eval_math.py | 4 +- python/sgl_jax/test/simple_eval_mgsm.py | 8 +- python/sgl_jax/test/simple_eval_mmlu.py | 8 +- python/sgl_jax/test/test_flashattention.py | 52 ++----- python/sgl_jax/test/test_jax_model_loader.py | 38 ++--- python/sgl_jax/test/test_model_loader.py | 101 ++++--------- .../test/test_multi_process_model_loader.py | 4 +- .../test/test_multi_process_radix_cache.py | 51 ++----- python/sgl_jax/test/test_utils.py | 16 +- python/sgl_jax/tools/trace_diff.py | 20 +-- .../openai_server/basic/test_openai_server.py | 48 ++---- test/srt/openai_server/basic/test_protocol.py | 8 +- .../openai_server/basic/test_serving_chat.py | 20 +-- .../basic/test_serving_completions.py | 8 +- .../test_openai_server_params_validation.py | 4 +- test/srt/run_suite.py | 8 +- test/srt/test_features.py | 16 +- test/srt/test_srt_engine.py | 3 +- 98 files changed, 818 insertions(+), 2281 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26c4173fb..a3eff1439 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,6 +27,8 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.13.3 hooks: + # Ruff is lint-only; formatting is handled by Black. + # Do not add ruff-format to avoid conflicts with Black. - id: ruff-check args: [--output-format, github, --fix] files: ^(python/|benchmark/|docs/|examples/) @@ -35,6 +37,8 @@ repos: rev: 24.10.0 hooks: - id: black-jupyter + args: ["--config", "python/pyproject.toml"] + # Black is the only formatter; keep Ruff formatting disabled. # exclude: > # (?x)^( # python/sgl_jax/srt/entrypoints/openai/serving_rerank\.py| diff --git a/benchmark/kernels/flash_attention/bench_flashattention.py b/benchmark/kernels/flash_attention/bench_flashattention.py index 45d28e098..c15777a1f 100644 --- a/benchmark/kernels/flash_attention/bench_flashattention.py +++ b/benchmark/kernels/flash_attention/bench_flashattention.py @@ -181,10 +181,7 @@ def main(): ) for max_num_batched_tokens in max_num_batched_tokens_config: - if ( - q_head_num < kv_head_num - or q_head_num % kv_head_num != 0 - ): + if q_head_num < kv_head_num or q_head_num % kv_head_num != 0: continue all_combinations.append( ( diff --git a/benchmark/kernels/flash_attention/get_block_spec_config.py b/benchmark/kernels/flash_attention/get_block_spec_config.py index ce7fa546f..18b6df3b7 100644 --- a/benchmark/kernels/flash_attention/get_block_spec_config.py +++ b/benchmark/kernels/flash_attention/get_block_spec_config.py @@ -177,10 +177,7 @@ def main(): for page_size in page_size_config: for max_kv_cache_tokens in max_kv_cache_tokens_config: for max_num_batched_tokens in max_num_batched_tokens_config: - if ( - q_head_num < kv_head_num - or q_head_num % kv_head_num != 0 - ): + if q_head_num < kv_head_num or q_head_num % kv_head_num != 0: continue all_combinations.append( ( @@ -215,9 +212,7 @@ def main(): ) in enumerate(all_combinations): best_output = inf best_config = None - for i, (num_kv_pages_per_blk, num_queries_per_block) in enumerate( - block_spec_configs - ): + for i, (num_kv_pages_per_blk, num_queries_per_block) in enumerate(block_spec_configs): try: ( flash_time, diff --git a/benchmark/kernels/flash_attention/utils.py b/benchmark/kernels/flash_attention/utils.py index 062715769..0e87f44e6 100644 --- a/benchmark/kernels/flash_attention/utils.py +++ b/benchmark/kernels/flash_attention/utils.py @@ -16,9 +16,7 @@ def create_kv_cache_data( return kv_cache -def create_qkv_data( - total_tokens, q_head_num, kv_head_num, head_dim, dtype=jnp.bfloat16, seed=42 -): +def create_qkv_data(total_tokens, q_head_num, kv_head_num, head_dim, dtype=jnp.bfloat16, seed=42): key = jax.random.PRNGKey(seed) keys = jax.random.split(key, 3) q = jax.random.normal(keys[0], (total_tokens, q_head_num, head_dim), dtype=dtype) @@ -27,14 +25,10 @@ def create_qkv_data( return q, k, v -def create_page_indices_data( - num_seqs, total_kv_tokens, seq_lens, max_context_len, page_size=128 -): +def create_page_indices_data(num_seqs, total_kv_tokens, seq_lens, max_context_len, page_size=128): cache_loc = jnp.arange(0, total_kv_tokens, dtype=jnp.int32) - cache_start_idx = jnp.concatenate( - [jnp.array([0], dtype=jnp.int32), jnp.cumsum(seq_lens)] - ) + cache_start_idx = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jnp.cumsum(seq_lens)]) cache_loc_list = [] for i in range(num_seqs): @@ -130,9 +124,7 @@ def create_decode_uniform_data( ): batch_size = max_num_batched_tokens # hackly set prefix len to 2048-4096 for decode one seq in random - random_prefix_lens = jax.random.randint( - jax.random.PRNGKey(42), (batch_size,), 1024, 2048 - ) + random_prefix_lens = jax.random.randint(jax.random.PRNGKey(42), (batch_size,), 1024, 2048) seq_lens = random_prefix_lens + 1 cu_q_lens = jnp.concatenate( [ @@ -146,9 +138,7 @@ def create_decode_uniform_data( jnp.cumsum(seq_lens), ] ) - q, k, v = create_qkv_data( - batch_size, q_head_num, kv_head_num, head_dim, dtype, seed - ) + q, k, v = create_qkv_data(batch_size, q_head_num, kv_head_num, head_dim, dtype, seed) kv_cache = create_kv_cache_data( max_kv_cache_tokens, kv_head_num, diff --git a/benchmark/kernels/megablox_gmm/bench_megablox_gmm.py b/benchmark/kernels/megablox_gmm/bench_megablox_gmm.py index e2e04e665..0a11e7b04 100644 --- a/benchmark/kernels/megablox_gmm/bench_megablox_gmm.py +++ b/benchmark/kernels/megablox_gmm/bench_megablox_gmm.py @@ -186,15 +186,9 @@ def main(): worst_config = max(results, key=lambda x: x["megablox_ms"]) print("-" * 80) - print( - f"Best performance: {best_config['config']} - {best_config['megablox_ms']:.2f} ms" - ) - print( - f"Worst performance: {worst_config['config']} - {worst_config['megablox_ms']:.2f} ms" - ) - print( - f"Speedup ratio: {worst_config['megablox_ms'] / best_config['megablox_ms']:.2f}x" - ) + print(f"Best performance: {best_config['config']} - {best_config['megablox_ms']:.2f} ms") + print(f"Worst performance: {worst_config['config']} - {worst_config['megablox_ms']:.2f} ms") + print(f"Speedup ratio: {worst_config['megablox_ms'] / best_config['megablox_ms']:.2f}x") if __name__ == "__main__": diff --git a/benchmark/kernels/update_kv_cache/bench_update_kv_cache.py b/benchmark/kernels/update_kv_cache/bench_update_kv_cache.py index e26651e49..eb9344063 100644 --- a/benchmark/kernels/update_kv_cache/bench_update_kv_cache.py +++ b/benchmark/kernels/update_kv_cache/bench_update_kv_cache.py @@ -131,11 +131,9 @@ def main(): head_num, head_dim, ) - max_num_slices_per_block_config = get_num_slices_per_block( - new_value, cache, page_size - ) - random_cache_loc, slice_lens, new_value_start_loc, update_slices_num = ( - create_input_params(max_cache_len, new_value_len, page_size=page_size) + max_num_slices_per_block_config = get_num_slices_per_block(new_value, cache, page_size) + random_cache_loc, slice_lens, new_value_start_loc, update_slices_num = create_input_params( + max_cache_len, new_value_len, page_size=page_size ) print( @@ -160,9 +158,7 @@ def main(): if cost < min_cost: min_cost = cost fastest_num_slices_per_block = num_slices_per_block - print( - f"[num_slices_per_block={num_slices_per_block}] avg cost: {cost * 1000} ms" - ) + print(f"[num_slices_per_block={num_slices_per_block}] avg cost: {cost * 1000} ms") print( f"Fastest [num_slices_per_block={fastest_num_slices_per_block}] costs: {min_cost * 1000} ms" diff --git a/benchmark/kernels/update_kv_cache/utils.py b/benchmark/kernels/update_kv_cache/utils.py index 8a663756e..6d1aa9b5b 100644 --- a/benchmark/kernels/update_kv_cache/utils.py +++ b/benchmark/kernels/update_kv_cache/utils.py @@ -13,12 +13,8 @@ def create_bench_data( ): key = jax.random.PRNGKey(42) keys = jax.random.split(key, 3) - new_value = jax.random.normal( - keys[1], (new_kv_len, kv_head_num, head_dim), dtype=dtype - ) - cache = jax.random.normal( - keys[2], (cache_max_tokens, kv_head_num, head_dim), dtype=dtype - ) + new_value = jax.random.normal(keys[1], (new_kv_len, kv_head_num, head_dim), dtype=dtype) + cache = jax.random.normal(keys[2], (cache_max_tokens, kv_head_num, head_dim), dtype=dtype) return new_value, cache @@ -37,8 +33,7 @@ def create_random_cache_start_loc(cache_max_tokens, new_kv_len, page_size=128): new_value_page_num = cdiv(new_kv_len, page_size) max_cache_page_num = cdiv(cache_max_tokens, page_size) cache_start_loc = ( - jax.random.randint(key, (new_value_page_num,), 0, max_cache_page_num - 1) - * page_size + jax.random.randint(key, (new_value_page_num,), 0, max_cache_page_num - 1) * page_size ) return cache_start_loc diff --git a/python/sgl_jax/bench_offline_throughput.py b/python/sgl_jax/bench_offline_throughput.py index 08ca400ec..1fc811c69 100644 --- a/python/sgl_jax/bench_offline_throughput.py +++ b/python/sgl_jax/bench_offline_throughput.py @@ -62,9 +62,7 @@ class BenchArgs: @staticmethod def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument("--backend", type=str, default=BenchArgs.backend) - parser.add_argument( - "--result-filename", type=str, default=BenchArgs.result_filename - ) + parser.add_argument("--result-filename", type=str, default=BenchArgs.result_filename) parser.add_argument( "--dataset-name", type=str, @@ -72,9 +70,7 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=["sharegpt", "random", "generated-shared-prefix"], help="Name of the dataset to benchmark on.", ) - parser.add_argument( - "--dataset-path", type=str, default="", help="Path to the dataset." - ) + parser.add_argument("--dataset-path", type=str, default="", help="Path to the dataset.") parser.add_argument( "--num-prompts", type=int, @@ -222,9 +218,7 @@ def throughput_test_once( ] if profile: - assert ( - "SGLANG_TORCH_PROFILER_DIR" in os.environ - ), "Please set SGLANG_TORCH_PROFILER_DIR." + assert "SGLANG_TORCH_PROFILER_DIR" in os.environ, "Please set SGLANG_TORCH_PROFILER_DIR." os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True) backend.start_profile() @@ -247,18 +241,11 @@ def throughput_test_once( measurement_results["total_output_tokens"] = sum( o["meta_info"]["completion_tokens"] for o in gen_out ) - measurement_results["request_throughput"] = ( - measurement_results["successful_requests"] / latency - ) - measurement_results["input_throughput"] = ( - measurement_results["total_input_tokens"] / latency - ) - measurement_results["output_throughput"] = ( - measurement_results["total_output_tokens"] / latency - ) + measurement_results["request_throughput"] = measurement_results["successful_requests"] / latency + measurement_results["input_throughput"] = measurement_results["total_input_tokens"] / latency + measurement_results["output_throughput"] = measurement_results["total_output_tokens"] / latency measurement_results["total_throughput"] = ( - measurement_results["total_input_tokens"] - + measurement_results["total_output_tokens"] + measurement_results["total_input_tokens"] + measurement_results["total_output_tokens"] ) / latency if inspect.isawaitable(server_info): @@ -370,41 +357,23 @@ def throughput_test( with open(bench_args.result_filename, "a") as fout: fout.write(json.dumps(result) + "\n") - print( - "\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=") - ) + print("\n{s:{c}^{n}}".format(s=" Offline Throughput Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Backend:", result["backend"])) print("{:<40} {:<10}".format("Successful requests:", result["successful_requests"])) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", result["total_latency"])) print("{:<40} {:<10}".format("Total input tokens:", result["total_input_tokens"])) - print( - "{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"]) - ) + print("{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])) print( "{:<40} {:<10.2f}".format( "Last generation throughput (tok/s):", result["last_gen_throughput"] ) ) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", result["request_throughput"])) + print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", result["input_throughput"])) print( - "{:<40} {:<10.2f}".format( - "Request throughput (req/s):", result["request_throughput"] - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Input token throughput (tok/s):", result["input_throughput"] - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Output token throughput (tok/s):", result["output_throughput"] - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Total token throughput (tok/s):", result["total_throughput"] - ) + "{:<40} {:<10.2f}".format("Output token throughput (tok/s):", result["output_throughput"]) ) + print("{:<40} {:<10.2f}".format("Total token throughput (tok/s):", result["total_throughput"])) print("=" * 50) return result diff --git a/python/sgl_jax/bench_one_batch.py b/python/sgl_jax/bench_one_batch.py index 0260ccd19..f569eca2b 100644 --- a/python/sgl_jax/bench_one_batch.py +++ b/python/sgl_jax/bench_one_batch.py @@ -85,18 +85,10 @@ class BenchArgs: @staticmethod def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument("--run-name", type=str, default=BenchArgs.run_name) - parser.add_argument( - "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size - ) - parser.add_argument( - "--input-len", type=int, nargs="+", default=BenchArgs.input_len - ) - parser.add_argument( - "--output-len", type=int, nargs="+", default=BenchArgs.output_len - ) - parser.add_argument( - "--result-filename", type=str, default=BenchArgs.result_filename - ) + parser.add_argument("--batch-size", type=int, nargs="+", default=BenchArgs.batch_size) + parser.add_argument("--input-len", type=int, nargs="+", default=BenchArgs.input_len) + parser.add_argument("--output-len", type=int, nargs="+", default=BenchArgs.output_len) + parser.add_argument("--result-filename", type=str, default=BenchArgs.result_filename) parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) parser.add_argument( @@ -118,9 +110,7 @@ def add_cli_args(parser: argparse.ArgumentParser): def from_cli_args(cls, args: argparse.Namespace): # use the default value's type to cast the args into correct types. attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] - return cls( - **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} - ) + return cls(**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}) def load_model(server_args, port_args, tp_rank): @@ -156,9 +146,7 @@ def load_model(server_args, port_args, tp_rank): try: jax_mh.sync_global_devices("load_model") except Exception as err: - logging.info( - "Could not sync global devices (expected in single-host): %s", err - ) + logging.info("Could not sync global devices (expected in single-host): %s", err) return model_runner, tokenizer @@ -194,15 +182,11 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer): return input_ids, reqs -def prepare_extend_inputs_for_correctness_test( - bench_args, input_ids, reqs, model_runner -): +def prepare_extend_inputs_for_correctness_test(bench_args, input_ids, reqs, model_runner): for i in range(len(reqs)): req = reqs[i] req.fill_ids += input_ids[i][bench_args.cut_len :] - req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ - i, : bench_args.cut_len - ] + req.prefix_indices = model_runner.req_to_token_pool.req_to_token[i, : bench_args.cut_len] req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.logprob_start_len = len(req.origin_input_ids) - 1 return reqs @@ -212,9 +196,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32) sampling_params = SamplingParams( temperature=0, - max_new_tokens=( - BenchArgs.output_len[0] if isinstance(BenchArgs.output_len, tuple) else 16 - ), + max_new_tokens=(BenchArgs.output_len[0] if isinstance(BenchArgs.output_len, tuple) else 16), ) reqs = [] @@ -252,9 +234,7 @@ def extend(reqs, model_runner): token_needed = int(np.sum(np.array(batch.extend_lens, dtype=np.int64))) else: token_needed = int(np.sum(np.array(batch.seq_lens, dtype=np.int64))) - next_token_ids, next_token_logits = _run_forward_and_sample( - model_runner, batch, token_needed - ) + next_token_ids, next_token_logits = _run_forward_and_sample(model_runner, batch, token_needed) return next_token_ids, next_token_logits, batch @@ -264,9 +244,7 @@ def decode(input_token_ids, batch, model_runner): _maybe_prepare_mlp_sync_batch(batch, model_runner) # For decode, the token dimension equals current batch size bs_needed = len(batch.seq_lens) - next_token_ids, next_token_logits = _run_forward_and_sample( - model_runner, batch, bs_needed - ) + next_token_ids, next_token_logits = _run_forward_and_sample(model_runner, batch, bs_needed) return next_token_ids, next_token_logits @@ -287,8 +265,7 @@ def _run_forward_and_sample(model_runner, batch: ScheduleBatch, token_first_arg: bs_needed = len(batch.seq_lens) cache_loc_needed = int( np.sum( - ((np.array(batch.seq_lens, dtype=np.int64) + page_size - 1) // page_size) - * page_size + ((np.array(batch.seq_lens, dtype=np.int64) + page_size - 1) // page_size) * page_size ) ) @@ -297,9 +274,7 @@ def _run_forward_and_sample(model_runner, batch: ScheduleBatch, token_first_arg: ) # Prepare attention forward metadata (required by FlashAttention backend) - forward_metadata = model_runner.attn_backend.get_forward_metadata( - model_worker_batch - ) + forward_metadata = model_runner.attn_backend.get_forward_metadata(model_worker_batch) model_runner.attn_backend.forward_metadata = forward_metadata forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) @@ -307,9 +282,7 @@ def _run_forward_and_sample(model_runner, batch: ScheduleBatch, token_first_arg: model_worker_batch, mesh=model_runner.mesh ) - logits_output, _ = model_runner.forward( - forward_batch, logits_metadata=logits_metadata - ) + logits_output, _ = model_runner.forward(forward_batch, logits_metadata=logits_metadata) pad_size = len(model_worker_batch.seq_lens) - model_worker_batch.real_bs sampling_metadata = SamplingMetadata.from_model_worker_batch( @@ -343,9 +316,7 @@ def correctness_test( rank_print(f"prefill logits (first half): {next_token_logits} \n") # Prepare extend inputs - reqs = prepare_extend_inputs_for_correctness_test( - bench_args, input_ids, reqs, model_runner - ) + reqs = prepare_extend_inputs_for_correctness_test(bench_args, input_ids, reqs, model_runner) # Extend (prefill w/ KV cache) next_token_ids, next_token_logits, batch = extend(reqs, model_runner) @@ -404,7 +375,9 @@ def latency_test_run_once( tot_latency = 0 if profile: - profile_dir = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.tb" + profile_dir = ( + f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.tb" + ) os.makedirs(profile_dir, exist_ok=True) jax_profiler.start_trace(profile_dir) @@ -416,9 +389,7 @@ def latency_test_run_once( prefill_latency = time.perf_counter() - tic tot_latency += prefill_latency throughput = input_len * batch_size / prefill_latency - rank_print( - f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s" - ) + rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s") measurement_results["prefill_latency"] = prefill_latency measurement_results["prefill_throughput"] = throughput @@ -453,9 +424,7 @@ def latency_test_run_once( measurement_results["median_decode_throughput"] = med_decode_throughput throughput = (input_len + output_len) * batch_size / tot_latency - rank_print( - f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s" - ) + rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s") measurement_results["total_latency"] = tot_latency measurement_results["overall_throughput"] = throughput return measurement_results diff --git a/python/sgl_jax/bench_one_batch_server.py b/python/sgl_jax/bench_one_batch_server.py index a4a5da29f..8e26c0c2a 100644 --- a/python/sgl_jax/bench_one_batch_server.py +++ b/python/sgl_jax/bench_one_batch_server.py @@ -49,15 +49,9 @@ class BenchArgs: @staticmethod def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument("--run-name", type=str, default=BenchArgs.run_name) - parser.add_argument( - "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size - ) - parser.add_argument( - "--input-len", type=int, nargs="+", default=BenchArgs.input_len - ) - parser.add_argument( - "--output-len", type=int, nargs="+", default=BenchArgs.output_len - ) + parser.add_argument("--batch-size", type=int, nargs="+", default=BenchArgs.batch_size) + parser.add_argument("--input-len", type=int, nargs="+", default=BenchArgs.input_len) + parser.add_argument("--output-len", type=int, nargs="+", default=BenchArgs.output_len) parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) parser.add_argument("--return-logprob", action="store_true") parser.add_argument( @@ -70,9 +64,7 @@ def add_cli_args(parser: argparse.ArgumentParser): type=float, default=BenchArgs.input_len_step_percentage, ) - parser.add_argument( - "--result-filename", type=str, default=BenchArgs.result_filename - ) + parser.add_argument("--result-filename", type=str, default=BenchArgs.result_filename) parser.add_argument("--base-url", type=str, default=BenchArgs.base_url) parser.add_argument("--skip-warmup", action="store_true") parser.add_argument("--show-report", action="store_true") @@ -83,9 +75,7 @@ def add_cli_args(parser: argparse.ArgumentParser): def from_cli_args(cls, args: argparse.Namespace): # use the default value's type to cast the args into correct types. attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)] - return cls( - **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs} - ) + return cls(**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}) def launch_server_internal(server_args): @@ -160,9 +150,7 @@ def run_one_case( profile_link = None if profile: - profile_link: str = run_profile( - url, 3, ["CPU", "GPU"], None, None, profile_by_stage - ) + profile_link: str = run_profile(url, 3, ["CPU", "GPU"], None, None, profile_by_stage) tic = time.perf_counter() response = requests.post( @@ -336,9 +324,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): if not bench_args.show_report: return - summary = ( - f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n" - ) + summary = f"\nInput lens: {bench_args.input_len}. Output lens: {bench_args.output_len}.\n" summary += "| batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) |" if bench_args.profile: diff --git a/python/sgl_jax/bench_serving.py b/python/sgl_jax/bench_serving.py index 7e5cbea99..e73a87e5d 100644 --- a/python/sgl_jax/bench_serving.py +++ b/python/sgl_jax/bench_serving.py @@ -214,9 +214,7 @@ async def async_request_openai_completions( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post( - url=api_url, json=payload, headers=headers - ) as response: + async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -295,9 +293,7 @@ async def async_request_truss( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post( - url=api_url, json=payload, headers=headers - ) as response: + async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -382,9 +378,7 @@ async def async_request_sglang_generate( most_recent_timestamp = st last_output_len = 0 try: - async with session.post( - url=api_url, json=payload, headers=headers - ) as response: + async with session.post(url=api_url, json=payload, headers=headers) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() @@ -487,13 +481,10 @@ def get_model(pretrained_model_name_or_path: str) -> str: def get_tokenizer( pretrained_model_name_or_path: str, ) -> PreTrainedTokenizer | PreTrainedTokenizerFast: - assert ( - pretrained_model_name_or_path is not None - and pretrained_model_name_or_path != "" - ) - if pretrained_model_name_or_path.endswith( - ".json" - ) or pretrained_model_name_or_path.endswith(".model"): + assert pretrained_model_name_or_path is not None and pretrained_model_name_or_path != "" + if pretrained_model_name_or_path.endswith(".json") or pretrained_model_name_or_path.endswith( + ".model" + ): from sglang.srt.hf_transformers_utils import get_tokenizer return get_tokenizer(pretrained_model_name_or_path) @@ -502,9 +493,7 @@ def get_tokenizer( pretrained_model_name_or_path ): pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) - return AutoTokenizer.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=True - ) + return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) def get_dataset(args, tokenizer): @@ -650,9 +639,7 @@ def is_file_valid_json(path): json.load(f) return True except JSONDecodeError as e: - print( - f"{path} exists but json loading fails ({e=}), thus treat as invalid file" - ) + print(f"{path} exists but json loading fails ({e=}), thus treat as invalid file") return False @@ -710,9 +697,7 @@ def sample_mmmu_requests( sample_dataset = mmmu_dataset.select(indices) else: # Take first N - sample_dataset = mmmu_dataset.select( - range(min(num_requests, len(mmmu_dataset))) - ) + sample_dataset = mmmu_dataset.select(range(min(num_requests, len(mmmu_dataset)))) else: print(f"Dataset has less than {num_requests} examples, using all examples") sample_dataset = mmmu_dataset @@ -838,11 +823,7 @@ def sample_sharegpt_requests( # Tokenize the prompts and completions. prompt = dataset[i][0] if prompt_suffix: - prompt = ( - remove_suffix(prompt, ASSISTANT_SUFFIX) - + prompt_suffix - + ASSISTANT_SUFFIX - ) + prompt = remove_suffix(prompt, ASSISTANT_SUFFIX) + prompt_suffix + ASSISTANT_SUFFIX if apply_chat_template: prompt = tokenizer.apply_chat_template( @@ -856,9 +837,7 @@ def sample_sharegpt_requests( completion = dataset[i][1] completion_token_ids = tokenizer.encode(completion) prompt_len = len(prompt_token_ids) - output_len = ( - len(completion_token_ids) if fixed_output_len is None else fixed_output_len - ) + output_len = len(completion_token_ids) if fixed_output_len is None else fixed_output_len if prompt_len < 2 or output_len < 2: # Prune too short sequences. @@ -962,8 +941,7 @@ def sample_random_requests( input_requests = [] for i in range(num_prompts): input_content = [ - (offsets[i] + i + j) % tokenizer.vocab_size - for j in range(input_lens[i]) + (offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i]) ] if return_text: input_content = tokenizer.decode(input_content) @@ -1039,17 +1017,13 @@ def sample_generated_shared_prefix_requests( for group_idx in tqdm(range(num_groups), desc="Generating system prompt"): system_prompt = system_prompts[group_idx] - for prompt_idx in tqdm( - range(prompts_per_group), desc="Generating questions", leave=False - ): + for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions", leave=False): question = questions[group_idx * prompts_per_group + prompt_idx] full_prompt = f"{system_prompt}\n\n{question}" prompt_len = len(tokenizer.encode(full_prompt)) input_requests.append( - DatasetRow( - prompt=full_prompt, prompt_len=prompt_len, output_len=output_len - ) + DatasetRow(prompt=full_prompt, prompt_len=prompt_len, output_len=output_len) ) total_input_tokens += prompt_len total_output_tokens += output_len @@ -1150,8 +1124,7 @@ def calculate_metrics( output_throughput=sum(output_lens) / dur_s, output_throughput_retokenized=sum(retokenized_output_lens) / dur_s, total_throughput=(total_input + sum(output_lens)) / dur_s, - total_throughput_retokenized=(total_input + sum(retokenized_output_lens)) - / dur_s, + total_throughput_retokenized=(total_input + sum(retokenized_output_lens)) / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend median_ttft_ms=np.median(ttfts or 0) * 1000, @@ -1215,9 +1188,7 @@ async def limited_request_func(request_func_input, pbar): # Use the first request for all warmup iterations test_request = input_requests[0] - lora_name = ( - lora_names[0] if lora_names is not None and len(lora_names) != 0 else None - ) + lora_name = lora_names[0] if lora_names is not None and len(lora_names) != 0 else None # Create the test input once test_input = RequestFuncInput( @@ -1234,9 +1205,7 @@ async def limited_request_func(request_func_input, pbar): # Run warmup requests warmup_tasks = [] for _ in range(warmup_requests): - warmup_tasks.append( - asyncio.create_task(request_func(request_func_input=test_input)) - ) + warmup_tasks.append(asyncio.create_task(request_func(request_func_input=test_input))) warmup_outputs = await asyncio.gather(*warmup_tasks) @@ -1260,9 +1229,7 @@ async def limited_request_func(request_func_input, pbar): # Start profiler if profile: print("Starting profiler...") - profile_output = await async_request_profile( - api_url=base_url + "/start_profile" - ) + profile_output = await async_request_profile(api_url=base_url + "/start_profile") if profile_output.success: print("Profiler started") @@ -1348,38 +1315,16 @@ async def limited_request_func(request_func_input, pbar): "Total generated tokens (retokenized):", metrics.total_output_retokenized ) ) - print( - "{:<40} {:<10.2f}".format( - "Request throughput (req/s):", metrics.request_throughput - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Input token throughput (tok/s):", metrics.input_throughput - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Output token throughput (tok/s):", metrics.output_throughput - ) - ) - print( - "{:<40} {:<10.2f}".format( - "Total token throughput (tok/s):", metrics.total_throughput - ) - ) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) + print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", metrics.input_throughput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total token throughput (tok/s):", metrics.total_throughput)) print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency)) if accept_length: print("{:<40} {:<10.2f}".format("Accept length:", accept_length)) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) - print( - "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) - ) - print( - "{:<40} {:<10.2f}".format( - "Median E2E Latency (ms):", metrics.median_e2e_latency_ms - ) - ) + print("{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)) + print("{:<40} {:<10.2f}".format("Median E2E Latency (ms):", metrics.median_e2e_latency_ms)) print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-")) print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) @@ -1565,9 +1510,7 @@ def run_benchmark(args_: argparse.Namespace): if args.base_url else f"http://{args.host}:{args.port}/v1/models/model:predict" ) - base_url = ( - f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url - ) + base_url = f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url # Get model name if args.model is None: @@ -1582,9 +1525,7 @@ def run_benchmark(args_: argparse.Namespace): args.model = model_list[0]["id"] if model_list else None except Exception as e: print(f"Failed to fetch model from {model_url}. Error: {e}") - print( - "Please specify the correct host and port using `--host` and `--port`." - ) + print("Please specify the correct host and port using `--host` and `--port`.") sys.exit(1) if args.model is None: @@ -1664,9 +1605,7 @@ def __call__(self, parser, namespace, values, option_string=None): default=None, help="Server or API base url if not using http host and port.", ) - parser.add_argument( - "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." - ) + parser.add_argument("--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0.") parser.add_argument( "--port", type=int, @@ -1679,9 +1618,7 @@ def __call__(self, parser, namespace, values, option_string=None): choices=["sharegpt", "random", "random-ids", "generated-shared-prefix", "mmmu"], help="Name of the dataset to benchmark on.", ) - parser.add_argument( - "--dataset-path", type=str, default="", help="Path to the dataset." - ) + parser.add_argument("--dataset-path", type=str, default="", help="Path to the dataset.") parser.add_argument( "--model", type=str, diff --git a/python/sgl_jax/profiler.py b/python/sgl_jax/profiler.py index c7fd18dd0..a599684d1 100644 --- a/python/sgl_jax/profiler.py +++ b/python/sgl_jax/profiler.py @@ -39,9 +39,7 @@ def _run_profile( output_dir.mkdir(exist_ok=True, parents=True) print(f"Dump profiling traces to {output_dir}") - print( - f"Waiting for {num_steps} steps and the trace to be flushed.... ({profile_by_stage=})" - ) + print(f"Waiting for {num_steps} steps and the trace to be flushed.... ({profile_by_stage=})") # Dump server args. file_path = Path(output_dir) / "server_args.json" @@ -77,9 +75,7 @@ def run_profile( profile_by_stage: bool = False, ): # step based profile will self terminate on num_steps constraints - link = _run_profile( - url, num_steps, activities, output_dir, profile_name, profile_by_stage - ) + link = _run_profile(url, num_steps, activities, output_dir, profile_name, profile_by_stage) return link diff --git a/python/sgl_jax/srt/configs/model_config.py b/python/sgl_jax/srt/configs/model_config.py index f0d90dd6d..72fd35bed 100644 --- a/python/sgl_jax/srt/configs/model_config.py +++ b/python/sgl_jax/srt/configs/model_config.py @@ -73,22 +73,15 @@ def __init__( ) self.hf_text_config = get_hf_text_config(self.hf_config) - self.attention_chunk_size = getattr( - self.hf_text_config, "attention_chunk_size", None - ) + self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None) - if ( - is_draft_model - and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM" - ): + if is_draft_model and self.hf_config.architectures[0] == "DeepseekV3ForCausalLM": self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM": self.hf_config.architectures[0] = "MiMoMTP" # Check model type - self.is_generation = is_generation_model( - self.hf_config.architectures, is_embedding - ) + self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding) self.is_multimodal = False self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -96,9 +89,7 @@ def __init__( derived_context_len = get_context_length(self.hf_text_config) if context_length is not None: if context_length > derived_context_len: - if get_bool_env_var( - "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True" - ): + if get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"): logger.warning( "Warning: User-specified context_length (%s) is greater than the derived context_length (%s). This may lead to incorrect model outputs or CUDA errors.", context_length, @@ -123,15 +114,11 @@ def __init__( self.attention_arch = AttentionArch.MHA self.num_attention_heads = self.hf_text_config.num_attention_heads - self.num_key_value_heads = getattr( - self.hf_text_config, "num_key_value_heads", None - ) + self.num_key_value_heads = getattr(self.hf_text_config, "num_key_value_heads", None) # for Dbrx and MPT models if self.hf_config.model_type in ["dbrx", "mpt"]: - self.num_key_value_heads = getattr( - self.hf_config.attn_config, "kv_n_heads", None - ) + self.num_key_value_heads = getattr(self.hf_config.attn_config, "kv_n_heads", None) if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads @@ -142,9 +129,7 @@ def __init__( # Override num_hidden_layers if model_layer_nums is specified if model_layer_nums is not None: if model_layer_nums <= 0: - raise ValueError( - f"model_layer_nums must be positive, got {model_layer_nums}" - ) + raise ValueError(f"model_layer_nums must be positive, got {model_layer_nums}") if model_layer_nums > self.num_hidden_layers: logger.warning( "model_layer_nums (%s) is greater than the original num_hidden_layers (%s). Using original value.", @@ -193,13 +178,10 @@ def get_total_num_kv_heads(self) -> int: # multi_query flag is ignored and we use n_head_kv for the number of # KV heads. falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] - new_decoder_arch_falcon = ( - self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False) + new_decoder_arch_falcon = self.hf_config.model_type in falcon_model_types and getattr( + self.hf_config, "new_decoder_architecture", False ) - if not new_decoder_arch_falcon and getattr( - self.hf_text_config, "multi_query", False - ): + if not new_decoder_arch_falcon and getattr(self.hf_text_config, "multi_query", False): # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. return 1 @@ -277,15 +259,11 @@ def configure_for_tensor_parallel(self, tensor_parallel_size: int): # Handle cases where HF config doesn't have num_key_value_heads (MHA models) if hasattr(self.hf_text_config, "num_key_value_heads"): if not hasattr(self, "_original_hf_num_key_value_heads"): - self._original_hf_num_key_value_heads = ( - self.hf_text_config.num_key_value_heads - ) + self._original_hf_num_key_value_heads = self.hf_text_config.num_key_value_heads else: # For MHA models without this attribute, it equals num_attention_heads if not hasattr(self, "_original_hf_num_key_value_heads"): - self._original_hf_num_key_value_heads = ( - self.hf_text_config.num_attention_heads - ) + self._original_hf_num_key_value_heads = self.hf_text_config.num_attention_heads # CRITICAL: Set to TOTAL count for global sharding # JAX tensor parallel will automatically shard this across devices @@ -304,9 +282,7 @@ def get_original_kv_head_id(self, tp_rank: int, tensor_parallel_size: int) -> in from sgl_jax.srt.utils.jax_utils import get_original_kv_head_id total_num_kv_heads = self.get_total_num_kv_heads() - return get_original_kv_head_id( - tp_rank, total_num_kv_heads, tensor_parallel_size - ) + return get_original_kv_head_id(tp_rank, total_num_kv_heads, tensor_parallel_size) def is_gqa_model(self) -> bool: """Returns True if this is a Grouped Query Attention model.""" @@ -379,9 +355,7 @@ def _parse_quant_hf_config(self): if hf_api.file_exists(self.model_path, "hf_quant_config.json"): quant_cfg = modelopt_quant_config elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")): - quant_config_file = os.path.join( - self.model_path, "hf_quant_config.json" - ) + quant_config_file = os.path.join(self.model_path, "hf_quant_config.json") with open(quant_config_file) as f: quant_config_dict = json.load(f) json_quant_configs = quant_config_dict["quantization"] @@ -400,9 +374,7 @@ def get_hf_eos_token_id(self) -> set[int] | None: if eos_ids is None: eos_ids = set() if self.hf_generation_config: - generation_eos_ids = getattr( - self.hf_generation_config, "eos_token_id", None - ) + generation_eos_ids = getattr(self.hf_generation_config, "eos_token_id", None) if generation_eos_ids: generation_eos_ids = ( {generation_eos_ids} @@ -447,9 +419,7 @@ def _get_and_verify_dtype( if isinstance(config_dtype, str): config_dtype = _STR_DTYPE_TO_JAX_DTYPE.get(config_dtype) elif config_dtype is not None: - config_dtype = _STR_DTYPE_TO_JAX_DTYPE.get( - str(config_dtype).replace("torch.", ""), None - ) + config_dtype = _STR_DTYPE_TO_JAX_DTYPE.get(str(config_dtype).replace("torch.", ""), None) if config_dtype is None: config_dtype = jnp.float32 diff --git a/python/sgl_jax/srt/conversation.py b/python/sgl_jax/srt/conversation.py index 15b7ba9c4..0b6e24e2b 100644 --- a/python/sgl_jax/srt/conversation.py +++ b/python/sgl_jax/srt/conversation.py @@ -71,9 +71,7 @@ def clear(self): def register_conv_template(template: Conversation, override: bool = False): """Register a new conversation template.""" if not override: - assert ( - template.name not in chat_templates - ), f"{template.name} has been registered." + assert template.name not in chat_templates, f"{template.name} has been registered." chat_templates[template.name] = template diff --git a/python/sgl_jax/srt/entrypoints/engine.py b/python/sgl_jax/srt/entrypoints/engine.py index f8de4c8d8..bee350f34 100644 --- a/python/sgl_jax/srt/entrypoints/engine.py +++ b/python/sgl_jax/srt/entrypoints/engine.py @@ -95,11 +95,9 @@ def __init__(self, **kwargs): logger.info("server_args=%s", server_args) # Launch subprocesses or threads - tokenizer_manager, template_manager, scheduler_info = ( - _launch_subprocesses_or_threads( - server_args=server_args, - port_args=self.port_args, - ) + tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses_or_threads( + server_args=server_args, + port_args=self.port_args, ) self.server_args = server_args self.tokenizer_manager = tokenizer_manager @@ -107,9 +105,7 @@ def __init__(self, **kwargs): self.scheduler_info = scheduler_info self.default_sampling_params: dict[str, Any] | None = None context = zmq.Context(2) - self.send_to_rpc = get_zmq_socket( - context, zmq.DEALER, self.port_args.rpc_ipc_name, True - ) + self.send_to_rpc = get_zmq_socket(context, zmq.DEALER, self.port_args.rpc_ipc_name, True) def generate( self, @@ -271,9 +267,7 @@ def stop_profile(self): def get_server_info(self): loop = asyncio.get_event_loop() - internal_states = loop.run_until_complete( - self.tokenizer_manager.get_internal_state() - ) + internal_states = loop.run_until_complete(self.tokenizer_manager.get_internal_state()) return { **dataclasses.asdict(self.tokenizer_manager.server_args), **self.scheduler_info, @@ -284,16 +278,12 @@ def get_server_info(self): def release_memory_occupation(self, tags: list[str] | None = None): obj = ReleaseMemoryOccupationReqInput(tags=tags) loop = asyncio.get_event_loop() - return loop.run_until_complete( - self.tokenizer_manager.release_memory_occupation(obj, None) - ) + return loop.run_until_complete(self.tokenizer_manager.release_memory_occupation(obj, None)) def resume_memory_occupation(self, tags: list[str] | None = None): obj = ResumeMemoryOccupationReqInput(tags=tags) loop = asyncio.get_event_loop() - return loop.run_until_complete( - self.tokenizer_manager.resume_memory_occupation(obj, None) - ) + return loop.run_until_complete(self.tokenizer_manager.resume_memory_occupation(obj, None)) def score( self, @@ -429,9 +419,7 @@ def sigchld_handler(signum, frame): # The child processes will send SIGQUIT to this process when any error happens # This process then clean up the whole process tree def sigquit_handler(signum, frame): - logger.error( - "Received sigquit from a child process. It usually means the child failed." - ) + logger.error("Received sigquit from a child process. It usually means the child failed.") kill_process_tree(os.getpid()) signal.signal(signal.SIGQUIT, sigquit_handler) @@ -540,9 +528,7 @@ def _launch_subprocesses( raise if data["status"] != "ready": - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) + raise RuntimeError("Initialization failed. Please see the error messages above.") scheduler_infos.append(data) # Assume all schedulers have the same scheduler_info @@ -594,9 +580,7 @@ def _launch_threads( for thread in scheduler_threads: thread.join() - logger.error( - "Scheduler or DataParallelController %s terminated", thread.name - ) + logger.error("Scheduler or DataParallelController %s terminated", thread.name) return None, None, None # Launch detokenizer thread diff --git a/python/sgl_jax/srt/entrypoints/http_server.py b/python/sgl_jax/srt/entrypoints/http_server.py index d18a01fb5..ea3f58ea2 100644 --- a/python/sgl_jax/srt/entrypoints/http_server.py +++ b/python/sgl_jax/srt/entrypoints/http_server.py @@ -112,12 +112,8 @@ async def lifespan(fast_api_app: FastAPI): fast_api_app.state.openai_serving_embedding = OpenAIServingEmbedding( _global_state.tokenizer_manager, _global_state.template_manager ) - fast_api_app.state.openai_serving_score = OpenAIServingScore( - _global_state.tokenizer_manager - ) - fast_api_app.state.openai_serving_rerank = OpenAIServingRerank( - _global_state.tokenizer_manager - ) + fast_api_app.state.openai_serving_score = OpenAIServingScore(_global_state.tokenizer_manager) + fast_api_app.state.openai_serving_rerank = OpenAIServingRerank(_global_state.tokenizer_manager) server_args: ServerArgs = fast_api_app.server_args if server_args.warmups is not None: @@ -151,9 +147,7 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE exc_str = str(exc) errors_str = str(exc.errors()) - message = ( - f"{exc_str} {errors_str}" if errors_str and errors_str != exc_str else exc_str - ) + message = f"{exc_str} {errors_str}" if errors_str and errors_str != exc_str else exc_str err = ErrorResponse( message=message, @@ -288,22 +282,12 @@ async def generate_request(obj: GenerateReqInput, request: Request): async def stream_results() -> AsyncIterator[bytes]: try: - async for out in _global_state.tokenizer_manager.generate_request( - obj, request - ): - yield ( - b"data: " - + orjson.dumps(out, option=orjson.OPT_NON_STR_KEYS) - + b"\n\n" - ) + async for out in _global_state.tokenizer_manager.generate_request(obj, request): + yield (b"data: " + orjson.dumps(out, option=orjson.OPT_NON_STR_KEYS) + b"\n\n") except ValueError as e: out = {"error": {"message": str(e)}} logger.error("[http_server] Error: %s", e) - yield ( - b"data: " - + orjson.dumps(out, option=orjson.OPT_NON_STR_KEYS) - + b"\n\n" - ) + yield (b"data: " + orjson.dumps(out, option=orjson.OPT_NON_STR_KEYS) + b"\n\n") yield b"data: [DONE]\n\n" return StreamingResponse( @@ -313,9 +297,7 @@ async def stream_results() -> AsyncIterator[bytes]: ) else: try: - ret = await _global_state.tokenizer_manager.generate_request( - obj, request - ).__anext__() + ret = await _global_state.tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: logger.error("[http_server] Error: %s", e) @@ -337,9 +319,7 @@ async def generate_from_file_request(file: UploadFile, request: Request): ) try: - ret = await _global_state.tokenizer_manager.generate_request( - obj, request - ).__anext__() + ret = await _global_state.tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: logger.error("Error: %s", e) @@ -350,9 +330,7 @@ async def generate_from_file_request(file: UploadFile, request: Request): async def encode_request(obj: EmbeddingReqInput, request: Request): """Handle an embedding request.""" try: - ret = await _global_state.tokenizer_manager.generate_request( - obj, request - ).__anext__() + ret = await _global_state.tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: return _create_error_response(e) @@ -362,9 +340,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): async def classify_request(obj: EmbeddingReqInput, request: Request): """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" try: - ret = await _global_state.tokenizer_manager.generate_request( - obj, request - ).__anext__() + ret = await _global_state.tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: return _create_error_response(e) @@ -431,9 +407,7 @@ async def start_trace_async(obj: StartTraceReqInput | None = None): else: timestamp = int(time.time()) unique_suffix = random.randint(1000, 9999) - output_file = ( - f"debug_outputs/request_traces_{timestamp}_{unique_suffix}.jsonl" - ) + output_file = f"debug_outputs/request_traces_{timestamp}_{unique_suffix}.jsonl" precision_tracer.start_trace(req_num=obj.req_num, output_file=output_file) logger.info("[HTTP] Sending trace state to scheduler...") @@ -551,9 +525,7 @@ async def trace_status_async(obj: TraceStatusReqInput | None = None): @app.api_route("/release_memory_occupation", methods=["GET", "POST"]) -async def release_memory_occupation( - obj: ReleaseMemoryOccupationReqInput, request: Request -): +async def release_memory_occupation(obj: ReleaseMemoryOccupationReqInput, request: Request): """Release GPU memory occupation temporarily.""" try: await _global_state.tokenizer_manager.release_memory_occupation(obj, request) @@ -562,9 +534,7 @@ async def release_memory_occupation( @app.api_route("/resume_memory_occupation", methods=["GET", "POST"]) -async def resume_memory_occupation( - obj: ResumeMemoryOccupationReqInput, request: Request -): +async def resume_memory_occupation(obj: ResumeMemoryOccupationReqInput, request: Request): """Resume GPU memory occupation.""" try: await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) @@ -607,9 +577,7 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request): async def abort_request(obj: AbortReq, request: Request): """Abort a request.""" try: - _global_state.tokenizer_manager.abort_request( - rid=obj.rid, abort_all=obj.abort_all - ) + _global_state.tokenizer_manager.abort_request(rid=obj.rid, abort_all=obj.abort_all) return Response(status_code=200) except Exception as e: return _create_error_response(e) @@ -629,9 +597,7 @@ async def parse_function_call_request(obj: ParseFunctionCallReq, request: Reques # 3) Organize the response content response_data = { "normal_text": normal_text, - "calls": [ - call.model_dump() for call in calls - ], # Convert pydantic objects to dictionaries + "calls": [call.model_dump() for call in calls], # Convert pydantic objects to dictionaries } return ORJSONResponse(content=response_data, status_code=200) @@ -689,13 +655,9 @@ async def openai_v1_completions(request: CompletionRequest, raw_request: Request @app.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)]) -async def openai_v1_chat_completions( - request: ChatCompletionRequest, raw_request: Request -): +async def openai_v1_chat_completions(request: ChatCompletionRequest, raw_request: Request): """OpenAI-compatible chat completion endpoint.""" - return await raw_request.app.state.openai_serving_chat.handle_request( - request, raw_request - ) + return await raw_request.app.state.openai_serving_chat.handle_request(request, raw_request) @app.post( @@ -705,9 +667,7 @@ async def openai_v1_chat_completions( ) async def openai_v1_embeddings(request: EmbeddingRequest, raw_request: Request): """OpenAI-compatible embeddings endpoint.""" - return await raw_request.app.state.openai_serving_embedding.handle_request( - request, raw_request - ) + return await raw_request.app.state.openai_serving_embedding.handle_request(request, raw_request) @app.get("/v1/models", response_class=ORJSONResponse) @@ -759,37 +719,25 @@ async def sagemaker_health() -> Response: @app.post("/invocations") -async def sagemaker_chat_completions( - request: ChatCompletionRequest, raw_request: Request -): +async def sagemaker_chat_completions(request: ChatCompletionRequest, raw_request: Request): """OpenAI-compatible chat completion endpoint.""" - return await raw_request.app.state.openai_serving_chat.handle_request( - request, raw_request - ) + return await raw_request.app.state.openai_serving_chat.handle_request(request, raw_request) @app.post("/v1/score", dependencies=[Depends(validate_json_request)]) async def v1_score_request(request: ScoringRequest, raw_request: Request): """Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation.""" - return await raw_request.app.state.openai_serving_score.handle_request( - request, raw_request - ) + return await raw_request.app.state.openai_serving_score.handle_request(request, raw_request) -@app.api_route( - "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)] -) +@app.api_route("/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]) async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request): """Endpoint for reranking documents based on query relevance.""" - return await raw_request.app.state.openai_serving_rerank.handle_request( - request, raw_request - ) + return await raw_request.app.state.openai_serving_rerank.handle_request(request, raw_request) def _create_error_response(e): - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) + return ORJSONResponse({"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST) def launch_server( @@ -815,8 +763,8 @@ def launch_server( # Initialize precision tracer enable state in HTTP server process precision_tracer.set_enable_precision_tracer(server_args.enable_precision_tracer) - tokenizer_manager, template_manager, scheduler_info = ( - _launch_subprocesses_or_threads(server_args=server_args, port_args=None) + tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses_or_threads( + server_args=server_args, port_args=None ) ## don't expose scheduler in server mode if "scheduler" in scheduler_info: diff --git a/python/sgl_jax/srt/entrypoints/openai/protocol.py b/python/sgl_jax/srt/entrypoints/openai/protocol.py index 10b594915..e8d4ca154 100644 --- a/python/sgl_jax/srt/entrypoints/openai/protocol.py +++ b/python/sgl_jax/srt/entrypoints/openai/protocol.py @@ -81,9 +81,7 @@ class JsonSchemaResponseFormat(BaseModel): class FileRequest(BaseModel): # https://platform.openai.com/docs/api-reference/files/create file: bytes # The File object (not file name) to be uploaded - purpose: str = ( - "batch" # The intended purpose of the uploaded file, default is "batch" - ) + purpose: str = "batch" # The intended purpose of the uploaded file, default is "batch" class FileResponse(BaseModel): @@ -102,9 +100,7 @@ class FileDeleteResponse(BaseModel): class BatchRequest(BaseModel): - input_file_id: ( - str # The ID of an uploaded file that contains requests for the new batch - ) + input_file_id: str # The ID of an uploaded file that contains requests for the new batch endpoint: str # The endpoint to be used for all requests in the batch completion_window: str # The time frame within which the batch should be processed metadata: dict | None = None # Optional custom metadata for the batch @@ -321,9 +317,7 @@ class ChatCompletionMessageUserParam(BaseModel): content: str | list[ChatCompletionMessageContentPart] -ChatCompletionMessageParam = ( - ChatCompletionMessageGenericParam | ChatCompletionMessageUserParam -) +ChatCompletionMessageParam = ChatCompletionMessageGenericParam | ChatCompletionMessageUserParam class ResponseFormat(BaseModel): @@ -456,10 +450,7 @@ class ChatCompletionResponseChoice(BaseModel): message: ChatMessage logprobs: LogProbs | ChoiceLogprobs | None = None finish_reason: ( - Literal[ - "stop", "length", "tool_calls", "content_filter", "function_call", "abort" - ] - | None + Literal["stop", "length", "tool_calls", "content_filter", "function_call", "abort"] | None ) = None matched_stop: None | int | str = None hidden_states: object | None = None @@ -501,10 +492,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): delta: DeltaMessage logprobs: LogProbs | ChoiceLogprobs | None = None finish_reason: ( - Literal[ - "stop", "length", "tool_calls", "content_filter", "function_call", "abort" - ] - | None + Literal["stop", "length", "tool_calls", "content_filter", "function_call", "abort"] | None ) = None matched_stop: None | int | str = None @@ -523,9 +511,7 @@ class MultimodalEmbeddingInput(BaseModel): image: str | None = None -EmbeddingInput = ( - list[int] | list[list[int]] | str | list[str] | list[MultimodalEmbeddingInput] -) +EmbeddingInput = list[int] | list[list[int]] | str | list[str] | list[MultimodalEmbeddingInput] class EmbeddingRequest(BaseModel): @@ -587,11 +573,7 @@ class RerankResponse(BaseModel): OpenAIServingRequest = ( - ChatCompletionRequest - | CompletionRequest - | EmbeddingRequest - | ScoringRequest - | V1RerankReqInput + ChatCompletionRequest | CompletionRequest | EmbeddingRequest | ScoringRequest | V1RerankReqInput ) diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_base.py b/python/sgl_jax/srt/entrypoints/openai/serving_base.py index 773815735..a80fc5d2d 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_base.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_base.py @@ -32,9 +32,7 @@ async def handle_request( return self.create_error_response(error_msg) # Convert to internal format - adapted_request, processed_request = self._convert_to_internal_request( - request - ) + adapted_request, processed_request = self._convert_to_internal_request(request) # Note(Xinyuan): raw_request below is only used for detecting the connection of the client if hasattr(request, "stream") and request.stream: diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_chat.py b/python/sgl_jax/srt/entrypoints/openai/serving_chat.py index 9fa64a0ed..b626175c6 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_chat.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_chat.py @@ -46,9 +46,7 @@ class OpenAIServingChat(OpenAIServingBase): """Handler for /v1/chat/completions requests""" - def __init__( - self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager - ): + def __init__(self, tokenizer_manager: TokenizerManager, template_manager: TemplateManager): super().__init__(tokenizer_manager) self.template_manager = template_manager @@ -152,9 +150,7 @@ def _apply_jinja_template( modalities, ) - if "tool_calls" in processed_msg and isinstance( - processed_msg.get("tool_calls"), list - ): + if "tool_calls" in processed_msg and isinstance(processed_msg.get("tool_calls"), list): for call in processed_msg["tool_calls"]: try: if "arguments" in call["function"] and isinstance( @@ -165,9 +161,7 @@ def _apply_jinja_template( ) except json.JSONDecodeError as e: # Log a warning or error if JSON parsing fails for arguments - logger.warning( - "Failed to parse tool call arguments as JSON: %s", e - ) + logger.warning("Failed to parse tool call arguments as JSON: %s", e) # Decide whether to continue or raise the exception based on desired behavior continue # Or raise e if strict parsing is required openai_compatible_messages.append(processed_msg) @@ -210,11 +204,7 @@ def _apply_jinja_template( # This except branch will be triggered when the chosen model # has a different tools input format that is not compatible # with openAI's apply_chat_template tool_call format, like Mistral. - tools = ( - [t if "function" in t else {"function": t} for t in tools] - if tools - else None - ) + tools = [t if "function" in t else {"function": t} for t in tools] if tools else None prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template( openai_compatible_messages, tokenize=True, @@ -343,9 +333,7 @@ def _build_sampling_params( pass elif request.response_format and request.response_format.type == "json_object": sampling_params["json_schema"] = '{"type": "object"}' - elif ( - request.response_format and request.response_format.type == "structural_tag" - ): + elif request.response_format and request.response_format.type == "structural_tag": # sampling_params["structural_tag"] = convert_json_schema_to_str( # request.response_format.model_dump(by_alias=True) # ) @@ -424,9 +412,7 @@ async def _generate_chat_stream( choice_logprobs = self._process_streaming_logprobs( content, n_prev_tokens.get(index, 0) ) - n_prev_tokens[index] = len( - content["meta_info"]["output_token_logprobs"] - ) + n_prev_tokens[index] = len(content["meta_info"]["output_token_logprobs"]) finish_reason = content["meta_info"]["finish_reason"] finish_reason_type = finish_reason["type"] if finish_reason else None @@ -505,8 +491,7 @@ async def _generate_chat_stream( delta=DeltaMessage(content=delta if delta else None), finish_reason=( None - if request.stream_options - and request.stream_options.include_usage + if request.stream_options and request.stream_options.include_usage else finish_reason_type ), matched_stop=( @@ -550,9 +535,7 @@ async def _generate_chat_stream( for index, choice_hidden_states in hidden_states.items(): if choice_hidden_states: last_token_hidden_states = ( - choice_hidden_states[-1] - if len(choice_hidden_states) > 1 - else [] + choice_hidden_states[-1] if len(choice_hidden_states) > 1 else [] ) hidden_states_chunk = ChatCompletionStreamResponse( id=content["meta_info"]["id"], @@ -560,9 +543,7 @@ async def _generate_chat_stream( choices=[ ChatCompletionResponseStreamChoice( index=index, - delta=DeltaMessage( - hidden_states=last_token_hidden_states - ), + delta=DeltaMessage(hidden_states=last_token_hidden_states), finish_reason=finish_reason_type, ) ], @@ -645,9 +626,7 @@ def _build_chat_response( reasoning_parser = self.tokenizer_manager.server_args.reasoning_parser if reasoning_parser and request.separate_reasoning: try: - parser = ReasoningParser( - model_type=reasoning_parser, stream_reasoning=False - ) + parser = ReasoningParser(model_type=reasoning_parser, stream_reasoning=False) reasoning_text, text = parser.parse_non_stream(text) except Exception as e: logger.error("Reasoning parsing error: %s", e) @@ -710,18 +689,14 @@ def _process_logprobs_tokens( """ token_logprobs = [] - for token_idx, (token, logprob) in enumerate( - zip(logprobs.tokens, logprobs.token_logprobs) - ): + for token_idx, (token, logprob) in enumerate(zip(logprobs.tokens, logprobs.token_logprobs)): token_bytes = list(token.encode("utf-8")) top_logprobs = [] if logprobs.top_logprobs: # - Non-streaming (use_token_index=True): uses token_idx for full data # - Streaming (use_token_index=False): uses index 0 for pre-sliced data top_logprobs_idx = token_idx if use_token_index else 0 - for top_token, top_logprob in logprobs.top_logprobs[ - top_logprobs_idx - ].items(): + for top_token, top_logprob in logprobs.top_logprobs[top_logprobs_idx].items(): top_token_bytes = list(top_token.encode("utf-8")) top_logprobs.append( TopLogprob( @@ -788,12 +763,8 @@ def _process_streaming_logprobs( ) -> ChoiceLogprobs: """Process logprobs for streaming response""" logprobs = to_openai_style_logprobs( - output_token_logprobs=content["meta_info"]["output_token_logprobs"][ - n_prev_token: - ], - output_top_logprobs=content["meta_info"].get("output_top_logprobs", [])[ - n_prev_token: - ], + output_token_logprobs=content["meta_info"]["output_token_logprobs"][n_prev_token:], + output_top_logprobs=content["meta_info"].get("output_top_logprobs", [])[n_prev_token:], ) token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=False) diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_completions.py b/python/sgl_jax/srt/entrypoints/openai/serving_completions.py index aa3c2e606..36159bedb 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_completions.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_completions.py @@ -65,9 +65,7 @@ def _convert_to_internal_request( sampling_params = self._build_sampling_params(request) # Determine prompt format - if isinstance(prompt, str) or ( - isinstance(prompt, list) and isinstance(prompt[0], str) - ): + if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], str)): prompt_kwargs = {"text": prompt} else: prompt_kwargs = {"input_ids": prompt} @@ -171,9 +169,7 @@ async def _generate_completion_stream( if request.logprobs is not None: # The first chunk and echo is enabled. if not stream_buffer and request.echo: - input_token_logprobs = content["meta_info"][ - "input_token_logprobs" - ] + input_token_logprobs = content["meta_info"]["input_token_logprobs"] input_top_logprobs = content["meta_info"]["input_top_logprobs"] else: input_token_logprobs = None @@ -183,16 +179,14 @@ async def _generate_completion_stream( logprobs = to_openai_style_logprobs( input_token_logprobs=input_token_logprobs, input_top_logprobs=input_top_logprobs, - output_token_logprobs=content["meta_info"][ - "output_token_logprobs" - ][n_prev_token:], + output_token_logprobs=content["meta_info"]["output_token_logprobs"][ + n_prev_token: + ], output_top_logprobs=content["meta_info"]["output_top_logprobs"][ n_prev_token: ], ) - n_prev_tokens[index] = len( - content["meta_info"]["output_token_logprobs"] - ) + n_prev_tokens[index] = len(content["meta_info"]["output_token_logprobs"]) # Generate delta delta = text[len(stream_buffer) :] @@ -224,9 +218,7 @@ async def _generate_completion_stream( for index, choice_hidden_states in hidden_states.items(): if choice_hidden_states: last_token_hidden_states = ( - choice_hidden_states[-1] - if len(choice_hidden_states) > 1 - else [] + choice_hidden_states[-1] if len(choice_hidden_states) > 1 else [] ) hidden_states_chunk = CompletionStreamResponse( id=content["meta_info"]["id"], @@ -277,9 +269,7 @@ async def _handle_non_streaming_request( ) -> CompletionResponse | ErrorResponse | ORJSONResponse: """Handle non-streaming completion request""" try: - generator = self.tokenizer_manager.generate_request( - adapted_request, raw_request - ) + generator = self.tokenizer_manager.generate_request(adapted_request, raw_request) ret = await generator.__anext__() except ValueError as e: return self.create_error_response(str(e)) @@ -332,9 +322,7 @@ def _build_completion_response( logprobs = to_openai_style_logprobs( input_token_logprobs=input_token_logprobs, input_top_logprobs=input_top_logprobs, - output_token_logprobs=ret_item["meta_info"][ - "output_token_logprobs" - ], + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], ) @@ -385,9 +373,7 @@ def _get_echo_text(self, request: CompletionRequest, index: int) -> str: return self.tokenizer_manager.tokenizer.decode( request.prompt, skip_special_tokens=True ) - elif isinstance(request.prompt[0], list) and isinstance( - request.prompt[0][0], int - ): + elif isinstance(request.prompt[0], list) and isinstance(request.prompt[0][0], int): # for the case of multiple token ids prompts return self.tokenizer_manager.tokenizer.decode( request.prompt[index // request.n], @@ -403,17 +389,13 @@ def _prepare_echo_prompts(self, request: CompletionRequest) -> list[str]: elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list): # for the case of multiple token ids prompts return [ - self.tokenizer_manager.tokenizer.decode( - prompt, skip_special_tokens=True - ) + self.tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True) for prompt in request.prompt ] elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int): # for the case of single token ids prompt return [ - self.tokenizer_manager.tokenizer.decode( - request.prompt, skip_special_tokens=True - ) + self.tokenizer_manager.tokenizer.decode(request.prompt, skip_special_tokens=True) ] else: # for the case of single str prompt diff --git a/python/sgl_jax/srt/entrypoints/openai/usage_processor.py b/python/sgl_jax/srt/entrypoints/openai/usage_processor.py index dfd2d7957..7fdf00e9b 100644 --- a/python/sgl_jax/srt/entrypoints/openai/usage_processor.py +++ b/python/sgl_jax/srt/entrypoints/openai/usage_processor.py @@ -24,15 +24,12 @@ def calculate_response_usage( completion_tokens = sum(r["meta_info"]["completion_tokens"] for r in responses) prompt_tokens = sum( - responses[i]["meta_info"]["prompt_tokens"] - for i in range(0, len(responses), n_choices) + responses[i]["meta_info"]["prompt_tokens"] for i in range(0, len(responses), n_choices) ) cached_details = None if enable_cache_report: - cached_total = sum( - r["meta_info"].get("cached_tokens", 0) for r in responses - ) + cached_total = sum(r["meta_info"].get("cached_tokens", 0) for r in responses) cached_details = UsageProcessor._details_if_cached(cached_total) return UsageProcessor.calculate_token_usage( @@ -50,9 +47,7 @@ def calculate_streaming_usage( enable_cache_report: bool = False, ) -> UsageInfo: # index % n_choices == 0 marks the first choice of a prompt - total_prompt_tokens = sum( - tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0 - ) + total_prompt_tokens = sum(tok for idx, tok in prompt_tokens.items() if idx % n_choices == 0) total_completion_tokens = sum(completion_tokens.values()) cached_details = ( diff --git a/python/sgl_jax/srt/entrypoints/openai/utils.py b/python/sgl_jax/srt/entrypoints/openai/utils.py index 478fa2640..dad0a048d 100644 --- a/python/sgl_jax/srt/entrypoints/openai/utils.py +++ b/python/sgl_jax/srt/entrypoints/openai/utils.py @@ -29,9 +29,7 @@ def append_token_logprobs(token_logprobs): def append_top_logprobs(top_logprobs): for tokens in top_logprobs: if tokens is not None: - ret_logprobs.top_logprobs.append( - {token[2]: token[0] for token in tokens} - ) + ret_logprobs.top_logprobs.append({token[2]: token[0] for token in tokens}) else: ret_logprobs.top_logprobs.append(None) diff --git a/python/sgl_jax/srt/hf_transformers_utils.py b/python/sgl_jax/srt/hf_transformers_utils.py index fbc9a653d..20f968d07 100644 --- a/python/sgl_jax/srt/hf_transformers_utils.py +++ b/python/sgl_jax/srt/hf_transformers_utils.py @@ -50,9 +50,7 @@ def get_hf_text_config(config: PretrainedConfig): # qwen2.5 omni thinker_config = config.thinker_config if hasattr(thinker_config, "text_config"): - thinker_config.text_config.torch_dtype = getattr( - thinker_config, "torch_dtype", None - ) + thinker_config.text_config.torch_dtype = getattr(thinker_config, "torch_dtype", None) return thinker_config.text_config return thinker_config else: @@ -283,9 +281,7 @@ def get_processor( def attach_additional_stop_token_ids(tokenizer): # Special handling for stop token <|eom_id|> generated by llama 3 tool use. if "<|eom_id|>" in tokenizer.get_added_vocab(): - tokenizer.additional_stop_token_ids = set( - [tokenizer.get_added_vocab()["<|eom_id|>"]] - ) + tokenizer.additional_stop_token_ids = set([tokenizer.get_added_vocab()["<|eom_id|>"]]) else: tokenizer.additional_stop_token_ids = None diff --git a/python/sgl_jax/srt/jinja_template_utils.py b/python/sgl_jax/srt/jinja_template_utils.py index daa89301c..5fbe9ad75 100644 --- a/python/sgl_jax/srt/jinja_template_utils.py +++ b/python/sgl_jax/srt/jinja_template_utils.py @@ -61,9 +61,7 @@ def process_content_for_template_format( # Keep other content as-is (text, etc.) processed_content_parts.append(chunk) - new_msg = { - k: v for k, v in msg_dict.items() if v is not None and k != "content" - } + new_msg = {k: v for k, v in msg_dict.items() if v is not None and k != "content"} new_msg["content"] = processed_content_parts return new_msg diff --git a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py index 39fb6cc58..84f58476a 100644 --- a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py +++ b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/flash_attention.py @@ -59,9 +59,9 @@ def ref_ragged_paged_attention_fused( indices = page_indices[i] q = queries[q_start:q_end] - kv_fused = kv_pages_fused[indices, :, :, :].reshape( - -1, num_kv_heads_interleaved, head_dim - )[:kv_len] + kv_fused = kv_pages_fused[indices, :, :, :].reshape(-1, num_kv_heads_interleaved, head_dim)[ + :kv_len + ] # Head format: [K1, V1, K2, V2, ...] k = kv_fused[:, 0::2, :] # indices 0, 2, 4, ... @@ -263,9 +263,7 @@ def _ragged_paged_attention_kernel( bq_sz, ): assert q_hbm_ref.shape == o_hbm_ref.shape - assert ( - q_hbm_ref.shape[-1] == kv_cache_fused_hbm_ref.shape[-1] - ) # head_dim should match + assert q_hbm_ref.shape[-1] == kv_cache_fused_hbm_ref.shape[-1] # head_dim should match ( actual_num_kv_heads, max_num_tokens, @@ -326,9 +324,7 @@ def _fetch_bkv(seq_idx, bkv_idx, bkv_sem_idx, *, wait=False): kv_left_frm_cache = jnp.maximum(kv_left - q_len, 0) kv_left_frm_new = kv_left - kv_left_frm_cache bkv_p_frm_cache = jnp.minimum(cdiv(kv_left_frm_cache, page_size), bkv_p) - bkv_sz_frm_new = jnp.minimum( - jnp.maximum(bkv_sz - kv_left_frm_cache, 0), kv_left_frm_new - ) + bkv_sz_frm_new = jnp.minimum(jnp.maximum(bkv_sz - kv_left_frm_cache, 0), kv_left_frm_new) start_kv_page_idx = cdiv(cu_kv_lens_ref[seq_idx], page_size) page_indices_offset = start_kv_page_idx + kv_p_start @@ -486,9 +482,7 @@ def load_bq(bq_sem_idx, kv_head_idx, *, actual_bq_sz=bq_sz): .at[bq_sem_idx, kv_head_idx] .reshape(bq_sz * num_q_heads_per_kv_head_per_packing, head_dim) ) - return pltpu.bitcast( - q_ref[: actual_bq_sz * num_q_heads_per_kv_head_per_packing], q_dtype - ) + return pltpu.bitcast(q_ref[: actual_bq_sz * num_q_heads_per_kv_head_per_packing], q_dtype) def strided_load(ref, start, step, *, dtype=None): assert get_dtype_packing(ref.dtype) == 1 @@ -511,9 +505,7 @@ def strided_load_bkv_fused(bkv_sem_idx, start, step, *, bkv_bitmask): step //= kv_packing kv_ref = ( - bkv_fused_x2_ref.bitcast(jnp.uint32) - .at[bkv_sem_idx] - .reshape(bkv_sz * step, head_dim) + bkv_fused_x2_ref.bitcast(jnp.uint32).at[bkv_sem_idx].reshape(bkv_sz * step, head_dim) ) def _mask_kv(k, v): @@ -547,9 +539,9 @@ def broadcast_minor(src, shape): assert src.shape[-1] % 128 == 0 target_minor = align_to(shape[-1], src.shape[-1]) # no-op concatenation. - return jnp.concatenate( - [src for _ in range(target_minor // src.shape[-1])], axis=-1 - )[..., : shape[-1]] + return jnp.concatenate([src for _ in range(target_minor // src.shape[-1])], axis=-1)[ + ..., : shape[-1] + ] def process(static_q_len=None): num_bkv = cdiv(kv_len, bkv_sz) @@ -692,12 +684,9 @@ def flash_attention(q_batch, k_batch, v_batch): kv_len - q_len + bq_idx * bq_sz - + lax.broadcasted_iota(jnp.int32, s.shape, 1) - // num_q_heads_per_kv_head - ) - k_span = bkv_idx * bkv_sz + lax.broadcasted_iota( - jnp.int32, s.shape, 2 + + lax.broadcasted_iota(jnp.int32, s.shape, 1) // num_q_heads_per_kv_head ) + k_span = bkv_idx * bkv_sz + lax.broadcasted_iota(jnp.int32, s.shape, 2) mask = q_span < k_span if sliding_window is not None: @@ -714,9 +703,7 @@ def flash_attention(q_batch, k_batch, v_batch): head_acc_ref = acc_ref.at[head_idx, : q_batch.shape[1]] def load_with_init(ref, init_val): - return jnp.where( - bkv_idx == 0, jnp.full_like(ref, init_val), ref[...] - ) + return jnp.where(bkv_idx == 0, jnp.full_like(ref, init_val), ref[...]) s_head = s[head_idx] s_head_rowmax = jnp.max(s_head, axis=1, keepdims=True) @@ -975,13 +962,9 @@ def static_validate_inputs_fused( if k.shape != v.shape: raise ValueError(f"Expected {k.shape=} to be equal to {v.shape=}") if not (q.shape[0] == k.shape[0] == v.shape[0]): - raise ValueError( - f"Expected {q.shape[0]=} to be equal to {k.shape[0]=} and {v.shape[0]=}" - ) + raise ValueError(f"Expected {q.shape[0]=} to be equal to {k.shape[0]=} and {v.shape[0]=}") if not (q.shape[2] == k.shape[2] == v.shape[2]): - raise ValueError( - f"Expected {q.shape[2]=} to be equal to {k.shape[2]=} and {v.shape[2]=}" - ) + raise ValueError(f"Expected {q.shape[2]=} to be equal to {k.shape[2]=} and {v.shape[2]=}") actual_head_dim = q.shape[2] actual_num_q_heads = q.shape[1] @@ -994,9 +977,7 @@ def static_validate_inputs_fused( # Validate fused KV cache if len(kv_cache_fused.shape) != 4: - raise ValueError( - f"Expected 4D kv_cache_fused, got shape {kv_cache_fused.shape}" - ) + raise ValueError(f"Expected 4D kv_cache_fused, got shape {kv_cache_fused.shape}") _, page_size, cache_num_kv_heads_interleaved, head_dim = kv_cache_fused.shape @@ -1012,9 +993,7 @@ def static_validate_inputs_fused( ) if head_dim != align_to(actual_head_dim, 128): - raise ValueError( - f"Expected {head_dim=} is equal to {align_to(actual_head_dim, 128)=}" - ) + raise ValueError(f"Expected {head_dim=} is equal to {align_to(actual_head_dim, 128)=}") if not (kv_cache_fused.dtype == k.dtype == v.dtype): raise ValueError( @@ -1025,11 +1004,7 @@ def static_validate_inputs_fused( raise ValueError(f"Expected {kv_cache_fused.dtype=} to be a floating point.") if not ( - jnp.int32 - == kv_lens.dtype - == page_indices.dtype - == cu_q_lens.dtype - == distribution.dtype + jnp.int32 == kv_lens.dtype == page_indices.dtype == cu_q_lens.dtype == distribution.dtype ): raise ValueError( f"Expected int32 dtype for {kv_lens.dtype=}, {page_indices.dtype=}," @@ -1325,9 +1300,7 @@ def ragged_paged_attention( ) ) - output, updated_kv_cache_fused = kernel( - *scalar_prefetches, q, kv, kv_cache_fused_processed - ) + output, updated_kv_cache_fused = kernel(*scalar_prefetches, q, kv, kv_cache_fused_processed) return ( prepare_outputs(output, actual_num_q_heads_per_kv_head, actual_head_dim), prepare_updated_kv_cache_fused( @@ -1403,9 +1376,7 @@ def prepare_updated_kv_cache_fused( head_dim, ) = kv_cache_fused.shape - actual_num_kv_heads_interleaved = ( - actual_num_kv_heads * 2 - ) # Head interleaving: K1,V1,K2,V2,... + actual_num_kv_heads_interleaved = actual_num_kv_heads * 2 # Head interleaving: K1,V1,K2,V2,... return kv_cache_fused.reshape( -1, num_kv_heads_interleaved_packed * kv_packing, diff --git a/python/sgl_jax/srt/layers/attention/flashattention_backend.py b/python/sgl_jax/srt/layers/attention/flashattention_backend.py index b3a1d2a9f..1e9f1f53b 100644 --- a/python/sgl_jax/srt/layers/attention/flashattention_backend.py +++ b/python/sgl_jax/srt/layers/attention/flashattention_backend.py @@ -136,9 +136,7 @@ def get_forward_metadata(self, batch: ModelWorkerBatch): distribution = np.array([0, 0, num_seqs.item()], dtype=np.int32) elif batch.forward_mode == ForwardMode.EXTEND: # All sequences are prefill mode - distribution = np.array( - [0, num_seqs.item(), num_seqs.item()], dtype=np.int32 - ) + distribution = np.array([0, num_seqs.item(), num_seqs.item()], dtype=np.int32) else: raise ValueError(f"Invalid forward mode: {batch.forward_mode}") @@ -151,9 +149,7 @@ def get_forward_metadata(self, batch: ModelWorkerBatch): metadata.distribution, ) = device_array( (num_seqs, cu_q_lens, cu_kv_lens, page_indices, seq_lens, distribution), - sharding=( - NamedSharding(self.mesh, P()) if jax.process_count() == 1 else None - ), + sharding=(NamedSharding(self.mesh, P()) if jax.process_count() == 1 else None), ) return metadata @@ -202,28 +198,20 @@ def __call__( Returns: Output tensor of shape [total_tokens, hidden_size] """ - kv_cache_fused = self._get_fused_kv_cache( - forward_batch, token_to_kv_pool, layer.layer_id - ) + kv_cache_fused = self._get_fused_kv_cache(forward_batch, token_to_kv_pool, layer.layer_id) - scale = ( - 1.0 / jnp.sqrt(layer.head_dim) if layer.scaling is None else layer.scaling - ) + scale = 1.0 / jnp.sqrt(layer.head_dim) if layer.scaling is None else layer.scaling # Prepare fused KV cache for paged format: [num_pages, page_size, num_kv_heads * 2, head_dim] total_tokens = kv_cache_fused.shape[0] num_pages = total_tokens // self.page_size - kv_cache_fused_paged = kv_cache_fused.reshape( - num_pages, self.page_size, -1, self.head_dim - ) + kv_cache_fused_paged = kv_cache_fused.reshape(num_pages, self.page_size, -1, self.head_dim) in_specs = ( P(None, self.kv_partition_axis), # queries P(None, self.kv_partition_axis), # keys (new tokens) P(None, self.kv_partition_axis), # values (new tokens) - P( - None, None, self.kv_partition_axis, None - ), # kv_cache_fused (head interleaved) + P(None, None, self.kv_partition_axis, None), # kv_cache_fused (head interleaved) P(), # kv_lens P(), # page_indices P(), # cu_q_lens diff --git a/python/sgl_jax/srt/layers/embeddings.py b/python/sgl_jax/srt/layers/embeddings.py index c8de05835..b7cb374ce 100644 --- a/python/sgl_jax/srt/layers/embeddings.py +++ b/python/sgl_jax/srt/layers/embeddings.py @@ -87,9 +87,7 @@ def __call__(self, inputs: jax.Array) -> jax.Array: raise ValueError("Input type must be an integer or unsigned integer.") # Use take because fancy indexing numpy arrays with JAX indices does not # work correctly. - (embedding,) = self.promote_dtype( - (self.embedding.value,), dtype=self.dtype, inexact=False - ) + (embedding,) = self.promote_dtype((self.embedding.value,), dtype=self.dtype, inexact=False) if self.num_embeddings == 1: return jnp.broadcast_to(embedding, inputs.shape + (self.features,)) return jnp.take(embedding, inputs, axis=0) @@ -107,9 +105,7 @@ def attend(self, query: jax.Array) -> jax.Array: Commonly used for weight-sharing between embeddings and logit transform in NLP models. """ - query, embedding = self.promote_dtype( - (query, self.embedding.value), dtype=self.dtype - ) + query, embedding = self.promote_dtype((query, self.embedding.value), dtype=self.dtype) return jnp.dot(query, embedding.T) @@ -235,8 +231,7 @@ def __call__( def _compute_inv_freq(self, base: int | float) -> jax.Array: """Compute the inverse frequency.""" inv_freq = 1.0 / ( - base - ** (jnp.arange(0, self.rotary_dim, 2, dtype=jnp.float32) / self.rotary_dim) + base ** (jnp.arange(0, self.rotary_dim, 2, dtype=jnp.float32) / self.rotary_dim) ) return inv_freq @@ -268,9 +263,7 @@ def __init__( self.low_freq_factor = low_freq_factor self.high_freq_factor = high_freq_factor self.orig_max_position = orig_max_position - super().__init__( - head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype - ) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) def _compute_inv_freq(self, base: int | float) -> jax.Array: inv_freqs = super()._compute_inv_freq(base) diff --git a/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/common.py b/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/common.py index b33213c76..8745fd139 100644 --- a/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/common.py +++ b/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/common.py @@ -41,11 +41,7 @@ def select_input_dtype(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.dtype: """A type to which both input should be adapted to before dot product.""" # bf16xbf16 matmul is only supported since TPUv4 generation. In case of mixed # input precision, we need to convert bf16 argument to fp32 beforehand. - if ( - supports_bfloat16_matmul() - and lhs.dtype == jnp.bfloat16 - and rhs.dtype == jnp.bfloat16 - ): + if supports_bfloat16_matmul() and lhs.dtype == jnp.bfloat16 and rhs.dtype == jnp.bfloat16: return jnp.bfloat16 else: return jnp.float32 diff --git a/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py b/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py index 030e305db..7f856279f 100644 --- a/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py +++ b/python/sgl_jax/srt/layers/gmm/megablox_gmm_kernel/gmm.py @@ -34,9 +34,7 @@ def _validate_args( # Validate 'group_sizes'. if group_sizes.dtype != jnp.int32: - raise ValueError( - f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype}." - ) + raise ValueError(f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype}.") return lhs, group_sizes, common.select_input_dtype(lhs, rhs) @@ -198,9 +196,7 @@ def make_group_metadata( partial_tile_ids = jnp.where(partial_tile_mask, tiles_m, group_offsets[:-1] // tm) - tile_visits = ( - jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + 1 - ) + tile_visits = jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + 1 # Create the m-dimension tile ids for each grid index based on the visit # counts for each tile. @@ -230,9 +226,7 @@ def make_group_metadata( return (group_offsets, group_ids, m_tile_ids), num_tiles -def _get_group_size( - *, grid_id: jnp.ndarray, group_metadata: GroupMetadata -) -> jnp.ndarray: +def _get_group_size(*, grid_id: jnp.ndarray, group_metadata: GroupMetadata) -> jnp.ndarray: """Calculate the number of rows in the current group.""" group_offsets, group_ids = group_metadata[:2] group_id = group_ids[grid_id] @@ -324,15 +318,11 @@ def gmm( group_offset = jnp.array([0], dtype=jnp.int32) else: if group_offset.shape: - raise ValueError( - f"group_offset must be a ()-shaped array. Got: {group_offset.shape}." - ) + raise ValueError(f"group_offset must be a ()-shaped array. Got: {group_offset.shape}.") group_offset = group_offset[None] num_current_groups = rhs.shape[0] num_total_groups = group_sizes.shape[0] - lhs, group_sizes, input_dtype = _validate_args( - lhs=lhs, rhs=rhs, group_sizes=group_sizes - ) + lhs, group_sizes, input_dtype = _validate_args(lhs=lhs, rhs=rhs, group_sizes=group_sizes) # Gather shape information. m, k, n = (lhs.shape[0], lhs.shape[1], rhs.shape[2]) @@ -362,9 +352,7 @@ def gmm( visit_empty_groups=False, ) - dot_general_dims = ( - (((1,), (1,)), ((), ())) if transpose_rhs else (((1,), (0,)), ((), ())) - ) + dot_general_dims = (((1,), (1,)), ((), ())) if transpose_rhs else (((1,), (0,)), ((), ())) def kernel( group_metadata, @@ -389,9 +377,7 @@ def _zero_acc(): prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0) is_first_processed_group = grid_id == 0 m_tile_changed = m_tile_ids[grid_id] != m_tile_ids[prev_grid_id] - first_time_seeing_out = jnp.logical_or( - is_first_processed_group, m_tile_changed - ) + first_time_seeing_out = jnp.logical_or(is_first_processed_group, m_tile_changed) @pl.when(first_time_seeing_out) def _init_out(): @@ -414,9 +400,9 @@ def _store_accum(): tn=tn, ) to_store = acc_scratch[...] - out[...] = jax.lax.select( - mask[...], to_store, out[...].astype(jnp.float32) - ).astype(preferred_element_type) + out[...] = jax.lax.select(mask[...], to_store, out[...].astype(jnp.float32)).astype( + preferred_element_type + ) def _accum(is_last_k_tile): if is_last_k_tile: @@ -491,9 +477,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): max_active_tiles = group_metadata[1].size bytes_accessed = (lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) + out_bytes flops = 2 * m * k * n - cost_estimate = pl.CostEstimate( - flops=flops, bytes_accessed=bytes_accessed, transcendentals=0 - ) + cost_estimate = pl.CostEstimate(flops=flops, bytes_accessed=bytes_accessed, transcendentals=0) call_gmm = pl.pallas_call( kernel, diff --git a/python/sgl_jax/srt/layers/logits_processor.py b/python/sgl_jax/srt/layers/logits_processor.py index 521fb1212..7f1ec8aa7 100644 --- a/python/sgl_jax/srt/layers/logits_processor.py +++ b/python/sgl_jax/srt/layers/logits_processor.py @@ -76,12 +76,8 @@ def tree_unflatten(cls, aux_data, children): obj.next_token_top_logprobs_val = aux_data["next_token_top_logprobs_val"] obj.next_token_top_logprobs_idx = aux_data["next_token_top_logprobs_idx"] - obj.next_token_token_ids_logprobs_val = aux_data[ - "next_token_token_ids_logprobs_val" - ] - obj.next_token_token_ids_logprobs_idx = aux_data[ - "next_token_token_ids_logprobs_idx" - ] + obj.next_token_token_ids_logprobs_val = aux_data["next_token_token_ids_logprobs_val"] + obj.next_token_token_ids_logprobs_idx = aux_data["next_token_token_ids_logprobs_idx"] obj.input_top_logprobs_val = aux_data["input_top_logprobs_val"] obj.input_top_logprobs_idx = aux_data["input_top_logprobs_idx"] obj.input_token_ids_logprobs_val = aux_data["input_token_ids_logprobs_val"] @@ -175,9 +171,7 @@ def from_model_worker_batch(cls, batch: ModelWorkerBatch, mesh: Mesh = None): extend_seq_lens_cpu = batch.extend_seq_lens.tolist() extend_return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) - extend_token_ids_logprob = any( - x is not None for x in batch.token_ids_logprobs - ) + extend_token_ids_logprob = any(x is not None for x in batch.token_ids_logprobs) extend_return_logprob = False extend_logprob_pruned_lens_cpu = [] for extend_len, start_len in zip( @@ -188,9 +182,7 @@ def from_model_worker_batch(cls, batch: ModelWorkerBatch, mesh: Mesh = None): extend_return_logprob = True extend_logprob_pruned_lens_cpu.append(extend_len - start_len) else: - extend_return_logprob = extend_return_top_logprob = ( - extend_token_ids_logprob - ) = False + extend_return_logprob = extend_return_top_logprob = extend_token_ids_logprob = False extend_logprob_pruned_lens_cpu = extend_seq_lens_cpu = None sharding = NamedSharding(mesh, P()) if jax.process_count() == 1 else None @@ -232,10 +224,7 @@ def __call__( pruned_states = hidden_states sample_indices = None input_logprob_indices = None - elif ( - logits_metadata.forward_mode.is_extend() - and not logits_metadata.extend_return_logprob - ): + elif logits_metadata.forward_mode.is_extend() and not logits_metadata.extend_return_logprob: last_index = jnp.cumsum(logits_metadata.extend_seq_lens, axis=0) - 1 pruned_states = hidden_states[last_index] sample_indices = None @@ -288,9 +277,7 @@ def __call__( # Compute logits for both input and sampled tokens. logits = self._get_logits(pruned_states, self.lm_head) - sampled_logits = ( - logits[sample_indices] if sample_indices is not None else logits - ) + sampled_logits = logits[sample_indices] if sample_indices is not None else logits hidden_states_to_store: jax.Array | None = None if logits_metadata.capture_hidden_mode.need_capture(): @@ -300,9 +287,7 @@ def __call__( # Get the last token hidden states. If sample_indices is None, # pruned states only contain the last tokens already. hidden_states_to_store = ( - pruned_states[sample_indices] - if sample_indices is not None - else pruned_states + pruned_states[sample_indices] if sample_indices is not None else pruned_states ) else: raise AssertionError() @@ -372,9 +357,7 @@ def __call__( ) @staticmethod - def get_token_ids_logprobs( - all_logprobs: jax.Array, logits_metadata: LogitsMetadata - ): + def get_token_ids_logprobs(all_logprobs: jax.Array, logits_metadata: LogitsMetadata): input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], [] pt = 0 for token_ids, pruned_len in zip( @@ -413,12 +396,8 @@ def get_top_logprobs(all_logprobs: jax.Array, logits_metadata: LogitsMetadata): input_top_logprobs_idx.append([]) continue - input_top_logprobs_val.append( - [values[pt + j][:k] for j in range(pruned_len)] - ) - input_top_logprobs_idx.append( - [indices[pt + j][:k] for j in range(pruned_len)] - ) + input_top_logprobs_val.append([values[pt + j][:k] for j in range(pruned_len)]) + input_top_logprobs_idx.append([indices[pt + j][:k] for j in range(pruned_len)]) pt += pruned_len return input_top_logprobs_val, input_top_logprobs_idx @@ -439,10 +418,7 @@ def compute_temp_top_p_normalized_logprobs( # Normalize logprobs if top_p normalization is enabled # NOTE: only normalize logprobs when top_p is set and not equal to 1.0 - if ( - logits_metadata.top_p_normalized_logprobs - and (logits_metadata.top_p != 1.0).any() - ): + if logits_metadata.top_p_normalized_logprobs and (logits_metadata.top_p != 1.0).any(): from sgl_jax.srt.layers.sampler import top_p_normalize_probs_jax probs = jnp.softmax(last_logits, axis=-1) @@ -471,10 +447,6 @@ def _get_logits( logits = jnp.dot(hidden_states, embedding.T) - logits = ( - logits[:, : self.vocab_size] - if logits.ndim > 1 - else logits[: self.vocab_size] - ) + logits = logits[:, : self.vocab_size] if logits.ndim > 1 else logits[: self.vocab_size] return logits diff --git a/python/sgl_jax/srt/layers/moe.py b/python/sgl_jax/srt/layers/moe.py index e220cf657..32e77acab 100644 --- a/python/sgl_jax/srt/layers/moe.py +++ b/python/sgl_jax/srt/layers/moe.py @@ -66,9 +66,7 @@ def __init__( if self.use_bias: bias_shape = self.features - bias_axes = ( - self.kernel_axes[-len(self.features) :] if self.kernel_axes else () - ) + bias_axes = self.kernel_axes[-len(self.features) :] if self.kernel_axes else () self.bias = nnx.Param( nnx.with_partitioning(nnx.initializers.zeros_init(), bias_axes)( jax.random.PRNGKey(0), bias_shape, self.weight_dtype @@ -200,17 +198,13 @@ def _internal_moe_computation( expert_shard_id = data_index * tensor_size + tensor_index # topk - top_k_logits, top_k_indices = jax.lax.top_k( - router_logits, self.num_experts_per_tok + top_k_logits, top_k_indices = jax.lax.top_k(router_logits, self.num_experts_per_tok) + top_k_weights = jax.nn.softmax(top_k_logits.astype(jnp.bfloat16), axis=-1).astype( + self.dtype ) - top_k_weights = jax.nn.softmax( - top_k_logits.astype(jnp.bfloat16), axis=-1 - ).astype(self.dtype) # ep moe norm_topk_prob=true - top_k_weights = top_k_weights / jnp.sum( - top_k_weights, axis=-1, keepdims=True - ) + top_k_weights = top_k_weights / jnp.sum(top_k_weights, axis=-1, keepdims=True) if hidden_states.ndim == 2: total_tokens = hidden_states.shape[0] @@ -219,16 +213,14 @@ def _internal_moe_computation( batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] total_tokens = batch_size * seq_len # Permute - x, sorted_selected_experts, weights, group_sizes, selected_experts = ( - self._permute(hidden_states, top_k_indices, top_k_weights) + x, sorted_selected_experts, weights, group_sizes, selected_experts = self._permute( + hidden_states, top_k_indices, top_k_weights ) # EP Dispatch if self.expert_parallel_size > 1: - x, local_group_sizes, selected_experts = ( - self._expert_all_to_all_dispatch( - x, selected_experts, expert_shard_id - ) + x, local_group_sizes, selected_experts = self._expert_all_to_all_dispatch( + x, selected_experts, expert_shard_id ) else: local_group_sizes = group_sizes @@ -278,9 +270,7 @@ def _gmm_compute_with_sharded_weights( self, x, local_group_sizes, selected_experts, w0_kernel, w1_kernel, wo_kernel ): if x.shape[0] == 0: - empty_output = jnp.zeros( - (0, wo_kernel.shape[-1]), dtype=x.dtype - ) # (0, hidden_dim) + empty_output = jnp.zeros((0, wo_kernel.shape[-1]), dtype=x.dtype) # (0, hidden_dim) return empty_output m, k = x.shape[0], x.shape[1] @@ -331,12 +321,8 @@ def _gmm_compute_with_sharded_weights( return intermediate_output def _single_device_forward(self, inputs, router_logits): - top_k_logits, top_k_indices = jax.lax.top_k( - router_logits, self.num_experts_per_tok - ) - top_k_weights = jax.nn.softmax( - top_k_logits.astype(jnp.float32), axis=-1 - ).astype(self.dtype) + top_k_logits, top_k_indices = jax.lax.top_k(router_logits, self.num_experts_per_tok) + top_k_weights = jax.nn.softmax(top_k_logits.astype(jnp.float32), axis=-1).astype(self.dtype) top_k_weights = top_k_weights / jnp.sum(top_k_weights, axis=-1, keepdims=True) @@ -394,26 +380,20 @@ def _expert_all_to_all_dispatch(self, data, sorted_experts, expert_shard_id): valid_experts_for_bincount = jnp.where( valid_expert_mask, local_experts_extracted, local_expert_size ) - local_group_sizes = jnp.bincount( - valid_experts_for_bincount, length=local_expert_size - ) + local_group_sizes = jnp.bincount(valid_experts_for_bincount, length=local_expert_size) return local_data, local_group_sizes, local_experts_extracted def _get_all_to_all_params(self, group_sizes, shard_id): input_offsets = jnp.zeros(self.expert_parallel_size, dtype=group_sizes.dtype) send_sizes = jnp.repeat(group_sizes[shard_id], self.expert_parallel_size) - output_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(group_sizes[:-1])))[ - shard_id - ] + output_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(group_sizes[:-1])))[shard_id] output_offsets = jnp.repeat(output_offset, self.expert_parallel_size) recv_sizes = group_sizes return input_offsets, send_sizes, output_offsets, recv_sizes - def _expert_all_to_all_collect( - self, data, global_group_sizes, expert_shard_id, target_size - ): + def _expert_all_to_all_collect(self, data, global_group_sizes, expert_shard_id, target_size): # Calculate the number of tokens to be handled by each device. reshaped_group_sizes = global_group_sizes.reshape( self.expert_parallel_size, self.experts_per_device @@ -421,8 +401,8 @@ def _expert_all_to_all_collect( tokens_per_device = jnp.sum(reshaped_group_sizes, axis=1) # Get parameters for ragged_all_to_all - input_offsets, send_sizes, output_offsets, recv_sizes = ( - self._get_all_to_all_params(tokens_per_device, expert_shard_id) + input_offsets, send_sizes, output_offsets, recv_sizes = self._get_all_to_all_params( + tokens_per_device, expert_shard_id ) # Create output shape buffer @@ -455,9 +435,7 @@ def _permute(self, inputs, top_k_indices, top_k_weights): sorted_selected_experts = jnp.argsort(flatten_selected_experts) sorted_indices = sorted_selected_experts // self.num_experts_per_tok - sorted_inputs = jnp.take(inputs_2d, indices=sorted_indices, axis=0).astype( - self.dtype - ) + sorted_inputs = jnp.take(inputs_2d, indices=sorted_indices, axis=0).astype(self.dtype) group_sizes = jnp.bincount(flatten_selected_experts, length=self.num_experts) @@ -476,9 +454,7 @@ def _permute(self, inputs, top_k_indices, top_k_weights): sorted_experts, ) - def _unpermute( - self, intermediate, sorted_selected_experts, weights, batch_size, seq_len - ): + def _unpermute(self, intermediate, sorted_selected_experts, weights, batch_size, seq_len): expected_tokens = sorted_selected_experts.shape[0] actual_tokens = intermediate.shape[0] @@ -487,9 +463,7 @@ def _unpermute( intermediate = intermediate[:expected_tokens] else: padding_size = expected_tokens - actual_tokens - padding = jnp.zeros( - (padding_size, intermediate.shape[1]), dtype=intermediate.dtype - ) + padding = jnp.zeros((padding_size, intermediate.shape[1]), dtype=intermediate.dtype) intermediate = jnp.concatenate([intermediate, padding], axis=0) argsort_indices = jnp.argsort(sorted_selected_experts) @@ -497,9 +471,7 @@ def _unpermute( total_tokens = weights.shape[0] * weights.shape[1] // self.num_experts_per_tok - reshaped_weights = jnp.reshape( - weights, (total_tokens, self.num_experts_per_tok) - ) + reshaped_weights = jnp.reshape(weights, (total_tokens, self.num_experts_per_tok)) reshaped_intermediate = jnp.reshape( unsort_intermediate, (total_tokens, self.num_experts_per_tok, -1), diff --git a/python/sgl_jax/srt/layers/sampler.py b/python/sgl_jax/srt/layers/sampler.py index 5d5ba642d..2710e2245 100644 --- a/python/sgl_jax/srt/layers/sampler.py +++ b/python/sgl_jax/srt/layers/sampler.py @@ -38,9 +38,7 @@ def _regular_sampling(self, operands): ), f"Temperature batch size {temperatures_shape[0]} doesn't match logits batch size {logits_batch_size}" # Post process logits - processed_logits = jnp.divide(logits, sampling_metadata.temperatures).astype( - logits.dtype - ) + processed_logits = jnp.divide(logits, sampling_metadata.temperatures).astype(logits.dtype) probs = jax.nn.softmax(processed_logits, axis=-1) @@ -132,9 +130,7 @@ def _apply_min_tokens_penalty(self, operands): return logits + stop_penalty.astype(logits.dtype) - def apply_penalties( - self, logits: jax.Array, sampling_metadata: SamplingMetadata - ) -> jax.Array: + def apply_penalties(self, logits: jax.Array, sampling_metadata: SamplingMetadata) -> jax.Array: """ Apply penalties to logits with JIT-optimized tensor operations using lax.cond. diff --git a/python/sgl_jax/srt/managers/detokenizer_manager.py b/python/sgl_jax/srt/managers/detokenizer_manager.py index 526aad7da..39ddabb15 100644 --- a/python/sgl_jax/srt/managers/detokenizer_manager.py +++ b/python/sgl_jax/srt/managers/detokenizer_manager.py @@ -91,9 +91,7 @@ def event_loop(self): # if recv_obj is not None: self.send_to_tokenizer.send_pyobj(output) - def trim_matched_stop( - self, output: str | list[int], finished_reason: dict, no_stop_trim: bool - ): + def trim_matched_stop(self, output: str | list[int], finished_reason: dict, no_stop_trim: bool): if no_stop_trim or not finished_reason: return output @@ -158,10 +156,7 @@ def deep_flatten(lst): if hasattr(lst, "__len__") and len(lst) == 1: return [int(lst.item())] elif hasattr(lst, "__iter__"): - return [ - int(x.item() if hasattr(x, "item") else x) - for x in lst - ] + return [int(x.item() if hasattr(x, "item") else x) for x in lst] else: return [int(lst.item())] else: @@ -178,9 +173,7 @@ def deep_flatten(lst): result.append(int(item.item())) elif hasattr(item, "__iter__"): for x in item: - result.append( - int(x.item() if hasattr(x, "item") else x) - ) + result.append(int(x.item() if hasattr(x, "item") else x)) else: result.append(int(item.item())) else: diff --git a/python/sgl_jax/srt/managers/io_struct.py b/python/sgl_jax/srt/managers/io_struct.py index dd63aff44..f8e23535c 100644 --- a/python/sgl_jax/srt/managers/io_struct.py +++ b/python/sgl_jax/srt/managers/io_struct.py @@ -235,12 +235,8 @@ def _expand_inputs(self, num): raise ValueError("Text should be a list for batch processing.") self.text = self.text * self.parallel_sample_num elif self.input_ids is not None: - if not isinstance(self.input_ids, list) or not isinstance( - self.input_ids[0], list - ): - raise ValueError( - "input_ids should be a list of lists for batch processing." - ) + if not isinstance(self.input_ids, list) or not isinstance(self.input_ids[0], list): + raise ValueError("input_ids should be a list of lists for batch processing.") self.input_ids = self.input_ids * self.parallel_sample_num elif self.input_embeds is not None: if not isinstance(self.input_embeds, list): @@ -302,21 +298,13 @@ def normalize_param(param, default_value, param_name): return [param] * num else: if self.parallel_sample_num > 1: - raise ValueError( - f"Cannot use list {param_name} with parallel_sample_num > 1" - ) + raise ValueError(f"Cannot use list {param_name} with parallel_sample_num > 1") return param # Normalize each logprob parameter - self.return_logprob = normalize_param( - self.return_logprob, False, "return_logprob" - ) - self.logprob_start_len = normalize_param( - self.logprob_start_len, -1, "logprob_start_len" - ) - self.top_logprobs_num = normalize_param( - self.top_logprobs_num, 0, "top_logprobs_num" - ) + self.return_logprob = normalize_param(self.return_logprob, False, "return_logprob") + self.logprob_start_len = normalize_param(self.logprob_start_len, -1, "logprob_start_len") + self.top_logprobs_num = normalize_param(self.top_logprobs_num, 0, "top_logprobs_num") # Handle token_ids_logprob specially due to its nested structure if not self.token_ids_logprob: # covers both None and [] @@ -324,13 +312,9 @@ def normalize_param(param, default_value, param_name): elif not isinstance(self.token_ids_logprob, list): self.token_ids_logprob = [[self.token_ids_logprob] for _ in range(num)] elif not isinstance(self.token_ids_logprob[0], list): - self.token_ids_logprob = [ - copy.deepcopy(self.token_ids_logprob) for _ in range(num) - ] + self.token_ids_logprob = [copy.deepcopy(self.token_ids_logprob) for _ in range(num)] elif self.parallel_sample_num > 1: - raise ValueError( - "Cannot use list token_ids_logprob with parallel_sample_num > 1" - ) + raise ValueError("Cannot use list token_ids_logprob with parallel_sample_num > 1") def regenerate_rid(self): """Generate a new request ID and return it.""" diff --git a/python/sgl_jax/srt/managers/schedule_batch.py b/python/sgl_jax/srt/managers/schedule_batch.py index 5846877a8..05e2e752d 100644 --- a/python/sgl_jax/srt/managers/schedule_batch.py +++ b/python/sgl_jax/srt/managers/schedule_batch.py @@ -322,9 +322,7 @@ def init_incremental_detokenize(self): if first_iter: self.read_offset = len(self.origin_input_ids_unpadded) - self.surr_offset = max( - self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0 - ) + self.surr_offset = max(self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0) all_ids = self.origin_input_ids_unpadded + self.output_ids return all_ids[self.surr_offset :], self.read_offset - self.surr_offset @@ -340,9 +338,7 @@ def check_finished(self): return if len(self.output_ids) >= self.sampling_params.max_new_tokens: - self.finished_reason = FINISH_LENGTH( - length=self.sampling_params.max_new_tokens - ) + self.finished_reason = FINISH_LENGTH(length=self.sampling_params.max_new_tokens) return last_token_id = self.output_ids[-1] @@ -358,27 +354,19 @@ def check_finished(self): if self.eos_token_ids: if any(hasattr(token_id, "item") for token_id in self.eos_token_ids): self.eos_token_ids = { - ( - int(token_id.item()) - if hasattr(token_id, "item") - else int(token_id) - ) + (int(token_id.item()) if hasattr(token_id, "item") else int(token_id)) for token_id in self.eos_token_ids } matched_eos |= last_token_id in self.eos_token_ids if self.tokenizer is not None: matched_eos |= last_token_id == self.tokenizer.eos_token_id if self.tokenizer.additional_stop_token_ids: - matched_eos |= ( - last_token_id in self.tokenizer.additional_stop_token_ids - ) + matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids if matched_eos: self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) return - if self.vocab_size is not None and ( - last_token_id > self.vocab_size or last_token_id < 0 - ): + if self.vocab_size is not None and (last_token_id > self.vocab_size or last_token_id < 0): if self.sampling_params.stop_token_ids: self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids)) elif self.eos_token_ids: @@ -414,9 +402,7 @@ def set_finish_with_abort(self, error_msg: str): # set it to one token to skip the long prefill self.origin_input_ids = [0] self.return_logprob = False - self.finished_reason = FINISH_ABORT( - error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" - ) + self.finished_reason = FINISH_ABORT(error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError") def __repr__(self): return ( @@ -527,9 +513,7 @@ def init_new( has_stream=any(req.stream for req in reqs), chunked_req=chunked_req, mesh=mesh, - is_prefill_only=all( - req.sampling_params.max_new_tokens == 0 for req in reqs - ), + is_prefill_only=all(req.sampling_params.max_new_tokens == 0 for req in reqs), ) def batch_size(self): @@ -581,10 +565,7 @@ def alloc_paged_token_slots_extend( extend_num_tokens: int, backup_state: bool = False, ): - num_tokens = ( - extend_num_tokens - + len(seq_lens) * self.token_to_kv_pool_allocator.page_size - ) + num_tokens = extend_num_tokens + len(seq_lens) * self.token_to_kv_pool_allocator.page_size self._evict_tree_cache_if_needed(num_tokens) if backup_state: @@ -645,9 +626,7 @@ def mix_with_running(self, running_batch: ScheduleBatch): req.extend_input_len = 1 input_ids = jnp.concatenate([self.input_ids, running_batch.input_ids]) - out_cache_loc = jnp.concatenate( - [self.out_cache_loc, running_batch.out_cache_loc] - ) + out_cache_loc = jnp.concatenate([self.out_cache_loc, running_batch.out_cache_loc]) self.merge_batch(running_batch) self.input_ids = input_ids @@ -656,10 +635,7 @@ def mix_with_running(self, running_batch: ScheduleBatch): delta = 0 if self.enable_overlap else -1 # NOTE: prefix_indices is what has been cached, but we don't cache each decode step self.prefix_lens.extend( - [ - len(r.origin_input_ids) + len(r.output_ids) + delta - for r in running_batch.reqs - ] + [len(r.origin_input_ids) + len(r.output_ids) + delta for r in running_batch.reqs] ) self.extend_lens.extend([1] * running_bs) self.extend_num_tokens += running_bs @@ -695,9 +671,7 @@ def prepare_for_extend(self): prefix_indices = req.prefix_indices if pre_len > 0: # note: prefix_indices has to locate on device, or will meet Received incompatible devices for jitted computation - self.req_to_token_pool.write( - (req.req_pool_idx, slice(0, pre_len)), prefix_indices - ) + self.req_to_token_pool.write((req.req_pool_idx, slice(0, pre_len)), prefix_indices) req.cached_tokens += pre_len - req.already_computed req.already_computed = seq_len @@ -735,20 +709,14 @@ def prepare_for_extend(self): if global_start_idx < req.logprob_start_len: global_start_idx = req.logprob_start_len - logprob_token_ids = req.origin_input_ids[ - global_start_idx + 1 : global_end_idx + 1 - ] + logprob_token_ids = req.origin_input_ids[global_start_idx + 1 : global_end_idx + 1] extend_input_logprob_token_ids.extend(logprob_token_ids) # We will need req.extend_input_len - req.extend_logprob_start_len number of # tokens, and logprob_token_ids is for input logprob, so pad the rest of them by 0. extend_input_logprob_token_ids.extend( [0] - * ( - req.extend_input_len - - req.extend_logprob_start_len - - len(logprob_token_ids) - ) + * (req.extend_input_len - req.extend_logprob_start_len - len(logprob_token_ids)) ) if self.return_logprob: @@ -847,10 +815,7 @@ def _get_available_size(): retracted_reqs = [] seq_lens_cpu = self.seq_lens first_iter = True - while ( - _get_available_size() < get_required_tokens(len(sorted_indices)) - or first_iter - ): + while _get_available_size() < get_required_tokens(len(sorted_indices)) or first_iter: if len(sorted_indices) == 1: # Corner case: only one request left assert ( @@ -930,22 +895,14 @@ def prepare_for_decode(self): # TODO: this can be slow, optimize this. delayed_output_ids = np.array( [ - ( - req.output_ids[-1] - if len(req.output_ids) - else req.origin_input_ids[-1] - ) + (req.output_ids[-1] if len(req.output_ids) else req.origin_input_ids[-1]) for req in self.reqs ], dtype=np.int64, ) - self.sampling_info.penalizer_orchestrator.cumulate_output_tokens( - delayed_output_ids - ) + self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(delayed_output_ids) else: - self.sampling_info.penalizer_orchestrator.cumulate_output_tokens( - self.output_ids - ) + self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.output_ids) # Update fields self.input_ids = self.output_ids @@ -964,9 +921,7 @@ def prepare_for_decode(self): if self.token_to_kv_pool_allocator.page_size == 1: self.out_cache_loc = self.alloc_token_slots(bs) else: - last_loc = self.req_to_token_pool.req_to_token[ - self.req_pool_indices, self.seq_lens - 2 - ] + last_loc = self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens - 2] self.out_cache_loc = self.alloc_paged_token_slots_decode( self.seq_lens.tolist(), last_loc.tolist(), @@ -988,8 +943,7 @@ def filter_batch( keep_indices = [ i for i in range(len(self.reqs)) - if not self.reqs[i].finished() - and self.reqs[i] not in chunked_req_to_exclude + if not self.reqs[i].finished() and self.reqs[i] not in chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: @@ -1006,9 +960,7 @@ def filter_batch( self.seq_lens = self.seq_lens[keep_indices] self.out_cache_loc = None self.seq_lens_sum = self.seq_lens.sum().item() - self.output_ids = ( - self.output_ids[keep_indices] if self.output_ids is not None else None - ) + self.output_ids = self.output_ids[keep_indices] if self.output_ids is not None else None self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob: self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices] @@ -1027,9 +979,7 @@ def merge_batch(self, other: ScheduleBatch): # needs to be called with pre-merged Batch.reqs. self.sampling_info.merge_batch(other.sampling_info) - self.req_pool_indices = np.concat( - [self.req_pool_indices, other.req_pool_indices] - ) + self.req_pool_indices = np.concat([self.req_pool_indices, other.req_pool_indices]) self.seq_lens = np.concat([self.seq_lens, other.seq_lens]) self.out_cache_loc = None self.seq_lens_sum += other.seq_lens_sum @@ -1085,9 +1035,7 @@ def get_model_worker_batch( seq_lens_cpu = self.seq_lens real_bs = len(seq_lens_cpu) req_pool_indices_cpu = self.req_pool_indices - token_indices_with_all_reqs = self.req_to_token_pool.req_to_token[ - self.req_pool_indices - ] + token_indices_with_all_reqs = self.req_to_token_pool.req_to_token[self.req_pool_indices] # padding seq # extend & decode: input_ids, positions, out_cache_loc, cache_loc @@ -1125,9 +1073,7 @@ def get_model_worker_batch( if self.forward_mode.is_extend(): # For prefill: create positions for each token in sequences # Calculate total tokens without padding first - total_tokens_before_padding = sum( - [extend_len for extend_len in self.extend_lens] - ) + total_tokens_before_padding = sum([extend_len for extend_len in self.extend_lens]) positions_cpu = np.concatenate( [ np.arange(prefix_len, seq_len, dtype=seq_lens_cpu.dtype) @@ -1178,9 +1124,7 @@ def get_model_worker_batch( valid_seq_lens = seq_lens_cpu[valid_mask] # Calculate aligned lengths for all valid sequences at once - aligned_lengths = ( - (valid_seq_lens + page_size - 1) // page_size - ) * page_size + aligned_lengths = ((valid_seq_lens + page_size - 1) // page_size) * page_size total_aligned_length = np.sum(aligned_lengths) # Pre-allocate the result array @@ -1192,9 +1136,9 @@ def get_model_worker_batch( zip(valid_indices, valid_seq_lens, aligned_lengths) ): # Copy the actual data - cache_loc_flat[offset : offset + seq_len] = ( - token_indices_with_all_reqs[seq_idx, :seq_len] - ) + cache_loc_flat[offset : offset + seq_len] = token_indices_with_all_reqs[ + seq_idx, :seq_len + ] # Padding is already zero from initialization offset += aligned_len @@ -1228,9 +1172,7 @@ def get_model_worker_batch( [extend_start_loc[-1] + extend_seq_lens[-1]] * bs_padding_size, dtype=extend_start_loc.dtype, ) - extend_start_loc = np.concat( - [extend_start_loc, invalid_extend_start_loc], axis=0 - ) + extend_start_loc = np.concat([extend_start_loc, invalid_extend_start_loc], axis=0) invalid_extend_prefix_lens = np.array( [0] * bs_padding_size, dtype=extend_prefix_lens.dtype ) @@ -1240,16 +1182,12 @@ def get_model_worker_batch( invalid_extend_seq_lens = np.array( [0] * bs_padding_size, dtype=extend_seq_lens.dtype ) - extend_seq_lens = np.concat( - [extend_seq_lens, invalid_extend_seq_lens], axis=0 - ) + extend_seq_lens = np.concat([extend_seq_lens, invalid_extend_seq_lens], axis=0) else: invalid_extend_start_loc = np.array( [len(seq_lens_cpu)] * bs_padding_size, dtype=extend_start_loc.dtype ) - extend_start_loc = np.concat( - [extend_start_loc, invalid_extend_start_loc], axis=0 - ) + extend_start_loc = np.concat([extend_start_loc, invalid_extend_start_loc], axis=0) if precision_tracer.get_trace_active(): self._generate_trace_info(real_bs, bid) @@ -1272,9 +1210,7 @@ def get_model_worker_batch( extend_prefix_lens=( extend_prefix_lens if self.forward_mode == ForwardMode.EXTEND else None ), - extend_seq_lens=( - extend_seq_lens if self.forward_mode == ForwardMode.EXTEND else None - ), + extend_seq_lens=(extend_seq_lens if self.forward_mode == ForwardMode.EXTEND else None), extend_logprob_start_lens=extend_logprob_start_lens, extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, real_bs=real_bs, @@ -1296,9 +1232,7 @@ def _generate_trace_info(self, real_bs: int, bid: int) -> list[str]: precision_tracer.add_request_to_batch_requests_mapping( bid, - PrecisionTracerRequestMetadata( - req.rid, input_ids_to_trace, self.forward_mode - ), + PrecisionTracerRequestMetadata(req.rid, input_ids_to_trace, self.forward_mode), ) if self.forward_mode == ForwardMode.EXTEND: precision_tracer.add_request_counter() diff --git a/python/sgl_jax/srt/managers/schedule_policy.py b/python/sgl_jax/srt/managers/schedule_policy.py index dcc174eec..4173bac73 100644 --- a/python/sgl_jax/srt/managers/schedule_policy.py +++ b/python/sgl_jax/srt/managers/schedule_policy.py @@ -87,13 +87,9 @@ def calc_priority(self, waiting_queue: list[Req]) -> bool: prefix_computed = False if isinstance(policy, CacheAwarePolicy): prefix_computed = True - temporary_deprioritized = self._compute_prefix_matches( - waiting_queue, policy - ) + temporary_deprioritized = self._compute_prefix_matches(waiting_queue, policy) if policy == CacheAwarePolicy.LPM: - SchedulePolicy._sort_by_longest_prefix( - waiting_queue, temporary_deprioritized - ) + SchedulePolicy._sort_by_longest_prefix(waiting_queue, temporary_deprioritized) elif policy == CacheAwarePolicy.DFS_WEIGHT: SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache) else: @@ -116,9 +112,7 @@ def _determine_active_policy(self, waiting_queue: list[Req]) -> Policy: return CacheAgnosticPolicy.FCFS return self.policy - def _validate_and_adjust_policy( - self, policy: str, tree_cache: BasePrefixCache - ) -> Policy: + def _validate_and_adjust_policy(self, policy: str, tree_cache: BasePrefixCache) -> Policy: """ Validates the policy and adjusts it if necessary based on tree cache settings. """ @@ -160,10 +154,8 @@ def _compute_prefix_matches( # threshold means we cannot use in-batch prefix caching for short prefixes. # It is kind of common when the engine is long running (e.g., imagine the prefix "the"). if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: - in_batch_matching_prefixes, _, _, _ = ( - self.waiting_queue_radix_tree.match_prefix( - rid=r.rid, key=prefix_ids - ) + in_batch_matching_prefixes, _, _, _ = self.waiting_queue_radix_tree.match_prefix( + rid=r.rid, key=prefix_ids ) if ( len(in_batch_matching_prefixes) @@ -184,16 +176,12 @@ def _sort_by_longest_prefix( """Sorts the waiting queue based on the longest prefix match.""" waiting_queue.sort( key=lambda r: ( - -len(r.prefix_indices) - if r.rid not in temporary_deprioritized - else float("inf") + -len(r.prefix_indices) if r.rid not in temporary_deprioritized else float("inf") ) ) @staticmethod - def _sort_by_dfs_weight( - waiting_queue: list[Req], tree_cache: BasePrefixCache - ) -> None: + def _sort_by_dfs_weight(waiting_queue: list[Req], tree_cache: BasePrefixCache) -> None: """Sorts the waiting queue based on a depth-first search weighting.""" last_node_to_reqs = defaultdict(list) for req in waiting_queue: @@ -238,9 +226,7 @@ def _get_dfs_priority( childs = [child for child in cur_node.children.values()] childs.sort(key=lambda x: -node_to_priority[x]) for child in childs: - SchedulePolicy._get_dfs_priority( - child, node_to_priority, last_node_to_reqs, q - ) + SchedulePolicy._get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) q.extend(last_node_to_reqs[cur_node]) @@ -296,8 +282,7 @@ def __init__( @property def rem_total_tokens(self): available_and_evictable = ( - self.token_to_kv_pool_allocator.available_size() - + self.tree_cache.evictable_size() + self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() ) return available_and_evictable - self.rem_total_token_offset @@ -305,8 +290,7 @@ def rem_total_tokens(self): @property def cur_rem_tokens(self): available_and_evictable = ( - self.token_to_kv_pool_allocator.available_size() - + self.tree_cache.evictable_size() + self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() ) return available_and_evictable - self.cur_rem_token_offset @@ -344,9 +328,7 @@ def add_chunked_req(self, req: Req): # Return if chunked prefill not finished return req if truncated else None - def _update_prefill_budget( - self, prefix_len: int, extend_input_len: int, max_new_tokens: int - ): + def _update_prefill_budget(self, prefix_len: int, extend_input_len: int, max_new_tokens: int): extend_input_len = self.ceil_paged_tokens(extend_input_len) self.rem_total_token_offset += extend_input_len + max_new_tokens @@ -373,12 +355,8 @@ def add_one_req_ignore_eos(self, req: Req): return AddReqResult.NO_TOKEN def add_req_state(r, insert_sort=False): - new_token_ratio = ( - 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio - ) - tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len( - r.output_ids - ) + new_token_ratio = 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio + tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(r.output_ids) tokens_occupied = len(r.origin_input_ids) + len(r.output_ids) if tokens_left <= 0: @@ -405,9 +383,7 @@ def add_req_state(r, insert_sort=False): else: add_req_state(req, insert_sort=True) - cur_rem_tokens = self.cur_rem_tokens - self.ceil_paged_tokens( - req.extend_input_len - ) + cur_rem_tokens = self.cur_rem_tokens - self.ceil_paged_tokens(req.extend_input_len) tokens_freed = 0 for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): # tokens_left gives a reservative calculation as the last token is not stored diff --git a/python/sgl_jax/srt/managers/scheduler.py b/python/sgl_jax/srt/managers/scheduler.py index 5255987fc..806db7126 100644 --- a/python/sgl_jax/srt/managers/scheduler.py +++ b/python/sgl_jax/srt/managers/scheduler.py @@ -156,13 +156,9 @@ def __init__( context, zmq.PUSH, port_args.detokenizer_ipc_name, False ) - self.recv_from_rpc = get_zmq_socket( - context, zmq.DEALER, port_args.rpc_ipc_name, False - ) + self.recv_from_rpc = get_zmq_socket(context, zmq.DEALER, port_args.rpc_ipc_name, False) if self.nnodes > 1: - self.publisher = get_zmq_socket( - context, zmq.PUB, self.pub_sub_addr, bind=True - ) + self.publisher = get_zmq_socket(context, zmq.PUB, self.pub_sub_addr, bind=True) self.publisher_sync = get_zmq_socket( context, zmq.REP, self.pub_sub_sync_addr, bind=True ) @@ -173,9 +169,7 @@ def __init__( self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None) self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) if self.nnodes > 1: - self.subscriber = get_zmq_socket( - context, zmq.SUB, self.pub_sub_addr, bind=False - ) + self.subscriber = get_zmq_socket(context, zmq.SUB, self.pub_sub_addr, bind=False) self.subscriber.setsockopt(zmq.SUBSCRIBE, b"") self.subscriber.setsockopt(zmq.RCVTIMEO, 5000) self.subscriber_sync = get_zmq_socket( @@ -194,9 +188,7 @@ def __init__( # init distribution if self.nnodes > 1: - jax.distributed.initialize( - server_args.dist_init_addr, self.nnodes, self.node_rank - ) + jax.distributed.initialize(server_args.dist_init_addr, self.nnodes, self.node_rank) self.mesh = create_device_mesh( ici_parallelism=[-1, self.tp_size, 1], dcn_parallelism=[1, 1, 1] ) @@ -262,17 +254,13 @@ def __init__( self.schedule_policy, self.tree_cache, ) - assert ( - server_args.schedule_conservativeness >= 0 - ), "Invalid schedule_conservativeness" + assert server_args.schedule_conservativeness >= 0, "Invalid schedule_conservativeness" self.init_new_token_ratio = min( - global_config.default_init_new_token_ratio - * server_args.schedule_conservativeness, + global_config.default_init_new_token_ratio * server_args.schedule_conservativeness, 1.0, ) self.min_new_token_ratio = min( - self.init_new_token_ratio - * global_config.default_min_new_token_ratio_factor, + self.init_new_token_ratio * global_config.default_min_new_token_ratio_factor, 1.0, ) self.new_token_ratio_decay = ( @@ -327,9 +315,7 @@ def sync_pub(self): else: self.publisher_sync.send_string("NACK") except zmq.Again: - logger.error( - "[Publisher %s] Fails to synchronize due to timeout", self.node_rank - ) + logger.error("[Publisher %s] Fails to synchronize due to timeout", self.node_rank) return False except Exception as e: logger.error("[Publisher %s] Encounters error: %s", self.node_rank, e) @@ -353,9 +339,7 @@ def sync_sub(self): ) return False except Exception as e: - logger.error( - "[Subscriber %s] Fails to synchronize with error: %s", self.node_rank, e - ) + logger.error("[Subscriber %s] Fails to synchronize with error: %s", self.node_rank, e) return False def sync_pub_sub(self): @@ -379,14 +363,9 @@ def init_tokenizer(self): def init_memory_pool_and_cache(self): server_args = self.server_args - self.req_to_token_pool, self.token_to_kv_pool_allocator = ( - self.tp_worker.get_memory_pool() - ) + self.req_to_token_pool, self.token_to_kv_pool_allocator = self.tp_worker.get_memory_pool() - if ( - server_args.chunked_prefill_size is not None - and server_args.disable_radix_cache - ): + if server_args.chunked_prefill_size is not None and server_args.disable_radix_cache: self.tree_cache = ChunkCache( req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, @@ -509,9 +488,7 @@ def broadcast_pyobj(self, recv_reqs): else: recv_reqs = self.run_subscriber() if recv_reqs is None: - raise ReceiveDataError( - f"[Subscriber {self.node_rank}] Fails to receive data" - ) + raise ReceiveDataError(f"[Subscriber {self.node_rank}] Fails to receive data") return recv_reqs def recv_requests(self) -> list[Req]: @@ -604,9 +581,7 @@ def get_internal_state(self, recv_req: GetInternalStateReq): ret = dict(global_server_args_dict) ret["last_gen_throughput"] = self.last_gen_throughput ret["memory_usage"] = { - "kvcache": round( - self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2 - ), + "kvcache": round(self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2), "token_capacity": int(self.max_total_num_tokens), } @@ -638,9 +613,7 @@ def set_internal_state(self, recv_req: SetInternalStateReq): precision_tracer._request_counter = 0 precision_tracer._completed_requests_count = 0 precision_tracer._request_traces = {} - logger.info( - "[SCHEDULER] Reset request_counter, completed_count and traces" - ) + logger.info("[SCHEDULER] Reset request_counter, completed_count and traces") if "max_requests" in tracer_config: precision_tracer._max_requests = tracer_config["max_requests"] @@ -656,9 +629,7 @@ def set_internal_state(self, recv_req: SetInternalStateReq): precision_tracer._trace_output_file, ) - logger.info( - "[SCHEDULER] Precision tracer state updated: %s", tracer_config - ) + logger.info("[SCHEDULER] Precision tracer state updated: %s", tracer_config) except Exception as e: success = False @@ -723,9 +694,7 @@ def get_next_batch_to_run(self) -> ScheduleBatch | None: # Filter batch last_bs = self.last_batch.batch_size() - self.last_batch.filter_batch( - chunked_req_to_exclude=list(chunked_req_to_exclude) - ) + self.last_batch.filter_batch(chunked_req_to_exclude=list(chunked_req_to_exclude)) if self.last_batch.batch_size() < last_bs: self.running_batch.batch_is_full = False @@ -807,9 +776,7 @@ def get_new_batch_prefill(self) -> ScheduleBatch | None: if len(can_run_list) == 0: return None - self.waiting_queue = [ - x for x in self.waiting_queue if x not in set(can_run_list) - ] + self.waiting_queue = [x for x in self.waiting_queue if x not in set(can_run_list)] if adder.new_chunked_req is not None: assert self.chunked_req is None @@ -926,16 +893,12 @@ def run_batch(self, batch: ScheduleBatch) -> GenerationBatchResult: model_worker_batch, self.tp_worker.get_model_runner() ) logits_output, next_token_ids, cache_miss_count = ( - self.tp_worker.forward_batch_generation( - model_worker_batch, sampling_metadata=None - ) + self.tp_worker.forward_batch_generation(model_worker_batch, sampling_metadata=None) ) next_token_ids = next_token_ids[: model_worker_batch.real_bs] else: logits_output, next_token_ids_device, cache_miss_count = ( - self.tp_worker.forward_batch_generation( - model_worker_batch, sampling_metadata=None - ) + self.tp_worker.forward_batch_generation(model_worker_batch, sampling_metadata=None) ) next_token_ids = np.array(jax.device_get(next_token_ids_device))[ : model_worker_batch.real_bs @@ -952,9 +915,7 @@ def run_batch(self, batch: ScheduleBatch) -> GenerationBatchResult: else: extend_input_len_per_req = None if batch.return_logprob: - extend_logprob_start_len_per_req = [ - req.extend_logprob_start_len for req in batch.reqs - ] + extend_logprob_start_len_per_req = [req.extend_logprob_start_len for req in batch.reqs] else: extend_logprob_start_len_per_req = None @@ -1052,9 +1013,7 @@ def abort_request(self, recv_req: AbortReq): reqs = self.running_batch.reqs + self.cur_batch.reqs for req in reqs: - if not req.finished() and ( - recv_req.abort_all or req.rid.startswith(recv_req.rid) - ): + if not req.finished() and (recv_req.abort_all or req.rid.startswith(recv_req.rid)): # Abort method 3: set `to_abort=True` # The request will still run one decode forward pass. # Then we reuse all existing code to clean up the KV cache allocation. diff --git a/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py b/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py index 975cfb759..cc4b004a2 100644 --- a/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sgl_jax/srt/managers/scheduler_output_processor_mixin.py @@ -147,15 +147,11 @@ def process_batch_result_prefill( batch.cache_miss_count = cache_miss_count if batch.cache_miss_count > 0: - logger.info( - "Prefill batch. #bid: %s, #cache_miss: %s", result.bid, cache_miss_count - ) + logger.info("Prefill batch. #bid: %s, #cache_miss: %s", result.bid, cache_miss_count) self.set_next_batch_sampling_info_done(batch) - self.stream_output( - batch.reqs, batch.return_logprob, skip_stream_req, cache_miss_count - ) + self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req, cache_miss_count) def process_batch_result_decode( self: Scheduler, @@ -178,9 +174,9 @@ def process_batch_result_decode( else: # spec decoding handles output logprobs inside verify process. if batch.return_logprob: - next_token_logprobs = jax.device_get( - logits_output.next_token_logprobs - ).astype(float) + next_token_logprobs = jax.device_get(logits_output.next_token_logprobs).astype( + float + ) # batch.output_ids = np.array(next_token_ids, dtype=np.int32) @@ -197,12 +193,8 @@ def process_batch_result_decode( if self.page_size == 1: self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) else: - if ( - len(req.origin_input_ids) + len(req.output_ids) - 1 - ) % self.page_size == 0: - self.token_to_kv_pool_allocator.free( - batch.out_cache_loc[i : i + 1] - ) + if (len(req.origin_input_ids) + len(req.output_ids) - 1) % self.page_size == 0: + self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) continue req.output_ids.append(next_token_id) @@ -232,12 +224,8 @@ def process_batch_result_decode( req.output_token_logprobs_val.append(next_token_logprobs[i]) req.output_token_logprobs_idx.append(next_token_id) if req.top_logprobs_num > 0: - req.output_top_logprobs_val.append( - logits_output.next_token_top_logprobs_val[i] - ) - req.output_top_logprobs_idx.append( - logits_output.next_token_top_logprobs_idx[i] - ) + req.output_top_logprobs_val.append(logits_output.next_token_top_logprobs_val[i]) + req.output_top_logprobs_idx.append(logits_output.next_token_top_logprobs_idx[i]) if req.token_ids_logprob is not None: req.output_token_ids_logprobs_val.append( logits_output.next_token_token_ids_logprobs_val[i] @@ -247,9 +235,7 @@ def process_batch_result_decode( ) self.set_next_batch_sampling_info_done(batch) - self.stream_output( - batch.reqs, batch.return_logprob, cache_miss_count=cache_miss_count - ) + self.stream_output(batch.reqs, batch.return_logprob, cache_miss_count=cache_miss_count) self.token_to_kv_pool_allocator.free_group_end() self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) @@ -304,9 +290,7 @@ def add_input_logprob_return_values( # Important for the performance. assert isinstance(output.input_token_logprobs, tuple) input_token_logprobs: tuple[int] = output.input_token_logprobs - input_token_logprobs = input_token_logprobs[ - logprob_pt : logprob_pt + num_input_logprobs - ] + input_token_logprobs = input_token_logprobs[logprob_pt : logprob_pt + num_input_logprobs] req.input_token_logprobs.extend(input_token_logprobs) if req.top_logprobs_num > 0: @@ -314,12 +298,8 @@ def add_input_logprob_return_values( req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i]) if req.token_ids_logprob is not None: - req.temp_input_token_ids_logprobs_val.append( - output.input_token_ids_logprobs_val[i] - ) - req.temp_input_token_ids_logprobs_idx.append( - output.input_token_ids_logprobs_idx[i] - ) + req.temp_input_token_ids_logprobs_val.append(output.input_token_ids_logprobs_val[i]) + req.temp_input_token_ids_logprobs_idx.append(output.input_token_ids_logprobs_idx[i]) if last_prefill_chunk: input_token_logprobs = req.input_token_logprobs @@ -341,8 +321,7 @@ def add_input_logprob_return_values( # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. input_token_logprobs_idx = [ - x if x < self.model_config.vocab_size - 1 else 0 - for x in input_token_logprobs_idx + x if x < self.model_config.vocab_size - 1 else 0 for x in input_token_logprobs_idx ] req.input_token_logprobs_idx = input_token_logprobs_idx @@ -417,12 +396,8 @@ def add_logprob_return_values( req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) if req.token_ids_logprob is not None: - req.output_token_ids_logprobs_val.append( - output.next_token_token_ids_logprobs_val[i] - ) - req.output_token_ids_logprobs_idx.append( - output.next_token_token_ids_logprobs_idx[i] - ) + req.output_token_ids_logprobs_val.append(output.next_token_token_ids_logprobs_val[i]) + req.output_token_ids_logprobs_idx.append(output.next_token_token_ids_logprobs_idx[i]) return num_input_logprobs @@ -474,15 +449,13 @@ def stream_output_generation( output_token_ids_logprobs_val = [] output_token_ids_logprobs_idx = [] else: - input_token_logprobs_val = input_token_logprobs_idx = ( - output_token_logprobs_val - ) = output_token_logprobs_idx = input_top_logprobs_val = ( - input_top_logprobs_idx - ) = output_top_logprobs_val = output_top_logprobs_idx = ( - input_token_ids_logprobs_val - ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = ( - output_token_ids_logprobs_idx - ) = None + input_token_logprobs_val = input_token_logprobs_idx = output_token_logprobs_val = ( + output_token_logprobs_idx + ) = input_top_logprobs_val = input_top_logprobs_idx = output_top_logprobs_val = ( + output_top_logprobs_idx + ) = input_token_ids_logprobs_val = input_token_ids_logprobs_idx = ( + output_token_ids_logprobs_val + ) = output_token_ids_logprobs_idx = None for req in reqs: if req is skip_req: @@ -497,24 +470,18 @@ def stream_output_generation( should_output = True else: if req.stream: - stream_interval = ( - req.sampling_params.stream_interval or self.stream_interval - ) + stream_interval = req.sampling_params.stream_interval or self.stream_interval should_output = ( len(req.output_ids) % stream_interval == 1 if stream_interval > 1 else len(req.output_ids) % stream_interval == 0 ) else: - should_output = ( - len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0 - ) + should_output = len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0 if should_output: send_token_offset = req.send_token_offset - send_output_token_logprobs_offset = ( - req.send_output_token_logprobs_offset - ) + send_output_token_logprobs_offset = req.send_output_token_logprobs_offset if isinstance(req.rid, list): # if rid is a list, extend the list to rids rids.extend(req.rid) @@ -548,12 +515,8 @@ def stream_output_generation( input_token_logprobs_idx.append(req.input_token_logprobs_idx) input_top_logprobs_val.append(req.input_top_logprobs_val) input_top_logprobs_idx.append(req.input_top_logprobs_idx) - input_token_ids_logprobs_val.append( - req.input_token_ids_logprobs_val - ) - input_token_ids_logprobs_idx.append( - req.input_token_ids_logprobs_idx - ) + input_token_ids_logprobs_val.append(req.input_token_ids_logprobs_val) + input_token_ids_logprobs_idx.append(req.input_token_ids_logprobs_idx) req.input_logprob_sent = True else: input_token_logprobs_val.append([]) @@ -565,38 +528,24 @@ def stream_output_generation( if req.return_logprob: output_token_logprobs_val.append( - req.output_token_logprobs_val[ - send_output_token_logprobs_offset: - ] + req.output_token_logprobs_val[send_output_token_logprobs_offset:] ) output_token_logprobs_idx.append( - req.output_token_logprobs_idx[ - send_output_token_logprobs_offset: - ] + req.output_token_logprobs_idx[send_output_token_logprobs_offset:] ) output_top_logprobs_val.append( - req.output_top_logprobs_val[ - send_output_token_logprobs_offset: - ] + req.output_top_logprobs_val[send_output_token_logprobs_offset:] ) output_top_logprobs_idx.append( - req.output_top_logprobs_idx[ - send_output_token_logprobs_offset: - ] + req.output_top_logprobs_idx[send_output_token_logprobs_offset:] ) output_token_ids_logprobs_val.append( - req.output_token_ids_logprobs_val[ - send_output_token_logprobs_offset: - ] + req.output_token_ids_logprobs_val[send_output_token_logprobs_offset:] ) output_token_ids_logprobs_idx.append( - req.output_token_ids_logprobs_idx[ - send_output_token_logprobs_offset: - ] - ) - req.send_output_token_logprobs_offset = len( - req.output_token_logprobs_val + req.output_token_ids_logprobs_idx[send_output_token_logprobs_offset:] ) + req.send_output_token_logprobs_offset = len(req.output_token_logprobs_val) else: output_token_logprobs_val.append([]) output_token_logprobs_idx.append([]) diff --git a/python/sgl_jax/srt/managers/scheduler_profiler_mixing.py b/python/sgl_jax/srt/managers/scheduler_profiler_mixing.py index 5d7714983..781852381 100644 --- a/python/sgl_jax/srt/managers/scheduler_profiler_mixing.py +++ b/python/sgl_jax/srt/managers/scheduler_profiler_mixing.py @@ -55,9 +55,7 @@ def start_profile( if num_steps: self.profile_steps = num_steps if start_step: - self.profiler_target_forward_ct = ( - self.profiler_start_forward_ct + num_steps - ) + self.profiler_target_forward_ct = self.profiler_start_forward_ct + num_steps else: self.profiler_target_forward_ct = self.forward_ct + num_steps else: @@ -111,18 +109,10 @@ def stop_profile(self) -> ProfileReqOutput | None: return ProfileReqOutput(success=True, message="Succeeded.") def _profile_batch_predicate(self, batch): - if ( - self.profiler_target_forward_ct - and self.profiler_target_forward_ct <= self.forward_ct - ): + if self.profiler_target_forward_ct and self.profiler_target_forward_ct <= self.forward_ct: self.stop_profile() - if ( - self.profiler_start_forward_ct - and self.profiler_start_forward_ct == self.forward_ct - ): - self.start_profile( - self.profiler_output_dir, None, self.profile_steps, self.profile_id - ) + if self.profiler_start_forward_ct and self.profiler_start_forward_ct == self.forward_ct: + self.start_profile(self.profiler_output_dir, None, self.profile_steps, self.profile_id) def profile(self, recv_req: ProfileReq): if recv_req.type == ProfileReqType.START_PROFILE: diff --git a/python/sgl_jax/srt/managers/tokenizer_manager.py b/python/sgl_jax/srt/managers/tokenizer_manager.py index 124087c6f..939eff751 100644 --- a/python/sgl_jax/srt/managers/tokenizer_manager.py +++ b/python/sgl_jax/srt/managers/tokenizer_manager.py @@ -183,12 +183,8 @@ def __init__( self.resume_memory_occupation_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) - self.flush_cache_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) - self.profile_communicator = _Communicator( - self.send_to_scheduler, server_args.dp_size - ) + self.flush_cache_communicator = _Communicator(self.send_to_scheduler, server_args.dp_size) + self.profile_communicator = _Communicator(self.send_to_scheduler, server_args.dp_size) self.get_internal_state_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -268,9 +264,7 @@ async def generate_request( async for response in self._wait_one_response(obj, state, request): yield response else: - async for response in self._handle_batch_request( - obj, request, created_time - ): + async for response in self._handle_batch_request(obj, request, created_time): yield response async def _tokenize_one_request( @@ -307,10 +301,7 @@ def _validate_one_request( # Check total tokens (input + max_new_tokens) max_new_tokens = obj.sampling_params.get("max_new_tokens") - if ( - max_new_tokens is not None - and (max_new_tokens + input_token_num) >= self.context_len - ): + if max_new_tokens is not None and (max_new_tokens + input_token_num) >= self.context_len: total_tokens = max_new_tokens + input_token_num error_msg = ( f"Requested token count exceeds the model's maximum context length " @@ -321,9 +312,7 @@ def _validate_one_request( ) raise ValueError(error_msg) - def _validate_input_ids_in_vocab( - self, input_ids: list[int], vocab_size: int - ) -> None: + def _validate_input_ids_in_vocab(self, input_ids: list[int], vocab_size: int) -> None: if any(id >= vocab_size for id in input_ids): raise ValueError( f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})." @@ -382,9 +371,7 @@ async def _batch_tokenize_and_process( for i, req in enumerate(requests): # self._validate_token_len(obj[i], input_ids_list[i]) tokenized_objs.append( - self._create_tokenized_object( - req, req.text, input_ids_list[i], None, None - ) + self._create_tokenized_object(req, req.text, input_ids_list[i], None, None) ) logger.debug("Completed batch processing for %s requests", batch_size) return tokenized_objs @@ -554,9 +541,7 @@ async def _handle_batch_request( rid_to_index = {rid: i for i, rid in enumerate(rids)} task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators} while task_map: - done, _ = await asyncio.wait( - task_map.keys(), return_when=asyncio.FIRST_COMPLETED - ) + done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED) for task in done: gen = task_map.pop(task) @@ -635,9 +620,7 @@ async def resume_memory_occupation( self.auto_create_handle_loop() await self.resume_memory_occupation_communicator(obj) - async def open_session( - self, obj: OpenSessionReqInput, request: fastapi.Request | None = None - ): + async def open_session(self, obj: OpenSessionReqInput, request: fastapi.Request | None = None): self.auto_create_handle_loop() if obj.session_id is None: @@ -659,9 +642,7 @@ async def close_session( async def get_internal_state(self) -> list[dict[Any, Any]]: req = GetInternalStateReq() - responses: list[GetInternalStateReqOutput] = ( - await self.get_internal_state_communicator(req) - ) + responses: list[GetInternalStateReqOutput] = await self.get_internal_state_communicator(req) # Many DP ranks return [res.internal_state for res in responses] @@ -672,13 +653,9 @@ async def get_load(self) -> dict: self.current_load = internal_state[0]["load"] return {"load": self.current_load} - async def set_internal_state( - self, obj: SetInternalStateReq - ) -> SetInternalStateReqOutput: + async def set_internal_state(self, obj: SetInternalStateReq) -> SetInternalStateReqOutput: self.auto_create_handle_loop() - responses: list[SetInternalStateReqOutput] = ( - await self.set_internal_state_communicator(obj) - ) + responses: list[SetInternalStateReqOutput] = await self.set_internal_state_communicator(obj) return ( responses[0] if responses @@ -738,9 +715,7 @@ def get_log_request_metadata(self): elif self.log_requests_level == 3: max_length = 1 << 30 else: - raise ValueError( - f"Invalid --log-requests-level: {self.log_requests_level=}" - ) + raise ValueError(f"Invalid --log-requests-level: {self.log_requests_level=}") return max_length, skip_names, out_skip_names def configure_logging(self, obj: ConfigureLoggingReq): @@ -777,9 +752,7 @@ def auto_create_handle_loop(self): self.no_create_loop = True loop = asyncio.get_event_loop() - self.asyncio_tasks.add( - loop.create_task(print_exception_wrapper(self.handle_loop)) - ) + self.asyncio_tasks.add(loop.create_task(print_exception_wrapper(self.handle_loop))) self.event_loop = loop @@ -789,18 +762,14 @@ def auto_create_handle_loop(self): signal_handler = SignalHandler(self) loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler) # Update the signal handler for the process. It overrides the sigquit handler in the launch phase. - loop.add_signal_handler( - signal.SIGQUIT, signal_handler.running_phase_sigquit_handler - ) + loop.add_signal_handler(signal.SIGQUIT, signal_handler.running_phase_sigquit_handler) else: logger.warning( "Signal handler is not added because the tokenizer manager is " "not in the main thread. This disables graceful shutdown of the " "tokenizer manager when SIGTERM is received." ) - self.asyncio_tasks.add( - loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) - ) + self.asyncio_tasks.add(loop.create_task(print_exception_wrapper(self.sigterm_watchdog))) def dump_requests_before_crash(self): if self.crash_dump_performed: @@ -824,9 +793,7 @@ def dump_requests_before_crash(self): unfinished_requests = [] for rid, state in self.rid_to_state.items(): if not state.finished: - unfinished_requests.append( - (state.obj, {}, state.created_time, time.time()) - ) + unfinished_requests.append((state.obj, {}, state.created_time, time.time())) if unfinished_requests: data_to_dump.extend(unfinished_requests) @@ -926,8 +893,7 @@ def _handle_batch_output( state, state.obj.top_logprobs_num, state.obj.token_ids_logprob, - state.obj.return_text_in_logprobs - and not self.server_args.skip_tokenizer_init, + state.obj.return_text_in_logprobs and not self.server_args.skip_tokenizer_init, recv_obj, i, ) @@ -1002,18 +968,10 @@ def convert_logprob_style( if recv_obj.input_token_logprobs_val is None: return if len(recv_obj.input_token_logprobs_val) > 0: - state.input_token_logprobs_val.extend( - recv_obj.input_token_logprobs_val[recv_obj_index] - ) - state.input_token_logprobs_idx.extend( - recv_obj.input_token_logprobs_idx[recv_obj_index] - ) - state.output_token_logprobs_val.extend( - recv_obj.output_token_logprobs_val[recv_obj_index] - ) - state.output_token_logprobs_idx.extend( - recv_obj.output_token_logprobs_idx[recv_obj_index] - ) + state.input_token_logprobs_val.extend(recv_obj.input_token_logprobs_val[recv_obj_index]) + state.input_token_logprobs_idx.extend(recv_obj.input_token_logprobs_idx[recv_obj_index]) + state.output_token_logprobs_val.extend(recv_obj.output_token_logprobs_val[recv_obj_index]) + state.output_token_logprobs_idx.extend(recv_obj.output_token_logprobs_idx[recv_obj_index]) meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( state.input_token_logprobs_val, state.input_token_logprobs_idx, @@ -1027,18 +985,10 @@ def convert_logprob_style( if top_logprobs_num > 0: if len(recv_obj.input_top_logprobs_val) > 0: - state.input_top_logprobs_val.extend( - recv_obj.input_top_logprobs_val[recv_obj_index] - ) - state.input_top_logprobs_idx.extend( - recv_obj.input_top_logprobs_idx[recv_obj_index] - ) - state.output_top_logprobs_val.extend( - recv_obj.output_top_logprobs_val[recv_obj_index] - ) - state.output_top_logprobs_idx.extend( - recv_obj.output_top_logprobs_idx[recv_obj_index] - ) + state.input_top_logprobs_val.extend(recv_obj.input_top_logprobs_val[recv_obj_index]) + state.input_top_logprobs_idx.extend(recv_obj.input_top_logprobs_idx[recv_obj_index]) + state.output_top_logprobs_val.extend(recv_obj.output_top_logprobs_val[recv_obj_index]) + state.output_top_logprobs_idx.extend(recv_obj.output_top_logprobs_idx[recv_obj_index]) meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( state.input_top_logprobs_val, state.input_top_logprobs_idx, @@ -1069,12 +1019,10 @@ def convert_logprob_style( state.input_token_ids_logprobs_idx, return_text_in_logprobs, ) - meta_info["output_token_ids_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - state.output_token_ids_logprobs_val, - state.output_token_ids_logprobs_idx, - return_text_in_logprobs, - ) + meta_info["output_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens( + state.output_token_ids_logprobs_val, + state.output_token_ids_logprobs_idx, + return_text_in_logprobs, ) def detokenize_logprob_tokens( @@ -1113,9 +1061,7 @@ def detokenize_top_logprobs_tokens( return ret def dump_requests(self, state: ReqState, out_dict: dict): - self.dump_request_list.append( - (state.obj, out_dict, state.created_time, time.time()) - ) + self.dump_request_list.append((state.obj, out_dict, state.created_time, time.time())) if len(self.dump_request_list) >= self.dump_requests_threshold: filename = os.path.join( @@ -1142,9 +1088,7 @@ def background_task(): def record_request_for_crash_dump(self, state: ReqState, out_dict: dict): current_time = time.time() - self.crash_dump_request_list.append( - (state.obj, out_dict, state.created_time, current_time) - ) + self.crash_dump_request_list.append((state.obj, out_dict, state.created_time, current_time)) # Remove requests older than 5 minutes based on finish time while ( self.crash_dump_request_list @@ -1236,9 +1180,7 @@ async def score_request( sampling_params={"max_new_tokens": 1}, ) else: - raise ValueError( - "Invalid combination of query/items types for score_request." - ) + raise ValueError("Invalid combination of query/items types for score_request.") results = await self.generate_request(batch_request, request).__anext__() scores = [] @@ -1246,25 +1188,19 @@ async def score_request( for result in results: # Get logprobs for each token logprobs = {} - for logprob, token_id, _ in result["meta_info"].get( - "output_token_ids_logprobs", [] - )[0]: + for logprob, token_id, _ in result["meta_info"].get("output_token_ids_logprobs", [])[0]: if token_id in label_token_ids: logprobs[token_id] = logprob # Get scores in order of label_token_ids - score_list = [ - logprobs.get(token_id, float("-inf")) for token_id in label_token_ids - ] + score_list = [logprobs.get(token_id, float("-inf")) for token_id in label_token_ids] # Apply softmax to logprobs if needed if apply_softmax: score_list = jax.nn.softmax(jnp.asarray(score_list), axis=0).tolist() else: # Convert logprobs to probabilities if not using softmax - score_list = [ - math.exp(x) if x != float("-inf") else 0.0 for x in score_list - ] + score_list = [math.exp(x) if x != float("-inf") else 0.0 for x in score_list] scores.append(score_list) @@ -1300,9 +1236,7 @@ def sigterm_handler(self, signum=None, frame=None): self.tokenizer_manager.gracefully_exit = True def running_phase_sigquit_handler(self, signum=None, frame=None): - logger.error( - "Received sigquit from a child process. It usually means the child failed." - ) + logger.error("Received sigquit from a child process. It usually means the child failed.") self.tokenizer_manager.dump_requests_before_crash() kill_process_tree(os.getpid()) diff --git a/python/sgl_jax/srt/managers/tp_worker.py b/python/sgl_jax/srt/managers/tp_worker.py index 1be5ada87..0e36a95cb 100644 --- a/python/sgl_jax/srt/managers/tp_worker.py +++ b/python/sgl_jax/srt/managers/tp_worker.py @@ -63,9 +63,7 @@ def __init__( # Each process may have different random_seed. After broadcast, all processes will have the same random_seed. if server_args.random_seed is None: with jax.default_device(jax.local_devices()[0]): - seed_to_broadcast = ( - server_args.random_seed if jax.process_index() == 0 else 0 - ) + seed_to_broadcast = server_args.random_seed if jax.process_index() == 0 else 0 self.random_seed = broadcast_one_to_all(seed_to_broadcast).item() else: self.random_seed = server_args.random_seed @@ -127,9 +125,7 @@ def __init__( self.max_total_num_tokens - 1, ) self.max_req_input_len = self.max_req_len - 5 - assert ( - self.max_req_len > 0 and self.max_req_input_len > 0 - ), "Memory pool size is too small" + assert self.max_req_len > 0 and self.max_req_input_len > 0, "Memory pool size is too small" # Sync random seed across TP workers # Each process may have different random_seed. After broadcast, all processes will have the same random_seed. @@ -138,9 +134,7 @@ def __init__( # A reference make this class has the same member as TpModelWorkerClient self.worker = self - self.max_padded_batch_size, self.max_padded_num_tokens = ( - self.get_max_padded_size() - ) + self.max_padded_batch_size, self.max_padded_num_tokens = self.get_max_padded_size() # precompile self.precompile_token_paddings = server_args.precompile_token_paddings @@ -168,9 +162,7 @@ def __init__( # padding cache_loc_paddings # note: the length of following two cache_loc_paddings must keep the same to length of separate bs_paddings. self.precompile_cache_loc_paddings = [ - (item * self.max_req_len + self.page_size - 1) - // self.page_size - * self.page_size + (item * self.max_req_len + self.page_size - 1) // self.page_size * self.page_size for item in self.precompile_bs_paddings ] @@ -180,10 +172,7 @@ def normalize_token_paddings(self): if self.precompile_token_paddings is None: self.precompile_token_paddings = PRECOMPILE_DEFAULT_TOKEN_PADDINGS for item in self.precompile_token_paddings: - if ( - item >= self.max_padded_batch_size - and item <= self.max_padded_num_tokens - ): + if item >= self.max_padded_batch_size and item <= self.max_padded_num_tokens: normalized_token_paddings.append(item) normalized_token_paddings.sort() @@ -217,9 +206,7 @@ def precompile_extend(self, future_token_ids_map=None): bs, num_tokens = pair[0], pair[1] pbar.set_postfix(bs=bs, tokens=num_tokens) if bs > num_tokens: - logger.warning( - "bs=%s > num_tokens=%s, skip this pair", bs, num_tokens - ) + logger.warning("bs=%s > num_tokens=%s, skip this pair", bs, num_tokens) continue model_worker_batch = self.generate_model_worker_batch( bs, @@ -234,16 +221,12 @@ def precompile_extend(self, future_token_ids_map=None): model_worker_batch, self.model_runner ) if future_token_ids_map is not None: - model_worker_batch.forward_batch.input_ids = ( - resolve_future_token_ids( - model_worker_batch.forward_batch.input_ids, - future_token_ids_map, - ) + model_worker_batch.forward_batch.input_ids = resolve_future_token_ids( + model_worker_batch.forward_batch.input_ids, + future_token_ids_map, ) - self.forward_batch_generation( - model_worker_batch, None, False, sampling_metadata - ) + self.forward_batch_generation(model_worker_batch, None, False, sampling_metadata) end_time = time.perf_counter() logger.info("[EXTEND] Precompile finished in %.0f secs", end_time - start_time) @@ -254,16 +237,12 @@ def precompile_decode(self, future_token_ids_map=None): self.precompile_bs_paddings, ) - with tqdm( - self.precompile_bs_paddings, desc="[DECODE] PRECOMPILE", leave=False - ) as pbar: + with tqdm(self.precompile_bs_paddings, desc="[DECODE] PRECOMPILE", leave=False) as pbar: for bs in pbar: pbar.set_postfix(bs=bs) # use same page aligned with precompile cache_loc_paddings aligned_cache_loc_size = ( - (bs * self.max_req_len + self.page_size - 1) - // self.page_size - * self.page_size + (bs * self.max_req_len + self.page_size - 1) // self.page_size * self.page_size ) model_worker_batch = self.generate_model_worker_batch( bs, @@ -278,11 +257,9 @@ def precompile_decode(self, future_token_ids_map=None): model_worker_batch, self.model_runner ) if future_token_ids_map is not None: - model_worker_batch.forward_batch.input_ids = ( - resolve_future_token_ids( - model_worker_batch.forward_batch.input_ids, - future_token_ids_map, - ) + model_worker_batch.forward_batch.input_ids = resolve_future_token_ids( + model_worker_batch.forward_batch.input_ids, + future_token_ids_map, ) _, next_token_ids, _ = self.forward_batch_generation( model_worker_batch, None, False, sampling_metadata @@ -301,17 +278,13 @@ def precompile_penalties(self, future_token_ids_map=None): self.precompile_bs_paddings, ) - with tqdm( - self.precompile_bs_paddings, desc="[PENALTIES] PRECOMPILE", leave=False - ) as pbar: + with tqdm(self.precompile_bs_paddings, desc="[PENALTIES] PRECOMPILE", leave=False) as pbar: for bs in pbar: pbar.set_postfix(bs=bs) # Create model worker batch aligned_cache_loc_size = ( - (bs * self.max_req_len + self.page_size - 1) - // self.page_size - * self.page_size + (bs * self.max_req_len + self.page_size - 1) // self.page_size * self.page_size ) model_worker_batch = self.generate_model_worker_batch( bs, @@ -323,10 +296,8 @@ def precompile_penalties(self, future_token_ids_map=None): # Create sampling metadata with all penalty shapes for comprehensive compilation # This ensures JAX compiles all penalty application branches in lax.cond - sampling_metadata = ( - SamplingMetadata.from_model_worker_batch_for_precompile( - model_worker_batch, 0, self.mesh - ) + sampling_metadata = SamplingMetadata.from_model_worker_batch_for_precompile( + model_worker_batch, 0, self.mesh ) # Initialize forward batch @@ -335,11 +306,9 @@ def precompile_penalties(self, future_token_ids_map=None): ) if future_token_ids_map is not None: - model_worker_batch.forward_batch.input_ids = ( - resolve_future_token_ids( - model_worker_batch.forward_batch.input_ids, - future_token_ids_map, - ) + model_worker_batch.forward_batch.input_ids = resolve_future_token_ids( + model_worker_batch.forward_batch.input_ids, + future_token_ids_map, ) # Run forward with penalty application @@ -351,15 +320,11 @@ def precompile_penalties(self, future_token_ids_map=None): set_future_token_ids(future_token_ids_map, 0, next_token_ids) end_time = time.perf_counter() - logger.info( - "[PENALTIES] Precompile finished in %.0f secs", end_time - start_time - ) + logger.info("[PENALTIES] Precompile finished in %.0f secs", end_time - start_time) def set_forward_metadata(self, model_worker_batch: ModelWorkerBatch): self.model_runner.attn_backend.forward_metadata = ( - self.worker.model_runner.attn_backend.get_forward_metadata( - model_worker_batch - ) + self.worker.model_runner.attn_backend.get_forward_metadata(model_worker_batch) ) def get_max_padded_size(self): @@ -373,10 +338,7 @@ def get_max_padded_size(self): # Use chunked prefill size if enabled (> 0), otherwise use max prefill tokens # Take minimum with max_prefill_tokens as upper bound max_padded_num_tokens = self.max_prefill_tokens - if ( - self.chunked_prefill_size > 0 - and max_padded_num_tokens > self.chunked_prefill_size - ): + if self.chunked_prefill_size > 0 and max_padded_num_tokens > self.chunked_prefill_size: max_padded_num_tokens = self.chunked_prefill_size # Batch size is constrained by both max_running_requests and available tokens divide by page_size @@ -420,9 +382,7 @@ def generate_model_worker_batch( real_bs=bs, req_pool_indices=np.arange(bs, dtype=np.int32), seq_lens=np.array([1] * bs, dtype=np.int32), - out_cache_loc=np.concat( - [valid_out_cache_loc, invalid_out_cache_loc], axis=0 - ), + out_cache_loc=np.concat([valid_out_cache_loc, invalid_out_cache_loc], axis=0), return_logprob=False, sampling_info=SamplingBatchInfo.generate_for_precompile( bs, self.model_config.vocab_size, do_penalties=do_penalties @@ -431,9 +391,7 @@ def generate_model_worker_batch( positions=np.concat([valid_positions, invalid_positions], axis=0), extend_start_loc=np.arange(bs, dtype=np.int64), cache_loc=np.concat([valid_cache_loc, invalid_cache_loc], axis=0), - extend_prefix_lens=( - np.array([0] * bs) if mode == ForwardMode.EXTEND else None - ), + extend_prefix_lens=(np.array([0] * bs) if mode == ForwardMode.EXTEND else None), extend_seq_lens=np.array([1] * bs) if mode == ForwardMode.EXTEND else None, top_logprobs_nums=None, token_ids_logprobs=None, @@ -483,10 +441,8 @@ def forward_batch_generation( forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) if forward_metadata is None: - forward_metadata = ( - self.worker.model_runner.attn_backend.get_forward_metadata( - model_worker_batch - ) + forward_metadata = self.worker.model_runner.attn_backend.get_forward_metadata( + model_worker_batch ) if sampling_metadata is None: @@ -514,9 +470,7 @@ def forward_batch_generation( ) logits_output, cache_miss_count = self.model_runner.forward( forward_batch, - logits_metadata=LogitsMetadata.from_model_worker_batch( - model_worker_batch, self.mesh - ), + logits_metadata=LogitsMetadata.from_model_worker_batch(model_worker_batch, self.mesh), ) if launch_done is not None: @@ -590,9 +544,7 @@ def __init__( self.max_total_num_tokens - 1, ) self.max_req_input_len = self.max_req_len - 5 - assert ( - self.max_req_len > 0 and self.max_req_input_len > 0 - ), "Memory pool size is too small" + assert self.max_req_len > 0 and self.max_req_input_len > 0, "Memory pool size is too small" # A reference make this class has the same member as TpModelWorkerClient self.worker = self diff --git a/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py b/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py index 9530d13af..c781b952a 100644 --- a/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py +++ b/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py @@ -38,9 +38,7 @@ def __init__( # Init future mappings self.future_token_ids_ct = 0 self.future_token_ids_limit = self.max_running_requests * 3 - self.future_token_ids_map = jnp.zeros( - (self.max_running_requests * 5,), dtype=jnp.int32 - ) + self.future_token_ids_map = jnp.zeros((self.max_running_requests * 5,), dtype=jnp.int32) self.mesh = mesh sharding = NamedSharding(mesh, PartitionSpec(None)) self.future_token_ids_map = jax.device_put(self.future_token_ids_map, sharding) @@ -112,13 +110,11 @@ def forward_thread_func_(self): ) # Run forward - logits_output, next_token_ids, cache_miss_count = ( - self.worker.forward_batch_generation( - model_worker_batch, - model_worker_batch.launch_done, - sampling_metadata=sampling_metadata, - forward_metadata=forward_metadata, - ) + logits_output, next_token_ids, cache_miss_count = self.worker.forward_batch_generation( + model_worker_batch, + model_worker_batch.launch_done, + sampling_metadata=sampling_metadata, + forward_metadata=forward_metadata, ) # Update the future token ids map @@ -126,9 +122,7 @@ def forward_thread_func_(self): self.future_token_ids_map, future_token_ids_ct, next_token_ids ) - self.output_queue.put( - (None, logits_output, next_token_ids, cache_miss_count) - ) + self.output_queue.put((None, logits_output, next_token_ids, cache_miss_count)) def resolve_last_batch_result(self, launch_done: threading.Event | None = None): """ @@ -197,9 +191,7 @@ def forward_batch_generation( -1, dtype=np.int32, ) - self.future_token_ids_ct = ( - self.future_token_ids_ct + bs - ) % self.future_token_ids_limit + self.future_token_ids_ct = (self.future_token_ids_ct + bs) % self.future_token_ids_limit return None, future_next_token_ids, 0 def run_precompile(self): diff --git a/python/sgl_jax/srt/managers/utils.py b/python/sgl_jax/srt/managers/utils.py index 19c5d8e4f..2c2aa9ee0 100644 --- a/python/sgl_jax/srt/managers/utils.py +++ b/python/sgl_jax/srt/managers/utils.py @@ -53,6 +53,4 @@ def resolve_future_token_ids(input_ids, future_token_ids_map): @jax.jit def set_future_token_ids(future_token_ids_map, future_token_ids_ct, next_token_ids): start_indices = (future_token_ids_ct + 1,) - return jax.lax.dynamic_update_slice( - future_token_ids_map, next_token_ids, start_indices - ) + return jax.lax.dynamic_update_slice(future_token_ids_map, next_token_ids, start_indices) diff --git a/python/sgl_jax/srt/mem_cache/allocator.py b/python/sgl_jax/srt/mem_cache/allocator.py index ec27689c2..3d1dc3592 100644 --- a/python/sgl_jax/srt/mem_cache/allocator.py +++ b/python/sgl_jax/srt/mem_cache/allocator.py @@ -144,9 +144,7 @@ def __init__( def alloc(self, need_size: int) -> np.ndarray | None: # page-aligned allocation, returning contiguous indices of pages - assert ( - need_size % self.page_size == 0 - ), "The allocation size should be page-aligned" + assert need_size % self.page_size == 0, "The allocation size should be page-aligned" num_pages = need_size // self.page_size if num_pages > len(self.free_pages): @@ -218,12 +216,8 @@ def alloc_extend( part1_size = min(extend_len, current_page_capacity - pre_len) if part1_size > 0: - part1_indices = np.arange( - last_loc + 1, last_loc + 1 + part1_size, dtype=np.int32 - ) - out_indices[current_output_idx : current_output_idx + part1_size] = ( - part1_indices - ) + part1_indices = np.arange(last_loc + 1, last_loc + 1 + part1_size, dtype=np.int32) + out_indices[current_output_idx : current_output_idx + part1_size] = part1_indices current_output_idx += part1_size remaining_tokens = extend_len - part1_size @@ -240,9 +234,9 @@ def alloc_extend( part2_indices = np.arange( page_start, page_start + self.page_size, dtype=np.int32 ) - out_indices[ - current_output_idx : current_output_idx + self.page_size - ] = part2_indices + out_indices[current_output_idx : current_output_idx + self.page_size] = ( + part2_indices + ) current_output_idx += self.page_size page_idx += 1 @@ -250,12 +244,10 @@ def alloc_extend( remaining_tokens -= part2_size if remaining_tokens > 0: page_start = allocated_pages[page_idx] * self.page_size - part3_indices = np.arange( - page_start, page_start + remaining_tokens, dtype=np.int32 + part3_indices = np.arange(page_start, page_start + remaining_tokens, dtype=np.int32) + out_indices[current_output_idx : current_output_idx + remaining_tokens] = ( + part3_indices ) - out_indices[ - current_output_idx : current_output_idx + remaining_tokens - ] = part3_indices current_output_idx += remaining_tokens page_idx += 1 # page_idx is the number of new pages allocated diff --git a/python/sgl_jax/srt/mem_cache/memory_pool.py b/python/sgl_jax/srt/mem_cache/memory_pool.py index 60d4acbc3..f3dde6344 100644 --- a/python/sgl_jax/srt/mem_cache/memory_pool.py +++ b/python/sgl_jax/srt/mem_cache/memory_pool.py @@ -17,9 +17,7 @@ def merge_kv(k: jax.Array, v: jax.Array) -> jax.Array: - assert ( - k.shape == v.shape - ), f"k and v must have same shape, got {k.shape} vs {v.shape}" + assert k.shape == v.shape, f"k and v must have same shape, got {k.shape} vs {v.shape}" num_tokens, num_kv_heads, head_dim = k.shape @@ -215,9 +213,7 @@ def __init__( start_layer: int | None = None, end_layer: int | None = None, ): - super().__init__( - size, page_size, dtype, layer_num, mesh, start_layer, end_layer - ) + super().__init__(size, page_size, dtype, layer_num, mesh, start_layer, end_layer) self.head_num = head_num self.head_dim = head_dim self.kv_partition_axis = "tensor" @@ -390,9 +386,7 @@ def set_kv_buffer( kv_partition_axis=self.kv_partition_axis, ) - def get_kv_data( - self, layer_id: int, indices: jnp.ndarray - ) -> tuple[jnp.ndarray, jnp.ndarray]: + def get_kv_data(self, layer_id: int, indices: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: """Get KV data at specified positions""" layer_idx = layer_id - self.start_layer fused_kv_data = self.kv_buffer[layer_idx][indices] @@ -418,9 +412,7 @@ def load_cpu_copy(self, kv_cache_host, indices): # Merge k and v into fused format fused_kv_host = merge_kv(k_host, v_host) fused_kv_device = jax.device_put(fused_kv_host, self.kv_sharding) - self.kv_buffer[layer_id] = ( - self.kv_buffer[layer_id].at[indices].set(fused_kv_device) - ) + self.kv_buffer[layer_id] = self.kv_buffer[layer_id].at[indices].set(fused_kv_device) def move_kv_cache(self, tgt_loc: jnp.ndarray, src_loc: jnp.ndarray): """Move fused KV cache from source locations to target locations""" @@ -428,9 +420,7 @@ def move_kv_cache(self, tgt_loc: jnp.ndarray, src_loc: jnp.ndarray): # Get fused KV data from source locations fused_kv_data = self.kv_buffer[layer_id][src_loc] # Set data to target locations - self.kv_buffer[layer_id] = ( - self.kv_buffer[layer_id].at[tgt_loc].set(fused_kv_data) - ) + self.kv_buffer[layer_id] = self.kv_buffer[layer_id].at[tgt_loc].set(fused_kv_data) def clear_cache(self, indices: jnp.ndarray): """Clear fused KV cache at specified indices""" @@ -454,9 +444,7 @@ def set_kv_buffer_legacy( N = self.kv_buffer[layer_idx].shape[0] safe_loc = jnp.where(loc >= 0, loc, jnp.int32(N)) # for jax function - updated_layer = ( - self.kv_buffer[layer_idx].at[safe_loc].set(fused_kv, mode="drop") - ) + updated_layer = self.kv_buffer[layer_idx].at[safe_loc].set(fused_kv, mode="drop") return updated_layer @@ -587,17 +575,13 @@ def get_num_slices_per_block(new_kv: jax.Array, kv_cache: jax.Array, page_size=1 kv_head_num = new_kv.shape[1] head_dim = new_kv.shape[2] - max_num_slices_per_block = VMEM_SIZE // ( - bytes_per_element * page_size * kv_head_num * head_dim - ) + max_num_slices_per_block = VMEM_SIZE // (bytes_per_element * page_size * kv_head_num * head_dim) assert ( max_num_slices_per_block > 0 ), f"max_num_slices_per_block={max_num_slices_per_block} is not greater than 0" return ( - total_num_token - if total_num_token < max_num_slices_per_block - else max_num_slices_per_block + total_num_token if total_num_token < max_num_slices_per_block else max_num_slices_per_block ) @@ -619,13 +603,9 @@ def kv_cache_update( ): @jax.shard_map( in_specs=( - P( - None, kv_partition_axis, None - ), # new_kv - consistent with KV cache sharding + P(None, kv_partition_axis, None), # new_kv - consistent with KV cache sharding P(None, None), # slices - P( - None, kv_partition_axis, None - ), # kv_cache - consistent with KV cache sharding + P(None, kv_partition_axis, None), # kv_cache - consistent with KV cache sharding P(None), # num_kv_update_slices ), out_specs=P( @@ -879,13 +859,7 @@ def find_value(lst, target_num) -> int: hd_val = find_value(head_dim_config, head_dim) ps_val = find_value(page_size_config, page_size) - if ( - hn_val != -1 - and mcl_val != -1 - and nkl_val != -1 - and hd_val != -1 - and ps_val != -1 - ): + if hn_val != -1 and mcl_val != -1 and nkl_val != -1 and hd_val != -1 and ps_val != -1: return best_num_slices_per_block_config[ f"hn_{hn_val}_mcl_{mcl_val}_nvl_{nkl_val}_hd_{hd_val}_ps_{ps_val}" ] @@ -912,9 +886,7 @@ def __init__( start_layer: int | None = None, end_layer: int | None = None, ): - super().__init__( - size, page_size, dtype, layer_num, mesh, start_layer, end_layer - ) + super().__init__(size, page_size, dtype, layer_num, mesh, start_layer, end_layer) self.kv_lora_rank = kv_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.kv_partition_axis = kv_partition_axis @@ -988,15 +960,11 @@ def get_kv_buffer(self, layer_id: int) -> tuple[jnp.ndarray, jnp.ndarray]: - v_buffer contains the qk_rope_head_dim portion """ layer_idx = layer_id - self.start_layer - mla_kv = self.kv_buffer[ - layer_idx - ] # [cache_size, 1, kv_lora_rank + qk_rope_head_dim] + mla_kv = self.kv_buffer[layer_idx] # [cache_size, 1, kv_lora_rank + qk_rope_head_dim] # Split MLA KV buffer into K and V components for native attention k_buffer = mla_kv[:, :, : self.kv_lora_rank] # [cache_size, 1, kv_lora_rank] - v_buffer = mla_kv[ - :, :, self.kv_lora_rank : - ] # [cache_size, 1, qk_rope_head_dim] + v_buffer = mla_kv[:, :, self.kv_lora_rank :] # [cache_size, 1, qk_rope_head_dim] return k_buffer, v_buffer @@ -1023,9 +991,7 @@ def set_mla_kv_buffer( layer_idx = layer_id - self.start_layer # Concatenate nope and rope components cache_k_combined = jnp.concatenate([cache_k_nope, cache_k_rope], axis=-1) - self.kv_buffer[layer_idx] = ( - self.kv_buffer[layer_idx].at[loc].set(cache_k_combined) - ) + self.kv_buffer[layer_idx] = self.kv_buffer[layer_idx].at[loc].set(cache_k_combined) def get_cpu_copy(self, indices): """Get CPU copy of KV cache for specified indices""" @@ -1040,9 +1006,7 @@ def load_cpu_copy(self, kv_cache_host, indices): for layer_id in range(self.layer_num): kv_host = kv_cache_host[layer_id] kv_device = jax.device_put(kv_host, self.kv_sharding) - self.kv_buffer[layer_id] = ( - self.kv_buffer[layer_id].at[indices].set(kv_device) - ) + self.kv_buffer[layer_id] = self.kv_buffer[layer_id].at[indices].set(kv_device) best_num_slices_per_block_config = { diff --git a/python/sgl_jax/srt/mem_cache/radix_cache.py b/python/sgl_jax/srt/mem_cache/radix_cache.py index e55c6c154..56d55058b 100644 --- a/python/sgl_jax/srt/mem_cache/radix_cache.py +++ b/python/sgl_jax/srt/mem_cache/radix_cache.py @@ -105,9 +105,7 @@ def __init__( if self.page_size == 1: self.key_match_fn = _key_match_page_size1 - self.get_child_key_fn = lambda key: ( - int(key[0]) if hasattr(key[0], "item") else key[0] - ) + self.get_child_key_fn = lambda key: (int(key[0]) if hasattr(key[0], "item") else key[0]) else: self.key_match_fn = partial(_key_match_paged, page_size=page_size) # Ensure returning hashable types, convert numpy arrays to Python native types @@ -204,12 +202,8 @@ def cache_finished_req(self, req): page_aligned_kv_indices = kv_indices # Radix Cache takes over one reference from memory pool - new_prefix_len = self.insert( - token_ids[:page_aligned_len], page_aligned_kv_indices - ) - self.token_to_kv_pool_allocator.free( - kv_indices[len(req.prefix_indices) : new_prefix_len] - ) + new_prefix_len = self.insert(token_ids[:page_aligned_len], page_aligned_kv_indices) + self.token_to_kv_pool_allocator.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) # Remove request slot and release cache lock self.req_to_token_pool.free(req.req_pool_idx) @@ -233,9 +227,7 @@ def cache_unfinished_req(self, req): # Radix Cache takes over one reference from memory pool new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices) - self.token_to_kv_pool_allocator.free( - kv_indices[len(req.prefix_indices) : new_prefix_len] - ) + self.token_to_kv_pool_allocator.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) # Prefix indices may have been updated, reuse them new_match_result = self.match_prefix(page_aligned_token_ids) @@ -253,9 +245,7 @@ def cache_unfinished_req(self, req): # `req.prefix_indices` will be used later in `PrefillAdder::add_chunked_req` if self.page_size != 1: # create array on CPU - req.prefix_indices = np.concat( - [new_indices, kv_indices[len(new_indices) :]] - ) + req.prefix_indices = np.concat([new_indices, kv_indices[len(new_indices) :]]) # with jax.default_device(self.cpu_device): # kv_indices_cpu = jax.device_put(kv_indices, self.cpu_device) # req.prefix_indices = jnp.concatenate( diff --git a/python/sgl_jax/srt/memory_profiler.py b/python/sgl_jax/srt/memory_profiler.py index 4ec3086ff..c79bd21b7 100644 --- a/python/sgl_jax/srt/memory_profiler.py +++ b/python/sgl_jax/srt/memory_profiler.py @@ -104,11 +104,7 @@ def _save_memory_snapshot(filename: str, condition: bool = True): return try: - output_path = ( - os.path.join(_config.output_dir, filename) - if _config.output_dir - else filename - ) + output_path = os.path.join(_config.output_dir, filename) if _config.output_dir else filename jax.profiler.save_device_memory_profile(output_path) except Exception as e: logger.warning("Failed to save memory snapshot %s: %s", filename, e) @@ -201,9 +197,7 @@ def _create_memory_report( ) for name, info in sorted_tensors: - percentage = ( - (info["memory_mb"] / total_memory) * 100 if total_memory > 0 else 0 - ) + percentage = (info["memory_mb"] / total_memory) * 100 if total_memory > 0 else 0 shape_str = "x".join(map(str, info["shape"])) f.write( f"{name:<25}: {info['memory_mb']:>8.2f} MB ({percentage:>5.1f}%) " @@ -231,9 +225,7 @@ def _create_memory_report( with open(json_report_path, "w") as f: json.dump(json_report, f, indent=2) - logger.debug( - " Generated memory reports: %s, %s", report_path, json_report_path - ) + logger.debug(" Generated memory reports: %s, %s", report_path, json_report_path) except Exception as e: logger.warning("Failed to create memory report for %s: %s", stage, e) @@ -359,9 +351,7 @@ def move_reports_to_output_dir(): os.rename(prof_file, dest_path) moved_files.append(dest_path) - for report_file in glob.glob("memory_report_*.txt") + glob.glob( - "memory_report_*.json" - ): + for report_file in glob.glob("memory_report_*.txt") + glob.glob("memory_report_*.json"): if not report_file.startswith(_config.output_dir): dest_path = os.path.join(_config.output_dir, report_file) os.rename(report_file, dest_path) diff --git a/python/sgl_jax/srt/model_executor/model_runner.py b/python/sgl_jax/srt/model_executor/model_runner.py index 3615226bf..faaa92c14 100644 --- a/python/sgl_jax/srt/model_executor/model_runner.py +++ b/python/sgl_jax/srt/model_executor/model_runner.py @@ -71,9 +71,7 @@ def __init__( self.mesh = mesh # model args self.num_attn_heads = model_config.num_attention_heads - self.num_kv_heads = model_config.get_total_num_kv_heads_with_replication( - tp_size - ) + self.num_kv_heads = model_config.get_total_num_kv_heads_with_replication(tp_size) self.rngs = rngs self.tp_size = tp_size @@ -101,9 +99,7 @@ def __init__( ) # Initialize precision tracer enable state - precision_tracer.set_enable_precision_tracer( - server_args.enable_precision_tracer - ) + precision_tracer.set_enable_precision_tracer(server_args.enable_precision_tracer) # If it is a draft model, tp_group can be different self.initialize() @@ -141,9 +137,7 @@ def initialize_jit(self): model_def, model_state = nnx.split(self.model) model_state_leaves, model_state_def = jax.tree_util.tree_flatten(model_state) sampler_def, sampler_state = nnx.split(self.sampler) - sampler_state_leaves, sampler_state_def = jax.tree_util.tree_flatten( - sampler_state - ) + sampler_state_leaves, sampler_state_def = jax.tree_util.tree_flatten(sampler_state) @partial( jax.jit, @@ -158,17 +152,13 @@ def jitted_run_model( token_to_kv_pool, logits_metadata, ): - model_state = jax.tree_util.tree_unflatten( - model_state_def, model_state_leaves - ) + model_state = jax.tree_util.tree_unflatten(model_state_def, model_state_leaves) model = nnx.merge(model_def, model_state) return model(forward_batch, token_to_kv_pool, logits_metadata) @partial(jax.jit, static_argnames=["sampler_state_def"]) def jitted_sampler(sampler_def, sampler_state_def, sampler_state_leaves, *args): - model_state = jax.tree_util.tree_unflatten( - sampler_state_def, sampler_state_leaves - ) + model_state = jax.tree_util.tree_unflatten(sampler_state_def, sampler_state_leaves) sampler = nnx.merge(sampler_def, model_state) return sampler(*args) @@ -190,9 +180,7 @@ def run_model_wrapper(forward_batch, logits_metadata): ) def get_available_device_memory(self): - min_available_device_memory = get_available_device_memory( - self.device, distributed=False - ) + min_available_device_memory = get_available_device_memory(self.device, distributed=False) # Check memory for tensor parallelism local_device_memory = get_available_device_memory(self.device) @@ -222,9 +210,7 @@ def load_model(self): self.dtype = self.model_config.dtype self.start_layer = getattr(self.model, "start_layer", 0) - self.end_layer = getattr( - self.model, "end_layer", self.model_config.num_hidden_layers - ) + self.end_layer = getattr(self.model, "end_layer", self.model_config.num_hidden_layers) self.num_effective_layers = self.end_layer - self.start_layer def profile_max_num_token(self, total_device_memory: int): @@ -240,9 +226,7 @@ def profile_max_num_token(self, total_device_memory: int): ) if available_kv_cache_bytes <= 0: - raise RuntimeError( - "Not enough memory. Please try to increase --mem-fraction-static." - ) + raise RuntimeError("Not enough memory. Please try to increase --mem-fraction-static.") cell_size = ( self.model_config.get_num_kv_heads(self.tp_size) @@ -278,9 +262,7 @@ def init_memory_pool( elif self.server_args.kv_cache_dtype == "bf16": self.kv_cache_dtype = jnp.bfloat16 else: - raise ValueError( - f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." - ) + raise ValueError(f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}.") logger.info("ModelRunner kv_cache_dtype: %s", self.kv_cache_dtype) # Profile maximum number of tokens self.max_total_num_tokens = self.profile_max_num_token(total_device_memory) @@ -289,9 +271,7 @@ def init_memory_pool( if max_num_reqs is None: max_num_reqs = min( max( - int( - self.max_total_num_tokens / self.model_config.context_len * 512 - ), + int(self.max_total_num_tokens / self.model_config.context_len * 512), 2048, ), 4096, @@ -314,15 +294,11 @@ def init_memory_pool( # Align to page size self.max_total_num_tokens = ( - self.max_total_num_tokens - // self.server_args.page_size - * self.server_args.page_size + self.max_total_num_tokens // self.server_args.page_size * self.server_args.page_size ) if self.max_total_num_tokens <= 0: - raise RuntimeError( - "Not enough memory. Please try to increase --mem-fraction-static." - ) + raise RuntimeError("Not enough memory. Please try to increase --mem-fraction-static.") logger.info("ModelRunner max_total_num_tokens: %s", self.max_total_num_tokens) @@ -339,9 +315,7 @@ def init_memory_pool( size=self.max_total_num_tokens, page_size=self.page_size, dtype=self.kv_cache_dtype, - head_num=self.model_config.get_total_num_kv_heads_with_replication( - self.tp_size - ), + head_num=self.model_config.get_total_num_kv_heads_with_replication(self.tp_size), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, mesh=self.mesh, @@ -393,9 +367,7 @@ def _get_attention_backend(self): mesh=self.mesh, ) else: - raise ValueError( - f"Unsupported attention backend: {self.server_args.attention_backend}" - ) + raise ValueError(f"Unsupported attention backend: {self.server_args.attention_backend}") def _forward( self, @@ -406,9 +378,7 @@ def _forward( import jax._src.test_util as jtu with jtu.count_pjit_cpp_cache_miss() as count: - output, layers_kv_fused, _ = self.jitted_run_model( - forward_batch, logits_metadata - ) + output, layers_kv_fused, _ = self.jitted_run_model(forward_batch, logits_metadata) cache_miss_count = count() self._set_kv_cache_after_forward(layers_kv_fused, forward_batch) @@ -465,10 +435,7 @@ def _forward_raw( except AttributeError: ctx = self.mesh with ctx: - if ( - forward_batch.forward_mode.is_decode() - or forward_batch.forward_mode.is_extend() - ): + if forward_batch.forward_mode.is_decode() or forward_batch.forward_mode.is_extend(): ret = self._forward(forward_batch, logits_metadata) elif forward_batch.forward_mode.is_idle(): ret = self.forward_idle(forward_batch, logits_metadata) @@ -516,9 +483,7 @@ def __init__( self.num_attn_heads = model_config.num_heads self.rngs = rngs else: - self.num_kv_heads = model_config.get_total_num_kv_heads_with_replication( - self.tp_size - ) + self.num_kv_heads = model_config.get_total_num_kv_heads_with_replication(self.tp_size) self.num_attn_heads = model_config.num_attention_heads self.rngs = rngs @@ -552,9 +517,7 @@ def __init__( size=self.max_total_num_tokens, page_size=self.page_size, dtype=self.kv_cache_dtype, - head_num=self.model_config.get_total_num_kv_heads_with_replication( - self.tp_size - ), + head_num=self.model_config.get_total_num_kv_heads_with_replication(self.tp_size), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, mesh=mesh, diff --git a/python/sgl_jax/srt/model_loader/arch.py b/python/sgl_jax/srt/model_loader/arch.py index 09a335cbb..11a4ba36f 100644 --- a/python/sgl_jax/srt/model_loader/arch.py +++ b/python/sgl_jax/srt/model_loader/arch.py @@ -15,9 +15,7 @@ def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str for i, arch in enumerate(architectures): if arch == "TransformersForCausalLM": continue - auto_map: dict[str, str] = ( - getattr(model_config.hf_config, "auto_map", None) or dict() - ) + auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", None) or dict() # Make sure that config class is always initialized before model class, # otherwise the model class won't be able to access the config class, # the expected auto_map should have correct order like: diff --git a/python/sgl_jax/srt/model_loader/loader.py b/python/sgl_jax/srt/model_loader/loader.py index 6b9033466..16643ffbc 100644 --- a/python/sgl_jax/srt/model_loader/loader.py +++ b/python/sgl_jax/srt/model_loader/loader.py @@ -57,9 +57,7 @@ def init_new(cls, model_config: ModelConfig): model_config.revision, ) - def __init__( - self, load_config: LoadConfig, rngs: jax.Array, mesh: jax.sharding.Mesh - ): + def __init__(self, load_config: LoadConfig, rngs: jax.Array, mesh: jax.sharding.Mesh): super().__init__(load_config) self.rng = rngs self.mesh = mesh @@ -107,9 +105,7 @@ def _get_model(self, model_class: Any, model_config: ModelConfig) -> nnx.Module: model.load_weights(model_config, self.rng.default.key.value) return model - def _maybe_download_from_modelscope( - self, model: str, revision: str | None - ) -> str | None: + def _maybe_download_from_modelscope(self, model: str, revision: str | None) -> str | None: if get_bool_env_var("SGLANG_USE_MODELSCOPE"): # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. @@ -157,9 +153,7 @@ def _prepare_weights( class JAXDummyModelLoader(BaseModelLoader): """Model loader that will set model weights to random values for JAX models.""" - def __init__( - self, load_config: LoadConfig, rngs: jax.Array, mesh: jax.sharding.Mesh - ): + def __init__(self, load_config: LoadConfig, rngs: jax.Array, mesh: jax.sharding.Mesh): super().__init__(load_config) if load_config.model_loader_extra_config: raise ValueError( @@ -235,9 +229,7 @@ def _preserve_rope_caches(path, old, new): return old return new - new_params = jax.tree_util.tree_map_with_path( - _preserve_rope_caches, params, new_params - ) + new_params = jax.tree_util.tree_map_with_path(_preserve_rope_caches, params, new_params) nnx.update(model, new_params) def load_model( @@ -249,9 +241,7 @@ def load_model( model_class = self._initialize_model(model_config) def create_model(rng: nnx.Rngs): - model = model_class( - model_config.hf_config, model_config.dtype, rng, self.mesh - ) + model = model_class(model_config.hf_config, model_config.dtype, rng, self.mesh) state = nnx.state(model) pspecs = nnx.get_partition_spec(state) sharded_state = jax.lax.with_sharding_constraint(state, pspecs) diff --git a/python/sgl_jax/srt/models/llama.py b/python/sgl_jax/srt/models/llama.py index 9006664da..9026e4db3 100644 --- a/python/sgl_jax/srt/models/llama.py +++ b/python/sgl_jax/srt/models/llama.py @@ -210,9 +210,7 @@ def __init__( self.layer_id = layer_id rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None - ): + if rope_scaling is not None and getattr(config, "original_max_position_embeddings", None): rope_scaling["original_max_position_embeddings"] = ( config.original_max_position_embeddings ) @@ -220,9 +218,7 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 8192) # Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias # Support internlm/internlm-7b with bias - attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False - ) + attention_bias = getattr(config, "attention_bias", False) or getattr(config, "bias", False) head_dim = getattr(config, "head_dim", None) self.self_attn = LlamaAttention( @@ -392,9 +388,7 @@ def __init__( logger.info("LlamaForCausalLM config dtype: %s", self.dtype) self.transformer = LlamaModel(config, dtype=self.dtype, rngs=rngs) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, rngs=rngs) - self.logits_processor = LogitsProcessor( - config.vocab_size, self.lm_head, self.mesh - ) + self.logits_processor = LogitsProcessor(config.vocab_size, self.lm_head, self.mesh) def load_weights(self, model_config: ModelConfig, rng_key: jax.Array): self.rng = nnx.Rngs(rng_key) diff --git a/python/sgl_jax/srt/models/qwen.py b/python/sgl_jax/srt/models/qwen.py index 96291e90e..9a475519b 100644 --- a/python/sgl_jax/srt/models/qwen.py +++ b/python/sgl_jax/srt/models/qwen.py @@ -409,9 +409,7 @@ def __call__( token_to_kv_pool: KVCache, logits_metadata: LogitsMetadata, ): - hidden_states, layers_kv_fused = self.transformer( - forward_batch, token_to_kv_pool - ) + hidden_states, layers_kv_fused = self.transformer(forward_batch, token_to_kv_pool) output = self.logits_processor(hidden_states, logits_metadata) return output, layers_kv_fused, True diff --git a/python/sgl_jax/srt/models/qwen2.py b/python/sgl_jax/srt/models/qwen2.py index 360182772..58f150925 100644 --- a/python/sgl_jax/srt/models/qwen2.py +++ b/python/sgl_jax/srt/models/qwen2.py @@ -323,9 +323,7 @@ def __init__( logger.info("Qwen2ForCausalLM config dtype: %s", self.dtype) self.transformer = Qwen2Model(config, dtype=self.dtype, rngs=rngs) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, rngs=rngs) - self.logits_processor = LogitsProcessor( - config.vocab_size, self.lm_head, self.mesh - ) + self.logits_processor = LogitsProcessor(config.vocab_size, self.lm_head, self.mesh) def load_weights(self, model_config: ModelConfig, rng_key: jax.Array): self.rng = nnx.Rngs(rng_key) diff --git a/python/sgl_jax/srt/models/qwen3.py b/python/sgl_jax/srt/models/qwen3.py index 38a3b4291..69dfed99a 100644 --- a/python/sgl_jax/srt/models/qwen3.py +++ b/python/sgl_jax/srt/models/qwen3.py @@ -362,9 +362,7 @@ def __init__( logger.info("QWen3ForCausalLMModel config dtype: %s", self.dtype) self.transformer = QWen3Model(config, dtype=self.dtype, rngs=rngs) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, rngs=rngs) - self.logits_processor = LogitsProcessor( - config.vocab_size, self.lm_head, self.mesh - ) + self.logits_processor = LogitsProcessor(config.vocab_size, self.lm_head, self.mesh) def load_weights(self, model_config: ModelConfig, rng_key: jax.Array): self.rng = nnx.Rngs(rng_key) diff --git a/python/sgl_jax/srt/models/qwen3_moe.py b/python/sgl_jax/srt/models/qwen3_moe.py index 62aee4033..600256d7e 100644 --- a/python/sgl_jax/srt/models/qwen3_moe.py +++ b/python/sgl_jax/srt/models/qwen3_moe.py @@ -185,9 +185,7 @@ def __init__( num_experts = getattr(config, "num_experts", 128) num_experts_per_tok = getattr(config, "num_experts_per_tok", 8) moe_intermediate_size = getattr(config, "moe_intermediate_size", 768) - expert_parallel_size = mesh.shape.get("data", 1) * mesh.shape.get( - "tensor", 1 - ) + expert_parallel_size = mesh.shape.get("data", 1) * mesh.shape.get("tensor", 1) self.moe_gate = GateLogit( input_size=config.hidden_size, features=num_experts, @@ -347,9 +345,7 @@ def __init__( logger.info("QWen3MoeForCausalLMModel config dtype: %s", self.dtype) self.transformer = QWen3MoeModel(config, dtype=self.dtype, rngs=rngs, mesh=mesh) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, rngs=rngs) - self.logits_processor = LogitsProcessor( - config.vocab_size, self.lm_head, self.mesh - ) + self.logits_processor = LogitsProcessor(config.vocab_size, self.lm_head, self.mesh) def load_weights(self, model_config: ModelConfig, rng_key: jax.Array): self.rng = nnx.Rngs(rng_key) @@ -501,8 +497,7 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict "down_proj": "wo", }[expert_type] expert_keys = [ - f"{prefix}.mlp.experts.{i}.{expert_type}.weight" - for i in range(num_experts) + f"{prefix}.mlp.experts.{i}.{expert_type}.weight" for i in range(num_experts) ] mappings[f"__MOE_EXPERTS__{prefix}.mlp.{target_name}"] = WeightMapping( @@ -519,9 +514,7 @@ def __call__( token_to_kv_pool: KVCache, logits_metadata: LogitsMetadata, ): - hidden_states, layers_kv_fused = self.transformer( - forward_batch, token_to_kv_pool - ) + hidden_states, layers_kv_fused = self.transformer(forward_batch, token_to_kv_pool) output = self.logits_processor(hidden_states, logits_metadata) return output, layers_kv_fused, True diff --git a/python/sgl_jax/srt/models/registry.py b/python/sgl_jax/srt/models/registry.py index 3a8c32a70..2690d7f1c 100644 --- a/python/sgl_jax/srt/models/registry.py +++ b/python/sgl_jax/srt/models/registry.py @@ -47,9 +47,7 @@ def _normalize_archs( logger.warning("No model architectures are specified") # filter out support architectures - normalized_arch = list( - filter(lambda model: model in self.models, architectures) - ) + normalized_arch = list(filter(lambda model: model in self.models, architectures)) # make sure Transformers backend is put at the last as a fallback if len(normalized_arch) != len(architectures): @@ -84,9 +82,7 @@ def import_model_classes(): continue if hasattr(module, "EntryClass"): entry = module.EntryClass - if isinstance( - entry, list - ): # To support multiple model classes in one module + if isinstance(entry, list): # To support multiple model classes in one module for tmp in entry: assert ( tmp.__name__ not in model_arch_name_to_cls diff --git a/python/sgl_jax/srt/precision_tracer.py b/python/sgl_jax/srt/precision_tracer.py index bf05d6be8..375ec6382 100644 --- a/python/sgl_jax/srt/precision_tracer.py +++ b/python/sgl_jax/srt/precision_tracer.py @@ -25,11 +25,7 @@ def default(self, obj): "__tensor_type__": "jax", "shape": list(obj.shape), "dtype": str(obj.dtype), - "data": ( - obj.tolist() - if obj.size < 100 - else f"" - ), + "data": (obj.tolist() if obj.size < 100 else f""), } except Exception: return { @@ -47,9 +43,7 @@ def default(self, obj): class PrecisionTracerRequestMetadata: def __init__(self, request_id, request_input_ids, forward_mode): self.request_id = request_id - self.input_hash = hashlib.md5( - str(request_input_ids).encode("utf-8") - ).hexdigest()[:16] + self.input_hash = hashlib.md5(str(request_input_ids).encode("utf-8")).hexdigest()[:16] self.request_input_ids = request_input_ids self.input_len = len(request_input_ids) self.forward_mode = forward_mode @@ -122,9 +116,7 @@ def __init__(self): # metadata self._current_batch_id = None self._records: dict[str, PrecisionTracerRecord] = {} - self._batch_requests_mapping: dict[ - int, list[PrecisionTracerRequestMetadata] - ] = {} + self._batch_requests_mapping: dict[int, list[PrecisionTracerRequestMetadata]] = {} self._token_counters: dict[str, int] = {} self._last_forward_pass_id: dict[str, int] = {} @@ -132,9 +124,7 @@ def __init__(self): def set_enable_precision_tracer(self, enabled: bool): self._enable_precision_tracer = enabled - logger.info( - "Precision tracer globally %s", "enabled" if enabled else "disabled" - ) + logger.info("Precision tracer globally %s", "enabled" if enabled else "disabled") def get_trace_active(self): with self.lock: @@ -168,8 +158,7 @@ def set_end_time_and_duration(self, request_id: str): with self.lock: self._records[request_id].end_time = time.time() self._records[request_id].duration = ( - self._records[request_id].end_time - - self._records[request_id].start_time + self._records[request_id].end_time - self._records[request_id].start_time ) def add_request_to_batch_requests_mapping( @@ -187,9 +176,7 @@ def start_trace( verbose_logging: bool = False, ): if not self._enable_precision_tracer: - logger.warning( - "Precision tracer is disabled. Enable with --enable-precision-tracer" - ) + logger.warning("Precision tracer is disabled. Enable with --enable-precision-tracer") return if self._trace_active: @@ -205,9 +192,7 @@ def start_trace( self._current_batch_id = None self._records: dict[str, PrecisionTracerRecord] = {} - self._batch_requests_mapping: dict[ - int, list[PrecisionTracerRequestMetadata] - ] = {} + self._batch_requests_mapping: dict[int, list[PrecisionTracerRequestMetadata]] = {} self._token_counters: dict[str, int] = {} self._last_forward_pass_id: dict[str, int] = {} @@ -244,9 +229,7 @@ def stop_trace(self): json.dump(record_dict, f, cls=TensorJSONEncoder, ensure_ascii=False) f.write("\n") - logger.info( - "Saved %s request traces to: %s", len(self._records), output_file - ) + logger.info("Saved %s request traces to: %s", len(self._records), output_file) except Exception as e: logger.error("Error saving traces to %s: %s", output_file, e) @@ -299,9 +282,7 @@ def jit_pure_callback_record( self, tensor: Any, name: str, stage: str, layer_id: int | None = None ) -> bool: if self._enable_precision_tracer: - full_stage = ( - f"{stage}_layer_id_{layer_id}" if layer_id is not None else stage - ) + full_stage = f"{stage}_layer_id_{layer_id}" if layer_id is not None else stage def trace_callback(tensor): # Debug logging to check what stage is being passed @@ -336,9 +317,7 @@ def record( return with self.lock: - request_in_batch = self._batch_requests_mapping.get( - self._current_batch_id, [] - ) + request_in_batch = self._batch_requests_mapping.get(self._current_batch_id, []) current_batch_id = self._current_batch_id if len(request_in_batch) == 0: @@ -384,9 +363,7 @@ def record( else: # For decode, use forward_pass_id to determine when to start new token group current_token_idx = self._token_counters.get(req_id, 0) - last_forward_pass_id = self._last_forward_pass_id.get( - req_id, -1 - ) + last_forward_pass_id = self._last_forward_pass_id.get(req_id, -1) # Check if this is a new forward pass (new inference step) is_new_forward_pass = ( @@ -398,9 +375,7 @@ def record( if is_new_forward_pass and last_forward_pass_id != -1: current_token_idx = self._token_counters.get(req_id, 0) + 1 self._token_counters[req_id] = current_token_idx - self._last_forward_pass_id[req_id] = ( - self._current_forward_pass_id - ) + self._last_forward_pass_id[req_id] = self._current_forward_pass_id # Look for existing token group at current position current_token_group = None @@ -419,9 +394,7 @@ def record( precision_records[category].append(current_token_group) # Update forward pass id for first record of this token if hasattr(self, "_current_forward_pass_id"): - self._last_forward_pass_id[req_id] = ( - self._current_forward_pass_id - ) + self._last_forward_pass_id[req_id] = self._current_forward_pass_id # Add record to the token group record_with_metadata = data.copy() @@ -507,12 +480,8 @@ def _calculate_tensor_pricision_info( "framework": "jax", "name": name, "stage": stage, - "shape": ( - tuple(tensor.shape) if hasattr(tensor, "shape") else "unknown" - ), - "dtype": ( - str(tensor.dtype) if hasattr(tensor, "dtype") else "unknown" - ), + "shape": (tuple(tensor.shape) if hasattr(tensor, "shape") else "unknown"), + "dtype": (str(tensor.dtype) if hasattr(tensor, "dtype") else "unknown"), "error": str(e), "layer_id": "unknown", "module_type": "unknown", @@ -679,9 +648,7 @@ def _verbose_logging_console(self, stats: dict[str, Any]): req_info = "" if "request_id" in stats: req_id_short = ( - stats["request_id"][:8] - if len(stats["request_id"]) > 8 - else stats["request_id"] + stats["request_id"][:8] if len(stats["request_id"]) > 8 else stats["request_id"] ) req_info = f"[Req:{req_id_short}]" diff --git a/python/sgl_jax/srt/reasoning_parser.py b/python/sgl_jax/srt/reasoning_parser.py index 3a8f14497..2592f641a 100644 --- a/python/sgl_jax/srt/reasoning_parser.py +++ b/python/sgl_jax/srt/reasoning_parser.py @@ -46,9 +46,7 @@ def detect_and_parse(self, text: str) -> StreamingParseResult: reasoning_text = splits[0] normal_text = splits[1].strip() - return StreamingParseResult( - normal_text=normal_text, reasoning_text=reasoning_text - ) + return StreamingParseResult(normal_text=normal_text, reasoning_text=reasoning_text) def parse_streaming_increment(self, new_text: str) -> StreamingParseResult: """ diff --git a/python/sgl_jax/srt/sampling/penaltylib/frequency_penalty.py b/python/sgl_jax/srt/sampling/penaltylib/frequency_penalty.py index abf4c7a37..6c83555a4 100644 --- a/python/sgl_jax/srt/sampling/penaltylib/frequency_penalty.py +++ b/python/sgl_jax/srt/sampling/penaltylib/frequency_penalty.py @@ -16,10 +16,7 @@ def __init__(self, orchestrator: BatchedPenalizerOrchestrator): self._is_prepared = False def _is_required(self) -> bool: - return any( - req.sampling_params.frequency_penalty != 0.0 - for req in self.orchestrator.reqs() - ) + return any(req.sampling_params.frequency_penalty != 0.0 for req in self.orchestrator.reqs()) def _prepare(self): # Only keep the frequency penalty values, not the large penalty array diff --git a/python/sgl_jax/srt/sampling/penaltylib/min_new_tokens.py b/python/sgl_jax/srt/sampling/penaltylib/min_new_tokens.py index 479042615..53cbf084a 100644 --- a/python/sgl_jax/srt/sampling/penaltylib/min_new_tokens.py +++ b/python/sgl_jax/srt/sampling/penaltylib/min_new_tokens.py @@ -12,15 +12,11 @@ def pad_sequence(sequences, batch_first=True, padding_value=0): """ max_len = max(len(seq) for seq in sequences) if batch_first: - padded = np.full( - (len(sequences), max_len), padding_value, dtype=sequences[0].dtype - ) + padded = np.full((len(sequences), max_len), padding_value, dtype=sequences[0].dtype) for i, seq in enumerate(sequences): padded[i, : len(seq)] = seq else: - padded = np.full( - (max_len, len(sequences)), padding_value, dtype=sequences[0].dtype - ) + padded = np.full((max_len, len(sequences)), padding_value, dtype=sequences[0].dtype) for i, seq in enumerate(sequences): padded[: len(seq), i] = seq return padded @@ -36,17 +32,13 @@ def __init__(self, orchestrator: BatchedPenalizerOrchestrator): self._is_prepared = False def _is_required(self) -> bool: - return any( - req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs() - ) + return any(req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()) def _prepare(self): min_new_tokens_list = [ req.sampling_params.min_new_tokens for req in self.orchestrator.reqs() ] - self.min_new_tokens = np.expand_dims( - np.array(min_new_tokens_list, dtype=np.int32), axis=1 - ) + self.min_new_tokens = np.expand_dims(np.array(min_new_tokens_list, dtype=np.int32), axis=1) # Store stop token sequences without creating large penalty array self.stop_token_sequences = [] @@ -59,9 +51,7 @@ def _prepare(self): if req.tokenizer.eos_token_id is not None: stop_tokens.add(req.tokenizer.eos_token_id) - self.stop_token_sequences.append( - np.array(list(stop_tokens), dtype=np.int64) - ) + self.stop_token_sequences.append(np.array(list(stop_tokens), dtype=np.int64)) self.len_output_tokens = np.zeros( (len(self.orchestrator.reqs()), 1), @@ -112,9 +102,7 @@ def _filter(self, keep_indices: np.ndarray): self.len_output_tokens = self.len_output_tokens[keep_indices] def _merge(self, their: "BatchedMinNewTokensPenalizer"): - self.min_new_tokens = np.concatenate( - [self.min_new_tokens, their.min_new_tokens], axis=0 - ) + self.min_new_tokens = np.concatenate([self.min_new_tokens, their.min_new_tokens], axis=0) self.stop_token_sequences.extend(their.stop_token_sequences) self.len_output_tokens = np.concatenate( [self.len_output_tokens, their.len_output_tokens], axis=0 diff --git a/python/sgl_jax/srt/sampling/penaltylib/orchestrator.py b/python/sgl_jax/srt/sampling/penaltylib/orchestrator.py index ab2a04227..832955377 100644 --- a/python/sgl_jax/srt/sampling/penaltylib/orchestrator.py +++ b/python/sgl_jax/srt/sampling/penaltylib/orchestrator.py @@ -74,9 +74,7 @@ def apply(self) -> np.ndarray | None: ) # Get active penalizers - active_penalizers = [ - (type(p), p) for p in self.penalizers.values() if p.is_prepared() - ] + active_penalizers = [(type(p), p) for p in self.penalizers.values() if p.is_prepared()] if len(active_penalizers) == 0: return None @@ -94,10 +92,7 @@ def apply(self) -> np.ndarray | None: result = None for penalty_type in penalty_order: - if ( - penalty_type in self.penalizers - and self.penalizers[penalty_type].is_prepared() - ): + if penalty_type in self.penalizers and self.penalizers[penalty_type].is_prepared(): penalty_values = self.penalizers[penalty_type].compute_penalty() if result is None: result = penalty_values.copy() diff --git a/python/sgl_jax/srt/sampling/penaltylib/presence_penalty.py b/python/sgl_jax/srt/sampling/penaltylib/presence_penalty.py index 4d93b004b..8debd9286 100644 --- a/python/sgl_jax/srt/sampling/penaltylib/presence_penalty.py +++ b/python/sgl_jax/srt/sampling/penaltylib/presence_penalty.py @@ -16,10 +16,7 @@ def __init__(self, orchestrator: BatchedPenalizerOrchestrator): self._is_prepared = False def _is_required(self) -> bool: - return any( - req.sampling_params.presence_penalty != 0.0 - for req in self.orchestrator.reqs() - ) + return any(req.sampling_params.presence_penalty != 0.0 for req in self.orchestrator.reqs()) def _prepare(self): # Only keep the presence penalty values, not the large penalty array @@ -56,6 +53,4 @@ def _merge(self, their: "BatchedPresencePenalizer"): self.presence_penalties = np.concatenate( [self.presence_penalties, their.presence_penalties], axis=0 ) - self.token_presence = np.concatenate( - [self.token_presence, their.token_presence], axis=0 - ) + self.token_presence = np.concatenate([self.token_presence, their.token_presence], axis=0) diff --git a/python/sgl_jax/srt/sampling/sampling_batch_info.py b/python/sgl_jax/srt/sampling/sampling_batch_info.py index f2846d48a..dee50e674 100644 --- a/python/sgl_jax/srt/sampling/sampling_batch_info.py +++ b/python/sgl_jax/srt/sampling/sampling_batch_info.py @@ -96,15 +96,11 @@ def from_model_worker_batch( pad_size: int = 0, mesh: Mesh = None, ) -> SamplingMetadata: - sharding = ( - NamedSharding(mesh, PartitionSpec()) if jax.process_count() == 1 else None - ) + sharding = NamedSharding(mesh, PartitionSpec()) if jax.process_count() == 1 else None padded_temperatures = np.concat( [ batch.sampling_info.temperatures, - np.array( - [1.0] * pad_size, dtype=batch.sampling_info.temperatures.dtype - ), + np.array([1.0] * pad_size, dtype=batch.sampling_info.temperatures.dtype), ] ).reshape(-1, 1) padded_top_ps = np.concat( @@ -135,17 +131,13 @@ def from_model_worker_batch( ), ] ) - sampling_seeds_device = device_array( - padded_sampling_seeds, sharding=sharding - ) + sampling_seeds_device = device_array(padded_sampling_seeds, sharding=sharding) else: sampling_seeds_device = None - (temperatures_device, top_ps_device, top_ks_device, min_ps_device) = ( - device_array( - (padded_temperatures, padded_top_ps, padded_top_ks, padded_min_ps), - sharding=sharding, - ) + (temperatures_device, top_ps_device, top_ks_device, min_ps_device) = device_array( + (padded_temperatures, padded_top_ps, padded_top_ks, padded_min_ps), + sharding=sharding, ) # Extract penalty information from penalizer orchestrator @@ -165,9 +157,7 @@ def from_model_worker_batch( (pad_size, original_linear_penalty.shape[1]), dtype=original_linear_penalty.dtype, ) - padded_linear_penalty = np.concat( - [original_linear_penalty, pad_rows], axis=0 - ) + padded_linear_penalty = np.concat([original_linear_penalty, pad_rows], axis=0) else: padded_linear_penalty = original_linear_penalty @@ -190,9 +180,7 @@ def from_model_worker_batch( (pad_size, original_linear_penalty.shape[1]), dtype=original_linear_penalty.dtype, ) - padded_linear_penalty = np.concat( - [original_linear_penalty, pad_rows], axis=0 - ) + padded_linear_penalty = np.concat([original_linear_penalty, pad_rows], axis=0) else: padded_linear_penalty = original_linear_penalty @@ -229,15 +217,11 @@ def from_model_worker_batch_for_precompile( shapes for all penalty types to ensure comprehensive compilation coverage. """ # Basic sampling parameters (same as original method) - sharding = ( - NamedSharding(mesh, PartitionSpec()) if jax.process_count() == 1 else None - ) + sharding = NamedSharding(mesh, PartitionSpec()) if jax.process_count() == 1 else None padded_temperatures = np.concat( [ batch.sampling_info.temperatures, - np.array( - [1.0] * pad_size, dtype=batch.sampling_info.temperatures.dtype - ), + np.array([1.0] * pad_size, dtype=batch.sampling_info.temperatures.dtype), ] ).reshape(-1, 1) padded_top_ps = np.concat( @@ -274,11 +258,9 @@ def from_model_worker_batch_for_precompile( else: sampling_seeds_device = None - (temperatures_device, top_ps_device, top_ks_device, min_ps_device) = ( - device_array( - (padded_temperatures, padded_top_ps, padded_top_ks, padded_min_ps), - sharding=sharding, - ) + (temperatures_device, top_ps_device, top_ks_device, min_ps_device) = device_array( + (padded_temperatures, padded_top_ps, padded_top_ks, padded_min_ps), + sharding=sharding, ) if batch.sampling_info.sampling_seeds is not None: @@ -301,9 +283,7 @@ def from_model_worker_batch_for_precompile( # Calculate batch size and vocab size batch_size = len(batch.sampling_info.temperatures) + pad_size vocab_size = batch.sampling_info.vocab_size - padded_linear_penalty = ( - jnp.ones((batch_size, vocab_size), dtype=jnp.float32) * 0.2 - ) + padded_linear_penalty = jnp.ones((batch_size, vocab_size), dtype=jnp.float32) * 0.2 (linear_penalty_device,) = device_array( (padded_linear_penalty,), @@ -368,17 +348,13 @@ def _get_global_server_args_dict(cls): return global_server_args_dict @classmethod - def generate_for_precompile( - cls, bs: int, vocab_size: int = 32000, do_penalties: bool = False - ): + def generate_for_precompile(cls, bs: int, vocab_size: int = 32000, do_penalties: bool = False): temperatures = np.array([0.6 for _ in range(bs)], dtype=np.float32) top_ps = np.array([0.9 for _ in range(bs)], dtype=np.float32) top_ks = np.array([30 for _ in range(bs)], dtype=np.int32) min_ps = np.array([0.6 for _ in range(bs)], dtype=np.float32) if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_SAMPLING"): - sampling_seeds = np.array( - [DEFAULT_SAMPLING_SEED for _ in range(bs)], dtype=np.int32 - ) + sampling_seeds = np.array([DEFAULT_SAMPLING_SEED for _ in range(bs)], dtype=np.int32) else: sampling_seeds = None diff --git a/python/sgl_jax/srt/sampling/sampling_params.py b/python/sgl_jax/srt/sampling/sampling_params.py index eb60057da..500cb1f3c 100644 --- a/python/sgl_jax/srt/sampling/sampling_params.py +++ b/python/sgl_jax/srt/sampling/sampling_params.py @@ -68,10 +68,7 @@ def __init__( self.stream_interval = stream_interval self.logit_bias = logit_bias # Used for deterministic sampling - if ( - get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_SAMPLING") - and sampling_seed is None - ): + if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_SAMPLING") and sampling_seed is None: # If deterministic sampling is enabled and sampling_seed is not set, use the default seed sampling_seed = DEFAULT_SAMPLING_SEED self.sampling_seed = sampling_seed @@ -86,25 +83,17 @@ def __init__( def verify(self, vocab_size): if self.temperature < 0.0: - raise ValueError( - f"temperature must be non-negative, got {self.temperature}." - ) + raise ValueError(f"temperature must be non-negative, got {self.temperature}.") if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") if not 0.0 <= self.min_p <= 1.0: raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") if self.top_k < 1 or self.top_k == -1: - raise ValueError( - f"top_k must be -1 (disable) or at least 1, got {self.top_k}." - ) + raise ValueError(f"top_k must be -1 (disable) or at least 1, got {self.top_k}.") if not -2.0 <= self.frequency_penalty <= 2.0: - raise ValueError( - f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}." - ) + raise ValueError(f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}.") if not -2.0 <= self.presence_penalty <= 2.0: - raise ValueError( - f"presence_penalty must be in [-2, 2], got {self.presence_penalty}." - ) + raise ValueError(f"presence_penalty must be in [-2, 2], got {self.presence_penalty}.") if not 0.0 <= self.repetition_penalty <= 2.0: raise ValueError( f"repetition_penalty must be in [0, 2], got {self.repetition_penalty}." @@ -115,9 +104,7 @@ def verify(self, vocab_size): ) if self.max_new_tokens is not None: if self.max_new_tokens < 0: - raise ValueError( - f"max_new_tokens must be at least 0, got {self.max_new_tokens}." - ) + raise ValueError(f"max_new_tokens must be at least 0, got {self.max_new_tokens}.") if not self.min_new_tokens <= self.max_new_tokens: raise ValueError( f"min_new_tokens must be in [0, max_new_tokens({self.max_new_tokens})], got " diff --git a/python/sgl_jax/srt/server_args.py b/python/sgl_jax/srt/server_args.py index 33096a00a..96290019c 100644 --- a/python/sgl_jax/srt/server_args.py +++ b/python/sgl_jax/srt/server_args.py @@ -169,9 +169,9 @@ def __post_init__(self): self.chunked_prefill_size = 4096 # GGUF - if ( - self.load_format == "auto" or self.load_format == "gguf" - ) and check_gguf_file(self.model_path): + if (self.load_format == "auto" or self.load_format == "gguf") and check_gguf_file( + self.model_path + ): self.quantization = self.load_format = "gguf" if is_remote_url(self.model_path): @@ -801,9 +801,7 @@ def get_hf_config(self): return hf_config def check_server_args(self): - assert ( - self.tp_size - ) % self.nnodes == 0, "tp_size must be divisible by number of nodes" + assert (self.tp_size) % self.nnodes == 0, "tp_size must be divisible by number of nodes" # Check chunked prefill # Skip validation if chunked prefill is disabled (i.e., size <= 0). @@ -868,13 +866,9 @@ def init_new(server_args, dp_rank: int | None = None) -> "PortArgs": rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", pub_sub_addr=( - f"tcp://{dist_init_host}:{port_base + 4}" - if server_args.nnodes > 1 - else None + f"tcp://{dist_init_host}:{port_base + 4}" if server_args.nnodes > 1 else None ), pub_sub_sync_addr=( - f"tcp://{dist_init_host}:{port_base + 5}" - if server_args.nnodes > 1 - else None + f"tcp://{dist_init_host}:{port_base + 5}" if server_args.nnodes > 1 else None ), ) diff --git a/python/sgl_jax/srt/utils/common_utils.py b/python/sgl_jax/srt/utils/common_utils.py index 1224c66b8..e622573cc 100644 --- a/python/sgl_jax/srt/utils/common_utils.py +++ b/python/sgl_jax/srt/utils/common_utils.py @@ -116,9 +116,7 @@ def set_ulimit(target_soft_limit=65535): target_soft_limit_stack_size = 1024 * target_soft_limit if current_soft < target_soft_limit_stack_size: try: - resource.setrlimit( - resource_type, (target_soft_limit_stack_size, current_hard) - ) + resource.setrlimit(resource_type, (target_soft_limit_stack_size, current_hard)) except ValueError as e: logger.warning("Fail to set RLIMIT_STACK: %s", e) @@ -168,9 +166,7 @@ def configure_logger(server_args, prefix: str = ""): ) -def get_zmq_socket( - context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool -): +def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool): mem = psutil.virtual_memory() total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 @@ -214,9 +210,7 @@ def delete_directory(dirpath): print(f"Warning: {dirpath} : {e.strerror}") -def dataclass_to_string_truncated( - data, max_length=2048, skip_names: set[str] | None = None -): +def dataclass_to_string_truncated(data, max_length=2048, skip_names: set[str] | None = None): if skip_names is None: skip_names = set() if isinstance(data, str): @@ -268,9 +262,7 @@ def pyspy_dump_schedulers(): pid = psutil.Process().pid # Command to run py-spy with the PID cmd = f"py-spy dump --pid {pid}" - result = subprocess.run( - cmd, shell=True, capture_output=True, text=True, check=True - ) + result = subprocess.run(cmd, shell=True, capture_output=True, text=True, check=True) logger.error("Pyspy dump for PID %s:\n%s", pid, result.stdout) except subprocess.CalledProcessError as e: logger.error("Pyspy failed to dump PID %s. Error: %s", pid, e.stderr) @@ -289,9 +281,7 @@ def kill_itself_when_parent_died(): def set_uvicorn_logging_configs(): from uvicorn.config import LOGGING_CONFIG - LOGGING_CONFIG["formatters"]["default"][ - "fmt" - ] = "[%(asctime)s] %(levelprefix)s %(message)s" + LOGGING_CONFIG["formatters"]["default"]["fmt"] = "[%(asctime)s] %(levelprefix)s %(message)s" LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" LOGGING_CONFIG["formatters"]["access"][ "fmt" @@ -338,9 +328,7 @@ async def health_generate(): try: loop = asyncio.get_running_loop() - logger.info( - "Dummy health check server scheduled on existing loop at %s:%s", host, port - ) + logger.info("Dummy health check server scheduled on existing loop at %s:%s", host, port) loop.create_task(server.serve()) except RuntimeError: @@ -376,13 +364,9 @@ def retry( raise Exception("retry() exceed maximum number of retries.") from e if not should_retry(e): - raise Exception( - "retry() observe errors that should not be retried." - ) from e + raise Exception("retry() observe errors that should not be retried.") from e - delay = min(initial_delay * (2**try_index), max_delay) * ( - 0.75 + 0.25 * random.random() - ) + delay = min(initial_delay * (2**try_index), max_delay) * (0.75 + 0.25 * random.random()) logger.warning( "retry() failed once (%sth try, maximum %s retries). Will delay %.2fs and retry. Error: %s", @@ -404,9 +388,7 @@ def _to_hashable(o): except TypeError: # Not hashable; convert based on type if isinstance(o, (dict)): - return frozenset( - (_to_hashable(k), _to_hashable(v)) for k, v in o.items() - ) + return frozenset((_to_hashable(k), _to_hashable(v)) for k, v in o.items()) elif isinstance(o, set): return frozenset(_to_hashable(v) for v in o) elif isinstance(o, (list, tuple)) or ( @@ -422,9 +404,7 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): h_args = tuple(_to_hashable(a) for a in args) - h_kwargs = frozenset( - (_to_hashable(k), _to_hashable(v)) for k, v in kwargs.items() - ) + h_kwargs = frozenset((_to_hashable(k), _to_hashable(v)) for k, v in kwargs.items()) key = (h_args, h_kwargs) if key in cache: cache.move_to_end(key) diff --git a/python/sgl_jax/srt/utils/jax_utils.py b/python/sgl_jax/srt/utils/jax_utils.py index 79c145cf8..ad9e8ad7f 100644 --- a/python/sgl_jax/srt/utils/jax_utils.py +++ b/python/sgl_jax/srt/utils/jax_utils.py @@ -61,9 +61,7 @@ def get_available_device_memory(device, distributed=False, empty_cache=True): elif device in ("gpu", "cuda"): if empty_cache: jax.clear_caches() - devices = [ - d for d in jax.local_devices() if getattr(d, "platform", None) == "gpu" - ] + devices = [d for d in jax.local_devices() if getattr(d, "platform", None) == "gpu"] if not devices: raise RuntimeError("No GPU devices found by JAX") avail = [] @@ -84,9 +82,7 @@ def get_available_device_memory(device, distributed=False, empty_cache=True): # Use pmap to find the minimum available memory across all devices. mesh = jax.make_mesh((jax.process_count(), 4), ("node", "device")) - @jax.shard_map( - mesh=mesh, in_specs=PartitionSpec(None), out_specs=PartitionSpec(None) - ) + @jax.shard_map(mesh=mesh, in_specs=PartitionSpec(None), out_specs=PartitionSpec(None)) def _get_available_memory_distributed(a): return jax.lax.pmin(a, axis_name="node") diff --git a/python/sgl_jax/srt/utils/mesh_utils.py b/python/sgl_jax/srt/utils/mesh_utils.py index bee8a5c7f..ecfc77871 100644 --- a/python/sgl_jax/srt/utils/mesh_utils.py +++ b/python/sgl_jax/srt/utils/mesh_utils.py @@ -36,9 +36,7 @@ def create_device_mesh( return mesh -def fill_unspecified_parallelism( - parallelism: Sequence[int], num_devices: int -) -> Sequence[int]: +def fill_unspecified_parallelism(parallelism: Sequence[int], num_devices: int) -> Sequence[int]: if -1 not in parallelism: return parallelism diff --git a/python/sgl_jax/test/mem_cache/test_kv_cache.py b/python/sgl_jax/test/mem_cache/test_kv_cache.py index c3bcb6238..01432152d 100644 --- a/python/sgl_jax/test/mem_cache/test_kv_cache.py +++ b/python/sgl_jax/test/mem_cache/test_kv_cache.py @@ -71,10 +71,7 @@ def generate_test_data(self, total_tokens: int, add_padding: bool = False): loc = jnp.zeros(total_tokens, dtype=jnp.int32) loc = loc.at[::2].set(all_locs[: total_tokens // 2]) loc = loc.at[1::2].set( - all_locs[ - total_tokens // 2 : total_tokens // 2 - + (total_tokens - total_tokens // 2) - ] + all_locs[total_tokens // 2 : total_tokens // 2 + (total_tokens - total_tokens // 2)] ) else: # All valid tokens @@ -104,13 +101,9 @@ def expected_update_kv_cache(self, k, v, loc, k_cache, v_cache): def test_kv_cache_update_page_size_1(self): """Test KV cache update with page_size=1.""" total_tokens = 16 - k, v, loc, k_cache, v_cache = self.generate_test_data( - total_tokens, add_padding=False - ) + k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False) - updated_k_cache, updated_v_cache = update_kv_cache( - k, v, loc, k_cache, v_cache, page_size=1 - ) + updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=1) # Expected result expected_k_cache, expected_v_cache = self.expected_update_kv_cache( @@ -123,13 +116,9 @@ def test_kv_cache_update_page_size_1(self): def test_kv_cache_update_page_size_1_with_padding(self): """Test KV cache update with page_size=1 and padding tokens.""" total_tokens = 12 - k, v, loc, k_cache, v_cache = self.generate_test_data( - total_tokens, add_padding=True - ) + k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=True) - updated_k_cache, updated_v_cache = update_kv_cache( - k, v, loc, k_cache, v_cache, page_size=1 - ) + updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=1) # Expected result (should ignore padding tokens where loc == -1) expected_k_cache, expected_v_cache = self.expected_update_kv_cache( @@ -150,14 +139,10 @@ def test_kv_cache_update_page_size_1_with_padding(self): def test_kv_cache_update_page_size_4(self): """Test KV cache update with page_size=4.""" total_tokens = 16 - k, v, loc, k_cache, v_cache = self.generate_test_data( - total_tokens, add_padding=False - ) + k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False) # Test with page_size=4 - updated_k_cache, updated_v_cache = update_kv_cache( - k, v, loc, k_cache, v_cache, page_size=4 - ) + updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=4) # Expected result expected_k_cache, expected_v_cache = self.expected_update_kv_cache( @@ -171,14 +156,10 @@ def test_kv_cache_update_page_size_4(self): def test_kv_cache_update_page_size_4_with_padding(self): """Test KV cache update with page_size=4 and padding tokens.""" total_tokens = 12 - k, v, loc, k_cache, v_cache = self.generate_test_data( - total_tokens, add_padding=True - ) + k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=True) # Test with page_size=4 - updated_k_cache, updated_v_cache = update_kv_cache( - k, v, loc, k_cache, v_cache, page_size=4 - ) + updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=4) # Expected result (should ignore padding tokens where loc == -1) expected_k_cache, expected_v_cache = self.expected_update_kv_cache( @@ -191,14 +172,10 @@ def test_kv_cache_update_page_size_4_with_padding(self): def test_kv_cache_update_page_size_8_contiguous(self): """Test KV cache update with page_size=8 and contiguous locations.""" total_tokens = 16 - k, v, loc, k_cache, v_cache = self.generate_test_data( - total_tokens, add_padding=False - ) + k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False) # Test with page_size=8 - updated_k_cache, updated_v_cache = update_kv_cache( - k, v, loc, k_cache, v_cache, page_size=8 - ) + updated_k_cache, updated_v_cache = update_kv_cache(k, v, loc, k_cache, v_cache, page_size=8) # Expected result expected_k_cache, expected_v_cache = self.expected_update_kv_cache( @@ -211,9 +188,7 @@ def test_kv_cache_update_page_size_8_contiguous(self): def test_all_padding_tokens(self): """Test case where all tokens are padding tokens.""" total_tokens = 4 - k, v, _, k_cache, v_cache = self.generate_test_data( - total_tokens, add_padding=False - ) + k, v, _, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False) # Make all tokens padding loc = jnp.full((total_tokens,), -1, dtype=jnp.int32) @@ -234,15 +209,13 @@ def test_update_kv_cache_logic_page_size_1(self): from sgl_jax.srt.mem_cache.memory_pool import _optimize_contiguous_updates total_tokens = 8 - k, v, loc, k_cache, v_cache = self.generate_test_data( - total_tokens, add_padding=False - ) + k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False) # Test the optimization logic for page_size=1 page_size = 1 if page_size > 1: - kv_cache_locs, new_kv_locs, slice_lens, num_slices = ( - _optimize_contiguous_updates(loc, page_size) + kv_cache_locs, new_kv_locs, slice_lens, num_slices = _optimize_contiguous_updates( + loc, page_size ) else: # Use original logic for page_size = 1: one slice per token @@ -270,12 +243,8 @@ def test_update_kv_cache_logic_page_size_1(self): # Update cache for j in range(length): - updated_k_cache = updated_k_cache.at[cache_start + j].set( - k[new_start + j] - ) - updated_v_cache = updated_v_cache.at[cache_start + j].set( - v[new_start + j] - ) + updated_k_cache = updated_k_cache.at[cache_start + j].set(k[new_start + j]) + updated_v_cache = updated_v_cache.at[cache_start + j].set(v[new_start + j]) # Expected result expected_k_cache, expected_v_cache = self.expected_update_kv_cache( @@ -290,14 +259,12 @@ def test_update_kv_cache_logic_page_size_4(self): from sgl_jax.srt.mem_cache.memory_pool import _optimize_contiguous_updates total_tokens = 16 - k, v, loc, k_cache, v_cache = self.generate_test_data( - total_tokens, add_padding=False - ) + k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=False) # Test the optimization logic for page_size=4 page_size = 4 - kv_cache_locs, new_kv_locs, slice_lens, num_slices = ( - _optimize_contiguous_updates(loc, page_size) + kv_cache_locs, new_kv_locs, slice_lens, num_slices = _optimize_contiguous_updates( + loc, page_size ) # Verify the slice logic makes sense @@ -312,9 +279,7 @@ def test_update_kv_cache_logic_page_size_4(self): self.assertEqual(total_processed, non_padding_count) # For contiguous tokens with page_size=4, expect 4 slices of length 4 each - actual_slices = [ - (i, slice_lens[i]) for i in range(num_slices) if slice_lens[i] > 0 - ] + actual_slices = [(i, slice_lens[i]) for i in range(num_slices) if slice_lens[i] > 0] expected_slices = [(0, 4), (4, 4), (8, 4), (12, 4)] self.assertEqual(len(actual_slices), len(expected_slices)) for (actual_i, actual_len), (expected_i, expected_len) in zip( @@ -335,12 +300,8 @@ def test_update_kv_cache_logic_page_size_4(self): # Update cache for j in range(length): - updated_k_cache = updated_k_cache.at[cache_start + j].set( - k[new_start + j] - ) - updated_v_cache = updated_v_cache.at[cache_start + j].set( - v[new_start + j] - ) + updated_k_cache = updated_k_cache.at[cache_start + j].set(k[new_start + j]) + updated_v_cache = updated_v_cache.at[cache_start + j].set(v[new_start + j]) # For this test, we'll verify that the processed tokens are correctly updated # rather than expecting all tokens to be processed due to the optimization bug @@ -357,26 +318,20 @@ def test_update_kv_cache_logic_page_size_4(self): actual_k_val = updated_k_cache[cache_start + j] actual_v_val = updated_v_cache[cache_start + j] - self.assertTrue( - jnp.allclose(actual_k_val, expected_k_val, rtol=1e-5) - ) - self.assertTrue( - jnp.allclose(actual_v_val, expected_v_val, rtol=1e-5) - ) + self.assertTrue(jnp.allclose(actual_k_val, expected_k_val, rtol=1e-5)) + self.assertTrue(jnp.allclose(actual_v_val, expected_v_val, rtol=1e-5)) def test_update_kv_cache_logic_page_size_8_with_padding(self): """Test KV cache update logic with page_size=8 and padding using optimization functions.""" from sgl_jax.srt.mem_cache.memory_pool import _optimize_contiguous_updates total_tokens = 20 - k, v, loc, k_cache, v_cache = self.generate_test_data( - total_tokens, add_padding=True - ) + k, v, loc, k_cache, v_cache = self.generate_test_data(total_tokens, add_padding=True) # Test the optimization logic for page_size=8 with padding page_size = 8 - kv_cache_locs, new_kv_locs, slice_lens, num_slices = ( - _optimize_contiguous_updates(loc, page_size) + kv_cache_locs, new_kv_locs, slice_lens, num_slices = _optimize_contiguous_updates( + loc, page_size ) # Verify the slice logic handles padding correctly @@ -399,12 +354,8 @@ def test_update_kv_cache_logic_page_size_8_with_padding(self): # Update cache for j in range(length): - updated_k_cache = updated_k_cache.at[cache_start + j].set( - k[new_start + j] - ) - updated_v_cache = updated_v_cache.at[cache_start + j].set( - v[new_start + j] - ) + updated_k_cache = updated_k_cache.at[cache_start + j].set(k[new_start + j]) + updated_v_cache = updated_v_cache.at[cache_start + j].set(v[new_start + j]) # Expected result (should ignore padding tokens where loc == -1) expected_k_cache, expected_v_cache = self.expected_update_kv_cache( @@ -481,9 +432,7 @@ def test_kv_cache_update_multiple_segments_with_padding(self): cache_pos = 11 + i input_pos = i self.assertTrue( - jnp.allclose( - updated_k_cache[cache_pos], k[input_pos], rtol=1e-5 - ) + jnp.allclose(updated_k_cache[cache_pos], k[input_pos], rtol=1e-5) ) # Segment 2: cache locations 22-25 @@ -491,9 +440,7 @@ def test_kv_cache_update_multiple_segments_with_padding(self): cache_pos = 22 + i input_pos = 7 + i self.assertTrue( - jnp.allclose( - updated_k_cache[cache_pos], k[input_pos], rtol=1e-5 - ) + jnp.allclose(updated_k_cache[cache_pos], k[input_pos], rtol=1e-5) ) # Segment 3: cache locations 30-39 @@ -501,9 +448,7 @@ def test_kv_cache_update_multiple_segments_with_padding(self): cache_pos = 30 + i input_pos = 11 + i self.assertTrue( - jnp.allclose( - updated_k_cache[cache_pos], k[input_pos], rtol=1e-5 - ) + jnp.allclose(updated_k_cache[cache_pos], k[input_pos], rtol=1e-5) ) print(f" ✓ page_size={page_size} passed") @@ -529,8 +474,8 @@ def test_optimize_contiguous_updates_corner_cases(self): for page_size in [1, 2, 4, 8]: with self.subTest(case=1, page_size=page_size): - kv_cache_locs, new_kv_locs, slice_lens, num_slices = ( - _optimize_contiguous_updates(loc, page_size) + kv_cache_locs, new_kv_locs, slice_lens, num_slices = _optimize_contiguous_updates( + loc, page_size ) # Verify all valid tokens are processed @@ -571,8 +516,8 @@ def test_optimize_contiguous_updates_corner_cases(self): for page_size in [1, 2, 4]: with self.subTest(case=2, page_size=page_size): - kv_cache_locs, new_kv_locs, slice_lens, num_slices = ( - _optimize_contiguous_updates(loc2, page_size) + kv_cache_locs, new_kv_locs, slice_lens, num_slices = _optimize_contiguous_updates( + loc2, page_size ) total_processed = jnp.sum(slice_lens) @@ -594,14 +539,12 @@ def test_optimize_contiguous_updates_corner_cases(self): for page_size in [1, 4]: with self.subTest(case=3, page_size=page_size): - kv_cache_locs, new_kv_locs, slice_lens, num_slices = ( - _optimize_contiguous_updates(loc3, page_size) + kv_cache_locs, new_kv_locs, slice_lens, num_slices = _optimize_contiguous_updates( + loc3, page_size ) total_processed = jnp.sum(slice_lens) - self.assertEqual( - total_processed, 0, "Should process 0 tokens when all are padding" - ) + self.assertEqual(total_processed, 0, "Should process 0 tokens when all are padding") # No slices should be created slices = [i for i in range(num_slices) if slice_lens[i] > 0] diff --git a/python/sgl_jax/test/mem_cache/test_radix_cache.py b/python/sgl_jax/test/mem_cache/test_radix_cache.py index 82370e850..6429940a4 100644 --- a/python/sgl_jax/test/mem_cache/test_radix_cache.py +++ b/python/sgl_jax/test/mem_cache/test_radix_cache.py @@ -134,9 +134,7 @@ def _create_radix_cache(self, mesh, req_pool, allocator, **kwargs): def _print_cache_sharding_info(self, cache, mesh, req_pool, allocator): print("\n" + "=" * 60) - print( - f"[MESH INFO] device number: {len(self.devices)}, Mesh axis: {mesh.axis_names}" - ) + print(f"[MESH INFO] device number: {len(self.devices)}, Mesh axis: {mesh.axis_names}") print(f"[MESH INFO] Mesh device layout: {mesh.devices.shape}") print(f"[MESH INFO] Mesh: {mesh}") @@ -442,9 +440,7 @@ def test_get_cached_kv_without_value(self): # since there is no actual KV data, get_cpu_copy will use token values as indices # this will return data at the corresponding position in the KV cache (usually zero values, because the cache is initialized to zero) # verify that the returned data is not empty, but contains zero value data - self.assertEqual( - kv_data.shape[1], matched_len, "returned KV data length should match" - ) + self.assertEqual(kv_data.shape[1], matched_len, "returned KV data length should match") # verify data content - should be all zero (because KV cache is initialized to zero) kv_data_cpu = jax.device_get(kv_data) @@ -469,9 +465,7 @@ def test_empty_key_handling(self): def test_kv_cache_events(self): mesh, req_pool, allocator = self._create_auto_device_setup() - cache = self._create_radix_cache( - mesh, req_pool, allocator, enable_kv_cache_events=True - ) + cache = self._create_radix_cache(mesh, req_pool, allocator, enable_kv_cache_events=True) # test event queue events = cache.take_events() @@ -540,9 +534,7 @@ def test_device_consistency(self): # verify device type (should be CPU) if hasattr(device_indices, "device"): device_str = str(device_indices.device) - self.assertIn( - "cpu", device_str.lower(), f"Expected CPU device, got: {device_str}" - ) + self.assertIn("cpu", device_str.lower(), f"Expected CPU device, got: {device_str}") # check array content correctness self.assertEqual(len(device_indices), len(key)) @@ -570,9 +562,7 @@ def test_cross_device_operations(self): device_indices = match_result.device_indices if hasattr(device_indices, "device"): device_str = str(device_indices.device) - self.assertIn( - "cpu", device_str.lower(), f"Expected CPU device, got: {device_str}" - ) + self.assertIn("cpu", device_str.lower(), f"Expected CPU device, got: {device_str}") # verify content self.assertEqual(len(device_indices), len(prefix_key)) @@ -692,9 +682,7 @@ def setUp(self): def _print_cache_sharding_info(self, cache, mesh, req_pool, allocator): """print cache related sharding information""" print("\n" + "=" * 60) - print( - f"[MESH INFO] device count: {len(self.devices)}, Mesh axis: {mesh.axis_names}" - ) + print(f"[MESH INFO] device count: {len(self.devices)}, Mesh axis: {mesh.axis_names}") print(f"[MESH INFO] Mesh device layout: {mesh.devices.shape}") print(f"[MESH INFO] Mesh: {mesh}") diff --git a/python/sgl_jax/test/model_executor/test_model_runner.py b/python/sgl_jax/test/model_executor/test_model_runner.py index fc6e99fa4..fdc492cd8 100644 --- a/python/sgl_jax/test/model_executor/test_model_runner.py +++ b/python/sgl_jax/test/model_executor/test_model_runner.py @@ -32,9 +32,7 @@ def setUp(self): """Set up ModelRunner""" num_processes = int(os.environ.get("SGL_JAX_NUM_PROCESSES", 1)) process_id = int(os.environ.get("SGL_JAX_PROCESS_ID", 0)) - coordinator_address = os.environ.get( - "SGL_JAX_COORDINATOR_ADDRESS", "localhost:10000" - ) + coordinator_address = os.environ.get("SGL_JAX_COORDINATOR_ADDRESS", "localhost:10000") if num_processes > 1: jax.distributed.initialize( coordinator_address=coordinator_address, @@ -108,25 +106,19 @@ def _get_tokenizer(self): # Check if it's a local path and has tokenizer files if model_path.exists(): tokenizer_files = ["tokenizer_config.json"] - has_tokenizer = any( - (model_path / file).exists() for file in tokenizer_files - ) + has_tokenizer = any((model_path / file).exists() for file in tokenizer_files) if has_tokenizer: print(f"Using local tokenizer from: {model_path}") try: - return AutoTokenizer.from_pretrained( - str(model_path), trust_remote_code=True - ) + return AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) except Exception as e: print(f" Failed to load local tokenizer: {e}") # Use HuggingFace model with network error handling try: print(f"Loading tokenizer from HuggingFace: {self.model_path}") - return AutoTokenizer.from_pretrained( - self.model_path, trust_remote_code=True - ) + return AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) except Exception as e: print(f"Failed to load tokenizer from HuggingFace: {e}") raise RuntimeError( @@ -137,9 +129,7 @@ def _new_forward_batch(self, input_ids, positions): """Create a ForwardBatch for testing.""" total_tokens = sum(len(ids) for ids in input_ids) req_pool_indices = self.model_runner.req_to_token_pool.alloc(len(input_ids)) - cache_loc_index = self.model_runner.token_to_kv_pool_allocator.alloc( - total_tokens - ) + cache_loc_index = self.model_runner.token_to_kv_pool_allocator.alloc(total_tokens) # out_cache_loc = self.model_runner.token_to_kv_pool_allocator.alloc(len(input_ids)) # write to req_to_token_pool @@ -175,9 +165,7 @@ def _new_forward_batch(self, input_ids, positions): def _update_forward_batch(self, forward_batch: ForwardBatch, output_ids: jax.Array): """Update the forward batch with the next token ids.""" - out_cache_loc = self.model_runner.token_to_kv_pool_allocator.alloc( - len(output_ids) - ) + out_cache_loc = self.model_runner.token_to_kv_pool_allocator.alloc(len(output_ids)) forward_batch.forward_mode = ForwardMode.DECODE forward_batch.input_ids = output_ids.flatten() @@ -197,9 +185,7 @@ def _update_forward_batch(self, forward_batch: ForwardBatch, output_ids: jax.Arr ) forward_batch.out_cache_loc = jnp.array(out_cache_loc) - forward_batch.seq_lens = jnp.array( - [seq_len + 1 for seq_len in forward_batch.seq_lens] - ) + forward_batch.seq_lens = jnp.array([seq_len + 1 for seq_len in forward_batch.seq_lens]) token_indices_with_all_reqs = self.model_runner.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices @@ -241,9 +227,7 @@ def test_forward(self): extend_output.next_token_logits.shape, (1, self.model_config.vocab_size) ) # (batch_size, vocab_size) - print( - f" Extend phase completed. Output shape: {extend_output.next_token_logits.shape}" - ) + print(f" Extend phase completed. Output shape: {extend_output.next_token_logits.shape}") # Step 2: Multiple decode phases (generation) # Continue from the extend batch for proper KV cache continuity @@ -267,9 +251,7 @@ def test_forward(self): with self.mesh: decode_output = self.model_runner.forward( current_batch, - LogitsMetadata.from_model_worker_batch( - model_worker_batch, self.mesh - ), + LogitsMetadata.from_model_worker_batch(model_worker_batch, self.mesh), ) decode_outputs.append(decode_output) @@ -300,9 +282,7 @@ def test_forward(self): print("Precision trace not saved") # Verify all decode outputs have consistent shapes for output in decode_outputs: - self.assertEqual( - output.next_token_logits.shape, (1, self.model_config.vocab_size) - ) + self.assertEqual(output.next_token_logits.shape, (1, self.model_config.vocab_size)) self.assertEqual(output.next_token_logits.dtype, jnp.bfloat16) self.assertEqual(current_token.shape, (1, 1)) # (batch_size, 1) # Assertions for final verification diff --git a/python/sgl_jax/test/models/test_qwen_model.py b/python/sgl_jax/test/models/test_qwen_model.py index 2655fa07d..8859189a0 100644 --- a/python/sgl_jax/test/models/test_qwen_model.py +++ b/python/sgl_jax/test/models/test_qwen_model.py @@ -21,9 +21,7 @@ class TestQwenModel(unittest.TestCase): """Test cases for the Qwen model.""" def setUp(self): - self.mesh = create_device_mesh( - ici_parallelism=[-1, 1, 1], dcn_parallelism=[1, 1, 1] - ) + self.mesh = create_device_mesh(ici_parallelism=[-1, 1, 1], dcn_parallelism=[1, 1, 1]) # Model path for local model and tokenizer self.test_model_path = os.environ.get( "MODEL_PATH", "Qwen/Qwen-7B" @@ -43,25 +41,19 @@ def _get_tokenizer(self): # Check if it's a local path and has tokenizer files if model_path.exists(): tokenizer_files = ["tokenizer_config.json"] - has_tokenizer = any( - (model_path / file).exists() for file in tokenizer_files - ) + has_tokenizer = any((model_path / file).exists() for file in tokenizer_files) if has_tokenizer: print(f"Using local tokenizer from: {model_path}") try: - return AutoTokenizer.from_pretrained( - str(model_path), trust_remote_code=True - ) + return AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) except Exception as e: print(f" Failed to load local tokenizer: {e}") # Use HuggingFace model with network error handling try: print(f"Loading tokenizer from HuggingFace: {self.test_model_path}") - return AutoTokenizer.from_pretrained( - self.test_model_path, trust_remote_code=True - ) + return AutoTokenizer.from_pretrained(self.test_model_path, trust_remote_code=True) except Exception as e: print(f"Failed to load tokenizer from HuggingFace: {e}") raise RuntimeError( @@ -273,16 +265,12 @@ def _generate_random_questions(self, batch_size: int) -> list[str]: if "{}" in template: param_count = template.count("{}") - fill_params = random.sample( - fill_words, min(param_count, len(fill_words)) - ) + fill_params = random.sample(fill_words, min(param_count, len(fill_words))) try: question = template.format(*fill_params) except (IndexError, ValueError): - question = ( - f"Question {i + 1}: Tell me about {random.choice(fill_words)}" - ) + question = f"Question {i + 1}: Tell me about {random.choice(fill_words)}" else: question = template @@ -315,9 +303,7 @@ def _update_forward_batch( # Check if this request should finish BEFORE updating sequences if self._is_finished(current_token_id, tokenizer): - print( - f" Request {orig_idx} will be removed from batch (token: {current_token_id})" - ) + print(f" Request {orig_idx} will be removed from batch (token: {current_token_id})") finished_requests.add(orig_idx) continue @@ -336,9 +322,7 @@ def _update_forward_batch( forward_batch.seq_lens = jnp.array(new_seq_lens, dtype=jnp.int32) # update cache loc - out_cache_start_loc = ( - max(item for sublist in new_cache_loc for item in sublist) + 1 - ) + out_cache_start_loc = max(item for sublist in new_cache_loc for item in sublist) + 1 forward_batch.out_cache_loc = jnp.arange( out_cache_start_loc, out_cache_start_loc + forward_batch.batch_size, @@ -382,9 +366,7 @@ def test_qwen_model_forward(self, batch_size: int = None): """ print("Testing Qwen model generation...") model = self._setup_model() - jax_profiling_dir = os.environ.get( - "JAX_TRACE_PROFILING_DIR", "/tmp/jax_profiling" - ) + jax_profiling_dir = os.environ.get("JAX_TRACE_PROFILING_DIR", "/tmp/jax_profiling") batch_size = int(os.environ.get("BATCH_SIZE", 10)) with self.mesh, jax_trace_context(jax_profiling_dir): sampler = Sampler(rngs=nnx.Rngs(0)) @@ -409,8 +391,8 @@ def test_qwen_model_forward(self, batch_size: int = None): start_time = time.time() - input_ids_array, actual_seq_lens, forward_batch = ( - self._create_batch_from_texts(model.config, input_texts, tokenizer) + input_ids_array, actual_seq_lens, forward_batch = self._create_batch_from_texts( + model.config, input_texts, tokenizer ) print("\n Batch Processing Info:") @@ -448,9 +430,7 @@ def test_qwen_model_forward(self, batch_size: int = None): print(f"Active requests: {forward_batch.batch_size}") # Forward pass - y = model( - forward_batch.input_ids, forward_batch.positions, forward_batch - ) + y = model(forward_batch.input_ids, forward_batch.positions, forward_batch) # Sample next token for each active sequence next_token_ids = sampler( @@ -468,12 +448,8 @@ def test_qwen_model_forward(self, batch_size: int = None): print(f"Generated tokens: {next_token_ids.tolist()}") for batch_idx, token_id in enumerate(next_token_ids): - decoded_token = tokenizer.decode( - int(token_id[0]), skip_special_tokens=False - ) - final_results[original_indices[batch_idx]][ - "output" - ] += decoded_token + decoded_token = tokenizer.decode(int(token_id[0]), skip_special_tokens=False) + final_results[original_indices[batch_idx]]["output"] += decoded_token if len(input_texts) <= 10 and (iteration % 5 == 0): print( @@ -512,9 +488,9 @@ def test_qwen_model_forward(self, batch_size: int = None): total_time = end_time - start_time finished_count = sum(1 for r in final_results.values() if r["finished"]) - avg_output_length = sum( - len(r["output"]) for r in final_results.values() - ) / len(final_results) + avg_output_length = sum(len(r["output"]) for r in final_results.values()) / len( + final_results + ) print("\n === Generation Results Summary ===") print("Performance Metrics:") @@ -530,11 +506,7 @@ def test_qwen_model_forward(self, batch_size: int = None): print("\nDetailed Results:") for i in range(len(input_texts)): result = final_results[i] - status = ( - " Finished" - if result["finished"] - else "⏰ Max iterations reached" - ) + status = " Finished" if result["finished"] else "⏰ Max iterations reached" print(f"\nRequest {i} ({status}):") print(f" Input: '{result['input']}'") print(f" Output: '{result['output']}'") diff --git a/python/sgl_jax/test/run_curl.py b/python/sgl_jax/test/run_curl.py index ce59b621e..7c42ed872 100644 --- a/python/sgl_jax/test/run_curl.py +++ b/python/sgl_jax/test/run_curl.py @@ -31,9 +31,7 @@ def run_curl(args): print(f"Payload: {json.dumps(payload, indent=2)}") try: - response = requests.post( - f"{base_url}/generate", json=payload, headers=headers, timeout=30 - ) + response = requests.post(f"{base_url}/generate", json=payload, headers=headers, timeout=30) print(f"Status Code: {response.status_code}") diff --git a/python/sgl_jax/test/run_eval.py b/python/sgl_jax/test/run_eval.py index d8baf8ccf..214f669a2 100644 --- a/python/sgl_jax/test/run_eval.py +++ b/python/sgl_jax/test/run_eval.py @@ -21,9 +21,7 @@ def run_eval(args): if "OPENAI_API_KEY" not in os.environ: os.environ["OPENAI_API_KEY"] = "EMPTY" - base_url = ( - f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" - ) + base_url = f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1" if args.eval_name == "mmlu": from sgl_jax.test.simple_eval_mmlu import MMLUEval @@ -35,12 +33,8 @@ def run_eval(args): equality_checker = ChatCompletionSampler(model="gpt-4-turbo") - filename = ( - "https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv" - ) - eval_obj = MathEval( - filename, equality_checker, args.num_examples, args.num_threads - ) + filename = "https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv" + eval_obj = MathEval(filename, equality_checker, args.num_examples, args.num_threads) elif args.eval_name == "mgsm": from sgl_jax.test.simple_eval_mgsm import MGSMEval @@ -52,9 +46,7 @@ def run_eval(args): elif args.eval_name == "gpqa": from sgl_jax.test.simple_eval_gpqa import GPQAEval - filename = ( - "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv" - ) + filename = "https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv" eval_obj = GPQAEval(filename, args.num_examples, args.num_threads) elif args.eval_name == "humaneval": from sgl_jax.test.simple_eval_humaneval import HumanEval @@ -109,9 +101,7 @@ def run_eval(args): default=None, help="Server or API base url if not using http host and port.", ) - parser.add_argument( - "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0." - ) + parser.add_argument("--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0.") parser.add_argument( "--port", type=int, diff --git a/python/sgl_jax/test/run_jax_loader_test.py b/python/sgl_jax/test/run_jax_loader_test.py index 6efce587e..5dcb16068 100644 --- a/python/sgl_jax/test/run_jax_loader_test.py +++ b/python/sgl_jax/test/run_jax_loader_test.py @@ -135,17 +135,11 @@ def main(): help="Specific test method to run (e.g., test_load_model_with_custom_path)", ) - parser.add_argument( - "--verbose", "-v", action="store_true", help="Verbose test output" - ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose test output") - parser.add_argument( - "--check-jax", action="store_true", help="Check JAX dependencies and exit" - ) + parser.add_argument("--check-jax", action="store_true", help="Check JAX dependencies and exit") - parser.add_argument( - "--check-deps", action="store_true", help="Check all dependencies and exit" - ) + parser.add_argument("--check-deps", action="store_true", help="Check all dependencies and exit") parser.add_argument( "--create-sample", @@ -192,9 +186,7 @@ def main(): print("\nRunning JAXModelLoader tests...") - success = run_tests( - test_name=args.test, model_path=args.model_path, verbose=args.verbose - ) + success = run_tests(test_name=args.test, model_path=args.model_path, verbose=args.verbose) if success: print("\n✓ All tests passed!") diff --git a/python/sgl_jax/test/run_qwen3_moe_test.py b/python/sgl_jax/test/run_qwen3_moe_test.py index 4ac5523ee..5b443018a 100644 --- a/python/sgl_jax/test/run_qwen3_moe_test.py +++ b/python/sgl_jax/test/run_qwen3_moe_test.py @@ -68,9 +68,7 @@ def check_sglang_dependencies(): try: importlib.util.find_spec("sgl_jax.srt.configs.load_config.LoadFormat") importlib.util.find_spec("sgl_jax.srt.model_loader.loader.JAXModelLoader") - importlib.util.find_spec( - "sgl_jax.srt.models.qwen3_moe.Qwen3MoeForCausalLMJaxModel" - ) + importlib.util.find_spec("sgl_jax.srt.models.qwen3_moe.Qwen3MoeForCausalLMJaxModel") print("✓ SGLang JAXModelLoader available") print("✓ Qwen3MoeForCausalLMJaxModel available") @@ -165,26 +163,24 @@ def create_sample_qwen3_moe_model(output_dir): mock_weights = { "model": { "embed_tokens": { - "kernel": np.random.randn( - config["vocab_size"], config["hidden_size"] - ).astype(np.float32) + "kernel": np.random.randn(config["vocab_size"], config["hidden_size"]).astype( + np.float32 + ) }, "layers": {}, "norm": {"scale": np.ones(config["hidden_size"], dtype=np.float32)}, }, "lm_head": { - "kernel": np.random.randn( - config["hidden_size"], config["vocab_size"] - ).astype(np.float32) + "kernel": np.random.randn(config["hidden_size"], config["vocab_size"]).astype( + np.float32 + ) }, } # Add layer weights with MoE structure for i in range(config["num_hidden_layers"]): layer_weights = { - "input_layernorm": { - "scale": np.ones(config["hidden_size"], dtype=np.float32) - }, + "input_layernorm": {"scale": np.ones(config["hidden_size"], dtype=np.float32)}, "self_attn": { "q_proj": { "kernel": np.random.randn( @@ -321,26 +317,18 @@ def main(): help="Specific test method to run (use --list-tests to see available tests)", ) - parser.add_argument( - "--verbose", "-v", action="store_true", help="Verbose test output" - ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose test output") - parser.add_argument( - "--check-jax", action="store_true", help="Check JAX dependencies and exit" - ) + parser.add_argument("--check-jax", action="store_true", help="Check JAX dependencies and exit") - parser.add_argument( - "--check-deps", action="store_true", help="Check all dependencies and exit" - ) + parser.add_argument("--check-deps", action="store_true", help="Check all dependencies and exit") parser.add_argument( "--create-sample", help="Create a sample Qwen3 MoE JAX model directory at the specified path", ) - parser.add_argument( - "--list-tests", action="store_true", help="List all available test methods" - ) + parser.add_argument("--list-tests", action="store_true", help="List all available test methods") args = parser.parse_args() @@ -397,9 +385,7 @@ def main(): print("\nRunning Qwen3 MoE JAXModelLoader tests...") - success = run_tests( - test_name=args.test, model_path=args.model_path, verbose=args.verbose - ) + success = run_tests(test_name=args.test, model_path=args.model_path, verbose=args.verbose) if success: print("\n✓ All Qwen3 MoE tests passed!") diff --git a/python/sgl_jax/test/run_qwen_test.py b/python/sgl_jax/test/run_qwen_test.py index 1ee29e2a5..9310c02dd 100644 --- a/python/sgl_jax/test/run_qwen_test.py +++ b/python/sgl_jax/test/run_qwen_test.py @@ -62,9 +62,7 @@ def check_transformers_dependencies(): return False -def run_tests( - test_name=None, model_path=None, verbose=False, enable_precision_tracer=False -): +def run_tests(test_name=None, model_path=None, verbose=False, enable_precision_tracer=False): """Run the QWen JAXModelLoader tests""" env = os.environ.copy() if model_path: @@ -127,22 +125,22 @@ def create_sample_qwen_model(output_dir): mock_weights = { "transformer": { "embed_tokens": { - "kernel": np.random.randn( - config["vocab_size"], config["hidden_size"] - ).astype(np.float32) + "kernel": np.random.randn(config["vocab_size"], config["hidden_size"]).astype( + np.float32 + ) }, "h": {}, "ln_f": {"scale": np.ones(config["hidden_size"], dtype=np.float32)}, }, "lm_head": { - "kernel": np.random.randn( - config["hidden_size"], config["vocab_size"] - ).astype(np.float32) + "kernel": np.random.randn(config["hidden_size"], config["vocab_size"]).astype( + np.float32 + ) }, "logits_processor": { - "kernel": np.random.randn( - config["hidden_size"], config["vocab_size"] - ).astype(np.float32) + "kernel": np.random.randn(config["hidden_size"], config["vocab_size"]).astype( + np.float32 + ) }, } @@ -233,26 +231,18 @@ def main(): help="Specific test method to run (use --list-tests to see available tests)", ) - parser.add_argument( - "--verbose", "-v", action="store_true", help="Verbose test output" - ) + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose test output") - parser.add_argument( - "--check-jax", action="store_true", help="Check JAX dependencies and exit" - ) + parser.add_argument("--check-jax", action="store_true", help="Check JAX dependencies and exit") - parser.add_argument( - "--check-deps", action="store_true", help="Check all dependencies and exit" - ) + parser.add_argument("--check-deps", action="store_true", help="Check all dependencies and exit") parser.add_argument( "--create-sample", help="Create a sample QWen JAX model directory at the specified path", ) - parser.add_argument( - "--list-tests", action="store_true", help="List all available test methods" - ) + parser.add_argument("--list-tests", action="store_true", help="List all available test methods") parser.add_argument( "--enable-precision-tracer", action="store_true", diff --git a/python/sgl_jax/test/simple_eval_common.py b/python/sgl_jax/test/simple_eval_common.py index 21600a42d..918c30ee4 100644 --- a/python/sgl_jax/test/simple_eval_common.py +++ b/python/sgl_jax/test/simple_eval_common.py @@ -127,9 +127,7 @@ def _pack_message(self, role: str, content: Any): def __call__(self, message_list: MessageList) -> str: if self.system_message: - message_list = [ - self._pack_message("system", self.system_message) - ] + message_list + message_list = [self._pack_message("system", self.system_message)] + message_list trial = 0 while True: try: @@ -425,9 +423,7 @@ def make_report_from_example_htmls(htmls: list[str]): """ Create a standalone HTML report from a list of example htmls """ - return jinja_env.from_string(_report_template).render( - score=None, metrics={}, htmls=htmls - ) + return jinja_env.from_string(_report_template).render(score=None, metrics={}, htmls=htmls) def download_dataset(path, url): diff --git a/python/sgl_jax/test/simple_eval_gpqa.py b/python/sgl_jax/test/simple_eval_gpqa.py index e93eeecf4..bbc946b30 100644 --- a/python/sgl_jax/test/simple_eval_gpqa.py +++ b/python/sgl_jax/test/simple_eval_gpqa.py @@ -38,9 +38,7 @@ def __init__( assert n_repeats == 1, "n_repeats only supported for num_examples" examples = rng.sample(examples, num_examples) examples = examples * n_repeats - examples = [ - example | {"permutation": rng.sample(range(4), 4)} for example in examples - ] + examples = [example | {"permutation": rng.sample(range(4), 4)} for example in examples] self.examples = examples self.n_repeats = n_repeats self.num_threads = num_threads diff --git a/python/sgl_jax/test/simple_eval_humaneval.py b/python/sgl_jax/test/simple_eval_humaneval.py index abc34a498..fc6219f26 100644 --- a/python/sgl_jax/test/simple_eval_humaneval.py +++ b/python/sgl_jax/test/simple_eval_humaneval.py @@ -89,13 +89,10 @@ def find_code(completion): def fn(sample: dict[str, str]): prompt_messages = [ - sampler._pack_message( - role="user", content=instruction + sample["prompt"] - ) + sampler._pack_message(role="user", content=instruction + sample["prompt"]) ] completions = [ - find_code(sampler(prompt_messages)) - for _ in range(self._num_samples_per_task) + find_code(sampler(prompt_messages)) for _ in range(self._num_samples_per_task) ] results = evaluate_functional_correctness(sample, completions) total = len(results) @@ -123,7 +120,5 @@ def fn(sample: dict[str, str]): }, ) - results = common.map_with_progress( - fn, self.examples, num_threads=self._num_threads - ) + results = common.map_with_progress(fn, self.examples, num_threads=self._num_threads) return common.aggregate_results(results) diff --git a/python/sgl_jax/test/simple_eval_math.py b/python/sgl_jax/test/simple_eval_math.py index 3e213fec3..a742de069 100644 --- a/python/sgl_jax/test/simple_eval_math.py +++ b/python/sgl_jax/test/simple_eval_math.py @@ -55,9 +55,7 @@ def fn(row: dict): response_text = sampler(prompt_messages) match = re.search(ANSWER_PATTERN, response_text) extracted_answer = match.group(1) if match else None - score = float( - check_equality(self.equality_checker, row["Answer"], extracted_answer) - ) + score = float(check_equality(self.equality_checker, row["Answer"], extracted_answer)) html = common.jinja_env.from_string(HTML_JINJA).render( prompt_messages=prompt_messages, next_message=dict(content=response_text, role="assistant"), diff --git a/python/sgl_jax/test/simple_eval_mgsm.py b/python/sgl_jax/test/simple_eval_mgsm.py index 5e5eed488..1717af35a 100644 --- a/python/sgl_jax/test/simple_eval_mgsm.py +++ b/python/sgl_jax/test/simple_eval_mgsm.py @@ -162,9 +162,7 @@ def __init__( def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(example: dict[str, str]): language = example["lang"] - latin_language = ( - "group_latin" if language in LATIN_LANGUAGES else "group_non_latin" - ) + latin_language = "group_latin" if language in LATIN_LANGUAGES else "group_non_latin" correct_answer = example["targets"] instructoin = LANG_TO_INSTRUCTIONS[language] prompt_messages = [ @@ -196,7 +194,5 @@ def fn(example: dict[str, str]): metrics={language: score, latin_language: score}, ) - results = common.map_with_progress( - fn, self.examples, num_threads=self._num_threads - ) + results = common.map_with_progress(fn, self.examples, num_threads=self._num_threads) return common.aggregate_results(results, default_stats=("mean", "std")) diff --git a/python/sgl_jax/test/simple_eval_mmlu.py b/python/sgl_jax/test/simple_eval_mmlu.py index df0cde7d9..149ac65e3 100644 --- a/python/sgl_jax/test/simple_eval_mmlu.py +++ b/python/sgl_jax/test/simple_eval_mmlu.py @@ -95,9 +95,7 @@ def __init__(self, filename: str, num_examples: int | None, num_threads: int): def __call__(self, sampler: SamplerBase) -> EvalResult: def fn(row: dict): prompt_messages = [ - sampler._pack_message( - content=format_multichoice_question(row), role="user" - ) + sampler._pack_message(content=format_multichoice_question(row), role="user") ] response_text = sampler(prompt_messages) match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text) @@ -112,9 +110,7 @@ def fn(row: dict): ) convo = prompt_messages + [dict(content=response_text, role="assistant")] category = subject2category.get(row["Subject"], "other") - return SingleEvalResult( - html=html, score=score, metrics={category: score}, convo=convo - ) + return SingleEvalResult(html=html, score=score, metrics={category: score}, convo=convo) results = common.map_with_progress(fn, self.examples, self.num_threads) return common.aggregate_results(results) diff --git a/python/sgl_jax/test/test_flashattention.py b/python/sgl_jax/test/test_flashattention.py index f52d323a3..1cc595bad 100644 --- a/python/sgl_jax/test/test_flashattention.py +++ b/python/sgl_jax/test/test_flashattention.py @@ -85,9 +85,7 @@ def create_qkv_cache( def write_prefix_tokens_for_kv(forward_batch, token_to_kv_pool: KVCache, lens, k, v): page_size = forward_batch.attn_backend.page_size # Use aligned positions for k/v indexing since k/v arrays are created with alignment gaps - aligned_seq_lens = ( - (forward_batch.seq_lens + page_size - 1) // page_size - ) * page_size + aligned_seq_lens = ((forward_batch.seq_lens + page_size - 1) // page_size) * page_size aligned_cache_loc_idx = jnp.concatenate( [jnp.array([0], dtype=jnp.int32), jnp.cumsum(aligned_seq_lens)] ) @@ -172,9 +170,7 @@ def align_to_size(lst, size, value=0): for i, (_, kv_len) in enumerate(lens): # Create token indices for this sequence based on actual k/v storage position - seq_token_indices = list( - range(current_aligned_pos, current_aligned_pos + kv_len) - ) + seq_token_indices = list(range(current_aligned_pos, current_aligned_pos + kv_len)) # Apply alignment padding to this sequence aligned_seq_indices = align_to_size(seq_token_indices, page_size, 0) cache_loc_flat.extend(aligned_seq_indices) @@ -338,9 +334,7 @@ def run_test(self, mode, lens, mode_args): padding_size = 4096 cache_loc_list = [] - aligned_seq_lens = ( - (forward_batch.seq_lens + page_size - 1) // page_size - ) * page_size + aligned_seq_lens = ((forward_batch.seq_lens + page_size - 1) // page_size) * page_size cache_start_loc = jnp.concatenate( [jnp.zeros(1, dtype=jnp.int32), jnp.cumsum(aligned_seq_lens)] ) @@ -377,9 +371,7 @@ def jit_attn(q, k, v, forward_batch, token_to_kv_pool: KVCache): return out # run - jax_output, _ = jit_attn( - q_shard, extend_k, extend_v, forward_batch, token_to_kv_pool - ) + jax_output, _ = jit_attn(q_shard, extend_k, extend_v, forward_batch, token_to_kv_pool) jax.block_until_ready(jax_output) rtol = 2e-2 # Relative tolerance @@ -424,9 +416,7 @@ def jit_attn(q, k, v, forward_batch, token_to_kv_pool: KVCache): ) # Check how many tokens have large differences - large_diff_tokens = int( - np.sum(np.max(diff.reshape(num_tokens, -1), axis=1) > 0.1) - ) + large_diff_tokens = int(np.sum(np.max(diff.reshape(num_tokens, -1), axis=1) > 0.1)) print(f"Tokens with max diff > 0.1: {large_diff_tokens}/{num_tokens}") are_close = np.allclose( @@ -455,9 +445,7 @@ def test_mha_prefill_accuracy_page_size_1(self): (512, 1024), ] - self.run_test( - "prefill", lens, (num_heads, head_dim, num_kv_heads, 1, jnp.bfloat16) - ) + self.run_test("prefill", lens, (num_heads, head_dim, num_kv_heads, 1, jnp.bfloat16)) def test_mha_decode_accuracy_page_size_1(self): """Test JAX attention accuracy against native fa""" @@ -477,9 +465,7 @@ def test_mha_decode_accuracy_page_size_1(self): (1, 1025), ] - self.run_test( - "decode", lens, (num_heads, head_dim, num_kv_heads, 1, jnp.bfloat16) - ) + self.run_test("decode", lens, (num_heads, head_dim, num_kv_heads, 1, jnp.bfloat16)) def test_mha_prefill_accuracy_page_size_8(self): """ @@ -495,9 +481,7 @@ def test_mha_prefill_accuracy_page_size_8(self): (5, 33), (5, 5), ] - self.run_test( - "prefill", lens, (num_heads, head_dim, num_kv_heads, 8, jnp.bfloat16) - ) + self.run_test("prefill", lens, (num_heads, head_dim, num_kv_heads, 8, jnp.bfloat16)) def test_mha_decode_accuracy_page_size_8(self): """Test JAX attention accuracy against native fa""" @@ -510,9 +494,7 @@ def test_mha_decode_accuracy_page_size_8(self): (1, 6), (1, 5), ] - self.run_test( - "decode", lens, (num_heads, head_dim, num_kv_heads, 8, jnp.bfloat16) - ) + self.run_test("decode", lens, (num_heads, head_dim, num_kv_heads, 8, jnp.bfloat16)) def test_mha_prefill_accuracy_page_size_64(self): """Test JAX attention accuracy against PyTorch reference""" @@ -530,9 +512,7 @@ def test_mha_prefill_accuracy_page_size_64(self): (123, 522), (1, 511), ] - self.run_test( - "prefill", lens, (num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16) - ) + self.run_test("prefill", lens, (num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16)) def test_mha_decode_accuracy_page_size_64(self): """Test JAX attention accuracy against native fa""" @@ -551,9 +531,7 @@ def test_mha_decode_accuracy_page_size_64(self): (1, 1024), (1, 1025), ] - self.run_test( - "decode", lens, (num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16) - ) + self.run_test("decode", lens, (num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16)) def test_gqa_prefill_accuracy_page_size_64(self): """Test JAX attention accuracy against PyTorch reference""" @@ -571,9 +549,7 @@ def test_gqa_prefill_accuracy_page_size_64(self): (123, 522), (1, 511), ] - self.run_test( - "prefill", lens, (num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16) - ) + self.run_test("prefill", lens, (num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16)) def test_gqa_decode_accuracy_page_size_64(self): """Test JAX attention accuracy against native fa""" @@ -593,9 +569,7 @@ def test_gqa_decode_accuracy_page_size_64(self): (1, 1025), ] - self.run_test( - "decode", lens, (num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16) - ) + self.run_test("decode", lens, (num_heads, head_dim, num_kv_heads, 64, jnp.bfloat16)) if __name__ == "__main__": diff --git a/python/sgl_jax/test/test_jax_model_loader.py b/python/sgl_jax/test/test_jax_model_loader.py index b5f9369e6..20e098af9 100644 --- a/python/sgl_jax/test/test_jax_model_loader.py +++ b/python/sgl_jax/test/test_jax_model_loader.py @@ -49,17 +49,13 @@ def _print_pytree_structure(self, pytree, prefix="", max_depth=10, current_depth print(f"{prefix}dict ({len(pytree)} keys):") for key, value in pytree.items(): print(f"{prefix} {key}:") - self._print_pytree_structure( - value, prefix + " ", max_depth, current_depth + 1 - ) + self._print_pytree_structure(value, prefix + " ", max_depth, current_depth + 1) elif isinstance(pytree, (list, tuple)): type_name = "list" if isinstance(pytree, list) else "tuple" print(f"{prefix}{type_name} ({len(pytree)} items):") for i, item in enumerate(pytree): print(f"{prefix} [{i}]:") - self._print_pytree_structure( - item, prefix + " ", max_depth, current_depth + 1 - ) + self._print_pytree_structure(item, prefix + " ", max_depth, current_depth + 1) elif hasattr(pytree, "shape") and hasattr(pytree, "dtype"): shape = pytree.shape dtype = pytree.dtype @@ -72,11 +68,7 @@ def _print_pytree_structure(self, pytree, prefix="", max_depth=10, current_depth memory_bytes = pytree.nbytes else: # Estimate memory usage - dtype_size = ( - pytree.dtype.itemsize - if hasattr(pytree.dtype, "itemsize") - else 4 - ) + dtype_size = pytree.dtype.itemsize if hasattr(pytree.dtype, "itemsize") else 4 memory_bytes = size * dtype_size if memory_bytes >= 1024**3: # GB @@ -121,7 +113,9 @@ def _print_pytree_structure(self, pytree, prefix="", max_depth=10, current_depth # For 1D tensors, show head and tail 3 elements if shape[0] <= 6: preview_data = pytree[:] - data_preview = f"\n{prefix} All {shape[0]} elements:\n{prefix} {preview_data}" + data_preview = ( + f"\n{prefix} All {shape[0]} elements:\n{prefix} {preview_data}" + ) else: head_elements = pytree[:3] tail_elements = pytree[-3:] @@ -154,9 +148,7 @@ class TestJAXModelLoader(CustomTestCase): """Test cases for JAXModelLoader""" def setUp(self): - self.mesh = create_device_mesh( - ici_parallelism=[-1, 1, 1], dcn_parallelism=[1, 1, 1] - ) + self.mesh = create_device_mesh(ici_parallelism=[-1, 1, 1], dcn_parallelism=[1, 1, 1]) self.test_model_path = os.environ.get("MODEL_PATH", "/tmp/test_jax_model") self.load_config = LoadConfig(load_format=LoadFormat.JAX) @@ -185,9 +177,7 @@ def test_prepare_jax_weights_local_path(self): """Test preparing JAX weights from local path""" loader = JAXModelLoader(self.load_config) - hf_folder, hf_weights_files = loader._prepare_jax_weights( - self.mock_model_path, None - ) + hf_folder, hf_weights_files = loader._prepare_jax_weights(self.mock_model_path, None) self.assertEqual(hf_folder, self.mock_model_path) self.assertEqual(len(hf_weights_files), 1) @@ -210,21 +200,15 @@ def test_load_model_with_real_path(self): if not os.path.exists(self.test_model_path): self.skipTest(f"Real model path {self.test_model_path} not found") - msgpack_files = [ - f for f in os.listdir(self.test_model_path) if f.endswith(".msgpack") - ] + msgpack_files = [f for f in os.listdir(self.test_model_path) if f.endswith(".msgpack")] if not msgpack_files: self.skipTest(f"No .msgpack files found in {self.test_model_path}") - model_config = ModelConfig( - model_path=self.test_model_path, model_override_args="{}" - ) + model_config = ModelConfig(model_path=self.test_model_path, model_override_args="{}") loader = JAXModelLoader(self.load_config) - with patch( - "sgl_jax.srt.model_loader.loader.get_model_architecture" - ) as mock_arch: + with patch("sgl_jax.srt.model_loader.loader.get_model_architecture") as mock_arch: mock_arch.return_value = (MockJAXModel, None) model = loader.load_model( model_config=model_config, diff --git a/python/sgl_jax/test/test_model_loader.py b/python/sgl_jax/test/test_model_loader.py index abc788eac..c266a00a8 100644 --- a/python/sgl_jax/test/test_model_loader.py +++ b/python/sgl_jax/test/test_model_loader.py @@ -51,9 +51,7 @@ def setUp(self): self.temp_dir = tempfile.mkdtemp() # Load config - self.load_config = LoadConfig( - load_format=LoadFormat.JAX, download_dir=self.temp_dir - ) + self.load_config = LoadConfig(load_format=LoadFormat.JAX, download_dir=self.temp_dir) def test_jax_model_loader_init(self): """Test JAXModelLoader initialization.""" @@ -90,13 +88,9 @@ def test_multi_device_environment_setup(self): print(f" Device details: {[str(d) for d in devices[:8]]}") # Verify we have simulated multiple devices - if "--xla_force_host_platform_device_count=8" in os.environ.get( - "XLA_FLAGS", "" - ): + if "--xla_force_host_platform_device_count=8" in os.environ.get("XLA_FLAGS", ""): print("PASS: Multi-device simulation properly configured") - self.assertGreaterEqual( - len(devices), 2, "Should have at least 2 simulated devices" - ) + self.assertGreaterEqual(len(devices), 2, "Should have at least 2 simulated devices") else: print("WARNING: Multi-device simulation not configured") @@ -116,9 +110,7 @@ def test_multi_device_environment_setup(self): # Verify all devices are CPU (as expected in test environment) for device in devices: - self.assertEqual( - device.platform, "cpu", f"Expected CPU device, got {device.platform}" - ) + self.assertEqual(device.platform, "cpu", f"Expected CPU device, got {device.platform}") print("PASS: Multi-device environment validation completed!") @@ -233,9 +225,7 @@ def setUp(self): # Single device fallback self.mesh = Mesh(devices, ("tensor",)) print(f" Using single-device mesh: {self.mesh}") - print( - f"[MESH CHECK] mesh shape: {self.mesh.shape}, mesh devices: {self.mesh.devices}" - ) + print(f"[MESH CHECK] mesh shape: {self.mesh.shape}, mesh devices: {self.mesh.devices}") # Initialize RNG self.rng = nnx.Rngs(42) @@ -274,9 +264,7 @@ def test_qwen_model_instantiation(self): ) # Create QWen model instance - model = QWenLMHeadModel( - model_config, model_config.dtype, self.rng, self.mesh - ) + model = QWenLMHeadModel(model_config, model_config.dtype, self.rng, self.mesh) self.assertIsInstance(model, QWenLMHeadModel) self.assertEqual(model.config, model_config) @@ -295,9 +283,7 @@ def test_safetensor_files_detection(self): if file.endswith(".safetensors"): safetensor_files.append(file) - self.assertGreater( - len(safetensor_files), 0, "No safetensor files found in model directory" - ) + self.assertGreater(len(safetensor_files), 0, "No safetensor files found in model directory") print(f"PASS: Found {len(safetensor_files)} safetensor files:") for f in safetensor_files[:5]: # Show first 5 files @@ -311,9 +297,7 @@ def test_weight_loading_process(self): ) # Create QWen model instance - model = QWenLMHeadModel( - model_config, model_config.dtype, self.rng, self.mesh - ) + model = QWenLMHeadModel(model_config, model_config.dtype, self.rng, self.mesh) # Print the actual parameter structure of the model try: @@ -372,9 +356,7 @@ def _print_param_structure(self, params, prefix="", max_depth=2, current_depth=0 print(f" {current_prefix}: {value.value.shape}") elif isinstance(value, dict): print(f" {current_prefix}/ (dict)") - self._print_param_structure( - value, current_prefix, max_depth, current_depth + 1 - ) + self._print_param_structure(value, current_prefix, max_depth, current_depth + 1) else: print(f" {current_prefix}: {type(value)}") @@ -392,9 +374,7 @@ def test_model_actual_structure_debug(self): print(f" Attention heads: {model_config.num_attention_heads}") # Create QWen model instance - model = QWenLMHeadModel( - model_config, model_config.dtype, self.rng, self.mesh - ) + model = QWenLMHeadModel(model_config, model_config.dtype, self.rng, self.mesh) # Check what attributes the model actually has print(" Model Attributes:") @@ -450,9 +430,7 @@ def print_sharding(params, prefix=""): else: if hasattr(v, "sharding"): print(f"[SHARDING] {prefix}: sharding={v.sharding}") - for i, shard in enumerate( - getattr(v, "addressable_shards", []) - ): + for i, shard in enumerate(getattr(v, "addressable_shards", [])): print( f" [SHARD] idx={i}, device={shard.device}, index={getattr(shard, 'index', None)}, shape={getattr(shard.data, 'shape', None)}" ) @@ -493,9 +471,7 @@ def test_model_parameter_structure_validation(self): ) # Create QWen model instance - model = QWenLMHeadModel( - model_config, model_config.dtype, self.rng, self.mesh - ) + model = QWenLMHeadModel(model_config, model_config.dtype, self.rng, self.mesh) print(" Validating Model Parameter Structure:") print(f" Model type: {type(model).__name__}") @@ -543,9 +519,7 @@ def test_model_parameter_structure_validation(self): traceback.print_exc() - def _print_param_structure_detailed( - self, params, prefix="", max_depth=3, current_depth=0 - ): + def _print_param_structure_detailed(self, params, prefix="", max_depth=3, current_depth=0): """Helper function to print detailed parameter structure.""" if current_depth >= max_depth: return @@ -555,14 +529,10 @@ def _print_param_structure_detailed( current_prefix = f"{prefix}.{key}" if prefix else key if hasattr(value, "value") and hasattr(value.value, "shape"): # This is a parameter with shape - print( - f" {current_prefix}: {value.value.shape} ({type(value).__name__})" - ) + print(f" {current_prefix}: {value.value.shape} ({type(value).__name__})") elif isinstance(value, dict): print(f" {current_prefix}/ (dict with {len(value)} keys)") - if ( - current_depth < max_depth - 1 - ): # Only recurse if not at max depth + if current_depth < max_depth - 1: # Only recurse if not at max depth self._print_param_structure_detailed( value, current_prefix, max_depth, current_depth + 1 ) @@ -572,9 +542,7 @@ def _print_param_structure_detailed( for i, item in enumerate(value[:3]): # Show first 3 items item_prefix = f"{current_prefix}.{i}" if hasattr(item, "value") and hasattr(item.value, "shape"): - print( - f" {item_prefix}: {item.value.shape} ({type(item).__name__})" - ) + print(f" {item_prefix}: {item.value.shape} ({type(item).__name__})") elif isinstance(item, dict): print(f" {item_prefix}/ (dict)") if current_depth < max_depth - 1: @@ -603,9 +571,7 @@ def test_multi_device_tensor_parallelism(self): model_path=self.model_path, trust_remote_code=True, dtype="bfloat16" ) - print( - f"🔄 Testing multi-device tensor parallelism with {len(devices[:4])} devices..." - ) + print(f"🔄 Testing multi-device tensor parallelism with {len(devices[:4])} devices...") # Create loader with multi-device mesh loader = get_model_loader(self.load_config, self.rng, self.mesh) @@ -613,9 +579,7 @@ def test_multi_device_tensor_parallelism(self): # Load model model = loader.load_model(model_config=model_config) - print( - f"PASS: Model loaded successfully on {len(self.mesh.devices)} devices" - ) + print(f"PASS: Model loaded successfully on {len(self.mesh.devices)} devices") # Check if weights are properly sharded state = nnx.state(model) @@ -643,17 +607,13 @@ def test_multi_device_tensor_parallelism(self): ) # Verify sharding is actually distributed - unique_devices = set( - shard.device for shard in weight.addressable_shards - ) + unique_devices = set(shard.device for shard in weight.addressable_shards) if len(unique_devices) > 1: print( f" PASS: Weight is distributed across {len(unique_devices)} devices" ) else: - print( - f" WARNING: Weight is on single device: {unique_devices}" - ) + print(f" WARNING: Weight is on single device: {unique_devices}") except Exception as e: print(f" ERROR: Could not check {param_path}: {e}") @@ -671,9 +631,7 @@ def test_multi_device_tensor_parallelism(self): keyword in str(e).lower() for keyword in ["weight", "tensor", "shape", "mapping", "jit"] ): - print( - "PASS: Test passed: Multi-device setup reached weight loading stage" - ) + print("PASS: Test passed: Multi-device setup reached weight loading stage") else: self.fail(f"Unexpected error in multi-device test: {e}") @@ -681,9 +639,7 @@ def test_tensor_parallel_computation(self): """Test that computation works correctly with tensor parallelism.""" devices = jax.devices() if len(devices) < 2: - self.skipTest( - "Tensor parallel computation test requires at least 2 devices" - ) + self.skipTest("Tensor parallel computation test requires at least 2 devices") try: model_config = ModelConfig( @@ -704,9 +660,7 @@ def test_tensor_parallel_computation(self): # Test that we can access parameters and they're properly sharded state = nnx.state(model) - embed_param = self._get_param_by_path( - state, "transformer.embed_tokens.embedding" - ) + embed_param = self._get_param_by_path(state, "transformer.embed_tokens.embedding") if hasattr(embed_param, "value"): embed_weight = embed_param.value @@ -716,9 +670,7 @@ def test_tensor_parallel_computation(self): # Verify we can perform operations on sharded weights # Simple operation to test sharding works weight_sum = jnp.sum(embed_weight, axis=0) - print( - f" PASS: Successfully computed sum over sharded weight: {weight_sum.shape}" - ) + print(f" PASS: Successfully computed sum over sharded weight: {weight_sum.shape}") print(f" 📊 Sum result sharding: {weight_sum.sharding}") print("PASS: Tensor parallel computation test completed!") @@ -729,10 +681,7 @@ def test_tensor_parallel_computation(self): traceback.print_exc() - if any( - keyword in str(e).lower() - for keyword in ["weight", "tensor", "shape", "jit"] - ): + if any(keyword in str(e).lower() for keyword in ["weight", "tensor", "shape", "jit"]): print("PASS: Test passed: Computation test reached expected stage") else: self.fail(f"Unexpected error in computation test: {e}") diff --git a/python/sgl_jax/test/test_multi_process_model_loader.py b/python/sgl_jax/test/test_multi_process_model_loader.py index a635d0a0d..6fe7480ab 100644 --- a/python/sgl_jax/test/test_multi_process_model_loader.py +++ b/python/sgl_jax/test/test_multi_process_model_loader.py @@ -53,9 +53,7 @@ def main(): ) model_path = os.environ.get("TEST_MODEL_PATH", "./test_models/your_model_dir") - model_config = ModelConfig( - model_path=model_path, trust_remote_code=True, dtype="bfloat16" - ) + model_config = ModelConfig(model_path=model_path, trust_remote_code=True, dtype="bfloat16") load_config = LoadConfig(load_format=LoadFormat.JAX) rng = nnx.Rngs(42) diff --git a/python/sgl_jax/test/test_multi_process_radix_cache.py b/python/sgl_jax/test/test_multi_process_radix_cache.py index 2223b7962..d240459d3 100644 --- a/python/sgl_jax/test/test_multi_process_radix_cache.py +++ b/python/sgl_jax/test/test_multi_process_radix_cache.py @@ -24,9 +24,7 @@ def print_cache_sharding_info(cache, mesh, req_pool, allocator, process_id): print(f"[PROCESS {process_id}] Local device count: {len(jax.local_devices())}") print(f"[PROCESS {process_id}] Global device count: {len(jax.devices())}") print(f"[PROCESS {process_id}] Mesh axes: {mesh.axis_names}") - print( - f"[PROCESS {process_id}] Local mesh device layout: {mesh.local_mesh.devices.shape}" - ) + print(f"[PROCESS {process_id}] Local mesh device layout: {mesh.local_mesh.devices.shape}") print(f"[PROCESS {process_id}] Global mesh device layout: {mesh.devices.shape}") print(f"[PROCESS {process_id}] Local mesh: {mesh.local_mesh}") print(f"[PROCESS {process_id}] Global mesh: {mesh}") @@ -36,9 +34,7 @@ def print_sharding(obj, name, prefix=""): full_name = f"{prefix}.{name}" if prefix else name if hasattr(obj, "sharding") and obj.sharding is not None: - print( - f"[PROCESS {process_id}] [SHARDING] {full_name}: sharding={obj.sharding}" - ) + print(f"[PROCESS {process_id}] [SHARDING] {full_name}: sharding={obj.sharding}") if hasattr(obj, "addressable_shards"): for i, shard in enumerate(obj.addressable_shards): print( @@ -49,9 +45,7 @@ def print_sharding(obj, name, prefix=""): f"[PROCESS {process_id}] [SHARDING] {full_name}: Unsharded, shape={obj.shape}, device={getattr(obj, 'device', 'unknown')}" ) else: - print( - f"[PROCESS {process_id}] [SHARDING] {full_name}: Non-JAX array, type={type(obj)}" - ) + print(f"[PROCESS {process_id}] [SHARDING] {full_name}: Non-JAX array, type={type(obj)}") # Print RadixCache's sharding information if hasattr(cache, "kv_cache_sharding"): @@ -59,9 +53,7 @@ def print_sharding(obj, name, prefix=""): f"[PROCESS {process_id}] [CACHE] KV cache sharding strategy: {cache.kv_cache_sharding}" ) if hasattr(cache, "token_sharding"): - print( - f"[PROCESS {process_id}] [CACHE] Token sharding strategy: {cache.token_sharding}" - ) + print(f"[PROCESS {process_id}] [CACHE] Token sharding strategy: {cache.token_sharding}") # Print ReqToTokenPool's sharding information print_sharding(req_pool.req_to_token, "req_to_token_pool.req_to_token") @@ -114,9 +106,7 @@ def create_multi_process_radix_cache(process_id, tp_size=8): import numpy as np - devices_reshaped = np.array(local_devices).reshape( - local_data_size, local_tensor_size - ) + devices_reshaped = np.array(local_devices).reshape(local_data_size, local_tensor_size) mesh = Mesh(devices_reshaped, ("data", "tensor")) else: # If local devices are insufficient, use single axis @@ -181,15 +171,11 @@ def test_basic_radix_cache_operations(cache, process_id): # Test matching match_result = cache.match_prefix(key) - print( - f"[PROCESS {process_id}] Match result length: {len(match_result.device_indices)}" - ) + print(f"[PROCESS {process_id}] Match result length: {len(match_result.device_indices)}") # Test getting KV data kv_data, matched_len = cache.get_cached_kv(key) - print( - f"[PROCESS {process_id}] KV data shape: {kv_data.shape}, match length: {matched_len}" - ) + print(f"[PROCESS {process_id}] KV data shape: {kv_data.shape}, match length: {matched_len}") # Test cache size print(f"[PROCESS {process_id}] Cache status:") @@ -214,12 +200,7 @@ def test_memory_usage(cache, process_id, pool_size_per_device): # Theoretical KV cache size per device theoretical_kv_size = ( - pool_size_per_device - * kv_head_num - * head_dim - * 2 - * layer_num - * bytes_per_element + pool_size_per_device * kv_head_num * head_dim * 2 * layer_num * bytes_per_element ) # 2 for K and V theoretical_kv_size_gb = theoretical_kv_size / (1024**3) @@ -256,9 +237,7 @@ def test_cross_process_isolation(cache, process_id): print(f"[PROCESS {process_id}] Inserting key{i + 1}: {key}") cache.insert(key) match_result = cache.match_prefix(key) - print( - f"[PROCESS {process_id}] Key{i + 1} match result: {len(match_result.device_indices)}" - ) + print(f"[PROCESS {process_id}] Key{i + 1} match result: {len(match_result.device_indices)}") # Test cache status print(f"[PROCESS {process_id}] Process-specific cache status:") @@ -277,12 +256,8 @@ def main(): coordinator_address = os.environ.get("COORDINATOR_ADDRESS", "localhost:12345") tp_size = int(os.environ.get("TP_SIZE", "8")) - print( - f"[PROCESS {process_id}] Starting multi-process environment initialization..." - ) - print( - f"[PROCESS {process_id}] Process ID: {process_id}, total processes: {num_processes}" - ) + print(f"[PROCESS {process_id}] Starting multi-process environment initialization...") + print(f"[PROCESS {process_id}] Process ID: {process_id}, total processes: {num_processes}") print(f"[PROCESS {process_id}] Coordinator address: {coordinator_address}") print(f"[PROCESS {process_id}] TP size: {tp_size}") @@ -306,9 +281,7 @@ def main(): try: # Create multi-process RadixCache - cache, mesh, req_pool, allocator = create_multi_process_radix_cache( - process_id, tp_size - ) + cache, mesh, req_pool, allocator = create_multi_process_radix_cache(process_id, tp_size) # Print sharding information print_cache_sharding_info(cache, mesh, req_pool, allocator, process_id) diff --git a/python/sgl_jax/test/test_utils.py b/python/sgl_jax/test/test_utils.py index a3ac4ed28..b22dcf54c 100644 --- a/python/sgl_jax/test/test_utils.py +++ b/python/sgl_jax/test/test_utils.py @@ -83,9 +83,7 @@ def create_device_mesh( return mesh -def fill_unspecified_parallelism( - parallelism: Sequence[int], num_devices: int -) -> Sequence[int]: +def fill_unspecified_parallelism(parallelism: Sequence[int], num_devices: int) -> Sequence[int]: if -1 not in parallelism: return parallelism @@ -126,9 +124,7 @@ def jax_trace_context(log_dir: str): class CustomTestCase(unittest.TestCase): def _callTestMethod(self, method): - max_retry = int( - os.environ.get("SGLANG_TEST_MAX_RETRY", "1" if is_in_ci() else "0") - ) + max_retry = int(os.environ.get("SGLANG_TEST_MAX_RETRY", "1" if is_in_ci() else "0")) retry( lambda: super(CustomTestCase, self)._callTestMethod(method), max_retry=max_retry, @@ -562,9 +558,7 @@ def run_bench_one_batch(model, other_args): # Return prefill_latency, decode_throughput, decode_latency prefill_line = output.split("\n")[-9] decode_line = output.split("\n")[-3] - pattern = ( - r"latency: (?P\d+\.\d+).*?throughput:\s*(?P\d+\.\d+)" - ) + pattern = r"latency: (?P\d+\.\d+).*?throughput:\s*(?P\d+\.\d+)" match = re.search(pattern, prefill_line) if match: prefill_latency = float(match.group("latency")) @@ -714,9 +708,7 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2): precision = lcs_len / len(s1) if len(s1) > 0 else 0 recall = lcs_len / len(s2) if len(s2) > 0 else 0 fmeasure = ( - (2 * precision * recall) / (precision + recall) - if precision + recall > 0 - else 0.0 + (2 * precision * recall) / (precision + recall) if precision + recall > 0 else 0.0 ) rouge_l_scores.append(fmeasure) diff --git a/python/sgl_jax/tools/trace_diff.py b/python/sgl_jax/tools/trace_diff.py index a6bcd8624..6be5d451d 100755 --- a/python/sgl_jax/tools/trace_diff.py +++ b/python/sgl_jax/tools/trace_diff.py @@ -106,9 +106,7 @@ def compare_token_groups( ) if len(tokens1) != len(tokens2): - differences.append( - f" {category}: Token count mismatch: {len(tokens1)} vs {len(tokens2)}" - ) + differences.append(f" {category}: Token count mismatch: {len(tokens1)} vs {len(tokens2)}") all_match = False # Continue comparing up to the shorter length min_length = min(len(tokens1), len(tokens2)) @@ -268,9 +266,7 @@ def group_records(records): all_match = False else: # Compare token-level stats - for ts_idx, (ts1, ts2) in enumerate( - zip(token_stats1, token_stats2) - ): + for ts_idx, (ts1, ts2) in enumerate(zip(token_stats1, token_stats2)): for ts_field in ["min", "max", "mean", "std", "value"]: ts_val1, ts_val2 = ts1.get(ts_field), ts2.get(ts_field) if ts_val1 is not None and ts_val2 is not None: @@ -381,9 +377,7 @@ def print_tree_differences(differences: list[str]): category_part = path.split("[")[0].strip() rest = path.split("]", 1)[1].strip() if "]" in path else "" token_part = ( - path.split("[")[1].split("]")[0] - if "[" in path and "]" in path - else "0" + path.split("[")[1].split("]")[0] if "[" in path and "]" in path else "0" ) if category_part not in tree: @@ -418,9 +412,7 @@ def display_category_sort_key(category): print(f"\n{Colors.BOLD}{category.upper()}:{Colors.RESET}") tokens = tree[category] - for token_idx in sorted( - tokens.keys(), key=lambda x: int(x) if x.isdigit() else 999 - ): + for token_idx in sorted(tokens.keys(), key=lambda x: int(x) if x.isdigit() else 999): token_diffs = tokens[token_idx] if len(token_diffs) > 0: print(f" Token[{token_idx}]:") @@ -527,9 +519,7 @@ def parse_layer_module(layer_key): total_items = len(mismatches) + len(matches) + len(others) shown_items = ( - min(3, len(mismatches)) - + min(3, len(matches)) - + min(2, len(others)) + min(3, len(mismatches)) + min(3, len(matches)) + min(2, len(others)) ) if total_items > shown_items: print( diff --git a/test/srt/openai_server/basic/test_openai_server.py b/test/srt/openai_server/basic/test_openai_server.py index 4c20c51a1..da3bddb1d 100644 --- a/test/srt/openai_server/basic/test_openai_server.py +++ b/test/srt/openai_server/basic/test_openai_server.py @@ -63,17 +63,13 @@ def setUpClass(cls): }, ) cls.base_url += "/v1" - cls.tokenizer = get_tokenizer( - DEFAULT_SMALL_MODEL_NAME_FOR_TEST, trust_remote_code=True - ) + cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST, trust_remote_code=True) @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid) - def run_completion( - self, echo, logprobs, use_list_input, parallel_sample_num, token_input - ): + def run_completion(self, echo, logprobs, use_list_input, parallel_sample_num, token_input): client = openai.Client(api_key=self.api_key, base_url=self.base_url) prompt = "The capital of France is" if token_input: @@ -182,9 +178,7 @@ def run_completion_stream( assert isinstance( response.choices[0].logprobs.top_logprobs[0], dict ), f"top_logprobs was not a dictionary" - ret_num_top_logprobs = len( - response.choices[0].logprobs.top_logprobs[0] - ) + ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[0]) # FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map # assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0" @@ -199,9 +193,7 @@ def run_completion_stream( assert response.created, f"no created in response" for index in [i for i in range(parallel_sample_num * num_choices)]: - assert not is_firsts.get( - index, True - ), f"index {index} is not found in the response" + assert not is_firsts.get(index, True), f"index {index} is not found in the response" def run_chat_completion(self, logprobs, parallel_sample_num): client = openai.Client(api_key=self.api_key, base_url=self.base_url) @@ -221,16 +213,10 @@ def run_chat_completion(self, logprobs, parallel_sample_num): ) if logprobs: - assert isinstance( - response.choices[0].logprobs.content[0].top_logprobs[0].token, str - ) + assert isinstance(response.choices[0].logprobs.content[0].top_logprobs[0].token, str) - ret_num_top_logprobs = len( - response.choices[0].logprobs.content[0].top_logprobs - ) - assert ( - ret_num_top_logprobs == logprobs - ), f"{ret_num_top_logprobs} vs {logprobs}" + ret_num_top_logprobs = len(response.choices[0].logprobs.content[0].top_logprobs) + assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" assert len(response.choices) == parallel_sample_num assert response.choices[0].message.role == "assistant" @@ -277,9 +263,7 @@ def run_chat_completion_stream(self, logprobs, parallel_sample_num=1): data = response.choices[0].delta if is_firsts.get(index, True): - assert ( - data.role == "assistant" - ), f"data.role was not 'assistant' for first chunk" + assert data.role == "assistant", f"data.role was not 'assistant' for first chunk" is_firsts[index] = False continue @@ -291,12 +275,8 @@ def run_chat_completion_stream(self, logprobs, parallel_sample_num=1): assert isinstance( response.choices[0].logprobs.content[0].top_logprobs, list ), f"top_logprobs was not a list" - ret_num_top_logprobs = len( - response.choices[0].logprobs.content[0].top_logprobs - ) - assert ( - ret_num_top_logprobs == logprobs - ), f"{ret_num_top_logprobs} vs {logprobs}" + ret_num_top_logprobs = len(response.choices[0].logprobs.content[0].top_logprobs) + assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}" assert ( isinstance(data.content, str) @@ -308,15 +288,11 @@ def run_chat_completion_stream(self, logprobs, parallel_sample_num=1): assert response.created for index in [i for i in range(parallel_sample_num)]: - assert not is_firsts.get( - index, True - ), f"index {index} is not found in the response" + assert not is_firsts.get(index, True), f"index {index} is not found in the response" # Verify that each choice gets exactly one finish_reason chunk for index in range(parallel_sample_num): - assert ( - index in finish_reason_counts - ), f"No finish_reason found for index {index}" + assert index in finish_reason_counts, f"No finish_reason found for index {index}" assert ( finish_reason_counts[index] == 1 ), f"Expected 1 finish_reason chunk for index {index}, got {finish_reason_counts[index]}" diff --git a/test/srt/openai_server/basic/test_protocol.py b/test/srt/openai_server/basic/test_protocol.py index a0bff3523..e539591cb 100644 --- a/test/srt/openai_server/basic/test_protocol.py +++ b/test/srt/openai_server/basic/test_protocol.py @@ -169,9 +169,7 @@ def test_chat_completion_tool_choice_validation(self): "function": {"name": "test_func", "description": "Test function"}, } ] - request2 = ChatCompletionRequest( - model="test-model", messages=messages, tools=tools - ) + request2 = ChatCompletionRequest(model="test-model", messages=messages, tools=tools) self.assertEqual(request2.tool_choice, "auto") def test_chat_completion_sglang_extensions(self): @@ -245,9 +243,7 @@ def test_invalid_tool_choice_type(self): """Test invalid tool choice type""" messages = [{"role": "user", "content": "Hello"}] with self.assertRaises(ValidationError): - ChatCompletionRequest( - model="test-model", messages=messages, tool_choice=123 - ) + ChatCompletionRequest(model="test-model", messages=messages, tool_choice=123) def test_negative_token_limits(self): """Test negative token limits""" diff --git a/test/srt/openai_server/basic/test_serving_chat.py b/test/srt/openai_server/basic/test_serving_chat.py index 0bf57fedc..b29891c3a 100644 --- a/test/srt/openai_server/basic/test_serving_chat.py +++ b/test/srt/openai_server/basic/test_serving_chat.py @@ -98,9 +98,7 @@ def setUp(self): # ------------- conversion tests ------------- def test_convert_to_internal_request_single(self): with ( - patch( - "sgl_jax.srt.entrypoints.openai.serving_chat.generate_chat_conv" - ) as conv_mock, + patch("sgl_jax.srt.entrypoints.openai.serving_chat.generate_chat_conv") as conv_mock, patch.object(self.chat, "_process_messages") as proc_mock, ): conv_ins = Mock() @@ -134,18 +132,14 @@ def test_stop_str_isolation_between_requests(self): # Mock conversation template with initial stop_str initial_stop_str = ["\n"] - with patch( - "sgl_jax.srt.entrypoints.openai.serving_chat.generate_chat_conv" - ) as conv_mock: + with patch("sgl_jax.srt.entrypoints.openai.serving_chat.generate_chat_conv") as conv_mock: # Create a mock conversation object that will be returned by generate_chat_conv conv_ins = Mock() conv_ins.get_prompt.return_value = "Test prompt" conv_ins.image_data = None conv_ins.audio_data = None conv_ins.modalities = [] - conv_ins.stop_str = ( - initial_stop_str.copy() - ) # Template's default stop strings + conv_ins.stop_str = initial_stop_str.copy() # Template's default stop strings conv_mock.return_value = conv_ins # First request with additional stop string @@ -243,9 +237,7 @@ async def test_unstreamed_tool_args_completion(self): # Should return a chunk with remaining arguments self.assertIsNotNone(result, "Should return chunk with remaining arguments") self.assertIn('"arguments":', result, "Should contain arguments field") - self.assertIn( - ', "unit": "celsius"}', result, "Should contain remaining arguments" - ) + self.assertIn(', "unit": "celsius"}', result, "Should contain remaining arguments") self.assertIn( '"finish_reason":null', result, @@ -324,9 +316,7 @@ async def test_unstreamed_tool_args_no_parser_data(self): ) # Should return None since there's no parser data - self.assertIsNone( - result, "Should return None when parser has no tool call data" - ) + self.assertIsNone(result, "Should return None when parser has no tool call data") if __name__ == "__main__": diff --git a/test/srt/openai_server/basic/test_serving_completions.py b/test/srt/openai_server/basic/test_serving_completions.py index abfa79896..4f9896717 100644 --- a/test/srt/openai_server/basic/test_serving_completions.py +++ b/test/srt/openai_server/basic/test_serving_completions.py @@ -63,9 +63,7 @@ def test_echo_with_string_prompt_streaming(self): self.assertEqual(self.sc._get_echo_text(req, 0), "Hello") def test_echo_with_list_of_strings_streaming(self): - req = CompletionRequest( - model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1 - ) + req = CompletionRequest(model="x", prompt=["A", "B"], max_tokens=1, echo=True, n=1) self.assertEqual(self.sc._get_echo_text(req, 0), "A") self.assertEqual(self.sc._get_echo_text(req, 1), "B") @@ -75,9 +73,7 @@ def test_echo_with_token_ids_streaming(self): self.assertEqual(self.sc._get_echo_text(req, 0), "decoded_prompt") def test_echo_with_multiple_token_ids_streaming(self): - req = CompletionRequest( - model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1 - ) + req = CompletionRequest(model="x", prompt=[[1, 2], [3, 4]], max_tokens=1, echo=True, n=1) self.sc.tokenizer_manager.tokenizer.decode.return_value = "decoded" self.assertEqual(self.sc._get_echo_text(req, 0), "decoded") diff --git a/test/srt/openai_server/validation/test_openai_server_params_validation.py b/test/srt/openai_server/validation/test_openai_server_params_validation.py index f420e0608..44d302279 100644 --- a/test/srt/openai_server/validation/test_openai_server_params_validation.py +++ b/test/srt/openai_server/validation/test_openai_server_params_validation.py @@ -103,9 +103,7 @@ def test_malformed_json_request(self): timeout=10, ) # return 400 rather than failed - self.assertEqual( - response.status_code, 400, f"Expected 400, got {response.status_code}" - ) + self.assertEqual(response.status_code, 400, f"Expected 400, got {response.status_code}") except Exception as e: self.fail(f"Server should handle malformed JSON gracefully, but got: {e}") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 5d86a4253..7bd1bcce0 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -115,12 +115,8 @@ def run_one_file(filename): return process.returncode try: - ret_code = run_with_timeout( - run_one_file, args=(filename,), timeout=timeout_per_file - ) - assert ( - ret_code == 0 - ), f"expected return code 0, but {filename} returned {ret_code}" + ret_code = run_with_timeout(run_one_file, args=(filename,), timeout=timeout_per_file) + assert ret_code == 0, f"expected return code 0, but {filename} returned {ret_code}" except TimeoutError: kill_process_tree(process.pid) time.sleep(5) diff --git a/test/srt/test_features.py b/test/srt/test_features.py index 9e4bccdff..6025950ca 100644 --- a/test/srt/test_features.py +++ b/test/srt/test_features.py @@ -124,9 +124,7 @@ def test_abort_all(self): ) for future in as_completed(futures): - self.assertEqual( - future.result()["meta_info"]["finish_reason"]["type"], "abort" - ) + self.assertEqual(future.result()["meta_info"]["finish_reason"]["type"], "abort") def test_cache_miss_prefill(self): args = SimpleNamespace( @@ -424,12 +422,8 @@ def test_logprobs(self): continue for j, pair in enumerate(subitem): real_prob, real_token = pair[0], pair[1] - self.assertEqual( - real_prob, expected_input_token_ids_logprobs[i - 1][j][0] - ) - self.assertEqual( - real_token, expected_input_token_ids_logprobs[i - 1][j][1] - ) + self.assertEqual(real_prob, expected_input_token_ids_logprobs[i - 1][j][0]) + self.assertEqual(real_token, expected_input_token_ids_logprobs[i - 1][j][1]) expected_output_token_ids_logprobs = [ [ @@ -522,9 +516,7 @@ def test_logprobs(self): for j, pair in enumerate(subitem): real_prob, real_token = pair[0], pair[1] self.assertEqual(real_prob, expected_output_token_ids_logprobs[i][j][0]) - self.assertEqual( - real_token, expected_output_token_ids_logprobs[i][j][1] - ) + self.assertEqual(real_token, expected_output_token_ids_logprobs[i][j][1]) def test_frequency_penalty(self): """Test frequency penalty functionality.""" diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 411fe00cf..2972406ec 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -54,8 +54,7 @@ def tokenize(self, input_string: str) -> List[int]: ) eos_tok = ( [tokenizer.eos_token_id] - if tokenizer.eos_token_id is not None - and input_ids[-1] != tokenizer.eos_token_id + if tokenizer.eos_token_id is not None and input_ids[-1] != tokenizer.eos_token_id else [] ) return bos_tok + input_ids + eos_tok From 7be0ab96d82d5369973d487405844a355e789b15 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 17 Oct 2025 04:35:03 +0000 Subject: [PATCH 10/18] Fix --- python/pyproject.toml | 4 ++-- python/sgl_jax/tools/trace_diff.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index a490e9485..914685bf5 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -91,10 +91,10 @@ select = [ "G", ] ignore = [ + # line too long handled by Black + "E501", # star imports "F405", "F403", - # line too long - "E501", # lambda expression assignment "E731", # zip without `strict=` diff --git a/python/sgl_jax/tools/trace_diff.py b/python/sgl_jax/tools/trace_diff.py index 6be5d451d..5c9cb329f 100755 --- a/python/sgl_jax/tools/trace_diff.py +++ b/python/sgl_jax/tools/trace_diff.py @@ -523,7 +523,9 @@ def parse_layer_module(layer_key): ) if total_items > shown_items: print( - f" {Colors.YELLOW}... and {total_items - shown_items} more{Colors.RESET}" + " " + f"{Colors.YELLOW}... and {total_items - shown_items} more" + f"{Colors.RESET}" ) # Print root-level differences @@ -628,7 +630,8 @@ def compare_trace_files( traces = groups1[content_hash] trace = traces[0] print( - f" {Colors.YELLOW}-{Colors.RESET} {content_hash} (Request ID: {trace.get('request_id', 'N/A')})" + f" {Colors.YELLOW}-{Colors.RESET} {content_hash} " + f"(Request ID: {trace.get('request_id', 'N/A')})" ) if only_in_2: @@ -637,7 +640,8 @@ def compare_trace_files( traces = groups2[content_hash] trace = traces[0] print( - f" {Colors.YELLOW}+{Colors.RESET} {content_hash} (Request ID: {trace.get('request_id', 'N/A')})" + f" {Colors.YELLOW}+{Colors.RESET} {content_hash} " + f"(Request ID: {trace.get('request_id', 'N/A')})" ) # Summary From b3099668686027c0b22080b5571dcb4f9506b2d0 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 17 Oct 2025 04:52:59 +0000 Subject: [PATCH 11/18] Fix var --- python/sgl_jax/srt/entrypoints/openai/serving_chat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sgl_jax/srt/entrypoints/openai/serving_chat.py b/python/sgl_jax/srt/entrypoints/openai/serving_chat.py index b626175c6..6af84f3e8 100644 --- a/python/sgl_jax/srt/entrypoints/openai/serving_chat.py +++ b/python/sgl_jax/srt/entrypoints/openai/serving_chat.py @@ -167,6 +167,7 @@ def _apply_jinja_template( openai_compatible_messages.append(processed_msg) # Handle assistant prefix for continue_final_message + assistant_prefix = None if ( openai_compatible_messages and openai_compatible_messages[-1]["role"] == "assistant" From 8e268d15f615068e15164e22504683befd83c1e0 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sat, 18 Oct 2025 01:48:15 +0000 Subject: [PATCH 12/18] Fix --- python/sgl_jax/srt/layers/binary_search.py | 10 ++---- python/sgl_jax/srt/layers/embeddings.py | 4 +-- python/sgl_jax/srt/layers/layernorm.py | 31 +++++++++---------- python/sgl_jax/srt/layers/linear.py | 6 ++-- .../sgl_jax/srt/managers/tokenizer_manager.py | 9 ++---- python/sgl_jax/srt/model_loader/loader.py | 7 +---- python/sgl_jax/srt/models/llama.py | 6 +--- python/sgl_jax/srt/utils/jax_utils.py | 10 ++---- python/sgl_jax/srt/utils/weight_utils.py | 2 +- 9 files changed, 30 insertions(+), 55 deletions(-) diff --git a/python/sgl_jax/srt/layers/binary_search.py b/python/sgl_jax/srt/layers/binary_search.py index 8806d853c..b349e7b39 100644 --- a/python/sgl_jax/srt/layers/binary_search.py +++ b/python/sgl_jax/srt/layers/binary_search.py @@ -20,16 +20,14 @@ distributions. """ -from typing import Callable, Sequence +from collections.abc import Callable, Sequence import jax from jax import lax from jax import numpy as jnp -def int32_bsearch( - batch_shape: Sequence[int], predicate: Callable[[jnp.ndarray], jnp.ndarray] -): +def int32_bsearch(batch_shape: Sequence[int], predicate: Callable[[jnp.ndarray], jnp.ndarray]): """Batched binary search over int32 values. For each element of the batch, search for the largest int32 (closest to @@ -54,9 +52,7 @@ def int32_bsearch( # bits. we use uint32 due to numpy promotion/casting rules. midpoint = current_bits predicate_satisfied = predicate(midpoint) - current_bits = current_bits | jnp.where( - predicate_satisfied, jnp.uint32(1 << 31), jnp.uint32(0) - ) + current_bits = current_bits | jnp.where(predicate_satisfied, jnp.uint32(1 << 31), jnp.uint32(0)) del midpoint, predicate_satisfied def loop_body(i, current_bits): diff --git a/python/sgl_jax/srt/layers/embeddings.py b/python/sgl_jax/srt/layers/embeddings.py index b7cb374ce..2ab9fcd51 100644 --- a/python/sgl_jax/srt/layers/embeddings.py +++ b/python/sgl_jax/srt/layers/embeddings.py @@ -190,9 +190,7 @@ def __init__( self.is_neox_style = is_neox_style self.dtype = dtype - inv_freq_np = 1.0 / ( - base ** (np.arange(0, rotary_dim, 2, dtype=np.float32) / rotary_dim) - ) + inv_freq_np = 1.0 / (base ** (np.arange(0, rotary_dim, 2, dtype=np.float32) / rotary_dim)) self._inv_freq_np = inv_freq_np # shape: (rotary_dim // 2,) def __call__( diff --git a/python/sgl_jax/srt/layers/layernorm.py b/python/sgl_jax/srt/layers/layernorm.py index 3ff067b31..2db3b5332 100644 --- a/python/sgl_jax/srt/layers/layernorm.py +++ b/python/sgl_jax/srt/layers/layernorm.py @@ -1,4 +1,5 @@ -from typing import Any, Iterable, Optional, Tuple +from collections.abc import Iterable +from typing import Any import jax import jax.numpy as jnp @@ -9,7 +10,7 @@ from jax import lax -def _canonicalize_axes(rank: int, axes: Axes) -> Tuple[int, ...]: +def _canonicalize_axes(rank: int, axes: Axes) -> tuple[int, ...]: """Returns a tuple of deduplicated, sorted, and positive axes.""" if not isinstance(axes, Iterable): axes = (axes,) @@ -30,13 +31,13 @@ def __init__( num_features: int, *, epsilon: float = 1e-6, - dtype: Optional[Dtype] = None, + dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, use_scale: bool = True, scale_init: Initializer = initializers.ones, reduction_axes: Axes = -1, feature_axes: Axes = -1, - axis_name: Optional[str] = None, + axis_name: str | None = None, axis_index_groups: Any = None, use_fast_variance: bool = True, rngs: rnglib.Rngs, @@ -45,9 +46,7 @@ def __init__( self.scale: nnx.Param[jax.Array] | None if use_scale: - self.scale = nnx.Param( - scale_init(jax.random.PRNGKey(0), feature_shape, param_dtype) - ) + self.scale = nnx.Param(scale_init(jax.random.PRNGKey(0), feature_shape, param_dtype)) else: self.scale = None @@ -63,7 +62,7 @@ def __init__( self.axis_index_groups = axis_index_groups self.use_fast_variance = use_fast_variance - def __call__(self, x, mask: Optional[jax.Array] = None): + def __call__(self, x, mask: jax.Array | None = None): mean, var = _compute_stats( x, self.reduction_axes, @@ -91,12 +90,12 @@ def __call__(self, x, mask: Optional[jax.Array] = None): def _compute_stats( x: Array, axes: Axes, - dtype: Optional[Dtype], - axis_name: Optional[str] = None, + dtype: Dtype | None, + axis_name: str | None = None, axis_index_groups: Any = None, use_mean: bool = True, use_fast_variance: bool = True, - mask: Optional[Array] = None, + mask: Array | None = None, ): if dtype is None: dtype = jnp.result_type(x) @@ -130,9 +129,7 @@ def maybe_distributed_mean(*xs, mask=None): var = jnp.maximum(0.0, mu2 - _abs_sq(mu)) else: mu = maybe_distributed_mean(x, mask=mask) - var = maybe_distributed_mean( - _abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask - ) + var = maybe_distributed_mean(_abs_sq(x - jnp.expand_dims(mu, axes)), mask=mask) else: var = maybe_distributed_mean(_abs_sq(x), mask=mask) mu = jnp.zeros_like(var) @@ -143,11 +140,11 @@ def _normalize( x: Array, mean: Array, var: Array, - scale: Optional[Array], - bias: Optional[Array], + scale: Array | None, + bias: Array | None, reduction_axes: Axes, feature_axes: Axes, - dtype: Optional[Dtype], + dtype: Dtype | None, epsilon: float, ): reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) diff --git a/python/sgl_jax/srt/layers/linear.py b/python/sgl_jax/srt/layers/linear.py index 5e1a9f4b0..4de5ddb6f 100644 --- a/python/sgl_jax/srt/layers/linear.py +++ b/python/sgl_jax/srt/layers/linear.py @@ -48,9 +48,9 @@ def __init__( ) if use_bias: self.bias = nnx.Param( - nnx.with_partitioning( - nnx.initializers.zeros_init(), (kernel_axes[-1],) - )(jax.random.PRNGKey(0), (output_size,), params_dtype) + nnx.with_partitioning(nnx.initializers.zeros_init(), (kernel_axes[-1],))( + jax.random.PRNGKey(0), (output_size,), params_dtype + ) ) else: self.bias = None diff --git a/python/sgl_jax/srt/managers/tokenizer_manager.py b/python/sgl_jax/srt/managers/tokenizer_manager.py index 939eff751..e997da2bc 100644 --- a/python/sgl_jax/srt/managers/tokenizer_manager.py +++ b/python/sgl_jax/srt/managers/tokenizer_manager.py @@ -16,7 +16,7 @@ from collections import deque from datetime import datetime from http import HTTPStatus -from typing import Any, Generic, TypeVar +from typing import Any import fastapi import jax @@ -417,7 +417,7 @@ async def _wait_one_response( while True: try: await asyncio.wait_for(state.event.wait(), timeout=self.wait_timeout) - except asyncio.TimeoutError: + except TimeoutError: if request is not None and await request.is_disconnected(): # Abort the request for disconnected requests (non-streaming, waiting queue) self.abort_request(obj.rid) @@ -1241,10 +1241,7 @@ def running_phase_sigquit_handler(self, signum=None, frame=None): kill_process_tree(os.getpid()) -T = TypeVar("T") - - -class _Communicator(Generic[T]): +class _Communicator[T]: """Note: The communicator now only run up to 1 in-flight request at any time.""" def __init__(self, sender, fan_out: int): diff --git a/python/sgl_jax/srt/model_loader/loader.py b/python/sgl_jax/srt/model_loader/loader.py index 16643ffbc..c5834b0f1 100644 --- a/python/sgl_jax/srt/model_loader/loader.py +++ b/python/sgl_jax/srt/model_loader/loader.py @@ -1,24 +1,19 @@ import dataclasses import logging import os -import time from abc import ABC, abstractmethod -from functools import partial -from typing import Any, List, Optional, Tuple +from typing import Any -import flax.linen as nn import huggingface_hub import jax import jax.numpy as jnp import numpy as np from flax import nnx -from jax.sharding import NamedSharding from sgl_jax.srt.configs.load_config import LoadConfig, LoadFormat from sgl_jax.srt.configs.model_config import ModelConfig from sgl_jax.srt.model_loader.arch import get_model_architecture from sgl_jax.srt.utils.common_utils import get_bool_env_var -from sgl_jax.srt.utils.jax_utils import print_memory logger = logging.getLogger(__name__) diff --git a/python/sgl_jax/srt/models/llama.py b/python/sgl_jax/srt/models/llama.py index 9026e4db3..715b07195 100644 --- a/python/sgl_jax/srt/models/llama.py +++ b/python/sgl_jax/srt/models/llama.py @@ -25,11 +25,7 @@ from transformers import LlamaConfig, PretrainedConfig from sgl_jax.srt.configs.model_config import ModelConfig -from sgl_jax.srt.layers.embeddings import ( - Embed, - ParallelLMHead, - get_rope, -) +from sgl_jax.srt.layers.embeddings import Embed, ParallelLMHead, get_rope from sgl_jax.srt.layers.layernorm import RMSNorm from sgl_jax.srt.layers.linear import LinearBase from sgl_jax.srt.layers.logits_processor import LogitsMetadata, LogitsProcessor diff --git a/python/sgl_jax/srt/utils/jax_utils.py b/python/sgl_jax/srt/utils/jax_utils.py index ad9e8ad7f..65e543e23 100644 --- a/python/sgl_jax/srt/utils/jax_utils.py +++ b/python/sgl_jax/srt/utils/jax_utils.py @@ -120,11 +120,7 @@ def print_memory(stage_name): memory = get_memory_usage() print(f"\n[{stage_name}] Memory usage:") for device, usage in memory.items(): - print( - f" {device}: {usage}GB" - if isinstance(usage, float) - else f" {device}: {usage}" - ) + print(f" {device}: {usage}GB" if isinstance(usage, float) else f" {device}: {usage}") return memory @@ -136,8 +132,8 @@ def get_memory_usage(): try: device_stats = device.memory_stats() stats[f"device_{i}"] = device_stats.get("bytes_in_use", 0) / (1024**3) - except: + except Exception: stats[f"device_{i}"] = "N/A" return stats - except: + except Exception: return {f"device_{i}": "N/A" for i in range(len(jax.devices()))} diff --git a/python/sgl_jax/srt/utils/weight_utils.py b/python/sgl_jax/srt/utils/weight_utils.py index 2a5923198..9c0b6ba17 100644 --- a/python/sgl_jax/srt/utils/weight_utils.py +++ b/python/sgl_jax/srt/utils/weight_utils.py @@ -123,7 +123,7 @@ def load_weights_from_safetensors( if self._is_excluded_layer_weight(hf_key): logger.debug("Skipping excluded layer weight: %s", hf_key) else: - logger.warning(f"No mapping found for weight: {hf_key}") + logger.warning("No mapping found for weight: %s", hf_key) nnx.update(self.model, params) if moe_mappings: From 6eeba235b687f6e6293811965678a60e92f65b33 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sat, 18 Oct 2025 01:52:29 +0000 Subject: [PATCH 13/18] Update lint version --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 1190280de..1f4592d86 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -11,7 +11,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.9' + python-version: '3.12' - name: Install pre-commit hook run: | From 00b4c3d12410b440cd4e7fe79f214d3f2175b5eb Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sat, 18 Oct 2025 05:02:13 +0000 Subject: [PATCH 14/18] Reformat --- python/sgl_jax/srt/model_loader/loader.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/sgl_jax/srt/model_loader/loader.py b/python/sgl_jax/srt/model_loader/loader.py index c5834b0f1..237c68e75 100644 --- a/python/sgl_jax/srt/model_loader/loader.py +++ b/python/sgl_jax/srt/model_loader/loader.py @@ -92,9 +92,7 @@ def _initialize_model(self, model_config: ModelConfig) -> Any: def _get_model(self, model_class: Any, model_config: ModelConfig) -> nnx.Module: with self.mesh: model = nnx.eval_shape( - lambda: model_class( - model_config.hf_config, model_config.dtype, self.rng, self.mesh - ) + lambda: model_class(model_config.hf_config, model_config.dtype, self.rng, self.mesh) ) model.load_weights(model_config, self.rng.default.key.value) From 7fac20269251f12b216ca89fe45a509d6ecd6acb Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 20 Oct 2025 02:40:24 +0000 Subject: [PATCH 15/18] Recover --- .../flash_attn_kernel/tuned_block_sizes.py | 4 +- .../srt/layers/attention/native_backend.py | 6 +- python/sgl_jax/srt/mem_cache/memory_pool.py | 87 +++++++++++++++++-- 3 files changed, 87 insertions(+), 10 deletions(-) diff --git a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py index 1ff0b5ff7..124a5f357 100644 --- a/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py +++ b/python/sgl_jax/srt/layers/attention/flash_attn_kernel/tuned_block_sizes.py @@ -142,9 +142,9 @@ ("bfloat16", "bfloat16", 4, 4, 128, 256, 64): (8, 16), ("bfloat16", "bfloat16", 4, 4, 128, 256, 128): (8, 4), ("bfloat16", "bfloat16", 4, 4, 128, 256, 256): (8, 1), - ("bfloat16", "bfloat16", 4, 4, 128, 256, 512): (1, 128), + ("bfloat16", "bfloat16", 4, 4, 128, 256, 512): (4, 128), ("bfloat16", "bfloat16", 4, 4, 128, 256, 1024): (4, 128), - ("bfloat16", "bfloat16", 4, 4, 128, 256, 2048): (4, 128), + ("bfloat16", "bfloat16", 4, 4, 128, 256, 2048): (1, 128), ("bfloat16", "bfloat16", 4, 4, 128, 256, 4096): (8, 128), ("bfloat16", "bfloat16", 4, 4, 128, 256, 8192): (8, 128), ("bfloat16", "bfloat16", 8, 2, 128, 64, 1): (1, 16), diff --git a/python/sgl_jax/srt/layers/attention/native_backend.py b/python/sgl_jax/srt/layers/attention/native_backend.py index bb7c5341a..fa389aae3 100644 --- a/python/sgl_jax/srt/layers/attention/native_backend.py +++ b/python/sgl_jax/srt/layers/attention/native_backend.py @@ -62,10 +62,12 @@ def __call__( scale = 1.0 / jnp.sqrt(layer.head_dim) if layer.scaling is None else layer.scaling - is_causal = not ( + is_causal = True + if ( forward_batch.forward_mode == ForwardMode.DECODE or layer.attn_type == AttentionType.ENCODER_ONLY - ) + ): + is_causal = False attn_output = forward_attention( q, diff --git a/python/sgl_jax/srt/mem_cache/memory_pool.py b/python/sgl_jax/srt/mem_cache/memory_pool.py index f3dde6344..9edf0f7ba 100644 --- a/python/sgl_jax/srt/mem_cache/memory_pool.py +++ b/python/sgl_jax/srt/mem_cache/memory_pool.py @@ -1152,12 +1152,87 @@ def load_cpu_copy(self, kv_cache_host, indices): "hn_32_mcl_320000_nvl_16384_hd_128_ps_128": 4, "hn_32_mcl_640000_nvl_1024_hd_128_ps_128": 512, "hn_32_mcl_640000_nvl_2048_hd_128_ps_128": 4096, - "hn_32_mcl_640000_nvl_4096_hd_128_ps_128": 4096, - "hn_32_mcl_640000_nvl_9182_hd_128_ps_128": 256, - "hn_32_mcl_640000_nvl_16384_hd_128_ps_128": 128, + "hn_32_mcl_640000_nvl_4096_hd_128_ps_128": 4, + "hn_32_mcl_640000_nvl_9182_hd_128_ps_128": 2, + "hn_32_mcl_640000_nvl_16384_hd_128_ps_128": 1024, "hn_32_mcl_1280000_nvl_1024_hd_128_ps_128": 32, "hn_32_mcl_1280000_nvl_2048_hd_128_ps_128": 2048, - "hn_32_mcl_1280000_nvl_4096_hd_128_ps_128": 8, - "hn_32_mcl_1280000_nvl_9182_hd_128_ps_128": 4, - "hn_32_mcl_1280000_nvl_16384_hd_128_ps_128": 2, + "hn_32_mcl_1280000_nvl_4096_hd_128_ps_128": 128, + "hn_32_mcl_1280000_nvl_9182_hd_128_ps_128": 1024, + "hn_32_mcl_1280000_nvl_16384_hd_128_ps_128": 1024, + "hn_8_mcl_80000_nvl_1024_hd_128_ps_256": 2, + "hn_8_mcl_80000_nvl_2048_hd_128_ps_256": 4, + "hn_8_mcl_80000_nvl_4096_hd_128_ps_256": 2, + "hn_8_mcl_80000_nvl_9182_hd_128_ps_256": 512, + "hn_8_mcl_80000_nvl_16384_hd_128_ps_256": 2048, + "hn_8_mcl_160000_nvl_1024_hd_128_ps_256": 4, + "hn_8_mcl_160000_nvl_2048_hd_128_ps_256": 256, + "hn_8_mcl_160000_nvl_4096_hd_128_ps_256": 128, + "hn_8_mcl_160000_nvl_9182_hd_128_ps_256": 2048, + "hn_8_mcl_160000_nvl_16384_hd_128_ps_256": 2048, + "hn_8_mcl_320000_nvl_1024_hd_128_ps_256": 4, + "hn_8_mcl_320000_nvl_2048_hd_128_ps_256": 1024, + "hn_8_mcl_320000_nvl_4096_hd_128_ps_256": 64, + "hn_8_mcl_320000_nvl_9182_hd_128_ps_256": 16, + "hn_8_mcl_320000_nvl_16384_hd_128_ps_256": 512, + "hn_8_mcl_640000_nvl_1024_hd_128_ps_256": 1024, + "hn_8_mcl_640000_nvl_2048_hd_128_ps_256": 8, + "hn_8_mcl_640000_nvl_4096_hd_128_ps_256": 16, + "hn_8_mcl_640000_nvl_9182_hd_128_ps_256": 16, + "hn_8_mcl_640000_nvl_16384_hd_128_ps_256": 4096, + "hn_8_mcl_1280000_nvl_1024_hd_128_ps_256": 64, + "hn_8_mcl_1280000_nvl_2048_hd_128_ps_256": 2, + "hn_8_mcl_1280000_nvl_4096_hd_128_ps_256": 2048, + "hn_8_mcl_1280000_nvl_9182_hd_128_ps_256": 1024, + "hn_8_mcl_1280000_nvl_16384_hd_128_ps_256": 128, + "hn_16_mcl_80000_nvl_1024_hd_128_ps_256": 2, + "hn_16_mcl_80000_nvl_2048_hd_128_ps_256": 16, + "hn_16_mcl_80000_nvl_4096_hd_128_ps_256": 64, + "hn_16_mcl_80000_nvl_9182_hd_128_ps_256": 256, + "hn_16_mcl_80000_nvl_16384_hd_128_ps_256": 16, + "hn_16_mcl_160000_nvl_1024_hd_128_ps_256": 4, + "hn_16_mcl_160000_nvl_2048_hd_128_ps_256": 2, + "hn_16_mcl_160000_nvl_4096_hd_128_ps_256": 128, + "hn_16_mcl_160000_nvl_9182_hd_128_ps_256": 16, + "hn_16_mcl_160000_nvl_16384_hd_128_ps_256": 8, + "hn_16_mcl_320000_nvl_1024_hd_128_ps_256": 16, + "hn_16_mcl_320000_nvl_2048_hd_128_ps_256": 8, + "hn_16_mcl_320000_nvl_4096_hd_128_ps_256": 4, + "hn_16_mcl_320000_nvl_9182_hd_128_ps_256": 8, + "hn_16_mcl_320000_nvl_16384_hd_128_ps_256": 8, + "hn_16_mcl_640000_nvl_1024_hd_128_ps_256": 512, + "hn_16_mcl_640000_nvl_2048_hd_128_ps_256": 1024, + "hn_16_mcl_640000_nvl_4096_hd_128_ps_256": 2048, + "hn_16_mcl_640000_nvl_9182_hd_128_ps_256": 4096, + "hn_16_mcl_640000_nvl_16384_hd_128_ps_256": 32, + "hn_16_mcl_1280000_nvl_1024_hd_128_ps_256": 4, + "hn_16_mcl_1280000_nvl_2048_hd_128_ps_256": 2, + "hn_16_mcl_1280000_nvl_4096_hd_128_ps_256": 1024, + "hn_16_mcl_1280000_nvl_9182_hd_128_ps_256": 2048, + "hn_16_mcl_1280000_nvl_16384_hd_128_ps_256": 16, + "hn_32_mcl_80000_nvl_1024_hd_128_ps_256": 4, + "hn_32_mcl_80000_nvl_2048_hd_128_ps_256": 256, + "hn_32_mcl_80000_nvl_4096_hd_128_ps_256": 4096, + "hn_32_mcl_80000_nvl_9182_hd_128_ps_256": 128, + "hn_32_mcl_80000_nvl_16384_hd_128_ps_256": 512, + "hn_32_mcl_160000_nvl_1024_hd_128_ps_256": 64, + "hn_32_mcl_160000_nvl_2048_hd_128_ps_256": 4096, + "hn_32_mcl_160000_nvl_4096_hd_128_ps_256": 4096, + "hn_32_mcl_160000_nvl_9182_hd_128_ps_256": 256, + "hn_32_mcl_160000_nvl_16384_hd_128_ps_256": 128, + "hn_32_mcl_320000_nvl_1024_hd_128_ps_256": 4, + "hn_32_mcl_320000_nvl_2048_hd_128_ps_256": 64, + "hn_32_mcl_320000_nvl_4096_hd_128_ps_256": 1024, + "hn_32_mcl_320000_nvl_9182_hd_128_ps_256": 256, + "hn_32_mcl_320000_nvl_16384_hd_128_ps_256": 32, + "hn_32_mcl_640000_nvl_1024_hd_128_ps_256": 256, + "hn_32_mcl_640000_nvl_2048_hd_128_ps_256": 8, + "hn_32_mcl_640000_nvl_4096_hd_128_ps_256": 64, + "hn_32_mcl_640000_nvl_9182_hd_128_ps_256": 32, + "hn_32_mcl_640000_nvl_16384_hd_128_ps_256": 32, + "hn_32_mcl_1280000_nvl_1024_hd_128_ps_256": 256, + "hn_32_mcl_1280000_nvl_2048_hd_128_ps_256": 2048, + "hn_32_mcl_1280000_nvl_4096_hd_128_ps_256": 8, + "hn_32_mcl_1280000_nvl_9182_hd_128_ps_256": 4, + "hn_32_mcl_1280000_nvl_16384_hd_128_ps_256": 2, } From 0f76da3c3c6d3d93c4d6644a76d7891e416f0b89 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 20 Oct 2025 02:44:38 +0000 Subject: [PATCH 16/18] Fix --- python/sgl_jax/srt/entrypoints/engine.py | 4 +--- python/sgl_jax/srt/managers/scheduler.py | 2 +- python/sgl_jax/srt/managers/tp_worker_overlap_thread.py | 2 +- python/sgl_jax/srt/sampling/sampling_batch_info.py | 4 +--- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/python/sgl_jax/srt/entrypoints/engine.py b/python/sgl_jax/srt/entrypoints/engine.py index bee350f34..0563e6108 100644 --- a/python/sgl_jax/srt/entrypoints/engine.py +++ b/python/sgl_jax/srt/entrypoints/engine.py @@ -606,9 +606,7 @@ def _launch_threads( # Wait for the model to finish loading for i in range(len(scheduler_infos)): if scheduler_infos[i]["status"] != "ready": - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) + raise RuntimeError("Initialization failed. Please see the error messages above.") # Assume all schedulers have the same scheduler_info assert len(scheduler_infos) > 0, "scheduler_infos is empty" diff --git a/python/sgl_jax/srt/managers/scheduler.py b/python/sgl_jax/srt/managers/scheduler.py index 806db7126..8fd944a5d 100644 --- a/python/sgl_jax/srt/managers/scheduler.py +++ b/python/sgl_jax/srt/managers/scheduler.py @@ -1085,7 +1085,7 @@ def run_scheduler_loop_thread_after_create( } except Exception: traceback = get_exception_traceback() - logger.error(f"Scheduler hit an exception: {traceback}") + logger.error("Scheduler hit an exception: %s", traceback) current_process.send_signal(signal.SIGQUIT) diff --git a/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py b/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py index c781b952a..be43e7740 100644 --- a/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py +++ b/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py @@ -48,7 +48,7 @@ def __init__( # JAX handles device execution automatically, no need for explicit streams self.forward_thread = threading.Thread( target=self.forward_thread_func, - daemon=True if server_args.enable_single_process else False, + daemon=bool(server_args.enable_single_process), ) self.forward_thread.start() self.parent_process = psutil.Process().parent() diff --git a/python/sgl_jax/srt/sampling/sampling_batch_info.py b/python/sgl_jax/srt/sampling/sampling_batch_info.py index dee50e674..2393949b5 100644 --- a/python/sgl_jax/srt/sampling/sampling_batch_info.py +++ b/python/sgl_jax/srt/sampling/sampling_batch_info.py @@ -252,9 +252,7 @@ def from_model_worker_batch_for_precompile( ), ] ) - sampling_seeds_device = device_array( - padded_sampling_seeds, sharding=sharding - ) + sampling_seeds_device = device_array(padded_sampling_seeds, sharding=sharding) else: sampling_seeds_device = None From 3d05c7216fea463e28756cc71f5f6c83abc08455 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 20 Oct 2025 03:04:10 +0000 Subject: [PATCH 17/18] fix indent --- python/sgl_jax/srt/configs/model_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sgl_jax/srt/configs/model_config.py b/python/sgl_jax/srt/configs/model_config.py index 72fd35bed..ca56da293 100644 --- a/python/sgl_jax/srt/configs/model_config.py +++ b/python/sgl_jax/srt/configs/model_config.py @@ -455,7 +455,7 @@ def _get_and_verify_dtype( else: # Casting between float16 and bfloat16 is allowed with a warning. logger.warning("Casting %s to %s.", config_dtype, jax_dtype) - return jax_dtype + return jax_dtype def is_generation_model(model_architectures: list[str], is_embedding: bool = False): From 030348a4680f87fec5b1ddb8d2aedd462eb13bc1 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 20 Oct 2025 04:05:38 +0000 Subject: [PATCH 18/18] Fix --- python/sgl_jax/srt/layers/logits_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sgl_jax/srt/layers/logits_processor.py b/python/sgl_jax/srt/layers/logits_processor.py index 7f1ec8aa7..3fc489db6 100644 --- a/python/sgl_jax/srt/layers/logits_processor.py +++ b/python/sgl_jax/srt/layers/logits_processor.py @@ -290,7 +290,7 @@ def __call__( pruned_states[sample_indices] if sample_indices is not None else pruned_states ) else: - raise AssertionError() + raise AssertionError("This branch should not be reached") if not logits_metadata.extend_return_logprob: # Decode mode or extend mode without return_logprob.