diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index acb7c480259d2..68f665637b869 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,3 +1,4 @@ +use infer::TypeInferenceBuilder; use ruff_db::files::File; use ruff_python_ast as ast; @@ -400,28 +401,42 @@ impl<'db> Type<'db> { /// for y in x: /// pass /// ``` - /// - /// Returns `None` if `self` represents a type that is not iterable. - fn iterate(&self, db: &'db dyn Db) -> Option> { + fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> { // `self` represents the type of the iterable; // `__iter__` and `__next__` are both looked up on the class of the iterable: - let type_of_class = self.to_meta_type(db); + let iterable_meta_type = self.to_meta_type(db); - let dunder_iter_method = type_of_class.member(db, "__iter__"); + let dunder_iter_method = iterable_meta_type.member(db, "__iter__"); if !dunder_iter_method.is_unbound() { - let iterator_ty = dunder_iter_method.call(db)?; + let Some(iterator_ty) = dunder_iter_method.call(db) else { + return IterationOutcome::NotIterable { + not_iterable_ty: *self, + }; + }; + let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__"); - return dunder_next_method.call(db); + return dunder_next_method + .call(db) + .map(|element_ty| IterationOutcome::Iterable { element_ty }) + .unwrap_or(IterationOutcome::NotIterable { + not_iterable_ty: *self, + }); } // Although it's not considered great practice, // classes that define `__getitem__` are also iterable, // even if they do not define `__iter__`. // - // TODO this is only valid if the `__getitem__` method is annotated as + // TODO(Alex) this is only valid if the `__getitem__` method is annotated as // accepting `int` or `SupportsIndex` - let dunder_get_item_method = type_of_class.member(db, "__getitem__"); - dunder_get_item_method.call(db) + let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__"); + + dunder_get_item_method + .call(db) + .map(|element_ty| IterationOutcome::Iterable { element_ty }) + .unwrap_or(IterationOutcome::NotIterable { + not_iterable_ty: *self, + }) } #[must_use] @@ -463,6 +478,28 @@ impl<'db> Type<'db> { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum IterationOutcome<'db> { + Iterable { element_ty: Type<'db> }, + NotIterable { not_iterable_ty: Type<'db> }, +} + +impl<'db> IterationOutcome<'db> { + fn unwrap_with_diagnostic( + self, + iterable_node: ast::AnyNodeRef, + inference_builder: &mut TypeInferenceBuilder<'db>, + ) -> Type<'db> { + match self { + Self::Iterable { element_ty } => element_ty, + Self::NotIterable { not_iterable_ty } => { + inference_builder.not_iterable_diagnostic(iterable_node, not_iterable_ty); + Type::Unknown + } + } + } +} + #[salsa::interned] pub struct FunctionType<'db> { /// name of the function at definition @@ -789,4 +826,65 @@ mod tests { &["Object of type 'NotIterable' is not iterable"], ); } + + #[test] + fn starred_expressions_must_be_iterable() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class NotIterable: pass + + class Iterator: + def __next__(self) -> int: + return 42 + + class Iterable: + def __iter__(self) -> Iterator: + + x = [*NotIterable()] + y = [*Iterable()] + ", + ) + .unwrap(); + + let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); + let a_file_diagnostics = super::check_types(&db, a_file); + assert_diagnostic_messages( + &a_file_diagnostics, + &["Object of type 'NotIterable' is not iterable"], + ); + } + + #[test] + fn yield_from_expression_must_be_iterable() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class NotIterable: pass + + class Iterator: + def __next__(self) -> int: + return 42 + + class Iterable: + def __iter__(self) -> Iterator: + + def generator_function(): + yield from Iterable() + yield from NotIterable() + ", + ) + .unwrap(); + + let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); + let a_file_diagnostics = super::check_types(&db, a_file); + assert_diagnostic_messages( + &a_file_diagnostics, + &["Object of type 'NotIterable' is not iterable"], + ); + } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 98a038afb0c6d..472a171579d24 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -243,7 +243,7 @@ impl<'db> TypeInference<'db> { /// Similarly, when we encounter a standalone-inferable expression (right-hand side of an /// assignment, type narrowing guard), we use the [`infer_expression_types()`] query to ensure we /// don't infer its types more than once. -struct TypeInferenceBuilder<'db> { +pub(super) struct TypeInferenceBuilder<'db> { db: &'db dyn Db, index: &'db SemanticIndex<'db>, region: InferenceRegion<'db>, @@ -1029,6 +1029,18 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_body(orelse); } + /// Emit a diagnostic declaring that the object represented by `node` is not iterable + pub(super) fn not_iterable_diagnostic(&mut self, node: AnyNodeRef, not_iterable_ty: Type<'db>) { + self.add_diagnostic( + node, + "not-iterable", + format_args!( + "Object of type '{}' is not iterable", + not_iterable_ty.display(self.db) + ), + ); + } + fn infer_for_statement_definition( &mut self, target: &ast::ExprName, @@ -1042,17 +1054,9 @@ impl<'db> TypeInferenceBuilder<'db> { .types .expression_ty(iterable.scoped_ast_id(self.db, self.scope)); - let loop_var_value_ty = iterable_ty.iterate(self.db).unwrap_or_else(|| { - self.add_diagnostic( - iterable.into(), - "not-iterable", - format_args!( - "Object of type '{}' is not iterable", - iterable_ty.display(self.db) - ), - ); - Type::Unknown - }); + let loop_var_value_ty = iterable_ty + .iterate(self.db) + .unwrap_with_diagnostic(iterable.into(), self); self.types .expressions @@ -1812,7 +1816,10 @@ impl<'db> TypeInferenceBuilder<'db> { ctx: _, } = starred; - self.infer_expression(value); + let iterable_ty = self.infer_expression(value); + iterable_ty + .iterate(self.db) + .unwrap_with_diagnostic(value.as_ref().into(), self); // TODO Type::Unknown @@ -1830,9 +1837,12 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_yield_from_expression(&mut self, yield_from: &ast::ExprYieldFrom) -> Type<'db> { let ast::ExprYieldFrom { range: _, value } = yield_from; - self.infer_expression(value); + let iterable_ty = self.infer_expression(value); + iterable_ty + .iterate(self.db) + .unwrap_with_diagnostic(value.as_ref().into(), self); - // TODO get type from awaitable + // TODO get type from `ReturnType` of generator Type::Unknown }