33from __future__ import annotations
44
55import itertools
6- from collections .abc import Callable , Hashable , Iterable , Mapping
6+ from collections .abc import Callable , Hashable , Iterable , Mapping , MutableSet
77from collections .abc import Set as AbstractSet
88from typing import Generic , TypeVar
99
@@ -300,9 +300,13 @@ def ecatch(self, f: Callable[_S, _T], *args: _S.args, **kwargs: _S.kwargs) -> _T
300300 raise self .decode_err (e ) from None
301301
302302
303- def _infer_layers_impl (gd : nx .DiGraph [_V ]) -> Mapping [_V , int ]:
304- """Fix flow layers one by one depending on order constraints."""
305- pred = {u : set (gd .predecessors (u )) for u in gd .nodes }
303+ def _infer_layers_impl (pred : Mapping [_V , MutableSet [_V ]], succ : Mapping [_V , AbstractSet [_V ]]) -> Mapping [_V , int ]:
304+ """Fix flow layers one by one depending on order constraints.
305+
306+ Notes
307+ -----
308+ :py:obj:`pred` is mutated in-place.
309+ """
306310 work = {u for u , pu in pred .items () if not pu }
307311 ret : dict [_V , int ] = {}
308312 for l_now in itertools .count ():
@@ -311,13 +315,13 @@ def _infer_layers_impl(gd: nx.DiGraph[_V]) -> Mapping[_V, int]:
311315 next_work : set [_V ] = set ()
312316 for u in work :
313317 ret [u ] = l_now
314- for v in gd . successors ( u ) :
318+ for v in succ [ u ] :
315319 ent = pred [v ]
316320 ent .discard (u )
317321 if not ent :
318322 next_work .add (v )
319323 work = next_work
320- if len (ret ) != len (gd ):
324+ if len (ret ) != len (succ ):
321325 msg = "Failed to determine layer for all nodes."
322326 raise ValueError (msg )
323327 return ret
@@ -377,15 +381,17 @@ def infer_layers(
377381 -----
378382 This function operates in Pauli flow mode only when :py:obj`pplane` is explicitly given.
379383 """
380- gd : nx .DiGraph [_V ] = nx .DiGraph ()
381- gd .add_nodes_from (g .nodes )
382384 special = _special_edges (g , anyflow , pplane )
385+ pred : dict [_V , set [_V ]] = {u : set () for u in g .nodes }
386+ succ : dict [_V , set [_V ]] = {u : set () for u in g .nodes }
383387 for u , fu_ in anyflow .items ():
384388 fu = fu_ if isinstance (fu_ , AbstractSet ) else {fu_ }
385389 fu_odd = odd_neighbors (g , fu )
386390 for v in itertools .chain (fu , fu_odd ):
387391 if u == v or (u , v ) in special :
388392 continue
389- gd .add_edge (u , v )
390- gd = gd .reverse ()
391- return _infer_layers_impl (gd )
393+ # Reversed
394+ pred [u ].add (v )
395+ succ [v ].add (u )
396+ # MEMO: `pred` is invalidated by `_infer_layers_impl`
397+ return _infer_layers_impl (pred , succ )
0 commit comments