Skip to content

toms748_scan doesn't work with JAX backend #2466

Open
@TC01

Description

@TC01

Summary

Hello; perhaps this is known but I thought I'd file a bug report just in case. I was testing the upper_limits API and discovered that the example given in the documentation doesn't seem to work with the JAX backend. It fails with a complaint about an unhashable array type (see the traceback). If I switch to the numpy backend, as shown in the documentation, it runs fine.

I see this on both EL7 in an ATLAS environment (StatAnalysis,0.3,latest) and on my own desktop (Fedora 38); in both cases I have the same pyhf version (0.7.6) and I manually installed jax[CPU] == 0.4.26 on top of that.

I should add that things work fine with JAX if I use the version of upper_limits where I pass in a range of mu values to scan-- so I guess maybe some extra type conversion is needed to go from the JAX array type to a list or something hashable?

OS / Environment

# Linux
$ cat /etc/os-release
NAME="Fedora Linux"
VERSION="38 (Thirty Eight)"
ID=fedora
VERSION_ID=38
VERSION_CODENAME=""
PLATFORM_ID="platform:f38"
PRETTY_NAME="Fedora Linux 38 (Thirty Eight)"
ANSI_COLOR="0;38;2;60;110;180"
LOGO=fedora-logo-icon
CPE_NAME="cpe:/o:fedoraproject:fedora:38"
DEFAULT_HOSTNAME="fedora"
HOME_URL="https://fedoraproject.org/"
DOCUMENTATION_URL="https://docs.fedoraproject.org/en-US/fedora/f38/system-administrators-guide/"
SUPPORT_URL="https://ask.fedoraproject.org/"
BUG_REPORT_URL="https://bugzilla.redhat.com/"
REDHAT_BUGZILLA_PRODUCT="Fedora"
REDHAT_BUGZILLA_PRODUCT_VERSION=38
REDHAT_SUPPORT_PRODUCT="Fedora"
REDHAT_SUPPORT_PRODUCT_VERSION=38
SUPPORT_END=2024-05-14

Steps to Reproduce

Install pyhf and JAX through pip; then try to run the example in the documentation, but with the JAX backend instead of numpy:

import numpy as np
import pyhf
pyhf.set_backend("JAX")
model = pyhf.simplemodels.uncorrelated_background(
    signal=[12.0, 11.0], bkg=[50.0, 52.0], bkg_uncertainty=[3.0, 7.0]
)
observations = [51, 48]
data = pyhf.tensorlib.astensor(observations + model.config.auxdata)
obs_limit, exp_limits = pyhf.infer.intervals.upper_limits.toms748_scan(
    data, model, 0., 5., rtol=0.01
)

File Upload (optional)

No response

Expected Results

Ideally the example would run without crashing (as it does with the numpy backend).

Actual Results

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 130, in toms748_scan
    toms748(f, bounds_low, bounds_up, args=(level, 0), k=2, xtol=atol, rtol=rtol)
  File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1374, in toms748
    result = solver.solve(f, a, b, args=args, k=k, xtol=xtol, rtol=rtol,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1229, in solve
    fc = self._callf(c)
         ^^^^^^^^^^^^^^
  File "/usr/lib64/python3.11/site-packages/scipy/optimize/_zeros_py.py", line 1083, in _callf
    fx = self.f(x, *self.args)
         ^^^^^^^^^^^^^^^^^^^^^
  File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 95, in f
    f_cached(poi)[0] - level
    ^^^^^^^^^^^^^
  File "/home/bjr/.local/lib/python3.11/site-packages/pyhf/infer/intervals/upper_limits.py", line 80, in f_cached
    if poi not in cache:
       ^^^^^^^^^^^^^^^^
TypeError: unhashable type: 'jaxlib.xla_extension.ArrayImpl'

pyhf Version

$ pyhf --version
pyhf, version 0.7.6

Code of Conduct

  • I agree to follow the Code of Conduct

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions