Skip to content

Commit 82a97d6

Browse files
authored
Fix beam search in GPU clients (#531)
Fixes #503.
1 parent 47d50e1 commit 82a97d6

File tree

2 files changed

+24
-36
lines changed

2 files changed

+24
-36
lines changed

.github/workflows/run-tests.yaml

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ jobs:
4848
export MODEL_NAME="${{ matrix.model }}"
4949
export REF_NAME="${{ matrix.model }}"
5050
export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"
51-
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
5251
5352
# [Step 1] Set up a tiny test swarm (see https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm)
5453
@@ -61,27 +60,25 @@ jobs:
6160
6261
until [ -s bootstrap.log ]; do sleep 5; done # wait for DHT init
6362
64-
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 5 \
65-
--mean_balance_check_period 10 \
66-
--initial_peers $INITIAL_PEERS --throughput 1 &> server1.log &
63+
export RUN_SERVER="python -m petals.cli.run_server $MODEL_NAME \
64+
--device cpu --torch_dtype float32 --initial_peers $INITIAL_PEERS"
65+
export TENSOR_PARALLEL_ARGS="${{ matrix.model == 'bigscience/bloom-560m' && '--tensor_parallel_devices cpu cpu' || '' }}"
66+
67+
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 5 --throughput 1 --mean_balance_check_period 10 &> server1.log &
6768
SERVER1_PID=$!
6869
# ^-- rebalacing test: this server chooses blocks 0:5, then sees a gap in the swarm and moves there
6970
7071
sleep 10 # wait for the 1st server to choose blocks
7172
72-
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --block_indices 0:5 \
73-
--identity_path tests/server2.id \
74-
--initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
73+
$RUN_SERVER --adapters $ADAPTER_NAME --block_indices 0:5 --throughput 1 --identity_path tests/server2.id &> server2.log &
7574
SERVER2_PID=$!
7675
77-
python -m petals.cli.run_server $MODEL_NAME --adapters $ADAPTER_NAME --torch_dtype float32 --num_blocks 14 \
78-
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 \
79-
--initial_peers $INITIAL_PEERS --throughput auto &> server3.log &
76+
$RUN_SERVER --adapters $ADAPTER_NAME --num_blocks 14 --throughput auto \
77+
--attn_cache_tokens 2048 --max_chunk_size_bytes 1024 &> server3.log &
8078
SERVER3_PID=$!
8179
# ^-- chunking test
8280
83-
python -m petals.cli.run_server $MODEL_NAME $TENSOR_PARALLEL_ARGS --torch_dtype float32 --block_indices 0:2 \
84-
--initial_peers $INITIAL_PEERS --throughput auto &> server4.log &
81+
$RUN_SERVER $TENSOR_PARALLEL_ARGS --block_indices 0:2 --throughput auto &> server4.log &
8582
SERVER4_PID=$!
8683
# ^-- tensor parallelism test (not compatible with adapters yet)
8784
@@ -121,4 +118,3 @@ jobs:
121118
# [Step 4] Clean up
122119
123120
kill -s SIGINT $BOOTSTRAP_PID $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
124-
echo "Done!"

src/petals/client/inference_session.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,7 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
8484
break # this message means "done sending"
8585

8686
def step(
87-
self,
88-
inputs: torch.Tensor,
89-
prompts: Optional[torch.Tensor] = None,
90-
hypo_ids: Optional[torch.Tensor] = None,
91-
*,
92-
step_id: str,
87+
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
9388
) -> torch.Tensor:
9489
"""
9590
Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -114,21 +109,6 @@ def step(
114109
else:
115110
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
116111

117-
if prompts is None or is_dummy(prompts):
118-
prompts = DUMMY
119-
else:
120-
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
121-
assert prompts.shape[0] == self.num_blocks
122-
assert prompts.shape[1] in (inputs.shape[0], 1)
123-
assert prompts.shape[2] <= inputs.shape[1]
124-
assert prompts.shape[3] == inputs.shape[2]
125-
126-
if hypo_ids is None or is_dummy(hypo_ids):
127-
hypo_ids = DUMMY_INT64
128-
else:
129-
assert len(hypo_ids) == len(inputs)
130-
assert hypo_ids.dtype == torch.int64
131-
132112
# serialize inputs and put them into the queue
133113
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
134114

@@ -275,7 +255,9 @@ def __enter__(self) -> "InferenceSession":
275255
assert not self._closed and not self._server_sessions
276256
return self
277257

278-
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
258+
def step(
259+
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
260+
) -> torch.Tensor:
279261
assert not self._closed
280262
if torch.is_grad_enabled():
281263
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -285,11 +267,21 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k
285267
else:
286268
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
287269
assert prompts.shape[0] == self.num_blocks
270+
assert prompts.shape[1] in (inputs.shape[0], 1)
271+
assert prompts.shape[2] <= inputs.shape[1]
272+
assert prompts.shape[3] == inputs.shape[2]
273+
274+
if hypo_ids is None or is_dummy(hypo_ids):
275+
hypo_ids = DUMMY_INT64
276+
else:
277+
assert len(hypo_ids) == len(inputs)
278+
assert hypo_ids.dtype == torch.int64
288279

289280
inputs_device = inputs.device
290281
inputs_dtype = inputs.dtype
291282
inputs = inputs.cpu()
292283
prompts = prompts.cpu()
284+
hypo_ids = hypo_ids.cpu()
293285
step_id = str(uuid.uuid4())
294286

295287
n_input_tokens = inputs.shape[1]
@@ -310,7 +302,7 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k
310302

311303
server_session = self._server_sessions[server_idx]
312304
inputs = server_session.step(
313-
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
305+
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
314306
)
315307

316308
server_idx += 1

0 commit comments

Comments
 (0)