@@ -170,7 +170,7 @@ def __init__(
170
170
self .num_blocks = num_blocks
171
171
172
172
# Initialize the free blocks.
173
- self .free_blocks : BlockTable = []
173
+ self .free_blocks : List [ PhysicalTokenBlock ] = []
174
174
for i in range (num_blocks ):
175
175
block = PhysicalTokenBlock (device = device ,
176
176
block_number = i ,
@@ -256,6 +256,7 @@ def __init__(
256
256
Device .CPU , block_size , num_cpu_blocks )
257
257
# Mapping: seq_id -> BlockTable.
258
258
self .block_tables : Dict [int , BlockTable ] = {}
259
+
259
260
# Mapping: req_id -> BlockTable
260
261
# Note that each SequenceGroup has a unique
261
262
# request ID
@@ -299,7 +300,7 @@ def _allocate_sequence(self, \
299
300
# Allocate new physical token blocks that will store the prompt tokens.
300
301
num_prompt_blocks = seq .n_blocks
301
302
302
- block_table : BlockTable = []
303
+ block_table : BlockTable = BlockTable ()
303
304
for logical_idx in range (num_prompt_blocks ):
304
305
if (self .block_sliding_window is not None
305
306
and logical_idx >= self .block_sliding_window ):
@@ -326,15 +327,19 @@ def allocate(self, seq_group: SequenceGroup) -> None:
326
327
#
327
328
# NOTE: Here we assume that all sequences in the group have the same
328
329
# decoder prompt.
329
- seq = seq_group .get_seqs (status = SequenceStatus .WAITING )[0 ]
330
+ wait_seqs = seq_group .get_seqs (status = SequenceStatus .WAITING )
331
+ seq = wait_seqs [0 ]
330
332
block_table : BlockTable = \
331
333
self ._allocate_sequence (seq ,
332
334
seq_group .num_seqs (),
333
335
is_encoder_decoder )
334
336
335
337
# Assign the self-attention block tables for each sequence.
336
- for seq in seq_group .get_seqs (status = SequenceStatus .WAITING ):
337
- self .block_tables [seq .seq_id ] = block_table .copy ()
338
+ if len (wait_seqs ) == 1 :
339
+ self .block_tables [wait_seqs [0 ].seq_id ] = block_table
340
+ else :
341
+ for seq in seq_group .get_seqs (status = SequenceStatus .WAITING ):
342
+ self .block_tables [seq .seq_id ] = block_table .copy ()
338
343
339
344
# Allocate encoder sequence
340
345
if is_encoder_decoder :
@@ -476,6 +481,7 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
476
481
return
477
482
src_block_table = self .block_tables [parent_seq .seq_id ]
478
483
self .block_tables [child_seq .seq_id ] = src_block_table .copy ()
484
+
479
485
# When using a sliding window, blocks will be eventually reused.
480
486
# In this case the block tables will contain repeated blocks.
481
487
# When forking, we must make sure that each block's `ref_count`
@@ -527,7 +533,7 @@ def _swap_block_table(
527
533
dest_allocator : BlockAllocatorBase ,
528
534
mapping : Dict [PhysicalTokenBlock ,
529
535
PhysicalTokenBlock ]) -> BlockTable :
530
- new_block_table = []
536
+ new_block_table : BlockTable = BlockTable ()
531
537
532
538
for from_block in block_table :
533
539
if from_block in mapping :
@@ -553,8 +559,7 @@ def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
553
559
for seq in seq_group .get_seqs (status = SequenceStatus .SWAPPED ):
554
560
self .block_tables [seq .seq_id ] = \
555
561
self ._swap_block_table (self .block_tables [seq .seq_id ],
556
- self .cpu_allocator ,
557
- self .gpu_allocator ,
562
+ self .cpu_allocator , self .gpu_allocator ,
558
563
mapping )
559
564
560
565
if seq_group .is_encoder_decoder ():
@@ -580,8 +585,7 @@ def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
580
585
for seq in seq_group .get_seqs (status = SequenceStatus .RUNNING ):
581
586
self .block_tables [seq .seq_id ] = \
582
587
self ._swap_block_table (self .block_tables [seq .seq_id ],
583
- self .gpu_allocator ,
584
- self .cpu_allocator ,
588
+ self .gpu_allocator , self .cpu_allocator ,
585
589
mapping )
586
590
587
591
if seq_group .is_encoder_decoder ():
@@ -636,8 +640,7 @@ def reset(self) -> None:
636
640
self .cross_block_tables .clear ()
637
641
638
642
def get_block_table (self , seq : Sequence ) -> List [int ]:
639
- block_table = self .block_tables [seq .seq_id ]
640
- return [block .block_number for block in block_table ]
643
+ return self .block_tables [seq .seq_id ].ids ()
641
644
642
645
def get_cross_block_table (self , seq_group : SequenceGroup ) -> List [int ]:
643
646
block_table = self .cross_block_tables [seq_group .request_id ]
0 commit comments