Skip to content

Commit

Permalink
[red-knot] add if-statement support to FlowGraph (#11673)
Browse files Browse the repository at this point in the history
## Summary

Add if-statement support to FlowGraph. This introduces branches and
joins in the graph for the first time.

## Test Plan

Added tests.
  • Loading branch information
carljm authored Jun 4, 2024
1 parent 1645be0 commit d056d09
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 40 deletions.
170 changes: 153 additions & 17 deletions crates/red_knot/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ pub(crate) enum Definition {
FunctionDef(TypedNodeKey<ast::StmtFunctionDef>),
Assignment(TypedNodeKey<ast::StmtAssign>),
AnnotatedAssignment(TypedNodeKey<ast::StmtAnnAssign>),
None,
// TODO with statements, except handlers, function args...
}

Expand Down Expand Up @@ -288,8 +289,8 @@ impl SymbolTable {
let flow_node_id = self.flow_graph.ast_to_flow[&node_key];
ReachableDefinitionsIterator {
table: self,
flow_node_id,
symbol_id,
pending: vec![flow_node_id],
}
}

Expand Down Expand Up @@ -545,25 +546,30 @@ where
#[derive(Debug)]
pub(crate) struct ReachableDefinitionsIterator<'a> {
table: &'a SymbolTable,
flow_node_id: FlowNodeId,
symbol_id: SymbolId,
pending: Vec<FlowNodeId>,
}

