-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Description
It would be great to have a utility function to check if a map in starry is positive semidefinite. Starry has a function like this with map.minimize(),
but it's quite slow, and I think can be improved by using gradient-based optimization rather than the scipy.optimize.minimize
that starry uses internally. I've been trying to write such a function but am getting an error:
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File ~/opt/anaconda3/envs/jax/lib/python3.11/site-packages/jax/_src/scipy/optimize/minimize.py:106, in minimize.<locals>.<lambda>()
104 raise TypeError(msg.format(args))
--> 106 fun_with_args = lambda x: fun(x, *args)
108 if method.lower() == 'bfgs':
Cell In[153], line 22
21 lat, lon = coord
---> 22 return surface.intensity(lat, lon)
File ~/opt/anaconda3/envs/jax/lib/python3.11/site-packages/jaxoplanet/starry/surface.py:227, in intensity()
225 x, y, z = rotation.apply(jnp.array([x, y, z]).T).T
--> 227 return self._intensity(x, y, z)
File ~/opt/anaconda3/envs/jax/lib/python3.11/site-packages/jaxoplanet/starry/surface.py:170, in _intensity()
169 def _intensity(self, x, y, z, theta=None, rv=False):
--> 170 pT = self._poly_basis(rv)(x, y, z)
171 Ry = left_project(self.ydeg, self.inc, self.obl, theta, 0.0, self.y.todense())
File ~/opt/anaconda3/envs/jax/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py:340, in wrapped()
339 vectorized_func = api.vmap(vectorized_func, in_axes)
--> 340 result = vectorized_func(*squeezed_args)
342 if not dims_to_expand:
File ~/opt/anaconda3/envs/jax/lib/python3.11/site-packages/jax/_src/numpy/vectorize.py:140, in wrapped()
139 def wrapped(*args):
--> 140 out = func(*args)
141 out_shapes = map(jnp.shape, out if isinstance(out, tuple) else [out])
File ~/opt/anaconda3/envs/jax/lib/python3.11/site-packages/jaxoplanet/starry/core/basis.py:372, in impl()
371 if len(inds):
--> 372 return p.at[np.array(inds)].multiply(z)
373 else:
File ~/opt/anaconda3/envs/jax/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:550, in multiply()
543 """Pure equivalent of ``x[idx] *= y``.
544
545 Returns the value of ``x`` that would result from the NumPy-style
(...)
548 See :mod:`jax.ops` for details.
549 """
--> 550 return scatter._scatter_update(self.array, self.index, values,
551 lax.scatter_mul,
552 indices_are_sorted=indices_are_sorted,
553 unique_indices=unique_indices,
554 mode=mode)
File ~/opt/anaconda3/envs/jax/lib/python3.11/site-packages/jax/_src/ops/scatter.py:76, in _scatter_update()
75 treedef, static_idx, dynamic_idx = jnp._split_index_for_jit(idx, x.shape)
---> 76 return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
77 indices_are_sorted, unique_indices, mode,
78 normalize_indices)
File ~/opt/anaconda3/envs/jax/lib/python3.11/site-packages/jax/_src/ops/scatter.py:127, in _scatter_impl()
122 dnums = lax.ScatterDimensionNumbers(
123 update_window_dims=indexer.dnums.offset_dims,
124 inserted_window_dims=indexer.dnums.collapsed_slice_dims,
125 scatter_dims_to_operand_dims=indexer.dnums.start_index_map
126 )
--> 127 out = scatter_op(
128 x, indexer.gather_indices, y, dnums,
129 indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted,
...
2156 ad_util.zeros_like_jaxval(x), i, g, dimension_numbers=dimension_numbers,
2157 indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
2158 mode=mode))
NotImplementedError: scatter_mul gradients are only implemented if `unique_indices=True`
shishirdholakia
Metadata
Metadata
Assignees
Labels
No labels