Skip to content

Commit e2f184e

Browse files
committed
fix tests
1 parent 1ffb392 commit e2f184e

File tree

2 files changed

+47
-39
lines changed

2 files changed

+47
-39
lines changed

xinference/core/tests/test_restart_supervisor.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import asyncio
1615
import multiprocessing
16+
import time
1717
from typing import Dict, Optional
1818

19-
import pytest
2019
import xoscar as xo
2120

22-
from ...core.supervisor import SupervisorActor
21+
from ...api import restful_api
22+
from ...client import Client
2323

2424

25-
# test restart supervisor
26-
@pytest.mark.asyncio
27-
async def test_restart_supervisor():
25+
def test_restart_supervisor():
2826
from ...deploy.supervisor import run_in_subprocess as supervisor_run_in_subprocess
2927
from ...deploy.worker import main as _start_worker
3028

@@ -39,50 +37,59 @@ def worker_run_in_subprocess(
3937
return p
4038

4139
# start supervisor
42-
supervisor_address = f"localhost:{xo.utils.get_next_port()}"
40+
web_port, supervisor_port = xo.utils.get_next_port(), xo.utils.get_next_port()
41+
supervisor_address = f"127.0.0.1:{supervisor_port}"
4342
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)
43+
rest_api_proc = multiprocessing.Process(
44+
target=restful_api.run,
45+
kwargs=dict(
46+
supervisor_address=supervisor_address, host="127.0.0.1", port=web_port
47+
),
48+
)
49+
rest_api_proc.start()
4450

45-
await asyncio.sleep(5)
51+
time.sleep(5)
4652

4753
# start worker
48-
worker_run_in_subprocess(
49-
address=f"localhost:{xo.utils.get_next_port()}",
54+
proc_worker = worker_run_in_subprocess(
55+
address=f"127.0.0.1:{xo.utils.get_next_port()}",
5056
supervisor_address=supervisor_address,
5157
)
5258

53-
await asyncio.sleep(10)
54-
55-
# load model
56-
supervisor_ref = await xo.actor_ref(
57-
supervisor_address, SupervisorActor.default_uid()
58-
)
59+
time.sleep(10)
5960

60-
model_uid = "qwen1.5-chat"
61-
await supervisor_ref.launch_builtin_model(
62-
model_uid=model_uid,
63-
model_name="qwen1.5-chat",
64-
model_size_in_billions="0_5",
65-
quantization="q4_0",
66-
model_engine="llama.cpp",
67-
)
61+
client = Client(f"http://127.0.0.1:{web_port}")
6862

69-
# query replica info
70-
model_replica_info = await supervisor_ref.describe_model(model_uid)
63+
try:
64+
model_uid = "qwen1.5-chat"
65+
client.launch_model(
66+
model_uid=model_uid,
67+
model_name="qwen1.5-chat",
68+
model_size_in_billions="0_5",
69+
quantization="q4_0",
70+
model_engine="llama.cpp",
71+
)
7172

72-
# kill supervisor
73-
proc_supervisor.terminate()
74-
proc_supervisor.join()
73+
# query replica info
74+
model_replica_info = client.describe_model(model_uid)
75+
assert model_replica_info is not None
7576

76-
# restart supervisor
77-
proc_supervisor = supervisor_run_in_subprocess(supervisor_address)
77+
# kill supervisor
78+
proc_supervisor.terminate()
79+
proc_supervisor.join()
7880

79-
await asyncio.sleep(5)
81+
# restart supervisor
82+
supervisor_run_in_subprocess(supervisor_address)
8083

81-
supervisor_ref = await xo.actor_ref(
82-
supervisor_address, SupervisorActor.default_uid()
83-
)
84+
time.sleep(5)
8485

85-
# check replica info
86-
model_replic_info_check = await supervisor_ref.describe_model(model_uid)
86+
# check replica info
87+
model_replic_info_check = client.describe_model(model_uid)
88+
assert model_replica_info["replica"] == model_replic_info_check["replica"]
8789

88-
assert model_replica_info["replica"] == model_replic_info_check["replica"]
90+
finally:
91+
client.abort_cluster()
92+
proc_supervisor.terminate()
93+
proc_worker.terminate()
94+
proc_supervisor.join()
95+
proc_worker.join()

xinference/deploy/supervisor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333

3434

3535
async def _start_supervisor(address: str, logging_conf: Optional[Dict] = None):
36-
logging.config.dictConfig(logging_conf) # type: ignore
36+
if logging_conf:
37+
logging.config.dictConfig(logging_conf) # type: ignore
3738

3839
pool = None
3940
try:

0 commit comments

Comments
 (0)