Skip to content

Commit a91c41e

Browse files
michaelosthegetwiecki
authored andcommitted
Deprecate Model.update_start_vals method
1 parent dae6c56 commit a91c41e

File tree

2 files changed

+17
-87
lines changed

2 files changed

+17
-87
lines changed

pymc/model.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,25 +1522,10 @@ def update_start_vals(self, a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]):
15221522
conditional on the values of `b` and stored in `b`.
15231523
15241524
"""
1525-
# TODO FIXME XXX: If we're going to incrementally update transformed
1526-
# variables, we should do it in topological order.
1527-
for a_name, a_value in tuple(a.items()):
1528-
# If the name is a random variable, get its value variable and
1529-
# potentially transform it
1530-
var = self.named_vars.get(a_name, None)
1531-
value_var = self.rvs_to_values.get(var, None)
1532-
if value_var:
1533-
transform = getattr(value_var.tag, "transform", None)
1534-
if transform:
1535-
fval_graph = transform.forward(var, a_value)
1536-
(fval_graph,), _ = rvs_to_value_vars((fval_graph,), apply_transforms=True)
1537-
fval_graph_inputs = {i: b[i.name] for i in inputvars(fval_graph) if i.name in b}
1538-
rv_var_value = fval_graph.eval(fval_graph_inputs)
1539-
# Why are these transformed values stored in `b`? They're
1540-
# not going to be used to update `a`.
1541-
b[value_var.name] = rv_var_value
1542-
1543-
a.update({k: v for k, v in b.items() if k not in a})
1525+
raise DeprecationWarning(
1526+
"The `Model.update_start_vals` method was removed."
1527+
" To change initial values you may set the items of `Model.initial_values` directly."
1528+
)
15441529

15451530
def eval_rv_shapes(self) -> Dict[str, Tuple[int, ...]]:
15461531
"""Evaluates shapes of untransformed AND transformed free variables.

pymc/tests/test_model.py

Lines changed: 13 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -533,68 +533,6 @@ def test_point_logps():
533533
assert "a" in logp_vals.keys()
534534

535535

536-
class TestUpdateStartVals(SeededTest):
537-
def setup_method(self):
538-
super().setup_method()
539-
540-
def test_soft_update_all_present(self):
541-
model = pm.Model()
542-
start = {"a": 1, "b": 2}
543-
test_point = {"a": 3, "b": 4}
544-
model.update_start_vals(start, test_point)
545-
assert start == {"a": 1, "b": 2}
546-
547-
def test_soft_update_one_missing(self):
548-
model = pm.Model()
549-
start = {
550-
"a": 1,
551-
}
552-
test_point = {"a": 3, "b": 4}
553-
model.update_start_vals(start, test_point)
554-
assert start == {"a": 1, "b": 4}
555-
556-
def test_soft_update_empty(self):
557-
model = pm.Model()
558-
start = {}
559-
test_point = {"a": 3, "b": 4}
560-
model.update_start_vals(start, test_point)
561-
assert start == test_point
562-
563-
def test_soft_update_transformed(self):
564-
with pm.Model() as model:
565-
pm.Exponential("a", 1)
566-
start = {"a": 2.0}
567-
test_point = {"a_log__": 0}
568-
model.update_start_vals(start, test_point)
569-
assert_almost_equal(np.exp(start["a_log__"]), start["a"])
570-
571-
def test_soft_update_parent(self):
572-
with pm.Model() as model:
573-
a = pm.Uniform("a", lower=0.0, upper=1.0)
574-
b = pm.Uniform("b", lower=2.0, upper=3.0)
575-
pm.Uniform("lower", lower=a, upper=3.0)
576-
pm.Uniform("upper", lower=0.0, upper=b)
577-
pm.Uniform("interv", lower=a, upper=b)
578-
579-
initial_point = {
580-
"a_interval__": np.array(0.0, dtype=aesara.config.floatX),
581-
"b_interval__": np.array(0.0, dtype=aesara.config.floatX),
582-
"lower_interval__": np.array(0.0, dtype=aesara.config.floatX),
583-
"upper_interval__": np.array(0.0, dtype=aesara.config.floatX),
584-
"interv_interval__": np.array(0.0, dtype=aesara.config.floatX),
585-
}
586-
start = {"a": 0.3, "b": 2.1, "lower": 1.4, "upper": 1.4, "interv": 1.4}
587-
test_point = {
588-
"lower_interval__": -0.3746934494414109,
589-
"upper_interval__": 0.693147180559945,
590-
"interv_interval__": 0.4519851237430569,
591-
}
592-
model.update_start_vals(start, initial_point)
593-
assert_almost_equal(start["lower_interval__"], test_point["lower_interval__"])
594-
assert_almost_equal(start["upper_interval__"], test_point["upper_interval__"])
595-
assert_almost_equal(start["interv_interval__"], test_point["interv_interval__"])
596-
597-
598536
class TestShapeEvaluation:
599537
def test_eval_rv_shapes(self):
600538
with pm.Model(
@@ -626,17 +564,21 @@ def test_valid_start_point(self):
626564
a = pm.Uniform("a", lower=0.0, upper=1.0)
627565
b = pm.Uniform("b", lower=2.0, upper=3.0)
628566

629-
start = {"a": 0.3, "b": 2.1}
630-
model.update_start_vals(start, model.initial_point)
567+
start = {
568+
"a_interval__": model.rvs_to_values[a].tag.transform.forward(a, 0.3).eval(),
569+
"b_interval__": model.rvs_to_values[b].tag.transform.forward(b, 2.1).eval(),
570+
}
631571
model.check_start_vals(start)
632572

633573
def test_invalid_start_point(self):
634574
with pm.Model() as model:
635575
a = pm.Uniform("a", lower=0.0, upper=1.0)
636576
b = pm.Uniform("b", lower=2.0, upper=3.0)
637577

638-
start = {"a": np.nan, "b": np.nan}
639-
model.update_start_vals(start, model.initial_point)
578+
start = {
579+
"a_interval__": np.nan,
580+
"b_interval__": model.rvs_to_values[b].tag.transform.forward(b, 2.1).eval(),
581+
}
640582
with pytest.raises(pm.exceptions.SamplingError):
641583
model.check_start_vals(start)
642584

@@ -645,8 +587,11 @@ def test_invalid_variable_name(self):
645587
a = pm.Uniform("a", lower=0.0, upper=1.0)
646588
b = pm.Uniform("b", lower=2.0, upper=3.0)
647589

648-
start = {"a": 0.3, "b": 2.1, "c": 1.0}
649-
model.update_start_vals(start, model.initial_point)
590+
start = {
591+
"a_interval__": model.rvs_to_values[a].tag.transform.forward(a, 0.3).eval(),
592+
"b_interval__": model.rvs_to_values[b].tag.transform.forward(b, 2.1).eval(),
593+
"c": 1.0,
594+
}
650595
with pytest.raises(KeyError):
651596
model.check_start_vals(start)
652597

0 commit comments

Comments
 (0)