@@ -52,7 +52,7 @@ def __init__(
52
52
self .stepped = False
53
53
self .closed = False
54
54
55
- self ._position = 0
55
+ self .position = 0
56
56
self .history = None # Used in case of server failures to regenerate attention caches on new servers
57
57
self .next_session = None
58
58
@@ -102,12 +102,11 @@ def step(
102
102
n_input_tokens = inputs .shape [1 ]
103
103
if self .history is None :
104
104
self .history = inputs
105
- elif self .history .shape [1 ] == self ._position :
105
+ elif self .history .shape [1 ] == self .position :
106
106
self .history = torch .cat ([self .history , inputs [:, - n_input_tokens :]], dim = 1 )
107
- assert self .history .shape [1 ] == self ._position + n_input_tokens , (
108
- f"Broken input cache: span={ self .span } shape={ self .history .shape } "
109
- f"position={ self ._position } n_input_tokens={ n_input_tokens } "
110
- )
107
+ assert (
108
+ self .history .shape [1 ] == self .position + n_input_tokens
109
+ ), f"Broken input cache: { self .span = } { self .history .shape = } { self .position = } { n_input_tokens = } "
111
110
112
111
if not self .stepped :
113
112
inputs = self .history # Pass full inputs including prefix
@@ -169,7 +168,7 @@ def step(
169
168
outputs [0 ].shape == inputs .shape
170
169
), f"output activation shape is different from input shape: { outputs [0 ].shape } != { inputs .shape } "
171
170
172
- self ._position += n_input_tokens
171
+ self .position += n_input_tokens
173
172
174
173
return outputs [0 ]
175
174
@@ -359,6 +358,10 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) ->
359
358
# If there is a failed span, this code replaces it, otherwise it just adds new ones
360
359
if server_idx < n_prev_spans :
361
360
updated_sessions [0 ].history = self ._server_sessions [server_idx ].history
361
+ updated_sessions [0 ].position = self ._position
362
+ assert (
363
+ updated_sessions [0 ].history .shape [1 ] == self ._position
364
+ ), f"Broken input cache: { updated_sessions [0 ].history .shape [1 ]= } { self ._position = } "
362
365
self ._server_sessions [server_idx : server_idx + 1 ] = updated_sessions
363
366
364
367
# Update links to the next server session for direct server-to-server communication via rpc_push()
0 commit comments