Skip to content

Tweak get_var_by_name #1440

Open
Open
@ricardoV94

Description

@ricardoV94

Description

This function has an unused argument: ids.

More importantly it also looks for matches in inner_graphs. I would make this behavior optional and non-default, since variables in inner_graphs are not really variables of the graph.

def get_var_by_name(
graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR"
) -> tuple[Variable, ...]:
r"""Get variables in a graph using their names.
Parameters
----------
graphs:
The graph, or graphs, to search.
target_var_id:
The name to match against either ``Variable.name`` or
``Variable.auto_name``.
Returns
-------
A ``tuple`` containing all the `Variable`\s that match `target_var_id`.
"""
from pytensor.graph.op import HasInnerGraph
def expand(r) -> list[Variable] | None:
if not r.owner:
return None
res = list(r.owner.inputs)
if isinstance(r.owner.op, HasInnerGraph):
res.extend(r.owner.op.inner_outputs)
return res
results: tuple[Variable, ...] = ()
for var in walk(graphs, expand, False):
var = cast(Variable, var)
if target_var_id == var.name or target_var_id == var.auto_name:
results += (var,)
return results

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions