Skip to content

Crashes with imaginary numbers #6

@thomasaarholt

Description

@thomasaarholt

I tried converting my complex sympy expression to jax, and got the following error.

I wrote a minimum working example. The I is sympy's variable for a complex number. 1j is Python's version, and they are both treated the same.

from sympy import symbols, I
import sympy2jax

x = symbols("x")

expr = x*I # or x*1j

sympy2jax.SymbolicModule(expr)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:213, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    212 try:
--> 213     return memodict[expr]
    214 except KeyError:

KeyError: I*x

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:213, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    212 try:
--> 213     return memodict[expr]
    214 except KeyError:

KeyError: I

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:180, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
    179 try:
--> 180     self._func = func_lookup[expr.func]
    181 except KeyError as e:

KeyError: <class 'sympy.core.numbers.ImaginaryUnit'>

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
/Users/thomas/Documents/vilde.ipynb Cell 6 in <cell line: 8>()
      [4](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=3) x = symbols("x")
      [6](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=5) expr = x*I # or x*1j
----> [8](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=7) sympy2jax.SymbolicModule(expr)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
    129 object.__setattr__(self, "__class__", initable_cls)
    130 try:
--> 131     cls.__init__(self, *args, **kwargs)
    132 finally:
    133     object.__setattr__(self, "__class__", cls)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:257, in SymbolicModule.__init__(self, expressions, extra_funcs, make_array, **kwargs)
    250     self.has_extra_funcs = True
    251 _convert = ft.partial(
    252     _sympy_to_node,
    253     memodict=dict(),
    254     func_lookup=lookup,
    255     make_array=make_array,
    256 )
--> 257 self.nodes = jax.tree_map(_convert, expressions)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/jax/_src/tree_util.py:205, in tree_map(f, tree, is_leaf, *rest)
    203 leaves, treedef = tree_flatten(tree, is_leaf)
    204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/jax/_src/tree_util.py:205, in <genexpr>(.0)
    203 leaves, treedef = tree_flatten(tree, is_leaf)
    204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:224, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    222     out = _Rational(expr, make_array)
    223 else:
--> 224     out = _Func(expr, memodict, func_lookup, make_array)
    225 memodict[expr] = out
    226 return out

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
    129 object.__setattr__(self, "__class__", initable_cls)
    130 try:
--> 131     cls.__init__(self, *args, **kwargs)
    132 finally:
    133     object.__setattr__(self, "__class__", cls)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:183, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
    181 except KeyError as e:
    182     raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
--> 183 self._args = [
    184     _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
    185 ]

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:184, in <listcomp>(.0)
    181 except KeyError as e:
    182     raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
    183 self._args = [
--> 184     _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
    185 ]

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:224, in _sympy_to_node(expr, memodict, func_lookup, make_array)
    222     out = _Rational(expr, make_array)
    223 else:
--> 224     out = _Func(expr, memodict, func_lookup, make_array)
    225 memodict[expr] = out
    226 return out

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
    129 object.__setattr__(self, "__class__", initable_cls)
    130 try:
--> 131     cls.__init__(self, *args, **kwargs)
    132 finally:
    133     object.__setattr__(self, "__class__", cls)

File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:182, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
    180     self._func = func_lookup[expr.func]
    181 except KeyError as e:
--> 182     raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
    183 self._args = [
    184     _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
    185 ]

KeyError: "Unsupported Sympy type <class 'sympy.core.numbers.ImaginaryUnit'>"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions