Skip to content

Commit a43cbe7

Browse files
address review comments
1 parent 8d56bd6 commit a43cbe7

File tree

1 file changed

+18
-57
lines changed

1 file changed

+18
-57
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp

Lines changed: 18 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
22
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
33
#include "triton/Dialect/Triton/IR/Utility.h"
4+
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
45
#include "llvm/ADT/PriorityWorklist.h"
56

67
namespace ttg = mlir::triton::gpu;
@@ -16,45 +17,11 @@ namespace gpu::intel {
1617

1718
namespace {
1819

19-
SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) {
20-
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
21-
auto iterArg = forOp.getRegionIterArg(resultIdx);
22-
auto result = forOp.getResult(resultIdx);
23-
auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx);
24-
auto initVal = forOp.getInitArgs()[resultIdx];
25-
return {iterArg, result, yieldVal, initVal};
26-
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
27-
auto iterArg = whileOp.getBeforeArguments()[resultIdx];
28-
auto result = whileOp.getResults()[resultIdx];
29-
auto yieldVal = whileOp.getConditionOp().getArgs()[resultIdx];
30-
auto initVal = whileOp.getOperands()[resultIdx];
31-
auto bodyArg = whileOp.getAfterArguments()[resultIdx];
32-
return {iterArg, result, yieldVal, initVal, bodyArg};
33-
} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
34-
SmallVector<Value> values;
35-
for (auto &block : ifOp.getThenRegion().getBlocks()) {
36-
auto terminator = block.getTerminator();
37-
if (isa<scf::YieldOp>(terminator))
38-
values.push_back(terminator->getOperands()[resultIdx]);
39-
}
40-
for (auto &block : ifOp.getElseRegion().getBlocks()) {
41-
auto terminator = block.getTerminator();
42-
if (isa<scf::YieldOp>(terminator))
43-
values.push_back(terminator->getOperands()[resultIdx]);
44-
}
45-
values.push_back(ifOp->getResults()[resultIdx]);
46-
return values;
47-
}
48-
return {};
49-
}
50-
5120
struct EncodingInfo {
5221
Attribute desiredEncoding;
53-
bool requiresConvert = false;
5422

5523
bool operator==(const EncodingInfo &other) const {
56-
return desiredEncoding == other.desiredEncoding &&
57-
requiresConvert == other.requiresConvert;
24+
return desiredEncoding == other.desiredEncoding;
5825
}
5926
};
6027

@@ -77,10 +44,6 @@ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) {
7744

7845
auto updateEncoding = [&](ArrayRef<Value> ptrValues, EncodingInfo info) {
7946
for (auto value : ptrValues) {
80-
bool requiresConvert = llvm::any_of(
81-
value.getUsers(), [](auto user) { return isa<LoadOp>(user); });
82-
info.requiresConvert = requiresConvert;
83-
8447
auto typedVal = cast<TypedValue<PointerType>>(value);
8548
auto itr = valueToEncodingInfo.find(typedVal);
8649
if (itr == valueToEncodingInfo.end()) {
@@ -157,24 +120,22 @@ void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) {
157120
oldTensorTy.getShape(), oldTensorTy.getElementType(), newEncoding);
158121

159122
val.setType(PointerType::get(newTensorTy, oldType.getAddressSpace()));
160-
if (einfo.requiresConvert) {
161-
for (auto user : val.getUsers()) {
162-
if (auto loadOp = dyn_cast<LoadOp>(user)) {
163-
164-
OpBuilder builder(loadOp);
165-
auto oldLoadType = loadOp.getType();
166-
Value result = loadOp.getResult();
167-
168-
builder.setInsertionPointAfter(loadOp);
169-
auto cvt = builder.create<ConvertLayoutOp>(loadOp.getLoc(),
170-
result.getType(), result);
171-
LLVM_DEBUG(DBGS() << "Added convert Op:\n"
172-
<< cvt << " after Load Op:\n"
173-
<< loadOp << "\n");
174-
result.setType(newTensorTy);
175-
176-
result.replaceAllUsesExcept(cvt.getResult(), cvt.getOperation());
177-
}
123+
for (auto user : val.getUsers()) {
124+
if (auto loadOp = dyn_cast<LoadOp>(user)) {
125+
126+
OpBuilder builder(loadOp);
127+
auto oldLoadType = loadOp.getType();
128+
Value result = loadOp.getResult();
129+
130+
builder.setInsertionPointAfter(loadOp);
131+
auto cvt = builder.create<ConvertLayoutOp>(loadOp.getLoc(),
132+
result.getType(), result);
133+
LLVM_DEBUG(DBGS() << "Added convert Op:\n"
134+
<< cvt << " after Load Op:\n"
135+
<< loadOp << "\n");
136+
result.setType(newTensorTy);
137+
138+
result.replaceAllUsesExcept(cvt.getResult(), cvt.getOperation());
178139
}
179140
}
180141
}

0 commit comments

Comments
 (0)