Skip to content

BUG: AdvancedSubTensor with None and integer indices raises a logprob error instead of silently failing #7762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
lucianopaz opened this issue Apr 23, 2025 · 3 comments

Comments

@lucianopaz
Copy link
Contributor

lucianopaz commented Apr 23, 2025

Describe the issue:

I just ran into a logprob rewrite error with an AdvancedSubTensor op that mixed None entries and int32 indices together. This wasn't actually a mixture model but logprob found the op and tried to apply its rewrite rules and raised an error instead of just failing silently. The problem seems to be from this line that doesn't include a guard against a None constant as well as a slice constant.

Reproduceable code example:

import numpy as np
import pymc as pm


obs = np.random.default_rng().normal(size=(7, 4))
with pm.Model():
   inds = np.arange(obs.shape[1])
   a = pm.Normal("a", shape=10)
   b = pm.Deterministic("b", a[None, inds])
   c = pm.Normal("c", mu=b, sigma=1, observed=obs)
   pm.sample()

Error message:

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: find_measurable_index_mixture
ERROR (pytensor.graph.rewriting.basic): node: AdvancedSubtensor(a, NoneConst{None}, [0 1 2 3])
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "pytensor/graph/rewriting/basic.py", line 1913, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "pytensor/graph/rewriting/basic.py", line 1085, in transform
    return self.fn(fgraph, node)
  File "pymc/logprob/mixture.py", line 291, in find_measurable_index_mixture
    if any(
  File "pymc/logprob/mixture.py", line 292, in <genexpr>
    indices.dtype.startswith("int") and sum(1 - b for b in indices.type.broadcastable) > 0
AttributeError: 'Constant' object has no attribute 'dtype'. Did you mean: 'type'?

but sampling works fine because the rewrite was actually supposed to fail and return None.

PyMC version information:

Github main

Context for the issue:

This doesn't really affect anything. It just confuses regular users that see the error traceback from rewriting and get alarmed. It would be more elegant to handle this extra indexer type just like with slice constants.

@lucianopaz lucianopaz changed the title BUG: <Please write a comprehensive title after the 'BUG: ' prefix> BUG: AdvancedSubTensor with None and integer indices raises a logprob error instead of silently failing Apr 23, 2025
@Hashcode-Ankit
Copy link

Hashcode-Ankit commented Apr 26, 2025

hi @lucianopaz just to understand the end outcome, are we expecting it to skip the check if that is constant or none value?

as in you case it is a constant and the condition that is trying to get dtype of it which it not possible here.

Thanks

@lucianopaz
Copy link
Contributor Author

lucianopaz commented Apr 26, 2025

Hi @Hashcode-Ankit. The line that I quoted above needs to also check if the indices are None or NoneConst. That way, the rewrite will return None when it has a mixture of integer indexes, and slices or new axis.
If you look through the code base, you’ll see that rewrites have a bunch of conditions that check whether the rewrite could be applied to the inputs. When the conditions fail, the rewrite returns None. When it succeeds, it returns the modified graph or node. By adding the extra check on that condition, we are explicitly telling pytensor that the rewrite can’t work if the indexing operation mixes integers and other basic indexing things.

@Hashcode-Ankit
Copy link

Hi, I have raised a pr for the same, can you check that once.

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants