Skip to content

Commit 47de2c6

Browse files
authored
[infer] get_output and save_output recover input without msg_queue_id (#10574)
* check * fix input_text dy_insert * check code
1 parent 8205e3d commit 47de2c6

File tree

10 files changed

+89
-61
lines changed

10 files changed

+89
-61
lines changed

csrc/gpu/cpp_extensions.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,9 @@ std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
238238

239239
void SaveOutMmsg(const paddle::Tensor& x,
240240
const paddle::Tensor& not_need_stop, // cpu
241-
const paddle::Tensor& msg_queue_id, // cpu
242241
int64_t rank_id);
243242

244243
void GetOutput(const paddle::Tensor& x,
245-
const paddle::Tensor& msg_queue_id, // cpu
246244
int64_t rank_id,
247245
bool wait_flag);
248246

csrc/gpu/get_output.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,33 @@
2020
#include "paddle/extension.h"
2121

2222
#define MAX_BSZ 512
23+
// #define GET_OUTPUT_DEBUG
2324

2425
struct msgdata {
2526
long mtype;
2627
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
2728
};
2829

2930
void GetOutput(const paddle::Tensor& x,
30-
const paddle::Tensor& msg_queue_id,
3131
int64_t rank_id,
3232
bool wait_flag) {
3333
if (rank_id > 0) return;
3434

3535
static struct msgdata msg_rcv;
36-
int queue_id_val = msg_queue_id.data<int>()[0];
37-
static key_t key = ftok("./", queue_id_val);
36+
int msg_queue_id = 1;
37+
if (const char* inference_msg_queue_id_env_p =
38+
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
39+
std::string inference_msg_queue_id_env_str(
40+
inference_msg_queue_id_env_p);
41+
int inference_msg_queue_id_from_env =
42+
std::stoi(inference_msg_queue_id_env_str);
43+
#ifdef GET_OUTPUT_DEBUG
44+
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
45+
<< inference_msg_queue_id_from_env << std::endl;
46+
#endif
47+
msg_queue_id = inference_msg_queue_id_from_env;
48+
}
49+
static key_t key = ftok("./", msg_queue_id);
3850

3951
static int msgid = msgget(key, IPC_CREAT | 0666);
4052

@@ -62,7 +74,7 @@ void GetOutput(const paddle::Tensor& x,
6274
}
6375

6476
PD_BUILD_OP(get_output)
65-
.Inputs({"x", "msg_queue_id"})
77+
.Inputs({"x"})
6678
.Attrs({"rank_id: int64_t",
6779
"wait_flag: bool"})
6880
.Outputs({"x_out"})

csrc/gpu/save_with_output_msg.cc

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "paddle/extension.h"
2121

2222
#define MAX_BSZ 512
23+
// #define SAVE_WITH_OUTPUT_DEBUG
2324

2425
struct msgdata {
2526
long mtype;
@@ -28,16 +29,32 @@ struct msgdata {
2829

2930
void SaveOutMmsg(const paddle::Tensor& x,
3031
const paddle::Tensor& not_need_stop, // cpu
31-
const paddle::Tensor& msg_queue_id, // cpu
3232
int64_t rank_id) {
3333
if (rank_id > 0) return;
3434
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
3535
int64_t *x_data = x_cpu.data<int64_t>();
3636
auto not_need_stop_data = not_need_stop.data<bool>()[0];
3737

3838
static struct msgdata msg_sed;
39-
int queue_id_val = msg_queue_id.data<int>()[0];
40-
static key_t key = ftok("./", queue_id_val);
39+
int msg_queue_id = 1;
40+
if (const char* inference_msg_queue_id_env_p =
41+
std::getenv("INFERENCE_MSG_QUEUE_ID")) {
42+
std::string inference_msg_queue_id_env_str(
43+
inference_msg_queue_id_env_p);
44+
int inference_msg_queue_id_from_env =
45+
std::stoi(inference_msg_queue_id_env_str);
46+
msg_queue_id = inference_msg_queue_id_from_env;
47+
#ifdef SAVE_WITH_OUTPUT_DEBUG
48+
std::cout << "Your INFERENCE_MSG_QUEUE_ID is: "
49+
<< inference_msg_queue_id_from_env << std::endl;
50+
#endif
51+
} else {
52+
#ifdef SAVE_WITH_OUTPUT_DEBUG
53+
std::cout << "Failed to got INFERENCE_MSG_QUEUE_ID at env, use default."
54+
<< std::endl;
55+
#endif
56+
}
57+
static key_t key = ftok("./", msg_queue_id);
4158
static int msgid = msgget(key, IPC_CREAT | 0666);
4259

4360
msg_sed.mtype = 1;
@@ -54,7 +71,7 @@ void SaveOutMmsg(const paddle::Tensor& x,
5471
}
5572

5673
PD_BUILD_OP(save_output)
57-
.Inputs({"x", "not_need_stop", "msg_queue_id"})
74+
.Inputs({"x", "not_need_stop"})
5875
.Attrs({"rank_id: int64_t"})
5976
.Outputs({"x_out"})
6077
.SetInplaceMap({{"x", "x_out"}})

llm/docs/predict/best_practices.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,6 @@ PaddleNLP 提供了多种环境变量,用于优化推理性能和资源使用
4444

4545
**自定义算子通用 优化**
4646
- `DYNAMIC_INFERENCE_MODE`:动态图推理时自定义算子是否采用 pybind 调用方式,默认为 True。
47+
48+
**其他**
49+
- `INFERENCE_MSG_QUEUE_ID`:多实例消息队列 id,默认为 进程 id。

llm/predict/predictor.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,8 @@ def __init__(
848848

849849
self.pre_cache_length = 0
850850

851-
self.msg_queue_id = os.getpid()
851+
msg_queue_id_str = os.getenv("INFERENCE_MSG_QUEUE_ID", str(os.getpid()))
852+
os.environ["INFERENCE_MSG_QUEUE_ID"] = msg_queue_id_str
852853

853854
if config.export_precache:
854855
pre_cache_npy = np.load(config.prefix_path)
@@ -948,7 +949,6 @@ def init_model_inputs(self, config: PredictorArgument):
948949
)
949950
self.model_inputs["bad_tokens"] = paddle.to_tensor([-1], dtype="int64")
950951
self.model_inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool")
951-
self.model_inputs["msg_queue_id"] = paddle.full(shape=[1], fill_value=self.msg_queue_id, dtype="int32").cpu()
952952

953953
# bloom model needs src_mask and tgt_mask!
954954
if "bloom" in self.architectures:
@@ -1185,7 +1185,7 @@ def predict_via_mq(self, input_texts: list[str], return_tokens=False):
11851185

11861186
read_res_process = mp.Process(
11871187
target=read_res_func,
1188-
args=[self.model_name_or_path, tensor_queue, result_queue, done_event, self.model_inputs["msg_queue_id"]],
1188+
args=[self.model_name_or_path, tensor_queue, result_queue, done_event],
11891189
)
11901190
if self.tensor_parallel_rank == 0:
11911191
read_res_process.start()
@@ -1311,7 +1311,7 @@ def insert_task(self, pos, task_id, repeat_num):
13111311
self.model_inputs["stop_flags"][pos] = False
13121312
self.model_inputs["result_id"][pos][0] = task_id
13131313
self.model_inputs["step_idx"][pos, 0] = 1
1314-
self.model_inputs["pre_ids"][pos][0] = self.input_ids[query_id][-1]
1314+
self.model_inputs["pre_ids"][pos][0] = np.array(self.input_ids[query_id][-1])
13151315
self.model_inputs["pre_ids"][pos][1:] = -1
13161316
self.model_inputs["not_need_stop"][0] = True
13171317

@@ -1477,7 +1477,6 @@ def predict_dy_insert(
14771477
task_queue,
14781478
result_queue,
14791479
done_event,
1480-
self.model_inputs["msg_queue_id"],
14811480
len(self.input_ids),
14821481
detokenize,
14831482
],
@@ -1716,7 +1715,7 @@ def predict_via_mq(self, input_texts: list[str], return_tokens=False):
17161715

17171716
read_res_process = mp.Process(
17181717
target=read_res_func,
1719-
args=[self.model_name_or_path, tensor_queue, result_queue, done_event, self.model_inputs["msg_queue_id"]],
1718+
args=[self.model_name_or_path, tensor_queue, result_queue, done_event],
17201719
)
17211720
if self.tensor_parallel_rank == 0:
17221721
read_res_process.start()

llm/server/server/server/engine/engine.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424

2525
import numpy as np
2626
from server.engine.resource_manager import ResourceManager
27-
from server.engine.task_queue_manager import (TaskQueueManager,
28-
launch_queue_service)
27+
from server.engine.task_queue_manager import TaskQueueManager, launch_queue_service
2928
from server.engine.token_processor import TokenProcessor, WarmUpTokenProcessor
3029
from server.utils import model_server_logger
3130

@@ -34,10 +33,11 @@ class Engine(object):
3433
"""
3534
Engine Class
3635
"""
36+
3737
def __init__(self, cfg, token_processor):
3838
self.cfg = cfg
3939
# Master node only
40-
if self.cfg.nnode == 1 or self.cfg.host_ip == os.getenv('POD_0_IP', '127.0.0.1'):
40+
if self.cfg.nnode == 1 or self.cfg.host_ip == os.getenv("POD_0_IP", "127.0.0.1"):
4141
self.queue_service = self._start_tasks_queue_service()
4242
self.tasks_queue = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port)
4343
self.resource_manager = ResourceManager(self.cfg)
@@ -53,8 +53,10 @@ def start(self):
5353
initialize engine and start sub services
5454
"""
5555
assert not self.is_started, "The engine is already started.!"
56+
msg_queue_id_str = os.getenv("INFERENCE_MSG_QUEUE_ID", str(os.getpid()))
57+
os.environ["INFERENCE_MSG_QUEUE_ID"] = msg_queue_id_str
5658
start_time = time.time()
57-
59+
5860
self.token_processor.tasks_queue = self.tasks_queue
5961
self.infer_proc = self._start_infer_service()
6062
model_server_logger.info("Waiting infer processes ready...")
@@ -80,17 +82,18 @@ def warmup(self):
8082
"""
8183
# get eos_token_id
8284
from server.data.processor import DataProcessor
85+
8386
eos_token_ids = DataProcessor().get_eos_tokens()
8487

85-
# construct test tasks
88+
# construct test tasks
8689
res_task = []
8790
for j in range(2 * self.cfg.max_batch_size):
8891
data = {
8992
"input_ids": [5],
9093
"req_id": j,
9194
"max_dec_len": self.cfg.dec_len_limit,
9295
"min_dec_len": int(self.cfg.dec_len_limit * 0.5) + 1,
93-
"eos_token_ids": eos_token_ids
96+
"eos_token_ids": eos_token_ids,
9497
}
9598
res_task.append(data)
9699
for j in range(2 * self.cfg.max_prefill_batch):
@@ -99,7 +102,7 @@ def warmup(self):
99102
"req_id": j + 2 * self.cfg.max_batch_size,
100103
"max_dec_len": 1,
101104
"min_dec_len": 1,
102-
"eos_token_ids": eos_token_ids
105+
"eos_token_ids": eos_token_ids,
103106
}
104107
res_task.append(data)
105108

@@ -130,8 +133,9 @@ def insert_tasks(self, tasks):
130133

131134
available_batch = np.sum(self.resource_manager.stop_flags)
132135
if len(tasks) > available_batch:
133-
model_server_logger.error("Inserting batch:{} exceeds the available batch:{}.".format(
134-
len(tasks), available_batch))
136+
model_server_logger.error(
137+
"Inserting batch:{} exceeds the available batch:{}.".format(len(tasks), available_batch)
138+
)
135139
model_server_logger.error("The exceeded part will be ignored!")
136140
tasks = tasks[:available_batch]
137141

@@ -140,21 +144,23 @@ def insert_tasks(self, tasks):
140144
input_token_num = len(tasks[i]["input_ids"])
141145
if input_token_num >= self.cfg.max_seq_len - 1:
142146
model_server_logger.warning(f"{req_id}: Input length:{input_token_num}, exceed the limits.")
143-
tasks[i]["input_ids"] = tasks[i]["input_ids"][:self.cfg.max_seq_len - 1]
147+
tasks[i]["input_ids"] = tasks[i]["input_ids"][: self.cfg.max_seq_len - 1]
144148
if "seq_len" in tasks[i] and "max_dec_len" not in tasks[i]:
145149
tasks[i]["max_dec_len"] = tasks[i]["seq_len"]
146150

147151
# max_dec_len + input_token_num > MAX_SEQ_LEN
148152
if input_token_num + tasks[i]["max_dec_len"] > self.cfg.max_seq_len:
149153
tasks[i]["max_dec_len"] = self.cfg.max_seq_len - input_token_num
150-
model_server_logger.warning("Force max_dec_len to be {} for req_id={}.".format(
151-
tasks[i]["max_dec_len"], tasks[i]["req_id"]))
154+
model_server_logger.warning(
155+
"Force max_dec_len to be {} for req_id={}.".format(tasks[i]["max_dec_len"], tasks[i]["req_id"])
156+
)
152157

153158
# min_dec_len + input_token_num > MAX_SEQ_LEN
154159
if input_token_num + tasks[i]["min_dec_len"] > self.cfg.max_seq_len:
155160
tasks[i]["min_dec_len"] = self.cfg.max_seq_len - input_token_num
156-
model_server_logger.warning("Force min_dec_len to be {} for req_id={}.".format(
157-
tasks[i]["min_dec_len"], tasks[i]["req_id"]))
161+
model_server_logger.warning(
162+
"Force min_dec_len to be {} for req_id={}.".format(tasks[i]["min_dec_len"], tasks[i]["req_id"])
163+
)
158164

159165
tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
160166
if not tasks:
@@ -292,9 +298,7 @@ def _init_engine_flags(self):
292298
self.shm_flag_ready = shared_memory.SharedMemory(
293299
create=True, size=flag_array.nbytes, name=self.cfg.get_unique_name("shm_flag_infer_ready")
294300
)
295-
self.flag_ready_array = np.ndarray(
296-
flag_array.shape, dtype=flag_array.dtype, buffer=self.shm_flag_ready.buf
297-
)
301+
self.flag_ready_array = np.ndarray(flag_array.shape, dtype=flag_array.dtype, buffer=self.shm_flag_ready.buf)
298302
self.flag_ready_array[:] = 0
299303

300304
# broadcast flag for engine
@@ -324,19 +328,22 @@ def _init_engine_flags(self):
324328
tmp = shared_memory.SharedMemory(
325329
create=False,
326330
size=has_block_step_flag_array.nbytes,
327-
name=self.cfg.get_unique_name("shm_flag_has_block_step"))
331+
name=self.cfg.get_unique_name("shm_flag_has_block_step"),
332+
)
328333
tmp.close()
329334
tmp.unlink()
330335
except:
331336
pass
332337
self.shm_flag_has_block_step = shared_memory.SharedMemory(
333338
create=True,
334339
size=has_block_step_flag_array.nbytes,
335-
name=self.cfg.get_unique_name("shm_flag_has_block_step"))
340+
name=self.cfg.get_unique_name("shm_flag_has_block_step"),
341+
)
336342
self.flag_has_block_step_array = np.ndarray(
337343
has_block_step_flag_array.shape,
338344
dtype=has_block_step_flag_array.dtype,
339-
buffer=self.shm_flag_has_block_step.buf)
345+
buffer=self.shm_flag_has_block_step.buf,
346+
)
340347
self.flag_has_block_step_array[:] = 0
341348

