Skip to content

Commit f6be1f7

Browse files
committed
2 parents 9027270 + 6ef6bf5 commit f6be1f7

File tree

10 files changed

+85
-43
lines changed

10 files changed

+85
-43
lines changed

src/petals/client/inference_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) ->
343343
n_prev_spans = len(self._server_sessions)
344344
update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
345345
if attempt_no >= 1:
346-
logger.info(
346+
logger.debug(
347347
f"Due to a server failure, remote attention caches "
348348
f"from block {block_idx} to {update_end} will be regenerated"
349349
)

src/petals/client/remote_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def generate(
6969
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
7070
):
7171
self._fix_generate_kwargs(kwargs)
72+
if inputs is None:
73+
inputs = kwargs.pop("input_ids", None)
7274

7375
if session is not None:
7476
# If a session specified explicitly, use it
@@ -125,7 +127,7 @@ def generate(
125127
return result
126128

127129
@staticmethod
128-
def _fix_generate_kwargs(kwargs: dict) -> dict:
130+
def _fix_generate_kwargs(kwargs: dict):
129131
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
130132
if "max_length" in kwargs and kwargs["max_length"] is None:
131133
del kwargs["max_length"]
@@ -135,8 +137,6 @@ def _fix_generate_kwargs(kwargs: dict) -> dict:
135137
if isinstance(do_sample, int):
136138
kwargs["do_sample"] = bool(do_sample)
137139

138-
return kwargs
139-
140140
@staticmethod
141141
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
142142
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)

src/petals/data_structures.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ class ServerState(Enum):
2020
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
2121

2222

23+
@pydantic.dataclasses.dataclass
24+
class ModelInfo:
25+
num_blocks: int
26+
repository: Optional[str] = None
27+
28+
def to_dict(self) -> dict:
29+
return dataclasses.asdict(self)
30+
31+
@classmethod
32+
def from_dict(cls, source: dict):
33+
return cls(**source)
34+
35+
2336
@pydantic.dataclasses.dataclass
2437
class ServerInfo:
2538
state: ServerState

src/petals/models/bloom/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ def from_pretrained(
3030
if loading_from_repo and dht_prefix is None:
3131
# We need "-petals" for backward compatibility with Petals < 1.2.0
3232
dht_prefix = str(model_name_or_path) + "-petals"
33+
dht_prefix = dht_prefix.replace(".", "-")
3334
logger.info(f"Using DHT prefix: {dht_prefix}")
3435
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)

src/petals/models/llama/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def from_pretrained(
3535
if loading_from_repo and dht_prefix is None:
3636
dht_prefix = str(model_name_or_path)
3737
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
38+
dht_prefix = dht_prefix.replace(".", "-")
3839
if not dht_prefix.endswith("-hf"):
3940
dht_prefix += "-hf"
4041
logger.info(f"Using DHT prefix: {dht_prefix}")

src/petals/server/memory_cache.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[fl
3131
self.max_alloc_timeout = max_alloc_timeout
3232
self._lock_metadata = mp.Lock()
3333
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
34-
self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False)
34+
self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True)
3535
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
3636
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
3737
self.runtime_pid = os.getpid()
@@ -138,7 +138,8 @@ async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float])
138138
start_time = time.perf_counter()
139139
loop = asyncio.get_event_loop()
140140

141-
self.enqueued_size_bytes += alloc_size
141+
with self._enqueued_size.get_lock():
142+
self._enqueued_size.value += alloc_size
142143
allocated = False
143144
try:
144145
context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
@@ -155,13 +156,15 @@ async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float])
155156
await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
156157

157158
allocated = True
158-
self.enqueued_size_bytes -= alloc_size
159+
with self._enqueued_size.get_lock():
160+
self._enqueued_size.value -= alloc_size
159161
yield
160162
except asyncio.TimeoutError:
161163
raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
162164
finally:
163165
if not allocated:
164-
self.enqueued_size_bytes -= alloc_size
166+
with self._enqueued_size.get_lock():
167+
self._enqueued_size.value -= alloc_size
165168

