@@ -73,11 +73,14 @@ def ref_paged_attn(
73
73
@pytest .mark .parametrize ("dtype" , DTYPES )
74
74
@pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
75
75
@torch .inference_mode
76
- def test_flashinfer_decode_with_paged_kv (kv_lens : List [int ],
77
- num_heads : Tuple [int ,
78
- int ], head_size : int ,
79
- dtype : torch .dtype , block_size : int ,
80
- soft_cap : Optional [float ]) -> None :
76
+ def test_flashinfer_decode_with_paged_kv (
77
+ kv_lens : List [int ],
78
+ num_heads : Tuple [int , int ],
79
+ head_size : int ,
80
+ dtype : torch .dtype ,
81
+ block_size : int ,
82
+ soft_cap : Optional [float ],
83
+ ) -> None :
81
84
torch .set_default_device ("cuda" )
82
85
torch .cuda .manual_seed_all (0 )
83
86
num_seqs = len (kv_lens )
@@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
88
91
scale = head_size ** - 0.5
89
92
90
93
query = torch .randn (num_seqs , num_query_heads , head_size , dtype = dtype )
94
+
91
95
key_value_cache = torch .randn (NUM_BLOCKS ,
92
96
2 ,
93
97
block_size ,
@@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
125
129
wrapper = flashinfer .\
126
130
BatchDecodeWithPagedKVCacheWrapper (workspace_buffer , "NHD" ,
127
131
use_tensor_cores = (
128
- (num_query_heads // num_kv_heads ) not in ( 1 , 2 , 4 , 8 ) )
132
+ (num_query_heads // num_kv_heads ) > 4 )
129
133
)
130
134
wrapper .begin_forward (kv_indptr ,
131
135
kv_indices ,
@@ -249,3 +253,215 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
249
253
soft_cap = soft_cap )
250
254
torch .testing .assert_close (output , ref_output , atol = 1e-2 , rtol = 1e-2 ), \
251
255
f"{ torch .max (torch .abs (output - ref_output ))} "
256
+
257
+
258
+ @pytest .mark .parametrize ("seq_lens" , [[(1 , 132 ), (5 , 18 )]])
259
+ @pytest .mark .parametrize ("num_heads" , [(32 , 8 ), (6 , 1 )])
260
+ @pytest .mark .parametrize ("head_size" , HEAD_SIZES )
261
+ @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
262
+ @pytest .mark .parametrize ("dtype" , DTYPES )
263
+ @pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
264
+ def test_flashinfer_prefill_with_paged_fp8_kv (
265
+ seq_lens : List [Tuple [int , int ]], num_heads : Tuple [int , int ],
266
+ head_size : int , dtype : torch .dtype , block_size : int ,
267
+ soft_cap : Optional [float ]) -> None :
268
+ torch .set_default_device ("cuda" )
269
+ torch .cuda .manual_seed_all (0 )
270
+ num_seqs = len (seq_lens )
271
+ query_lens = [x [0 ] for x in seq_lens ]
272
+ kv_lens = [x [1 ] for x in seq_lens ]
273
+ num_query_heads = num_heads [0 ]
274
+ num_kv_heads = num_heads [1 ]
275
+ assert num_query_heads % num_kv_heads == 0
276
+ max_kv_len = max (kv_lens )
277
+ scale = head_size ** - 0.5
278
+
279
+ kv_cache_dtype = torch .float8_e4m3fn
280
+
281
+ query = torch .randn (sum (query_lens ),
282
+ num_query_heads ,
283
+ head_size ,
284
+ dtype = dtype )
285
+ NUM_BLOCKS_FP8 = 2048
286
+ key_value_cache = torch .randn (NUM_BLOCKS_FP8 ,
287
+ 2 ,
288
+ block_size ,
289
+ num_kv_heads ,
290
+ head_size ,
291
+ dtype = dtype )
292
+ key_cache , value_cache = torch .chunk (key_value_cache , 2 , dim = 1 )
293
+ key_cache /= head_size ** 0.5
294
+ value_cache /= head_size ** 0.5
295
+
296
+ k_scale = key_cache .amax ().item () / 448.0
297
+ v_scale = value_cache .amax ().item () / 448.0
298
+
299
+ kv_cache_fp8 = torch .cat ([key_cache / k_scale , value_cache / v_scale ],
300
+ dim = 1 ).to (kv_cache_dtype )
301
+
302
+ assert (kv_cache_fp8 .shape == key_value_cache .shape )
303
+ max_num_blocks_per_seq = (max_kv_len + block_size - 1 ) // block_size
304
+ block_tables = torch .randint (0 ,
305
+ NUM_BLOCKS_FP8 ,
306
+ (num_seqs , max_num_blocks_per_seq ),
307
+ dtype = torch .int32 )
308
+
309
+ qo_indptr = [0 ]
310
+ kv_indptr = [0 ]
311
+ kv_indices = []
312
+ kv_last_page_lens = []
313
+ for i in range (num_seqs ):
314
+ seq_len = kv_lens [i ]
315
+ assert seq_len > 0
316
+ num_blocks = (seq_len + block_size - 1 ) // block_size
317
+ kv_indices .extend (block_tables [i , :num_blocks ])
318
+ kv_indptr .append (kv_indptr [- 1 ] + num_blocks )
319
+ kv_last_page_len = seq_len % block_size
320
+ if kv_last_page_len == 0 :
321
+ kv_last_page_len = block_size
322
+ kv_last_page_lens .append (kv_last_page_len )
323
+ qo_indptr .append (qo_indptr [- 1 ] + query_lens [i ])
324
+
325
+ qo_indptr = torch .tensor (qo_indptr , dtype = torch .int32 )
326
+ kv_indptr = torch .tensor (kv_indptr , dtype = torch .int32 )
327
+ kv_indices = torch .tensor (kv_indices , dtype = torch .int32 )
328
+ kv_last_page_lens = torch .tensor (kv_last_page_lens , dtype = torch .int32 )
329
+
330
+ workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 )
331
+ wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
332
+ workspace_buffer , "NHD" )
333
+ wrapper .begin_forward (
334
+ qo_indptr ,
335
+ kv_indptr ,
336
+ kv_indices ,
337
+ kv_last_page_lens ,
338
+ num_query_heads ,
339
+ num_kv_heads ,
340
+ head_size ,
341
+ block_size ,
342
+ )
343
+
344
+ output = wrapper .forward (query ,
345
+ kv_cache_fp8 ,
346
+ logits_soft_cap = soft_cap ,
347
+ k_scale = k_scale ,
348
+ v_scale = v_scale )
349
+
350
+ ref_output = ref_paged_attn (query = query ,
351
+ key_cache = key_cache .squeeze (1 ),
352
+ value_cache = value_cache .squeeze (1 ),
353
+ query_lens = query_lens ,
354
+ kv_lens = kv_lens ,
355
+ block_tables = block_tables ,
356
+ scale = scale ,
357
+ soft_cap = soft_cap )
358
+ del query
359
+ del block_tables
360
+ # verify prefill fp8
361
+ torch .testing .assert_close (output , ref_output , atol = 1e-2 , rtol = 1e-2 ), \
362
+ f"{ torch .max (torch .abs (output - ref_output ))} "
363
+
364
+
365
+ @pytest .mark .parametrize ("kv_lens" , [[1328 , 18 , 463 ], [1 , 54 , 293 , 70 ]])
366
+ @pytest .mark .parametrize ("num_heads" , [(32 , 8 ), (64 , 8 ), (6 , 1 )])
367
+ @pytest .mark .parametrize ("head_size" , HEAD_SIZES )
368
+ @pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
369
+ @pytest .mark .parametrize ("dtype" , DTYPES )
370
+ @pytest .mark .parametrize ("soft_cap" , [None , 30.0 , 50.0 ])
371
+ @torch .inference_mode
372
+ def test_flashinfer_decode_with_paged_fp8_kv (
373
+ kv_lens : List [int ],
374
+ num_heads : Tuple [int , int ],
375
+ head_size : int ,
376
+ dtype : torch .dtype ,
377
+ block_size : int ,
378
+ soft_cap : Optional [float ],
379
+ ) -> None :
380
+ # test doesn't work for num_heads = (16,16)
381
+ torch .set_default_device ("cuda" )
382
+ torch .cuda .manual_seed_all (0 )
383
+ num_seqs = len (kv_lens )
384
+ num_query_heads = num_heads [0 ]
385
+ num_kv_heads = num_heads [1 ]
386
+ assert num_query_heads % num_kv_heads == 0
387
+ max_kv_len = max (kv_lens )
388
+ scale = head_size ** - 0.5
389
+ use_tensor_cores = (num_query_heads // num_kv_heads ) > 4
390
+ kv_cache_dtype = torch .float8_e4m3fn
391
+
392
+ query = torch .randn (num_seqs , num_query_heads , head_size , dtype = dtype )
393
+ NUM_BLOCKS_FP8 = 2048
394
+ key_value_cache = torch .randn (NUM_BLOCKS_FP8 ,
395
+ 2 ,
396
+ block_size ,
397
+ num_kv_heads ,
398
+ head_size ,
399
+ dtype = dtype )
400
+ key_cache , value_cache = torch .chunk (key_value_cache , 2 , dim = 1 )
401
+ key_cache /= head_size ** 0.5
402
+ value_cache /= head_size ** 0.5
403
+
404
+ k_scale = key_cache .amax ().item () / 448.0
405
+ v_scale = value_cache .amax ().item () / 448.0
406
+
407
+ key_cache_fp8 = (key_cache / k_scale ).to (kv_cache_dtype )
408
+ value_cache_fp8 = (value_cache / v_scale ).to (kv_cache_dtype )
409
+ assert (key_cache_fp8 .shape [1 ] == 1 and value_cache_fp8 .shape [1 ] == 1 )
410
+ kv_cache_fp8 = torch .cat ([key_cache_fp8 , value_cache_fp8 ], dim = 1 )
411
+
412
+ max_num_blocks_per_seq = (max_kv_len + block_size - 1 ) // block_size
413
+ block_tables = torch .randint (0 ,
414
+ NUM_BLOCKS_FP8 ,
415
+ (num_seqs , max_num_blocks_per_seq ),
416
+ dtype = torch .int32 )
417
+
418
+ kv_indptr = [0 ]
419
+ kv_indices = []
420
+ kv_last_page_lens = []
421
+ for i in range (num_seqs ):
422
+ seq_len = kv_lens [i ]
423
+ assert seq_len > 0
424
+ num_blocks = (seq_len + block_size - 1 ) // block_size
425
+ kv_indices .extend (block_tables [i , :num_blocks ])
426
+ kv_indptr .append (kv_indptr [- 1 ] + num_blocks )
427
+ kv_last_page_len = seq_len % block_size
428
+ if kv_last_page_len == 0 :
429
+ kv_last_page_len = block_size
430
+ kv_last_page_lens .append (kv_last_page_len )
431
+
432
+ kv_indptr = torch .tensor (kv_indptr , dtype = torch .int32 )
433
+ kv_indices = torch .tensor (kv_indices , dtype = torch .int32 )
434
+ kv_last_page_lens = torch .tensor (kv_last_page_lens , dtype = torch .int32 )
435
+
436
+ workspace_buffer = torch .empty (128 * 1024 * 1024 , dtype = torch .int8 )
437
+ wrapper = flashinfer .\
438
+ BatchDecodeWithPagedKVCacheWrapper (workspace_buffer , "NHD" ,
439
+ use_tensor_cores = use_tensor_cores )
440
+ wrapper .begin_forward (kv_indptr ,
441
+ kv_indices ,
442
+ kv_last_page_lens ,
443
+ num_query_heads ,
444
+ num_kv_heads ,
445
+ head_size ,
446
+ block_size ,
447
+ "NONE" ,
448
+ data_type = dtype )
449
+ output = wrapper .forward (query ,
450
+ kv_cache_fp8 ,
451
+ logits_soft_cap = soft_cap ,
452
+ k_scale = k_scale ,
453
+ v_scale = v_scale )
454
+ key_cache = key_value_cache [:, 0 , :, :, :].squeeze (1 )
455
+ value_cache = key_value_cache [:, 1 , :, :, :].squeeze (1 )
456
+
457
+ ref_output = ref_paged_attn (query = query ,
458
+ key_cache = key_cache ,
459
+ value_cache = value_cache ,
460
+ query_lens = [1 ] * num_seqs ,
461
+ kv_lens = kv_lens ,
462
+ block_tables = block_tables ,
463
+ scale = scale ,
464
+ soft_cap = soft_cap )
465
+ # Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
466
+ torch .testing .assert_close (output , ref_output , atol = 2e-2 , rtol = 1e-2 ), \
467
+ f"{ torch .max (torch .abs (output - ref_output ))} "
0 commit comments