Skip to content

Commit 1f45d81

Browse files
committed
revert: using transmute to ignore mut refs of Sessions instead of Mutex lock
1 parent 8faa059 commit 1f45d81

File tree

4 files changed

+32
-42
lines changed

4 files changed

+32
-42
lines changed

ext/ai/lib.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ use tracing::error;
4040
use tracing::trace_span;
4141
use tracing::Instrument;
4242

43+
use crate::onnxruntime::session::as_mut_session;
44+
4345
deno_core::extension!(
4446
ai,
4547
ops = [
@@ -213,14 +215,10 @@ async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
213215
&*token_type_ids,
214216
))?;
215217

216-
let Ok(mut guard) = session.lock() else {
217-
let err = anyhow!("failed to lock session");
218-
error!(reason = ?err);
219-
return Err(err);
220-
};
218+
let session = unsafe { as_mut_session(&session) };
221219

222220
let outputs = trace_span!("infer_gte").in_scope(|| {
223-
guard.run(inputs! {
221+
session.run(inputs! {
224222
"input_ids" => input_ids_array,
225223
"token_type_ids" => token_type_ids_array,
226224
"attention_mask" => attention_mask_array,

ext/ai/onnxruntime/mod.rs

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ use std::cell::RefCell;
1010
use std::collections::HashMap;
1111
use std::rc::Rc;
1212
use std::sync::Arc;
13-
use std::sync::Mutex;
1413

1514
use anyhow::anyhow;
1615
use anyhow::Context;
@@ -33,6 +32,8 @@ use tokio::sync::oneshot;
3332
use tracing::debug;
3433
use tracing::trace;
3534

35+
use crate::onnxruntime::session::as_mut_session;
36+
3637
#[op2(async)]
3738
#[to_v8]
3839
pub async fn op_ai_ort_init_session(
@@ -56,11 +57,8 @@ pub async fn op_ai_ort_init_session(
5657
};
5758

5859
let mut state = state.borrow_mut();
59-
let mut sessions = {
60-
state
61-
.try_take::<Vec<Arc<Mutex<Session>>>>()
62-
.unwrap_or_default()
63-
};
60+
let mut sessions =
61+
{ state.try_take::<Vec<Arc<Session>>>().unwrap_or_default() };
6462

6563
sessions.push(model.get_session());
6664
state.put(sessions);
@@ -107,12 +105,9 @@ pub async fn op_ai_ort_run_session(
107105
JsRuntime::op_state_from(state)
108106
.borrow_mut()
109107
.spawn_cpu_accumul_blocking_scope(move || {
110-
let Ok(mut session_guard) = model_session.lock() else {
111-
let _ = tx.send(Err(anyhow!("failed to lock model session")));
112-
return;
113-
};
108+
let session = unsafe { as_mut_session(&model_session) };
114109

115-
let outputs = match session_guard.run(input_values) {
110+
let outputs = match session.run(input_values) {
116111
Ok(v) => v,
117112
Err(err) => {
118113
let _ = tx.send(Err(anyhow::Error::from(err)));

ext/ai/onnxruntime/model.rs

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
use std::sync::Arc;
2-
use std::sync::Mutex;
3-
4-
use anyhow::anyhow;
51
use anyhow::Result;
62
use deno_core::serde_v8::to_v8;
73
use deno_core::ToV8;
84
use ort::session::Session;
95
use reqwest::Url;
6+
use std::sync::Arc;
107

118
use super::session::get_session;
129
use super::session::load_session_from_bytes;
@@ -29,26 +26,21 @@ impl std::fmt::Display for ModelInfo {
2926
#[derive(Debug)]
3027
pub struct Model {
3128
info: ModelInfo,
32-
session: Arc<Mutex<Session>>,
29+
session: Arc<Session>,
3330
}
3431

3532
impl Model {
3633
fn new(session_with_id: SessionWithId) -> Result<Self> {
3734
let (input_names, output_names) = {
38-
let Ok(session_guard) = session_with_id.session.lock() else {
39-
return Err(anyhow!(
40-
"Could not lock model session {}",
41-
session_with_id.id
42-
));
43-
};
44-
45-
let input_names = session_guard
35+
let session = { session_with_id.session.clone() };
36+
37+
let input_names = session
4638
.inputs
4739
.iter()
4840
.map(|input| input.name.clone())
4941
.collect::<Vec<_>>();
5042

51-
let output_names = session_guard
43+
let output_names = session
5244
.outputs
5345
.iter()
5446
.map(|output| output.name.clone())
@@ -71,7 +63,7 @@ impl Model {
7163
self.info.clone()
7264
}
7365

74-
pub fn get_session(&self) -> Arc<Mutex<Session>> {
66+
pub fn get_session(&self) -> Arc<Session> {
7567
self.session.clone()
7668
}
7769

ext/ai/onnxruntime/session.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use once_cell::sync::Lazy;
55
use reqwest::Url;
66
use std::hash::Hasher;
77
use std::sync::Arc;
8-
use std::sync::Mutex;
98
use tokio_util::compat::FuturesAsyncWriteCompatExt;
109
use tracing::debug;
1110
use tracing::instrument;
@@ -25,17 +24,16 @@ use ort::session::Session;
2524

2625
use crate::onnx::ensure_onnx_env_init;
2726

28-
static SESSIONS: Lazy<DashMap<String, Arc<Mutex<Session>>>> =
29-
Lazy::new(DashMap::new);
27+
static SESSIONS: Lazy<DashMap<String, Arc<Session>>> = Lazy::new(DashMap::new);
3028

3129
#[derive(Debug)]
3230
pub struct SessionWithId {
3331
pub(crate) id: String,
34-
pub(crate) session: Arc<Mutex<Session>>,
32+
pub(crate) session: Arc<Session>,
3533
}
3634

37-
impl From<(String, Arc<Mutex<Session>>)> for SessionWithId {
38-
fn from(value: (String, Arc<Mutex<Session>>)) -> Self {
35+
impl From<(String, Arc<Session>)> for SessionWithId {
36+
fn from(value: (String, Arc<Session>)) -> Self {
3937
Self {
4038
id: value.0,
4139
session: value.1,
@@ -50,7 +48,7 @@ impl std::fmt::Display for SessionWithId {
5048
}
5149

5250
impl SessionWithId {
53-
pub fn into_split(self) -> (String, Arc<Mutex<Session>>) {
51+
pub fn into_split(self) -> (String, Arc<Session>) {
5452
(self.id, self.session)
5553
}
5654
}
@@ -106,7 +104,7 @@ fn get_execution_providers() -> Vec<ExecutionProviderDispatch> {
106104
[cpu].to_vec()
107105
}
108106

109-
fn create_session(model_bytes: &[u8]) -> Result<Arc<Mutex<Session>>, Error> {
107+
fn create_session(model_bytes: &[u8]) -> Result<Arc<Session>, Error> {
110108
let session = {
111109
if let Some(err) = ensure_onnx_env_init() {
112110
return Err(anyhow!("failed to create onnx environment: {err}"));
@@ -117,7 +115,14 @@ fn create_session(model_bytes: &[u8]) -> Result<Arc<Mutex<Session>>, Error> {
117115
.commit_from_memory(model_bytes)?
118116
};
119117

120-
Ok(Arc::new(Mutex::new(session)))
118+
Ok(Arc::new(session))
119+
}
120+
121+
#[allow(mutable_transmutes)]
122+
#[allow(clippy::mut_from_ref)]
123+
pub(crate) unsafe fn as_mut_session(session: &Arc<Session>) -> &mut Session {
124+
// SAFETY: CPU EP https://github.com/pykeio/ort/issues/402#issuecomment-2949993914
125+
unsafe { std::mem::transmute::<&Session, &mut Session>(&session.clone()) }
121126
}
122127

123128
#[instrument(level = "debug", skip_all, fields(model_bytes = model_bytes.len()), err)]
@@ -174,7 +179,7 @@ pub(crate) async fn load_session_from_url(
174179
Ok((session_id, session).into())
175180
}
176181

177-
pub(crate) async fn get_session(id: &str) -> Option<Arc<Mutex<Session>>> {
182+
pub(crate) async fn get_session(id: &str) -> Option<Arc<Session>> {
178183
SESSIONS.get(id).map(|value| value.pair().1.clone())
179184
}
180185

0 commit comments

Comments
 (0)