diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index 7ca29f5f2d9c5..0801809f522c2 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -15,7 +15,7 @@ use crate::source::{source_text, Source}; use crate::symbols::{ resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, SymbolId, SymbolTable, }; -use crate::types::{infer_definition_type, infer_symbol_type, Type}; +use crate::types::{infer_definition_type, infer_symbol_public_type, Type}; #[tracing::instrument(level = "debug", skip(db))] pub(crate) fn lint_syntax(db: &dyn LintDb, file_id: FileId) -> QueryResult { @@ -104,14 +104,14 @@ fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> { for (symbol, definition) in context.symbols().all_definitions() { match definition { Definition::Import(import) => { - let ty = context.infer_symbol_type(symbol)?; + let ty = context.infer_symbol_public_type(symbol)?; if ty.is_unknown() { context.push_diagnostic(format!("Unresolved module {}", import.module)); } } Definition::ImportFrom(import) => { - let ty = context.infer_symbol_type(symbol)?; + let ty = context.infer_symbol_public_type(symbol)?; if ty.is_unknown() { let module_name = import.module().map(Deref::deref).unwrap_or_default(); @@ -217,8 +217,8 @@ impl<'a> SemanticLintContext<'a> { &self.symbols } - pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult { - infer_symbol_type( + pub fn infer_symbol_public_type(&self, symbol_id: SymbolId) -> QueryResult { + infer_symbol_public_type( self.db.upcast(), GlobalSymbolId { file_id: self.file_id, diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index f8b0201435555..8628a8549ce35 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -12,7 +12,7 @@ use rustc_hash::FxHashMap; pub(crate) mod infer; -pub(crate) use infer::{infer_definition_type, infer_symbol_type}; +pub(crate) use infer::{infer_definition_type, infer_symbol_public_type}; /// unique ID for a type #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] @@ -119,7 +119,7 @@ impl TypeStore { self.modules.remove(&file_id); } - pub fn cache_symbol_type(&self, symbol: GlobalSymbolId, ty: Type) { + pub fn cache_symbol_public_type(&self, symbol: GlobalSymbolId, ty: Type) { self.add_or_get_module(symbol.file_id) .symbol_types .insert(symbol.symbol_id, ty); @@ -131,7 +131,7 @@ impl TypeStore { .insert(node_key, ty); } - pub fn get_cached_symbol_type(&self, symbol: GlobalSymbolId) -> Option { + pub fn get_cached_symbol_public_type(&self, symbol: GlobalSymbolId) -> Option { self.try_get_module(symbol.file_id)? .symbol_types .get(&symbol.symbol_id) @@ -182,12 +182,12 @@ impl TypeStore { .add_class(name, scope_id, bases) } - fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId { + fn add_union(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId { self.add_or_get_module(file_id).add_union(elems) } fn add_intersection( - &mut self, + &self, file_id: FileId, positive: &[Type], negative: &[Type], @@ -393,7 +393,7 @@ impl ModuleTypeId { fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? { - Ok(Some(infer_symbol_type(db, symbol_id)?)) + Ok(Some(infer_symbol_public_type(db, symbol_id)?)) } else { Ok(None) } @@ -441,7 +441,7 @@ impl ClassTypeId { let ClassType { scope_id, .. } = *self.class(db)?; let table = symbol_table(db, self.file_id)?; if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) { - Ok(Some(infer_symbol_type( + Ok(Some(infer_symbol_public_type( db, GlobalSymbolId { file_id: self.file_id, @@ -497,7 +497,7 @@ struct ModuleTypeStore { unions: IndexVec, /// arena of all intersection types created in this module intersections: IndexVec, - /// cached types of symbols in this module + /// cached public types of symbols in this module symbol_types: FxHashMap, /// cached types of AST nodes in this module node_types: FxHashMap, @@ -777,7 +777,7 @@ mod tests { #[test] fn add_union() { - let mut store = TypeStore::default(); + let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new()); @@ -794,7 +794,7 @@ mod tests { #[test] fn add_intersection() { - let mut store = TypeStore::default(); + let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new()); diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index ff27f25c2dd09..8f032e61e0e29 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -15,22 +15,41 @@ use crate::types::{ModuleTypeId, Type}; use crate::{FileId, Name}; // FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`. +/// Resolve the public-facing type for a symbol (the type seen by other scopes: other modules, or +/// nested functions). Because calls to nested functions and imports can occur anywhere in control +/// flow, this type must be conservative and consider all definitions of the symbol that could +/// possibly be seen by another scope. Currently we take the most conservative approach, which is +/// the union of all definitions. We may be able to narrow this in future to eliminate definitions +/// which can't possibly (or at least likely) be seen by any other scope, so that e.g. we could +/// infer `Literal["1"]` instead of `Literal[1] | Literal["1"]` for `x` in `x = x; x = str(x);`. #[tracing::instrument(level = "trace", skip(db))] -pub fn infer_symbol_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult { +pub fn infer_symbol_public_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult { let symbols = symbol_table(db, symbol.file_id)?; let defs = symbols.definitions(symbol.symbol_id); let jar: &SemanticJar = db.jar()?; - if let Some(ty) = jar.type_store.get_cached_symbol_type(symbol) { + if let Some(ty) = jar.type_store.get_cached_symbol_public_type(symbol) { return Ok(ty); } - // TODO handle multiple defs, conditional defs... - assert_eq!(defs.len(), 1); - - let ty = infer_definition_type(db, symbol, defs[0].clone())?; + let mut tys = defs + .iter() + .map(|def| infer_definition_type(db, symbol, def.clone())) + .peekable(); + let ty = if let Some(first) = tys.next() { + if tys.peek().is_some() { + Type::Union(jar.type_store.add_union( + symbol.file_id, + &Iterator::chain([first].into_iter(), tys).collect::>>()?, + )) + } else { + first? + } + } else { + Type::Unknown + }; - jar.type_store.cache_symbol_type(symbol, ty); + jar.type_store.cache_symbol_public_type(symbol, ty); // TODO record dependencies Ok(ty) @@ -65,7 +84,7 @@ pub fn infer_definition_type( assert!(matches!(level, 0)); let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? { - infer_symbol_type(db, remote_symbol) + infer_symbol_public_type(db, remote_symbol) } else { Ok(Type::Unknown) } @@ -158,7 +177,8 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu ast::Expr::Name(name) => { // TODO look up in the correct scope, don't assume global if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) { - infer_symbol_type(db, GlobalSymbolId { file_id, symbol_id }) + // TODO should use only reachable definitions, not public type + infer_symbol_public_type(db, GlobalSymbolId { file_id, symbol_id }) } else { Ok(Type::Unknown) } @@ -182,7 +202,7 @@ mod tests { resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind, }; use crate::symbols::{symbol_table, GlobalSymbolId}; - use crate::types::{infer_symbol_type, Type}; + use crate::types::{infer_symbol_public_type, Type}; use crate::Name; // TODO with virtual filesystem we shouldn't have to write files to disk for these @@ -228,7 +248,7 @@ mod tests { .root_symbol_id_by_name("E") .expect("E symbol should be found"); - let ty = infer_symbol_type( + let ty = infer_symbol_public_type( db, GlobalSymbolId { file_id: a_file, @@ -259,7 +279,7 @@ mod tests { .root_symbol_id_by_name("Sub") .expect("Sub symbol should be found"); - let ty = infer_symbol_type( + let ty = infer_symbol_public_type( db, GlobalSymbolId { file_id: file, @@ -300,7 +320,7 @@ mod tests { .root_symbol_id_by_name("C") .expect("C symbol should be found"); - let ty = infer_symbol_type( + let ty = infer_symbol_public_type( db, GlobalSymbolId { file_id: file, @@ -345,7 +365,7 @@ mod tests { .root_symbol_id_by_name("D") .expect("D symbol should be found"); - let ty = infer_symbol_type( + let ty = infer_symbol_public_type( db, GlobalSymbolId { file_id: a_file, @@ -375,7 +395,7 @@ mod tests { .root_symbol_id_by_name("x") .expect("x symbol should be found"); - let ty = infer_symbol_type( + let ty = infer_symbol_public_type( db, GlobalSymbolId { file_id: file, @@ -388,4 +408,37 @@ mod tests { assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[1]"); Ok(()) } + + #[test] + fn resolve_union() -> anyhow::Result<()> { + let case = create_test()?; + let db = &case.db; + + let path = case.src.path().join("a.py"); + std::fs::write(path, "if flag:\n x = 1\nelse:\n x = 2")?; + let file = resolve_module(db, ModuleName::new("a"))? + .expect("module should be found") + .path(db)? + .file(); + let syms = symbol_table(db, file)?; + let x_sym = syms + .root_symbol_id_by_name("x") + .expect("x symbol should be found"); + + let ty = infer_symbol_public_type( + db, + GlobalSymbolId { + file_id: file, + symbol_id: x_sym, + }, + )?; + + let jar = HasJar::::jar(db)?; + assert!(matches!(ty, Type::Union(_))); + assert_eq!( + format!("{}", ty.display(&jar.type_store)), + "(Literal[1] | Literal[2])" + ); + Ok(()) + } }