Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/backend/server/scheduler_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_schedule_status(self):
# todo rebalance status
status = (
NODE_STATUS_AVAILABLE
if self.scheduler.layer_allocator.has_full_active_pipeline()
if self.scheduler.layer_allocator.has_full_pipeline(active_only=True)
else NODE_STATUS_WAITING
)
logger.debug(f"SchedulerManage status queried: {status}")
Expand Down
6 changes: 5 additions & 1 deletion src/parallax/p2p/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def _announcer_thread():
try:
while not self.stop_event.is_set():
# Announce the range ID
should_sleep = True
try:
if self.scheduler_peer_id is not None:
response_future = self.scheduler_stub.node_update(
Expand Down Expand Up @@ -616,6 +617,8 @@ def _announcer_thread():
"Layer allocation updated. Executor will reload on next check. "
"Status set to INITIALIZING to prevent new requests."
)
# Skip sleep to immediately send next heartbeat with new status
should_sleep = False
else:
logger.warning(f"Heartbeat response: {response}")
else:
Expand All @@ -637,7 +640,8 @@ def _announcer_thread():
f"Failed to announce {self.prefix_id}_{self.lattica.peer_id()}: {e}"
)

time.sleep(10)
if should_sleep:
time.sleep(10)
except Exception as e:
logger.exception(f"Module announcer thread error: {e}")

Expand Down
54 changes: 30 additions & 24 deletions src/scheduling/layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def join(self, node: Node) -> None:
logger.debug("Joining node dynamically: %s", node.node_id)
self.declare(node)
lightest_layer = self.get_lightest_layer()
logger.debug("Lightest layer: %s", lightest_layer)
if lightest_layer is None:
raise ValueError("No layers to assign")

Expand Down Expand Up @@ -529,39 +530,44 @@ def _adjust_end_layer_for_tail(self, node: Node, proposed_start_layer: int) -> i

return end_layer

def has_full_pipeline(self) -> bool:
def has_full_pipeline(self, active_only: bool = False) -> bool:
"""Return True if there exists at least one pipeline covering [0, num_total_layers).

Checks whether we can chain contiguous node allocations starting at 0 to reach L.
This requires that there exists at least one node starting at layer 0 and a chain
of contiguous node ranges that reaches num_total_layers.
"""
total_layers = self.num_total_layers
layer_count: Dict[int, int] = {}
for _, (s, e) in self.node_allocation.items():
if s is None or e is None:
continue
for layer in range(s, e):
layer_count[layer] = layer_count.get(layer, 0) + 1

for layer in range(total_layers):
if layer not in layer_count or layer_count[layer] == 0:
return False
return True

def has_full_active_pipeline(self) -> bool:
"""Return True if there exists at least one active pipeline covering [0, num_total_layers)."""
total_layers = self.num_total_layers
layer_count: Dict[int, int] = {}
# Build index of nodes by start_layer
start_to_nodes: Dict[int, List[Node]] = {}
for node_id, (s, e) in self.node_allocation.items():
if self.node_id_to_node[node_id].is_active is False:
continue
if s is None or e is None:
continue
for layer in range(s, e):
layer_count[layer] = layer_count.get(layer, 0) + 1
for layer in range(total_layers):
if layer not in layer_count or layer_count[layer] == 0:
return False
return True
node = self.node_id_to_node.get(node_id)
if node is None or (active_only and not node.is_active):
continue
start_to_nodes.setdefault(s, []).append(node)

# Must have at least one node starting at layer 0
if not start_to_nodes.get(0):
return False

# DFS to check if we can reach total_layers from any head node
def can_reach_target(current_end: int) -> bool:
if current_end >= total_layers:
return current_end == total_layers

for nxt in start_to_nodes.get(current_end, []):
if nxt.end_layer and nxt.end_layer > current_end:
if can_reach_target(nxt.end_layer):
return True
return False

return any(
head.end_layer and can_reach_target(head.end_layer)
for head in start_to_nodes.get(0, [])
)

def layer_replication_stats(self) -> Tuple[int, int, float]:
"""Return (min, max, avg) number of nodes hosting each layer.
Expand Down
12 changes: 6 additions & 6 deletions src/scheduling/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ def decoder_layer_io_bytes(
ffn_params *= self.num_local_experts
kv_cache_size = 0

logger.debug(
"Model Info ffn_params=%d, kv_cache_size=%d, attention_params=%d",
ffn_params,
kv_cache_size,
attention_params,
)
# logger.debug(
# "Model Info ffn_params=%d, kv_cache_size=%d, attention_params=%d",
# ffn_params,
# kv_cache_size,
# attention_params,
# )
return round(ffn_params + kv_cache_size + attention_params)

def lm_head_flops(self, target_seq_len: int = 1) -> int:
Expand Down
10 changes: 5 additions & 5 deletions src/scheduling/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,11 @@ def get_decoder_layer_capacity(
if not (include_input_embed and self.model_info.tie_embedding):
available_memory_bytes -= self.model_info.embedding_io_bytes

logger.debug(
"Node available_memory_bytes=%d, decoder_layer_io_bytes=%d",
available_memory_bytes,
self.model_info.decoder_layer_io_bytes(roofline=False),
)
# logger.debug(
# "Node available_memory_bytes=%d, decoder_layer_io_bytes=%d",
# available_memory_bytes,
# self.model_info.decoder_layer_io_bytes(roofline=False),
# )
if self.hardware.device == "mlx":
# For mlx, consider mlx bit factor
return floor(
Expand Down
11 changes: 8 additions & 3 deletions src/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def list_node_allocations(self) -> List[Tuple[str, int, int]]:
"""List the allocations of all nodes."""
return self.layer_allocator.list_node_allocations()

# Warm-up and re-shard
def _run_warmup_and_truncate(self) -> None:
"""Run a brief warm-up to detect truncation points and shrink shards.

Expand Down Expand Up @@ -316,13 +315,19 @@ def leave(self, node_id: str) -> None:
f"Mixed assignment detected ({manual_count} manual, {total_count - manual_count} automatic); skipping rebalance"
)
else:
# All nodes are automatic, proceed with rebalance
# All nodes are automatic, try to recover pipeline through adjustment first
self._bootstrapped = False
self._bootstrapped_event.clear()
for n in self.nodes:
if n.start_layer is not None and n.end_layer is not None:
self.layer_allocator.deallocate(n)
self.layer_allocator.global_allocation()
success = self.layer_allocator.global_allocation()
if not success:
logger.warning("Global rebalance failed to produce a full pipeline")
else:
logger.debug("Global rebalance completed successfully")
self._bootstrapped = True
self._bootstrapped_event.set()

with self._node_count_cv:
self._node_count_cv.notify_all()
Expand Down