Skip to content

Commit 5b9ffc5

Browse files
committed
modified unit test
1 parent e8dde63 commit 5b9ffc5

File tree

3 files changed

+49
-146
lines changed

3 files changed

+49
-146
lines changed

fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def min_p_sampling(
163163
"""
164164
min_p_sampling
165165
"""
166-
if paddle.count_nonzero(min_p_arr)==0:
166+
if paddle.count_nonzero(min_p_arr) == 0:
167167
return probs
168168
else:
169169
if current_platform.is_cuda():
@@ -172,6 +172,6 @@ def min_p_sampling(
172172
else:
173173
max_probabilities = paddle.amax(probs,axis=-1,keepdim=True)
174174
adjusted_min_p = max_probabilities * min_p_arr
175-
invalid_token_mask = probs < adjusted_min_p
175+
invalid_token_mask = probs < adjusted_min_p.reshape([-1, 1])
176176
probs= paddle.where(invalid_token_mask,paddle.full_like(probs,0.0),probs)
177177
return probs

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def __init__(self):
176176
self.forward = self.forward_cuda
177177
else:
178178
raise NotImplementedError()
179-
self.step=0
180179

181180
self.processor = SamplerProcessor()
182181

@@ -286,7 +285,7 @@ def forward_cuda(
286285
sampled_token_ids=next_tokens,
287286
logprobs_tensors=logprobs_tensors,
288287
)
289-
self.step+=1
288+
290289
return sampler_output
291290

292291

test/layers/test_min_p.py

Lines changed: 46 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -13,117 +13,48 @@
1313
# limitations under the License.
1414

1515

16-
import matplotlib.pyplot as plt
1716
import numpy as np
1817
import paddle
1918
import paddle.nn.functional as F
20-
from tqdm import tqdm
2119

2220
from fastdeploy.model_executor.ops.gpu import min_p_sampling
2321

2422
sample_time = 1000000
2523
vocab_size = 1000
2624
min_p_value = 0.5
2725
batch_size = 3
28-
batch_min_p_values = [0.1, 0.5, 0.9]
29-
batch_min_p_values2=[0,3,0,0,0.4]
30-
31-
32-
def compress(data):
33-
new_data = np.array([0, 0, 0], dtype=float)
34-
new_data[0] = data[0]
35-
new_data[1] = data[1]
36-
new_data[2] = np.sum(data[2:])
37-
return new_data
38-
39-
40-
def plot_bar_chart(data1, data2, data3, title, request_idx=None):
41-
plt.figure(figsize=(6, 6))
42-
bar_width = 0.2
43-
idx = np.arange(len(data1)).astype(float)
44-
45-
bars1 = plt.bar(idx - bar_width, data1, width=bar_width, color='salmon', label='Original Probability', alpha=0.9)
46-
bars2 = plt.bar(idx, data2, width=bar_width, color='skyblue', label='Sampled Probability', alpha=0.9)
47-
bars3 = plt.bar(idx + bar_width, data3, width=bar_width, color='orange', label='Normalized Original Probability', alpha=0.9)
48-
49-
plt.bar_label(bars1, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='black')
50-
plt.bar_label(bars2, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='red')
51-
plt.bar_label(bars3, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='blue')
52-
53-
full_title = title if request_idx is None else f"{title} (min_p={batch_min_p_values[request_idx]})"
54-
plt.title(full_title, fontsize=14)
55-
plt.xlabel("Index", fontsize=12)
56-
plt.ylabel("Probability", fontsize=12)
57-
plt.ylim(0, 1.1)
58-
plt.xlim(-1, 3)
59-
plt.xticks(range(0, 3, 1))
60-
plt.legend(fontsize=10)
61-
plt.grid(axis='y', linestyle='--', alpha=0.5)
62-
output_path = f"{title.replace(' ', '_')}{'' if request_idx is None else f'_req{request_idx}'}.png"
63-
plt.savefig(output_path, dpi=300, bbox_inches='tight')
64-
plt.clf()
65-
66-
def plot_low_prob_curve(low_prob_token_probs, sample_time, title, request_idx=None):
67-
plt.figure(figsize=(6, 6))
68-
plt.plot(np.arange(0, sample_time), low_prob_token_probs, marker='', linestyle='-', linewidth=1, color='blue')
69-
plt.xlabel('Sample Times')
70-
plt.ylabel('Probability')
71-
full_title = 'Probability of Low-Probability Tokens' if request_idx is None else f"Low-Probability Tokens (min_p={batch_min_p_values[request_idx]})"
72-
plt.title(full_title)
73-
plt.grid(alpha=0.3)
74-
output_path = f"{title.replace(' ', '_')}_low_prob{'' if request_idx is None else f'_req{request_idx}'}.png"
75-
plt.savefig(output_path, dpi=300, bbox_inches='tight')
76-
plt.clf()
26+
batch_min_p_values = [0.1, 0.0, 0.9]
27+
7728

7829
# min_p:0.5:FastDeploy
79-
def fastdeploy_min_p_sampling():
30+
def min_p_sampling_cpu(min_p):
8031
logits = paddle.ones(shape=[1, vocab_size], dtype="float32")
8132
logits[0][0] = 10
8233
logits[0][1] = 8
8334
low_prob_tensor = paddle.linspace(2.0, 0.0, vocab_size - 2)
8435
logits[0][2:] = low_prob_tensor
8536

86-
probs = F.softmax(logits)
87-
min_p = paddle.to_tensor([min_p_value], dtype="float32")
88-
89-
max_prob = probs.max().item()
90-
threshold = max_prob * min_p.item()
91-
allowed_tokens = paddle.where(probs[0] >= threshold)[0].numpy()
92-
93-
sample_freq = [0] * vocab_size
94-
low_prob_token_times = 0
95-
low_prob_token_probs = []
96-
97-
for i in tqdm(range(sample_time), desc="FastDeploy Sampling"):
98-
ids = min_p_sampling(probs, min_p, seed=-1)
99-
sample_freq[ids.item()] += 1
100-
if ids.item() >= 2:
101-
low_prob_token_times += 1
102-
low_prob_token_probs.append(low_prob_token_times / (i + 1))
103-
104-
sample_freq = np.array(sample_freq, dtype=float) / sample_time
105-
low_prob_token_probs = np.array(low_prob_token_probs, dtype=float)
37+
probs=F.softmax(logits)
38+
max_probabilities = paddle.amax(probs, axis=-1, keepdim=True)
39+
adjusted_min_p = max_probabilities * min_p.reshape([-1, 1])
40+
invalid_token_mask = probs < adjusted_min_p
41+
probs = paddle.where(invalid_token_mask,paddle.full_like(probs,0.0), probs)
42+
return probs
10643

107-
ori_data1 = probs.numpy().reshape(-1)
108-
data1 = compress(ori_data1)
109-
data2 = compress(sample_freq)
110-
111-
allowed_probs = probs[0, allowed_tokens].numpy()
112-
norm_scale = np.sum(allowed_probs)
113-
data3 = np.zeros_like(data1)
114-
for idx in allowed_tokens:
115-
if idx < 2:
116-
data3[idx] = ori_data1[idx] / norm_scale
117-
else:
118-
data3[2] += ori_data1[idx] / norm_scale
119-
120-
plot_bar_chart(data1, data2, data3, "FastDeploy[min_p_sampling]")
121-
plot_low_prob_curve(low_prob_token_probs, sample_time, "FastDeploy[min_p_sampling]")
44+
# min_p:0.5:FastDeploy
45+
def fastdeploy_min_p_sampling(min_p):
46+
logits = paddle.ones(shape=[1, vocab_size], dtype="float32")
47+
logits[0][0] = 10
48+
logits[0][1] = 8
49+
low_prob_tensor = paddle.linspace(2.0, 0.0, vocab_size - 2)
50+
logits[0][2:] = low_prob_tensor
12251

123-
return data2, data3
52+
probs = F.softmax(logits)
53+
probs= min_p_sampling(probs, min_p)
54+
return probs
12455

12556

126-
# batch:[0.1.0,5,0.9]:FastDeploy
57+
# batch:[0.1.0.0,0.9]:FastDeploy
12758
def fastdeploy_batch_min_p_sampling(batch_size, min_p_values):
12859
logits = paddle.ones(shape=[batch_size, vocab_size], dtype="float32")
12960
for b in range(batch_size):
@@ -134,68 +65,41 @@ def fastdeploy_batch_min_p_sampling(batch_size, min_p_values):
13465
probs = F.softmax(logits, axis=-1)
13566
min_p_arr = paddle.to_tensor(min_p_values, dtype="float32")
13667

137-
allowed_tokens_list = []
138-
for b in range(batch_size):
139-
max_prob = probs[b].max().item()
140-
threshold = max_prob * min_p_values[b]
141-
allowed_tokens = paddle.where(probs[b] >= threshold)[0].numpy()
142-
allowed_tokens_list.append(allowed_tokens)
143-
144-
sample_freq = [np.zeros(vocab_size, dtype=float) for _ in range(batch_size)]
145-
low_prob_token_times = [0] * batch_size
146-
low_prob_token_probs = [[] for _ in range(batch_size)]
147-
148-
for i in tqdm(range(sample_time), desc="FastDeploy Batch Sampling"):
149-
ids = min_p_sampling(probs, min_p_arr, seed=-1)
150-
for b in range(batch_size):
151-
sample_freq[b][ids[b].item()] += 1
152-
if ids[b].item() >= 2:
153-
low_prob_token_times[b] += 1
154-
low_prob_token_probs[b].append(low_prob_token_times[b] / (i + 1))
155-
156-
data2_list = []
157-
data3_list = []
158-
for b in range(batch_size):
159-
sample_freq_b = sample_freq[b] / sample_time
160-
low_prob_token_probs[b] = np.array(low_prob_token_probs[b], dtype=float)
161-
162-
ori_data1 = probs[b].numpy()
163-
data1 = compress(ori_data1)
164-
data2 = compress(sample_freq_b)
165-
data2_list.append(data2)
68+
probs = min_p_sampling(probs, min_p_arr)
16669

167-
allowed_probs = probs[b, allowed_tokens_list[b]].numpy()
168-
norm_scale = np.sum(allowed_probs)
169-
data3 = np.zeros_like(data1)
170-
for idx in allowed_tokens_list[b]:
171-
if idx < 2:
172-
data3[idx] = ori_data1[idx] / norm_scale
173-
else:
174-
data3[2] += ori_data1[idx] / norm_scale
175-
data3_list.append(data3)
70+
return probs
17671

177-
plot_bar_chart(data1, data2, data3, "FastDeploy[min_p_batch_sampling]", b)
178-
plot_low_prob_curve(low_prob_token_probs[b], sample_time, "FastDeploy[min_p_batch_sampling]", b)
72+
def compare_results(probs,probs_cpu,atol=1e-6,rtol=1e-6):
73+
probs_np = probs.numpy()
74+
probs_cpu_np = probs_cpu.numpy()
75+
try:
76+
np.testing.assert_allclose(
77+
probs_np,
78+
probs_cpu_np,
79+
rtol=rtol,
80+
atol=atol,
81+
)
82+
print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu")
83+
except AssertionError as e:
84+
raise AssertionError(
85+
f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}")
17986

180-
return data2_list, data3_list
18187

18288

18389
def main():
90+
# min_p:0.5:FastDeploy
91+
min_p = paddle.to_tensor([min_p_value],dtype="float32")
18492
print("Running single min_p sampling (min_p=0.5)...")
185-
data2_fastdeploy, data3_fastdeploy = fastdeploy_min_p_sampling()
186-
187-
print("\nFastDeploy Single Request Results:")
188-
print(f"Sampled Probability: {data2_fastdeploy}")
189-
print(f"Theoretical Normalized Probability: {data3_fastdeploy}")
93+
probs = fastdeploy_min_p_sampling(min_p)
94+
probs_cpu = min_p_sampling_cpu(min_p)
95+
compare_results(probs,probs_cpu)
19096

191-
print("\nRunning batch min_p sampling (min_p=[0.1, 0.5, 0.9])...")
192-
data2_fd_batch, data3_fd_batch = fastdeploy_batch_min_p_sampling(batch_size, batch_min_p_values)
97+
# batch:[0.1.0.0,0.9]:FastDeploy
98+
batch_min_p = paddle.to_tensor(batch_min_p_values,dtype="float32")
99+
batch_probs = fastdeploy_batch_min_p_sampling(batch_size,batch_min_p)
100+
batch_probs_cpu = min_p_sampling_cpu(batch_min_p)
101+
compare_results(batch_probs,batch_probs_cpu)
193102

194-
data2_fd_batch,data3_fd_batch = fastdeploy_batch_min_p_sampling(batch_size,batch_min_p_values2)
195-
196-
for b in range(batch_size):
197-
print(f"\nBatch Request {b} (min_p={batch_min_p_values[b]}):")
198-
print(f"FastDeploy - Sampled: {data2_fd_batch[b]}, Normalized: {data3_fd_batch[b]}")
199103

200104
if __name__ == "__main__":
201105
if paddle.device.is_compiled_with_cuda():

0 commit comments

Comments
 (0)