Skip to content

Commit

Permalink
LS: Add "fill struct fields" code action
Browse files Browse the repository at this point in the history
commit-id:0051578a
  • Loading branch information
integraledelebesgue committed Nov 14, 2024
1 parent 920e8b0 commit efe63a2
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
use std::collections::HashMap;

use cairo_lang_defs::ids::LanguageElementId;
use cairo_lang_semantic::Expr;
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::items::function_with_body::SemanticExprLookup;
use cairo_lang_semantic::items::structure::concrete_struct_members;
use cairo_lang_semantic::items::visibility::peek_visible_in;
use cairo_lang_semantic::lookup_item::LookupItemEx;
use cairo_lang_syntax::node::ast::{ExprStructCtorCall, StructArg};
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{SyntaxNode, TypedSyntaxNode};
use lsp_types::{CodeAction, CodeActionKind, CodeActionParams, Range, TextEdit, WorkspaceEdit};
use tracing::error;

use crate::lang::db::{AnalysisDatabase, LsSemanticGroup, LsSyntaxGroup};
use crate::lang::lsp::ToLsp;

/// Generates a completion adding all visible struct members that have not yet been specified
/// to the constructor call, filling their values with a placeholder unit type.
pub fn fill_struct_fields(
db: &AnalysisDatabase,
node: SyntaxNode,
params: &CodeActionParams,
) -> Option<CodeAction> {
let module_file_id = db.find_module_file_containing_node(&node)?;
let module_id = module_file_id.0;
let file_id = module_file_id.file_id(db).ok()?;
let function_id = db.find_lookup_item(&node)?.function_with_body()?;

let constructor = db.first_ancestor_of_kind(node, SyntaxKind::ExprStructCtorCall)?;
let constructor_expr = ExprStructCtorCall::from_syntax_node(db, constructor.clone());

let mut last_important_element = None;
let mut has_trailing_comma = false;

for node in constructor.descendants(db) {
match node.kind(db) {
SyntaxKind::TokenComma => {
has_trailing_comma = true;
last_important_element = Some(node)
}
SyntaxKind::StructArgSingle => {
has_trailing_comma = false;
last_important_element = Some(node)
}
// Don't complete any fields if initialization contains tail.
SyntaxKind::StructArgTail => return None,
_ => {}
}
}

let code_prefix = String::from(if !has_trailing_comma && last_important_element.is_some() {
", "
} else {
" "
});

let struct_arguments = constructor_expr.arguments(db);
let left_brace = struct_arguments.lbrace(db);
let struct_arguments = struct_arguments.arguments(db).elements(db);

let already_present_arguments = struct_arguments
.iter()
.map(|member| match member {
StructArg::StructArgSingle(argument) => {
argument.identifier(db).token(db).as_syntax_node().get_text_without_trivia(db)
}
StructArg::StructArgTail(_) => unreachable!(),
})
.collect::<Vec<_>>();

let constructor_expr_id =
db.lookup_expr_by_ptr(function_id, constructor_expr.stable_ptr().into()).ok()?;

let constructor_semantic = match db.expr_semantic(function_id, constructor_expr_id) {
Expr::StructCtor(semantic) => semantic,
_ => {
error!(
"Semantic expression obtained from StructCtorCall doesn't refer to constructor."
);
return None;
}
};

let concrete_struct_id = constructor_semantic.concrete_struct_id;
let struct_parent_module_id = concrete_struct_id.struct_id(db).parent_module(db);

let arguments_to_complete = concrete_struct_members(db, concrete_struct_id)
.ok()?
.iter()
.filter_map(|(name, member)| {
let name = name.to_string();

if already_present_arguments.contains(&name) {
None
} else if peek_visible_in(db, member.visibility, struct_parent_module_id, module_id) {
Some(format!("{name}: ()"))
} else {
None
}
})
.collect::<Vec<_>>();

let code_to_insert = code_prefix + &arguments_to_complete.join(", ");

let edit_start = last_important_element
.unwrap_or(left_brace.as_syntax_node())
.span_end_without_trivia(db)
.position_in_file(db, file_id)?
.to_lsp();

let mut changes = HashMap::new();
let url = params.text_document.uri.clone();
let change = TextEdit { range: Range::new(edit_start, edit_start), new_text: code_to_insert };

changes.insert(url, vec![change]);

let edit = WorkspaceEdit::new(changes);

Some(CodeAction {
title: String::from("Fill struct fields"),
kind: Some(CodeActionKind::QUICKFIX),
edit: Some(edit),
..Default::default()
})
}
104 changes: 67 additions & 37 deletions crates/cairo-lang-language-server/src/ide/code_actions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
use cairo_lang_syntax::node::SyntaxNode;
use cairo_lang_utils::ordered_hash_map::{Entry, OrderedHashMap};
use itertools::Itertools;
use lsp_types::{
CodeAction, CodeActionOrCommand, CodeActionParams, CodeActionResponse, Diagnostic,
NumberOrString,
};
use tracing::debug;
use tracing::{debug, warn};

