Skip to content

Commit 38731ad

Browse files
ferrinericardoV94
authored andcommitted
move clone_replace to a separate file
1 parent 2445327 commit 38731ad

File tree

14 files changed

+213
-206
lines changed

14 files changed

+213
-206
lines changed

pytensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def disable_log_handler(logger=pytensor_logger, handler=logging_default_handler)
7373
__api_version__ = 1
7474

7575
# isort: off
76-
from pytensor.graph.basic import Variable, clone_replace
76+
from pytensor.graph.basic import Variable
77+
from pytensor.graph.replace import clone_replace
7778

7879
# isort: on
7980

pytensor/compile/builders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
Constant,
1717
NominalVariable,
1818
Variable,
19-
clone_replace,
2019
graph_inputs,
2120
io_connection_pattern,
2221
)
2322
from pytensor.graph.fg import FunctionGraph
2423
from pytensor.graph.null_type import NullType
2524
from pytensor.graph.op import HasInnerGraph, Op
25+
from pytensor.graph.replace import clone_replace
2626
from pytensor.graph.rewriting.basic import in2out, node_rewriter
2727
from pytensor.graph.utils import MissingInputError
2828
from pytensor.tensor.rewriting.shape import ShapeFeature

pytensor/graph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
Constant,
88
graph_inputs,
99
clone,
10-
clone_replace,
1110
ancestors,
1211
)
12+
from pytensor.graph.replace import clone_replace
1313
from pytensor.graph.op import Op
1414
from pytensor.graph.type import Type
1515
from pytensor.graph.fg import FunctionGraph

