Skip to content

Commit e096b0c

Browse files
Fix vectorization problem with fusion on reshaped not contiguous tensors (#3075)
* Fix vectorization problem with fusion on reshaped not contiguous tensors that are written to * Remove prints
1 parent c71b1cd commit e096b0c

File tree

1 file changed

+6
-11
lines changed
  • crates/burn-cubecl-fusion/src/shared/trace

1 file changed

+6
-11
lines changed

crates/burn-cubecl-fusion/src/shared/trace/runner.rs

+6-11
Original file line numberDiff line numberDiff line change
@@ -197,50 +197,45 @@ fn vectorization_default<'a, R: Runtime>(
197197
if let Some((s, o, mr, dims)) = swapped.iter().find(|(_s, o, _mr, _dims)| o.id == tensor.id)
198198
{
199199
let val = vectorization_swapped(handle, s, o, *mr, dims);
200-
multi_reads_vectorization_update(vectorizations, o.id, s.id, val);
200+
multi_reads_vectorization_update(vectorizations, o.id, val);
201201
} else {
202202
let val = vectorization_input(handle, tensor);
203203
vectorizations.insert(tensor.id, val);
204204
}
205205
}
206206

207+
for (reshaped, original, multi_reads) in reshaped {
208+
let val = vectorization_reshape(reshaped, original, multi_reads);
209+
multi_reads_vectorization_update(vectorizations, original.id, val);
210+
}
211+
207212
for tensor in outputs {
208213
let val = vectorization_output(tensor);
209214
vectorizations.insert(tensor.id, val);
210215
}
211-
212-
for (reshaped, original, multi_reads) in reshaped {
213-
let val = vectorization_reshape(reshaped, original, multi_reads);
214-
multi_reads_vectorization_update(vectorizations, original.id, reshaped.id, val);
215-
}
216216
}
217217

218218
fn multi_reads_vectorization_update(
219219
vectorizations: &mut BTreeMap<TensorId, Vect>,
220220
original: TensorId,
221-
view: TensorId,
222221
vect: Vect,
223222
) {
224223
if let Some(ori_vect) = vectorizations.get(&original).cloned() {
225224
match ori_vect {
226225
Vect::Broadcasted => {
227226
// keep the original as is.
228-
vectorizations.insert(view, vect.limit_to_one());
229227
}
230228
Vect::Aligned(ori) => match vect {
231229
Vect::Broadcasted => {
232230
vectorizations.insert(original, Vect::Aligned(1));
233-
vectorizations.insert(view, vect.limit_to_one());
234231
}
235232
Vect::Aligned(new) => {
236233
let val = if new != ori { 1 } else { new };
237234
vectorizations.insert(original, Vect::Aligned(val));
238-
vectorizations.insert(view, Vect::Aligned(val));
239235
}
240236
},
241237
};
242238
} else {
243239
vectorizations.insert(original, vect);
244-
vectorizations.insert(view, vect);
245240
}
246241
}

0 commit comments

Comments
 (0)