Skip to content

Commit 6ef6bf5

Browse files
authored
Create model index in DHT (#491)
This PR creates an index of models hosted in the swarm - it is useful to know which custom models users run and display them at https://health.petals.dev as "not officially supported" models.
1 parent 6bb3f54 commit 6ef6bf5

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

src/petals/client/inference_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) ->
343343
n_prev_spans = len(self._server_sessions)
344344
update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
345345
if attempt_no >= 1:
346-
logger.info(
346+
logger.debug(
347347
f"Due to a server failure, remote attention caches "
348348
f"from block {block_idx} to {update_end} will be regenerated"
349349
)

src/petals/data_structures.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ class ServerState(Enum):
2020
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
2121

2222

23+
@pydantic.dataclasses.dataclass
24+
class ModelInfo:
25+
num_blocks: int
26+
repository: Optional[str] = None
27+
28+
def to_dict(self) -> dict:
29+
return dataclasses.asdict(self)
30+
31+
@classmethod
32+
def from_dict(cls, source: dict):
33+
return cls(**source)
34+
35+
2336
@pydantic.dataclasses.dataclass
2437
class ServerInfo:
2538
state: ServerState

src/petals/server/server.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gc
44
import math
55
import multiprocessing as mp
6+
import os
67
import random
78
import threading
89
import time
@@ -21,7 +22,7 @@
2122

2223
import petals
2324
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
24-
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
25+
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
2526
from petals.server import block_selection
2627
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
2728
from petals.server.block_utils import get_block_size, resolve_block_dtype
@@ -259,6 +260,9 @@ def __init__(
259260
using_relay=reachable_via_relay,
260261
**throughput_info,
261262
)
263+
self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)
264+
if not os.path.isdir(converted_model_name_or_path):
265+
self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path
262266

263267
self.balance_quality = balance_quality
264268
self.mean_balance_check_period = mean_balance_check_period
@@ -330,6 +334,7 @@ def run(self):
330334
block_config=self.block_config,
331335
attn_cache_bytes=self.attn_cache_bytes,
332336
server_info=self.server_info,
337+
model_info=self.model_info,
333338
block_indices=block_indices,
334339
num_handlers=self.num_handlers,
335340
min_batch_size=self.min_batch_size,
@@ -436,6 +441,7 @@ def create(
436441
block_config: PretrainedConfig,
437442
attn_cache_bytes: int,
438443
server_info: ServerInfo,
444+
model_info: ModelInfo,
439445
block_indices: List[int],
440446
min_batch_size: int,
441447
max_batch_size: int,
@@ -463,6 +469,7 @@ def create(
463469
module_uids,
464470
dht,
465471
server_info,
472+
model_info,
466473
block_config=block_config,
467474
memory_cache=memory_cache,
468475
update_period=update_period,
@@ -671,6 +678,7 @@ def __init__(
671678
module_uids: List[str],
672679
dht: DHT,
673680
server_info: ServerInfo,
681+
model_info: ModelInfo,
674682
*,
675683
block_config: PretrainedConfig,
676684
memory_cache: MemoryCache,
@@ -683,6 +691,7 @@ def __init__(
683691
self.module_uids = module_uids
684692
self.dht = dht
685693
self.server_info = server_info
694+
self.model_info = model_info
686695
self.memory_cache = memory_cache
687696

688697
self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
@@ -693,10 +702,10 @@ def __init__(
693702
self.trigger = threading.Event()
694703

695704
self.max_pinged = max_pinged
696-
dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
705+
self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
697706
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
698707
start_block, end_block = min(block_indices), max(block_indices) + 1
699-
self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
708+
self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
700709
self.ping_aggregator = PingAggregator(self.dht)
701710

702711
def run(self) -> None:
@@ -720,6 +729,13 @@ def run(self) -> None:
720729
)
721730
if self.server_info.state == ServerState.OFFLINE:
722731
break
732+
if not self.dht_prefix.startswith("_"): # Not private
733+
self.dht.store(
734+
key="_petals.models",
735+
subkey=self.dht_prefix,
736+
value=self.model_info.to_dict(),
737+
expiration_time=get_dht_time() + self.expiration,
738+
)
723739

724740
delay = self.update_period - (time.perf_counter() - start_time)
725741
if delay < 0:

0 commit comments

Comments
 (0)