From b02d3f3fd91292f67ae47fef0eee3e06224403a6 Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Mon, 3 Jun 2024 17:27:06 -0600 Subject: [PATCH] [red-knot] infer_symbol_public_type infers union of all definitions (#11669) ## Summary Rename `infer_symbol_type` to `infer_symbol_public_type`, and allow it to work on symbols with more than one definition. For now, use the most cautious/sound inference, which is the union of all definitions. We can prune this union more in future by eliminating definitions if we can show that they can't be visible (this requires both that the symbol is definitely later reassigned, and that there is no intervening call/import that might be able to see the over-written definition). ## Test Plan Added a test showing inference of union from multiple definitions. --- crates/red_knot/src/lint.rs | 10 ++-- crates/red_knot/src/types.rs | 20 +++---- crates/red_knot/src/types/infer.rs | 83 ++++++++++++++++++++++++------ 3 files changed, 83 insertions(+), 30 deletions(-) diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index 7ca29f5f2d9c5..0801809f522c2 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -15,7 +15,7 @@ use crate::source::{source_text, Source}; use crate::symbols::{ resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, SymbolId, SymbolTable, }; -use crate::types::{infer_definition_type, infer_symbol_type, Type}; +use crate::types::{infer_definition_type, infer_symbol_public_type, Type}; #[tracing::instrument(level = "debug", skip(db))] pub(crate) fn lint_syntax(db: &dyn LintDb, file_id: FileId) -> QueryResult { @@ -104,14 +104,14 @@ fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> { for (symbol, definition) in context.symbols().all_definitions() { match definition { Definition::Import(import) => { - let ty = context.infer_symbol_type(symbol)?; + let ty = context.infer_symbol_public_type(symbol)?; if ty.is_unknown() { context.push_diagnostic(format!("Unresolved module {}", import.module)); } } Definition::ImportFrom(import) => { - let ty = context.infer_symbol_type(symbol)?; + let ty = context.infer_symbol_public_type(symbol)?; if ty.is_unknown() { let module_name = import.module().map(Deref::deref).unwrap_or_default(); @@ -217,8 +217,8 @@ impl<'a> SemanticLintContext<'a> { &self.symbols } - pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult { - infer_symbol_type( + pub fn infer_symbol_public_type(&self, symbol_id: SymbolId) -> QueryResult { + infer_symbol_public_type( self.db.upcast(), GlobalSymbolId { file_id: self.file_id, diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index f8b0201435555..8628a8549ce35 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -12,7 +12,7 @@ use rustc_hash::FxHashMap; pub(crate) mod infer; -pub(crate) use infer::{infer_definition_type, infer_symbol_type}; +pub(crate) use infer::{infer_definition_type, infer_symbol_public_type}; /// unique ID for a type #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] @@ -119,7 +119,7 @@ impl TypeStore { self.modules.remove(&file_id); } - pub fn cache_symbol_type(&self, symbol: GlobalSymbolId, ty: Type) { + pub fn cache_symbol_public_type(&self, symbol: GlobalSymbolId, ty: Type) { self.add_or_get_module(symbol.file_id) .symbol_types .insert(symbol.symbol_id, ty); @@ -131,7 +131,7 @@ impl TypeStore { .insert(node_key, ty); } - pub fn get_cached_symbol_type(&self, symbol: GlobalSymbolId) -> Option { + pub fn get_cached_symbol_public_type(&self, symbol: GlobalSymbolId) -> Option { self.try_get_module(symbol.file_id)? .symbol_types .get(&symbol.symbol_id) @@ -182,12 +182,12 @@ impl TypeStore { .add_class(name, scope_id, bases) } - fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId { + fn add_union(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId { self.add_or_get_module(file_id).add_union(elems) } fn add_intersection( - &mut self, + &self, file_id: FileId, positive: &[Type], negative: &[Type], @@ -393,7 +393,7 @@ impl ModuleTypeId { fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? { - Ok(Some(infer_symbol_type(db, symbol_id)?)) + Ok(Some(infer_symbol_public_type(db, symbol_id)?)) } else { Ok(None) } @@ -441,7 +441,7 @@ impl ClassTypeId { let ClassType { scope_id, .. } = *self.class(db)?; let table = symbol_table(db, self.file_id)?; if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) { - Ok(Some(infer_symbol_type( + Ok(Some(infer_symbol_public_type( db, GlobalSymbolId { file_id: self.file_id, @@ -497,7 +497,7 @@ struct ModuleTypeStore { unions: IndexVec, /// arena of all intersection types created in this module intersections: IndexVec, - /// cached types of symbols in this module + /// cached public types of symbols in this module symbol_types: FxHashMap, /// cached types of AST nodes in this module node_types: FxHashMap, @@ -777,7 +777,7 @@ mod tests { #[test] fn add_union() { - let mut store = TypeStore::default(); + let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new()); @@ -794,7 +794,7 @@ mod tests { #[test] fn add_intersection() { - let mut store = TypeStore::default(); + let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new()); diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index ff27f25c2dd09..8f032e61e0e29 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -15,22 +15,41 @@ use crate::types::{ModuleTypeId, Type}; use crate::{FileId, Name}; // FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`. +/// Resolve the public-facing type for a symbol (the type seen by other scopes: other modules, or +/// nested functions). Because calls to nested functions and imports can occur anywhere in control +/// flow, this type must be conservative and consider all definitions of the symbol that could +/// possibly be seen by another scope. Currently we take the most conservative approach, which is +/// the union of all definitions. We may be able to narrow this in future to eliminate definitions +/// which can't possibly (or at least likely) be seen by any other scope, so that e.g. we could +/// infer `Literal["1"]` instead of `Literal[1] | Literal["1"]` for `x` in `x = x; x = str(x);`. #[tracing::instrument(level = "trace", skip(db))] -pub fn infer_symbol_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult { +pub fn infer_symbol_public_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult { let symbols = symbol_table(db, symbol.file_id)?; let defs = symbols.definitions(symbol.symbol_id); let jar: &SemanticJar = db.jar()?; - if let Some(ty) = jar.type_store.get_cached_symbol_type(symbol) { + if let Some(ty) = jar.type_store.get_cached_symbol_public_type(symbol) { return Ok(ty); } - // TODO handle multiple defs, conditional defs... - assert_eq!(defs.len(), 1); - - let ty = infer_definition_type(db, symbol, defs[0].clone())?; + let mut tys = defs + .iter() + .map(|def| infer_definition_type(db, symbol, def.clone())) + .peekable(); + let ty = if let Some(first) = tys.next() { + if tys.peek().is_some() { + Type::Union(jar.type_store.add_union( + symbol.file_id, + &Iterator::chain([first].into_iter(), tys).collect::>>()?, + )) + } else { + first? + } + } else { + Type::Unknown + }; - jar.type_store.cache_symbol_type(symbol, ty); + jar.type_store.cache_symbol_public_type(symbol, ty); // TODO record dependencies Ok(ty) @@ -65,7 +84,7 @@ pub fn infer_definition_type( assert!(matches!(level, 0)); let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? { - infer_symbol_type(db, remote_symbol) + infer_symbol_public_type(db, remote_symbol) } else { Ok(Type::Unknown) } @@ -158,7 +177,8 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu ast::Expr::Name(name) => { // TODO look up in the correct scope, don't assume global if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) { - infer_symbol_type(db, GlobalSymbolId { file_id, symbol_id }) + // TODO should use only reachable definitions, not public type + infer_symbol_public_type(db, GlobalSymbolId { file_id, symbol_id }) } else { Ok(Type::Unknown) } @@ -182,7 +202,7 @@ mod tests { resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind, }; use crate::symbols::{symbol_table, GlobalSymbolId}; - use crate::types::{infer_symbol_type, Type}; + use crate::types::{infer_symbol_public_type, Type}; use crate::Name; // TODO with virtual filesystem we shouldn't have to write files to disk for these @@ -228,7 +248,7 @@ mod tests { .root_symbol_id_by_name("E") .expect("E symbol should be found"); - let ty = infer_symbol_type( + let ty = infer_symbol_public_type( db, GlobalSymbolId { file_id: a_file, @@ -259,7 +279,7 @@ mod tests { .root_symbol_id_by_name("Sub") .expect("Sub symbol should be found"); - let ty = infer_symbol_type( + let ty = infer_symbol_public_type( db, GlobalSymbolId { file_id: file, @@ -300,7 +320,7 @@ mod tests { .root_symbol_id_by_name("C") .expect("C symbol should be found"); - let ty = infer_symbol_type( + let ty = infer_symbol_public_type( db, GlobalSymbolId { file_id: file, @@ -345,7 +365,7 @@ mod tests { .root_symbol_id_by_name("D") .expect("D symbol should be found"); - let ty = infer_symbol_type( + let ty = infer_symbol_public_type( db, GlobalSymbolId { file_id: a_file, @@ -375,7 +395,7 @@ mod tests { .root_symbol_id_by_name("x") .expect("x symbol should be found"); - let ty = infer_symbol_type( + let ty = infer_symbol_public_type( db, GlobalSymbolId { file_id: file, @@ -388,4 +408,37 @@ mod tests { assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[1]"); Ok(()) } + + #[test] + fn resolve_union() -> anyhow::Result<()> { + let case = create_test()?; + let db = &case.db; + + let path = case.src.path().join("a.py"); + std::fs::write(path, "if flag:\n x = 1\nelse:\n x = 2")?; + let file = resolve_module(db, ModuleName::new("a"))? + .expect("module should be found") + .path(db)? + .file(); + let syms = symbol_table(db, file)?; + let x_sym = syms + .root_symbol_id_by_name("x") + .expect("x symbol should be found"); + + let ty = infer_symbol_public_type( + db, + GlobalSymbolId { + file_id: file, + symbol_id: x_sym, + }, + )?; + + let jar = HasJar::::jar(db)?; + assert!(matches!(ty, Type::Union(_))); + assert_eq!( + format!("{}", ty.display(&jar.type_store)), + "(Literal[1] | Literal[2])" + ); + Ok(()) + } }