From bd4a947b29564fdbe95a4ee6a188163cd2346eb7 Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Fri, 16 Aug 2024 10:59:59 +0530 Subject: [PATCH] [red-knot] Add symbol and definition for parameters (#12862) ## Summary This PR adds support for adding symbols and definitions for function and lambda parameters to the semantic index. ### Notes * The default expression of a parameter is evaluated in the enclosing scope (not the type parameter or function scope). * The annotation expression of a parameter is evaluated in the type parameter scope if they're present other in the enclosing scope. * The symbols and definitions are added in the function parameter scope. ### Type Inference There are two definitions `Parameter` and `ParameterWithDefault` and their respective `*_definition` methods on the type inference builder. These methods are preferred and are re-used when checking from a different region. ## Test Plan Add test case for validating that the parameters are defined in the function / lambda scope. ### Benchmark update Validated the difference in diagnostics for benchmark code between `main` and this branch. All of them are either directly or indirectly referencing one of the function parameters. The diff is in the PR description. --- .../src/semantic_index.rs | 97 +++++++++++++++++++ .../src/semantic_index/builder.rs | 48 +++++++++ .../src/semantic_index/definition.rs | 33 +++++++ .../src/types/infer.rs | 50 +++++++++- crates/ruff_benchmark/benches/red_knot.rs | 4 +- 5 files changed, 227 insertions(+), 5 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index a3626c0bdc3b9..fef72fe74ca80 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -528,6 +528,103 @@ y = 2 )); } + #[test] + fn function_parameter_symbols() { + let TestCase { db, file } = test_case( + " +def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): + pass +", + ); + + let index = semantic_index(&db, file); + let global_table = symbol_table(&db, global_scope(&db, file)); + + assert_eq!(names(&global_table), vec!["f", "str", "int"]); + + let [(function_scope_id, _function_scope)] = index + .child_scopes(FileScopeId::global()) + .collect::>()[..] + else { + panic!("Expected a function scope") + }; + + let function_table = index.symbol_table(function_scope_id); + assert_eq!( + names(&function_table), + vec!["a", "b", "c", "args", "d", "kwargs"], + ); + + let use_def = index.use_def_map(function_scope_id); + for name in ["a", "b", "c", "d"] { + let [definition] = use_def.public_definitions( + function_table + .symbol_id_by_name(name) + .expect("symbol exists"), + ) else { + panic!("Expected parameter definition for {name}"); + }; + assert!(matches!( + definition.node(&db), + DefinitionKind::ParameterWithDefault(_) + )); + } + for name in ["args", "kwargs"] { + let [definition] = use_def.public_definitions( + function_table + .symbol_id_by_name(name) + .expect("symbol exists"), + ) else { + panic!("Expected parameter definition for {name}"); + }; + assert!(matches!(definition.node(&db), DefinitionKind::Parameter(_))); + } + } + + #[test] + fn lambda_parameter_symbols() { + let TestCase { db, file } = test_case("lambda a, b, c=1, *args, d=2, **kwargs: None"); + + let index = semantic_index(&db, file); + let global_table = symbol_table(&db, global_scope(&db, file)); + + assert!(names(&global_table).is_empty()); + + let [(lambda_scope_id, _lambda_scope)] = index + .child_scopes(FileScopeId::global()) + .collect::>()[..] + else { + panic!("Expected a lambda scope") + }; + + let lambda_table = index.symbol_table(lambda_scope_id); + assert_eq!( + names(&lambda_table), + vec!["a", "b", "c", "args", "d", "kwargs"], + ); + + let use_def = index.use_def_map(lambda_scope_id); + for name in ["a", "b", "c", "d"] { + let [definition] = use_def + .public_definitions(lambda_table.symbol_id_by_name(name).expect("symbol exists")) + else { + panic!("Expected parameter definition for {name}"); + }; + assert!(matches!( + definition.node(&db), + DefinitionKind::ParameterWithDefault(_) + )); + } + for name in ["args", "kwargs"] { + let [definition] = use_def + .public_definitions(lambda_table.symbol_id_by_name(name).expect("symbol exists")) + else { + panic!("Expected parameter definition for {name}"); + }; + assert!(matches!(definition.node(&db), DefinitionKind::Parameter(_))); + } + } + /// Test case to validate that the comprehension scope is correctly identified and that the target /// variable is defined only in the comprehension scope and not in the global scope. #[test] diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index ee17e228d9a34..7fa6fe1639d0c 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -368,6 +368,16 @@ where .add_or_update_symbol(function_def.name.id.clone(), SymbolFlags::IS_DEFINED); self.add_definition(symbol, function_def); + // The default value of the parameters needs to be evaluated in the + // enclosing scope. + for default in function_def + .parameters + .iter_non_variadic_params() + .filter_map(|param| param.default.as_deref()) + { + self.visit_expr(default); + } + self.with_type_params( NodeWithScopeRef::FunctionTypeParameters(function_def), function_def.type_params.as_deref(), @@ -378,6 +388,16 @@ where } builder.push_scope(NodeWithScopeRef::Function(function_def)); + + // Add symbols and definitions for the parameters to the function scope. + for parameter in &*function_def.parameters { + let symbol = builder.add_or_update_symbol( + parameter.name().id().clone(), + SymbolFlags::IS_DEFINED, + ); + builder.add_definition(symbol, parameter); + } + builder.visit_body(&function_def.body); builder.pop_scope() }, @@ -574,9 +594,29 @@ where } ast::Expr::Lambda(lambda) => { if let Some(parameters) = &lambda.parameters { + // The default value of the parameters needs to be evaluated in the + // enclosing scope. + for default in parameters + .iter_non_variadic_params() + .filter_map(|param| param.default.as_deref()) + { + self.visit_expr(default); + } self.visit_parameters(parameters); } self.push_scope(NodeWithScopeRef::Lambda(lambda)); + + // Add symbols and definitions for the parameters to the lambda scope. + if let Some(parameters) = &lambda.parameters { + for parameter in &**parameters { + let symbol = self.add_or_update_symbol( + parameter.name().id().clone(), + SymbolFlags::IS_DEFINED, + ); + self.add_definition(symbol, parameter); + } + } + self.visit_expr(lambda.body.as_ref()); } ast::Expr::If(ast::ExprIf { @@ -654,6 +694,14 @@ where self.pop_scope(); } } + + fn visit_parameters(&mut self, parameters: &'ast ruff_python_ast::Parameters) { + // Intentionally avoid walking default expressions, as we handle them in the enclosing + // scope. + for parameter in parameters.iter().map(ast::AnyParameterRef::as_parameter) { + self.visit_parameter(parameter); + } + } } #[derive(Copy, Clone, Debug)] diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 0c4c9f39fe6a8..6886396160360 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -45,6 +45,7 @@ pub(crate) enum DefinitionNodeRef<'a> { Assignment(AssignmentDefinitionNodeRef<'a>), AnnotatedAssignment(&'a ast::StmtAnnAssign), Comprehension(ComprehensionDefinitionNodeRef<'a>), + Parameter(ast::AnyParameterRef<'a>), } impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> { @@ -95,6 +96,12 @@ impl<'a> From> for DefinitionNodeRef<'a> { } } +impl<'a> From> for DefinitionNodeRef<'a> { + fn from(node: ast::AnyParameterRef<'a>) -> Self { + Self::Parameter(node) + } +} + #[derive(Copy, Clone, Debug)] pub(crate) struct ImportFromDefinitionNodeRef<'a> { pub(crate) node: &'a ast::StmtImportFrom, @@ -150,6 +157,14 @@ impl DefinitionNodeRef<'_> { first, }) } + DefinitionNodeRef::Parameter(parameter) => match parameter { + ast::AnyParameterRef::Variadic(parameter) => { + DefinitionKind::Parameter(AstNodeRef::new(parsed, parameter)) + } + ast::AnyParameterRef::NonVariadic(parameter) => { + DefinitionKind::ParameterWithDefault(AstNodeRef::new(parsed, parameter)) + } + }, } } @@ -168,6 +183,10 @@ impl DefinitionNodeRef<'_> { }) => target.into(), Self::AnnotatedAssignment(node) => node.into(), Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(), + Self::Parameter(node) => match node { + ast::AnyParameterRef::Variadic(parameter) => parameter.into(), + ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(), + }, } } } @@ -182,6 +201,8 @@ pub enum DefinitionKind { Assignment(AssignmentDefinitionKind), AnnotatedAssignment(AstNodeRef), Comprehension(ComprehensionDefinitionKind), + Parameter(AstNodeRef), + ParameterWithDefault(AstNodeRef), } #[derive(Clone, Debug)] @@ -273,3 +294,15 @@ impl From<&ast::Comprehension> for DefinitionNodeKey { Self(NodeKey::from_node(node)) } } + +impl From<&ast::Parameter> for DefinitionNodeKey { + fn from(node: &ast::Parameter) -> Self { + Self(NodeKey::from_node(node)) + } +} + +impl From<&ast::ParameterWithDefault> for DefinitionNodeKey { + fn from(node: &ast::ParameterWithDefault) -> Self { + Self(NodeKey::from_node(node)) + } +} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 5b4ba7ffb3406..46b52ca62751d 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -307,6 +307,12 @@ impl<'db> TypeInferenceBuilder<'db> { definition, ); } + DefinitionKind::Parameter(parameter) => { + self.infer_parameter_definition(parameter, definition); + } + DefinitionKind::ParameterWithDefault(parameter_with_default) => { + self.infer_parameter_with_default_definition(parameter_with_default, definition); + } } } @@ -421,6 +427,13 @@ impl<'db> TypeInferenceBuilder<'db> { .map(|decorator| self.infer_decorator(decorator)) .collect(); + for default in parameters + .iter_non_variadic_params() + .filter_map(|param| param.default.as_deref()) + { + self.infer_expression(default); + } + // If there are type params, parameters and returns are evaluated in that scope. if type_params.is_none() { self.infer_parameters(parameters); @@ -458,10 +471,12 @@ impl<'db> TypeInferenceBuilder<'db> { let ast::ParameterWithDefault { range: _, parameter, - default, + default: _, } = parameter_with_default; - self.infer_parameter(parameter); - self.infer_optional_expression(default.as_deref()); + + self.infer_optional_expression(parameter.annotation.as_deref()); + + self.infer_definition(parameter_with_default); } fn infer_parameter(&mut self, parameter: &ast::Parameter) { @@ -470,7 +485,29 @@ impl<'db> TypeInferenceBuilder<'db> { name: _, annotation, } = parameter; + self.infer_optional_expression(annotation.as_deref()); + + self.infer_definition(parameter); + } + + fn infer_parameter_with_default_definition( + &mut self, + _parameter_with_default: &ast::ParameterWithDefault, + definition: Definition<'db>, + ) { + // TODO(dhruvmanila): Infer types from annotation or default expression + self.types.definitions.insert(definition, Type::Unknown); + } + + fn infer_parameter_definition( + &mut self, + _parameter: &ast::Parameter, + definition: Definition<'db>, + ) { + // TODO(dhruvmanila): Annotation expression is resolved at the enclosing scope, infer the + // parameter type from there + self.types.definitions.insert(definition, Type::Unknown); } fn infer_class_definition_statement(&mut self, class: &ast::StmtClassDef) { @@ -1277,6 +1314,13 @@ impl<'db> TypeInferenceBuilder<'db> { } = lambda_expression; if let Some(parameters) = parameters { + for default in parameters + .iter_non_variadic_params() + .filter_map(|param| param.default.as_deref()) + { + self.infer_expression(default); + } + self.infer_parameters(parameters); } diff --git a/crates/ruff_benchmark/benches/red_knot.rs b/crates/ruff_benchmark/benches/red_knot.rs index 6ad901614219b..727b50a452ed4 100644 --- a/crates/ruff_benchmark/benches/red_knot.rs +++ b/crates/ruff_benchmark/benches/red_knot.rs @@ -89,7 +89,7 @@ fn benchmark_incremental(criterion: &mut Criterion) { let Case { db, parser, .. } = case; let result = db.check_file(*parser).unwrap(); - assert_eq!(result.len(), 402); + assert_eq!(result.len(), 111); }, BatchSize::SmallInput, ); @@ -104,7 +104,7 @@ fn benchmark_cold(criterion: &mut Criterion) { let Case { db, parser, .. } = case; let result = db.check_file(*parser).unwrap(); - assert_eq!(result.len(), 402); + assert_eq!(result.len(), 111); }, BatchSize::SmallInput, );