use crate::lang::db::{AnalysisDatabase, LsSyntaxGroup};
use crate::lang::lsp::{LsProtoGroup, ToCairo};

mod add_missing_trait;
mod expand_macro;
mod fill_struct_fields;
mod rename_unused_variable;

/// Compute commands for a given text document and range. These commands are typically code fixes to
Expand All @@ -18,61 +21,88 @@ pub fn code_actions(params: CodeActionParams, db: &AnalysisDatabase) -> Option<C
let mut actions = Vec::with_capacity(params.context.diagnostics.len());
let file_id = db.file_for_url(&params.text_document.uri)?;
let node = db.find_syntax_node_at_position(file_id, params.range.start.to_cairo())?;
for diagnostic in params.context.diagnostics.iter() {
actions.extend(
get_code_actions_for_diagnostic(db, &node, diagnostic, &params)
.into_iter()
.map(CodeActionOrCommand::from),
);
}
actions.extend(expand_macro::expand_macro(db, node).into_iter().map(CodeActionOrCommand::from));

actions.extend(
get_code_actions_for_diagnostics(db, &node, &params)
.into_iter()
.map(CodeActionOrCommand::from),
);

actions.extend(
expand_macro::expand_macro(db, node.clone()).into_iter().map(CodeActionOrCommand::from),
);

Some(actions)
}

