Skip to content

Commit

Permalink
Reformat support (psl crate)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergey Tatarintsev committed Mar 26, 2024
1 parent aad9355 commit 883146e
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 45 deletions.
2 changes: 1 addition & 1 deletion prisma-fmt/src/get_dmmf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub(crate) fn get_dmmf(params: &str) -> Result<String, String> {
}
};

validate::run(params.prisma_schema, params.no_color).map(|schema| dmmf::dmmf_json_from_validate_schema(schema))
validate::run(params.prisma_schema, params.no_color).map(|schema| dmmf::dmmf_json_from_validated_schema(schema))
}

#[cfg(test)]
Expand Down
1 change: 0 additions & 1 deletion prisma-fmt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mod code_actions;
mod get_config;
mod get_dmmf;
mod lint;
mod merge_schemas;
mod native;
mod preview;
mod schema_file_input;
Expand Down
12 changes: 11 additions & 1 deletion psl/parser-database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ impl ParserDatabase {
self.asts.into_iter().map(|(_, _, _, ast)| ast)
}

/// Iterate all file ids
pub fn iter_file_ids(&self) -> impl Iterator<Item = FileId> + '_ {
self.asts.iter().map(|(file_id, _, _, _)| file_id)
}

/// A parsed AST.
pub fn ast(&self, file_id: FileId) -> &ast::SchemaAst {
&self.asts[file_id].2
Expand All @@ -203,9 +208,14 @@ impl ParserDatabase {
}

/// The source file contents.
pub(crate) fn source(&self, file_id: FileId) -> &str {
pub fn source(&self, file_id: FileId) -> &str {
self.asts[file_id].1.as_str()
}

/// The name of the file.
pub fn file_name(&self, file_id: FileId) -> &str {
self.asts[file_id].0.as_str()
}
}

impl std::ops::Index<FileId> for ParserDatabase {
Expand Down
2 changes: 1 addition & 1 deletion psl/psl-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub use crate::{
configuration::{
Configuration, Datasource, DatasourceConnectorData, Generator, GeneratorConfigValue, StringFromEnvVar,
},
reformat::reformat,
reformat::{reformat, reformat_multiple},
};
pub use diagnostics;
pub use parser_database::{self, is_reserved_type_name};
Expand Down
146 changes: 106 additions & 40 deletions psl/psl-core/src/reformat.rs
Original file line number Diff line number Diff line change
@@ -1,46 +1,93 @@
use crate::ParserDatabase;
use diagnostics::FileId;
use parser_database::{ast::WithSpan, walkers};
use schema_ast::{ast, SourceFile};
use std::{borrow::Cow, sync::Arc};
use std::{borrow::Cow, collections::HashMap, sync::Arc};

/// Returns either the reformatted schema, or the original input if we can't reformat. This happens
/// if and only if the source does not parse to a well formed AST.
pub fn reformat(source: &str, indent_width: usize) -> Option<String> {
let file = SourceFile::new_allocated(Arc::from(source.to_owned().into_boxed_str()));
let reformatted = reformat_multiple(&vec![("schema.prisma", source)], indent_width);

let mut diagnostics = diagnostics::Diagnostics::new();
let db = parser_database::ParserDatabase::new_single_file(file, &mut diagnostics);
reformatted.first().map(|(_, source)| source).cloned()
}

let source_to_reformat = if diagnostics.has_errors() {
Cow::Borrowed(source)
pub fn reformat_multiple(sources: &[(&str, &str)], indent_width: usize) -> Vec<(String, String)> {
let sources = sources
.into_iter()
.map(|(name, source)| {
(
(*name).to_owned(),
SourceFile::new_allocated(Arc::from((*source).to_owned().into_boxed_str())),
)
})
.collect();
let mut diagnostics = diagnostics::Diagnostics::new();
let db = parser_database::ParserDatabase::new(sources, &mut diagnostics);

if diagnostics.has_errors() {
dbg!(&diagnostics);
db.iter_file_ids()
.filter_map(|file_id| {
let formatted_source = schema_ast::reformat(db.source(file_id), indent_width)?;
Some((db.file_name(file_id).to_owned(), formatted_source))
})
.collect()
} else {
let mut missing_bits = Vec::new();
let mut missing_bits = HashMap::new();

let mut ctx = MagicReformatCtx {
original_schema: source,
missing_bits: &mut missing_bits,
missing_bits_map: &mut missing_bits,
db: &db,
};

push_missing_fields(&mut ctx);
push_missing_attributes(&mut ctx);
push_missing_relation_attribute_args(&mut ctx);
missing_bits.sort_by_key(|bit| bit.position);
ctx.sort_missing_bits();

if missing_bits.is_empty() {
Cow::Borrowed(source)
} else {
Cow::Owned(enrich(source, &missing_bits))
}
};
db.iter_file_ids()
.filter_map(|file_id| {
let source = if let Some(missing_bits) = ctx.get_missing_bits(file_id) {
Cow::Owned(enrich(db.source(file_id), missing_bits))
} else {
Cow::Borrowed(db.source(file_id))
};

schema_ast::reformat(&source_to_reformat, indent_width)
let formatted_source = schema_ast::reformat(&source, indent_width)?;

Some((db.file_name(file_id).to_owned(), formatted_source))
})
.collect()
}
}

struct MagicReformatCtx<'a> {
original_schema: &'a str,
missing_bits: &'a mut Vec<MissingBit>,
missing_bits_map: &'a mut HashMap<FileId, Vec<MissingBit>>,
db: &'a ParserDatabase,
}

impl<'a> MagicReformatCtx<'a> {
fn add_missing_bit(&mut self, file_id: FileId, bit: MissingBit) {
self.missing_bits_map.entry(file_id).or_default().push(bit);
}

fn get_missing_bits(&self, file_id: FileId) -> Option<&Vec<MissingBit>> {
let bits_vec = self.missing_bits_map.get(&file_id)?;
if bits_vec.is_empty() {
None
} else {
Some(bits_vec)
}
}

fn sort_missing_bits(&mut self) {
self.missing_bits_map
.iter_mut()
.for_each(|(_, bits)| bits.sort_by_key(|bit| bit.position))
}
}

fn enrich(input: &str, missing_bits: &[MissingBit]) -> String {
let bits = missing_bits.iter().scan(0usize, |last_insert_position, missing_bit| {
let start: usize = *last_insert_position;
Expand Down Expand Up @@ -109,10 +156,13 @@ fn push_inline_relation_missing_arguments(
(", ", "", relation_attribute.span.end - 1)
};

ctx.missing_bits.push(MissingBit {
position,
content: format!("{prefix}{extra_args}{suffix}"),
});
ctx.add_missing_bit(
relation_attribute.span.file_id,
MissingBit {
position,
content: format!("{prefix}{extra_args}{suffix}"),
},
);
}
}

Expand All @@ -136,10 +186,14 @@ fn push_missing_relation_attribute(inline_relation: walkers::InlineRelationWalke
content.push_str(&references_argument(inline_relation));
content.push(')');

ctx.missing_bits.push(MissingBit {
position: after_type(forward.ast_field().field_type.span().end, ctx.original_schema),
content,
})
let file_id = forward.ast_field().span().file_id;
ctx.add_missing_bit(
file_id,
MissingBit {
position: after_type(forward.ast_field().field_type.span().end, ctx.db.source(file_id)),
content,
},
);
}
}

Expand Down Expand Up @@ -167,10 +221,14 @@ fn push_missing_relation_fields(inline: walkers::InlineRelationWalker<'_>, ctx:
};
let arity = if inline.is_one_to_one() { "?" } else { "[]" };

ctx.missing_bits.push(MissingBit {
position: inline.referenced_model().ast_model().span().end - 1,
content: format!("{referencing_model_name} {referencing_model_name}{arity} {ignore}\n"),
});
let span = inline.referenced_model().ast_model().span();
ctx.add_missing_bit(
span.file_id,
MissingBit {
position: span.end - 1,
content: format!("{referencing_model_name} {referencing_model_name}{arity} {ignore}\n"),
},
);
}