impl<'a> Iterator for ReachableDefinitionsIterator<'a> {
type Item = Definition;

fn next(&mut self) -> Option<Self::Item> {
loop {
match &self.table.flow_graph.flow_nodes_by_id[self.flow_node_id] {
FlowNode::Start => return None,
let flow_node_id = self.pending.pop()?;
match &self.table.flow_graph.flow_nodes_by_id[flow_node_id] {
FlowNode::Start => return Some(Definition::None),
FlowNode::Definition(def_node) => {
if def_node.symbol_id == self.symbol_id {
// we found a definition; previous definitions along this path are not
// reachable
self.flow_node_id = FlowGraph::start();
return Some(def_node.definition.clone());
}
self.flow_node_id = def_node.predecessor;
self.pending.push(def_node.predecessor);
}
FlowNode::Branch(branch_node) => {
self.pending.push(branch_node.predecessor);
}
FlowNode::Phi(phi_node) => {
self.pending.push(phi_node.first_predecessor);
self.pending.push(phi_node.second_predecessor);
}
}
}
Expand All @@ -579,15 +585,31 @@ struct FlowNodeId;
enum FlowNode {
Start,
Definition(DefinitionFlowNode),
Branch(BranchFlowNode),
Phi(PhiFlowNode),
}

/// A Definition node represents a point in control flow where a symbol is defined
#[derive(Debug)]
struct DefinitionFlowNode {
symbol_id: SymbolId,
definition: Definition,
predecessor: FlowNodeId,
}

/// A Branch node represents a branch in control flow
#[derive(Debug)]
struct BranchFlowNode {
predecessor: FlowNodeId,
}

/// A Phi node represents a join point where control flow paths come together
#[derive(Debug)]
struct PhiFlowNode {
first_predecessor: FlowNodeId,
second_predecessor: FlowNodeId,
}

#[derive(Debug, Default)]
struct FlowGraph {
flow_nodes_by_id: IndexVec<FlowNodeId, FlowNode>,
Expand Down Expand Up @@ -636,6 +658,10 @@ impl SymbolTableBuilder {
.add_or_update_symbol(self.cur_scope(), identifier, flags)
}

fn new_flow_node(&mut self, node: FlowNode) -> FlowNodeId {
self.table.flow_graph.flow_nodes_by_id.push(node)
}

fn add_or_update_symbol_with_def(
&mut self,
identifier: &str,
Expand All @@ -647,15 +673,11 @@ impl SymbolTableBuilder {
.entry(symbol_id)
.or_default()
.push(definition.clone());
let new_flow_node_id = self
.table
.flow_graph
.flow_nodes_by_id
.push(FlowNode::Definition(DefinitionFlowNode {
definition,
symbol_id,
predecessor: self.current_flow_node(),
}));
let new_flow_node_id = self.new_flow_node(FlowNode::Definition(DefinitionFlowNode {
definition,
symbol_id,
predecessor: self.current_flow_node(),
}));
self.set_current_flow_node(new_flow_node_id);
symbol_id
}
Expand Down Expand Up @@ -871,13 +893,127 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
ast::visitor::preorder::walk_stmt(self, stmt);
self.current_definition = None;
}
ast::Stmt::If(node) => {
// we visit the if "test" condition first regardless
self.visit_expr(&node.test);

// create branch node: does the if test pass or not?
let if_branch = self.new_flow_node(FlowNode::Branch(BranchFlowNode {
predecessor: self.current_flow_node(),
}));

// visit the body of the `if` clause
self.set_current_flow_node(if_branch);
self.visit_body(&node.body);

// Flow node for the last if/elif condition branch; represents the "no branch
// taken yet" possibility (where "taking a branch" means that the condition in an
// if or elif evaluated to true and control flow went into that clause).
let mut prior_branch = if_branch;

// Flow node for the state after the prior if/elif/else clause; represents "we have
// taken one of the branches up to this point." Initially set to the post-if-clause
// state, later will be set to the phi node joining that possible path with the
// possibility that we took a later if/elif/else clause instead.
let mut post_prior_clause = self.current_flow_node();

// Flag to mark if the final clause is an "else" -- if so, that means the "match no
// clauses" path is not possible, we have to go through one of the clauses.
let mut last_branch_is_else = false;

for clause in &node.elif_else_clauses {
if clause.test.is_some() {
// This is an elif clause. Create a new branch node. Its predecessor is the
// previous branch node, because we can only take one branch in an entire
// if/elif/else chain, so if we take this branch, it can only be because we
// didn't take the previous one.
prior_branch = self.new_flow_node(FlowNode::Branch(BranchFlowNode {
predecessor: prior_branch,
}));
self.set_current_flow_node(prior_branch);
} else {
// This is an else clause. No need to create a branch node; there's no
// branch here, if we haven't taken any previous branch, we definitely go
// into the "else" clause.
self.set_current_flow_node(prior_branch);
last_branch_is_else = true;
}
self.visit_elif_else_clause(clause);
// Update `post_prior_clause` to a new phi node joining the possibility that we
// took any of the previous branches with the possibility that we took the one
// just visited.
post_prior_clause = self.new_flow_node(FlowNode::Phi(PhiFlowNode {
first_predecessor: self.current_flow_node(),
second_predecessor: post_prior_clause,
}));
}

if !last_branch_is_else {
// Final branch was not an "else", which means it's possible we took zero
// branches in the entire if/elif chain, so we need one more phi node to join
// the "no branches taken" possibility.
post_prior_clause = self.new_flow_node(FlowNode::Phi(PhiFlowNode {
first_predecessor: post_prior_clause,
second_predecessor: prior_branch,
}));
}

// Onward, with current flow node set to our final Phi node.
self.set_current_flow_node(post_prior_clause);
}
_ => {
ast::visitor::preorder::walk_stmt(self, stmt);
}
}
}
}

impl std::fmt::Display for FlowGraph {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f, "flowchart TD")?;
for (id, node) in self.flow_nodes_by_id.iter_enumerated() {
write!(f, " id{}", id.as_u32())?;
match node {
FlowNode::Start => writeln!(f, r"[\Start/]")?,
FlowNode::Definition(def_node) => {
writeln!(f, r"(Define symbol {})", def_node.symbol_id.as_u32())?;
writeln!(
f,
r" id{}-->id{}",
def_node.predecessor.as_u32(),
id.as_u32()
)?;
}
FlowNode::Branch(branch_node) => {
writeln!(f, r"{{Branch}}")?;
writeln!(
f,
r" id{}-->id{}",
branch_node.predecessor.as_u32(),
id.as_u32()
)?;
}
FlowNode::Phi(phi_node) => {
writeln!(f, r"((Phi))")?;
writeln!(
f,
r" id{}-->id{}",
phi_node.second_predecessor.as_u32(),
id.as_u32()
)?;
writeln!(
f,
r" id{}-->id{}",
phi_node.first_predecessor.as_u32(),
id.as_u32()
)?;
}
}
}
Ok(())
}
}

