Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.6.2
stablehlo=69d6dae46e1c7de36e6e6973654754f05353cba5
llvm=f8cb7987c64dcffb72414a40560055cb717dbf74
enzyme=v0.0.186
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
llvm=113f01aa82d055410f22a9d03b3468fa68600589
enzyme=v0.0.203

# Always remove custom PL/LQ versions before release.

Expand Down
2 changes: 1 addition & 1 deletion mlir/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ enzyme:
-DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) \
-DCMAKE_POLICY_DEFAULT_CMP0116=NEW

cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-21
cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-22

.PHONY: plugin
plugin:
Expand Down
13 changes: 11 additions & 2 deletions mlir/include/Catalyst/IR/CatalystOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,17 @@ def CallbackOp : Catalyst_Op<"callback",

let builders = [OpBuilder<(ins
"mlir::StringRef":$name, "mlir::FunctionType":$type,
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs)
>];
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs), [{
$_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
$_builder.getStringAttr(name));
$_state.addAttribute("function_type", mlir::TypeAttr::get(type));
$_state.addAttribute("id", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("argc", $_builder.getI64IntegerAttr(type.getNumInputs()));
$_state.addAttribute("resc", $_builder.getI64IntegerAttr(type.getNumResults()));
$_state.attributes.append(attrs.begin(), attrs.end());
$_state.addRegion();
}]>
];

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
Expand Down
35 changes: 27 additions & 8 deletions mlir/include/Gradient/IR/GradientOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ include "Gradient/IR/GradientInterfaces.td"

def GradOp : Gradient_Op<"grad", [
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
GradientOpInterface
]> {
let summary = "Compute the gradient of a function.";
Expand Down Expand Up @@ -287,7 +287,7 @@ def ForwardOp : Gradient_Op<"forward",
Then:

followed by the original return type, if any.

since there is none, then:

%returnTy = { %tape }
Expand All @@ -302,7 +302,7 @@ def ForwardOp : Gradient_Op<"forward",
One thing that was found experimentally and through tests in Enzyme is that the tape can also be a pointer.
We use this in the case when there is no tape to return. Instead of returning an empty struct, we return a null
pointer that is just never dereferenced.

}];

let arguments = (ins
Expand All @@ -320,8 +320,18 @@ def ForwardOp : Gradient_Op<"forward",

let builders = [OpBuilder<(ins
"mlir::StringRef":$name, "mlir::FunctionType":$type,
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs)
>];
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs), [{
$_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
$_builder.getStringAttr(name));
$_state.addAttribute("function_type", mlir::TypeAttr::get(type));
$_state.addAttribute("implementation", mlir::FlatSymbolRefAttr::get($_builder.getStringAttr("")));
$_state.addAttribute("argc", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("resc", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("tape", $_builder.getI64IntegerAttr(0));
$_state.attributes.append(attrs.begin(), attrs.end());
$_state.addRegion();
}]>
];

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
Expand Down Expand Up @@ -358,7 +368,6 @@ def ReverseOp : Gradient_Op<"reverse",

%returnTy = { %tape }


}];

let arguments = (ins
Expand All @@ -376,8 +385,18 @@ def ReverseOp : Gradient_Op<"reverse",

let builders = [OpBuilder<(ins
"mlir::StringRef":$name, "mlir::FunctionType":$type,
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs)
>];
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs), [{
$_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
$_builder.getStringAttr(name));
$_state.addAttribute("function_type", mlir::TypeAttr::get(type));
$_state.addAttribute("implementation", mlir::FlatSymbolRefAttr::get($_builder.getStringAttr("")));
$_state.addAttribute("argc", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("resc", $_builder.getI64IntegerAttr(0));
$_state.addAttribute("tape", $_builder.getI64IntegerAttr(0));
$_state.attributes.append(attrs.begin(), attrs.end());
$_state.addRegion();
}]>
];

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ set(LIBS
${translation_libs}
ExternalStablehloLib
MLIROptLib
MLIRRegisterAllDialects
MLIRRegisterAllPasses
MLIRRegisterAllExtensions
MLIRCatalyst
catalyst-transforms
MLIRQuantum
Expand Down
19 changes: 16 additions & 3 deletions mlir/lib/Driver/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@

#include <memory>

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "stablehlo/conversions/linalg/transforms/Passes.h"
Expand Down
20 changes: 10 additions & 10 deletions mlir/lib/Gradient/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void TensorType2MemrefType(const TypeRange &inTypes, SmallVector<Type> &converte
}
}