/// Generate code actions for a given diagnostic.
/// Generate code actions for a given diagnostics in context of [`CodeActionParams`].
///
/// # Arguments
///
/// * `db` - A reference to the Salsa database.
/// * `node` - The syntax node where the diagnostic is located.
/// * `diagnostic` - The diagnostic for which to generate code actions.
/// * `params` - The parameters for the code action request.
///
/// # Returns
///
/// A vector of [`CodeAction`] objects that can be applied to resolve the diagnostic.
fn get_code_actions_for_diagnostic(
/// A vector of [`CodeAction`] objects that can be applied to resolve the diagnostics.
fn get_code_actions_for_diagnostics(
db: &AnalysisDatabase,
node: &SyntaxNode,
diagnostic: &Diagnostic,
params: &CodeActionParams,
) -> Vec<CodeAction> {
let code = match &diagnostic.code {
Some(NumberOrString::String(code)) => code,
Some(NumberOrString::Number(code)) => {
debug!("diagnostic code is not a string: `{code}`");
return vec![];
}
None => {
debug!("diagnostic code is missing");
return vec![];
}
};
let mut diagnostic_groups_by_codes: OrderedHashMap<String, Vec<&Diagnostic>> =
OrderedHashMap::default();

match code.as_str() {
"E0001" => {
vec![rename_unused_variable::rename_unused_variable(
db,
node,
diagnostic.clone(),
params.text_document.uri.clone(),
)]
for diagnostic in params.context.diagnostics.iter() {
if let Some(code) = extract_code(diagnostic) {
match diagnostic_groups_by_codes.entry(code.to_owned()) {
Entry::Occupied(mut entry) => {
entry.get_mut().push(diagnostic);
}
Entry::Vacant(entry) => {
entry.insert(vec![diagnostic]);
}
}
}
"E0002" => add_missing_trait::add_missing_trait(db, node, params.text_document.uri.clone()),
code => {
debug!("no code actions for diagnostic code: {code}");
vec![]
}

diagnostic_groups_by_codes
.into_iter()
.flat_map(|(code, diagnostics)| match code.as_str() {
"E0001" => diagnostics
.into_iter()
.map(|diagnostic| {
rename_unused_variable::rename_unused_variable(
db,
node,
diagnostic.clone(),
params.text_document.uri.clone(),
)
})
.collect_vec(),
"E0002" => {
add_missing_trait::add_missing_trait(db, node, params.text_document.uri.clone())
}
"E0003" => fill_struct_fields::fill_struct_fields(db, node.clone(), params)
.map(|result| vec![result])
.unwrap_or_default(),
_ => {
debug!("no code actions for diagnostic code: {code}");
vec![]
}
})
.collect_vec()
}

/// Extracts [`Diagnostic`] code if it's given as a string, returns None otherwise.
fn extract_code(diagnostic: &Diagnostic) -> Option<&str> {
match &diagnostic.code {
Some(NumberOrString::String(code)) => Some(code),
Some(NumberOrString::Number(code)) => {
warn!("diagnostic code is not a string: `{code}`");
None
}
None => None,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cairo_lang_test_utils::test_file_test!(
{
missing_trait: "missing_trait.txt",
macro_expand: "macro_expand.txt",
fill_struct_fields: "fill_struct_fields.txt",
},
test_quick_fix
);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//! > Test filling missing members in struct constructor.

//! > test_runner_name
test_quick_fix

//! > cairo_project.toml
[crate_roots]
hello = "src"

[config.global]
edition = "2024_07"

//! > cairo_code
mod some_module {
pub struct Struct {
x: u32,
pub y: felt252,
pub z: i16
}

fn build_struct() {
let s = Struct {
x: 0x0,
y: 0x0,
z: 0x0
};

let _a = Struct { <caret> };

let _b = Struct { x: 0x0, <caret> };

let _c = Struct { <caret>..s };
}
}

mod happy_cases {
use super::some_module::Struct;

fn foo() {
let _a = Struct { <caret> };
let _b = Struct { y: 0x0, <caret> };
let _c = Struct { y: 0x0, x: 0x0, <caret> }
}
}

mod unhappy_cases {
fn foo() {
let _a = NonexsitentStruct { <caret> };
}
}

//! > Code action #0
let _a = Struct { <caret> };
Title: Fill struct fields
Add new text: " x: (), y: (), z: ()"
At: Range { start: Position { line: 14, character: 25 }, end: Position { line: 14, character: 25 } }

//! > Code action #1
let _b = Struct { x: 0x0, <caret> };
Title: Fill struct fields
Add new text: " y: (), z: ()"
At: Range { start: Position { line: 16, character: 33 }, end: Position { line: 16, character: 33 } }

//! > Code action #2
let _c = Struct { <caret>..s };
No code actions.

//! > Code action #3
let _a = Struct { <caret> };
Title: Fill struct fields
Add new text: " y: (), z: ()"
At: Range { start: Position { line: 26, character: 25 }, end: Position { line: 26, character: 25 } }

//! > Code action #4
let _b = Struct { y: 0x0, <caret> };
Title: Fill struct fields
Add new text: " z: ()"
At: Range { start: Position { line: 27, character: 33 }, end: Position { line: 27, character: 33 } }

//! > Code action #5
let _c = Struct { y: 0x0, x: 0x0, <caret> }
Title: Fill struct fields
Add new text: " z: ()"
At: Range { start: Position { line: 28, character: 41 }, end: Position { line: 28, character: 41 } }

//! > Code action #6
let _a = NonexsitentStruct { <caret> };
No code actions.
1 change: 1 addition & 0 deletions crates/cairo-lang-semantic/src/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,7 @@ impl SemanticDiagnosticKind {
Self::CannotCallMethod { .. } => {
error_code!(E0002)
}
Self::MissingMember(_) => error_code!(E0003),
_ => return None,
})
}
Expand Down
4 changes: 2 additions & 2 deletions crates/cairo-lang-semantic/src/expr/test_data/pattern
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ struct MyStruct {
}

//! > expected_diagnostics
error: Missing member "b".
error[E0003]: Missing member "b".
--> lib.cairo:6:9
let MyStruct { a: _ } = s;
^***************^
Expand Down Expand Up @@ -163,7 +163,7 @@ error: Redefinition of member "a" on struct "test::MyStruct".
let MyStruct { a: _, c: _, a: _ } = s;
^

error: Missing member "b".
error[E0003]: Missing member "b".
--> lib.cairo:6:9
let MyStruct { a: _, c: _, a: _ } = s;
^***************************^
Expand Down

0 comments on commit efe63a2

Please sign in to comment.