#[derive(Debug, Default)]
pub struct SymbolTablesStorage(KeyValueCache<FileId, Arc<SymbolTable>>);

Expand Down
110 changes: 87 additions & 23 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pub fn infer_definition_type(
let file_id = symbol.file_id;

match definition {
Definition::None => Ok(Type::Unbound),
Definition::Import(ImportDefinition {
module: module_name,
}) => {
Expand Down Expand Up @@ -223,7 +224,7 @@ mod tests {
use crate::module::{
resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind,
};
use crate::symbols::{resolve_global_symbol, symbol_table, GlobalSymbolId};
use crate::symbols::resolve_global_symbol;
use crate::types::{infer_symbol_public_type, Type};
use crate::Name;

Expand Down Expand Up @@ -399,30 +400,93 @@ mod tests {
#[test]
fn resolve_visible_def() -> anyhow::Result<()> {
let case = create_test()?;
let db = &case.db;

let path = case.src.path().join("a.py");
std::fs::write(path, "y = 1; y = 2; x = y")?;
let file = resolve_module(db, ModuleName::new("a"))?
.expect("module should be found")
.path(db)?
.file();
let symbols = symbol_table(db, file)?;
let x_sym = symbols
.root_symbol_id_by_name("x")
.expect("x symbol should be found");

let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
symbol_id: x_sym,
},
write_to_path(&case, "a.py", "y = 1; y = 2; x = y")?;

assert_public_type(&case, "a", "x", "Literal[2]")
}

#[test]
fn join_paths() -> anyhow::Result<()> {
let case = create_test()?;

write_to_path(
&case,
"a.py",
"
y = 1
y = 2
if flag:
y = 3
x = y
",
)?;

let jar = HasJar::<SemanticJar>::jar(db)?;
assert!(matches!(ty, Type::IntLiteral(_)));
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[2]");
Ok(())
assert_public_type(&case, "a", "x", "(Literal[2] | Literal[3])")
}

#[test]
fn maybe_unbound() -> anyhow::Result<()> {
let case = create_test()?;

write_to_path(
&case,
"a.py",
"
if flag:
y = 1
x = y
",
)?;

assert_public_type(&case, "a", "x", "(Unbound | Literal[1])")
}

#[test]
fn if_elif_else() -> anyhow::Result<()> {
let case = create_test()?;

write_to_path(
&case,
"a.py",
"
y = 1
y = 2
if flag:
y = 3
elif flag2:
y = 4
else:
r = y
y = 5
s = y
x = y
",
)?;

assert_public_type(&case, "a", "x", "(Literal[3] | Literal[4] | Literal[5])")?;
assert_public_type(&case, "a", "r", "Literal[2]")?;
assert_public_type(&case, "a", "s", "Literal[5]")
}

#[test]
fn if_elif() -> anyhow::Result<()> {
let case = create_test()?;

write_to_path(
&case,
"a.py",
"
y = 1
y = 2
if flag:
y = 3
elif flag2:
y = 4
x = y
",
)?;

assert_public_type(&case, "a", "x", "(Literal[2] | Literal[3] | Literal[4])")
}
}

0 comments on commit d056d09

Please sign in to comment.