static BaseMemRefType
static bufferization::BufferLikeType
getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index,
const bufferization::BufferizationOptions &options)
{
Expand All @@ -134,7 +134,7 @@ getBufferizedFunctionArgType(FunctionOpInterface funcOp, int64_t index,
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
tensorType, *options.defaultMemorySpaceFn(tensorType), nullptr, options);

return memrefType;
return cast<bufferization::BufferLikeType>(memrefType);
}

static ReturnOp getAssumedUniqueReturnOp(FunctionOpInterface funcOp)
Expand Down Expand Up @@ -402,10 +402,10 @@ struct ForwardOpInterface
return {};
}

FailureOr<BaseMemRefType> getBufferType(Operation *op, Value value,
const bufferization::BufferizationOptions &options,
const bufferization::BufferizationState &state,
SmallVector<Value> &invocationStack) const
FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options,
const bufferization::BufferizationState &state,
SmallVector<Value> &invocationStack) const
{
// The getBufferType() method is called on either BlockArguments or OpResults.
// https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td#L506
Expand Down Expand Up @@ -526,10 +526,10 @@ struct ReverseOpInterface
return {};
}

FailureOr<BaseMemRefType> getBufferType(Operation *op, Value value,
const bufferization::BufferizationOptions &options,
const bufferization::BufferizationState &state,
SmallVector<Value> &invocationStack) const
FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const bufferization::BufferizationOptions &options,
const bufferization::BufferizationState &state,
SmallVector<Value> &invocationStack) const
{
// See comment on the getBufferType() method on forward op.
auto reverseOp = cast<ReverseOp>(op);
Expand Down
22 changes: 7 additions & 15 deletions mlir/lib/Gradient/Transforms/GradMethods/JVPVJPPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/SymbolTable.h"

#include "Gradient/Utils/EinsumLinalgGeneric.h"
Expand Down Expand Up @@ -60,8 +61,6 @@ template <class T> std::vector<int64_t> _tovec(const T &x)

LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rewriter) const
{
MLIRContext *ctx = getContext();

Location loc = op.getLoc();

auto func_diff_operand_indices = computeDiffArgIndices(op.getDiffArgIndices());
Expand Down Expand Up @@ -159,12 +158,9 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew
}
else {
assert(acc.value().getType() == res.getType());

auto add_op = rewriter.create<linalg::ElemwiseBinaryOp>(
loc, res.getType(), ValueRange({acc.value(), res}), acc.value(),
linalg::BinaryFnAttr::get(ctx, linalg::BinaryFn::add),
linalg::TypeFnAttr::get(ctx, linalg::TypeFn::cast_signed));
acc = add_op.getResultTensors()[0];
auto addOp = rewriter.create<linalg::AddOp>(
loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()});
acc = addOp.getResultTensors()[0];
}
}
assert(acc.has_value());
Expand All @@ -181,8 +177,6 @@ LogicalResult JVPLoweringPattern::matchAndRewrite(JVPOp op, PatternRewriter &rew

LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rewriter) const
{
MLIRContext *ctx = getContext();

Location loc = op.getLoc();

auto func_diff_operand_indices = computeDiffArgIndices(op.getDiffArgIndices());
Expand Down Expand Up @@ -278,11 +272,9 @@ LogicalResult VJPLoweringPattern::matchAndRewrite(VJPOp op, PatternRewriter &rew
else {
assert(acc.value().getType() == res.getType());

auto add_op = rewriter.create<linalg::ElemwiseBinaryOp>(
loc, res.getType(), ValueRange({acc.value(), res}), acc.value(),
linalg::BinaryFnAttr::get(ctx, linalg::BinaryFn::add),
linalg::TypeFnAttr::get(ctx, linalg::TypeFn::cast_signed));
acc = add_op.getResultTensors()[0];
auto addOp = rewriter.create<linalg::AddOp>(
loc, res.getType(), ValueRange{acc.value(), res}, ValueRange{acc.value()});
acc = addOp.getResultTensors()[0];
}
}
assert(acc.has_value());
Expand Down
1 change: 1 addition & 0 deletions mlir/tools/quantum-lsp-server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set(LIBS
${conversion_libs}
ExternalStablehloLib
MLIRLspServerLib
MLIRRegisterAllDialects
MLIRCatalyst
MLIRQuantum
MLIRQEC
Expand Down
2 changes: 2 additions & 0 deletions mlir/tools/quantum-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ set(LIBS
${extension_libs}
ExternalStablehloLib
MLIROptLib
MLIRRegisterAllDialects
MLIRRegisterAllPasses
MLIRCatalyst
catalyst-transforms
catalyst-stablehlo-transforms
Expand Down
70 changes: 65 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def parse_dep_versions():
return results


def is_git_commit_hash(version_string):
"""Check if a version string is a git commit hash (40 character hex string)."""
if version_string is None:
return False
return len(version_string) == 40 and all(
c in "0123456789abcdef" for c in version_string.lower()
)


dep_versions = parse_dep_versions()
jax_version = dep_versions.get("jax")
pl_version = dep_versions.get("pennylane")
Expand All @@ -110,28 +119,79 @@ def parse_dep_versions():
pl_min_release = "0.43.0"
lq_min_release = pl_min_release

# Handle PennyLane version - support both release versions and git commit hashes
if pl_version is not None:
pennylane_dep = f"pennylane=={pl_version}" # use TestPyPI wheels, git is not allowed on PyPI
if is_git_commit_hash(pl_version):
# For git commits, install from git source
pennylane_dep = f"pennylane @ git+https://github.com/PennyLaneAI/pennylane.git@{pl_version}"
print("=" * 80)
print("WARNING: PennyLane is being installed from a git commit.")
print(f"Commit: {pl_version}")
print("=" * 80)
else:
# For release versions, use standard version specifier
pennylane_dep = (
f"pennylane=={pl_version}" # use TestPyPI wheels, git is not allowed on PyPI
)
else:
pennylane_dep = f"pennylane>={pl_min_release}"

# Handle Lightning version - support both release versions and git commit hashes
if lq_version is not None:
lightning_dep = f"pennylane-lightning=={lq_version}" # use TestPyPI wheels to avoid rebuild
kokkos_dep = f"pennylane-lightning-kokkos=={lq_version}"
if is_git_commit_hash(lq_version):
# For git commits, install from git source
lightning_dep = f"pennylane-lightning @ git+https://github.com/PennyLaneAI/pennylane-lightning.git@{lq_version}"
kokkos_dep = "" # Kokkos not available from git
print("=" * 80)
print("WARNING: PennyLane-Lightning is being installed from a git commit.")
print(f"Commit: {lq_version}")
print("Note: pennylane-lightning-kokkos is not available when installing from git.")
print("=" * 80)
else:
# For release versions, use standard version specifier
lightning_dep = f"pennylane-lightning=={lq_version}" # use TestPyPI wheels to avoid rebuild
kokkos_dep = f"pennylane-lightning-kokkos=={lq_version}"
else:
lightning_dep = f"pennylane-lightning>={lq_min_release}"
kokkos_dep = ""

# Handle JAX version - support both release versions and git commit hashes
if jax_version is not None:
if is_git_commit_hash(jax_version):
# For git commits, only specify jax from git source
# Note: When installing from git, jaxlib must be installed separately
jax_dep = f"jax @ git+https://github.com/google/jax.git@{jax_version}"
# Don't add jaxlib to requirements when using git - it needs to be installed separately
jaxlib_dep = None
print("=" * 80)
print("WARNING: JAX is being installed from a git commit.")
print("You may need to install a compatible jaxlib version separately:")
print(" pip install jaxlib==...")
print("Or build jaxlib from the same commit if needed.")
print("=" * 80)
else:
# For release versions, use standard version specifier
jax_dep = f"jax=={jax_version}"
jaxlib_dep = f"jaxlib=={jax_version}"
else:
# Fallback if no JAX version is specified
jax_dep = "jax"
jaxlib_dep = "jaxlib"

requirements = [
pennylane_dep,
lightning_dep,
kokkos_dep,
f"jax=={jax_version}",
f"jaxlib=={jax_version}",
jax_dep,
"numpy!=2.0.0",
"scipy-openblas32>=0.3.26", # symbol and library name
"diastatic-malt>=2.15.2",
]

# Add jaxlib only if it's not None (i.e., not using git commit)
if jaxlib_dep is not None:
requirements.insert(4, jaxlib_dep)

entry_points = {
"pennylane.plugins": [
"oqc.cloud = catalyst.third_party.oqc:OQCDevice",
Expand Down
Loading