1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import asyncio
1615import multiprocessing
16+ import time
1717from typing import Dict , Optional
1818
19- import pytest
2019import xoscar as xo
2120
22- from ...core .supervisor import SupervisorActor
21+ from ...api import restful_api
22+ from ...client import Client
2323
2424
25- # test restart supervisor
26- @pytest .mark .asyncio
27- async def test_restart_supervisor ():
25+ def test_restart_supervisor ():
2826 from ...deploy .supervisor import run_in_subprocess as supervisor_run_in_subprocess
2927 from ...deploy .worker import main as _start_worker
3028
@@ -39,50 +37,59 @@ def worker_run_in_subprocess(
3937 return p
4038
4139 # start supervisor
42- supervisor_address = f"localhost:{ xo .utils .get_next_port ()} "
40+ web_port , supervisor_port = xo .utils .get_next_port (), xo .utils .get_next_port ()
41+ supervisor_address = f"127.0.0.1:{ supervisor_port } "
4342 proc_supervisor = supervisor_run_in_subprocess (supervisor_address )
43+ rest_api_proc = multiprocessing .Process (
44+ target = restful_api .run ,
45+ kwargs = dict (
46+ supervisor_address = supervisor_address , host = "127.0.0.1" , port = web_port
47+ ),
48+ )
49+ rest_api_proc .start ()
4450
45- await asyncio .sleep (5 )
51+ time .sleep (5 )
4652
4753 # start worker
48- worker_run_in_subprocess (
49- address = f"localhost :{ xo .utils .get_next_port ()} " ,
54+ proc_worker = worker_run_in_subprocess (
55+ address = f"127.0.0.1 :{ xo .utils .get_next_port ()} " ,
5056 supervisor_address = supervisor_address ,
5157 )
5258
53- await asyncio .sleep (10 )
54-
55- # load model
56- supervisor_ref = await xo .actor_ref (
57- supervisor_address , SupervisorActor .default_uid ()
58- )
59+ time .sleep (10 )
5960
60- model_uid = "qwen1.5-chat"
61- await supervisor_ref .launch_builtin_model (
62- model_uid = model_uid ,
63- model_name = "qwen1.5-chat" ,
64- model_size_in_billions = "0_5" ,
65- quantization = "q4_0" ,
66- model_engine = "llama.cpp" ,
67- )
61+ client = Client (f"http://127.0.0.1:{ web_port } " )
6862
69- # query replica info
70- model_replica_info = await supervisor_ref .describe_model (model_uid )
63+ try :
64+ model_uid = "qwen1.5-chat"
65+ client .launch_model (
66+ model_uid = model_uid ,
67+ model_name = "qwen1.5-chat" ,
68+ model_size_in_billions = "0_5" ,
69+ quantization = "q4_0" ,
70+ model_engine = "llama.cpp" ,
71+ )
7172
72- # kill supervisor
73- proc_supervisor . terminate ( )
74- proc_supervisor . join ()
73+ # query replica info
74+ model_replica_info = client . describe_model ( model_uid )
75+ assert model_replica_info is not None
7576
76- # restart supervisor
77- proc_supervisor = supervisor_run_in_subprocess (supervisor_address )
77+ # kill supervisor
78+ proc_supervisor .terminate ()
79+ proc_supervisor .join ()
7880
79- await asyncio .sleep (5 )
81+ # restart supervisor
82+ supervisor_run_in_subprocess (supervisor_address )
8083
81- supervisor_ref = await xo .actor_ref (
82- supervisor_address , SupervisorActor .default_uid ()
83- )
84+ time .sleep (5 )
8485
85- # check replica info
86- model_replic_info_check = await supervisor_ref .describe_model (model_uid )
86+ # check replica info
87+ model_replic_info_check = client .describe_model (model_uid )
88+ assert model_replica_info ["replica" ] == model_replic_info_check ["replica" ]
8789
88- assert model_replica_info ["replica" ] == model_replic_info_check ["replica" ]
90+ finally :
91+ client .abort_cluster ()
92+ proc_supervisor .terminate ()
93+ proc_worker .terminate ()
94+ proc_supervisor .join ()
95+ proc_worker .join ()
0 commit comments