Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] infer_symbol_public_type infers union of all definitions #11669

Merged
merged 6 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions crates/red_knot/src/lint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::source::{source_text, Source};
use crate::symbols::{
resolve_global_symbol, symbol_table, Definition, GlobalSymbolId, SymbolId, SymbolTable,
};
use crate::types::{infer_definition_type, infer_symbol_type, Type};
use crate::types::{infer_definition_type, infer_symbol_public_type, Type};

#[tracing::instrument(level = "debug", skip(db))]
pub(crate) fn lint_syntax(db: &dyn LintDb, file_id: FileId) -> QueryResult<Diagnostics> {
Expand Down Expand Up @@ -104,14 +104,14 @@ fn lint_unresolved_imports(context: &SemanticLintContext) -> QueryResult<()> {
for (symbol, definition) in context.symbols().all_definitions() {
match definition {
Definition::Import(import) => {
let ty = context.infer_symbol_type(symbol)?;
let ty = context.infer_symbol_public_type(symbol)?;

if ty.is_unknown() {
context.push_diagnostic(format!("Unresolved module {}", import.module));
}
}
Definition::ImportFrom(import) => {
let ty = context.infer_symbol_type(symbol)?;
let ty = context.infer_symbol_public_type(symbol)?;

if ty.is_unknown() {
let module_name = import.module().map(Deref::deref).unwrap_or_default();
Expand Down Expand Up @@ -217,8 +217,8 @@ impl<'a> SemanticLintContext<'a> {
&self.symbols
}

pub fn infer_symbol_type(&self, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_type(
pub fn infer_symbol_public_type(&self, symbol_id: SymbolId) -> QueryResult<Type> {
infer_symbol_public_type(
self.db.upcast(),
GlobalSymbolId {
file_id: self.file_id,
Expand Down
20 changes: 10 additions & 10 deletions crates/red_knot/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rustc_hash::FxHashMap;

pub(crate) mod infer;

pub(crate) use infer::{infer_definition_type, infer_symbol_type};
pub(crate) use infer::{infer_definition_type, infer_symbol_public_type};

/// unique ID for a type
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -119,7 +119,7 @@ impl TypeStore {
self.modules.remove(&file_id);
}

pub fn cache_symbol_type(&self, symbol: GlobalSymbolId, ty: Type) {
pub fn cache_symbol_public_type(&self, symbol: GlobalSymbolId, ty: Type) {
self.add_or_get_module(symbol.file_id)
.symbol_types
.insert(symbol.symbol_id, ty);
Expand All @@ -131,7 +131,7 @@ impl TypeStore {
.insert(node_key, ty);
}

pub fn get_cached_symbol_type(&self, symbol: GlobalSymbolId) -> Option<Type> {
pub fn get_cached_symbol_public_type(&self, symbol: GlobalSymbolId) -> Option<Type> {
self.try_get_module(symbol.file_id)?
.symbol_types
.get(&symbol.symbol_id)
Expand Down Expand Up @@ -182,12 +182,12 @@ impl TypeStore {
.add_class(name, scope_id, bases)
}

fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
fn add_union(&self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
self.add_or_get_module(file_id).add_union(elems)
}

fn add_intersection(
&mut self,
&self,
file_id: FileId,
positive: &[Type],
negative: &[Type],
Expand Down Expand Up @@ -393,7 +393,7 @@ impl ModuleTypeId {

fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult<Option<Type>> {
if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? {
Ok(Some(infer_symbol_type(db, symbol_id)?))
Ok(Some(infer_symbol_public_type(db, symbol_id)?))
} else {
Ok(None)
}
Expand Down Expand Up @@ -441,7 +441,7 @@ impl ClassTypeId {
let ClassType { scope_id, .. } = *self.class(db)?;
let table = symbol_table(db, self.file_id)?;
if let Some(symbol_id) = table.symbol_id_by_name(scope_id, name) {
Ok(Some(infer_symbol_type(
Ok(Some(infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: self.file_id,
Expand Down Expand Up @@ -497,7 +497,7 @@ struct ModuleTypeStore {
unions: IndexVec<ModuleUnionTypeId, UnionType>,
/// arena of all intersection types created in this module
intersections: IndexVec<ModuleIntersectionTypeId, IntersectionType>,
/// cached types of symbols in this module
/// cached public types of symbols in this module
symbol_types: FxHashMap<SymbolId, Type>,
/// cached types of AST nodes in this module
node_types: FxHashMap<NodeKey, Type>,
Expand Down Expand Up @@ -777,7 +777,7 @@ mod tests {

#[test]
fn add_union() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
Expand All @@ -794,7 +794,7 @@ mod tests {

#[test]
fn add_intersection() {
let mut store = TypeStore::default();
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1", SymbolTable::root_scope_id(), Vec::new());
Expand Down
83 changes: 68 additions & 15 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,41 @@ use crate::types::{ModuleTypeId, Type};
use crate::{FileId, Name};

// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
/// Resolve the public-facing type for a symbol (the type seen by other scopes: other modules, or
/// nested functions). Because calls to nested functions and imports can occur anywhere in control
/// flow, this type must be conservative and consider all definitions of the symbol that could
/// possibly be seen by another scope. Currently we take the most conservative approach, which is
/// the union of all definitions. We may be able to narrow this in future to eliminate definitions
/// which can't possibly (or at least likely) be seen by any other scope, so that e.g. we could
/// infer `Literal["1"]` instead of `Literal[1] | Literal["1"]` for `x` in `x = x; x = str(x);`.
#[tracing::instrument(level = "trace", skip(db))]
pub fn infer_symbol_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult<Type> {
pub fn infer_symbol_public_type(db: &dyn SemanticDb, symbol: GlobalSymbolId) -> QueryResult<Type> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distinction between public and non public types is unclear to me, also how we guarantee that you can't call this function for a local symbol. But that's something we can tackle independently.

It may help to add some documentation to the query to explain the distinction and for what this query should be used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a docstring to the method that replaces the comment I'd written internally, and tries to explain what the "public type" is. Let me know if it's not clear.

The "public type" is relevant for any symbol that can be seen from outside its own scope: this includes all public symbols, and it also includes all "cellvars" (function-internal symbols that are used by a nested function.)

We may need to refine our terminology here, because this means that even symbols inside a function (which I wouldn't call "public symbols" in terms of cross-module dependency tracking) still have a "public type" which is relevant to type checking.

I don't think we need to "guarantee that you can't call this function" for any symbol. This function shouldn't ever be needed in type inference for purely-local symbols (where every use of the symbol is at a specific known point in control flow and uses only definitions reachable from there), but that's a matter for validating correctness of type inference; this function can return a reasonable result for any symbol (it's just the union of all definitions.)

let symbols = symbol_table(db, symbol.file_id)?;
let defs = symbols.definitions(symbol.symbol_id);
let jar: &SemanticJar = db.jar()?;

if let Some(ty) = jar.type_store.get_cached_symbol_type(symbol) {
if let Some(ty) = jar.type_store.get_cached_symbol_public_type(symbol) {
return Ok(ty);
}

// TODO handle multiple defs, conditional defs...
assert_eq!(defs.len(), 1);

let ty = infer_definition_type(db, symbol, defs[0].clone())?;
let mut tys = defs
.iter()
.map(|def| infer_definition_type(db, symbol, def.clone()))
.peekable();
let ty = if let Some(first) = tys.next() {
if tys.peek().is_some() {
Type::Union(jar.type_store.add_union(
symbol.file_id,
&Iterator::chain([first].into_iter(), tys).collect::<QueryResult<Vec<_>>>()?,
carljm marked this conversation as resolved.
Show resolved Hide resolved
))
} else {
first?
}
} else {
Type::Unknown
};

jar.type_store.cache_symbol_type(symbol, ty);
jar.type_store.cache_symbol_public_type(symbol, ty);

// TODO record dependencies
Ok(ty)
Expand Down Expand Up @@ -65,7 +84,7 @@ pub fn infer_definition_type(
assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? {
infer_symbol_type(db, remote_symbol)
infer_symbol_public_type(db, remote_symbol)
} else {
Ok(Type::Unknown)
}
Expand Down Expand Up @@ -158,7 +177,8 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu
ast::Expr::Name(name) => {
// TODO look up in the correct scope, don't assume global
if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) {
infer_symbol_type(db, GlobalSymbolId { file_id, symbol_id })
// TODO should use only reachable definitions, not public type
infer_symbol_public_type(db, GlobalSymbolId { file_id, symbol_id })
} else {
Ok(Type::Unknown)
}
Expand All @@ -182,7 +202,7 @@ mod tests {
resolve_module, set_module_search_paths, ModuleName, ModuleSearchPath, ModuleSearchPathKind,
};
use crate::symbols::{symbol_table, GlobalSymbolId};
use crate::types::{infer_symbol_type, Type};
use crate::types::{infer_symbol_public_type, Type};
use crate::Name;

// TODO with virtual filesystem we shouldn't have to write files to disk for these
Expand Down Expand Up @@ -228,7 +248,7 @@ mod tests {
.root_symbol_id_by_name("E")
.expect("E symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: a_file,
Expand Down Expand Up @@ -259,7 +279,7 @@ mod tests {
.root_symbol_id_by_name("Sub")
.expect("Sub symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
Expand Down Expand Up @@ -300,7 +320,7 @@ mod tests {
.root_symbol_id_by_name("C")
.expect("C symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
Expand Down Expand Up @@ -345,7 +365,7 @@ mod tests {
.root_symbol_id_by_name("D")
.expect("D symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: a_file,
Expand Down Expand Up @@ -375,7 +395,7 @@ mod tests {
.root_symbol_id_by_name("x")
.expect("x symbol should be found");

let ty = infer_symbol_type(
let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
Expand All @@ -388,4 +408,37 @@ mod tests {
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[1]");
Ok(())
}

#[test]
fn resolve_union() -> anyhow::Result<()> {
let case = create_test()?;
let db = &case.db;

let path = case.src.path().join("a.py");
std::fs::write(path, "if flag:\n x = 1\nelse:\n x = 2")?;
let file = resolve_module(db, ModuleName::new("a"))?
.expect("module should be found")
.path(db)?
.file();
let syms = symbol_table(db, file)?;
let x_sym = syms
.root_symbol_id_by_name("x")
.expect("x symbol should be found");

let ty = infer_symbol_public_type(
db,
GlobalSymbolId {
file_id: file,
symbol_id: x_sym,
},
)?;

let jar = HasJar::<SemanticJar>::jar(db)?;
assert!(matches!(ty, Type::Union(_)));
assert_eq!(
format!("{}", ty.display(&jar.type_store)),
"(Literal[1] | Literal[2])"
);
Ok(())
}
}
Loading