Skip to content

Commit ee90522

Browse files
committed
⚡ Remove intermediate graph allocation
1 parent aeb0e8a commit ee90522

1 file changed

Lines changed: 17 additions & 11 deletions

File tree

python/swiflow/_common.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import itertools
6-
from collections.abc import Callable, Hashable, Iterable, Mapping
6+
from collections.abc import Callable, Hashable, Iterable, Mapping, MutableSet
77
from collections.abc import Set as AbstractSet
88
from typing import Generic, TypeVar
99

@@ -300,9 +300,13 @@ def ecatch(self, f: Callable[_S, _T], *args: _S.args, **kwargs: _S.kwargs) -> _T
300300
raise self.decode_err(e) from None
301301

302302

303-
def _infer_layers_impl(gd: nx.DiGraph[_V]) -> Mapping[_V, int]:
304-
"""Fix flow layers one by one depending on order constraints."""
305-
pred = {u: set(gd.predecessors(u)) for u in gd.nodes}
303+
def _infer_layers_impl(pred: Mapping[_V, MutableSet[_V]], succ: Mapping[_V, AbstractSet[_V]]) -> Mapping[_V, int]:
304+
"""Fix flow layers one by one depending on order constraints.
305+
306+
Notes
307+
-----
308+
:py:obj:`pred` is mutated in-place.
309+
"""
306310
work = {u for u, pu in pred.items() if not pu}
307311
ret: dict[_V, int] = {}
308312
for l_now in itertools.count():
@@ -311,13 +315,13 @@ def _infer_layers_impl(gd: nx.DiGraph[_V]) -> Mapping[_V, int]:
311315
next_work: set[_V] = set()
312316
for u in work:
313317
ret[u] = l_now
314-
for v in gd.successors(u):
318+
for v in succ[u]:
315319
ent = pred[v]
316320
ent.discard(u)
317321
if not ent:
318322
next_work.add(v)
319323
work = next_work
320-
if len(ret) != len(gd):
324+
if len(ret) != len(succ):
321325
msg = "Failed to determine layer for all nodes."
322326
raise ValueError(msg)
323327
return ret
@@ -377,15 +381,17 @@ def infer_layers(
377381
-----
378382
This function operates in Pauli flow mode only when :py:obj`pplane` is explicitly given.
379383
"""
380-
gd: nx.DiGraph[_V] = nx.DiGraph()
381-
gd.add_nodes_from(g.nodes)
382384
special = _special_edges(g, anyflow, pplane)
385+
pred: dict[_V, set[_V]] = {u: set() for u in g.nodes}
386+
succ: dict[_V, set[_V]] = {u: set() for u in g.nodes}
383387
for u, fu_ in anyflow.items():
384388
fu = fu_ if isinstance(fu_, AbstractSet) else {fu_}
385389
fu_odd = odd_neighbors(g, fu)
386390
for v in itertools.chain(fu, fu_odd):
387391
if u == v or (u, v) in special:
388392
continue
389-
gd.add_edge(u, v)
390-
gd = gd.reverse()
391-
return _infer_layers_impl(gd)
393+
# Reversed
394+
pred[u].add(v)
395+
succ[v].add(u)
396+
# MEMO: `pred` is invalidated by `_infer_layers_impl`
397+
return _infer_layers_impl(pred, succ)

0 commit comments

Comments
 (0)