diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index f9d6cd767afa15..952e367003951e 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -496,30 +496,37 @@ impl<'db> Type<'db> { /// /// Returns `None` if `self` is not a callable type. #[must_use] - pub fn call(&self, db: &'db dyn Db) -> Option> { + fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> { match self { - Type::Function(function_type) => Some(function_type.return_type(db)), - Type::RevealTypeFunction(function_type) => { - // TODO emit the diagnostic - Some(function_type.return_type(db)) - } + Type::Function(function_type) => CallOutcome::callable(function_type.return_type(db)), + Type::RevealTypeFunction(function_type) => CallOutcome::RevealType { + return_ty: function_type.return_type(db), + revealed_ty: arg_types[0], + }, // TODO annotated return type on `__new__` or metaclass `__call__` - Type::Class(class) => Some(Type::Instance(*class)), + Type::Class(class) => CallOutcome::callable(Type::Instance(class)), - // TODO: handle classes which implement the Callable protocol - Type::Instance(_instance_ty) => Some(Type::Unknown), + // TODO: handle classes which implement the `__call__` protocol + Type::Instance(_instance_ty) => CallOutcome::callable(Type::Unknown), // `Any` is callable, and its return type is also `Any`. - Type::Any => Some(Type::Any), + Type::Any => CallOutcome::callable(Type::Any), - Type::Unknown => Some(Type::Unknown), + Type::Unknown => CallOutcome::callable(Type::Unknown), - // TODO: union and intersection types, if they reduce to `Callable` - Type::Union(_) => Some(Type::Unknown), - Type::Intersection(_) => Some(Type::Unknown), + Type::Union(union) => CallOutcome::Union { + outcomes: union + .elements(db) + .iter() + .map(|elem| elem.call(db, arg_types)) + .collect(), + }, - _ => None, + // TODO: intersection types + Type::Intersection(_) => CallOutcome::callable(Type::Unknown), + + _ => CallOutcome::not_callable(self), } } @@ -531,7 +538,7 @@ impl<'db> Type<'db> { /// for y in x: /// pass /// ``` - fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> { + fn iterate(self, db: &'db dyn Db) -> IterationOutcome<'db> { if let Type::Tuple(tuple_type) = self { return IterationOutcome::Iterable { element_ty: UnionType::from_elements(db, &**tuple_type.elements(db)), @@ -544,18 +551,22 @@ impl<'db> Type<'db> { let dunder_iter_method = iterable_meta_type.member(db, "__iter__"); if !dunder_iter_method.is_unbound() { - let Some(iterator_ty) = dunder_iter_method.call(db) else { + let CallOutcome::Callable { + return_ty: iterator_ty, + } = dunder_iter_method.call(db, &[]) + else { return IterationOutcome::NotIterable { - not_iterable_ty: *self, + not_iterable_ty: self, }; }; let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__"); return dunder_next_method - .call(db) + .call(db, &[]) + .return_ty(db) .map(|element_ty| IterationOutcome::Iterable { element_ty }) .unwrap_or(IterationOutcome::NotIterable { - not_iterable_ty: *self, + not_iterable_ty: self, }); } @@ -568,10 +579,11 @@ impl<'db> Type<'db> { let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__"); dunder_get_item_method - .call(db) + .call(db, &[]) + .return_ty(db) .map(|element_ty| IterationOutcome::Iterable { element_ty }) .unwrap_or(IterationOutcome::NotIterable { - not_iterable_ty: *self, + not_iterable_ty: self, }) } @@ -638,6 +650,98 @@ impl<'db> From<&Type<'db>> for Type<'db> { } } +#[derive(Debug, Clone, PartialEq, Eq)] +enum CallOutcome<'db> { + Callable { + return_ty: Type<'db>, + }, + RevealType { + return_ty: Type<'db>, + revealed_ty: Type<'db>, + }, + NotCallable { + not_callable_ty: Type<'db>, + }, + Union { + outcomes: Box<[CallOutcome<'db>]>, + }, +} + +impl<'db> CallOutcome<'db> { + /// Create a new `CallOutcome::Callable` with given return type. + fn callable(return_ty: Type<'db>) -> CallOutcome { + CallOutcome::Callable { return_ty } + } + + /// Create a new `CallOutcome::NotCallable` with given not-callable type. + fn not_callable(not_callable_ty: Type<'db>) -> CallOutcome { + CallOutcome::NotCallable { not_callable_ty } + } + + /// Get the return type of the call, or `None` if not callable. + fn return_ty(&self, db: &'db dyn Db) -> Option> { + match self { + Self::Callable { return_ty } => Some(*return_ty), + Self::RevealType { + return_ty, + revealed_ty: _, + } => Some(*return_ty), + Self::NotCallable { not_callable_ty: _ } => None, + Self::Union { outcomes } => outcomes + .iter() + .fold(None, |acc, outcome| { + let ty = outcome.return_ty(db); + match (acc, ty) { + (None, None) => None, + (None, Some(ty)) => Some(UnionBuilder::new(db).add(ty)), + (Some(builder), ty) => Some(builder.add(ty.unwrap_or(Type::Unknown))), + } + }) + .map(UnionBuilder::build), + } + } + + /// Get the return type of the call, emitting diagnostics if needed. + fn unwrap_with_diagnostic<'a>( + &self, + db: &'db dyn Db, + node: ast::AnyNodeRef, + builder: &'a mut TypeInferenceBuilder<'db>, + ) -> Type<'db> { + match self { + Self::Callable { return_ty } => *return_ty, + Self::RevealType { + return_ty, + revealed_ty, + } => { + builder.add_diagnostic( + node, + "revealed-type", + format_args!("Revealed type is '{}'.", revealed_ty.display(db)), + ); + *return_ty + } + Self::NotCallable { not_callable_ty } => { + builder.add_diagnostic( + node, + "call-non-callable", + format_args!( + "Object of type '{}' is not callable.", + not_callable_ty.display(db) + ), + ); + Type::Unknown + } + Self::Union { outcomes } => UnionType::from_elements( + db, + outcomes + .iter() + .map(|outcome| outcome.unwrap_with_diagnostic(db, node, builder)), + ), + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum IterationOutcome<'db> { Iterable { element_ty: Type<'db> }, diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 14c9d864056dbf..d040c5fa3aa0ab 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2023,19 +2023,11 @@ impl<'db> TypeInferenceBuilder<'db> { arguments, } = call_expression; - self.infer_arguments(arguments); + let arg_types = self.infer_arguments(arguments); let function_type = self.infer_expression(func); - function_type.call(self.db).unwrap_or_else(|| { - self.add_diagnostic( - func.as_ref().into(), - "call-non-callable", - format_args!( - "Object of type '{}' is not callable", - function_type.display(self.db) - ), - ); - Type::Unknown - }) + function_type + .call(self.db, arg_types.as_slice()) + .unwrap_with_diagnostic(self.db, func.as_ref().into(), self) } fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> { @@ -2410,7 +2402,19 @@ impl<'db> TypeInferenceBuilder<'db> { /// Adds a new diagnostic. /// /// The diagnostic does not get added if the rule isn't enabled for this file. - fn add_diagnostic(&mut self, node: AnyNodeRef, rule: &str, message: std::fmt::Arguments) { + pub(super) fn add_diagnostic( + &mut self, + node: AnyNodeRef, + rule: &str, + message: std::fmt::Arguments, + ) { + self.add_diagnostic_string(node, rule, message.to_string()); + } + + /// Adds a new diagnostic with a string message. + /// + /// The diagnostic does not get added if the rule isn't enabled for this file. + pub(super) fn add_diagnostic_string(&mut self, node: AnyNodeRef, rule: &str, message: String) { if !self.db.is_file_open(self.file) { return; } @@ -2424,7 +2428,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.types.diagnostics.push(TypeCheckDiagnostic { file: self.file, rule: rule.to_string(), - message: message.to_string(), + message, range: node.range(), }); } @@ -2760,7 +2764,7 @@ mod tests { ", )?; - assert_file_diagnostics(&db, "/src/a.py", &["Revealed type of 'x' is 'Literal[1]'."]); + assert_file_diagnostics(&db, "/src/a.py", &["Revealed type is 'Literal[1]'."]); Ok(()) } @@ -3368,7 +3372,7 @@ mod tests { assert_file_diagnostics( &db, "/src/a.py", - &["Object of type 'Literal[123]' is not callable"], + &["Object of type 'Literal[123]' is not callable."], ); }