Skip to content

Commit acb6e22

Browse files
authored
【Inference】Fix Bug set_preids_token_penalty_multi_scores (#10492)
* fix * check model_inputs[pre_ids] update in dynamic insert
1 parent 409fc40 commit acb6e22

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

csrc/gpu/set_preids_token_penalty_multi_scores.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ __global__ void set_preids_token_penalty_multi_scores_kernel(const bool *stop_fl
4141
T *logits_now = logits + bi * length;
4242
int tid = threadIdx.x;
4343

44-
if (tid < bs && !stop_flags[tid]) {
45-
int64_t *pre_ids_now = pre_ids + tid * length_id;
46-
const int64_t *input_ids_now = input_ids + tid * length_input_ids;
47-
const int seq_len_dec = seq_lens_decoder[tid];
48-
const int seq_len_enc = seq_lens_encoder[tid];
44+
if (bi < bs && !stop_flags[bi]) {
45+
int64_t *pre_ids_now = pre_ids + bi * length_id;
46+
const int64_t *input_ids_now = input_ids + bi * length_input_ids;
47+
const int seq_len_dec = seq_lens_decoder[bi];
48+
const int seq_len_enc = seq_lens_encoder[bi];
4949
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
5050

5151
const int step_idx_now = step_idx[bi];

llm/predict/predictor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,8 @@ def insert_task(self, pos, task_id, repeat_num):
13111311
self.model_inputs["stop_flags"][pos] = False
13121312
self.model_inputs["result_id"][pos][0] = task_id
13131313
self.model_inputs["step_idx"][pos, 0] = 1
1314+
self.model_inputs["pre_ids"][pos][0] = self.input_ids[query_id][-1]
1315+
self.model_inputs["pre_ids"][pos][1:] = -1
13141316
self.model_inputs["not_need_stop"][0] = True
13151317

13161318
num_prefill_blocks = length // self.block_size

0 commit comments

Comments
 (0)