Skip to content

Improve dot lift rewrites #1471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 13, 2025

This PR was motivated by the partial jacobian computation example in JAX discussed in jax-ml/jax#5904 (comment)

After #1228 it's actually easier to do this sort of optimization in PyTensor since there's no scan to worry about. We already have a bunch of rewrites to lift subtensor operations through elemwise and dots, but we did not have to lift it through blockwise (and blockwise dot - aka matmul). This PR addresses this.

Some notes on each commit:

  1. Do constant_folding in python mode. This is not related to this PR but I noticed a test was taking 10x longer than the others just because there was a simple constant folding operation being triggered in the rewrites, and the whole c-cache was being loaded. This incurs a one time penalty that's pretty large. For users, not interested in the C backend at all, there's no reason to involve the machinery. One single python eval should be pretty fast anyway.

  2. Simplified local_upcast_elemwise. This rewrite was too complex and wasteful, in that it wrapped constants in symbolic expand_dims / alloc + cast. I just do it in numpy directly. This reduces the number of rewrite iterations.

  3. Bunch of improvements to rewrites. Including lifting index operations on the batch dimensions of blockwise, and expanding the dot subtensor lift to work with the Blockwise case. This rewrite predates Blockwise. Others are self-explanatory.

  4. Canonicalize matvec, vecmat, vecdot internally to all use matmul (i.e., Blockwise of 2x2 dot operation). This makes things simpler for our rewrites, because we only need to worry about one case.

  5. The pre-existing test_local_batched_matmul_to_core_matmul rewrite was extend to better address cases of batched matvec, vecmat, and vecdot (batch dimensions are moved to the core dimension). It now moves non-ovelapping batch dimensions of both inputs to their core dimensions. It further tries to avoid reshape (needed when combining multiple batch/core dimensions), so that subtensor_lift rewrites mentioned above can work fine through them.

Benchmark result added in the last commit:
(Note that vectorize=True goes from underperforming (28ms) to overperforming (.37 ms).

Before
------------------------------------------------------------------------------------------------- benchmark: 2 tests ------------------------------------------------------------------------------------------------
Name (time in ms)                                        Min                Max               Mean            StdDev             Median               IQR            Outliers       OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_partial_jacobian[vectorize=False]      1.9453 (1.0)       2.8201 (1.0)       2.2296 (1.0)      0.0963 (1.0)       2.2031 (1.0)      0.0855 (1.0)         52;25  448.5095 (1.0)         421           1
test_benchmark_partial_jacobian[vectorize=True]      28.8122 (14.81)    36.9261 (13.09)    34.1470 (15.32)    2.3973 (24.90)    34.8889 (15.84)    2.6797 (31.35)         8;1   29.2851 (0.07)         21           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After
--------------------------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------------------------
Name (time in us)                                           Min                   Max                  Mean             StdDev                Median                IQR            Outliers         OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_partial_jacobian[vectorize=True]        345.7980 (1.0)        658.8850 (1.0)        370.9925 (1.0)      41.1362 (1.0)        357.2400 (1.0)      16.9117 (1.0)         24;34  2,695.4724 (1.0)         287           1
test_benchmark_partial_jacobian[vectorize=False]     2,148.9270 (6.21)     3,062.8910 (4.65)     2,215.2234 (5.97)     77.6787 (1.89)     2,194.7940 (6.14)     44.7890 (2.65)        33;34    451.4217 (0.17)        496           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

vectorized jacobian code before:

Subtensor{:stop, :stop} [id A] shape=(5, 5) 9
 ├─ DimShuffle{order=[1,0]} [id B] shape=(1000, 1000) 8
 │  └─ Reshape{3} [id C] shape=(1000, 1000, 1) 7
 │     ├─ Dot22 [id D] shape=(1000, 1000) 6
 │     │  ├─ [[0.903246 ... 74841955]] [id E] shape=(1000, 1000)
 │     │  └─ Reshape{2} [id F] shape=(1000, 1000) 5
 │     │     ├─ True_div [id G] shape=(1000, 1000, 1) 4
 │     │     │  ├─ [[[0.0005] ... [0.0005]]] [id H] shape=(1000, 1000, 1)
 │     │     │  └─ Composite{sqrt((0.001 * i0))} [id I] shape=(1000, 1, 1) 3
 │     │     │     └─ ExpandDims{axes=[1, 2]} [id J] shape=(1000, 1, 1) 2
 │     │     │        └─ CGemv{inplace} [id K] shape=(1000,) 1
 │     │     │           ├─ AllocEmpty{dtype='float64'} [id L] shape=(1000,) 0
 │     │     │           │  └─ 1000 [id M] shape=()
 │     │     │           ├─ 1.0 [id N] shape=()
 │     │     │           ├─ [[0.903246 ... 74841955]] [id O] shape=(1000, 1000)
 │     │     │           ├─ x [id P] shape=(?,)
 │     │     │           └─ 0.0 [id Q] shape=()
 │     │     └─ [1000   -1] [id R] shape=(2,)
 │     └─ [1000 1000    1] [id S] shape=(3,)
 ├─ 5 [id T] shape=()
 └─ 5 [id T] shape=()

and after:

Dot22 [id A] shape=(5, 5) 5
 ├─ True_div [id B] shape=(5, 1000) 4
 │  ├─ [[0.0005 0 ... 0.    ]] [id C] shape=(5, 1000)
 │  └─ Composite{sqrt((0.001 * i0))} [id D] shape=(1, 1000) 3
 │     └─ ExpandDims{axis=0} [id E] shape=(1, 1000) 2
 │        └─ CGemv{inplace} [id F] shape=(1000,) 1
 │           ├─ AllocEmpty{dtype='float64'} [id G] shape=(1000,) 0
 │           │  └─ 1000 [id H] shape=()
 │           ├─ 1.0 [id I] shape=()
 │           ├─ [[0.903246 ... 74841955]] [id J] shape=(1000, 1000)
 │           ├─ x [id K] shape=(?,)
 │           └─ 0.0 [id L] shape=()
 └─ [[0.903246 ... 45926986]] [id M] shape=(1000, 5)

📚 Documentation preview 📚: https://pytensor--1471.org.readthedocs.build/en/1471/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant