@@ -1130,27 +1130,39 @@ impl ServerActor {
1130
1130
self . device_manager
1131
1131
. await_event ( request_streams, & current_exchange_event) ;
1132
1132
1133
- record_stream_time ! ( & self . device_manager, batch_streams, events, "db_reduce" , {
1134
- compact_device_sums. compute_dot_reducer_against_db(
1135
- & mut self . codes_engine,
1136
- & mut self . masks_engine,
1137
- code_db_slices,
1138
- mask_db_slices,
1139
- & dot_chunk_size,
1140
- offset,
1141
- request_streams,
1142
- ) ;
1143
- } ) ;
1133
+ record_stream_time ! (
1134
+ & self . device_manager,
1135
+ request_streams,
1136
+ events,
1137
+ "db_reduce" ,
1138
+ {
1139
+ compact_device_sums. compute_dot_reducer_against_db(
1140
+ & mut self . codes_engine,
1141
+ & mut self . masks_engine,
1142
+ code_db_slices,
1143
+ mask_db_slices,
1144
+ & dot_chunk_size,
1145
+ offset,
1146
+ request_streams,
1147
+ ) ;
1148
+ }
1149
+ ) ;
1144
1150
1145
1151
self . device_manager
1146
1152
. record_event ( request_streams, & next_dot_event) ;
1147
1153
1148
- record_stream_time ! ( & self . device_manager, batch_streams, events, "db_reshare" , {
1149
- self . codes_engine
1150
- . reshare_results( & dot_chunk_size, request_streams) ;
1151
- self . masks_engine
1152
- . reshare_results( & dot_chunk_size, request_streams) ;
1153
- } ) ;
1154
+ record_stream_time ! (
1155
+ & self . device_manager,
1156
+ request_streams,
1157
+ events,
1158
+ "db_reshare" ,
1159
+ {
1160
+ self . codes_engine
1161
+ . reshare_results( & dot_chunk_size, request_streams) ;
1162
+ self . masks_engine
1163
+ . reshare_results( & dot_chunk_size, request_streams) ;
1164
+ }
1165
+ ) ;
1154
1166
1155
1167
// ---- END PHASE 1 ----
1156
1168
@@ -1170,9 +1182,10 @@ impl ServerActor {
1170
1182
) ;
1171
1183
self . phase2
1172
1184
. set_chunk_size ( max_chunk_size * self . max_batch_size * ROTATIONS / 64 ) ;
1185
+
1173
1186
record_stream_time ! (
1174
1187
& self . device_manager,
1175
- batch_streams ,
1188
+ request_streams ,
1176
1189
events,
1177
1190
"db_threshold" ,
1178
1191
{
@@ -1190,20 +1203,22 @@ impl ServerActor {
1190
1203
. record_event ( request_streams, & next_exchange_event) ;
1191
1204
1192
1205
let res = self . phase2 . take_result_buffer ( ) ;
1193
- open (
1194
- & mut self . phase2 ,
1195
- & res,
1196
- & self . distance_comparator ,
1197
- db_match_bitmap,
1198
- max_chunk_size * self . max_batch_size * ROTATIONS / 64 ,
1199
- & dot_chunk_size,
1200
- & chunk_size,
1201
- offset,
1202
- & self . current_db_sizes ,
1203
- & ignore_device_results,
1204
- request_streams,
1205
- ) ;
1206
- self . phase2 . return_result_buffer ( res) ;
1206
+ record_stream_time ! ( & self . device_manager, request_streams, events, "db_open" , {
1207
+ open(
1208
+ & mut self . phase2,
1209
+ & res,
1210
+ & self . distance_comparator,
1211
+ db_match_bitmap,
1212
+ max_chunk_size * self . max_batch_size * ROTATIONS / 64 ,
1213
+ & dot_chunk_size,
1214
+ & chunk_size,
1215
+ offset,
1216
+ & self . current_db_sizes,
1217
+ & ignore_device_results,
1218
+ request_streams,
1219
+ ) ;
1220
+ self . phase2. return_result_buffer( res) ;
1221
+ } ) ;
1207
1222
}
1208
1223
self . device_manager
1209
1224
. record_event ( request_streams, & next_phase2_event) ;
@@ -1336,19 +1351,11 @@ impl ServerActor {
1336
1351
1337
1352
/// Internal helper function to log the timers of measured cuda streams.
1338
1353
fn log_timers ( events : HashMap < & str , Vec < Vec < CUevent > > > ) {
1339
- for ( name, event_vecs) in & events {
1354
+ for ( name, event_vecs) in events {
1355
+ assert ! ( event_vecs. len( ) % 2 == 0 ) ;
1340
1356
let duration: f32 = event_vecs
1341
- . chunks ( 2 )
1342
- . map ( |pair| {
1343
- let ( start_events, end_events) = ( & pair[ 0 ] , & pair[ 1 ] ) ;
1344
- let total_duration: f32 = start_events
1345
- . iter ( )
1346
- . zip ( end_events. iter ( ) )
1347
- . map ( |( start, end) | unsafe { elapsed ( * start, * end) } . unwrap ( ) )
1348
- . sum ( ) ;
1349
-
1350
- total_duration / start_events. len ( ) as f32
1351
- } )
1357
+ . iter ( )
1358
+ . map ( |pair| unsafe { elapsed ( pair[ 0 ] , pair[ 1 ] ) } . unwrap ( ) )
1352
1359
. sum ( ) ;
1353
1360
1354
1361
tracing:: info!( "Event {}: {:?} ms" , name, duration) ;
0 commit comments