-
Notifications
You must be signed in to change notification settings - Fork 109
/
Copy pathsql.rs
110 lines (103 loc) · 3.69 KB
/
sql.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
use super::PreparedStatement;
use crate::webserver::database::StmtParam;
use crate::Database;
use anyhow::Context;
use sqlparser::ast::{visitor_fn_mut, DataType, DriveMut, Expr, Value, VisitorEvent};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use sqlx::any::{AnyKind, AnyTypeInfo};
use sqlx::postgres::types::Oid;
use sqlx::postgres::PgTypeInfo;
use sqlx::{Executor, Statement};
#[derive(Default)]
pub struct ParsedSqlFile {
pub(super) statements: Vec<anyhow::Result<PreparedStatement>>,
}
impl ParsedSqlFile {
pub(super) async fn new(db: &Database, sql: &str) -> anyhow::Result<ParsedSqlFile> {
let dialect = GenericDialect {};
let ast = Parser::parse_sql(&dialect, sql)?;
let db_kind = db.connection.any_kind();
let mut statements = Vec::with_capacity(ast.len());
for mut stmt in ast {
let param_names = extract_parameters(&mut stmt, db_kind);
let parameters = map_params(param_names);
let query = stmt.to_string();
let param_types = get_param_types(¶meters);
let stmt_res = db
.connection
.prepare_with(&query, ¶m_types)
.await
.with_context(|| format!("Preparing SQL statement: '{}'", query));
statements.push(stmt_res.map(|statement| PreparedStatement {
statement: statement.to_owned(),
parameters,
}));
}
Ok(ParsedSqlFile { statements })
}
}
fn get_param_types(parameters: &[StmtParam]) -> Vec<AnyTypeInfo> {
parameters
.iter()
.map(|_p| PgTypeInfo::with_oid(Oid(25)).into())
.collect()
}
fn map_params(names: Vec<String>) -> Vec<StmtParam> {
names
.into_iter()
.map(|name| {
let (prefix, name) = name.split_at(1);
let name = name.to_owned();
match prefix {
"$" => StmtParam::GetOrPost(name),
":" => StmtParam::Post(name),
_ => StmtParam::Get(name),
}
})
.collect()
}
fn extract_parameters(sql_ast: &mut sqlparser::ast::Statement, db: AnyKind) -> Vec<String> {
let mut parameters: Vec<String> = Vec::new();
sql_ast.drive_mut(&mut visitor_fn_mut(|value: &mut Expr, event| {
// Only update the nodes AFTER they have been visited
if let VisitorEvent::Enter = event {
return;
}
if let Expr::Value(Value::Placeholder(param)) = value {
let new_expr = make_placeholder(db, parameters.len());
let name = std::mem::take(param);
parameters.push(name);
*value = new_expr
}
}));
parameters
}
fn make_placeholder(db: AnyKind, current_count: usize) -> Expr {
let name = match db {
// Postgres only supports numbered parameters
AnyKind::Postgres => format!("${}", current_count + 1),
_ => '?'.to_string(),
};
let data_type = match db {
// MySQL requires CAST(? AS CHAR) and does not understand CAST(? AS TEXT)
AnyKind::MySql => DataType::Char(None),
_ => DataType::Text,
};
let value = Expr::Value(Value::Placeholder(name));
Expr::Cast {
expr: Box::new(value),
data_type,
}
}
#[test]
fn test_statement_rewrite() {
let sql = "select $a from t where $x > $a OR $x = 0";
let mut ast = Parser::parse_sql(&GenericDialect, sql).unwrap();
let parameters = extract_parameters(&mut ast[0], AnyKind::Postgres);
assert_eq!(
ast[0].to_string(),
"SELECT CAST($1 AS TEXT) FROM t WHERE CAST($2 AS TEXT) > CAST($3 AS TEXT) OR CAST($4 AS TEXT) = 0"
);
assert_eq!(parameters, ["$a", "$x", "$a", "$x"]);
}