diff --git a/crates/red_knot/src/lint.rs b/crates/red_knot/src/lint.rs index 7ca29f5f2d9c5f..4441234d777016 100644 --- a/crates/red_knot/src/lint.rs +++ b/crates/red_knot/src/lint.rs @@ -9,7 +9,7 @@ use ruff_python_ast::{ModModule, StringLiteral}; use crate::cache::KeyValueCache; use crate::db::{LintDb, LintJar, QueryResult}; use crate::files::FileId; -use crate::module::ModuleName; +use crate::module::{resolve_module, ModuleName}; use crate::parse::{parse, Parsed}; use crate::source::{source_text, Source}; use crate::symbols::{ @@ -144,9 +144,7 @@ fn lint_bad_overrides(context: &SemanticLintContext) -> QueryResult<()> { // TODO we should have a special marker on the real typing module (from typeshed) so if you // have your own "typing" module in your project, we don't consider it THE typing module (and // same for other stdlib modules that our lint rules care about) - let Some(typing_override) = - resolve_global_symbol(context.db.upcast(), ModuleName::new("typing"), "override")? - else { + let Some(typing_override) = context.resolve_global_symbol("typing", "override")? else { // TODO once we bundle typeshed, this should be unreachable!() return Ok(()); }; @@ -234,6 +232,18 @@ impl<'a> SemanticLintContext<'a> { pub fn extend_diagnostics(&mut self, diagnostics: impl IntoIterator) { self.diagnostics.get_mut().extend(diagnostics); } + + pub fn resolve_global_symbol( + &self, + module: &str, + symbol_name: &str, + ) -> QueryResult> { + let Some(module) = resolve_module(self.db.upcast(), ModuleName::new(module))? else { + return Ok(None); + }; + + resolve_global_symbol(self.db.upcast(), module, symbol_name) + } } #[derive(Debug)] diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index 3b4db31469b88c..04a5ddccb4afaa 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -18,7 +18,7 @@ use crate::ast_ids::{NodeKey, TypedNodeKey}; use crate::cache::KeyValueCache; use crate::db::{QueryResult, SemanticDb, SemanticJar}; use crate::files::FileId; -use crate::module::{resolve_module, ModuleName}; +use crate::module::{Module, ModuleName}; use crate::parse::parse; use crate::Name; @@ -35,13 +35,10 @@ pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult QueryResult> { - let Some(typing_module) = resolve_module(db, module)? else { - return Ok(None); - }; - let typing_file = typing_module.path(db)?.file(); + let typing_file = module.path(db)?.file(); let typing_table = symbol_table(db, typing_file)?; let Some(typing_override) = typing_table.root_symbol_id_by_name(name) else { return Ok(None); diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index f8b02014355557..f963adf734580c 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -392,7 +392,7 @@ impl ModuleTypeId { } fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult> { - if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? { + if let Some(symbol_id) = resolve_global_symbol(db, self.module, name)? { Ok(Some(infer_symbol_type(db, symbol_id)?)) } else { Ok(None) diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index ff27f25c2dd094..bc586a29be378b 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -64,7 +64,11 @@ pub fn infer_definition_type( // TODO relative imports assert!(matches!(level, 0)); let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports")); - if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? { + let Some(module) = resolve_module(db, module_name.clone())? else { + return Ok(Type::Unknown); + }; + + if let Some(remote_symbol) = resolve_global_symbol(db, module, &name)? { infer_symbol_type(db, remote_symbol) } else { Ok(Type::Unknown)