Skip to content

Commit 5ce4f1a

Browse files
authored
Store (start_block, end_block) in each DHT record for reliability (#510)
This PR fixes gaps in the DHT server info caused by unavailable DHT keys. Now, one DHT key is enough to get info about all blocks hosted by a server - so we'll see info until all keys are unavailable. Also, this PR refactors `petals.client.routing` and `petals.server.block_selection` modules to use the common `compute_spans()` function (defined in `petals.utils.dht`) and `RemoteSpanInfo` class (defined in `petals.data_structures`).
1 parent 1586216 commit 5ce4f1a

File tree

6 files changed

+119
-137
lines changed

6 files changed

+119
-137
lines changed
Lines changed: 13 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import dataclasses
22
import time
3-
from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
3+
from typing import Iterable, List, Optional, Tuple
44

55
from hivemind import get_logger
66

77
from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
8+
from petals.utils.dht import compute_spans
89

910
logger = get_logger(__name__)
1011

1112

12-
T = TypeVar("T")
13-
14-
1513
@dataclasses.dataclass
1614
class RemoteSequenceInfo:
1715
"""
@@ -30,7 +28,7 @@ class RemoteSequenceInfo:
3028
last_updated_time: Optional[float]
3129

3230
@classmethod
33-
def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
31+
def make_empty(cls, block_uids: Iterable[ModuleUID]) -> "RemoteSequenceInfo":
3432
block_uids = tuple(block_uids)
3533
empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
3634
empty_spans = tuple([] for _ in range(len(block_uids)))
@@ -39,68 +37,31 @@ def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
3937
def __getitem__(self, ix: slice):
4038
assert isinstance(ix, slice)
4139
block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
42-
spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
40+
spans_by_priority, spans_containing_block = self._sort_spans(block_infos)
4341
return RemoteSequenceInfo(
4442
block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
4543
)
4644

4745
def __len__(self):
4846
return len(self.block_uids)
4947

50-
def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
48+
def update_(self, new_block_infos: List[RemoteModuleInfo]):
5149
assert len(new_block_infos) == len(self.block_uids)
5250
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
53-
if info is None:
54-
logger.debug(f"Found no block info for block {uid}")
55-
continue
56-
if not isinstance(info, RemoteModuleInfo):
57-
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
58-
continue
59-
if not info.servers:
60-
logger.debug(f"Found no active peers for block {uid}")
61-
continue
62-
if info.uid != uid:
63-
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
64-
continue
51+
assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}"
6552
self.block_infos[block_index].servers = info.servers
6653

67-
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
54+
self.spans_by_priority, self.spans_containing_block = self._sort_spans(self.block_infos)
6855
self.last_updated_time = time.perf_counter()
6956

7057
@staticmethod
71-
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
72-
closed_spans = []
73-
active_spans = {}
74-
for block_index, info in enumerate(block_infos):
75-
if info is not None:
76-
for peer_id, server_info in info.servers.items():
77-
if server_info.state != ServerState.ONLINE:
78-
continue
79-
if peer_id not in active_spans:
80-
active_spans[peer_id] = RemoteSpanInfo(
81-
peer_id=peer_id,
82-
start=block_index,
83-
end=block_index + 1,
84-
server_info=server_info,
85-
)
86-
else: # peer_id in active_spans
87-
active_spans[peer_id].end = block_index + 1
88-
89-
for peer_id in list(active_spans.keys()):
90-
if (
91-
info is None
92-
or peer_id not in info.servers
93-
or info.servers[peer_id].state != ServerState.ONLINE
94-
or block_index == len(block_infos) - 1
95-
):
96-
closed_spans.append(active_spans.pop(peer_id))
97-
assert not active_spans, f"spans: {active_spans}"
98-
99-
closed_spans.sort(key=lambda span: span.length, reverse=True)
58+
def _sort_spans(block_infos: List[RemoteModuleInfo]):
59+
spans_by_priority = list(compute_spans(block_infos, min_state=ServerState.ONLINE).values())
60+
spans_by_priority.sort(key=lambda span: span.length, reverse=True)
10061

101-
spans_containing_block = tuple(list() for _ in range(len(block_infos)))
102-
for span in closed_spans:
62+
spans_containing_block = tuple([] for _ in range(len(block_infos)))
63+
for span in spans_by_priority:
10364
for block_index in range(span.start, span.end):
10465
spans_containing_block[block_index].append(span)
10566

106-
return closed_spans, spans_containing_block
67+
return spans_by_priority, spans_containing_block

src/petals/client/routing/sequence_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def __init__(
117117
if state.sequence_info.last_updated_time is not None:
118118
assert block_uids == state.sequence_info.block_uids
119119
self._thread.ready.set() # no need to await the first dht fetch
120-
self._need_latest_infos = True
121120

122121
@staticmethod
123122
def _peer_ids_to_set(peer_ids: Optional[Sequence[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
@@ -346,9 +345,6 @@ def _update(self):
346345
)
347346

348347
for block_info in new_block_infos:
349-
if not block_info:
350-
continue
351-
352348
# Apply allow and block lists
353349
block_info.servers = {
354350
peer_id: server_info

src/petals/data_structures.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,15 @@
1111
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
1212

1313

14-
class ServerState(Enum):
15-
OFFLINE = 0
16-
JOINING = 1
17-
ONLINE = 2
18-
19-
20-
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
14+
def parse_uid(uid: ModuleUID) -> Tuple[str, int]:
15+
assert CHAIN_DELIMITER not in uid, "parse_uid() does not support chained UIDs"
16+
dht_prefix, index = uid.split(UID_DELIMITER)
17+
return dht_prefix, int(index)
2118

2219

2320
@pydantic.dataclasses.dataclass
2421
class ModelInfo:
25-
num_blocks: int
22+
num_blocks: pydantic.conint(ge=1, strict=True)
2623
repository: Optional[str] = None
2724

2825
def to_dict(self) -> dict:
@@ -33,11 +30,23 @@ def from_dict(cls, source: dict):
3330
return cls(**source)
3431

3532

33+
class ServerState(Enum):
34+
OFFLINE = 0
35+
JOINING = 1
36+
ONLINE = 2
37+
38+
39+
RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
40+
41+
3642
@pydantic.dataclasses.dataclass
3743
class ServerInfo:
3844
state: ServerState
3945
throughput: RPS
4046

47+
start_block: Optional[pydantic.conint(ge=0, strict=True)] = None
48+
end_block: Optional[pydantic.conint(ge=0, strict=True)] = None
49+
4150
public_name: Optional[str] = None
4251
version: Optional[str] = None
4352

@@ -83,9 +92,17 @@ class RemoteSpanInfo:
8392
server_info: ServerInfo
8493

8594
@property
86-
def length(self):
95+
def length(self) -> int:
8796
return self.end - self.start
8897

98+
@property
99+
def state(self) -> ServerState:
100+
return self.server_info.state
101+
102+
@property
103+
def throughput(self) -> float:
104+
return self.server_info.throughput
105+
89106

90107
RPCInfo = Dict[str, Any]
91108

src/petals/server/block_selection.py

Lines changed: 24 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,50 @@
1-
from dataclasses import dataclass
2-
from typing import Dict, List, Optional, Tuple
1+
from typing import Dict, List
32

43
import numpy as np
54
from hivemind import PeerID, get_logger
65

7-
from petals.data_structures import RemoteModuleInfo, ServerState
8-
9-
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
6+
from petals.data_structures import RemoteModuleInfo, RemoteSpanInfo, ServerState
7+
from petals.utils.dht import compute_spans
108

119
logger = get_logger(__name__)
1210

1311

14-
@dataclass
15-
class Span:
16-
start: int
17-
end: int
18-
throughput: float
19-
state: ServerState
20-
21-
@property
22-
def length(self):
23-
return self.end - self.start
24-
25-
def move_to(self, new_start: int) -> None:
26-
self.start, self.end = new_start, new_start + self.length
27-
28-
29-
def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
30-
spans = {}
31-
throughputs = np.zeros(len(module_infos))
32-
for block, module in enumerate(module_infos):
33-
if module is None:
34-
continue
35-
36-
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
37-
# If the order were not defined, we would get slightly different values due to floating point errors,
38-
# which may cause excess block replacements.
39-
for peer_id, server in sorted(module.servers.items()):
40-
if server.state == ServerState.OFFLINE:
41-
continue
12+
def compute_throughputs(spans: Dict[PeerID, RemoteSpanInfo], *, total_blocks: int) -> np.ndarray:
13+
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
14+
# If the order were not defined, we would get slightly different values due to floating point errors,
15+
# which may cause excess block replacements.
4216

43-
if peer_id in spans:
44-
spans[peer_id].start = min(spans[peer_id].start, block)
45-
spans[peer_id].end = max(spans[peer_id].start, block + 1)
46-
else:
47-
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
48-
49-
throughputs[block] += server.throughput
50-
51-
return spans, throughputs
17+
throughputs = np.zeros(total_blocks)
18+
for span in sorted(spans.values(), key=lambda span: span.peer_id):
19+
throughputs[span.start : span.end] += span.throughput
20+
return throughputs
5221

5322

5423
def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
5524
options = ((sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1))
5625
return min(options)[-1]
5726

5827

59-
def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
60-
_, throughputs = compute_spans(module_infos)
28+
def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]:
29+
spans = compute_spans(module_infos, min_state=ServerState.JOINING)
30+
throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
31+
6132
start = _choose_best_start(throughputs, num_blocks)
6233
return list(range(start, start + num_blocks))
6334

6435

36+
def _move_span(span: RemoteSpanInfo, new_start: int):
37+
span.start, span.end = new_start, new_start + span.length
38+
39+
6540
def should_choose_other_blocks(
66-
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
41+
local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float
6742
) -> bool:
6843
if balance_quality > 1.0:
6944
return True # Forces rebalancing on each check (may be used for debugging purposes)
7045

71-
spans, throughputs = compute_spans(module_infos)
46+
spans = compute_spans(module_infos, min_state=ServerState.JOINING)
47+
throughputs = compute_throughputs(spans, total_blocks=len(module_infos))
7248
initial_throughput = throughputs.min()
7349
eps = 1e-3
7450

@@ -88,7 +64,7 @@ def should_choose_other_blocks(
8864
return False # This server is on its best place already
8965

9066
throughputs[local_span.start : local_span.end] += local_span.throughput * eps
91-
local_span.move_to(new_start)
67+
_move_span(local_span, new_start)
9268
throughputs[local_span.start : local_span.end] += local_span.throughput
9369

9470
moved = True
@@ -105,7 +81,7 @@ def should_choose_other_blocks(
10581

10682
throughputs[span.start : span.end] += span.throughput * eps
10783
if span.start != new_start:
108-
span.move_to(new_start)
84+
_move_span(span, new_start)
10985
moved = True
11086
throughputs[span.start : span.end] += span.throughput
11187

src/petals/server/server.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import petals
2525
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
26-
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
26+
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState, parse_uid
2727
from petals.server import block_selection
2828
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
2929
from petals.server.block_utils import get_block_size, resolve_block_dtype
@@ -220,11 +220,10 @@ def __init__(
220220
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
221221
if block_indices is not None:
222222
try:
223-
first_block_index, last_block_index = block_indices.split(":")
224-
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
223+
start_block, end_block = [int(index.strip()) for index in block_indices.split(":")]
225224
except Exception as e:
226225
raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
227-
block_indices = range(first_block_index, last_block_index)
226+
block_indices = range(start_block, end_block)
228227
num_blocks = len(block_indices)
229228
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
230229

@@ -703,11 +702,16 @@ def __init__(
703702
self.expiration = expiration
704703
self.trigger = threading.Event()
705704

705+
self.dht_prefix = parse_uid(module_uids[0])[0]
706+
block_indices = [parse_uid(uid)[1] for uid in module_uids]
707+
self.server_info.start_block = min(block_indices)
708+
self.server_info.end_block = max(block_indices) + 1
709+
706710
self.max_pinged = max_pinged
707-
self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
708-
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
709-
start_block, end_block = min(block_indices), max(block_indices) + 1
710-
self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
711+
self.next_uids = [
712+
f"{self.dht_prefix}{UID_DELIMITER}{i}"
713+
for i in range(self.server_info.start_block + 1, self.server_info.end_block + 1)
714+
]
711715
self.ping_aggregator = PingAggregator(self.dht)
712716

713717
def run(self) -> None:
@@ -755,12 +759,11 @@ def announce(self, state: ServerState) -> None:
755759

756760
def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
757761
module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
758-
middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
762+
middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers}
759763
pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
760764
pinged_servers.discard(self.dht.peer_id)
761-
if module_infos[-1] is not None:
762-
# Sample servers hosting the block after the last one (most likely continuations) separately
763-
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
765+
# Sample servers hosting the block after the last one (most likely continuations) separately
766+
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
764767
self.ping_aggregator.ping(list(pinged_servers))
765768

766769

0 commit comments

Comments
 (0)