Skip to content

Conversation

daniprec
Copy link
Contributor

@daniprec daniprec commented Jun 4, 2025

Purpose

This PR investigates whether JAX is responsible for server instability when running the script.

Changes

  • Replaced all JAX-related operations with NumPy equivalents
  • Removed GPU-specific logic to enforce CPU execution

Context
The server consistently crashes during execution, and this branch isolates the issue by removing JAX and GPU dependencies.

Next steps
If stability improves, this suggests JAX or GPU usage may be the root cause.

⚠️ This is an exploratory branch. Not intended for production unless validated.

@daniprec daniprec self-assigned this Jun 4, 2025
@daniprec daniprec added the bug Something isn't working label Jun 4, 2025
@daniprec daniprec requested a review from Copilot June 4, 2025 15:07
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR removes JAX and GPU-specific logic to isolate a server crash, replacing all JAX operations with NumPy equivalents and enforcing CPU execution.

  • Swapped out jax.numpy for numpy across tests, scripts, and core modules.
  • Stripped GPU configuration and JAX-specific decorators/imports.
  • Updated dependencies, removing JAX and related packages.

Reviewed Changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated no comments.

Show a summary per file
File Description
tests/test_vectorfield.py Switched imports and assertions from jnp to np
tests/test_land.py Replaced jnp array operations with np equivalents
tests/test_cmaes.py Updated test data and type checks to NumPy
setup.py Removed JAX and jaxlib from install_requires
scripts/single_run.py Replaced jnp with np and dropped GPU logic
scripts/results_on_literature_vector_fields.py Swapped JAX usage for NumPy array ops
scripts/results_land_avoidance.py Replaced JAX imports and calls with NumPy
scripts/plot_comptime_vs_popsize.py Removed JAX imports; uses NumPy
scripts/parameter_search.py Replaced JAX cache clear with np.clear_caches
scripts/draw_fms.py Converted JAX random and array ops to NumPy
routetools/vectorfield.py Replaced JAX types, removed @jit
routetools/plot.py Switched plot inputs from jnp to np
routetools/land.py Updated land logic to use NumPy and SciPy
routetools/fms.py Replaced JAX gradient code with NumPy (incomplete imports)
routetools/cost.py Swapped cost computations to NumPy
routetools/config.py Updated config parsing to use NumPy
routetools/cmaes.py Converted Bézier and CMA-ES routines to NumPy
pyproject.toml Removed JAX from project dependencies
Comments suppressed due to low confidence (4)

scripts/draw_fms.py:40

  • The .at[...,].set() accessor is a JAX feature and not supported by NumPy arrays. Replace this with standard NumPy indexing and assignment, e.g., routes[0, 1:99, 0] = routes[0, 1:99, 0] + w * rng.normal(size=(98,)).
routes = routes.at[0, 1:99, 0].set(

routetools/fms.py:4

  • The imports for grad, jacfwd, jacrev, and vmap were removed, but those functions are still referenced later (e.g., in hessian and FMS loops), causing undefined name errors. You should either reintroduce the necessary JAX (or alternative) imports or fully refactor those calls to use a compatible autodiff library.
import numpy as np

scripts/parameter_search.py:180

  • NumPy does not provide a clear_caches function. This call will raise an AttributeError; remove it or replace it with the correct cache-clearing mechanism (if needed).
np.clear_caches()

scripts/draw_fms.py:38

  • NumPy's random module does not have a PRNGKey function. For reproducible random draws, use a Generator instance, e.g., rng = np.random.default_rng(0) and then rng.normal(...).
key = np.random.PRNGKey(0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant