Skip to content

Commit 9289d93

Browse files
committed
Fix retries during inference
1 parent 3f70ab6 commit 9289d93

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

src/petals/client/inference_session.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
self.stepped = False
5353
self.closed = False
5454

55-
self._position = 0
55+
self.position = 0
5656
self.history = None # Used in case of server failures to regenerate attention caches on new servers
5757
self.next_session = None
5858

@@ -97,12 +97,11 @@ def step(
9797
n_input_tokens = inputs.shape[1]
9898
if self.history is None:
9999
self.history = inputs
100-
elif self.history.shape[1] == self._position:
100+
elif self.history.shape[1] == self.position:
101101
self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)
102-
assert self.history.shape[1] == self._position + n_input_tokens, (
103-
f"Broken input cache: span={self.span} shape={self.history.shape} "
104-
f"position={self._position} n_input_tokens={n_input_tokens}"
105-
)
102+
assert (
103+
self.history.shape[1] == self.position + n_input_tokens
104+
), f"Broken input cache: {self.span=} {self.history.shape=} {self.position=} {n_input_tokens=}"
106105

107106
if not self.stepped:
108107
inputs = self.history # Pass full inputs including prefix
@@ -154,7 +153,7 @@ def step(
154153
outputs[0].shape == inputs.shape
155154
), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}"
156155

157-
self._position += n_input_tokens
156+
self.position += n_input_tokens
158157

159158
return outputs[0]
160159

@@ -356,6 +355,10 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) ->
356355
# If there is a failed span, this code replaces it, otherwise it just adds new ones
357356
if server_idx < n_prev_spans:
358357
updated_sessions[0].history = self._server_sessions[server_idx].history
358+
updated_sessions[0].position = self._position
359+
assert (
360+
updated_sessions[0].history.shape[1] == self._position
361+
), f"Broken input cache: {updated_sessions[0].history.shape=} {self._position=}"
359362
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
360363

361364
# Update links to the next server session for direct server-to-server communication via rpc_push()

0 commit comments

Comments
 (0)