Skip to content

Commit f2c3191

Browse files
Revert "[linalg] Broadcast batch for mask on sdpa lowering (#3824)"
This reverts commit 25738b8.
1 parent 0f5dbbe commit f2c3191

File tree

2 files changed

+25
-81
lines changed

2 files changed

+25
-81
lines changed

lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp

Lines changed: 23 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,6 @@ class ConvertAtenScaledDotProductAttentionOp
16611661
auto valueTy = cast<ShapedType>(value.getType());
16621662
auto keyTy = cast<ShapedType>(key.getType());
16631663

1664-
auto loc = op.getLoc();
16651664
Value dropoutP = op.getDropoutP();
16661665
Value isCausal = op.getIsCausal();
16671666
Value scale = op.getScale();
@@ -1672,46 +1671,49 @@ class ConvertAtenScaledDotProductAttentionOp
16721671
double dropout;
16731672
if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) ||
16741673
dropout > 0.0)
1675-
return rewriter.notifyMatchFailure(loc, "dropout not supported");
1674+
return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported");
16761675

16771676
bool causal;
16781677
if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) {
16791678
if (!isa<Torch::NoneType>(mask.getType())) {
16801679
return rewriter.notifyMatchFailure(
1681-
loc, "expected no attention mask when isCausal is true");
1680+
op.getLoc(), "expected no attention mask when isCausal is true");
16821681
}
16831682

16841683
SmallVector<int64_t> maskStatic;
16851684
SmallVector<Value> maskDyn;
16861685
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
16871686
maskStatic.push_back(queryTy.getDimSize(i));
16881687
if (maskStatic.back() == ShapedType::kDynamic)
1689-
maskDyn.push_back(rewriter.create<tensor::DimOp>(loc, query, i));
1688+
maskDyn.push_back(
1689+
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
16901690
}
16911691

16921692
maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
16931693
if (maskStatic.back() == ShapedType::kDynamic)
1694-
maskDyn.push_back(
1695-
rewriter.create<tensor::DimOp>(loc, key, keyTy.getRank() - 2));
1694+
maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key,
1695+
keyTy.getRank() - 2));
16961696

16971697
Type maskType = getElementTypeOrSelf(queryTy);
1698-
Value emptyMask =
1699-
rewriter.create<tensor::EmptyOp>(loc, maskStatic, maskType, maskDyn);
1698+
Value emptyMask = rewriter.create<tensor::EmptyOp>(
1699+
op.getLoc(), maskStatic, maskType, maskDyn);
17001700

17011701
Value zero = rewriter.create<arith::ConstantOp>(
1702-
loc, rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
1702+
op.getLoc(),
1703+
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0));
17031704
Value negInf = rewriter.create<arith::ConstantOp>(
1704-
loc,
1705+
op.getLoc(),
17051706
rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY));
17061707

1707-
mask = rewriter.create<linalg::FillOp>(loc, zero, emptyMask).getResult(0);
1708+
mask = rewriter.create<linalg::FillOp>(op.getLoc(), zero, emptyMask)
1709+
.getResult(0);
17081710

