1
1
#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
2
2
#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
3
3
#include " triton/Dialect/Triton/IR/Utility.h"
4
+ #include " triton/Dialect/TritonGPU/Transforms/Utility.h"
4
5
#include " llvm/ADT/PriorityWorklist.h"
5
6
6
7
namespace ttg = mlir::triton::gpu;
@@ -16,45 +17,11 @@ namespace gpu::intel {
16
17
17
18
namespace {
18
19
19
- SmallVector<Value> getTiedArgs (Operation *op, int resultIdx) {
20
- if (auto forOp = dyn_cast<scf::ForOp>(op)) {
21
- auto iterArg = forOp.getRegionIterArg (resultIdx);
22
- auto result = forOp.getResult (resultIdx);
23
- auto yieldVal = forOp.getBody ()->getTerminator ()->getOperand (resultIdx);
24
- auto initVal = forOp.getInitArgs ()[resultIdx];
25
- return {iterArg, result, yieldVal, initVal};
26
- } else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
27
- auto iterArg = whileOp.getBeforeArguments ()[resultIdx];
28
- auto result = whileOp.getResults ()[resultIdx];
29
- auto yieldVal = whileOp.getConditionOp ().getArgs ()[resultIdx];
30
- auto initVal = whileOp.getOperands ()[resultIdx];
31
- auto bodyArg = whileOp.getAfterArguments ()[resultIdx];
32
- return {iterArg, result, yieldVal, initVal, bodyArg};
33
- } else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
34
- SmallVector<Value> values;
35
- for (auto &block : ifOp.getThenRegion ().getBlocks ()) {
36
- auto terminator = block.getTerminator ();
37
- if (isa<scf::YieldOp>(terminator))
38
- values.push_back (terminator->getOperands ()[resultIdx]);
39
- }
40
- for (auto &block : ifOp.getElseRegion ().getBlocks ()) {
41
- auto terminator = block.getTerminator ();
42
- if (isa<scf::YieldOp>(terminator))
43
- values.push_back (terminator->getOperands ()[resultIdx]);
44
- }
45
- values.push_back (ifOp->getResults ()[resultIdx]);
46
- return values;
47
- }
48
- return {};
49
- }
50
-
51
20
struct EncodingInfo {
52
21
Attribute desiredEncoding;
53
- bool requiresConvert = false ;
54
22
55
23
bool operator ==(const EncodingInfo &other) const {
56
- return desiredEncoding == other.desiredEncoding &&
57
- requiresConvert == other.requiresConvert ;
24
+ return desiredEncoding == other.desiredEncoding ;
58
25
}
59
26
};
60
27
@@ -77,10 +44,6 @@ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) {
77
44
78
45
auto updateEncoding = [&](ArrayRef<Value> ptrValues, EncodingInfo info) {
79
46
for (auto value : ptrValues) {
80
- bool requiresConvert = llvm::any_of (
81
- value.getUsers (), [](auto user) { return isa<LoadOp>(user); });
82
- info.requiresConvert = requiresConvert;
83
-
84
47
auto typedVal = cast<TypedValue<PointerType>>(value);
85
48
auto itr = valueToEncodingInfo.find (typedVal);
86
49
if (itr == valueToEncodingInfo.end ()) {
@@ -157,24 +120,22 @@ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) {
157
120
oldTensorTy.getShape (), oldTensorTy.getElementType (), newEncoding);
158
121
159
122
val.setType (PointerType::get (newTensorTy, oldType.getAddressSpace ()));
160
- if (einfo.requiresConvert ) {
161
- for (auto user : val.getUsers ()) {
162
- if (auto loadOp = dyn_cast<LoadOp>(user)) {
163
-
164
- OpBuilder builder (loadOp);
165
- auto oldLoadType = loadOp.getType ();
166
- Value result = loadOp.getResult ();
167
-
168
- builder.setInsertionPointAfter (loadOp);
169
- auto cvt = builder.create <ConvertLayoutOp>(loadOp.getLoc (),
170
- result.getType (), result);
171
- LLVM_DEBUG (DBGS () << " Added convert Op:\n "
172
- << cvt << " after Load Op:\n "
173
- << loadOp << " \n " );
174
- result.setType (newTensorTy);
175
-
176
- result.replaceAllUsesExcept (cvt.getResult (), cvt.getOperation ());
177
- }
123
+ for (auto user : val.getUsers ()) {
124
+ if (auto loadOp = dyn_cast<LoadOp>(user)) {
125
+
126
+ OpBuilder builder (loadOp);
127
+ auto oldLoadType = loadOp.getType ();
128
+ Value result = loadOp.getResult ();
129
+
130
+ builder.setInsertionPointAfter (loadOp);
131
+ auto cvt = builder.create <ConvertLayoutOp>(loadOp.getLoc (),
132
+ result.getType (), result);
133
+ LLVM_DEBUG (DBGS () << " Added convert Op:\n "
134
+ << cvt << " after Load Op:\n "
135
+ << loadOp << " \n " );
136
+ result.setType (newTensorTy);
137
+
138
+ result.replaceAllUsesExcept (cvt.getResult (), cvt.getOperation ());
178
139
}
179
140
}
180
141
}
0 commit comments