@@ -5,7 +5,6 @@ use once_cell::sync::Lazy;
5
5
use reqwest:: Url ;
6
6
use std:: hash:: Hasher ;
7
7
use std:: sync:: Arc ;
8
- use std:: sync:: Mutex ;
9
8
use tokio_util:: compat:: FuturesAsyncWriteCompatExt ;
10
9
use tracing:: debug;
11
10
use tracing:: instrument;
@@ -25,17 +24,16 @@ use ort::session::Session;
25
24
26
25
use crate :: onnx:: ensure_onnx_env_init;
27
26
28
- static SESSIONS : Lazy < DashMap < String , Arc < Mutex < Session > > > > =
29
- Lazy :: new ( DashMap :: new) ;
27
+ static SESSIONS : Lazy < DashMap < String , Arc < Session > > > = Lazy :: new ( DashMap :: new) ;
30
28
31
29
#[ derive( Debug ) ]
32
30
pub struct SessionWithId {
33
31
pub ( crate ) id : String ,
34
- pub ( crate ) session : Arc < Mutex < Session > > ,
32
+ pub ( crate ) session : Arc < Session > ,
35
33
}
36
34
37
- impl From < ( String , Arc < Mutex < Session > > ) > for SessionWithId {
38
- fn from ( value : ( String , Arc < Mutex < Session > > ) ) -> Self {
35
+ impl From < ( String , Arc < Session > ) > for SessionWithId {
36
+ fn from ( value : ( String , Arc < Session > ) ) -> Self {
39
37
Self {
40
38
id : value. 0 ,
41
39
session : value. 1 ,
@@ -50,7 +48,7 @@ impl std::fmt::Display for SessionWithId {
50
48
}
51
49
52
50
impl SessionWithId {
53
- pub fn into_split ( self ) -> ( String , Arc < Mutex < Session > > ) {
51
+ pub fn into_split ( self ) -> ( String , Arc < Session > ) {
54
52
( self . id , self . session )
55
53
}
56
54
}
@@ -106,7 +104,7 @@ fn get_execution_providers() -> Vec<ExecutionProviderDispatch> {
106
104
[ cpu] . to_vec ( )
107
105
}
108
106
109
- fn create_session ( model_bytes : & [ u8 ] ) -> Result < Arc < Mutex < Session > > , Error > {
107
+ fn create_session ( model_bytes : & [ u8 ] ) -> Result < Arc < Session > , Error > {
110
108
let session = {
111
109
if let Some ( err) = ensure_onnx_env_init ( ) {
112
110
return Err ( anyhow ! ( "failed to create onnx environment: {err}" ) ) ;
@@ -117,7 +115,14 @@ fn create_session(model_bytes: &[u8]) -> Result<Arc<Mutex<Session>>, Error> {
117
115
. commit_from_memory ( model_bytes) ?
118
116
} ;
119
117
120
- Ok ( Arc :: new ( Mutex :: new ( session) ) )
118
+ Ok ( Arc :: new ( session) )
119
+ }
120
+
121
+ #[ allow( mutable_transmutes) ]
122
+ #[ allow( clippy:: mut_from_ref) ]
123
+ pub ( crate ) unsafe fn as_mut_session ( session : & Arc < Session > ) -> & mut Session {
124
+ // SAFETY: CPU EP https://github.com/pykeio/ort/issues/402#issuecomment-2949993914
125
+ unsafe { std:: mem:: transmute :: < & Session , & mut Session > ( & session. clone ( ) ) }
121
126
}
122
127
123
128
#[ instrument( level = "debug" , skip_all, fields( model_bytes = model_bytes. len( ) ) , err) ]
@@ -174,7 +179,7 @@ pub(crate) async fn load_session_from_url(
174
179
Ok ( ( session_id, session) . into ( ) )
175
180
}
176
181
177
- pub ( crate ) async fn get_session ( id : & str ) -> Option < Arc < Mutex < Session > > > {
182
+ pub ( crate ) async fn get_session ( id : & str ) -> Option < Arc < Session > > {
178
183
SESSIONS . get ( id) . map ( |value| value. pair ( ) . 1 . clone ( ) )
179
184
}
180
185
0 commit comments