Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0cb451b
save
paul0403 Oct 15, 2025
eecee28
basic case works with all hardcodes
paul0403 Oct 20, 2025
8c65cb0
cleaner diff
paul0403 Oct 20, 2025
15d5e90
clean up some hard-coded things
paul0403 Oct 20, 2025
012ac42
codefactor
paul0403 Oct 20, 2025
81c1348
codefactor
paul0403 Oct 20, 2025
c736395
Merge remote-tracking branch 'origin/main' into paul0403/multiple_reg…
paul0403 Oct 20, 2025
1bf8e09
Merge remote-tracking branch 'origin/main' into paul0403/multiple_reg…
paul0403 Oct 20, 2025
5a88e6f
tests pass
paul0403 Oct 21, 2025
fb26363
multiple dynregs into for loop
paul0403 Oct 21, 2025
ac4bfb1
revert pytest "class"-ify: the diff looks abysmal
paul0403 Oct 21, 2025
86f2370
codefactor
paul0403 Oct 21, 2025
7dcd477
Merge remote-tracking branch 'origin/main' into paul0403/multiple_reg…
paul0403 Oct 21, 2025
44e6bf1
small merge
paul0403 Oct 21, 2025
3689ded
factor out a util
paul0403 Oct 22, 2025
1d62d21
Merge remote-tracking branch 'origin/main' into paul0403/multiple_reg…
paul0403 Oct 22, 2025
7b5baf7
cond
paul0403 Oct 22, 2025
8487276
whoops
paul0403 Oct 22, 2025
cbc4ee9
while loop
paul0403 Oct 22, 2025
be8353a
subroutine
paul0403 Oct 22, 2025
f79c47a
Merge remote-tracking branch 'origin/main' into paul0403/multiple_reg…
paul0403 Oct 23, 2025
a0746a8
codefactor
paul0403 Oct 23, 2025
8dbadfb
diff
paul0403 Oct 23, 2025
030b529
add decomp test
paul0403 Oct 23, 2025
e6c318b
changelog
paul0403 Oct 24, 2025
7120872
adjoint error msg
paul0403 Oct 24, 2025
3a9ea00
docs update
paul0403 Oct 24, 2025
3cef8b4
Merge remote-tracking branch 'origin/main' into paul0403/multiple_reg…
paul0403 Oct 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 110 additions & 27 deletions frontend/catalyst/from_plxpr/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,70 @@
from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim

from catalyst.from_plxpr.from_plxpr import PLxPRToQuantumJaxprInterpreter, WorkflowInterpreter
from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder
from catalyst.from_plxpr.qubit_handler import (
QubitHandler,
QubitIndexRecorder,
_get_dynamically_allocated_qregs,
)
from catalyst.jax_extras import jaxpr_pad_consts
from catalyst.jax_primitives import cond_p, for_p, while_p


def _calling_convention(interpreter, closed_jaxpr, *args_plus_qreg):
# The last arg is the scope argument for the body jaxpr
*args, qreg = args_plus_qreg
def _calling_convention(
interpreter, closed_jaxpr, *args_plus_qregs, outer_dynqreg_handlers=(), return_qreg=True
):
# Arg structure (all args are tracers, since this function is to be `make_jaxpr`'d):
# Regular args, then dynamically allocated qregs, then global qreg
# TODO: merge dynamically allocaed qregs into regular args?
# But this is tricky, since qreg arguments need all the SSA value semantics conversion infra
# and are different from the regular plain arguments.
*args_plus_dynqregs, global_qreg = args_plus_qregs
num_dynamic_alloced_qregs = len(outer_dynqreg_handlers)
args, dynalloced_qregs = (
args_plus_dynqregs[: len(args_plus_dynqregs) - num_dynamic_alloced_qregs],
args_plus_dynqregs[len(args_plus_dynqregs) - num_dynamic_alloced_qregs :],
)

