Skip to content

Commit

Permalink
red-knot: Change resolve_global_symbol to take Module as an argument
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaReiser committed Jun 3, 2024
1 parent 2b28889 commit 18afec8
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 12 deletions.
18 changes: 14 additions & 4 deletions crates/red_knot/src/lint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use ruff_python_ast::{ModModule, StringLiteral};
use crate::cache::KeyValueCache;
use crate::db::{LintDb, LintJar, QueryResult};
use crate::files::FileId;
use crate::module::ModuleName;
use crate::module::{resolve_module, ModuleName};
use crate::parse::{parse, Parsed};
use crate::source::{source_text, Source};
use crate::symbols::{
Expand Down Expand Up @@ -144,9 +144,7 @@ fn lint_bad_overrides(context: &SemanticLintContext) -> QueryResult<()> {
// TODO we should have a special marker on the real typing module (from typeshed) so if you
// have your own "typing" module in your project, we don't consider it THE typing module (and
// same for other stdlib modules that our lint rules care about)
let Some(typing_override) =
resolve_global_symbol(context.db.upcast(), ModuleName::new("typing"), "override")?
else {
let Some(typing_override) = context.resolve_global_symbol("typing", "override")? else {
// TODO once we bundle typeshed, this should be unreachable!()
return Ok(());
};
Expand Down Expand Up @@ -234,6 +232,18 @@ impl<'a> SemanticLintContext<'a> {
pub fn extend_diagnostics(&mut self, diagnostics: impl IntoIterator<Item = String>) {
self.diagnostics.get_mut().extend(diagnostics);
}

pub fn resolve_global_symbol(
&self,
module: &str,
symbol_name: &str,
) -> QueryResult<Option<GlobalSymbolId>> {
let Some(module) = resolve_module(self.db.upcast(), ModuleName::new(module))? else {
return Ok(None);
};

resolve_global_symbol(self.db.upcast(), module, symbol_name)
}
}

#[derive(Debug)]
Expand Down
9 changes: 3 additions & 6 deletions crates/red_knot/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::ast_ids::{NodeKey, TypedNodeKey};
use crate::cache::KeyValueCache;
use crate::db::{QueryResult, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::module::{resolve_module, ModuleName};
use crate::module::{Module, ModuleName};
use crate::parse::parse;
use crate::Name;

Expand All @@ -35,13 +35,10 @@ pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult<Arc<Sym
#[tracing::instrument(level = "debug", skip(db))]
pub fn resolve_global_symbol(
db: &dyn SemanticDb,
module: ModuleName,
module: Module,
name: &str,
) -> QueryResult<Option<GlobalSymbolId>> {
let Some(typing_module) = resolve_module(db, module)? else {
return Ok(None);
};
let typing_file = typing_module.path(db)?.file();
let typing_file = module.path(db)?.file();
let typing_table = symbol_table(db, typing_file)?;
let Some(typing_override) = typing_table.root_symbol_id_by_name(name) else {
return Ok(None);
Expand Down
2 changes: 1 addition & 1 deletion crates/red_knot/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ impl ModuleTypeId {
}

fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult<Option<Type>> {
if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? {
if let Some(symbol_id) = resolve_global_symbol(db, self.module, name)? {
Ok(Some(infer_symbol_type(db, symbol_id)?))
} else {
Ok(None)
Expand Down
6 changes: 5 additions & 1 deletion crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ pub fn infer_definition_type(
// TODO relative imports
assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? {
let Some(module) = resolve_module(db, module_name.clone())? else {
return Ok(Type::Unknown);
};

if let Some(remote_symbol) = resolve_global_symbol(db, module, &name)? {
infer_symbol_type(db, remote_symbol)
} else {
Ok(Type::Unknown)
Expand Down

0 comments on commit 18afec8

Please sign in to comment.