166169
def _free(self, alloc_size: int, alloc_task: asyncio.Task):
167170
if alloc_task.exception() is not None:

src/petals/server/server.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gc
44
import math
55
import multiprocessing as mp
6+
import os
67
import random
78
import threading
89
import time
@@ -21,7 +22,7 @@
2122

2223
import petals
2324
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
24-
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
25+
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
2526
from petals.server import block_selection
2627
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
2728
from petals.server.block_utils import get_block_size, resolve_block_dtype
@@ -259,11 +260,15 @@ def __init__(
259260
using_relay=reachable_via_relay,
260261
**throughput_info,
261262
)
263+
self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)
264+
if not os.path.isdir(converted_model_name_or_path):
265+
self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path
262266

263267
self.balance_quality = balance_quality
264268
self.mean_balance_check_period = mean_balance_check_period
265269
self.mean_block_selection_delay = mean_block_selection_delay
266270

271+
self.module_container = None
267272
self.stop = threading.Event()
268273

269274
def _choose_num_blocks(self) -> int:
@@ -329,6 +334,7 @@ def run(self):
329334
block_config=self.block_config,
330335
attn_cache_bytes=self.attn_cache_bytes,
331336
server_info=self.server_info,
337+
model_info=self.model_info,
332338
block_indices=block_indices,
333339
num_handlers=self.num_handlers,
334340
min_batch_size=self.min_batch_size,
@@ -377,7 +383,7 @@ def run(self):
377383
self._clean_memory_and_fds()
378384

379385
def _clean_memory_and_fds(self):
380-
del self.module_container
386+
self.module_container = None
381387
gc.collect() # In particular, this closes unused file descriptors
382388

383389
if self.device.type == "cuda":
@@ -410,8 +416,10 @@ def _should_choose_other_blocks(self) -> bool:
410416
module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
411417
return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
412418

413-
def shutdown(self):
419+
def shutdown(self, timeout: Optional[float] = 5):
414420
self.stop.set()
421+
if self.module_container is not None and self.module_container.is_alive():
422+
self.module_container.join(timeout)
415423

