3
3
import gc
4
4
import math
5
5
import multiprocessing as mp
6
+ import os
6
7
import random
7
8
import threading
8
9
import time
21
22
22
23
import petals
23
24
from petals .constants import DTYPE_MAP , PUBLIC_INITIAL_PEERS
24
- from petals .data_structures import CHAIN_DELIMITER , UID_DELIMITER , ServerInfo , ServerState
25
+ from petals .data_structures import CHAIN_DELIMITER , UID_DELIMITER , ModelInfo , ServerInfo , ServerState
25
26
from petals .server import block_selection
26
27
from petals .server .backend import TransformerBackend , merge_inference_pools_inplace
27
28
from petals .server .block_utils import get_block_size , resolve_block_dtype
@@ -259,6 +260,9 @@ def __init__(
259
260
using_relay = reachable_via_relay ,
260
261
** throughput_info ,
261
262
)
263
+ self .model_info = ModelInfo (num_blocks = self .block_config .num_hidden_layers )
264
+ if not os .path .isdir (converted_model_name_or_path ):
265
+ self .model_info .repository = "https://huggingface.co/" + converted_model_name_or_path
262
266
263
267
self .balance_quality = balance_quality
264
268
self .mean_balance_check_period = mean_balance_check_period
@@ -330,6 +334,7 @@ def run(self):
330
334
block_config = self .block_config ,
331
335
attn_cache_bytes = self .attn_cache_bytes ,
332
336
server_info = self .server_info ,
337
+ model_info = self .model_info ,
333
338
block_indices = block_indices ,
334
339
num_handlers = self .num_handlers ,
335
340
min_batch_size = self .min_batch_size ,
@@ -436,6 +441,7 @@ def create(
436
441
block_config : PretrainedConfig ,
437
442
attn_cache_bytes : int ,
438
443
server_info : ServerInfo ,
444
+ model_info : ModelInfo ,
439
445
block_indices : List [int ],
440
446
min_batch_size : int ,
441
447
max_batch_size : int ,
@@ -463,6 +469,7 @@ def create(
463
469
module_uids ,
464
470
dht ,
465
471
server_info ,
472
+ model_info ,
466
473
block_config = block_config ,
467
474
memory_cache = memory_cache ,
468
475
update_period = update_period ,
@@ -671,6 +678,7 @@ def __init__(
671
678
module_uids : List [str ],
672
679
dht : DHT ,
673
680
server_info : ServerInfo ,
681
+ model_info : ModelInfo ,
674
682
* ,
675
683
block_config : PretrainedConfig ,
676
684
memory_cache : MemoryCache ,
@@ -683,6 +691,7 @@ def __init__(
683
691
self .module_uids = module_uids
684
692
self .dht = dht
685
693
self .server_info = server_info
694
+ self .model_info = model_info
686
695
self .memory_cache = memory_cache
687
696
688
697
self .bytes_per_token = block_config .hidden_size * get_size_in_bytes (DTYPE_MAP [server_info .torch_dtype ])
@@ -693,10 +702,10 @@ def __init__(
693
702
self .trigger = threading .Event ()
694
703
695
704
self .max_pinged = max_pinged
696
- dht_prefix = module_uids [0 ].split (UID_DELIMITER )[0 ]
705
+ self . dht_prefix = module_uids [0 ].split (UID_DELIMITER )[0 ]
697
706
block_indices = [int (uid .split (UID_DELIMITER )[- 1 ]) for uid in module_uids ]
698
707
start_block , end_block = min (block_indices ), max (block_indices ) + 1
699
- self .next_uids = [f"{ dht_prefix } { UID_DELIMITER } { i } " for i in range (start_block + 1 , end_block + 1 )]
708
+ self .next_uids = [f"{ self . dht_prefix } { UID_DELIMITER } { i } " for i in range (start_block + 1 , end_block + 1 )]
700
709
self .ping_aggregator = PingAggregator (self .dht )
701
710
702
711
def run (self ) -> None :
@@ -720,6 +729,13 @@ def run(self) -> None:
720
729
)
721
730
if self .server_info .state == ServerState .OFFLINE :
722
731
break
732
+ if not self .dht_prefix .startswith ("_" ): # Not private
733
+ self .dht .store (
734
+ key = "_petals.models" ,
735
+ subkey = self .dht_prefix ,
736
+ value = self .model_info .to_dict (),
737
+ expiration_time = get_dht_time () + self .expiration ,
738
+ )
723
739
724
740
delay = self .update_period - (time .perf_counter () - start_time )
725
741
if delay < 0 :
0 commit comments