# Launch a new interpreter for the body region
# A new interpreter's root qreg value needs a new recorder
converter = copy(interpreter)
converter.qubit_index_recorder = QubitIndexRecorder()
init_qreg = QubitHandler(qreg, converter.qubit_index_recorder)
init_qreg = QubitHandler(global_qreg, converter.qubit_index_recorder)
converter.init_qreg = init_qreg

# add dynamic qregs to recorder
qreg_map = {}
dyn_qreg_handlers = []
for dyn_qreg, outer_dynqreg_handler in zip(
dynalloced_qregs, outer_dynqreg_handlers, strict=True
):
dyn_qreg_handler = QubitHandler(dyn_qreg, converter.qubit_index_recorder)
dyn_qreg_handlers.append(dyn_qreg_handler)

# plxpr global wire index does not change across scopes
# So scope arg dynamic qregs need to have the same root hash as their corresponding
# qreg tracers outside
dyn_qreg_handler.root_hash = outer_dynqreg_handler.root_hash

# Each qreg argument of the subscope corresponds to a qreg from the outer scope
qreg_map[outer_dynqreg_handler] = dyn_qreg_handler

# The new interpreter's recorder needs to be updated to include the qreg args
# of this scope, instead of the outer qregs
if qreg_map:
for k, outer_dynqreg_handler in interpreter.qubit_index_recorder.map.items():
converter.qubit_index_recorder[k] = qreg_map[outer_dynqreg_handler]

retvals = converter(closed_jaxpr, *args)
if not return_qreg:
return retvals

init_qreg.insert_all_dangling_qubits()

# Return all registers
for dyn_qreg_handler in dyn_qreg_handlers:
dyn_qreg_handler.insert_all_dangling_qubits()
retvals.append(dyn_qreg_handler.get())
return *retvals, converter.init_qreg.get()


Expand Down Expand Up @@ -89,7 +135,18 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
"""Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive"""
args = plxpr_invals[args_slice]
self.init_qreg.insert_all_dangling_qubits()
args_plus_qreg = [*args, self.init_qreg.get()] # Add the qreg to the args

dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
plxpr_invals, self.qubit_index_recorder, self.init_qreg
)

# Add the qregs to the args
args_plus_qreg = [
*args,
*[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
self.init_qreg.get(),
]

converted_jaxpr_branches = []
all_consts = []

Expand All @@ -102,7 +159,9 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
converted_jaxpr_branch = None
closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts)

f = partial(_calling_convention, self, closed_jaxpr)
f = partial(
_calling_convention, self, closed_jaxpr, outer_dynqreg_handlers=dynalloced_qregs
)
converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg)

all_consts += converted_jaxpr_branch.consts
Expand All @@ -111,6 +170,8 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
predicate = [_to_bool_if_not(p) for p in plxpr_invals[: len(jaxpr_branches) - 1]]

# Build Catalyst compatible input values
# strip global wire indices of dynamic wires
all_consts = tuple(const for const in all_consts if const not in dynalloced_wire_global_indices)
cond_invals = [*predicate, *all_consts, *args_plus_qreg]

# Perform the binding
Expand All @@ -120,9 +181,12 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
nimplicit_outputs=None,
)

# We assume the last output value is the returned qreg.
# Output structure:
# First a list of dynamically allocated qregs, then the global qreg
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())
for dyn_qreg in reversed(dynalloced_qregs):
dyn_qreg.set(outvals.pop())

# Return only the output values that match the plxpr output values
return outvals
Expand Down Expand Up @@ -192,25 +256,38 @@ def handle_for_loop(

# Add the iteration start and the qreg to the args
self.init_qreg.insert_all_dangling_qubits()

dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
plxpr_invals, self.qubit_index_recorder, self.init_qreg
)

