Skip to content

Commit

Permalink
Cached inference of all definitions in an unpacking (#13979)
Browse files Browse the repository at this point in the history
## Summary

This PR adds a new salsa query and an ingredient to resolve all the
variables involved in an unpacking assignment like `(a, b) = (1, 2)` at
once. Previously, we'd recursively try to match the correct type for
each definition individually which will result in creating duplicate
diagnostics.

This PR still doesn't solve the duplicate diagnostics issue because that
requires a different solution like using salsa accumulator or
de-duplicating the diagnostics manually.

Related: #13773 

## Test Plan

Make sure that all unpack assignment test cases pass, there are no
panics in the corpus tests.

## Todo

- [x] Look at the performance regression
  • Loading branch information
dhruvmanila authored Nov 4, 2024
1 parent 012f385 commit e302c2d
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 202 deletions.
1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub(crate) mod site_packages;
mod stdlib;
pub(crate) mod symbol;
pub mod types;
mod unpack;
mod util;

type FxOrderSet<V> = ordermap::set::OrderSet<V, BuildHasherDefault<FxHasher>>;
77 changes: 47 additions & 30 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ use crate::semantic_index::symbol::{
};
use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
use crate::semantic_index::SemanticIndex;
use crate::unpack::Unpack;
use crate::Db;

use super::constraint::{Constraint, ConstraintNode, PatternConstraint};
use super::definition::{
AssignmentKind, DefinitionCategory, ExceptHandlerDefinitionNodeRef,
MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef,
DefinitionCategory, ExceptHandlerDefinitionNodeRef, MatchPatternDefinitionNodeRef,
WithItemDefinitionNodeRef,
};

mod except_handlers;
Expand All @@ -46,6 +47,13 @@ pub(super) struct SemanticIndexBuilder<'db> {
current_assignments: Vec<CurrentAssignment<'db>>,
/// The match case we're currently visiting.
current_match_case: Option<CurrentMatchCase<'db>>,

/// The [`Unpack`] ingredient for the current definition that belongs to an unpacking
/// assignment. This is used to correctly map multiple definitions to the *same* unpacking.
/// For example, in `a, b = 1, 2`, both `a` and `b` creates separate definitions but they both
/// belong to the same unpacking.
current_unpack: Option<Unpack<'db>>,

/// Flow states at each `break` in the current loop.
loop_break_states: Vec<FlowSnapshot>,
/// Per-scope contexts regarding nested `try`/`except` statements
Expand Down Expand Up @@ -75,6 +83,7 @@ impl<'db> SemanticIndexBuilder<'db> {
scope_stack: Vec::new(),
current_assignments: vec![],
current_match_case: None,
current_unpack: None,
loop_break_states: vec![],
try_node_context_stack_manager: TryNodeContextStackManager::default(),

Expand Down Expand Up @@ -211,7 +220,7 @@ impl<'db> SemanticIndexBuilder<'db> {
let definition_node: DefinitionNodeRef<'_> = definition_node.into();
#[allow(unsafe_code)]
// SAFETY: `definition_node` is guaranteed to be a child of `self.module`
let kind = unsafe { definition_node.into_owned(self.module.clone()) };
let kind = unsafe { definition_node.into_owned(self.module.clone(), self.current_unpack) };
let category = kind.category();
let definition = Definition::new(
self.db,
Expand Down Expand Up @@ -619,25 +628,43 @@ where
}
ast::Stmt::Assign(node) => {
debug_assert_eq!(&self.current_assignments, &[]);

self.visit_expr(&node.value);
self.add_standalone_expression(&node.value);
for (target_index, target) in node.targets.iter().enumerate() {
let kind = match target {
ast::Expr::List(_) | ast::Expr::Tuple(_) => Some(AssignmentKind::Sequence),
ast::Expr::Name(_) => Some(AssignmentKind::Name),
_ => None,
let value = self.add_standalone_expression(&node.value);

for target in &node.targets {
// We only handle assignments to names and unpackings here, other targets like
// attribute and subscript are handled separately as they don't create a new
// definition.
let is_assignment_target = match target {
ast::Expr::List(_) | ast::Expr::Tuple(_) => {
self.current_unpack = Some(Unpack::new(
self.db,
self.file,
self.current_scope(),
#[allow(unsafe_code)]
unsafe {
AstNodeRef::new(self.module.clone(), target)
},
value,
countme::Count::default(),
));
true
}
ast::Expr::Name(_) => true,
_ => false,
};
if let Some(kind) = kind {
self.push_assignment(CurrentAssignment::Assign {
assignment: node,
target_index,
kind,
});

if is_assignment_target {
self.push_assignment(CurrentAssignment::Assign(node));
}

self.visit_expr(target);
if kind.is_some() {
// only need to pop in the case where we pushed something

if is_assignment_target {
// Only need to pop in the case where we pushed something
self.pop_assignment();
self.current_unpack = None;
}
}
}
Expand Down Expand Up @@ -971,18 +998,12 @@ where

if is_definition {
match self.current_assignment().copied() {
Some(CurrentAssignment::Assign {
assignment,
target_index,
kind,
}) => {
Some(CurrentAssignment::Assign(assign)) => {
self.add_definition(
symbol,
AssignmentDefinitionNodeRef {
assignment,
target_index,
value: &assign.value,
name: name_node,
kind,
},
);
}
Expand Down Expand Up @@ -1228,11 +1249,7 @@ where

#[derive(Copy, Clone, Debug, PartialEq)]
enum CurrentAssignment<'a> {
Assign {
assignment: &'a ast::StmtAssign,
target_index: usize,
kind: AssignmentKind,
},
Assign(&'a ast::StmtAssign),
AnnAssign(&'a ast::StmtAnnAssign),
AugAssign(&'a ast::StmtAugAssign),
For(&'a ast::StmtFor),
Expand Down
87 changes: 42 additions & 45 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::ast_node_ref::AstNodeRef;
use crate::module_resolver::file_to_module;
use crate::node_key::NodeKey;
use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId};
use crate::unpack::Unpack;
use crate::Db;

#[salsa::tracked]
Expand All @@ -24,7 +25,7 @@ pub struct Definition<'db> {

#[no_eq]
#[return_ref]
pub(crate) kind: DefinitionKind,
pub(crate) kind: DefinitionKind<'db>,

#[no_eq]
count: countme::Count<Definition<'static>>,
Expand Down Expand Up @@ -166,10 +167,8 @@ pub(crate) struct ImportFromDefinitionNodeRef<'a> {

#[derive(Copy, Clone, Debug)]
pub(crate) struct AssignmentDefinitionNodeRef<'a> {
pub(crate) assignment: &'a ast::StmtAssign,
pub(crate) target_index: usize,
pub(crate) value: &'a ast::Expr,
pub(crate) name: &'a ast::ExprName,
pub(crate) kind: AssignmentKind,
}

#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -213,7 +212,11 @@ pub(crate) struct MatchPatternDefinitionNodeRef<'a> {

impl DefinitionNodeRef<'_> {
#[allow(unsafe_code)]
pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind {
pub(super) unsafe fn into_owned(
self,
parsed: ParsedModule,
unpack: Option<Unpack<'_>>,
) -> DefinitionKind<'_> {
match self {
DefinitionNodeRef::Import(alias) => {
DefinitionKind::Import(AstNodeRef::new(parsed, alias))
Expand All @@ -233,17 +236,13 @@ impl DefinitionNodeRef<'_> {
DefinitionNodeRef::NamedExpression(named) => {
DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named))
}
DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef {
assignment,
target_index,
name,
kind,
}) => DefinitionKind::Assignment(AssignmentDefinitionKind {
assignment: AstNodeRef::new(parsed.clone(), assignment),
target_index,
name: AstNodeRef::new(parsed, name),
kind,
}),
DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { value, name }) => {
DefinitionKind::Assignment(AssignmentDefinitionKind {
target: TargetKind::from(unpack),
value: AstNodeRef::new(parsed.clone(), value),
name: AstNodeRef::new(parsed, name),
})
}
DefinitionNodeRef::AnnotatedAssignment(assign) => {
DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign))
}
Expand Down Expand Up @@ -315,12 +314,7 @@ impl DefinitionNodeRef<'_> {
Self::Function(node) => node.into(),
Self::Class(node) => node.into(),
Self::NamedExpression(node) => node.into(),
Self::Assignment(AssignmentDefinitionNodeRef {
assignment: _,
target_index: _,
name,
kind: _,
}) => name.into(),
Self::Assignment(AssignmentDefinitionNodeRef { value: _, name }) => name.into(),
Self::AnnotatedAssignment(node) => node.into(),
Self::AugmentedAssignment(node) => node.into(),
Self::For(ForStmtDefinitionNodeRef {
Expand Down Expand Up @@ -382,13 +376,13 @@ impl DefinitionCategory {
}

#[derive(Clone, Debug)]
pub enum DefinitionKind {
pub enum DefinitionKind<'db> {
Import(AstNodeRef<ast::Alias>),
ImportFrom(ImportFromDefinitionKind),
Function(AstNodeRef<ast::StmtFunctionDef>),
Class(AstNodeRef<ast::StmtClassDef>),
NamedExpression(AstNodeRef<ast::ExprNamed>),
Assignment(AssignmentDefinitionKind),
Assignment(AssignmentDefinitionKind<'db>),
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
For(ForStmtDefinitionKind),
Expand All @@ -400,7 +394,7 @@ pub enum DefinitionKind {
ExceptHandler(ExceptHandlerDefinitionKind),
}

impl DefinitionKind {
impl DefinitionKind<'_> {
pub(crate) fn category(&self) -> DefinitionCategory {
match self {
// functions, classes, and imports always bind, and we consider them declarations
Expand Down Expand Up @@ -445,6 +439,21 @@ impl DefinitionKind {
}
}

#[derive(Copy, Clone, Debug, PartialEq)]
pub(crate) enum TargetKind<'db> {
Sequence(Unpack<'db>),
Name,
}

impl<'db> From<Option<Unpack<'db>>> for TargetKind<'db> {
fn from(value: Option<Unpack<'db>>) -> Self {
match value {
Some(unpack) => TargetKind::Sequence(unpack),
None => TargetKind::Name,
}
}
}

#[derive(Clone, Debug)]
#[allow(dead_code)]
pub struct MatchPatternDefinitionKind {
Expand Down Expand Up @@ -506,36 +515,24 @@ impl ImportFromDefinitionKind {
}

#[derive(Clone, Debug)]
pub struct AssignmentDefinitionKind {
assignment: AstNodeRef<ast::StmtAssign>,
target_index: usize,
pub struct AssignmentDefinitionKind<'db> {
target: TargetKind<'db>,
value: AstNodeRef<ast::Expr>,
name: AstNodeRef<ast::ExprName>,
kind: AssignmentKind,
}

impl AssignmentDefinitionKind {
pub(crate) fn value(&self) -> &ast::Expr {
&self.assignment.node().value
impl AssignmentDefinitionKind<'_> {
pub(crate) fn target(&self) -> TargetKind {
self.target
}

pub(crate) fn target(&self) -> &ast::Expr {
&self.assignment.node().targets[self.target_index]
pub(crate) fn value(&self) -> &ast::Expr {
self.value.node()
}

pub(crate) fn name(&self) -> &ast::ExprName {
self.name.node()
}

pub(crate) fn kind(&self) -> AssignmentKind {
self.kind
}
}

/// The kind of assignment target expression.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AssignmentKind {
Sequence,
Name,
}

#[derive(Clone, Debug)]
Expand Down
1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod diagnostic;
mod display;
mod infer;
mod narrow;
mod unpacker;

pub fn check_types(db: &dyn Db, file: File) -> TypeCheckDiagnostics {
let _span = tracing::trace_span!("check_types", file=?file.path(db)).entered();
Expand Down
Loading

0 comments on commit e302c2d

Please sign in to comment.