Skip to content
190 changes: 119 additions & 71 deletions sqlx-macros-core/src/query/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::cell::RefCell;
use std::collections::{hash_map, HashMap};
use std::env::VarError;
use std::path::{Path, PathBuf};
use std::sync::{Arc, LazyLock, Mutex};
use std::{fs, io};
Expand Down Expand Up @@ -109,61 +111,64 @@ impl Metadata {
}
}

static METADATA: LazyLock<Mutex<HashMap<String, Metadata>>> = LazyLock::new(Default::default);
static METADATA: LazyLock<Mutex<HashMap<PathBuf, Arc<Metadata>>>> = LazyLock::new(Default::default);
static CRATE_ENV_FILE_VARS: LazyLock<Mutex<HashMap<PathBuf, HashMap<String, String>>>> =
LazyLock::new(Default::default);

thread_local! {
static CURRENT_CRATE_MANIFEST_DIR: RefCell<PathBuf> = RefCell::new(PathBuf::new());
}

// If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't
// reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946
fn init_metadata(manifest_dir: &String) -> crate::Result<Metadata> {
let manifest_dir: PathBuf = manifest_dir.into();
fn init_metadata(manifest_dir: &Path) -> crate::Result<Arc<Metadata>> {
let config = Config::try_from_crate_or_default()?;

let (database_url, offline, offline_dir) = load_dot_env(&manifest_dir);
load_env(manifest_dir, &config);

let offline = env("SQLX_OFFLINE")
.ok()
.or(offline)
.map(|s| s.eq_ignore_ascii_case("true") || s == "1")
.unwrap_or(false);

let offline_dir = env("SQLX_OFFLINE_DIR").ok().or(offline_dir);

let config = Config::try_from_crate_or_default()?;
let offline_dir = env("SQLX_OFFLINE_DIR").ok();

let database_url = env(config.common.database_url_var()).ok().or(database_url);
let database_url = env(config.common.database_url_var()).ok();

Ok(Metadata {
manifest_dir,
Ok(Arc::new(Metadata {
manifest_dir: manifest_dir.to_path_buf(),
offline,
database_url,
offline_dir,
config,
workspace_root: Arc::new(Mutex::new(None)),
})
}))
}

pub fn expand_input<'a>(
input: QueryMacroInput,
drivers: impl IntoIterator<Item = &'a QueryDriver>,
) -> crate::Result<TokenStream> {
let manifest_dir = env("CARGO_MANIFEST_DIR").expect("`CARGO_MANIFEST_DIR` must be set");
// `CARGO_MANIFEST_DIR` can only be loaded from a real environment variable due to the filtering done
// by `load_env`, so the value of `CURRENT_CRATE_MANIFEST_DIR` does not matter here.
let manifest_dir =
PathBuf::from(env("CARGO_MANIFEST_DIR").expect("`CARGO_MANIFEST_DIR` must be set"));
CURRENT_CRATE_MANIFEST_DIR.set(manifest_dir.clone());

let mut metadata_lock = METADATA
.lock()
// Just reset the metadata on error
.unwrap_or_else(|poison_err| {
let mut guard = poison_err.into_inner();
*guard = Default::default();
guard
});
let mut metadata_lock = METADATA.lock().unwrap();

let metadata = match metadata_lock.entry(manifest_dir) {
hash_map::Entry::Occupied(occupied) => occupied.into_mut(),
hash_map::Entry::Occupied(occupied) => Arc::clone(occupied.get()),
hash_map::Entry::Vacant(vacant) => {
let metadata = init_metadata(vacant.key())?;
vacant.insert(metadata)
vacant.insert(Arc::clone(&metadata));
metadata
}
};

let data_source = match &metadata {
// Release the lock now so other expansions in other threads of this process can proceed concurrently.
drop(metadata_lock);

let data_source = match &*metadata {
Metadata {
offline: false,
database_url: Some(db_url),
Expand All @@ -181,7 +186,7 @@ pub fn expand_input<'a>(
];
let Some(data_file_path) = dirs
.iter()
.filter_map(|path| path(metadata))
.filter_map(|path| path(&metadata))
.map(|path| path.join(&filename))
.find(|path| path.exists())
else {
Expand Down Expand Up @@ -415,64 +420,107 @@ where
Ok(ret_tokens)
}

/// Get the value of an environment variable, telling the compiler about it if applicable.
/// Get the value of an environment variable for the current crate, telling the compiler about it if applicable.
///
/// The current crate is determined by the `CURRENT_CRATE_MANIFEST_DIR` thread-local variable, which is assumed
/// to be set to match the crate whose macro is being expanded before this function is called. It is also assumed
/// that the expansion of this macro happens on a single thread.
fn env(name: &str) -> Result<String, std::env::VarError> {
#[cfg(procmacro2_semver_exempt)]
{
proc_macro::tracked_env::var(name)
}

let tracked_value = Some(proc_macro::tracked_env::var(name));
#[cfg(not(procmacro2_semver_exempt))]
{
std::env::var(name)
let tracked_value = None;

match tracked_value.map_or_else(|| std::env::var(name), |var| var) {
Ok(v) => Ok(v),
Err(VarError::NotPresent) => CURRENT_CRATE_MANIFEST_DIR
.with_borrow(|manifest_dir| {
CRATE_ENV_FILE_VARS
.lock()
.unwrap()
.get(manifest_dir)
.cloned()
})
.and_then(|env_file_vars| env_file_vars.get(name).cloned())
.ok_or(VarError::NotPresent),
Err(e) => Err(e),
}
}

/// Get `DATABASE_URL`, `SQLX_OFFLINE` and `SQLX_OFFLINE_DIR` from the `.env`.
fn load_dot_env(manifest_dir: &Path) -> (Option<String>, Option<String>, Option<String>) {
let mut env_path = manifest_dir.join(".env");

// If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this,
// otherwise fallback to default dotenv file.
#[cfg_attr(not(procmacro2_semver_exempt), allow(unused_variables))]
let env_file = if env_path.exists() {
let res = dotenvy::from_path_iter(&env_path);
match res {
Ok(iter) => Some(iter),
Err(e) => panic!("failed to load environment from {env_path:?}, {e}"),
}
/// Load configuration environment variables from a `.env` file. If applicable, the compiler is
/// about the `.env` files they may come from.
fn load_env(manifest_dir: &Path, config: &Config) {
// A whitelist of environment variables to load from a `.env` file avoids
// such files overriding internal variables they should not (e.g., `CARGO`,
// `CARGO_MANIFEST_DIR`) when using the `env` function above.
let database_url_var = config.common.database_url_var();
let loadable_vars = if database_url_var == "DATABASE_URL" {
&["DATABASE_URL", "SQLX_OFFLINE", "SQLX_OFFLINE_DIR"][..]
} else {
#[allow(unused_assignments)]
{
env_path = PathBuf::from(".env");
}
dotenvy::dotenv_iter().ok()
&[
"DATABASE_URL",
"SQLX_OFFLINE",
"SQLX_OFFLINE_DIR",
database_url_var,
]
};

let mut offline = None;
let mut database_url = None;
let mut offline_dir = None;
let (found_dotenv, candidate_dotenv_paths) = find_dotenv(manifest_dir);

if let Some(env_file) = env_file {
// tell the compiler to watch the `.env` for changes.
#[cfg(procmacro2_semver_exempt)]
if let Some(env_path) = env_path.to_str() {
proc_macro::tracked_path::path(env_path);
// Tell the compiler to watch the candidate `.env` paths for changes. It's important to
// watch them all, because there are several possible locations where a `.env` file
// might be read, and we want to react to changes in any of them.
#[cfg(procmacro2_semver_exempt)]
for path in &candidate_dotenv_paths {
if let Some(path) = path.to_str() {
proc_macro::tracked_path::path(path);
}
}

for item in env_file {
let Ok((key, value)) = item else {
continue;
};
let loaded_vars = found_dotenv
.then_some(candidate_dotenv_paths)
.iter()
.flatten()
.last()
.map(|dotenv_path| {
dotenvy::from_path_iter(dotenv_path)
.ok()
.into_iter()
.flatten()
.filter_map(|dotenv_var_result| match dotenv_var_result {
Ok((key, value))
if loadable_vars.contains(&&*key) && std::env::var(&key).is_err() =>
{
Some((key, value))
}
_ => None,
})
})
.into_iter()
.flatten()
.collect();

match key.as_str() {
"DATABASE_URL" => database_url = Some(value),
"SQLX_OFFLINE" => offline = Some(value),
"SQLX_OFFLINE_DIR" => offline_dir = Some(value),
_ => {}
};
CRATE_ENV_FILE_VARS
.lock()
.unwrap()
.insert(manifest_dir.to_path_buf(), loaded_vars);
}

fn find_dotenv(mut dir: &Path) -> (bool, Vec<PathBuf>) {
let mut candidate_files = vec![];

loop {
candidate_files.push(dir.join(".env"));
let candidate_file = candidate_files.last().unwrap();

if candidate_file.is_file() {
return (true, candidate_files);
}
}

(database_url, offline, offline_dir)
if let Some(parent) = dir.parent() {
dir = parent;
} else {
return (false, candidate_files);
}
}
}
Loading