342349
def _exit_sub_services(self):
@@ -362,8 +369,9 @@ def _start_tasks_queue_service(self):
362369
if p.is_alive():
363370
model_server_logger.info("start tasks queue service successfully")
364371
else:
365-
error_msg = "Failed to start tasks queue service, please check " \
366-
"the log/task_queue_manager.log for details"
372+
error_msg = (
373+
"Failed to start tasks queue service, please check " "the log/task_queue_manager.log for details"
374+
)
367375
model_server_logger.info(error_msg)
368376
raise Exception(error_msg)
369377
return p
@@ -380,14 +388,16 @@ def _start_gpu_infer_service(self):
380388
pd_cmd = "python3 -m paddle.distributed.launch "
381389
py_script = os.path.join(current_dir_path, "infer.py")
382390

383-
arguments = (f" --nnodes {str(self.cfg.nnode)}"
384-
f" --devices {self.cfg.device_ids} {py_script} --model_dir {self.cfg.model_dir}"
385-
f" --max_batch_size {self.cfg.max_batch_size} --max_seq_len {self.cfg.max_seq_len}"
386-
f" --max_dec_len {self.cfg.max_dec_len}"
387-
f" --max_block_num {self.cfg.total_block_num} --block_size {self.cfg.block_size}"
388-
f" --use_cache_kv_int8 {self.cfg.use_cache_kv_int8}"
389-
f" --enc_dec_block_num {self.cfg.enc_dec_block_num}"
390-
f" --block_ratio {self.cfg.block_ratio} --dtype {self.cfg.dtype}")
391+
arguments = (
392+
f" --nnodes {str(self.cfg.nnode)}"
393+
f" --devices {self.cfg.device_ids} {py_script} --model_dir {self.cfg.model_dir}"
394+
f" --max_batch_size {self.cfg.max_batch_size} --max_seq_len {self.cfg.max_seq_len}"
395+
f" --max_dec_len {self.cfg.max_dec_len}"
396+
f" --max_block_num {self.cfg.total_block_num} --block_size {self.cfg.block_size}"
397+
f" --use_cache_kv_int8 {self.cfg.use_cache_kv_int8}"
398+
f" --enc_dec_block_num {self.cfg.enc_dec_block_num}"
399+
f" --block_ratio {self.cfg.block_ratio} --dtype {self.cfg.dtype}"
400+
)
391401
if self.cfg.nnode > 1:
392402
pd_cmd = pd_cmd + f" --ips {self.cfg.ips}"
393403
pd_cmd = pd_cmd + arguments + " >log/launch_infer.log 2>&1"

