Skip to content

Prototyping / Research on using JAX for future versions #990

@jpn--

Description

@jpn--

Moving a platform like ActivitySim from Pandas to JAX would be a pretty big architectural shift, but it could unlock some unique advantages for large-scale activity-based modeling.

Advantages


1. Massive Speedups via XLA Compilation

  • What JAX does: JAX can transform NumPy-like operations into optimized code that runs through XLA (Accelerated Linear Algebra).
  • Why it matters for ActivitySim: Pandas is row/column oriented and highly flexible, but it’s not optimized for large-scale, repeated numerical kernels. In ActivitySim, large portions of runtime are spent in vectorized choice model evaluations, probability calculations, and sampling. Compiling these operations via XLA could yield significant speed improvements.

2. GPU/TPU Acceleration

  • What JAX does: JAX lets you run the same code on CPU, GPU, or TPU with essentially no changes.
  • Why it matters for ActivitySim: Currently, ActivitySim is CPU-bound, and performance tuning relies on parallelization (multithreading, multiprocessing, Dask). With JAX, computationally heavy components like utility calculations across millions of alternatives and households could be shifted to GPUs/TPUs, cutting runtimes for regional-scale models dramatically. This would almost certainly require some serious work to shrink the memory footprint of skim data, and the ultimate runtime benefits of doing so are, while possibly huge, also highly uncertain.

3. Automatic Differentiation (Autograd)

  • What JAX does: Every operation in JAX can be differentiated automatically.

  • Why it matters for ActivitySim:

    • Estimation workflows (maximum likelihood estimation, gradient-based calibration) could be easily implemented, since gradients can be computed directly rather than numerically approximated. Instead of porting out data in "estimation data bundles", estimation and auto-calibration could happen directly in the same library used for application.

4. Memory Efficiency & Scaling

  • What JAX does: Uses lazy evaluation and XLA-optimized memory layouts to avoid intermediate object creation (a common Pandas overhead).
  • Why it matters for ActivitySim: Many current bottlenecks are due to Pandas memory overhead (e.g., holding multiple large DataFrames in memory during model steps). JAX arrays are lighter weight, and compilation reduces intermediate allocation, meaning lower peak RAM usage and more efficient handling of very large populations.

5. Unified Interface with Scientific ML Ecosystem

  • What JAX does: JAX has become the backbone of many scientific machine learning and probabilistic programming frameworks (e.g., NumPyro, Flax, BlackJAX).

  • Why it matters for ActivitySim:

    • Opens the door to machine learning hybrids (e.g., neural network choice models, reinforcement learning-based accessibility components).
    • Easier integration with modern inference tools for calibration and scenario testing.

6. Lower Development Risk

@jpn-- already has the bones for a prototype available, which can be used to concretely demonstrate some of these benefits and offer a reasonable starting point for future development with a well defined initial fixed cost.


Downsides

1. Loss of Pandas Data Model Flexibility

  • Pandas is built for heterogeneous, labeled data with rich indexing, joins, and groupby operations — all of which ActivitySim uses heavily to represent persons, households, tours, trips, and land use.

  • JAX, by contrast, only works with homogeneous ndarrays (no column labels, no mixed dtypes). We would need to replace Pandas workflows with custom array-based data management or external libraries, which can:

    • Increase development overhead.
    • Make the code harder to read/maintain.
    • Reduce accessibility for practitioners who are comfortable with Pandas.

2. Debugging and Development Overhead

  • JAX uses just-in-time compilation (jit), which can obscure stack traces and make debugging more difficult.
  • Because of functional purity requirements, you can’t use Python control flow with side effects inside jitted functions, which forces rewrites of common idioms.
  • Developers used to Pandas and NumPy will face a steeper learning curve. Debugging inside compiled kernels can be much harder than with plain Pandas.

3. Longer Development Time for Equivalent Features

  • Many operations that are one-liners in Pandas (e.g., merge, groupby, fillna) require much more verbose code in JAX or must be re-implemented.
  • This could slow down both core development and model customization by agencies.
  • In short: expressiveness is traded for performance, which may not be worth it for all model components.

4. Ecosystem Mismatch

  • JAX is numerical computing first, not tabular data management. There’s no native equivalent of Dask, GeoPandas, or SQL-like joins that integrate seamlessly with JAX arrays.
  • This means you’d likely need a hybrid architecture: Pandas/Polars/Arrow for data prep + JAX for numerical kernels. That adds complexity.
  • Also, the transportation modeling ecosystem (e.g., UrbanSim, ActivitySim extensions, estimation toolchains) is Pandas-centric. Going all-in on JAX could isolate ActivitySim from those tools (although few deep integrations currently exist).

5. Hardware and Runtime Risks

  • GPU/TPU acceleration isn’t free — it requires specialized hardware, drivers, and cluster environments. Many MPOs/consultants may not have access to these, limiting portability.
  • Compilation overhead: for small workloads, JAX’s jit compilation time can outweigh performance gains. ActivitySim has many modular steps, so careful caching strategies would be needed.

6. Maintenance and Community Risks

  • Pandas is a mature, widely adopted standard with decades of stability.
  • JAX is powerful but still evolving, with API changes and some rough edges. This introduces risk for long-term support and backward compatibility.
  • Recruiting contributors and power users is harder if the system requires JAX expertise instead of Pandas fluency.

7. Not All of ActivitySim’s Workloads Fit JAX Perfectly

  • Some parts of ActivitySim are inherently data wrangling and I/O heavy, not raw numerics. JAX will bring little benefit there.
  • Only the core numerical kernels (choice models, sampling, utility calculations) map cleanly to JAX. A full rewrite may be overkill compared to a hybrid approach.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions