From e302c2de7c8707f8db8838ed87c08f7f3fe7e12e Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Mon, 4 Nov 2024 17:11:57 +0530 Subject: [PATCH] Cached inference of all definitions in an unpacking (#13979) ## Summary This PR adds a new salsa query and an ingredient to resolve all the variables involved in an unpacking assignment like `(a, b) = (1, 2)` at once. Previously, we'd recursively try to match the correct type for each definition individually which will result in creating duplicate diagnostics. This PR still doesn't solve the duplicate diagnostics issue because that requires a different solution like using salsa accumulator or de-duplicating the diagnostics manually. Related: #13773 ## Test Plan Make sure that all unpack assignment test cases pass, there are no panics in the corpus tests. ## Todo - [x] Look at the performance regression --- crates/red_knot_python_semantic/src/lib.rs | 1 + .../src/semantic_index/builder.rs | 77 ++++---- .../src/semantic_index/definition.rs | 87 +++++----- crates/red_knot_python_semantic/src/types.rs | 1 + .../src/types/infer.rs | 164 ++++-------------- .../src/types/unpacker.rs | 143 +++++++++++++++ crates/red_knot_python_semantic/src/unpack.rs | 43 +++++ 7 files changed, 314 insertions(+), 202 deletions(-) create mode 100644 crates/red_knot_python_semantic/src/types/unpacker.rs create mode 100644 crates/red_knot_python_semantic/src/unpack.rs diff --git a/crates/red_knot_python_semantic/src/lib.rs b/crates/red_knot_python_semantic/src/lib.rs index f4c6637e5a3e3..013dec6219a94 100644 --- a/crates/red_knot_python_semantic/src/lib.rs +++ b/crates/red_knot_python_semantic/src/lib.rs @@ -22,6 +22,7 @@ pub(crate) mod site_packages; mod stdlib; pub(crate) mod symbol; pub mod types; +mod unpack; mod util; type FxOrderSet = ordermap::set::OrderSet>; diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 2f4bf2c651a6a..7bcb35fd5ef8c 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -25,12 +25,13 @@ use crate::semantic_index::symbol::{ }; use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder}; use crate::semantic_index::SemanticIndex; +use crate::unpack::Unpack; use crate::Db; use super::constraint::{Constraint, ConstraintNode, PatternConstraint}; use super::definition::{ - AssignmentKind, DefinitionCategory, ExceptHandlerDefinitionNodeRef, - MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef, + DefinitionCategory, ExceptHandlerDefinitionNodeRef, MatchPatternDefinitionNodeRef, + WithItemDefinitionNodeRef, }; mod except_handlers; @@ -46,6 +47,13 @@ pub(super) struct SemanticIndexBuilder<'db> { current_assignments: Vec>, /// The match case we're currently visiting. current_match_case: Option>, + + /// The [`Unpack`] ingredient for the current definition that belongs to an unpacking + /// assignment. This is used to correctly map multiple definitions to the *same* unpacking. + /// For example, in `a, b = 1, 2`, both `a` and `b` creates separate definitions but they both + /// belong to the same unpacking. + current_unpack: Option>, + /// Flow states at each `break` in the current loop. loop_break_states: Vec, /// Per-scope contexts regarding nested `try`/`except` statements @@ -75,6 +83,7 @@ impl<'db> SemanticIndexBuilder<'db> { scope_stack: Vec::new(), current_assignments: vec![], current_match_case: None, + current_unpack: None, loop_break_states: vec![], try_node_context_stack_manager: TryNodeContextStackManager::default(), @@ -211,7 +220,7 @@ impl<'db> SemanticIndexBuilder<'db> { let definition_node: DefinitionNodeRef<'_> = definition_node.into(); #[allow(unsafe_code)] // SAFETY: `definition_node` is guaranteed to be a child of `self.module` - let kind = unsafe { definition_node.into_owned(self.module.clone()) }; + let kind = unsafe { definition_node.into_owned(self.module.clone(), self.current_unpack) }; let category = kind.category(); let definition = Definition::new( self.db, @@ -619,25 +628,43 @@ where } ast::Stmt::Assign(node) => { debug_assert_eq!(&self.current_assignments, &[]); + self.visit_expr(&node.value); - self.add_standalone_expression(&node.value); - for (target_index, target) in node.targets.iter().enumerate() { - let kind = match target { - ast::Expr::List(_) | ast::Expr::Tuple(_) => Some(AssignmentKind::Sequence), - ast::Expr::Name(_) => Some(AssignmentKind::Name), - _ => None, + let value = self.add_standalone_expression(&node.value); + + for target in &node.targets { + // We only handle assignments to names and unpackings here, other targets like + // attribute and subscript are handled separately as they don't create a new + // definition. + let is_assignment_target = match target { + ast::Expr::List(_) | ast::Expr::Tuple(_) => { + self.current_unpack = Some(Unpack::new( + self.db, + self.file, + self.current_scope(), + #[allow(unsafe_code)] + unsafe { + AstNodeRef::new(self.module.clone(), target) + }, + value, + countme::Count::default(), + )); + true + } + ast::Expr::Name(_) => true, + _ => false, }; - if let Some(kind) = kind { - self.push_assignment(CurrentAssignment::Assign { - assignment: node, - target_index, - kind, - }); + + if is_assignment_target { + self.push_assignment(CurrentAssignment::Assign(node)); } + self.visit_expr(target); - if kind.is_some() { - // only need to pop in the case where we pushed something + + if is_assignment_target { + // Only need to pop in the case where we pushed something self.pop_assignment(); + self.current_unpack = None; } } } @@ -971,18 +998,12 @@ where if is_definition { match self.current_assignment().copied() { - Some(CurrentAssignment::Assign { - assignment, - target_index, - kind, - }) => { + Some(CurrentAssignment::Assign(assign)) => { self.add_definition( symbol, AssignmentDefinitionNodeRef { - assignment, - target_index, + value: &assign.value, name: name_node, - kind, }, ); } @@ -1228,11 +1249,7 @@ where #[derive(Copy, Clone, Debug, PartialEq)] enum CurrentAssignment<'a> { - Assign { - assignment: &'a ast::StmtAssign, - target_index: usize, - kind: AssignmentKind, - }, + Assign(&'a ast::StmtAssign), AnnAssign(&'a ast::StmtAnnAssign), AugAssign(&'a ast::StmtAugAssign), For(&'a ast::StmtFor), diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index b2ac7acee052c..c723edff7d000 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -6,6 +6,7 @@ use crate::ast_node_ref::AstNodeRef; use crate::module_resolver::file_to_module; use crate::node_key::NodeKey; use crate::semantic_index::symbol::{FileScopeId, ScopeId, ScopedSymbolId}; +use crate::unpack::Unpack; use crate::Db; #[salsa::tracked] @@ -24,7 +25,7 @@ pub struct Definition<'db> { #[no_eq] #[return_ref] - pub(crate) kind: DefinitionKind, + pub(crate) kind: DefinitionKind<'db>, #[no_eq] count: countme::Count>, @@ -166,10 +167,8 @@ pub(crate) struct ImportFromDefinitionNodeRef<'a> { #[derive(Copy, Clone, Debug)] pub(crate) struct AssignmentDefinitionNodeRef<'a> { - pub(crate) assignment: &'a ast::StmtAssign, - pub(crate) target_index: usize, + pub(crate) value: &'a ast::Expr, pub(crate) name: &'a ast::ExprName, - pub(crate) kind: AssignmentKind, } #[derive(Copy, Clone, Debug)] @@ -213,7 +212,11 @@ pub(crate) struct MatchPatternDefinitionNodeRef<'a> { impl DefinitionNodeRef<'_> { #[allow(unsafe_code)] - pub(super) unsafe fn into_owned(self, parsed: ParsedModule) -> DefinitionKind { + pub(super) unsafe fn into_owned( + self, + parsed: ParsedModule, + unpack: Option>, + ) -> DefinitionKind<'_> { match self { DefinitionNodeRef::Import(alias) => { DefinitionKind::Import(AstNodeRef::new(parsed, alias)) @@ -233,17 +236,13 @@ impl DefinitionNodeRef<'_> { DefinitionNodeRef::NamedExpression(named) => { DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named)) } - DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { - assignment, - target_index, - name, - kind, - }) => DefinitionKind::Assignment(AssignmentDefinitionKind { - assignment: AstNodeRef::new(parsed.clone(), assignment), - target_index, - name: AstNodeRef::new(parsed, name), - kind, - }), + DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { value, name }) => { + DefinitionKind::Assignment(AssignmentDefinitionKind { + target: TargetKind::from(unpack), + value: AstNodeRef::new(parsed.clone(), value), + name: AstNodeRef::new(parsed, name), + }) + } DefinitionNodeRef::AnnotatedAssignment(assign) => { DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign)) } @@ -315,12 +314,7 @@ impl DefinitionNodeRef<'_> { Self::Function(node) => node.into(), Self::Class(node) => node.into(), Self::NamedExpression(node) => node.into(), - Self::Assignment(AssignmentDefinitionNodeRef { - assignment: _, - target_index: _, - name, - kind: _, - }) => name.into(), + Self::Assignment(AssignmentDefinitionNodeRef { value: _, name }) => name.into(), Self::AnnotatedAssignment(node) => node.into(), Self::AugmentedAssignment(node) => node.into(), Self::For(ForStmtDefinitionNodeRef { @@ -382,13 +376,13 @@ impl DefinitionCategory { } #[derive(Clone, Debug)] -pub enum DefinitionKind { +pub enum DefinitionKind<'db> { Import(AstNodeRef), ImportFrom(ImportFromDefinitionKind), Function(AstNodeRef), Class(AstNodeRef), NamedExpression(AstNodeRef), - Assignment(AssignmentDefinitionKind), + Assignment(AssignmentDefinitionKind<'db>), AnnotatedAssignment(AstNodeRef), AugmentedAssignment(AstNodeRef), For(ForStmtDefinitionKind), @@ -400,7 +394,7 @@ pub enum DefinitionKind { ExceptHandler(ExceptHandlerDefinitionKind), } -impl DefinitionKind { +impl DefinitionKind<'_> { pub(crate) fn category(&self) -> DefinitionCategory { match self { // functions, classes, and imports always bind, and we consider them declarations @@ -445,6 +439,21 @@ impl DefinitionKind { } } +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum TargetKind<'db> { + Sequence(Unpack<'db>), + Name, +} + +impl<'db> From>> for TargetKind<'db> { + fn from(value: Option>) -> Self { + match value { + Some(unpack) => TargetKind::Sequence(unpack), + None => TargetKind::Name, + } + } +} + #[derive(Clone, Debug)] #[allow(dead_code)] pub struct MatchPatternDefinitionKind { @@ -506,36 +515,24 @@ impl ImportFromDefinitionKind { } #[derive(Clone, Debug)] -pub struct AssignmentDefinitionKind { - assignment: AstNodeRef, - target_index: usize, +pub struct AssignmentDefinitionKind<'db> { + target: TargetKind<'db>, + value: AstNodeRef, name: AstNodeRef, - kind: AssignmentKind, } -impl AssignmentDefinitionKind { - pub(crate) fn value(&self) -> &ast::Expr { - &self.assignment.node().value +impl AssignmentDefinitionKind<'_> { + pub(crate) fn target(&self) -> TargetKind { + self.target } - pub(crate) fn target(&self) -> &ast::Expr { - &self.assignment.node().targets[self.target_index] + pub(crate) fn value(&self) -> &ast::Expr { + self.value.node() } pub(crate) fn name(&self) -> &ast::ExprName { self.name.node() } - - pub(crate) fn kind(&self) -> AssignmentKind { - self.kind - } -} - -/// The kind of assignment target expression. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum AssignmentKind { - Sequence, - Name, } #[derive(Clone, Debug)] diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index f064f22c602ab..75eb6688f5d4c 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -27,6 +27,7 @@ mod diagnostic; mod display; mod infer; mod narrow; +mod unpacker; pub fn check_types(db: &dyn Db, file: File) -> TypeCheckDiagnostics { let _span = tracing::trace_span!("check_types", file=?file.path(db)).entered(); diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index c16e1c1a11df9..c07472c5bbe85 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -26,7 +26,6 @@ //! stringified annotations. We have a fourth Salsa query for inferring the deferred types //! associated with a particular definition. Scope-level inference infers deferred types for all //! definitions once the rest of the types in the scope have been inferred. -use std::borrow::Cow; use std::num::NonZeroU32; use itertools::Itertools; @@ -42,7 +41,7 @@ use crate::module_name::ModuleName; use crate::module_resolver::{file_to_module, resolve_module}; use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId}; use crate::semantic_index::definition::{ - AssignmentKind, Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind, + Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind, TargetKind, }; use crate::semantic_index::expression::Expression; use crate::semantic_index::semantic_index; @@ -52,12 +51,14 @@ use crate::stdlib::builtins_module_scope; use crate::types::diagnostic::{ TypeCheckDiagnostic, TypeCheckDiagnostics, TypeCheckDiagnosticsBuilder, }; +use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ bindings_ty, builtins_symbol, declarations_ty, global_symbol, symbol, typing_extensions_symbol, Boundness, BytesLiteralType, ClassType, FunctionType, IterationOutcome, KnownClass, KnownFunction, SliceLiteralType, StringLiteralType, Symbol, Truthiness, TupleType, Type, TypeArrayDisplay, UnionBuilder, UnionType, }; +use crate::unpack::Unpack; use crate::util::subscript::{PyIndex, PySlice}; use crate::Db; @@ -161,6 +162,30 @@ pub(crate) fn infer_expression_types<'db>( TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index).finish() } +/// Infer the types for an [`Unpack`] operation. +/// +/// This infers the expression type and performs structural match against the target expression +/// involved in an unpacking operation. It returns a result-like object that can be used to get the +/// type of the variables involved in this unpacking along with any violations that are detected +/// during this unpacking. +#[salsa::tracked(return_ref)] +fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult<'db> { + let file = unpack.file(db); + let _span = + tracing::trace_span!("infer_unpack_types", unpack=?unpack.as_id(), file=%file.path(db)) + .entered(); + + let value = unpack.value(db); + let scope = unpack.scope(db); + + let result = infer_expression_types(db, value); + let value_ty = result.expression_ty(value.node_ref(db).scoped_ast_id(db, scope)); + + let mut unpacker = Unpacker::new(db, file); + unpacker.unpack(unpack.target(db), value_ty, scope); + unpacker.finish() +} + /// A region within which we can infer types. pub(crate) enum InferenceRegion<'db> { /// infer types for a standalone [`Expression`] @@ -443,7 +468,6 @@ impl<'db> TypeInferenceBuilder<'db> { assignment.target(), assignment.value(), assignment.name(), - assignment.kind(), definition, ); } @@ -1321,10 +1345,9 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_assignment_definition( &mut self, - target: &ast::Expr, + target: TargetKind<'db>, value: &ast::Expr, name: &ast::ExprName, - kind: AssignmentKind, definition: Definition<'db>, ) { let expression = self.index.expression(value); @@ -1332,132 +1355,19 @@ impl<'db> TypeInferenceBuilder<'db> { self.extend(result); let value_ty = self.expression_ty(value); + let name_ast_id = name.scoped_ast_id(self.db, self.scope()); - let target_ty = match kind { - AssignmentKind::Sequence => self.infer_sequence_unpacking(target, value_ty, name), - AssignmentKind::Name => value_ty, + let target_ty = match target { + TargetKind::Sequence(unpack) => { + let unpacked = infer_unpack_types(self.db, unpack); + self.diagnostics.extend(unpacked.diagnostics()); + unpacked.get(name_ast_id).unwrap_or(Type::Unknown) + } + TargetKind::Name => value_ty, }; self.add_binding(name.into(), definition, target_ty); - self.types - .expressions - .insert(name.scoped_ast_id(self.db, self.scope()), target_ty); - } - - fn infer_sequence_unpacking( - &mut self, - target: &ast::Expr, - value_ty: Type<'db>, - name: &ast::ExprName, - ) -> Type<'db> { - // The inner function is recursive and only differs in the return type which is an `Option` - // where if the variable is found, the corresponding type is returned otherwise `None`. - fn inner<'db>( - builder: &mut TypeInferenceBuilder<'db>, - target: &ast::Expr, - value_ty: Type<'db>, - name: &ast::ExprName, - ) -> Option> { - match target { - ast::Expr::Name(target_name) if target_name == name => { - return Some(value_ty); - } - ast::Expr::Starred(ast::ExprStarred { value, .. }) => { - return inner(builder, value, value_ty, name); - } - ast::Expr::List(ast::ExprList { elts, .. }) - | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => match value_ty { - Type::Tuple(tuple_ty) => { - let starred_index = elts.iter().position(ast::Expr::is_starred_expr); - - let element_types = if let Some(starred_index) = starred_index { - if tuple_ty.len(builder.db) >= elts.len() - 1 { - let mut element_types = Vec::with_capacity(elts.len()); - element_types.extend_from_slice( - // SAFETY: Safe because of the length check above. - &tuple_ty.elements(builder.db)[..starred_index], - ); - - // E.g., in `(a, *b, c, d) = ...`, the index of starred element `b` - // is 1 and the remaining elements after that are 2. - let remaining = elts.len() - (starred_index + 1); - // This index represents the type of the last element that belongs - // to the starred expression, in an exclusive manner. - let starred_end_index = tuple_ty.len(builder.db) - remaining; - // SAFETY: Safe because of the length check above. - let _starred_element_types = &tuple_ty.elements(builder.db) - [starred_index..starred_end_index]; - // TODO: Combine the types into a list type. If the - // starred_element_types is empty, then it should be `List[Any]`. - // combine_types(starred_element_types); - element_types.push(Type::Todo); - - element_types.extend_from_slice( - // SAFETY: Safe because of the length check above. - &tuple_ty.elements(builder.db)[starred_end_index..], - ); - Cow::Owned(element_types) - } else { - let mut element_types = tuple_ty.elements(builder.db).to_vec(); - // Subtract 1 to insert the starred expression type at the correct - // index. - element_types.resize(elts.len() - 1, Type::Unknown); - // TODO: This should be `list[Unknown]` - element_types.insert(starred_index, Type::Todo); - Cow::Owned(element_types) - } - } else { - Cow::Borrowed(tuple_ty.elements(builder.db).as_ref()) - }; - - for (index, element) in elts.iter().enumerate() { - if let Some(ty) = inner( - builder, - element, - element_types.get(index).copied().unwrap_or(Type::Unknown), - name, - ) { - return Some(ty); - } - } - } - Type::StringLiteral(string_literal_ty) => { - // Deconstruct the string literal to delegate the inference back to the - // tuple type for correct handling of starred expressions. We could go - // further and deconstruct to an array of `StringLiteral` with each - // individual character, instead of just an array of `LiteralString`, but - // there would be a cost and it's not clear that it's worth it. - let value_ty = Type::Tuple(TupleType::new( - builder.db, - vec![Type::LiteralString; string_literal_ty.len(builder.db)] - .into_boxed_slice(), - )); - if let Some(ty) = inner(builder, target, value_ty, name) { - return Some(ty); - } - } - _ => { - let value_ty = if value_ty.is_literal_string() { - Type::LiteralString - } else { - value_ty.iterate(builder.db).unwrap_with_diagnostic( - AnyNodeRef::from(target), - &mut builder.diagnostics, - ) - }; - for element in elts { - if let Some(ty) = inner(builder, element, value_ty, name) { - return Some(ty); - } - } - } - }, - _ => {} - } - None - } - - inner(self, target, value_ty, name).unwrap_or(Type::Unknown) + self.types.expressions.insert(name_ast_id, target_ty); } fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) { diff --git a/crates/red_knot_python_semantic/src/types/unpacker.rs b/crates/red_knot_python_semantic/src/types/unpacker.rs new file mode 100644 index 0000000000000..9dd9466c05943 --- /dev/null +++ b/crates/red_knot_python_semantic/src/types/unpacker.rs @@ -0,0 +1,143 @@ +use std::borrow::Cow; + +use ruff_db::files::File; +use ruff_python_ast::{self as ast, AnyNodeRef}; +use rustc_hash::FxHashMap; + +use crate::semantic_index::ast_ids::{HasScopedAstId, ScopedExpressionId}; +use crate::semantic_index::symbol::ScopeId; +use crate::types::{TupleType, Type, TypeCheckDiagnostics, TypeCheckDiagnosticsBuilder}; +use crate::Db; + +/// Unpacks the value expression type to their respective targets. +pub(crate) struct Unpacker<'db> { + db: &'db dyn Db, + targets: FxHashMap>, + diagnostics: TypeCheckDiagnosticsBuilder<'db>, +} + +impl<'db> Unpacker<'db> { + pub(crate) fn new(db: &'db dyn Db, file: File) -> Self { + Self { + db, + targets: FxHashMap::default(), + diagnostics: TypeCheckDiagnosticsBuilder::new(db, file), + } + } + + pub(crate) fn unpack(&mut self, target: &ast::Expr, value_ty: Type<'db>, scope: ScopeId<'db>) { + match target { + ast::Expr::Name(target_name) => { + self.targets + .insert(target_name.scoped_ast_id(self.db, scope), value_ty); + } + ast::Expr::Starred(ast::ExprStarred { value, .. }) => { + self.unpack(value, value_ty, scope); + } + ast::Expr::List(ast::ExprList { elts, .. }) + | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => match value_ty { + Type::Tuple(tuple_ty) => { + let starred_index = elts.iter().position(ast::Expr::is_starred_expr); + + let element_types = if let Some(starred_index) = starred_index { + if tuple_ty.len(self.db) >= elts.len() - 1 { + let mut element_types = Vec::with_capacity(elts.len()); + element_types.extend_from_slice( + // SAFETY: Safe because of the length check above. + &tuple_ty.elements(self.db)[..starred_index], + ); + + // E.g., in `(a, *b, c, d) = ...`, the index of starred element `b` + // is 1 and the remaining elements after that are 2. + let remaining = elts.len() - (starred_index + 1); + // This index represents the type of the last element that belongs + // to the starred expression, in an exclusive manner. + let starred_end_index = tuple_ty.len(self.db) - remaining; + // SAFETY: Safe because of the length check above. + let _starred_element_types = + &tuple_ty.elements(self.db)[starred_index..starred_end_index]; + // TODO: Combine the types into a list type. If the + // starred_element_types is empty, then it should be `List[Any]`. + // combine_types(starred_element_types); + element_types.push(Type::Todo); + + element_types.extend_from_slice( + // SAFETY: Safe because of the length check above. + &tuple_ty.elements(self.db)[starred_end_index..], + ); + Cow::Owned(element_types) + } else { + let mut element_types = tuple_ty.elements(self.db).to_vec(); + // Subtract 1 to insert the starred expression type at the correct + // index. + element_types.resize(elts.len() - 1, Type::Unknown); + // TODO: This should be `list[Unknown]` + element_types.insert(starred_index, Type::Todo); + Cow::Owned(element_types) + } + } else { + Cow::Borrowed(tuple_ty.elements(self.db).as_ref()) + }; + + for (index, element) in elts.iter().enumerate() { + self.unpack( + element, + element_types.get(index).copied().unwrap_or(Type::Unknown), + scope, + ); + } + } + Type::StringLiteral(string_literal_ty) => { + // Deconstruct the string literal to delegate the inference back to the + // tuple type for correct handling of starred expressions. We could go + // further and deconstruct to an array of `StringLiteral` with each + // individual character, instead of just an array of `LiteralString`, but + // there would be a cost and it's not clear that it's worth it. + let value_ty = Type::Tuple(TupleType::new( + self.db, + vec![Type::LiteralString; string_literal_ty.len(self.db)] + .into_boxed_slice(), + )); + self.unpack(target, value_ty, scope); + } + _ => { + let value_ty = if value_ty.is_literal_string() { + Type::LiteralString + } else { + value_ty + .iterate(self.db) + .unwrap_with_diagnostic(AnyNodeRef::from(target), &mut self.diagnostics) + }; + for element in elts { + self.unpack(element, value_ty, scope); + } + } + }, + _ => {} + } + } + + pub(crate) fn finish(mut self) -> UnpackResult<'db> { + self.targets.shrink_to_fit(); + UnpackResult { + diagnostics: self.diagnostics.finish(), + targets: self.targets, + } + } +} + +#[derive(Debug, Default, PartialEq, Eq)] +pub(crate) struct UnpackResult<'db> { + targets: FxHashMap>, + diagnostics: TypeCheckDiagnostics, +} + +impl<'db> UnpackResult<'db> { + pub(crate) fn get(&self, expr_id: ScopedExpressionId) -> Option> { + self.targets.get(&expr_id).copied() + } + + pub(crate) fn diagnostics(&self) -> &TypeCheckDiagnostics { + &self.diagnostics + } +} diff --git a/crates/red_knot_python_semantic/src/unpack.rs b/crates/red_knot_python_semantic/src/unpack.rs new file mode 100644 index 0000000000000..13d8d164a3adc --- /dev/null +++ b/crates/red_knot_python_semantic/src/unpack.rs @@ -0,0 +1,43 @@ +use ruff_db::files::File; +use ruff_python_ast::{self as ast}; + +use crate::ast_node_ref::AstNodeRef; +use crate::semantic_index::expression::Expression; +use crate::semantic_index::symbol::{FileScopeId, ScopeId}; +use crate::Db; + +/// This ingredient represents a single unpacking. +/// +/// This is required to make use of salsa to cache the complete unpacking of multiple variables +/// involved. It allows us to: +/// 1. Avoid doing structural match multiple times for each definition +/// 2. Avoid highlighting the same error multiple times +#[salsa::tracked] +pub(crate) struct Unpack<'db> { + #[id] + pub(crate) file: File, + + #[id] + pub(crate) file_scope: FileScopeId, + + /// The target expression that is being unpacked. For example, in `(a, b) = (1, 2)`, the target + /// expression is `(a, b)`. + #[no_eq] + #[return_ref] + pub(crate) target: AstNodeRef, + + /// The ingredient representing the value expression of the unpacking. For example, in + /// `(a, b) = (1, 2)`, the value expression is `(1, 2)`. + #[no_eq] + pub(crate) value: Expression<'db>, + + #[no_eq] + count: countme::Count>, +} + +impl<'db> Unpack<'db> { + /// Returns the scope where the unpacking is happening. + pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { + self.file_scope(db).to_scope_id(db, self.file(db)) + } +}