llm/server/server/server/engine/infer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,6 @@ def init_inputs(self):
279279
self.share_inputs["input_ids"] = paddle.full(
280280
shape=[self.args.max_batch_size, self.args.max_seq_len], fill_value=self.pad_token_id, dtype="int64"
281281
)
282-
self.share_inputs["msg_queue_id"] = paddle.full(shape=[1], fill_value=1, dtype="int32").cpu()
283282
self.share_inputs["top_p"] = paddle.full(
284283
shape=[self.args.max_batch_size, 1], fill_value=self.top_p, dtype="float32"
285284
)
@@ -743,7 +742,7 @@ def _init_predictor(self):
743742
config.set_xpu_device_id(device_id)
744743
xpu_config = paddle.inference.XpuConfig()
745744
xpu_config.device_id = device_id
746-
xpu_config.l3_size = 0
745+
xpu_config.l3_size = 0
747746
xpu_config.l3_autotune_size = 0
748747
config.set_xpu_config(xpu_config)
749748
config.switch_ir_optim(True)

llm/server/server/server/engine/token_processor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __init__(self, cfg):
4747
self.tokens_counter = Counter()
4848

4949
self.is_speculate_decoding = self.cfg.get_speculate_config().speculate_method != "None"
50-
self.msg_queue_id = paddle.full(shape=[1], fill_value=1, dtype="int32")
5150
if self.is_speculate_decoding:
5251
self.output_tokens = paddle.full(
5352
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64"
@@ -97,7 +96,7 @@ def process_sampling_results(self):
9796
if self.is_speculate_decoding:
9897
speculate_get_output(self.output_tokens, rank_id, is_blocking)
9998
else:
100-
get_output(self.output_tokens, self.msg_queue_id, rank_id, is_blocking)
99+
get_output(self.output_tokens, rank_id, is_blocking)
101100

102101
if self.output_tokens[0, 0] == -2:
103102
continue
@@ -281,7 +280,7 @@ def process_sampling_results(self):
281280
if self.is_speculate_decoding:
282281
speculate_get_output(self.output_tokens, rank_id, self._is_blocking)
283282
else:
284-
get_output(self.output_tokens, self.msg_queue_id, rank_id, self._is_blocking)
283+
get_output(self.output_tokens, rank_id, self._is_blocking)
285284

286285
if self.output_tokens[0, 0] == -2:
287286
continue

0 commit comments

Comments
 (0)