start_plus_args_plus_qreg = [
start,
*args,
*[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
self.init_qreg.get(),
]

consts = plxpr_invals[consts_slice]

jaxpr = ClosedJaxpr(jaxpr_body_fn, consts)

f = partial(_calling_convention, self, jaxpr)
f = partial(
_calling_convention,
self,
jaxpr,
outer_dynqreg_handlers=dynalloced_qregs,
)
converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg)

converted_closed_jaxpr_branch = ClosedJaxpr(
convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), ()
)

# Build Catalyst compatible input values
# strip global wire indices of dynamic wires
new_consts = converted_jaxpr_branch.consts
new_consts = tuple(const for const in new_consts if const not in dynalloced_wire_global_indices)
for_loop_invals = [*new_consts, start, stop, step, *start_plus_args_plus_qreg]

# Config additional for loop settings
Expand All @@ -226,10 +303,14 @@ def handle_for_loop(
preserve_dimensions=True,
)

# We assume the last output value is the returned qreg.
# Output structure:
# First a list of dynamically allocated qregs, then the global qreg
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())

for dyn_qreg in reversed(dynalloced_qregs):
dyn_qreg.set(outvals.pop())

# Return only the output values that match the plxpr output values
return outvals

Expand Down Expand Up @@ -288,14 +369,21 @@ def handle_while_loop(
):
"""Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive"""
self.init_qreg.insert_all_dangling_qubits()
dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
plxpr_invals, self.qubit_index_recorder, self.init_qreg
)
consts_body = plxpr_invals[body_slice]
consts_cond = plxpr_invals[cond_slice]
args = plxpr_invals[args_slice]
args_plus_qreg = [*args, self.init_qreg.get()] # Add the qreg to the args
args_plus_qreg = [
*args,
*[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
self.init_qreg.get(),
] # Add the qreg to the args

jaxpr = ClosedJaxpr(jaxpr_body_fn, consts_body)

f = partial(_calling_convention, self, jaxpr)
f = partial(_calling_convention, self, jaxpr, outer_dynqreg_handlers=dynalloced_qregs)
converted_body_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg).jaxpr

converted_body_closed_jaxpr_branch = ClosedJaxpr(
Expand All @@ -306,30 +394,22 @@ def handle_while_loop(
# We need to be able to handle arbitrary plxpr here.
# But we want to be able to create a state where:
# * We do not pass the quantum register as an argument.

# So let's just remove the quantum register here at the end

jaxpr = ClosedJaxpr(jaxpr_cond_fn, consts_cond)

def remove_qreg(*args_plus_qreg):
# The last arg is the scope argument for the body jaxpr
*args, qreg = args_plus_qreg

# Launch a new interpreter for the body region
# A new interpreter's root qreg value needs a new recorder
converter = copy(self)
converter.qubit_index_recorder = QubitIndexRecorder()
init_qreg = QubitHandler(qreg, converter.qubit_index_recorder)
converter.init_qreg = init_qreg

return converter(jaxpr, *args)
f_remove_qreg = partial(
_calling_convention, self, jaxpr, outer_dynqreg_handlers=dynalloced_qregs, return_qreg=False
)

converted_cond_jaxpr_branch = jax.make_jaxpr(remove_qreg)(*args_plus_qreg).jaxpr
converted_cond_jaxpr_branch = jax.make_jaxpr(f_remove_qreg)(*args_plus_qreg).jaxpr
converted_cond_closed_jaxpr_branch = ClosedJaxpr(
convert_constvars_jaxpr(converted_cond_jaxpr_branch), ()
)

# Build Catalyst compatible input values
consts_body = tuple(
const for const in consts_body if const not in dynalloced_wire_global_indices
)
while_loop_invals = [*consts_cond, *consts_body, *args_plus_qreg]

# Perform the binding
Expand All @@ -347,5 +427,8 @@ def remove_qreg(*args_plus_qreg):
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())

for dyn_qreg in reversed(dynalloced_qregs):
dyn_qreg.set(outvals.pop())

# Return only the output values that match the plxpr output values
return outvals
Loading