if inline.forward_relation_field().is_none() {
Expand All @@ -179,10 +237,14 @@ fn push_missing_relation_fields(inline: walkers::InlineRelationWalker<'_>, ctx:
let arity = render_arity(forward_relation_field_arity(inline));
let fields_arg = fields_argument(inline);
let references_arg = references_argument(inline);
ctx.missing_bits.push(MissingBit {
position: inline.referencing_model().ast_model().span().end - 1,
content: format!("{field_name} {field_type}{arity} @relation({fields_arg}, {references_arg})\n"),
})
let span = inline.referencing_model().ast_model().span();
ctx.add_missing_bit(
span.file_id,
MissingBit {
position: span.end - 1,
content: format!("{field_name} {field_type}{arity} @relation({fields_arg}, {references_arg})\n"),
},
)
}
}

Expand Down Expand Up @@ -211,13 +273,17 @@ fn push_missing_scalar_fields(inline: walkers::InlineRelationWalker<'_>, ctx: &m

let mut attributes: String = String::new();
if let Some((_datasource_name, _type_name, _args, span)) = field.blueprint.raw_native_type() {
attributes.push_str(&ctx.original_schema[span.start..span.end]);
attributes.push_str(&ctx.db.source(span.file_id)[span.start..span.end]);
}

ctx.missing_bits.push(MissingBit {
position: inline.referencing_model().ast_model().span().end - 1,
content: format!("{field_name} {field_type}{arity} {attributes}\n"),
});
let span = inline.referencing_model().ast_model().span();
ctx.add_missing_bit(
span.file_id,
MissingBit {
position: span.end - 1,
content: format!("{field_name} {field_type}{arity} {attributes}\n"),
},
);
}
}

Expand Down
1 change: 1 addition & 0 deletions psl/psl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub use psl_core::{
parser_database::{self, SourceFile},
reachable_only_with_capability,
reformat,
reformat_multiple,
schema_ast,
set_config_dir,
Configuration,
Expand Down
4 changes: 4 additions & 0 deletions psl/psl/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ pub(crate) fn reformat(input: &str) -> String {
psl::reformat(input, 2).unwrap_or_else(|| input.to_owned())
}

pub(crate) fn reformat_multiple(sources: &[(&str, &str)]) -> Vec<(String, String)> {
psl::reformat_multiple(sources, 2)
}

pub(crate) fn parse_unwrap_err(schema: &str) -> String {
psl::parse_schema(schema).map(drop).unwrap_err()
}
Expand Down
2 changes: 1 addition & 1 deletion query-engine/dmmf/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub fn dmmf_json_from_schema(schema: &str) -> String {
serde_json::to_string(&dmmf).unwrap()
}

pub fn dmmf_json_from_validate_schema(schema: ValidatedSchema) -> String {
pub fn dmmf_json_from_validated_schema(schema: ValidatedSchema) -> String {
let dmmf = from_precomputed_parts(&schema::build(Arc::new(schema), true));
serde_json::to_string(&dmmf).unwrap()
}
Expand Down

0 comments on commit 883146e

Please sign in to comment.