Skip to content

Commit

Permalink
red-knot: Change resolve_global_symbol to take Module as an argum…
Browse files Browse the repository at this point in the history
…ent (#11723)
  • Loading branch information
MichaReiser authored Jun 4, 2024
1 parent 64165be commit 6ffb961
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 24 deletions.
18 changes: 14 additions & 4 deletions crates/red_knot/src/lint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use ruff_python_parser::Parsed;
use crate::cache::KeyValueCache;
use crate::db::{LintDb, LintJar, QueryResult};
use crate::files::FileId;
use crate::module::ModuleName;
use crate::module::{resolve_module, ModuleName};
use crate::parse::parse;
use crate::source::{source_text, Source};
use crate::symbols::{
Expand Down Expand Up @@ -145,9 +145,7 @@ fn lint_bad_overrides(context: &SemanticLintContext) -> QueryResult<()> {
// TODO we should have a special marker on the real typing module (from typeshed) so if you
// have your own "typing" module in your project, we don't consider it THE typing module (and
// same for other stdlib modules that our lint rules care about)
let Some(typing_override) =
resolve_global_symbol(context.db.upcast(), ModuleName::new("typing"), "override")?
else {
let Some(typing_override) = context.resolve_global_symbol("typing", "override")? else {
// TODO once we bundle typeshed, this should be unreachable!()
return Ok(());
};
Expand Down Expand Up @@ -235,6 +233,18 @@ impl<'a> SemanticLintContext<'a> {
pub fn extend_diagnostics(&mut self, diagnostics: impl IntoIterator<Item = String>) {
self.diagnostics.get_mut().extend(diagnostics);
}

pub fn resolve_global_symbol(
&self,
module: &str,
symbol_name: &str,
) -> QueryResult<Option<GlobalSymbolId>> {
let Some(module) = resolve_module(self.db.upcast(), ModuleName::new(module))? else {
return Ok(None);
};

resolve_global_symbol(self.db.upcast(), module, symbol_name)
}
}

#[derive(Debug)]
Expand Down
9 changes: 3 additions & 6 deletions crates/red_knot/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::ast_ids::{NodeKey, TypedNodeKey};
use crate::cache::KeyValueCache;
use crate::db::{QueryResult, SemanticDb, SemanticJar};
use crate::files::FileId;
use crate::module::{resolve_module, ModuleName};
use crate::module::{Module, ModuleName};
use crate::parse::parse;
use crate::Name;

Expand All @@ -35,13 +35,10 @@ pub fn symbol_table(db: &dyn SemanticDb, file_id: FileId) -> QueryResult<Arc<Sym
#[tracing::instrument(level = "debug", skip(db))]
pub fn resolve_global_symbol(
db: &dyn SemanticDb,
module: ModuleName,
module: Module,
name: &str,
) -> QueryResult<Option<GlobalSymbolId>> {
let Some(typing_module) = resolve_module(db, module)? else {
return Ok(None);
};
let typing_file = typing_module.path(db)?.file();
let typing_file = module.path(db)?.file();
let typing_table = symbol_table(db, typing_file)?;
let Some(typing_override) = typing_table.root_symbol_id_by_name(name) else {
return Ok(None);
Expand Down
2 changes: 1 addition & 1 deletion crates/red_knot/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ impl ModuleTypeId {
}

fn get_member(self, db: &dyn SemanticDb, name: &Name) -> QueryResult<Option<Type>> {
if let Some(symbol_id) = resolve_global_symbol(db, self.name(db)?, name)? {
if let Some(symbol_id) = resolve_global_symbol(db, self.module, name)? {
Ok(Some(infer_symbol_public_type(db, symbol_id)?))
} else {
Ok(None)
Expand Down
34 changes: 21 additions & 13 deletions crates/red_knot/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ pub fn infer_definition_type(
// TODO relative imports
assert!(matches!(level, 0));
let module_name = ModuleName::new(module.as_ref().expect("TODO relative imports"));
if let Some(remote_symbol) = resolve_global_symbol(db, module_name, &name)? {
let Some(module) = resolve_module(db, module_name.clone())? else {
return Ok(Type::Unknown);
};

if let Some(remote_symbol) = resolve_global_symbol(db, module, &name)? {
infer_symbol_public_type(db, remote_symbol)
} else {
Ok(Type::Unknown)
Expand Down Expand Up @@ -248,30 +252,34 @@ mod tests {
Ok(TestCase { temp_dir, db, src })
}

fn write_to_path(case: &TestCase, relpath: &str, contents: &str) -> anyhow::Result<()> {
let path = case.src.path().join(relpath);
fn write_to_path(case: &TestCase, relative_path: &str, contents: &str) -> anyhow::Result<()> {
let path = case.src.path().join(relative_path);
std::fs::write(path, contents)?;
Ok(())
}

fn get_public_type(case: &TestCase, modname: &str, varname: &str) -> anyhow::Result<Type> {
fn get_public_type(
case: &TestCase,
module_name: &str,
variable_name: &str,
) -> anyhow::Result<Type> {
let db = &case.db;
let symbol =
resolve_global_symbol(db, ModuleName::new(modname), varname)?.expect("symbol to exist");
let module = resolve_module(db, ModuleName::new(module_name))?.expect("Module to exist");
let symbol = resolve_global_symbol(db, module, variable_name)?.expect("symbol to exist");

Ok(infer_symbol_public_type(db, symbol)?)
}

fn assert_public_type(
case: &TestCase,
modname: &str,
varname: &str,
tyname: &str,
module_name: &str,
variable_name: &str,
type_name: &str,
) -> anyhow::Result<()> {
let ty = get_public_type(case, modname, varname)?;
let ty = get_public_type(case, module_name, variable_name)?;

let jar = HasJar::<SemanticJar>::jar(&case.db)?;
assert_eq!(format!("{}", ty.display(&jar.type_store)), tyname);
assert_eq!(format!("{}", ty.display(&jar.type_store)), type_name);
Ok(())
}

Expand Down Expand Up @@ -399,8 +407,8 @@ mod tests {
.expect("module should be found")
.path(db)?
.file();
let syms = symbol_table(db, file)?;
let x_sym = syms
let symbols = symbol_table(db, file)?;
let x_sym = symbols
.root_symbol_id_by_name("x")
.expect("x symbol should be found");

Expand Down

0 comments on commit 6ffb961

Please sign in to comment.