@@ -197,50 +197,45 @@ fn vectorization_default<'a, R: Runtime>(
197
197
if let Some ( ( s, o, mr, dims) ) = swapped. iter ( ) . find ( |( _s, o, _mr, _dims) | o. id == tensor. id )
198
198
{
199
199
let val = vectorization_swapped ( handle, s, o, * mr, dims) ;
200
- multi_reads_vectorization_update ( vectorizations, o. id , s . id , val) ;
200
+ multi_reads_vectorization_update ( vectorizations, o. id , val) ;
201
201
} else {
202
202
let val = vectorization_input ( handle, tensor) ;
203
203
vectorizations. insert ( tensor. id , val) ;
204
204
}
205
205
}
206
206
207
+ for ( reshaped, original, multi_reads) in reshaped {
208
+ let val = vectorization_reshape ( reshaped, original, multi_reads) ;
209
+ multi_reads_vectorization_update ( vectorizations, original. id , val) ;
210
+ }
211
+
207
212
for tensor in outputs {
208
213
let val = vectorization_output ( tensor) ;
209
214
vectorizations. insert ( tensor. id , val) ;
210
215
}
211
-
212
- for ( reshaped, original, multi_reads) in reshaped {
213
- let val = vectorization_reshape ( reshaped, original, multi_reads) ;
214
- multi_reads_vectorization_update ( vectorizations, original. id , reshaped. id , val) ;
215
- }
216
216
}
217
217
218
218
fn multi_reads_vectorization_update (
219
219
vectorizations : & mut BTreeMap < TensorId , Vect > ,
220
220
original : TensorId ,
221
- view : TensorId ,
222
221
vect : Vect ,
223
222
) {
224
223
if let Some ( ori_vect) = vectorizations. get ( & original) . cloned ( ) {
225
224
match ori_vect {
226
225
Vect :: Broadcasted => {
227
226
// keep the original as is.
228
- vectorizations. insert ( view, vect. limit_to_one ( ) ) ;
229
227
}
230
228
Vect :: Aligned ( ori) => match vect {
231
229
Vect :: Broadcasted => {
232
230
vectorizations. insert ( original, Vect :: Aligned ( 1 ) ) ;
233
- vectorizations. insert ( view, vect. limit_to_one ( ) ) ;
234
231
}
235
232
Vect :: Aligned ( new) => {
236
233
let val = if new != ori { 1 } else { new } ;
237
234
vectorizations. insert ( original, Vect :: Aligned ( val) ) ;
238
- vectorizations. insert ( view, Vect :: Aligned ( val) ) ;
239
235
}
240
236
} ,
241
237
} ;
242
238
} else {
243
239
vectorizations. insert ( original, vect) ;
244
- vectorizations. insert ( view, vect) ;
245
240
}
246
241
}
0 commit comments