416424
if self.reachability_protocol is not None:
417425
self.reachability_protocol.shutdown()
@@ -433,6 +441,7 @@ def create(
433441
block_config: PretrainedConfig,
434442
attn_cache_bytes: int,
435443
server_info: ServerInfo,
444+
model_info: ModelInfo,
436445
block_indices: List[int],
437446
min_batch_size: int,
438447
max_batch_size: int,
@@ -460,6 +469,7 @@ def create(
460469
module_uids,
461470
dht,
462471
server_info,
472+
model_info,
463473
block_config=block_config,
464474
memory_cache=memory_cache,
465475
update_period=update_period,
@@ -668,6 +678,7 @@ def __init__(
668678
module_uids: List[str],
669679
dht: DHT,
670680
server_info: ServerInfo,
681+
model_info: ModelInfo,
671682
*,
672683
block_config: PretrainedConfig,
673684
memory_cache: MemoryCache,
@@ -680,6 +691,7 @@ def __init__(
680691
self.module_uids = module_uids
681692
self.dht = dht
682693
self.server_info = server_info
694+
self.model_info = model_info
683695
self.memory_cache = memory_cache
684696

685697
self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
@@ -690,10 +702,10 @@ def __init__(
690702
self.trigger = threading.Event()
691703

692704
self.max_pinged = max_pinged
693-
dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
705+
self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
694706
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
695707
start_block, end_block = min(block_indices), max(block_indices) + 1
696-
self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
708+
self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
697709
self.ping_aggregator = PingAggregator(self.dht)
698710

699711
def run(self) -> None:
@@ -717,6 +729,13 @@ def run(self) -> None:
717729
)
718730
if self.server_info.state == ServerState.OFFLINE:
719731
break
732+
if not self.dht_prefix.startswith("_"): # Not private
733+
self.dht.store(
734+
key="_petals.models",
735+
subkey=self.dht_prefix,
736+
value=self.model_info.to_dict(),
737+
expiration_time=get_dht_time() + self.expiration,
738+
)
720739

721740
delay = self.update_period - (time.perf_counter() - start_time)
722741
if delay < 0:

src/petals/server/task_pool.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111
from hivemind import get_logger
12-
from hivemind.moe.server.task_pool import TaskPoolBase
1312
from hivemind.utils.mpfuture import ALL_STATES, MPFuture
1413

1514
logger = get_logger(__name__)
@@ -27,7 +26,7 @@ def uid(self) -> int:
2726
return self.future._uid
2827

2928

30-
class PrioritizedTaskPool(TaskPoolBase):
29+
class PrioritizedTaskPool(threading.Thread):
3130
"""
3231
Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
3332
returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
@@ -57,52 +56,41 @@ def __init__(
5756
daemon=True,
5857
start=False,
5958
):
60-
super().__init__(process_func, daemon=daemon, name=name)
59+
super().__init__(daemon=daemon, name=name)
60+
self.process_func = process_func
61+
# the lower the priority is, the more urgent it is to process this pool
62+
self._priority = mp.Value(ctypes.c_double, 1.0)
63+
6164
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
6265
self.device = device
6366

6467
self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
6568
self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
6669

67-
self._prioritizer_thread = threading.Thread(
68-
name=self.name + "_prioritizer",
69-
target=self._prioritize_tasks,
70-
args=[self.submitted_tasks, self._ordered_tasks],
71-
daemon=True,
72-
)
7370
self._dispatched_tasks = {}
7471
self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
7572
self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
7673
self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
7774

78-
self._stop = mp.Event()
7975
if start:
8076
self.start()
8177

82-
@staticmethod
83-
def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
78+
def run(self):
8479
"""Read tasks from incoming queue and put them into a local priority queue"""
8580
while True:
86-
task = submitted_tasks.get()
81+
task = self.submitted_tasks.get()
8782
if task is None:
8883
logger.debug("Shutting down prioritizer thread")
8984
break
9085

91-
ordered_tasks.put(task, block=True)
92-
93-
def start(self):
94-
assert not self.is_alive() and not self._prioritizer_thread.is_alive()
95-
self._prioritizer_thread.start()
96-
super().start()
86+
self._ordered_tasks.put(task, block=True)
9787

98-
def shutdown(self, timeout: float = 3):
99-
self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread
100-
self._stop.set()
88+
def terminate(self):
89+
"""An alias for hivemind.Runtime that assumes that each TaskPool is a process"""
90+
self.shutdown()
10191

102-
self.join(timeout)
103-
if self.is_alive():
104-
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
105-
self.terminate()
92+
def shutdown(self):
93+
self.submitted_tasks.put(None) # Shuts down self.run()
10694

10795
def submit_task(self, *args: Any, priority: float = 0.0) -> MPFuture:
10896
"""Add task to this pool's queue, return Future for its output"""
@@ -163,9 +151,6 @@ def send_exception_from_runtime(self, uid: int, exception: BaseException):
163151
else:
164152
task.future.set_exception(exception)
165153

166-
def run(self, *args, **kwargs):
167-
self._stop.wait()
168-
169154
@property
170155
def empty(self):
171156
return not self.batch_receiver.poll()

tests/test_full_model.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,23 @@ def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, n
149149
outputs = make_generate_calls(model, inputs, **options)
150150
ref_outputs = ref_model.generate(inputs, **options)
151151
assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"
152+
153+
154+
@pytest.mark.forked
155+
def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
156+
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")
157+
assert inputs.keys() == {"input_ids", "attention_mask"}
158+
159+
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
160+
ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens)
161+
assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
162+
163+
with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):
164+
outputs = torch.cat(
165+
[
166+
model.generate(**inputs, max_new_tokens=2),
167+
model.generate(None, max_new_tokens=max_new_tokens - 2),
168+
],
169+
dim=1,
170+
)
171+
assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"

tests/test_remote_sequential.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,6 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
126126

127127
(outputs_ref * output_proj).sum().backward()
128128
assert input_prompts_ref.grad is not None
129-
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2)
129+
assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=3e-2)
130130
assert intermediate_prompts_ref.grad is not None
131131
assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)

0 commit comments

Comments
 (0)