Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/setup_environment/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ runs:
uses: astral-sh/setup-uv@v4
with:
version: "latest"
python-version: "3.10"
python-version: "3.12"
- name: Install skypilot and login
shell: bash
run: |
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ jobs:
- name: Run test
timeout-minutes: 120
run: |
uv venv && uv pip install -e python/
uv venv --python 3.12 && source .venv/bin/activate && uv pip install -e python/
bash scripts/killall_sglang.sh
uv run python test/srt/run_suite.py --suite per-commit-tpu-v6e-1
deactivate

e2e-test-4-tpu:
if: github.event.pull_request.draft == false
Expand All @@ -52,9 +53,10 @@ jobs:
- name: Run test
timeout-minutes: 120
run: |
uv venv && uv pip install -e python/
uv venv --python 3.12 && source .venv/bin/activate && uv pip install -e python/
bash scripts/killall_sglang.sh
uv run python test/srt/run_suite.py --suite per-commit-tpu-v6e-4
deactivate

# performance-test-1-tpu:
# if: github.event.pull_request.draft == false
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies = [
"fastapi~=0.116.1",
"flax~=0.10.7",
"huggingface-hub~=0.34.3",
"jax[tpu]~=0.6.2",
"jax[tpu]~=0.7.2",
"jinja2~=3.1.6",
"modelscope~=1.28.2",
"msgpack-python~=0.5.6",
Expand Down
5 changes: 1 addition & 4 deletions python/sgl_jax/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,7 @@ def _forward_raw(
forward_batch: ForwardBatch,
logits_metadata: LogitsMetadata,
) -> Tuple[LogitsProcessorOutput, int]:
# for compatibility, 0.6.3 need to use use_mesh. set_mesh is not have __entry__ attribute.
# on jax 0.7.1, we need to use set_mesh.
# with jax.sharding.set_mesh(self.mesh):
with jax.sharding.use_mesh(self.mesh):
with jax.sharding.set_mesh(self.mesh):
if (
forward_batch.forward_mode.is_decode()
or forward_batch.forward_mode.is_extend()
Expand Down
Loading