1
1
use core:: hash:: Hash ;
2
2
use std:: {
3
+ collections:: HashMap ,
3
4
fmt:: Debug ,
4
5
sync:: Arc ,
5
6
} ;
@@ -68,7 +69,8 @@ pub struct AsyncLru<RT: Runtime, Key, Value> {
68
69
pause_client : Option < Arc < tokio:: sync:: Mutex < PauseClient > > > ,
69
70
}
70
71
71
- pub type ValueGenerator < Value > = BoxFuture < ' static , anyhow:: Result < Value > > ;
72
+ pub type SingleValueGenerator < Value > = BoxFuture < ' static , anyhow:: Result < Value > > ;
73
+ pub type ValueGenerator < Key , Value > = BoxFuture < ' static , HashMap < Key , anyhow:: Result < Value > > > ;
72
74
73
75
impl < RT : Runtime , Key , Value > Clone for AsyncLru < RT , Key , Value > {
74
76
fn clone ( & self ) -> Self {
@@ -142,7 +144,7 @@ type BuildValueResult<Value> = Result<Arc<Value>, Arc<anyhow::Error>>;
142
144
143
145
type BuildValueRequest < Key , Value > = (
144
146
Key ,
145
- ValueGenerator < Value > ,
147
+ ValueGenerator < Key , Value > ,
146
148
async_broadcast:: Sender < BuildValueResult < Value > > ,
147
149
) ;
148
150
@@ -247,8 +249,8 @@ impl<
247
249
inner. current_size += new_value. size ( ) ;
248
250
// Ideally we'd not change the LRU order by putting here...
249
251
if let Some ( old_value) = inner. cache . put ( key, new_value) {
250
- anyhow :: ensure! ( !matches! ( old_value , CacheResult :: Ready { .. } ) ) ;
251
- // Just in case we ever assign a size to our Waiting entries .
252
+ // Allow overwriting entries (Waiting or Ready) which may have been populated
253
+ // by racing requests with prefetches .
252
254
inner. current_size -= old_value. size ( ) ;
253
255
}
254
256
Self :: trim_to_size ( & mut inner) ;
@@ -300,21 +302,45 @@ impl<
300
302
inner. current_size
301
303
}
302
304
303
- pub async fn get (
305
+ pub async fn get_and_prepopulate (
304
306
& self ,
305
307
key : Key ,
306
- value_generator : ValueGenerator < Value > ,
308
+ value_generator : ValueGenerator < Key , Value > ,
307
309
) -> anyhow:: Result < Arc < Value > > {
308
310
let timer = async_lru_get_timer ( self . label ) ;
309
311
let result = self . _get ( & key, value_generator) . await ;
310
312
timer. finish ( result. is_ok ( ) ) ;
311
313
result
312
314
}
313
315
316
+ pub async fn get (
317
+ & self ,
318
+ key : Key ,
319
+ value_generator : SingleValueGenerator < Value > ,
320
+ ) -> anyhow:: Result < Arc < Value > >
321
+ where
322
+ Key : Clone ,
323
+ {
324
+ let timer = async_lru_get_timer ( self . label ) ;
325
+ let key_ = key. clone ( ) ;
326
+ let result = self
327
+ . _get (
328
+ & key_,
329
+ Box :: pin ( async move {
330
+ let mut hashmap = HashMap :: new ( ) ;
331
+ hashmap. insert ( key, value_generator. await ) ;
332
+ hashmap
333
+ } ) ,
334
+ )
335
+ . await ;
336
+ timer. finish ( result. is_ok ( ) ) ;
337
+ result
338
+ }
339
+
314
340
async fn _get (
315
341
& self ,
316
342
key : & Key ,
317
- value_generator : ValueGenerator < Value > ,
343
+ value_generator : ValueGenerator < Key , Value > ,
318
344
) -> anyhow:: Result < Arc < Value > > {
319
345
match self . get_sync ( key, value_generator) ? {
320
346
Status :: Ready ( value) => Ok ( value) ,
@@ -336,7 +362,7 @@ impl<
336
362
fn get_sync (
337
363
& self ,
338
364
key : & Key ,
339
- value_generator : ValueGenerator < Value > ,
365
+ value_generator : ValueGenerator < Key , Value > ,
340
366
) -> anyhow:: Result < Status < Value > > {
341
367
let mut inner = self . inner . lock ( ) ;
342
368
log_async_lru_size ( inner. cache . len ( ) , inner. current_size , self . label ) ;
@@ -407,10 +433,16 @@ impl<
407
433
return ;
408
434
}
409
435
410
- let value = generator. await ;
436
+ let values = generator. await ;
411
437
412
- let to_broadcast = Self :: update_value ( rt, inner, key, value) . map_err ( Arc :: new) ;
413
- let _ = tx. broadcast ( to_broadcast) . await ;
438
+ for ( k, value) in values {
439
+ let is_requested_key = k == key;
440
+ let to_broadcast =
441
+ Self :: update_value ( rt. clone ( ) , inner. clone ( ) , k, value) . map_err ( Arc :: new) ;
442
+ if is_requested_key {
443
+ let _ = tx. broadcast ( to_broadcast) . await ;
444
+ }
445
+ }
414
446
}
415
447
} )
416
448
. await ;
@@ -420,7 +452,10 @@ impl<
420
452
421
453
#[ cfg( test) ]
422
454
mod tests {
423
- use std:: sync:: Arc ;
455
+ use std:: {
456
+ collections:: HashMap ,
457
+ sync:: Arc ,
458
+ } ;
424
459
425
460
use common:: {
426
461
pause:: PauseController ,
@@ -536,6 +571,36 @@ mod tests {
536
571
Ok ( ( ) )
537
572
}
538
573
574
+ #[ convex_macro:: test_runtime]
575
+ async fn test_get_and_prepopulate ( rt : TestRuntime ) -> anyhow:: Result < ( ) > {
576
+ let cache = AsyncLru :: new ( rt, 10 , 1 , "label" ) ;
577
+ let first = cache
578
+ . get_and_prepopulate (
579
+ "k1" ,
580
+ async move {
581
+ let mut hashmap = HashMap :: new ( ) ;
582
+ hashmap. insert ( "k1" , Ok ( 1 ) ) ;
583
+ hashmap. insert ( "k2" , Ok ( 2 ) ) ;
584
+ hashmap. insert ( "k3" , Err ( anyhow:: anyhow!( "k3 failed" ) ) ) ;
585
+ hashmap
586
+ }
587
+ . boxed ( ) ,
588
+ )
589
+ . await ?;
590
+ assert_eq ! ( * first, 1 ) ;
591
+ let k1_again = cache
592
+ . get ( "k1" , GenerateRandomValue :: generate_value ( "k1" ) . boxed ( ) )
593
+ . await ?;
594
+ assert_eq ! ( * k1_again, 1 ) ;
595
+ let k2_prepopulated = cache
596
+ . get ( "k2" , GenerateRandomValue :: generate_value ( "k2" ) . boxed ( ) )
597
+ . await ?;
598
+ assert_eq ! ( * k2_prepopulated, 2 ) ;
599
+ let k3_prepopulated = cache. get ( "k3" , async move { Ok ( 3 ) } . boxed ( ) ) . await ?;
600
+ assert_eq ! ( * k3_prepopulated, 3 ) ;
601
+ Ok ( ( ) )
602
+ }
603
+
539
604
#[ convex_macro:: test_runtime]
540
605
async fn get_generates_new_value_after_eviction ( rt : TestRuntime ) -> anyhow:: Result < ( ) > {
541
606
let cache = AsyncLru :: new ( rt, 1 , 1 , "label" ) ;
0 commit comments