17091711
int64_t rank = cast<ShapedType>(queryTy).getRank();
17101712
AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank);
17111713
SmallVector<utils::IteratorType> iteratorTypes(
17121714
rank, utils::IteratorType::parallel);
17131715
auto genericOp = rewriter.create<linalg::GenericOp>(
1714-
loc, mask.getType(), ValueRange{}, mask,
1716+
op.getLoc(), mask.getType(), ValueRange{}, mask,
17151717
SmallVector<AffineMap>{maskMap}, iteratorTypes,
17161718
[&](OpBuilder &b, Location loc, ValueRange args) {
17171719
Value i = b.create<linalg::IndexOp>(loc, queryTy.getRank() - 2);
@@ -1725,78 +1727,18 @@ class ConvertAtenScaledDotProductAttentionOp
17251727
mask = genericOp.getResult(0);
17261728
}
17271729

1728-
// Broadcast the batch dimensions of the mask:
1729-
if (!isa<Torch::NoneType>(mask.getType())) {
1730-
auto maskTy = cast<RankedTensorType>(mask.getType());
1731-
int64_t rank = maskTy.getRank();
1732-
bool needsBroadcast = false;
1733-
for (int i = 0, s = rank - 2; i < s; ++i) {
1734-
needsBroadcast |= maskTy.getDimSize(i) != keyTy.getDimSize(i);
1735-
}
1736-
1737-
if (needsBroadcast) {
1738-
SmallVector<int64_t> maskShape;
1739-
SmallVector<Value> maskDynDims;
1740-
1741-
SmallVector<AffineExpr> maskExprs;
1742-
for (int i = 0, s = rank - 2; i < s; ++i) {
1743-
maskShape.push_back(keyTy.getDimSize(i));
1744-
1745-
if (maskTy.getDimSize(i) != keyTy.getDimSize(i)) {
1746-
maskExprs.push_back(rewriter.getAffineConstantExpr(0));
1747-
} else {
1748-
maskExprs.push_back(rewriter.getAffineDimExpr(i));
1749-
}
1750-
1751-
if (keyTy.isDynamicDim(i)) {
1752-
maskDynDims.push_back(rewriter.create<tensor::DimOp>(loc, key, i));
1753-
}
1754-
}
1755-
1756-
maskExprs.push_back(rewriter.getAffineDimExpr(rank - 2));
1757-
maskExprs.push_back(rewriter.getAffineDimExpr(rank - 1));
1758-
maskShape.push_back(maskTy.getDimSize(rank - 2));
1759-
maskShape.push_back(maskTy.getDimSize(rank - 1));
1760-
if (maskTy.isDynamicDim(rank - 2))
1761-
maskDynDims.push_back(
1762-
rewriter.create<tensor::DimOp>(loc, mask, rank - 2));
1763-
if (maskTy.isDynamicDim(rank - 1))
1764-
maskDynDims.push_back(
1765-
rewriter.create<tensor::DimOp>(loc, mask, rank - 1));
1766-
1767-
SmallVector<AffineMap> affineMaps = {
1768-
AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, maskExprs,
1769-
op.getContext()),
1770-
rewriter.getMultiDimIdentityMap(rank)};
1771-
SmallVector<utils::IteratorType> findMaxIteratorTypes(
1772-
rank, utils::IteratorType::parallel);
1773-
1774-
Value emptyMask = rewriter.create<tensor::EmptyOp>(
1775-
loc, maskShape, maskTy.getElementType(), maskDynDims);
1776-
Value newMask =
1777-
rewriter
1778-
.create<linalg::GenericOp>(
1779-
loc, emptyMask.getType(), mask, ValueRange({emptyMask}),
1780-
affineMaps, findMaxIteratorTypes,
1781-
[&](OpBuilder &b, Location loc, ValueRange args) {
1782-
b.create<linalg::YieldOp>(loc, args[0]);
1783-
})
1784-
.getResult(0);
1785-
mask = newMask;
1786-
}
1787-
}
1788-
17891730
if (!isa<Torch::NoneType>(scale.getType())) {
17901731
double scaleFloat;
17911732
if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) ||
17921733
scaleFloat != 1.0)
1793-
return rewriter.notifyMatchFailure(loc, "only default scale supported");
1734+
return rewriter.notifyMatchFailure(op.getLoc(),
1735+
"only default scale supported");
17941736
}
17951737
bool isGQAEnabled;
17961738
if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled)) ||
17971739
isGQAEnabled)
17981740
return rewriter.notifyMatchFailure(
1799-
loc, "grouped query attention not supported");
1741+
op.getLoc(), "grouped query attention not supported");
18001742

18011743
if (queryTy.getRank() != valueTy.getRank() ||
18021744
queryTy.getRank() != keyTy.getRank())
@@ -1811,6 +1753,7 @@ class ConvertAtenScaledDotProductAttentionOp
18111753
reassociation[1].push_back(valueTy.getRank() - 2);
18121754
reassociation[2].push_back(valueTy.getRank() - 1);
18131755

1756+
auto loc = op.getLoc();
18141757
auto collapseBatch = [&rewriter, &reassociation,
18151758
loc](Value value) -> Value {
18161759
auto valueTy = cast<ShapedType>(value.getType());
@@ -1845,12 +1788,13 @@ class ConvertAtenScaledDotProductAttentionOp
18451788
SmallVector<int64_t> valueSizes(
18461789
cast<ShapedType>(value.getType()).getShape());
18471790
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
1848-
SmallVector<Value> outSizesDynamic(getTensorSizes(rewriter, loc, query));
1791+
SmallVector<Value> outSizesDynamic(
1792+
getTensorSizes(rewriter, op.getLoc(), query));
18491793
outSizesDynamic[outSizesDynamic.size() - 1] =
1850-
getTensorSizes(rewriter, loc, value)[valueSizes.size() - 1];
1794+
getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1];
18511795
Type outType = RankedTensorType::get(outSizes, elementType);
1852-
Value output =
1853-
createZeroInitTensor(rewriter, loc, outSizesDynamic, elementType);
1796+
Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic,
1797+
elementType);
18541798

18551799
SmallVector<Value> inputs = SmallVector<Value>{query, key, value};
18561800

projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5501,7 +5501,7 @@ def __init__(self):
55015501
([2, 3, 8, 16], torch.float32, True),
55025502
([2, 3, 12, 16], torch.float32, True),
55035503
([2, 3, 12, 20], torch.float32, True),
5504-
([2, 1, 8, 12], torch.float32, True),
5504+
([2, 3, 8, 12], torch.float32, True),
55055505
]
55065506
)
55075507
def forward(self, query, key, value, mask):
@@ -5513,7 +5513,7 @@ def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils):
55135513
query = torch.randn(2, 3, 8, 16, dtype=torch.float32)
55145514
key = torch.randn(2, 3, 12, 16, dtype=torch.float32)
55155515
value = torch.randn(2, 3, 12, 20, dtype=torch.float32)
5516-
mask = torch.randn(2, 1, 8, 12, dtype=torch.float32)
5516+
mask = torch.randn(2, 3, 8, 12, dtype=torch.float32)
55175517
module.forward(query, key, value, mask)
55185518

55195519

0 commit comments

Comments
 (0)