diff --git a/.dep-versions b/.dep-versions index a36582d4fd..ce8033697b 100644 --- a/.dep-versions +++ b/.dep-versions @@ -2,9 +2,9 @@ # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} jax=0.6.2 -stablehlo=69d6dae46e1c7de36e6e6973654754f05353cba5 -llvm=f8cb7987c64dcffb72414a40560055cb717dbf74 -enzyme=v0.0.186 +stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d +llvm=113f01aa82d055410f22a9d03b3468fa68600589 +enzyme=v0.0.203 # Always remove custom PL/LQ versions before release. diff --git a/mlir/Makefile b/mlir/Makefile index 8fc76e11e7..4628b99bd7 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -138,7 +138,7 @@ enzyme: -DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) \ -DCMAKE_POLICY_DEFAULT_CMP0116=NEW - cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-21 + cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-22 .PHONY: plugin plugin: diff --git a/mlir/include/Catalyst/IR/CatalystOps.td b/mlir/include/Catalyst/IR/CatalystOps.td index 12daf4f6e9..c3c60fc840 100644 --- a/mlir/include/Catalyst/IR/CatalystOps.td +++ b/mlir/include/Catalyst/IR/CatalystOps.td @@ -138,8 +138,17 @@ def CallbackOp : Catalyst_Op<"callback", let builders = [OpBuilder<(ins "mlir::StringRef":$name, "mlir::FunctionType":$type, - CArg<"mlir::ArrayRef", "{}">:$attrs) - >]; + CArg<"mlir::ArrayRef", "{}">:$attrs), [{ + $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + $_builder.getStringAttr(name)); + $_state.addAttribute("function_type", mlir::TypeAttr::get(type)); + $_state.addAttribute("id", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("argc", $_builder.getI64IntegerAttr(type.getNumInputs())); + $_state.addAttribute("resc", $_builder.getI64IntegerAttr(type.getNumResults())); + $_state.attributes.append(attrs.begin(), attrs.end()); + $_state.addRegion(); + }]> + ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// diff --git a/mlir/include/Gradient/IR/GradientOps.td b/mlir/include/Gradient/IR/GradientOps.td index fb81419b99..75905049aa 100644 --- a/mlir/include/Gradient/IR/GradientOps.td +++ b/mlir/include/Gradient/IR/GradientOps.td @@ -28,7 +28,7 @@ include "Gradient/IR/GradientInterfaces.td" def GradOp : Gradient_Op<"grad", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, GradientOpInterface ]> { let summary = "Compute the gradient of a function."; @@ -287,7 +287,7 @@ def ForwardOp : Gradient_Op<"forward", Then: followed by the original return type, if any. - + since there is none, then: %returnTy = { %tape } @@ -302,7 +302,7 @@ def ForwardOp : Gradient_Op<"forward", One thing that was found experimentally and through tests in Enzyme is that the tape can also be a pointer. We use this in the case when there is no tape to return. Instead of returning an empty struct, we return a null pointer that is just never dereferenced. - + }]; let arguments = (ins @@ -320,8 +320,18 @@ def ForwardOp : Gradient_Op<"forward", let builders = [OpBuilder<(ins "mlir::StringRef":$name, "mlir::FunctionType":$type, - CArg<"mlir::ArrayRef", "{}">:$attrs) - >]; + CArg<"mlir::ArrayRef", "{}">:$attrs), [{ + $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + $_builder.getStringAttr(name)); + $_state.addAttribute("function_type", mlir::TypeAttr::get(type)); + $_state.addAttribute("implementation", mlir::FlatSymbolRefAttr::get($_builder.getStringAttr(""))); + $_state.addAttribute("argc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("resc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("tape", $_builder.getI64IntegerAttr(0)); + $_state.attributes.append(attrs.begin(), attrs.end()); + $_state.addRegion(); + }]> + ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// @@ -358,7 +368,6 @@ def ReverseOp : Gradient_Op<"reverse", %returnTy = { %tape } - }]; let arguments = (ins @@ -376,8 +385,18 @@ def ReverseOp : Gradient_Op<"reverse", let builders = [OpBuilder<(ins "mlir::StringRef":$name, "mlir::FunctionType":$type, - CArg<"mlir::ArrayRef", "{}">:$attrs) - >]; + CArg<"mlir::ArrayRef", "{}">:$attrs), [{ + $_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + $_builder.getStringAttr(name)); + $_state.addAttribute("function_type", mlir::TypeAttr::get(type)); + $_state.addAttribute("implementation", mlir::FlatSymbolRefAttr::get($_builder.getStringAttr(""))); + $_state.addAttribute("argc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("resc", $_builder.getI64IntegerAttr(0)); + $_state.addAttribute("tape", $_builder.getI64IntegerAttr(0)); + $_state.attributes.append(attrs.begin(), attrs.end()); + $_state.addRegion(); + }]> + ]; let extraClassDeclaration = [{ //===------------------------------------------------------------------===// diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index c3a82be9ec..5ec6857426 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -27,6 +27,9 @@ set(LIBS ${translation_libs} ExternalStablehloLib MLIROptLib + MLIRRegisterAllDialects + MLIRRegisterAllPasses + MLIRRegisterAllExtensions MLIRCatalyst catalyst-transforms MLIRQuantum diff --git a/mlir/lib/Driver/Pipelines.cpp b/mlir/lib/Driver/Pipelines.cpp index ccf4e4ab4f..564b2c2f07 100644 --- a/mlir/lib/Driver/Pipelines.cpp +++ b/mlir/lib/Driver/Pipelines.cpp @@ -14,9 +14,22 @@ #include -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/InitAllDialects.h" -#include "mlir/InitAllPasses.h" +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "stablehlo/conversions/linalg/transforms/Passes.h" diff --git a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp index 54cbc85aac..4397aafab2 100644 --- a/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp @@ -124,7 +124,7 @@ void TensorType2MemrefType(const TypeRange &inTypes, SmallVector &converte } } -static BaseMemRefType +static bufferization::BufferLikeType getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index, const bufferization::BufferizationOptions &options) { @@ -134,7 +134,7 @@ getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index, BaseMemRefType memrefType = options.functionArgTypeConverterFn( tensorType, *options.defaultMemorySpaceFn(tensorType), nullptr, options); - return memrefType; + return cast(memrefType); } static ReturnOp getAssumedUniqueReturnOp(FunctionOpInterface funcOp) @@ -402,10 +402,10 @@ struct ForwardOpInterface return {}; } - FailureOr getBufferType(Operation *op, Value value, - const bufferization::BufferizationOptions &options, - const bufferization::BufferizationState &state, - SmallVector &invocationStack) const + FailureOr + getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options, + const bufferization::BufferizationState &state, + SmallVector &invocationStack) const { // The getBufferType() method is called on either BlockArguments or OpResults. // https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td#L506 @@ -526,10 +526,10 @@ struct ReverseOpInterface return {}; } - FailureOr getBufferType(Operation *op, Value value, - const bufferization::BufferizationOptions &options, - const bufferization::BufferizationState &state, - SmallVector &invocationStack) const + FailureOr + getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options, + const bufferization::BufferizationState &state, + SmallVector &invocationStack) const { // See comment on the getBufferType() method on forward op. auto reverseOp = cast(op); diff --git a/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp b/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp index 672819d3e7..ad0e26cdf0 100644 --- a/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/SymbolTable.h" #include "Gradient/Utils/EinsumLinalgGeneric.h" @@ -60,8 +61,6 @@ template std::vector _tovec(const T &x) LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rewriter) const { - MLIRContext *ctx = getContext(); - Location loc = op.getLoc(); auto func_diff_operand_indices = computeDiffArgIndices(op.getDiffArgIndices()); @@ -159,12 +158,9 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew } else { assert(acc.value().getType() == res.getType()); - - auto add_op = rewriter.create( - loc, res.getType(), ValueRange({acc.value(), res}), acc.value(), - linalg::BinaryFnAttr::get(ctx, linalg::BinaryFn::add), - linalg::TypeFnAttr::get(ctx, linalg::TypeFn::cast_signed)); - acc = add_op.getResultTensors()[0]; + auto addOp = rewriter.create( + loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()}); + acc = addOp.getResultTensors()[0]; } } assert(acc.has_value()); @@ -181,8 +177,6 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rewriter) const { - MLIRContext *ctx = getContext(); - Location loc = op.getLoc(); auto func_diff_operand_indices = computeDiffArgIndices(op.getDiffArgIndices()); @@ -278,11 +272,9 @@ LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rew else { assert(acc.value().getType() == res.getType()); - auto add_op = rewriter.create( - loc, res.getType(), ValueRange({acc.value(), res}), acc.value(), - linalg::BinaryFnAttr::get(ctx, linalg::BinaryFn::add), - linalg::TypeFnAttr::get(ctx, linalg::TypeFn::cast_signed)); - acc = add_op.getResultTensors()[0]; + auto addOp = rewriter.create( + loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()}); + acc = addOp.getResultTensors()[0]; } } assert(acc.has_value()); diff --git a/mlir/tools/quantum-lsp-server/CMakeLists.txt b/mlir/tools/quantum-lsp-server/CMakeLists.txt index f4a7c2e727..507480ef00 100644 --- a/mlir/tools/quantum-lsp-server/CMakeLists.txt +++ b/mlir/tools/quantum-lsp-server/CMakeLists.txt @@ -5,6 +5,7 @@ set(LIBS ${conversion_libs} ExternalStablehloLib MLIRLspServerLib + MLIRRegisterAllDialects MLIRCatalyst MLIRQuantum MLIRQEC diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index 10c6ed5a0f..617398b03c 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -7,6 +7,8 @@ set(LIBS ${extension_libs} ExternalStablehloLib MLIROptLib + MLIRRegisterAllDialects + MLIRRegisterAllPasses MLIRCatalyst catalyst-transforms catalyst-stablehlo-transforms diff --git a/setup.py b/setup.py index 84ce5c873a..6015976547 100644 --- a/setup.py +++ b/setup.py @@ -102,6 +102,15 @@ def parse_dep_versions(): return results +def is_git_commit_hash(version_string): + """Check if a version string is a git commit hash (40 character hex string).""" + if version_string is None: + return False + return len(version_string) == 40 and all( + c in "0123456789abcdef" for c in version_string.lower() + ) + + dep_versions = parse_dep_versions() jax_version = dep_versions.get("jax") pl_version = dep_versions.get("pennylane") @@ -110,28 +119,79 @@ def parse_dep_versions(): pl_min_release = "0.43.0" lq_min_release = pl_min_release +# Handle PennyLane version - support both release versions and git commit hashes if pl_version is not None: - pennylane_dep = f"pennylane=={pl_version}" # use TestPyPI wheels, git is not allowed on PyPI + if is_git_commit_hash(pl_version): + # For git commits, install from git source + pennylane_dep = f"pennylane @ git+https://github.com/PennyLaneAI/pennylane.git@{pl_version}" + print("=" * 80) + print("WARNING: PennyLane is being installed from a git commit.") + print(f"Commit: {pl_version}") + print("=" * 80) + else: + # For release versions, use standard version specifier + pennylane_dep = ( + f"pennylane=={pl_version}" # use TestPyPI wheels, git is not allowed on PyPI + ) else: pennylane_dep = f"pennylane>={pl_min_release}" + +# Handle Lightning version - support both release versions and git commit hashes if lq_version is not None: - lightning_dep = f"pennylane-lightning=={lq_version}" # use TestPyPI wheels to avoid rebuild - kokkos_dep = f"pennylane-lightning-kokkos=={lq_version}" + if is_git_commit_hash(lq_version): + # For git commits, install from git source + lightning_dep = f"pennylane-lightning @ git+https://github.com/PennyLaneAI/pennylane-lightning.git@{lq_version}" + kokkos_dep = "" # Kokkos not available from git + print("=" * 80) + print("WARNING: PennyLane-Lightning is being installed from a git commit.") + print(f"Commit: {lq_version}") + print("Note: pennylane-lightning-kokkos is not available when installing from git.") + print("=" * 80) + else: + # For release versions, use standard version specifier + lightning_dep = f"pennylane-lightning=={lq_version}" # use TestPyPI wheels to avoid rebuild + kokkos_dep = f"pennylane-lightning-kokkos=={lq_version}" else: lightning_dep = f"pennylane-lightning>={lq_min_release}" kokkos_dep = "" +# Handle JAX version - support both release versions and git commit hashes +if jax_version is not None: + if is_git_commit_hash(jax_version): + # For git commits, only specify jax from git source + # Note: When installing from git, jaxlib must be installed separately + jax_dep = f"jax @ git+https://github.com/google/jax.git@{jax_version}" + # Don't add jaxlib to requirements when using git - it needs to be installed separately + jaxlib_dep = None + print("=" * 80) + print("WARNING: JAX is being installed from a git commit.") + print("You may need to install a compatible jaxlib version separately:") + print(" pip install jaxlib==...") + print("Or build jaxlib from the same commit if needed.") + print("=" * 80) + else: + # For release versions, use standard version specifier + jax_dep = f"jax=={jax_version}" + jaxlib_dep = f"jaxlib=={jax_version}" +else: + # Fallback if no JAX version is specified + jax_dep = "jax" + jaxlib_dep = "jaxlib" + requirements = [ pennylane_dep, lightning_dep, kokkos_dep, - f"jax=={jax_version}", - f"jaxlib=={jax_version}", + jax_dep, "numpy!=2.0.0", "scipy-openblas32>=0.3.26", # symbol and library name "diastatic-malt>=2.15.2", ] +# Add jaxlib only if it's not None (i.e., not using git commit) +if jaxlib_dep is not None: + requirements.insert(4, jaxlib_dep) + entry_points = { "pennylane.plugins": [ "oqc.cloud = catalyst.third_party.oqc:OQCDevice",