Skip to content

Commit

Permalink
fix(macros): tell the compiler about external files/env vars to watch
Browse files Browse the repository at this point in the history
closes #663
closes #681
  • Loading branch information
abonander committed Jul 20, 2021
1 parent be189bd commit edb6b5f
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 31 deletions.
1 change: 1 addition & 0 deletions sqlx-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
not(any(feature = "postgres", feature = "mysql", feature = "offline")),
allow(dead_code, unused_macros, unused_imports)
)]
#![cfg_attr(procmacro2_semver_exempt, feature(track_path, proc_macro_tracked_env))]
extern crate proc_macro;

use proc_macro::TokenStream;
Expand Down
36 changes: 31 additions & 5 deletions sqlx-macros/src/migrate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct QuotedMigration {
version: i64,
description: String,
migration_type: QuotedMigrationType,
sql: String,
path: String,
checksum: Vec<u8>,
}

Expand All @@ -34,7 +34,7 @@ impl ToTokens for QuotedMigration {
version,
description,
migration_type,
sql,
path,
checksum,
} = &self;

Expand All @@ -43,7 +43,8 @@ impl ToTokens for QuotedMigration {
version: #version,
description: ::std::borrow::Cow::Borrowed(#description),
migration_type: #migration_type,
sql: ::std::borrow::Cow::Borrowed(#sql),
// this tells the compiler to watch this path for changes
sql: ::std::borrow::Cow::Borrowed(include_str!(#path)),
checksum: ::std::borrow::Cow::Borrowed(&[
#(#checksum),*
]),
Expand All @@ -59,7 +60,7 @@ pub(crate) fn expand_migrator_from_dir(dir: LitStr) -> crate::Result<TokenStream
let path = crate::common::resolve_path(&dir.value(), dir.span())?;
let mut migrations = Vec::new();

for entry in fs::read_dir(path)? {
for entry in fs::read_dir(&path)? {
let entry = entry?;
if !fs::metadata(entry.path())?.is_file() {
// not a file; ignore
Expand Down Expand Up @@ -89,18 +90,43 @@ pub(crate) fn expand_migrator_from_dir(dir: LitStr) -> crate::Result<TokenStream

let checksum = Vec::from(Sha384::digest(sql.as_bytes()).as_slice());

// canonicalize the path so we can pass it to `include_str!()`
let path = entry.path().canonicalize()?;
let path = path
.to_str()
.ok_or_else(|| {
format!(
"migration path cannot be represented as a string: {:?}",
path
)
})?
.to_owned();

migrations.push(QuotedMigration {
version,
description,
migration_type: QuotedMigrationType(migration_type),
sql,
path,
checksum,
})
}

// ensure that we are sorted by `VERSION ASC`
migrations.sort_by_key(|m| m.version);

#[cfg(procmacro2_semver_exempt)]
{
let path = path.canonicalize()?;
let path = path.to_str().ok_or_else(|| {
format!(
"migration directory path cannot be represented as a string: {:?}",
path
)
})?;

proc_macro::tracked_path::path(path);
}

Ok(quote! {
::sqlx::migrate::Migrator {
migrations: ::std::borrow::Cow::Borrowed(&[
Expand Down
20 changes: 17 additions & 3 deletions sqlx-macros/src/query/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,30 @@ pub mod offline {
/// Find and deserialize the data table for this query from a shared `sqlx-data.json`
/// file. The expected structure is a JSON map keyed by the SHA-256 hash of queries in hex.
pub fn from_data_file(path: impl AsRef<Path>, query: &str) -> crate::Result<Self> {
serde_json::Deserializer::from_reader(BufReader::new(
let this = serde_json::Deserializer::from_reader(BufReader::new(
File::open(path.as_ref()).map_err(|e| {
format!("failed to open path {}: {}", path.as_ref().display(), e)
})?,
))
.deserialize_map(DataFileVisitor {
query,
hash: hash_string(query),
})
.map_err(Into::into)
})?;

#[cfg(procmacr2_semver_exempt)]
{
let path = path.as_ref().canonicalize()?;
let path = path.to_str().ok_or_else(|| {
format!(
"sqlx-data.json path cannot be represented as a string: {:?}",
path
)
})?;

proc_macro::tracked_path::path(path);
}

Ok(this)
}
}

Expand Down
30 changes: 28 additions & 2 deletions sqlx-macros/src/query/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use syn::{ExprArray, Type};

/// Macro input shared by `query!()` and `query_file!()`
pub struct QueryMacroInput {
pub(super) src: String,
pub(super) sql: String,

#[cfg_attr(not(feature = "offline"), allow(dead_code))]
pub(super) src_span: Span,
Expand All @@ -18,6 +18,8 @@ pub struct QueryMacroInput {
pub(super) arg_exprs: Vec<Expr>,

pub(super) checked: bool,

pub(super) file_path: Option<String>,
}

enum QuerySrc {
Expand Down Expand Up @@ -94,12 +96,15 @@ impl Parse for QueryMacroInput {

let arg_exprs = args.unwrap_or_default();

let file_path = src.file_path(src_span)?;

Ok(QueryMacroInput {
src: src.resolve(src_span)?,
sql: src.resolve(src_span)?,
src_span,
record_type,
arg_exprs,
checked,
file_path,
})
}
}
Expand All @@ -112,6 +117,27 @@ impl QuerySrc {
QuerySrc::File(file) => read_file_src(&file, source_span),
}
}

fn file_path(&self, source_span: Span) -> syn::Result<Option<String>> {
if let QuerySrc::File(ref file) = *self {
let path = std::path::Path::new(file)
.canonicalize()
.map_err(|e| syn::Error::new(source_span, e))?;

Ok(Some(
path.to_str()
.ok_or_else(|| {
syn::Error::new(
source_span,
"query file path cannot be represented as a string",
)
})?
.to_string(),
))
} else {
Ok(None)
}
}
}

fn read_file_src(source: &str, source_span: Span) -> syn::Result<String> {
Expand Down
55 changes: 37 additions & 18 deletions sqlx-macros/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod input;
mod output;

struct Metadata {
#[allow(unused)]
manifest_dir: PathBuf,
offline: bool,
database_url: Option<String>,
Expand All @@ -35,42 +36,47 @@ struct Metadata {
// If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't
// reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946
static METADATA: Lazy<Metadata> = Lazy::new(|| {
use std::env;

let manifest_dir: PathBuf = env::var("CARGO_MANIFEST_DIR")
let manifest_dir: PathBuf = env("CARGO_MANIFEST_DIR")
.expect("`CARGO_MANIFEST_DIR` must be set")
.into();

#[cfg(feature = "offline")]
let target_dir =
env::var_os("CARGO_TARGET_DIR").map_or_else(|| "target".into(), |dir| dir.into());
let target_dir = env("CARGO_TARGET_DIR").map_or_else(|_| "target".into(), |dir| dir.into());

// If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this,
// otherwise fallback to default dotenv behaviour.
let env_path = manifest_dir.join(".env");
if env_path.exists() {

#[cfg_attr(not(procmacro2_semver_exempt), allow(unused_variables))]
let env_path = if env_path.exists() {
let res = dotenv::from_path(&env_path);
if let Err(e) = res {
panic!("failed to load environment from {:?}, {}", env_path, e);
}

Some(env_path)
} else {
let _ = dotenv::dotenv();
dotenv::dotenv().ok()
};

// tell the compiler to watch the `.env` for changes, if applicable
#[cfg(procmacro2_semver_exempt)]
if let Some(env_path) = env_path.as_ref().and_then(|path| path.to_str()) {
proc_macro::tracked_path::path(env_path);
}

// TODO: Switch to `var_os` after feature(osstring_ascii) is stable.
// Stabilization PR: https://github.com/rust-lang/rust/pull/80193
let offline = env::var("SQLX_OFFLINE")
let offline = env("SQLX_OFFLINE")
.map(|s| s.eq_ignore_ascii_case("true") || s == "1")
.unwrap_or(false);

let database_url = env::var("DATABASE_URL").ok();
let database_url = env("DATABASE_URL").ok();

#[cfg(feature = "offline")]
let workspace_root = {
use serde::Deserialize;
use std::process::Command;

let cargo = env::var_os("CARGO").expect("`CARGO` must be set");
let cargo = env("CARGO").expect("`CARGO` must be set");

let output = Command::new(&cargo)
.args(&["metadata", "--format-version=1"])
Expand Down Expand Up @@ -151,7 +157,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenSt
"postgres" | "postgresql" => {
let data = block_on(async {
let mut conn = sqlx_core::postgres::PgConnection::connect(db_url.as_str()).await?;
QueryData::from_db(&mut conn, &input.src).await
QueryData::from_db(&mut conn, &input.sql).await
})?;

expand_with_data(input, data, false)
Expand All @@ -164,7 +170,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenSt
"mssql" | "sqlserver" => {
let data = block_on(async {
let mut conn = sqlx_core::mssql::MssqlConnection::connect(db_url.as_str()).await?;
QueryData::from_db(&mut conn, &input.src).await
QueryData::from_db(&mut conn, &input.sql).await
})?;

expand_with_data(input, data, false)
Expand All @@ -177,7 +183,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenSt
"mysql" | "mariadb" => {
let data = block_on(async {
let mut conn = sqlx_core::mysql::MySqlConnection::connect(db_url.as_str()).await?;
QueryData::from_db(&mut conn, &input.src).await
QueryData::from_db(&mut conn, &input.sql).await
})?;

expand_with_data(input, data, false)
Expand All @@ -190,7 +196,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenSt
"sqlite" => {
let data = block_on(async {
let mut conn = sqlx_core::sqlite::SqliteConnection::connect(db_url.as_str()).await?;
QueryData::from_db(&mut conn, &input.src).await
QueryData::from_db(&mut conn, &input.sql).await
})?;

expand_with_data(input, data, false)
Expand All @@ -207,7 +213,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result<TokenSt
pub fn expand_from_file(input: QueryMacroInput, file: PathBuf) -> crate::Result<TokenStream> {
use data::offline::DynQueryData;

let query_data = DynQueryData::from_data_file(file, &input.src)?;
let query_data = DynQueryData::from_data_file(file, &input.sql)?;
assert!(!query_data.db_name.is_empty());

match &*query_data.db_name {
Expand Down Expand Up @@ -288,7 +294,7 @@ where
.all(|it| it.type_info().is_void())
{
let db_path = DB::db_path();
let sql = &input.src;
let sql = &input.sql;

quote! {
::sqlx::query_with::<#db_path, _>(#sql, #query_args)
Expand Down Expand Up @@ -368,3 +374,16 @@ where

Ok(ret_tokens)
}

/// Get the value of an environment variable, telling the compiler about it if applicable.
fn env(name: &str) -> Result<String, std::env::VarError> {
#[cfg(procmacro2_semver_exempt)]
{
proc_macro::tracked_env::var(name)
}

#[cfg(not(procmacro2_semver_exempt))]
{
std::env::var(name)
}
}
11 changes: 9 additions & 2 deletions sqlx-macros/src/query/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,14 @@ pub fn quote_query_as<DB: DatabaseExt>(

let db_path = DB::db_path();
let row_path = DB::row_path();
let sql = &input.src;

// if this query came from a file, use `include_str!()` to tell the compiler where it came from
let sql = if let Some(ref path) = &input.file_path {
quote::quote_spanned! { input.src_span => include_str!(#path) }
} else {
let sql = &input.sql;
quote! { #sql }
};

quote! {
::sqlx::query_with::<#db_path, _>(#sql, #bind_args).try_map(|row: #row_path| {
Expand Down Expand Up @@ -200,7 +207,7 @@ pub fn quote_query_scalar<DB: DatabaseExt>(
};

let db = DB::db_path();
let query = &input.src;
let query = &input.sql;

Ok(quote! {
::sqlx::query_scalar_with::<#db, #ty, _>(#query, #bind_args)
Expand Down
Loading

0 comments on commit edb6b5f

Please sign in to comment.