@@ -52,7 +52,7 @@ def __init__(
52
52
self .stepped = False
53
53
self .closed = False
54
54
55
- self ._position = 0
55
+ self .position = 0
56
56
self .history = None # Used in case of server failures to regenerate attention caches on new servers
57
57
self .next_session = None
58
58
@@ -97,12 +97,11 @@ def step(
97
97
n_input_tokens = inputs .shape [1 ]
98
98
if self .history is None :
99
99
self .history = inputs
100
- elif self .history .shape [1 ] == self ._position :
100
+ elif self .history .shape [1 ] == self .position :
101
101
self .history = torch .cat ([self .history , inputs [:, - n_input_tokens :]], dim = 1 )
102
- assert self .history .shape [1 ] == self ._position + n_input_tokens , (
103
- f"Broken input cache: span={ self .span } shape={ self .history .shape } "
104
- f"position={ self ._position } n_input_tokens={ n_input_tokens } "
105
- )
102
+ assert (
103
+ self .history .shape [1 ] == self .position + n_input_tokens
104
+ ), f"Broken input cache: { self .span = } { self .history .shape = } { self .position = } { n_input_tokens = } "
106
105
107
106
if not self .stepped :
108
107
inputs = self .history # Pass full inputs including prefix
@@ -154,7 +153,7 @@ def step(
154
153
outputs [0 ].shape == inputs .shape
155
154
), f"output activation shape is different from input shape: { outputs [0 ].shape } != { inputs .shape } "
156
155
157
- self ._position += n_input_tokens
156
+ self .position += n_input_tokens
158
157
159
158
return outputs [0 ]
160
159
@@ -356,6 +355,10 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) ->
356
355
# If there is a failed span, this code replaces it, otherwise it just adds new ones
357
356
if server_idx < n_prev_spans :
358
357
updated_sessions [0 ].history = self ._server_sessions [server_idx ].history
358
+ updated_sessions [0 ].position = self ._position
359
+ assert (
360
+ updated_sessions [0 ].history .shape [1 ] == self ._position
361
+ ), f"Broken input cache: { updated_sessions [0 ].history .shape = } { self ._position = } "
359
362
self ._server_sessions [server_idx : server_idx + 1 ] = updated_sessions
360
363
361
364
# Update links to the next server session for direct server-to-server communication via rpc_push()
0 commit comments