diff --git a/.github/actions/setup_environment/action.yml b/.github/actions/setup_environment/action.yml index 83cd460b..a1dd7f9e 100644 --- a/.github/actions/setup_environment/action.yml +++ b/.github/actions/setup_environment/action.yml @@ -25,7 +25,7 @@ runs: uses: astral-sh/setup-uv@v4 with: version: "latest" - python-version: "3.10" + python-version: "3.12" - name: Install skypilot and login shell: bash run: | diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 1f6bd873..57f25775 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -37,9 +37,10 @@ jobs: - name: Run test timeout-minutes: 120 run: | - uv venv && uv pip install -e python/ + uv venv --python 3.12 && source .venv/bin/activate && uv pip install -e python/ bash scripts/killall_sglang.sh uv run python test/srt/run_suite.py --suite per-commit-tpu-v6e-1 + deactivate e2e-test-4-tpu: if: github.event.pull_request.draft == false @@ -52,9 +53,10 @@ jobs: - name: Run test timeout-minutes: 120 run: | - uv venv && uv pip install -e python/ + uv venv --python 3.12 && source .venv/bin/activate && uv pip install -e python/ bash scripts/killall_sglang.sh uv run python test/srt/run_suite.py --suite per-commit-tpu-v6e-4 + deactivate # performance-test-1-tpu: # if: github.event.pull_request.draft == false diff --git a/python/pyproject.toml b/python/pyproject.toml index 26916d95..3c53bd0d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "fastapi~=0.116.1", "flax~=0.10.7", "huggingface-hub~=0.34.3", - "jax[tpu]~=0.6.2", + "jax[tpu]~=0.7.2", "jinja2~=3.1.6", "modelscope~=1.28.2", "msgpack-python~=0.5.6", diff --git a/python/sgl_jax/srt/model_executor/model_runner.py b/python/sgl_jax/srt/model_executor/model_runner.py index 39f6a693..eefa2085 100644 --- a/python/sgl_jax/srt/model_executor/model_runner.py +++ b/python/sgl_jax/srt/model_executor/model_runner.py @@ -421,10 +421,7 @@ def _forward_raw( forward_batch: ForwardBatch, logits_metadata: LogitsMetadata, ) -> Tuple[LogitsProcessorOutput, int]: - # for compatibility, 0.6.3 need to use use_mesh. set_mesh is not have __entry__ attribute. - # on jax 0.7.1, we need to use set_mesh. - # with jax.sharding.set_mesh(self.mesh): - with jax.sharding.use_mesh(self.mesh): + with jax.sharding.set_mesh(self.mesh): if ( forward_batch.forward_mode.is_decode() or forward_batch.forward_mode.is_extend()