Skip to content

Commit 3944eb3

Browse files
committed
.WIP model memoization
1 parent 3f7efa1 commit 3944eb3

File tree

16 files changed

+239
-281
lines changed

16 files changed

+239
-281
lines changed

pymc/backends/base.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@
3030
)
3131

3232
import numpy as np
33-
import pytensor
3433

3534
from pymc.backends.report import SamplerReport
3635
from pymc.model import modelcontext
37-
from pymc.pytensorf import compile
3836
from pymc.util import get_var_name
3937

4038
logger = logging.getLogger(__name__)
@@ -171,10 +169,14 @@ def __init__(
171169

172170
if fn is None:
173171
# borrow=True avoids deepcopy when inputs=output which is the case for untransformed value variables
174-
fn = compile(
175-
inputs=[pytensor.In(v, borrow=True) for v in model.value_vars],
176-
outputs=[pytensor.Out(v, borrow=True) for v in vars],
172+
fn = model.compile_fn(
173+
inputs=model.value_vars,
174+
outputs=vars,
177175
on_unused_input="ignore",
176+
random_seed=False,
177+
borrow_inputs=True,
178+
borrow_outputs=True,
179+
wrap_point_fn=False,
178180
)
179181
fn.trust_input = True
180182

pymc/initial_point.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@
2828
from pymc.pytensorf import (
2929
SeedSequenceSeed,
3030
compile,
31-
find_rng_nodes,
3231
replace_rng_nodes,
33-
reseed_rngs,
32+
seed_compiled_function,
3433
toposort_replace,
3534
)
3635
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name
@@ -167,7 +166,12 @@ def make_initial_point_fn(
167166
# Replace original rng shared variables so that we don't mess with them
168167
# when calling the final seeded function
169168
initial_values = replace_rng_nodes(initial_values)
170-
func = compile(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)
169+
func = compile(
170+
inputs=[],
171+
outputs=initial_values,
172+
mode=pytensor.compile.mode.FAST_COMPILE,
173+
random_seed=False,
174+
)
171175

172176
varnames = []
173177
for var in model.free_RVs:
@@ -179,11 +183,9 @@ def make_initial_point_fn(
179183
varnames.append(name)
180184

181185
def make_seeded_function(func):
182-
rngs = find_rng_nodes(func.maker.fgraph.outputs)
183-
184186
@functools.wraps(func)
185187
def inner(seed, *args, **kwargs):
186-
reseed_rngs(rngs, seed)
188+
seed_compiled_function(func, seed)
187189
values = func(*args, **kwargs)
188190
return dict(zip(varnames, values))
189191

0 commit comments

Comments
 (0)