24
24
25
25
import numpy as np
26
26
from server .engine .resource_manager import ResourceManager
27
- from server .engine .task_queue_manager import (TaskQueueManager ,
28
- launch_queue_service )
27
+ from server .engine .task_queue_manager import TaskQueueManager , launch_queue_service
29
28
from server .engine .token_processor import TokenProcessor , WarmUpTokenProcessor
30
29
from server .utils import model_server_logger
31
30
@@ -34,10 +33,11 @@ class Engine(object):
34
33
"""
35
34
Engine Class
36
35
"""
36
+
37
37
def __init__ (self , cfg , token_processor ):
38
38
self .cfg = cfg
39
39
# Master node only
40
- if self .cfg .nnode == 1 or self .cfg .host_ip == os .getenv (' POD_0_IP' , ' 127.0.0.1' ):
40
+ if self .cfg .nnode == 1 or self .cfg .host_ip == os .getenv (" POD_0_IP" , " 127.0.0.1" ):
41
41
self .queue_service = self ._start_tasks_queue_service ()
42
42
self .tasks_queue = TaskQueueManager (mp_num = self .cfg .mp_num , port = self .cfg .infer_port )
43
43
self .resource_manager = ResourceManager (self .cfg )
@@ -53,8 +53,10 @@ def start(self):
53
53
initialize engine and start sub services
54
54
"""
55
55
assert not self .is_started , "The engine is already started.!"
56
+ msg_queue_id_str = os .getenv ("INFERENCE_MSG_QUEUE_ID" , str (os .getpid ()))
57
+ os .environ ["INFERENCE_MSG_QUEUE_ID" ] = msg_queue_id_str
56
58
start_time = time .time ()
57
-
59
+
58
60
self .token_processor .tasks_queue = self .tasks_queue
59
61
self .infer_proc = self ._start_infer_service ()
60
62
model_server_logger .info ("Waiting infer processes ready..." )
@@ -80,17 +82,18 @@ def warmup(self):
80
82
"""
81
83
# get eos_token_id
82
84
from server .data .processor import DataProcessor
85
+
83
86
eos_token_ids = DataProcessor ().get_eos_tokens ()
84
87
85
- # construct test tasks
88
+ # construct test tasks
86
89
res_task = []
87
90
for j in range (2 * self .cfg .max_batch_size ):
88
91
data = {
89
92
"input_ids" : [5 ],
90
93
"req_id" : j ,
91
94
"max_dec_len" : self .cfg .dec_len_limit ,
92
95
"min_dec_len" : int (self .cfg .dec_len_limit * 0.5 ) + 1 ,
93
- "eos_token_ids" : eos_token_ids
96
+ "eos_token_ids" : eos_token_ids ,
94
97
}
95
98
res_task .append (data )
96
99
for j in range (2 * self .cfg .max_prefill_batch ):
@@ -99,7 +102,7 @@ def warmup(self):
99
102
"req_id" : j + 2 * self .cfg .max_batch_size ,
100
103
"max_dec_len" : 1 ,
101
104
"min_dec_len" : 1 ,
102
- "eos_token_ids" : eos_token_ids
105
+ "eos_token_ids" : eos_token_ids ,
103
106
}
104
107
res_task .append (data )
105
108
@@ -130,8 +133,9 @@ def insert_tasks(self, tasks):
130
133
131
134
available_batch = np .sum (self .resource_manager .stop_flags )
132
135
if len (tasks ) > available_batch :
133
- model_server_logger .error ("Inserting batch:{} exceeds the available batch:{}." .format (
134
- len (tasks ), available_batch ))
136
+ model_server_logger .error (
137
+ "Inserting batch:{} exceeds the available batch:{}." .format (len (tasks ), available_batch )
138
+ )
135
139
model_server_logger .error ("The exceeded part will be ignored!" )
136
140
tasks = tasks [:available_batch ]
137
141
@@ -140,21 +144,23 @@ def insert_tasks(self, tasks):
140
144
input_token_num = len (tasks [i ]["input_ids" ])
141
145
if input_token_num >= self .cfg .max_seq_len - 1 :
142
146
model_server_logger .warning (f"{ req_id } : Input length:{ input_token_num } , exceed the limits." )
143
- tasks [i ]["input_ids" ] = tasks [i ]["input_ids" ][:self .cfg .max_seq_len - 1 ]
147
+ tasks [i ]["input_ids" ] = tasks [i ]["input_ids" ][: self .cfg .max_seq_len - 1 ]
144
148
if "seq_len" in tasks [i ] and "max_dec_len" not in tasks [i ]:
145
149
tasks [i ]["max_dec_len" ] = tasks [i ]["seq_len" ]
146
150
147
151
# max_dec_len + input_token_num > MAX_SEQ_LEN
148
152
if input_token_num + tasks [i ]["max_dec_len" ] > self .cfg .max_seq_len :
149
153
tasks [i ]["max_dec_len" ] = self .cfg .max_seq_len - input_token_num
150
- model_server_logger .warning ("Force max_dec_len to be {} for req_id={}." .format (
151
- tasks [i ]["max_dec_len" ], tasks [i ]["req_id" ]))
154
+ model_server_logger .warning (
155
+ "Force max_dec_len to be {} for req_id={}." .format (tasks [i ]["max_dec_len" ], tasks [i ]["req_id" ])
156
+ )
152
157
153
158
# min_dec_len + input_token_num > MAX_SEQ_LEN
154
159
if input_token_num + tasks [i ]["min_dec_len" ] > self .cfg .max_seq_len :
155
160
tasks [i ]["min_dec_len" ] = self .cfg .max_seq_len - input_token_num
156
- model_server_logger .warning ("Force min_dec_len to be {} for req_id={}." .format (
157
- tasks [i ]["min_dec_len" ], tasks [i ]["req_id" ]))
161
+ model_server_logger .warning (
162
+ "Force min_dec_len to be {} for req_id={}." .format (tasks [i ]["min_dec_len" ], tasks [i ]["req_id" ])
163
+ )
158
164
159
165
tasks = self .resource_manager .allocate_resources_for_new_tasks (tasks )
160
166
if not tasks :
@@ -292,9 +298,7 @@ def _init_engine_flags(self):
292
298
self .shm_flag_ready = shared_memory .SharedMemory (
293
299
create = True , size = flag_array .nbytes , name = self .cfg .get_unique_name ("shm_flag_infer_ready" )
294
300
)
295
- self .flag_ready_array = np .ndarray (
296
- flag_array .shape , dtype = flag_array .dtype , buffer = self .shm_flag_ready .buf
297
- )
301
+ self .flag_ready_array = np .ndarray (flag_array .shape , dtype = flag_array .dtype , buffer = self .shm_flag_ready .buf )
298
302
self .flag_ready_array [:] = 0
299
303
300
304
# broadcast flag for engine
@@ -324,19 +328,22 @@ def _init_engine_flags(self):
324
328
tmp = shared_memory .SharedMemory (
325
329
create = False ,
326
330
size = has_block_step_flag_array .nbytes ,
327
- name = self .cfg .get_unique_name ("shm_flag_has_block_step" ))
331
+ name = self .cfg .get_unique_name ("shm_flag_has_block_step" ),
332
+ )
328
333
tmp .close ()
329
334
tmp .unlink ()
330
335
except :
331
336
pass
332
337
self .shm_flag_has_block_step = shared_memory .SharedMemory (
333
338
create = True ,
334
339
size = has_block_step_flag_array .nbytes ,
335
- name = self .cfg .get_unique_name ("shm_flag_has_block_step" ))
340
+ name = self .cfg .get_unique_name ("shm_flag_has_block_step" ),
341
+ )
336
342
self .flag_has_block_step_array = np .ndarray (
337
343
has_block_step_flag_array .shape ,
338
344
dtype = has_block_step_flag_array .dtype ,
339
- buffer = self .shm_flag_has_block_step .buf )
345
+ buffer = self .shm_flag_has_block_step .buf ,
346
+ )
340
347
self .flag_has_block_step_array [:] = 0
341
348
342
349
def _exit_sub_services (self ):
@@ -362,8 +369,9 @@ def _start_tasks_queue_service(self):
362
369
if p .is_alive ():
363
370
model_server_logger .info ("start tasks queue service successfully" )
364
371
else :
365
- error_msg = "Failed to start tasks queue service, please check " \
366
- "the log/task_queue_manager.log for details"
372
+ error_msg = (
373
+ "Failed to start tasks queue service, please check " "the log/task_queue_manager.log for details"
374
+ )
367
375
model_server_logger .info (error_msg )
368
376
raise Exception (error_msg )
369
377
return p
@@ -380,14 +388,16 @@ def _start_gpu_infer_service(self):
380
388
pd_cmd = "python3 -m paddle.distributed.launch "
381
389
py_script = os .path .join (current_dir_path , "infer.py" )
382
390
383
- arguments = (f" --nnodes { str (self .cfg .nnode )} "
384
- f" --devices { self .cfg .device_ids } { py_script } --model_dir { self .cfg .model_dir } "
385
- f" --max_batch_size { self .cfg .max_batch_size } --max_seq_len { self .cfg .max_seq_len } "
386
- f" --max_dec_len { self .cfg .max_dec_len } "
387
- f" --max_block_num { self .cfg .total_block_num } --block_size { self .cfg .block_size } "
388
- f" --use_cache_kv_int8 { self .cfg .use_cache_kv_int8 } "
389
- f" --enc_dec_block_num { self .cfg .enc_dec_block_num } "
390
- f" --block_ratio { self .cfg .block_ratio } --dtype { self .cfg .dtype } " )
391
+ arguments = (
392
+ f" --nnodes { str (self .cfg .nnode )} "
393
+ f" --devices { self .cfg .device_ids } { py_script } --model_dir { self .cfg .model_dir } "
394
+ f" --max_batch_size { self .cfg .max_batch_size } --max_seq_len { self .cfg .max_seq_len } "
395
+ f" --max_dec_len { self .cfg .max_dec_len } "
396
+ f" --max_block_num { self .cfg .total_block_num } --block_size { self .cfg .block_size } "
397
+ f" --use_cache_kv_int8 { self .cfg .use_cache_kv_int8 } "
398
+ f" --enc_dec_block_num { self .cfg .enc_dec_block_num } "
399
+ f" --block_ratio { self .cfg .block_ratio } --dtype { self .cfg .dtype } "
400
+ )
391
401
if self .cfg .nnode > 1 :
392
402
pd_cmd = pd_cmd + f" --ips { self .cfg .ips } "
393
403
pd_cmd = pd_cmd + arguments + " >log/launch_infer.log 2>&1"
0 commit comments