From d056d09547c5fe7ff5e787ada19fc37a73d9b444 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Tue, 4 Jun 2024 15:09:39 -0600 Subject: [PATCH] [red-knot] add if-statement support to FlowGraph (#11673) ## Summary Add if-statement support to FlowGraph. This introduces branches and joins in the graph for the first time. ## Test Plan Added tests. --- crates/red_knot/src/symbols.rs | 170 ++++++++++++++++++++++++++--- crates/red_knot/src/types/infer.rs | 110 +++++++++++++++---- 2 files changed, 240 insertions(+), 40 deletions(-) diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index 84ee74c93975b..b8eb81dbc0757 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -184,6 +184,7 @@ pub(crate) enum Definition { FunctionDef(TypedNodeKey), Assignment(TypedNodeKey), AnnotatedAssignment(TypedNodeKey), + None, // TODO with statements, except handlers, function args... } @@ -288,8 +289,8 @@ impl SymbolTable { let flow_node_id = self.flow_graph.ast_to_flow[&node_key]; ReachableDefinitionsIterator { table: self, - flow_node_id, symbol_id, + pending: vec![flow_node_id], } } @@ -545,8 +546,8 @@ where #[derive(Debug)] pub(crate) struct ReachableDefinitionsIterator<'a> { table: &'a SymbolTable, - flow_node_id: FlowNodeId, symbol_id: SymbolId, + pending: Vec, } impl<'a> Iterator for ReachableDefinitionsIterator<'a> { @@ -554,16 +555,21 @@ impl<'a> Iterator for ReachableDefinitionsIterator<'a> { fn next(&mut self) -> Option { loop { - match &self.table.flow_graph.flow_nodes_by_id[self.flow_node_id] { - FlowNode::Start => return None, + let flow_node_id = self.pending.pop()?; + match &self.table.flow_graph.flow_nodes_by_id[flow_node_id] { + FlowNode::Start => return Some(Definition::None), FlowNode::Definition(def_node) => { if def_node.symbol_id == self.symbol_id { - // we found a definition; previous definitions along this path are not - // reachable - self.flow_node_id = FlowGraph::start(); return Some(def_node.definition.clone()); } - self.flow_node_id = def_node.predecessor; + self.pending.push(def_node.predecessor); + } + FlowNode::Branch(branch_node) => { + self.pending.push(branch_node.predecessor); + } + FlowNode::Phi(phi_node) => { + self.pending.push(phi_node.first_predecessor); + self.pending.push(phi_node.second_predecessor); } } } @@ -579,8 +585,11 @@ struct FlowNodeId; enum FlowNode { Start, Definition(DefinitionFlowNode), + Branch(BranchFlowNode), + Phi(PhiFlowNode), } +/// A Definition node represents a point in control flow where a symbol is defined #[derive(Debug)] struct DefinitionFlowNode { symbol_id: SymbolId, @@ -588,6 +597,19 @@ struct DefinitionFlowNode { predecessor: FlowNodeId, } +/// A Branch node represents a branch in control flow +#[derive(Debug)] +struct BranchFlowNode { + predecessor: FlowNodeId, +} + +/// A Phi node represents a join point where control flow paths come together +#[derive(Debug)] +struct PhiFlowNode { + first_predecessor: FlowNodeId, + second_predecessor: FlowNodeId, +} + #[derive(Debug, Default)] struct FlowGraph { flow_nodes_by_id: IndexVec, @@ -636,6 +658,10 @@ impl SymbolTableBuilder { .add_or_update_symbol(self.cur_scope(), identifier, flags) } + fn new_flow_node(&mut self, node: FlowNode) -> FlowNodeId { + self.table.flow_graph.flow_nodes_by_id.push(node) + } + fn add_or_update_symbol_with_def( &mut self, identifier: &str, @@ -647,15 +673,11 @@ impl SymbolTableBuilder { .entry(symbol_id) .or_default() .push(definition.clone()); - let new_flow_node_id = self - .table - .flow_graph - .flow_nodes_by_id - .push(FlowNode::Definition(DefinitionFlowNode { - definition, - symbol_id, - predecessor: self.current_flow_node(), - })); + let new_flow_node_id = self.new_flow_node(FlowNode::Definition(DefinitionFlowNode { + definition, + symbol_id, + predecessor: self.current_flow_node(), + })); self.set_current_flow_node(new_flow_node_id); symbol_id } @@ -871,6 +893,74 @@ impl PreorderVisitor<'_> for SymbolTableBuilder { ast::visitor::preorder::walk_stmt(self, stmt); self.current_definition = None; } + ast::Stmt::If(node) => { + // we visit the if "test" condition first regardless + self.visit_expr(&node.test); + + // create branch node: does the if test pass or not? + let if_branch = self.new_flow_node(FlowNode::Branch(BranchFlowNode { + predecessor: self.current_flow_node(), + })); + + // visit the body of the `if` clause + self.set_current_flow_node(if_branch); + self.visit_body(&node.body); + + // Flow node for the last if/elif condition branch; represents the "no branch + // taken yet" possibility (where "taking a branch" means that the condition in an + // if or elif evaluated to true and control flow went into that clause). + let mut prior_branch = if_branch; + + // Flow node for the state after the prior if/elif/else clause; represents "we have + // taken one of the branches up to this point." Initially set to the post-if-clause + // state, later will be set to the phi node joining that possible path with the + // possibility that we took a later if/elif/else clause instead. + let mut post_prior_clause = self.current_flow_node(); + + // Flag to mark if the final clause is an "else" -- if so, that means the "match no + // clauses" path is not possible, we have to go through one of the clauses. + let mut last_branch_is_else = false; + + for clause in &node.elif_else_clauses { + if clause.test.is_some() { + // This is an elif clause. Create a new branch node. Its predecessor is the + // previous branch node, because we can only take one branch in an entire + // if/elif/else chain, so if we take this branch, it can only be because we + // didn't take the previous one. + prior_branch = self.new_flow_node(FlowNode::Branch(BranchFlowNode { + predecessor: prior_branch, + })); + self.set_current_flow_node(prior_branch); + } else { + // This is an else clause. No need to create a branch node; there's no + // branch here, if we haven't taken any previous branch, we definitely go + // into the "else" clause. + self.set_current_flow_node(prior_branch); + last_branch_is_else = true; + } + self.visit_elif_else_clause(clause); + // Update `post_prior_clause` to a new phi node joining the possibility that we + // took any of the previous branches with the possibility that we took the one + // just visited. + post_prior_clause = self.new_flow_node(FlowNode::Phi(PhiFlowNode { + first_predecessor: self.current_flow_node(), + second_predecessor: post_prior_clause, + })); + } + + if !last_branch_is_else { + // Final branch was not an "else", which means it's possible we took zero + // branches in the entire if/elif chain, so we need one more phi node to join + // the "no branches taken" possibility. + post_prior_clause = self.new_flow_node(FlowNode::Phi(PhiFlowNode { + first_predecessor: post_prior_clause, + second_predecessor: prior_branch, + })); + } + + // Onward, with current flow node set to our final Phi node. + self.set_current_flow_node(post_prior_clause); + } _ => { ast::visitor::preorder::walk_stmt(self, stmt); } @@ -878,6 +968,52 @@ impl PreorderVisitor<'_> for SymbolTableBuilder { } } +impl std::fmt::Display for FlowGraph { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + writeln!(f, "flowchart TD")?; + for (id, node) in self.flow_nodes_by_id.iter_enumerated() { + write!(f, " id{}", id.as_u32())?; + match node { + FlowNode::Start => writeln!(f, r"[\Start/]")?, + FlowNode::Definition(def_node) => { + writeln!(f, r"(Define symbol {})", def_node.symbol_id.as_u32())?; + writeln!( + f, + r" id{}-->id{}", + def_node.predecessor.as_u32(), + id.as_u32() + )?; + } + FlowNode::Branch(branch_node) => { + writeln!(f, r"{{Branch}}")?; + writeln!( + f, + r" id{}-->id{}", + branch_node.predecessor.as_u32(), + id.as_u32() + )?; + } + FlowNode::Phi(phi_node) => { + writeln!(f, r"((Phi))")?; + writeln!( + f, + r" id{}-->id{}", + phi_node.second_predecessor.as_u32(), + id.as_u32() + )?; + writeln!( + f, + r" id{}-->id{}", + phi_node.first_predecessor.as_u32(), + id.as_u32() + )?; + } + } + } + Ok(()) + } +} + #[derive(Debug, Default)] pub struct SymbolTablesStorage(KeyValueCache>); diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index 8472f9e646f2f..56ea8b334df69 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -79,6 +79,7 @@ pub fn infer_definition_type( let file_id = symbol.file_id; match definition { + Definition::None => Ok(Type::Unbound), Definition::Import(ImportDefinition { module: module_name, }) => { @@ -223,7 +224,7 @@ mod tests { use crate::module::{ resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind, }; - use crate::symbols::{resolve_global_symbol, symbol_table, GlobalSymbolId}; + use crate::symbols::resolve_global_symbol; use crate::types::{infer_symbol_public_type, Type}; use crate::Name; @@ -399,30 +400,93 @@ mod tests { #[test] fn resolve_visible_def() -> anyhow::Result<()> { let case = create_test()?; - let db = &case.db; - let path = case.src.path().join("a.py"); - std::fs::write(path, "y = 1; y = 2; x = y")?; - let file = resolve_module(db, ModuleName::new("a"))? - .expect("module should be found") - .path(db)? - .file(); - let symbols = symbol_table(db, file)?; - let x_sym = symbols - .root_symbol_id_by_name("x") - .expect("x symbol should be found"); - - let ty = infer_symbol_public_type( - db, - GlobalSymbolId { - file_id: file, - symbol_id: x_sym, - }, + write_to_path(&case, "a.py", "y = 1; y = 2; x = y")?; + + assert_public_type(&case, "a", "x", "Literal[2]") + } + + #[test] + fn join_paths() -> anyhow::Result<()> { + let case = create_test()?; + + write_to_path( + &case, + "a.py", + " + y = 1 + y = 2 + if flag: + y = 3 + x = y + ", )?; - let jar = HasJar::::jar(db)?; - assert!(matches!(ty, Type::IntLiteral(_))); - assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[2]"); - Ok(()) + assert_public_type(&case, "a", "x", "(Literal[2] | Literal[3])") + } + + #[test] + fn maybe_unbound() -> anyhow::Result<()> { + let case = create_test()?; + + write_to_path( + &case, + "a.py", + " + if flag: + y = 1 + x = y + ", + )?; + + assert_public_type(&case, "a", "x", "(Unbound | Literal[1])") + } + + #[test] + fn if_elif_else() -> anyhow::Result<()> { + let case = create_test()?; + + write_to_path( + &case, + "a.py", + " + y = 1 + y = 2 + if flag: + y = 3 + elif flag2: + y = 4 + else: + r = y + y = 5 + s = y + x = y + ", + )?; + + assert_public_type(&case, "a", "x", "(Literal[3] | Literal[4] | Literal[5])")?; + assert_public_type(&case, "a", "r", "Literal[2]")?; + assert_public_type(&case, "a", "s", "Literal[5]") + } + + #[test] + fn if_elif() -> anyhow::Result<()> { + let case = create_test()?; + + write_to_path( + &case, + "a.py", + " + y = 1 + y = 2 + if flag: + y = 3 + elif flag2: + y = 4 + x = y + ", + )?; + + assert_public_type(&case, "a", "x", "(Literal[2] | Literal[3] | Literal[4])") } }