diff --git a/doc/dev/sharp_bits.rst b/doc/dev/sharp_bits.rst
index 7afddb242..e4179fa11 100644
--- a/doc/dev/sharp_bits.rst
+++ b/doc/dev/sharp_bits.rst
@@ -1303,49 +1303,18 @@ Currently, however, this is not the case for the following functionalities.
of the initial ones. This will cause a performance difference, specifically in
memory usage, when using dynamic wire allocations with and without Catalyst.
- - Wires allocated outside of an MLIR region cannot be used inside the region.
- This includes control flow (``if`` statements, ``for`` loops and ``while`` loops),
- ``qml.adjoint()``, and subroutines. For example,
+ - Dynamically allocated wires cannot be used in quantum adjoints yet.
.. code-block:: python
qml.capture.enable()
- @qjit(autograph=True)
- @qml.qnode(qml.device("lightning.qubit", wires=3))
- def circuit(c):
-
+ @qml.qjit
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ def circuit():
with qml.allocate(1) as q:
- if c:
- qml.X(q[0])
- else:
- qml.Z(q[0])
-
- return qml.probs(wires=[0, 1, 2])
-
- >>> print(circuit(True))
- NotImplementedError: Dynamically allocated wires in a parent scope cannot be
- used in a child scope yet. Please consider dynamical allocation inside the
- child scope.
-
- A workaround is to move the allocations into the regions themselves:
-
- .. code-block:: python
-
- qml.capture.enable()
-
- @qjit(autograph=True)
- @qml.qnode(qml.device("lightning.qubit", wires=3))
- def circuit(c):
-
- if c:
- with qml.allocate(1) as q:
- qml.X(q[0])
- else:
- with qml.allocate(1) as q:
- qml.Z(q[0])
-
- return qml.probs(wires=[0, 1, 2])
+ qml.adjoint(qml.X)(wires=q[0])
+ return qml.probs(wires=[0])
- >>> print(circuit(True))
- [1. 0. 0. 0. 0. 0. 0. 0.]
+ >>> print(circuit())
+ NotImplementedError: Dynamically allocated wires cannot be used in quantum adjoints yet.
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index c8707eee0..368b3b191 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -4,6 +4,9 @@
Improvements ðŸ›
+* Dynamically allocated wires can now be passed into control flow and subroutines.
+ [(#2130)](https://github.com/PennyLaneAI/catalyst/pull/2130)
+
Breaking changes 💔
Deprecations 👋
diff --git a/frontend/catalyst/from_plxpr/control_flow.py b/frontend/catalyst/from_plxpr/control_flow.py
index 29a27998a..bd264849f 100644
--- a/frontend/catalyst/from_plxpr/control_flow.py
+++ b/frontend/catalyst/from_plxpr/control_flow.py
@@ -26,24 +26,70 @@
from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim
from catalyst.from_plxpr.from_plxpr import PLxPRToQuantumJaxprInterpreter, WorkflowInterpreter
-from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder
+from catalyst.from_plxpr.qubit_handler import (
+ QubitHandler,
+ QubitIndexRecorder,
+ _get_dynamically_allocated_qregs,
+)
from catalyst.jax_extras import jaxpr_pad_consts
from catalyst.jax_primitives import cond_p, for_p, while_p
-def _calling_convention(interpreter, closed_jaxpr, *args_plus_qreg):
- # The last arg is the scope argument for the body jaxpr
- *args, qreg = args_plus_qreg
+def _calling_convention(
+ interpreter, closed_jaxpr, *args_plus_qregs, outer_dynqreg_handlers=(), return_qreg=True
+):
+ # Arg structure (all args are tracers, since this function is to be `make_jaxpr`'d):
+ # Regular args, then dynamically allocated qregs, then global qreg
+ # TODO: merge dynamically allocaed qregs into regular args?
+ # But this is tricky, since qreg arguments need all the SSA value semantics conversion infra
+ # and are different from the regular plain arguments.
+ *args_plus_dynqregs, global_qreg = args_plus_qregs
+ num_dynamic_alloced_qregs = len(outer_dynqreg_handlers)
+ args, dynalloced_qregs = (
+ args_plus_dynqregs[: len(args_plus_dynqregs) - num_dynamic_alloced_qregs],
+ args_plus_dynqregs[len(args_plus_dynqregs) - num_dynamic_alloced_qregs :],
+ )
# Launch a new interpreter for the body region
# A new interpreter's root qreg value needs a new recorder
converter = copy(interpreter)
converter.qubit_index_recorder = QubitIndexRecorder()
- init_qreg = QubitHandler(qreg, converter.qubit_index_recorder)
+ init_qreg = QubitHandler(global_qreg, converter.qubit_index_recorder)
converter.init_qreg = init_qreg
+ # add dynamic qregs to recorder
+ qreg_map = {}
+ dyn_qreg_handlers = []
+ for dyn_qreg, outer_dynqreg_handler in zip(
+ dynalloced_qregs, outer_dynqreg_handlers, strict=True
+ ):
+ dyn_qreg_handler = QubitHandler(dyn_qreg, converter.qubit_index_recorder)
+ dyn_qreg_handlers.append(dyn_qreg_handler)
+
+ # plxpr global wire index does not change across scopes
+ # So scope arg dynamic qregs need to have the same root hash as their corresponding
+ # qreg tracers outside
+ dyn_qreg_handler.root_hash = outer_dynqreg_handler.root_hash
+
+ # Each qreg argument of the subscope corresponds to a qreg from the outer scope
+ qreg_map[outer_dynqreg_handler] = dyn_qreg_handler
+
+ # The new interpreter's recorder needs to be updated to include the qreg args
+ # of this scope, instead of the outer qregs
+ if qreg_map:
+ for k, outer_dynqreg_handler in interpreter.qubit_index_recorder.map.items():
+ converter.qubit_index_recorder[k] = qreg_map[outer_dynqreg_handler]
+
retvals = converter(closed_jaxpr, *args)
+ if not return_qreg:
+ return retvals
+
init_qreg.insert_all_dangling_qubits()
+
+ # Return all registers
+ for dyn_qreg_handler in dyn_qreg_handlers:
+ dyn_qreg_handler.insert_all_dangling_qubits()
+ retvals.append(dyn_qreg_handler.get())
return *retvals, converter.init_qreg.get()
@@ -89,7 +135,18 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
"""Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive"""
args = plxpr_invals[args_slice]
self.init_qreg.insert_all_dangling_qubits()
- args_plus_qreg = [*args, self.init_qreg.get()] # Add the qreg to the args
+
+ dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
+ plxpr_invals, self.qubit_index_recorder, self.init_qreg
+ )
+
+ # Add the qregs to the args
+ args_plus_qreg = [
+ *args,
+ *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
+ self.init_qreg.get(),
+ ]
+
converted_jaxpr_branches = []
all_consts = []
@@ -102,7 +159,9 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
converted_jaxpr_branch = None
closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts)
- f = partial(_calling_convention, self, closed_jaxpr)
+ f = partial(
+ _calling_convention, self, closed_jaxpr, outer_dynqreg_handlers=dynalloced_qregs
+ )
converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg)
all_consts += converted_jaxpr_branch.consts
@@ -111,6 +170,8 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
predicate = [_to_bool_if_not(p) for p in plxpr_invals[: len(jaxpr_branches) - 1]]
# Build Catalyst compatible input values
+ # strip global wire indices of dynamic wires
+ all_consts = tuple(const for const in all_consts if const not in dynalloced_wire_global_indices)
cond_invals = [*predicate, *all_consts, *args_plus_qreg]
# Perform the binding
@@ -120,9 +181,12 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
nimplicit_outputs=None,
)
- # We assume the last output value is the returned qreg.
+ # Output structure:
+ # First a list of dynamically allocated qregs, then the global qreg
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())
+ for dyn_qreg in reversed(dynalloced_qregs):
+ dyn_qreg.set(outvals.pop())
# Return only the output values that match the plxpr output values
return outvals
@@ -192,9 +256,15 @@ def handle_for_loop(
# Add the iteration start and the qreg to the args
self.init_qreg.insert_all_dangling_qubits()
+
+ dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
+ plxpr_invals, self.qubit_index_recorder, self.init_qreg
+ )
+
start_plus_args_plus_qreg = [
start,
*args,
+ *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
self.init_qreg.get(),
]
@@ -202,7 +272,12 @@ def handle_for_loop(
jaxpr = ClosedJaxpr(jaxpr_body_fn, consts)
- f = partial(_calling_convention, self, jaxpr)
+ f = partial(
+ _calling_convention,
+ self,
+ jaxpr,
+ outer_dynqreg_handlers=dynalloced_qregs,
+ )
converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg)
converted_closed_jaxpr_branch = ClosedJaxpr(
@@ -210,7 +285,9 @@ def handle_for_loop(
)
# Build Catalyst compatible input values
+ # strip global wire indices of dynamic wires
new_consts = converted_jaxpr_branch.consts
+ new_consts = tuple(const for const in new_consts if const not in dynalloced_wire_global_indices)
for_loop_invals = [*new_consts, start, stop, step, *start_plus_args_plus_qreg]
# Config additional for loop settings
@@ -226,10 +303,14 @@ def handle_for_loop(
preserve_dimensions=True,
)
- # We assume the last output value is the returned qreg.
+ # Output structure:
+ # First a list of dynamically allocated qregs, then the global qreg
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())
+ for dyn_qreg in reversed(dynalloced_qregs):
+ dyn_qreg.set(outvals.pop())
+
# Return only the output values that match the plxpr output values
return outvals
@@ -288,14 +369,21 @@ def handle_while_loop(
):
"""Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive"""
self.init_qreg.insert_all_dangling_qubits()
+ dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
+ plxpr_invals, self.qubit_index_recorder, self.init_qreg
+ )
consts_body = plxpr_invals[body_slice]
consts_cond = plxpr_invals[cond_slice]
args = plxpr_invals[args_slice]
- args_plus_qreg = [*args, self.init_qreg.get()] # Add the qreg to the args
+ args_plus_qreg = [
+ *args,
+ *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
+ self.init_qreg.get(),
+ ] # Add the qreg to the args
jaxpr = ClosedJaxpr(jaxpr_body_fn, consts_body)
- f = partial(_calling_convention, self, jaxpr)
+ f = partial(_calling_convention, self, jaxpr, outer_dynqreg_handlers=dynalloced_qregs)
converted_body_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg).jaxpr
converted_body_closed_jaxpr_branch = ClosedJaxpr(
@@ -306,30 +394,22 @@ def handle_while_loop(
# We need to be able to handle arbitrary plxpr here.
# But we want to be able to create a state where:
# * We do not pass the quantum register as an argument.
-
# So let's just remove the quantum register here at the end
-
jaxpr = ClosedJaxpr(jaxpr_cond_fn, consts_cond)
- def remove_qreg(*args_plus_qreg):
- # The last arg is the scope argument for the body jaxpr
- *args, qreg = args_plus_qreg
-
- # Launch a new interpreter for the body region
- # A new interpreter's root qreg value needs a new recorder
- converter = copy(self)
- converter.qubit_index_recorder = QubitIndexRecorder()
- init_qreg = QubitHandler(qreg, converter.qubit_index_recorder)
- converter.init_qreg = init_qreg
-
- return converter(jaxpr, *args)
+ f_remove_qreg = partial(
+ _calling_convention, self, jaxpr, outer_dynqreg_handlers=dynalloced_qregs, return_qreg=False
+ )
- converted_cond_jaxpr_branch = jax.make_jaxpr(remove_qreg)(*args_plus_qreg).jaxpr
+ converted_cond_jaxpr_branch = jax.make_jaxpr(f_remove_qreg)(*args_plus_qreg).jaxpr
converted_cond_closed_jaxpr_branch = ClosedJaxpr(
convert_constvars_jaxpr(converted_cond_jaxpr_branch), ()
)
# Build Catalyst compatible input values
+ consts_body = tuple(
+ const for const in consts_body if const not in dynalloced_wire_global_indices
+ )
while_loop_invals = [*consts_cond, *consts_body, *args_plus_qreg]
# Perform the binding
@@ -347,5 +427,8 @@ def remove_qreg(*args_plus_qreg):
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())
+ for dyn_qreg in reversed(dynalloced_qregs):
+ dyn_qreg.set(outvals.pop())
+
# Return only the output values that match the plxpr output values
return outvals
diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py
index 25b6cb90d..e79238f82 100644
--- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py
+++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py
@@ -17,6 +17,7 @@
# pylint: disable=protected-access
import textwrap
from copy import copy
+from functools import partial
import jax
import jax.numpy as jnp
@@ -66,6 +67,7 @@
from .qubit_handler import (
QubitHandler,
QubitIndexRecorder,
+ _get_dynamically_allocated_qregs,
get_in_qubit_values,
is_dynamically_allocated_wire,
)
@@ -316,22 +318,69 @@ def interpret_counts(self, *wires, all_outcomes):
return keys, vals
+def _subroutine_kernel(
+ interpreter,
+ jaxpr,
+ *qregs_plus_args,
+ outer_dynqreg_handlers=(),
+ dynalloced_wire_global_indices=(),
+ wire_label_arg_to_tracer_arg_index=(),
+):
+ global_qreg, *dynqregs_plus_args = qregs_plus_args
+ num_dynamic_alloced_qregs = len(outer_dynqreg_handlers)
+ dynalloced_qregs, args = (
+ dynqregs_plus_args[:num_dynamic_alloced_qregs],
+ dynqregs_plus_args[num_dynamic_alloced_qregs:],
+ )
+
+ # Launch a new interpreter for the body region
+ # A new interpreter's root qreg value needs a new recorder
+ converter = copy(interpreter)
+ converter.qubit_index_recorder = QubitIndexRecorder()
+ init_qreg = QubitHandler(global_qreg, converter.qubit_index_recorder)
+ converter.init_qreg = init_qreg
+
+ # add dynamic qregs to recorder
+ qreg_map = {}
+ dyn_qreg_handlers = []
+ for dyn_qreg, outer_dynqreg_handler, global_wire_index in zip(
+ dynalloced_qregs, outer_dynqreg_handlers, dynalloced_wire_global_indices, strict=True
+ ):
+ dyn_qreg_handler = QubitHandler(dyn_qreg, converter.qubit_index_recorder)
+ dyn_qreg_handlers.append(dyn_qreg_handler)
+
+ # plxpr global wire index does not change across scopes
+ # So scope arg dynamic qregs need to have the same root hash as their corresponding
+ # qreg tracers outside
+ dyn_qreg_handler.root_hash = outer_dynqreg_handler.root_hash
+
+ # Each qreg argument of the subscope corresponds to a qreg from the outer scope
+ qreg_map[args[wire_label_arg_to_tracer_arg_index[global_wire_index]]] = dyn_qreg_handler
+
+ # The new interpreter's recorder needs to be updated to include the qreg args
+ # of this scope, instead of the outer qregs
+ for arg in args:
+ if arg in qreg_map:
+ converter.qubit_index_recorder[arg] = qreg_map[arg]
+
+ retvals = converter(jaxpr, *args)
+
+ init_qreg.insert_all_dangling_qubits()
+
+ # Return all registers
+ for dyn_qreg_handler in reversed(dyn_qreg_handlers):
+ dyn_qreg_handler.insert_all_dangling_qubits()
+ retvals.insert(0, dyn_qreg_handler.get())
+
+ return converter.init_qreg.get(), *retvals
+
+
@PLxPRToQuantumJaxprInterpreter.register_primitive(quantum_subroutine_p)
def handle_subroutine(self, *args, **kwargs):
"""
Transform the subroutine from PLxPR into JAXPR with quantum primitives.
"""
- if any(is_dynamically_allocated_wire(arg) for arg in args):
- raise NotImplementedError(
- textwrap.dedent(
- """
- Dynamically allocated wires in a parent scope cannot be used in a child
- scope yet. Please consider dynamical allocation inside the child scope.
- """
- )
- )
-
backup = dict(self.init_qreg)
self.init_qreg.insert_all_dangling_qubits()
@@ -339,20 +388,32 @@ def handle_subroutine(self, *args, **kwargs):
plxpr = kwargs["jaxpr"]
transformed = self.subroutine_cache.get(plxpr)
- def wrapper(qreg, *args):
- # Launch a new interpreter for the new subroutine region
- # A new interpreter's root qreg value needs a new recorder
- converter = copy(self)
- converter.qubit_index_recorder = QubitIndexRecorder()
- init_qreg = QubitHandler(qreg, converter.qubit_index_recorder)
- converter.init_qreg = init_qreg
+ dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
+ args, self.qubit_index_recorder, self.init_qreg
+ )
- retvals = converter(plxpr, *args)
- converter.init_qreg.insert_all_dangling_qubits()
- return converter.init_qreg.get(), *retvals
+ # Convert global wire indices into local indices
+ new_args = ()
+ wire_label_arg_to_tracer_arg_index = {}
+ for i, arg in enumerate(args):
+ if arg in dynalloced_wire_global_indices:
+ wire_label_arg_to_tracer_arg_index[arg] = i
+ new_args += (self.qubit_index_recorder[arg].global_index_to_local_index(arg),)
+ else:
+ new_args += (arg,)
if not transformed:
- converted_closed_jaxpr_branch = jax.make_jaxpr(wrapper)(self.init_qreg.get(), *args)
+ f = partial(
+ _subroutine_kernel,
+ self,
+ plxpr,
+ outer_dynqreg_handlers=dynalloced_qregs,
+ dynalloced_wire_global_indices=dynalloced_wire_global_indices,
+ wire_label_arg_to_tracer_arg_index=wire_label_arg_to_tracer_arg_index,
+ )
+ converted_closed_jaxpr_branch = jax.make_jaxpr(f)(
+ self.init_qreg.get(), *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], *args
+ )
self.subroutine_cache[plxpr] = converted_closed_jaxpr_branch
else:
converted_closed_jaxpr_branch = transformed
@@ -361,12 +422,13 @@ def wrapper(qreg, *args):
# is just pjit_p with a different name.
vals_out = quantum_subroutine_p.bind(
self.init_qreg.get(),
- *args,
+ *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
+ *new_args,
jaxpr=converted_closed_jaxpr_branch,
- in_shardings=(UNSPECIFIED, *kwargs["in_shardings"]),
- out_shardings=(UNSPECIFIED, *kwargs["out_shardings"]),
- in_layouts=(None, *kwargs["in_layouts"]),
- out_layouts=(None, *kwargs["out_layouts"]),
+ in_shardings=(*(UNSPECIFIED,) * (len(dynalloced_qregs) + 1), *kwargs["in_shardings"]),
+ out_shardings=(*(UNSPECIFIED,) * (len(dynalloced_qregs) + 1), *kwargs["out_shardings"]),
+ in_layouts=(*(None,) * (len(dynalloced_qregs) + 1), *kwargs["in_layouts"]),
+ out_layouts=(*(None,) * (len(dynalloced_qregs) + 1), *kwargs["out_layouts"]),
donated_invars=kwargs["donated_invars"],
ctx_mesh=kwargs["ctx_mesh"],
name=kwargs["name"],
@@ -376,7 +438,9 @@ def wrapper(qreg, *args):
)
self.init_qreg.set(vals_out[0])
- vals_out = vals_out[1:]
+ for i, dyn_qreg in enumerate(dynalloced_qregs):
+ dyn_qreg.set(vals_out[i + 1])
+ vals_out = vals_out[len(dynalloced_qregs) + 1 :]
for orig_wire in backup.keys():
self.init_qreg.extract(orig_wire)
@@ -565,6 +629,12 @@ def handle_adjoint_transform(
n_consts,
):
"""Handle the conversion from plxpr to Catalyst jaxpr for the adjoint primitive"""
+
+ if any(is_dynamically_allocated_wire(arg) for arg in plxpr_invals):
+ raise NotImplementedError(
+ "Dynamically allocated wires cannot be used in quantum adjoints yet."
+ )
+
assert jaxpr is not None
consts = plxpr_invals[:n_consts]
args = plxpr_invals[n_consts:]
diff --git a/frontend/catalyst/from_plxpr/qubit_handler.py b/frontend/catalyst/from_plxpr/qubit_handler.py
index f9adb55e6..1739b3841 100644
--- a/frontend/catalyst/from_plxpr/qubit_handler.py
+++ b/frontend/catalyst/from_plxpr/qubit_handler.py
@@ -68,8 +68,6 @@
qubit SSA values on its wires?
"""
-import textwrap
-
from catalyst.jax_extras import DynamicJaxprTracer
from catalyst.jax_primitives import AbstractQbit, AbstractQreg, qextract_p, qinsert_p
from catalyst.utils.exceptions import CompileError
@@ -422,21 +420,6 @@ def get_in_qubit_values(
if not qubit_index_recorder.contains(w):
# First time the global wire index w is encountered
# Need to extract from fallback qreg
- # TODO: this can now only be from the global qreg, because right now in from_plxpr
- # conversion, subscopes (control flow, adjoint, ...) can only take in the global
- # qreg as the final scope argument. They cannot take an arbitrary number of qreg
- # values yet.
- # Supporting multiple registers requires refactoring the from_plxpr conversion's
- # implementation.
- if is_dynamically_allocated_wire(w):
- raise NotImplementedError(
- textwrap.dedent(
- """
- Dynamically allocated wires in a parent scope cannot be used in a child
- scope yet. Please consider dynamical allocation inside the child scope.
- """
- )
- )
in_qubits.append(fallback_qreg[fallback_qreg.global_index_to_local_index(w)])
in_qregs.append(fallback_qreg)
@@ -446,3 +429,28 @@ def get_in_qubit_values(
in_qubits.append(in_qreg[in_qreg.global_index_to_local_index(w)])
return in_qregs, in_qubits
+
+
+def _get_dynamically_allocated_qregs(plxpr_invals, qubit_index_recorder, init_qreg):
+ """
+ Get the potential dynamically allocated register values that are visible to a jaxpr.
+
+ Note that dynamically allocated wires have their qreg tracer's id as the global wire index
+ so the sub jaxpr takes that id in as a "const", since it is clousure from the target wire
+ of gates/measurements/...
+ We need to remove that const, so we also let this util return these global indices.
+ """
+ dynalloced_qregs = []
+ dynalloced_wire_global_indices = []
+ for inval in plxpr_invals:
+ if (
+ isinstance(inval, int)
+ and qubit_index_recorder.contains(inval)
+ and qubit_index_recorder[inval] is not init_qreg
+ ):
+ dyn_qreg = qubit_index_recorder[inval]
+ dyn_qreg.insert_all_dangling_qubits()
+ dynalloced_qregs.append(dyn_qreg)
+ dynalloced_wire_global_indices.append(inval)
+
+ return dynalloced_qregs, dynalloced_wire_global_indices
diff --git a/frontend/test/lit/test_decomposition.py b/frontend/test/lit/test_decomposition.py
index 619f83964..cd567e99c 100644
--- a/frontend/test/lit/test_decomposition.py
+++ b/frontend/test/lit/test_decomposition.py
@@ -1199,3 +1199,42 @@ def circuit_26(x: float, y: float, z: float):
test_decompose_lowering_params_ordering()
+
+
+def test_decomposition_rule_with_allocation():
+ """Test decomposition rule with dynamic qubit allocation"""
+
+ qml.capture.enable()
+
+ @decomposition_rule(is_qreg=True)
+ def Hadamard0_with_alloc(wire: WiresLike):
+ with qml.allocate(1) as q:
+ qml.X(q[0])
+ qml.CNOT(wires=[q[0], wire])
+
+ @qml.qjit
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ # CHECK: module @circuit_27
+ def circuit_27():
+ Hadamard0_with_alloc(int)
+ return qml.probs()
+
+ # CHECK: func.func public @Hadamard0_with_alloc(%arg0: !quantum.reg, %arg1: tensor) -> !quantum.reg
+ # CHECK: [[dynalloc_qreg:%.+]] = quantum.alloc( 1)
+ # CHECK: [[dynalloc_bit0:%.+]] = quantum.extract [[dynalloc_qreg]][ 0]
+ # CHECK: [[xout:%.+]] = quantum.custom "PauliX"() [[dynalloc_bit0]]
+ # CHECK: [[detensor:%.+]] = tensor.extract %arg1[]
+ # CHECK: [[glob_bit:%.+]] = quantum.extract %arg0[[[detensor]]]
+ # CHECK: [[cnot_out:%.+]]:2 = quantum.custom "CNOT"() [[xout]], [[glob_bit]]
+ # CHECK: [[dynalloc_qreg_inserted:%.+]] = quantum.insert [[dynalloc_qreg]][ 0], [[cnot_out]]#0
+ # CHECK: quantum.dealloc [[dynalloc_qreg_inserted]] : !quantum.reg
+ # CHECK: [[detensor:%.+]] = tensor.extract %arg1[]
+ # CHECK: [[glob_insert:%.+]] = quantum.insert %arg0[[[detensor]]], [[cnot_out]]#1
+ # CHECK: return [[glob_insert]] : !quantum.reg
+
+ print(circuit_27.mlir)
+
+ qml.capture.disable()
+
+
+test_decomposition_rule_with_allocation()
diff --git a/frontend/test/lit/test_dynamic_qubit_allocation.py b/frontend/test/lit/test_dynamic_qubit_allocation.py
index bb94a1bc6..0b1dc7eb6 100644
--- a/frontend/test/lit/test_dynamic_qubit_allocation.py
+++ b/frontend/test/lit/test_dynamic_qubit_allocation.py
@@ -21,10 +21,10 @@
import pennylane as qml
from catalyst import qjit
-from catalyst.jax_primitives import qalloc_p, qdealloc_qb_p, qextract_p
+from catalyst.jax_primitives import qalloc_p, qdealloc_qb_p, qextract_p, subroutine
-@qjit
+@qjit(target="mlir")
def test_single_qubit_dealloc():
"""
Unit test for the single qubit dealloc primitive's lowerings.
@@ -95,7 +95,7 @@ def test_basic_dynalloc():
print(test_basic_dynalloc.mlir)
-@qjit(autograph=True)
+@qjit(autograph=True, target="mlir")
@qml.qnode(qml.device("lightning.qubit", wires=3))
def test_measure_with_reset():
"""
@@ -125,4 +125,175 @@ def test_measure_with_reset():
print(test_measure_with_reset.mlir)
+@qjit(autograph=True, target="mlir")
+@qml.qnode(qml.device("lightning.qubit", wires=2))
+def test_pass_reg_into_forloop():
+ """
+ Test using a dynamically allocated resgister from inside a subscope.
+ """
+
+ # CHECK: [[global_reg:%.+]] = quantum.alloc( 2)
+ # CHECK: [[dyn_reg:%.+]] = quantum.alloc( 1)
+ # CHECK: [[for_out:%.+]]:2 = scf.for %arg0 = {{.+}} to {{.+}} step {{.+}} iter_args
+ # CHECK-SAME: (%arg1 = [[dyn_reg]], %arg2 = [[global_reg]]) -> (!quantum.reg, !quantum.reg) {
+ # CHECK: [[x_in:%.+]] = quantum.extract %arg1[ 0]
+ # CHECK: [[x_out:%.+]] = quantum.custom "PauliX"() [[x_in]]
+ # CHECK: [[cnot_in:%.+]] = quantum.extract %arg2[ 0]
+ # CHECK: [[cnot_out:%.+]]:2 = quantum.custom "CNOT"() [[x_out]], [[cnot_in]]
+ # CHECK: [[global_reg_yield:%.+]] = quantum.insert %arg2[ 0], [[cnot_out]]#1
+ # CHECK: [[dyn_reg_yield:%.+]] = quantum.insert %arg1[ 0], [[cnot_out]]#0
+ # CHECK: scf.yield [[dyn_reg_yield]], [[global_reg_yield]] : !quantum.reg, !quantum.reg
+ # CHECK: quantum.dealloc [[for_out]]#0 : !quantum.reg
+
+ with qml.allocate(1) as q:
+ for _ in range(3):
+ qml.X(wires=q[0])
+ qml.CNOT(wires=[q[0], 0])
+
+ # CHECK: [[global_bit0:%.+]] = quantum.extract [[for_out]]#1[ 0]
+ # CHECK: [[global_bit1:%.+]] = quantum.extract [[for_out]]#1[ 1]
+ # CHECK: [[obs:%.+]] = quantum.compbasis qubits [[global_bit0]], [[global_bit1]] : !quantum.obs
+ # CHECK: {{.+}} = quantum.probs [[obs]] : tensor<4xf64>
+ return qml.probs(wires=[0, 1])
+
+
+print(test_pass_reg_into_forloop.mlir)
+
+
+@qjit(autograph=True, target="mlir")
+@qml.qnode(qml.device("lightning.qubit", wires=3))
+def test_pass_multiple_regs_into_forloop():
+ """
+ Test using multiple dynamically allocated resgisters from inside a subscope.
+ """
+
+ # CHECK: [[global_reg:%.+]] = quantum.alloc( 3)
+ # CHECK: [[q1:%.+]] = quantum.alloc( 1)
+ # CHECK: [[q2:%.+]] = quantum.alloc( 2)
+ # CHECK: [[for_out:%.+]]:3 = scf.for %arg0 = {{.+}} to {{.+}} step {{.+}} iter_args
+ # CHECK-SAME: (%arg1 = [[q1]], %arg2 = [[q2]], %arg3 = [[global_reg]])
+ # CHECK-SAME: -> (!quantum.reg, !quantum.reg, !quantum.reg) {
+ # CHECK: [[q1_0:%.+]] = quantum.extract %arg1[ 0]
+ # CHECK: [[glob_0:%.+]] = quantum.extract %arg3[ 0]
+ # CHECK: [[cnot_out0:%.+]]:2 = quantum.custom "CNOT"() [[q1_0]], [[glob_0]]
+ # CHECK: [[q2_1:%.+]] = quantum.extract %arg2[ 1]
+ # CHECK: [[glob_1:%.+]] = quantum.extract %arg3[ 1]
+ # CHECK: [[cnot_out1:%.+]]:2 = quantum.custom "CNOT"() [[q2_1]], [[glob_1]]
+ # CHECK: [[glob_ins:%.+]] = quantum.insert %arg3[ 0], [[cnot_out0]]#1
+ # CHECK: [[glob_yield:%.+]] = quantum.insert [[glob_ins]][ 1], [[cnot_out1]]#1
+ # CHECK: [[q1_yield:%.+]] = quantum.insert %arg1[ 0], [[cnot_out0]]#0
+ # CHECK: [[q2_yield:%.+]] = quantum.insert %arg2[ 1], [[cnot_out1]]#0
+ # CHECK: scf.yield [[q1_yield]], [[q2_yield]], [[glob_yield]]
+ # CHECK-SAME: : !quantum.reg, !quantum.reg, !quantum.reg
+ # CHECK: quantum.dealloc [[for_out]]#1 : !quantum.reg
+ # CHECK: quantum.dealloc [[for_out]]#0 : !quantum.reg
+
+ with qml.allocate(1) as q1:
+ with qml.allocate(2) as q2:
+ for _ in range(3):
+ qml.CNOT(wires=[q1[0], 0])
+ qml.CNOT(wires=[q2[1], 1])
+
+ return qml.probs(wires=[0, 1])
+
+
+print(test_pass_multiple_regs_into_forloop.mlir)
+
+
+@qjit(autograph=True, target="mlir")
+@qml.qnode(qml.device("lightning.qubit", wires=2))
+def test_pass_multiple_regs_into_whileloop(N: int):
+ """
+ Test using multiple dynamically allocated resgisters from inside a while loop.
+ """
+
+ # CHECK: [[global_reg:%.+]] = quantum.alloc( 2)
+ # CHECK: [[q1:%.+]] = quantum.alloc( 1)
+ # CHECK: [[q2:%.+]] = quantum.alloc( 4)
+ # CHECK: [[while_out:%.+]]:4 = scf.while (%arg1 = {{%.+}}, %arg2 = [[q1]], %arg3 = [[q2]],
+ # CHECK-SAME: %arg4 = [[global_reg]]) : (tensor, !quantum.reg, !quantum.reg, !quantum.reg)
+ # CHECK-SAME: -> (tensor, !quantum.reg, !quantum.reg, !quantum.reg) {
+ # CHECK: stablehlo.compare LT, %arg1, %arg0
+ # CHECK: scf.condition({{%.+}}) %arg1, %arg2, %arg3, %arg4
+ # CHECK: } do {
+ # CHECK: ^bb0(%arg1: tensor, %arg2: !quantum.reg, %arg3: !quantum.reg, %arg4: !quantum.reg
+ # CHECK: [[q1_0:%.+]] = quantum.extract %arg2[ 0]
+ # CHECK: [[glob_1:%.+]] = quantum.extract %arg4[ 1]
+ # CHECK: [[cnot_out0:%.+]]:2 = quantum.custom "CNOT"() [[q1_0]], [[glob_1]]
+ # CHECK: [[q2_0:%.+]] = quantum.extract %arg3[ 0]
+ # CHECK: [[cnot_out1:%.+]]:2 = quantum.custom "CNOT"() [[q2_0]], [[cnot_out0]]#1
+ # CHECK: [[i:%.+]] = stablehlo.add %arg1, {{%.+}}
+ # CHECK: [[glob_yield:%.+]] = quantum.insert %arg4[ 1], [[cnot_out1]]#1
+ # CHECK: [[q1_yield:%.+]] = quantum.insert %arg2[ 0], [[cnot_out0]]#0
+ # CHECK: [[q2_yield:%.+]] = quantum.insert %arg3[ 0], [[cnot_out1]]#0
+ # CHECK: scf.yield [[i]], [[q1_yield]], [[q2_yield]], [[glob_yield]]
+ # CHECK: }
+ # CHECK: quantum.dealloc [[while_out]]#2
+ # CHECK: quantum.dealloc [[while_out]]#1
+
+ i = 0
+ with qml.allocate(1) as q1:
+ with qml.allocate(4) as q2:
+ while i < N:
+ qml.CNOT(wires=[q1[0], 1])
+ qml.CNOT(wires=[q2[0], 1])
+ i += 1
+
+ return qml.probs(wires=[0, 1])
+
+
+print(test_pass_multiple_regs_into_whileloop.mlir)
+
+
+def test_quantum_subroutine():
+ """
+ Test passing dynamically allocated wires into a quantum subroutine.
+ """
+
+ @subroutine
+ def flip(w1, w2, theta):
+ qml.X(w1)
+ qml.X(w2)
+ qml.ctrl(qml.RX, (w1, w2))(theta, wires=0)
+
+ # CHECK: [[angle:%.+]] = stablehlo.constant dense<1.230000e+00>
+ # CHECK: [[one:%.+]] = stablehlo.constant dense<1>
+ # CHECK: [[zero:%.+]] = stablehlo.constant dense<0>
+ # CHECK: [[global_qreg:%.+]] = quantum.alloc( 1)
+ # CHECK: [[q1:%.+]] = quantum.alloc( 2)
+ # CHECK: [[q2:%.+]] = quantum.alloc( 3)
+ # CHECK: {{%.+}}:3 = call @flip([[global_qreg]], [[q1]], [[q2]], [[zero]], [[one]], [[angle]])
+ # CHECK-SAME: (!quantum.reg, !quantum.reg, !quantum.reg, tensor, tensor, tensor)
+ # CHECK-SAME: -> (!quantum.reg, !quantum.reg, !quantum.reg)
+
+ @qjit(target="mlir")
+ @qml.qnode(qml.device("lightning.qubit", wires=1))
+ def circuit():
+ with qml.allocate(2) as q1:
+ with qml.allocate(3) as q2:
+ flip(q1[0], q2[1], 1.23)
+ return qml.probs(wires=[0])
+
+ # CHECK: func.func private @flip(
+ # CHECK: [[zero:%.+]] = tensor.extract %arg3[]
+ # CHECK: [[q1_0:%.+]] = quantum.extract %arg1[[[zero]]]
+ # CHECK: [[x1_out:%.+]] = quantum.custom "PauliX"() [[q1_0]]
+ # CHECK: [[one:%.+]] = tensor.extract %arg4[]
+ # CHECK: [[q2_1:%.+]] = quantum.extract %arg2[[[one]]]
+ # CHECK: [[x2_out:%.+]] = quantum.custom "PauliX"() [[q2_1]]
+ # CHECK: [[glob_0:%.+]] = quantum.extract %arg0[ 0]
+ # CHECK: [[angle:%.+]] = tensor.extract %arg5[]
+ # CHECK: [[rx_out:%.+]], [[rx_ctrl_out:%.+]]:2 = quantum.custom "RX"([[angle]]) [[glob_0]]
+ # CHECK-SAME: ctrls([[x1_out]], [[x2_out]])
+ # CHECK: [[glob_re:%.+]] = quantum.insert %arg0[ 0], [[rx_out]]
+ # CHECK: [[q2_re:%.+]] = quantum.insert %arg2[{{%.+}}], [[rx_ctrl_out]]#1
+ # CHECK: [[q1_re:%.+]] = quantum.insert %arg1[{{%.+}}], [[rx_ctrl_out]]#0
+ # CHECK: return [[glob_re]], [[q1_re]], [[q2_re]] : !quantum.reg, !quantum.reg, !quantum.reg
+
+ print(circuit.mlir)
+
+
+test_quantum_subroutine()
+
+
qml.capture.disable()
diff --git a/frontend/test/pytest/test_dynamic_qubit_allocation.py b/frontend/test/pytest/test_dynamic_qubit_allocation.py
index c1d62544a..4674c04a0 100644
--- a/frontend/test/pytest/test_dynamic_qubit_allocation.py
+++ b/frontend/test/pytest/test_dynamic_qubit_allocation.py
@@ -235,6 +235,31 @@ def circuit(c):
assert np.allclose(expected, observed)
+@pytest.mark.usefixtures("use_capture")
+@pytest.mark.parametrize("cond, expected", [(True, [0, 1, 0, 0]), (False, [1, 0, 0, 0])])
+def test_dynamic_wire_alloc_cond_outside(cond, expected, backend):
+ """
+ Test passing dynamically allocated wires into a cond.
+ """
+
+ @qjit(autograph=True)
+ @qml.qnode(qml.device(backend, wires=2))
+ def circuit(c):
+ with qml.allocate(1) as q1:
+ with qml.allocate(1) as q2:
+ qml.X(q1[0])
+ if c:
+ qml.CNOT(wires=[q1[0], 1]) # |01>
+ else:
+ qml.CNOT(wires=[q2[0], 1]) # |00>
+
+ return qml.probs(wires=[0, 1])
+
+ observed = circuit(cond)
+
+ assert np.allclose(expected, observed)
+
+
@pytest.mark.usefixtures("use_capture")
@pytest.mark.parametrize(
"num_iter, expected", [(3, [0, 0, 1, 0, 0, 0, 0, 0]), (4, [1, 0, 0, 0, 0, 0, 0, 0])]
@@ -260,6 +285,51 @@ def circuit(N):
assert np.allclose(expected, observed)
+@pytest.mark.usefixtures("use_capture")
+def test_dynamic_wire_alloc_forloop_outside(backend):
+ """
+ Test passing dynamically allocated wires into a for loop.
+ """
+
+ @qjit(autograph=True)
+ @qml.qnode(qml.device(backend, wires=1))
+ def circuit():
+ with qml.allocate(1) as q:
+ qml.X(wires=q[0])
+ for _ in range(3):
+ qml.CNOT(wires=[q[0], 0])
+
+ return qml.probs(wires=[0])
+
+ observed = circuit()
+ expected = [0, 1]
+
+ assert np.allclose(expected, observed)
+
+
+@pytest.mark.usefixtures("use_capture")
+def test_dynamic_wire_alloc_forloop_outside_multiple_regs(backend):
+ """
+ Test using multiple dynamically allocated registers from inside for loop.
+ """
+
+ @qjit(autograph=True)
+ @qml.qnode(qml.device(backend, wires=1))
+ def circuit():
+ with qml.allocate(1) as q1:
+ with qml.allocate(1) as q2:
+ for _ in range(3):
+ qml.CNOT(wires=[q1[0], 0])
+ qml.CNOT(wires=[q2[0], 0])
+
+ return qml.probs(wires=[0])
+
+ observed = circuit()
+ expected = [1, 0]
+
+ assert np.allclose(expected, observed)
+
+
@pytest.mark.usefixtures("use_capture")
@pytest.mark.parametrize(
"num_iter, expected", [(3, [0, 0, 1, 0, 0, 0, 0, 0]), (4, [1, 0, 0, 0, 0, 0, 0, 0])]
@@ -287,6 +357,83 @@ def circuit(N):
assert np.allclose(expected, observed)
+@pytest.mark.usefixtures("use_capture")
+@pytest.mark.parametrize("num_iter, expected", [(3, [0, 1, 0, 0]), (4, [1, 0, 0, 0])])
+def test_dynamic_wire_alloc_whileloop_outside(num_iter, expected, backend):
+ """
+ Test passing dynamically allocated wires into a while loop.
+ """
+
+ @qjit(autograph=True)
+ @qml.qnode(qml.device(backend, wires=2))
+ def circuit(N):
+ i = 0
+ with qml.allocate(1) as q1:
+ with qml.allocate(1) as q2:
+ qml.X(q1[0])
+ while i < N:
+ qml.CNOT(wires=[q1[0], 1])
+ qml.CNOT(wires=[q2[0], 1])
+ i += 1
+
+ return qml.probs(wires=[0, 1])
+
+ observed = circuit(num_iter)
+
+ assert np.allclose(expected, observed)
+
+
+@pytest.mark.usefixtures("use_capture")
+@pytest.mark.parametrize("flip_again, expected", [(True, [1, 0]), (False, [0, 1])])
+def test_subroutine(flip_again, expected, backend):
+ """
+ Test passing dynamically allocated wires into a subroutine.
+ """
+
+ @subroutine
+ def flip(w):
+ qml.X(w)
+ qml.CNOT(wires=[w, 0])
+
+ @qjit
+ @qml.qnode(qml.device(backend, wires=1))
+ def circuit():
+ with qml.allocate(1) as q1:
+ with qml.allocate(1) as q2:
+ flip(q1[0])
+ if flip_again:
+ flip(q2[0])
+ return qml.probs(wires=[0])
+
+ observed = circuit()
+ assert np.allclose(expected, observed)
+
+
+@pytest.mark.usefixtures("use_capture")
+def test_subroutine_multiple_args(backend):
+ """
+ Test passing dynamically allocated wires into a subroutine with multiple arguments.
+ """
+
+ @subroutine
+ def flip(w1, w2, theta):
+ qml.X(w1)
+ qml.X(w2)
+ qml.ctrl(qml.RX, (w1, w2))(theta, wires=0)
+
+ @qjit
+ @qml.qnode(qml.device(backend, wires=1))
+ def circuit():
+ with qml.allocate(1) as q1:
+ with qml.allocate(2) as q2:
+ flip(q1[0], q2[1], jnp.pi)
+ return qml.probs(wires=[0])
+
+ observed = circuit()
+ expected = [0, 1]
+ assert np.allclose(expected, observed)
+
+
def test_no_capture(backend):
"""
Test error message when used without capture.
@@ -372,59 +519,21 @@ def circuit():
@pytest.mark.usefixtures("use_capture")
-def test_unsupported_cross_scope_registers(backend):
+def test_unsupported_adjoint(backend):
"""
- Scope jaxprs in Catalyst cannot take multiple registers yet.
- Test that an error is raised when a dynamically allocated register in an outside scope
- is being used from an inside scope.
+ Test that an error is raised when a dynamically allocated wire is passed into a adjoint.
"""
with pytest.raises(
NotImplementedError,
- match=textwrap.dedent(
- """
- Dynamically allocated wires in a parent scope cannot be used in a child
- scope yet. Please consider dynamical allocation inside the child scope.
- """
- ),
+ match="Dynamically allocated wires cannot be used in quantum adjoints yet.",
):
- @qjit(autograph=True)
- @qml.qnode(qml.device(backend, wires=3))
- def circuit():
- wires = qml.allocate(3)
-
- for _ in range(3):
- qml.X(wires=wires[0])
-
- return qml.probs(wires=[0, 1, 2])
-
-
-@pytest.mark.usefixtures("use_capture")
-def test_unsupported_subroutine(backend):
- """
- Test that an error is raised when a dynamically allocated wire is passed into a subroutine.
- """
-
- with pytest.raises(
- NotImplementedError,
- match=textwrap.dedent(
- """
- Dynamically allocated wires in a parent scope cannot be used in a child
- scope yet. Please consider dynamical allocation inside the child scope.
- """
- ),
- ):
-
- @subroutine
- def sub(_):
- pass
-
@qjit
@qml.qnode(qml.device(backend, wires=2))
def circuit():
with qml.allocate(1) as q:
- sub(q[0])
+ qml.adjoint(qml.X)(q[0])
return qml.probs(wires=[0, 1])