pytensor/graph/basic.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,53 +1151,6 @@ def clone_get_equiv(
11511151
return memo
11521152

11531153

1154-
def clone_replace(
1155-
output: Collection[Variable],
1156-
replace: Optional[
1157-
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
1158-
] = None,
1159-
**rebuild_kwds,
1160-
) -> List[Variable]:
1161-
"""Clone a graph and replace subgraphs within it.
1162-
1163-
It returns a copy of the initial subgraph with the corresponding
1164-
substitutions.
1165-
1166-
Parameters
1167-
----------
1168-
output
1169-
PyTensor expression that represents the computational graph.
1170-
replace
1171-
Dictionary describing which subgraphs should be replaced by what.
1172-
rebuild_kwds
1173-
Keywords to `rebuild_collect_shared`.
1174-
1175-
"""
1176-
from pytensor.compile.function.pfunc import rebuild_collect_shared
1177-
1178-
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
1179-
if isinstance(replace, dict):
1180-
items = list(replace.items())
1181-
elif isinstance(replace, (list, tuple)):
1182-
items = replace
1183-
elif replace is None:
1184-
items = []
1185-
else:
1186-
raise ValueError(
1187-
"replace is neither a dictionary, list, "
1188-
f"tuple or None ! The value provided is {replace},"
1189-
f"of type {type(replace)}"
1190-
)
1191-
tmp_replace = [(x, x.type()) for x, y in items]
1192-
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
1193-
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
1194-
1195-
# TODO Explain why we call it twice ?!
1196-
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
1197-
1198-
return cast(List[Variable], outs)
1199-
1200-
12011154
def general_toposort(
12021155
outputs: Iterable[T],
12031156
deps: Callable[[T], Union[OrderedSet, List[T]]],

pytensor/graph/replace.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import (
2+
Collection,
3+
Dict,
4+
Iterable,
5+
List,
6+
Optional,
7+
Sequence,
8+
Tuple,
9+
Union,
10+
cast,
11+
)
12+
13+
from pytensor.graph.basic import Constant, Variable
14+
15+
16+
def clone_replace(
17+
output: Collection[Variable],
18+
replace: Optional[
19+
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
20+
] = None,
21+
**rebuild_kwds,
22+
) -> List[Variable]:
23+
"""Clone a graph and replace subgraphs within it.
24+
25+
It returns a copy of the initial subgraph with the corresponding
26+
substitutions.
27+
28+
Parameters
29+
----------
30+
output
31+
PyTensor expression that represents the computational graph.
32+
replace
33+
Dictionary describing which subgraphs should be replaced by what.
34+
rebuild_kwds
35+
Keywords to `rebuild_collect_shared`.
36+
37+
"""
38+
from pytensor.compile.function.pfunc import rebuild_collect_shared
39+
40+
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
41+
if isinstance(replace, dict):
42+
items = list(replace.items())
43+
elif isinstance(replace, (list, tuple)):
44+
items = replace
45+
elif replace is None:
46+
items = []
47+
else:
48+
raise ValueError(
49+
"replace is neither a dictionary, list, "
50+
f"tuple or None ! The value provided is {replace},"
51+
f"of type {type(replace)}"
52+
)
53+
tmp_replace = [(x, x.type()) for x, y in items]
54+
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
55+
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
56+
57+
# TODO Explain why we call it twice ?!
58+
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
59+
60+
return cast(List[Variable], outs)

pytensor/ifelse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from pytensor import as_symbolic
2121
from pytensor.compile import optdb
2222
from pytensor.configdefaults import config
23-
from pytensor.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
23+
from pytensor.graph.basic import Apply, Variable, is_in_ancestors
2424
from pytensor.graph.op import _NoPythonOp
25+
from pytensor.graph.replace import clone_replace
2526
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
2627
from pytensor.graph.type import HasDataType, HasShape
2728
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast

pytensor/scan/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from pytensor.compile.function.pfunc import construct_pfunc_ins_and_outs
77
from pytensor.compile.sharedvalue import SharedVariable, collect_new_shareds
88
from pytensor.configdefaults import config
9-
from pytensor.graph.basic import Constant, Variable, clone_replace, graph_inputs
9+
from pytensor.graph.basic import Constant, Variable, graph_inputs
1010
from pytensor.graph.op import get_test_value
11+
from pytensor.graph.replace import clone_replace
1112
from pytensor.graph.utils import MissingInputError, TestValueError
1213
from pytensor.scan.op import Scan, ScanInfo
1314
from pytensor.scan.utils import expand_empty, safe_new, until

pytensor/scan/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@
6565
from pytensor.graph.basic import (
6666
Apply,
6767
Variable,
68-
clone_replace,
6968
equal_computations,
7069
graph_inputs,
7170
io_connection_pattern,
7271
)
7372
from pytensor.graph.features import NoOutputFromInplace
7473
from pytensor.graph.op import HasInnerGraph, Op
74+
from pytensor.graph.replace import clone_replace
7575
from pytensor.graph.utils import InconsistencyError, MissingInputError
7676
from pytensor.link.c.basic import CLinker
7777
from pytensor.link.c.exceptions import MissingGXX

pytensor/scan/rewriting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
Apply,
1919
Constant,
2020
Variable,
21-
clone_replace,
2221
equal_computations,
2322
graph_inputs,
2423
io_toposort,
@@ -28,6 +27,7 @@
2827
from pytensor.graph.features import ReplaceValidate
2928
from pytensor.graph.fg import FunctionGraph
3029
from pytensor.graph.op import compute_test_value
30+
from pytensor.graph.replace import clone_replace
3131
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
3232
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
3333
from pytensor.graph.type import HasShape

pytensor/scan/utils.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,9 @@
1414
from pytensor import tensor as at
1515
from pytensor.compile.profiling import ProfileStats
1616
from pytensor.configdefaults import config
17-
from pytensor.graph.basic import (
18-
Constant,
19-
Variable,
20-
clone_replace,
21-
equal_computations,
22-
graph_inputs,
23-
)
17+
from pytensor.graph.basic import Constant, Variable, equal_computations, graph_inputs
2418
from pytensor.graph.op import get_test_value
19+
from pytensor.graph.replace import clone_replace
2520
from pytensor.graph.type import HasDataType
2621
from pytensor.graph.utils import TestValueError
2722
from pytensor.tensor.basic import AllocEmpty, cast

0 commit comments

Comments
 (0)