@@ -1661,7 +1661,6 @@ class ConvertAtenScaledDotProductAttentionOp
1661
1661
auto valueTy = cast<ShapedType>(value.getType ());
1662
1662
auto keyTy = cast<ShapedType>(key.getType ());
1663
1663
1664
- auto loc = op.getLoc ();
1665
1664
Value dropoutP = op.getDropoutP ();
1666
1665
Value isCausal = op.getIsCausal ();
1667
1666
Value scale = op.getScale ();
@@ -1672,46 +1671,49 @@ class ConvertAtenScaledDotProductAttentionOp
1672
1671
double dropout;
1673
1672
if (!matchPattern (dropoutP, m_TorchConstantFloat (&dropout)) ||
1674
1673
dropout > 0.0 )
1675
- return rewriter.notifyMatchFailure (loc , " dropout not supported" );
1674
+ return rewriter.notifyMatchFailure (op. getLoc () , " dropout not supported" );
1676
1675
1677
1676
bool causal;
1678
1677
if (!matchPattern (isCausal, m_TorchConstantBool (&causal)) || causal) {
1679
1678
if (!isa<Torch::NoneType>(mask.getType ())) {
1680
1679
return rewriter.notifyMatchFailure (
1681
- loc , " expected no attention mask when isCausal is true" );
1680
+ op. getLoc () , " expected no attention mask when isCausal is true" );
1682
1681
}
1683
1682
1684
1683
SmallVector<int64_t > maskStatic;
1685
1684
SmallVector<Value> maskDyn;
1686
1685
for (int i = 0 , s = queryTy.getRank () - 1 ; i < s; ++i) {
1687
1686
maskStatic.push_back (queryTy.getDimSize (i));
1688
1687
if (maskStatic.back () == ShapedType::kDynamic )
1689
- maskDyn.push_back (rewriter.create <tensor::DimOp>(loc, query, i));
1688
+ maskDyn.push_back (
1689
+ rewriter.create <tensor::DimOp>(op.getLoc (), query, i));
1690
1690
}
1691
1691
1692
1692
maskStatic.push_back (keyTy.getDimSize (keyTy.getRank () - 2 ));
1693
1693
if (maskStatic.back () == ShapedType::kDynamic )
1694
- maskDyn.push_back (
1695
- rewriter. create <tensor::DimOp>(loc, key, keyTy.getRank () - 2 ));
1694
+ maskDyn.push_back (rewriter. create <tensor::DimOp>(op. getLoc (), key,
1695
+ keyTy.getRank () - 2 ));
1696
1696
1697
1697
Type maskType = getElementTypeOrSelf (queryTy);
1698
- Value emptyMask =
1699
- rewriter. create <tensor::EmptyOp>(loc , maskStatic, maskType, maskDyn);
1698
+ Value emptyMask = rewriter. create <tensor::EmptyOp>(
1699
+ op. getLoc () , maskStatic, maskType, maskDyn);
1700
1700
1701
1701
Value zero = rewriter.create <arith::ConstantOp>(
1702
- loc, rewriter.getFloatAttr (getElementTypeOrSelf (maskType), 0.0 ));
1702
+ op.getLoc (),
1703
+ rewriter.getFloatAttr (getElementTypeOrSelf (maskType), 0.0 ));
1703
1704
Value negInf = rewriter.create <arith::ConstantOp>(
1704
- loc ,
1705
+ op. getLoc () ,
1705
1706
rewriter.getFloatAttr (getElementTypeOrSelf (maskType), -INFINITY));
1706
1707
1707
- mask = rewriter.create <linalg::FillOp>(loc, zero, emptyMask).getResult (0 );
1708
+ mask = rewriter.create <linalg::FillOp>(op.getLoc (), zero, emptyMask)
1709
+ .getResult (0 );
1708
1710
1709
1711
int64_t rank = cast<ShapedType>(queryTy).getRank ();
1710
1712
AffineMap maskMap = rewriter.getMultiDimIdentityMap (rank);
1711
1713
SmallVector<utils::IteratorType> iteratorTypes (
1712
1714
rank, utils::IteratorType::parallel);
1713
1715
auto genericOp = rewriter.create <linalg::GenericOp>(
1714
- loc , mask.getType (), ValueRange{}, mask,
1716
+ op. getLoc () , mask.getType (), ValueRange{}, mask,
1715
1717
SmallVector<AffineMap>{maskMap}, iteratorTypes,
1716
1718
[&](OpBuilder &b, Location loc, ValueRange args) {
1717
1719
Value i = b.create <linalg::IndexOp>(loc, queryTy.getRank () - 2 );
@@ -1725,78 +1727,18 @@ class ConvertAtenScaledDotProductAttentionOp
1725
1727
mask = genericOp.getResult (0 );
1726
1728
}
1727
1729
1728
- // Broadcast the batch dimensions of the mask:
1729
- if (!isa<Torch::NoneType>(mask.getType ())) {
1730
- auto maskTy = cast<RankedTensorType>(mask.getType ());
1731
- int64_t rank = maskTy.getRank ();
1732
- bool needsBroadcast = false ;
1733
- for (int i = 0 , s = rank - 2 ; i < s; ++i) {
1734
- needsBroadcast |= maskTy.getDimSize (i) != keyTy.getDimSize (i);
1735
- }
1736
-
1737
- if (needsBroadcast) {
1738
- SmallVector<int64_t > maskShape;
1739
- SmallVector<Value> maskDynDims;
1740
-
1741
- SmallVector<AffineExpr> maskExprs;
1742
- for (int i = 0 , s = rank - 2 ; i < s; ++i) {
1743
- maskShape.push_back (keyTy.getDimSize (i));
1744
-
1745
- if (maskTy.getDimSize (i) != keyTy.getDimSize (i)) {
1746
- maskExprs.push_back (rewriter.getAffineConstantExpr (0 ));
1747
- } else {
1748
- maskExprs.push_back (rewriter.getAffineDimExpr (i));
1749
- }
1750
-
1751
- if (keyTy.isDynamicDim (i)) {
1752
- maskDynDims.push_back (rewriter.create <tensor::DimOp>(loc, key, i));
1753
- }
1754
- }
1755
-
1756
- maskExprs.push_back (rewriter.getAffineDimExpr (rank - 2 ));
1757
- maskExprs.push_back (rewriter.getAffineDimExpr (rank - 1 ));
1758
- maskShape.push_back (maskTy.getDimSize (rank - 2 ));
1759
- maskShape.push_back (maskTy.getDimSize (rank - 1 ));
1760
- if (maskTy.isDynamicDim (rank - 2 ))
1761
- maskDynDims.push_back (
1762
- rewriter.create <tensor::DimOp>(loc, mask, rank - 2 ));
1763
- if (maskTy.isDynamicDim (rank - 1 ))
1764
- maskDynDims.push_back (
1765
- rewriter.create <tensor::DimOp>(loc, mask, rank - 1 ));
1766
-
1767
- SmallVector<AffineMap> affineMaps = {
1768
- AffineMap::get (/* dimCount=*/ rank, /* symbolCount=*/ 0 , maskExprs,
1769
- op.getContext ()),
1770
- rewriter.getMultiDimIdentityMap (rank)};
1771
- SmallVector<utils::IteratorType> findMaxIteratorTypes (
1772
- rank, utils::IteratorType::parallel);
1773
-
1774
- Value emptyMask = rewriter.create <tensor::EmptyOp>(
1775
- loc, maskShape, maskTy.getElementType (), maskDynDims);
1776
- Value newMask =
1777
- rewriter
1778
- .create <linalg::GenericOp>(
1779
- loc, emptyMask.getType (), mask, ValueRange ({emptyMask}),
1780
- affineMaps, findMaxIteratorTypes,
1781
- [&](OpBuilder &b, Location loc, ValueRange args) {
1782
- b.create <linalg::YieldOp>(loc, args[0 ]);
1783
- })
1784
- .getResult (0 );
1785
- mask = newMask;
1786
- }
1787
- }
1788
-
1789
1730
if (!isa<Torch::NoneType>(scale.getType ())) {
1790
1731
double scaleFloat;
1791
1732
if (!matchPattern (scale, m_TorchConstantFloat (&scaleFloat)) ||
1792
1733
scaleFloat != 1.0 )
1793
- return rewriter.notifyMatchFailure (loc, " only default scale supported" );
1734
+ return rewriter.notifyMatchFailure (op.getLoc (),
1735
+ " only default scale supported" );
1794
1736
}
1795
1737
bool isGQAEnabled;
1796
1738
if (!matchPattern (enableGQA, m_TorchConstantBool (&isGQAEnabled)) ||
1797
1739
isGQAEnabled)
1798
1740
return rewriter.notifyMatchFailure (
1799
- loc , " grouped query attention not supported" );
1741
+ op. getLoc () , " grouped query attention not supported" );
1800
1742
1801
1743
if (queryTy.getRank () != valueTy.getRank () ||
1802
1744
queryTy.getRank () != keyTy.getRank ())
@@ -1811,6 +1753,7 @@ class ConvertAtenScaledDotProductAttentionOp
1811
1753
reassociation[1 ].push_back (valueTy.getRank () - 2 );
1812
1754
reassociation[2 ].push_back (valueTy.getRank () - 1 );
1813
1755
1756
+ auto loc = op.getLoc ();
1814
1757
auto collapseBatch = [&rewriter, &reassociation,
1815
1758
loc](Value value) -> Value {
1816
1759
auto valueTy = cast<ShapedType>(value.getType ());
@@ -1845,12 +1788,13 @@ class ConvertAtenScaledDotProductAttentionOp
1845
1788
SmallVector<int64_t > valueSizes (
1846
1789
cast<ShapedType>(value.getType ()).getShape ());
1847
1790
outSizes[outSizes.size () - 1 ] = valueSizes[valueSizes.size () - 1 ];
1848
- SmallVector<Value> outSizesDynamic (getTensorSizes (rewriter, loc, query));
1791
+ SmallVector<Value> outSizesDynamic (
1792
+ getTensorSizes (rewriter, op.getLoc (), query));
1849
1793
outSizesDynamic[outSizesDynamic.size () - 1 ] =
1850
- getTensorSizes (rewriter, loc , value)[valueSizes.size () - 1 ];
1794
+ getTensorSizes (rewriter, op. getLoc () , value)[valueSizes.size () - 1 ];
1851
1795
Type outType = RankedTensorType::get (outSizes, elementType);
1852
- Value output =
1853
- createZeroInitTensor (rewriter, loc, outSizesDynamic, elementType);
1796
+ Value output = createZeroInitTensor (rewriter, op. getLoc (), outSizesDynamic,
1797
+ elementType);
1854
1798
1855
1799
SmallVector<Value> inputs = SmallVector<Value>{query, key, value};
1856
1800
0 commit comments