Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Improve type inference for except handlers #14838

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,44 @@ def foo(
try:
help()
except x as e:
# TODO: should be `AttributeError`
reveal_type(e) # revealed: @Todo(exception type)
reveal_type(e) # revealed: AttributeError
except y as f:
# TODO: should be `OSError | RuntimeError`
reveal_type(f) # revealed: @Todo(exception type)
reveal_type(f) # revealed: OSError | RuntimeError
except z as g:
# TODO: should be `BaseException`
reveal_type(g) # revealed: @Todo(exception type)
reveal_type(g) # revealed: @Todo(full tuple[...] support)
```

## Invalid exception handlers

```py
try:
pass
# error: [invalid-exception] "Cannot catch object of type `Literal[3]` in an exception handler (must be a `BaseException` subclass or a tuple of `BaseException` subclasses)"
except 3 as e:
reveal_type(e) # revealed: Unknown

try:
pass
# error: [invalid-exception] "Cannot catch object of type `Literal["foo"]` in an exception handler (must be a `BaseException` subclass or a tuple of `BaseException` subclasses)"
# error: [invalid-exception] "Cannot catch object of type `Literal[b"bar"]` in an exception handler (must be a `BaseException` subclass or a tuple of `BaseException` subclasses)"
except (ValueError, OSError, "foo", b"bar") as e:
reveal_type(e) # revealed: ValueError | OSError | Unknown

def foo(
x: type[str],
y: tuple[type[OSError], type[RuntimeError], int],
z: tuple[type[str], ...],
):
try:
help()
# error: [invalid-exception]
except x as e:
reveal_type(e) # revealed: Unknown
# error: [invalid-exception]
except y as f:
reveal_type(f) # revealed: OSError | RuntimeError | Unknown
except z as g:
# TODO: should emit a diagnostic here:
reveal_type(g) # revealed: @Todo(full tuple[...] support)
```
Original file line number Diff line number Diff line change
@@ -1,30 +1,59 @@
# Except star
# `except*`

## Except\* with BaseException
## `except*` with `BaseException`

```py
try:
help()
except* BaseException as e:
# TODO: should be `BaseExceptionGroup[BaseException]` --Alex
reveal_type(e) # revealed: BaseExceptionGroup
```

## Except\* with specific exception
## `except*` with specific exception

```py
try:
help()
except* OSError as e:
# TODO(Alex): more precise would be `ExceptionGroup[OSError]`
# TODO: more precise would be `ExceptionGroup[OSError]` --Alex
# (needs homogenous tuples + generics)
reveal_type(e) # revealed: BaseExceptionGroup
```

## Except\* with multiple exceptions
## `except*` with multiple exceptions

```py
try:
help()
except* (TypeError, AttributeError) as e:
# TODO(Alex): more precise would be `ExceptionGroup[TypeError | AttributeError]`.
# TODO: more precise would be `ExceptionGroup[TypeError | AttributeError]` --Alex
# (needs homogenous tuples + generics)
reveal_type(e) # revealed: BaseExceptionGroup
```

## `except*` with mix of `Exception`s and `BaseException`s

```py
try:
help()
except* (KeyboardInterrupt, AttributeError) as e:
# TODO: more precise would be `BaseExceptionGroup[KeyboardInterrupt | AttributeError]` --Alex
reveal_type(e) # revealed: BaseExceptionGroup
```

## Invalid `except*` handlers

```py
try:
help()
except* 3 as e: # error: [invalid-exception]
# TODO: Should be `BaseExceptionGroup[Unknown]` --Alex
reveal_type(e) # revealed: BaseExceptionGroup

try:
help()
except* (AttributeError, 42) as e: # error: [invalid-exception]
# TODO: Should be `BaseExceptionGroup[AttributeError | Unknown]` --Alex
reveal_type(e) # revealed: BaseExceptionGroup
```
20 changes: 20 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,8 @@ impl<'db> Type<'db> {
| KnownClass::Set
| KnownClass::Dict
| KnownClass::Slice
| KnownClass::BaseException
| KnownClass::BaseExceptionGroup
| KnownClass::GenericAlias
| KnownClass::ModuleType
| KnownClass::FunctionType
Expand Down Expand Up @@ -1857,6 +1859,8 @@ pub enum KnownClass {
Set,
Dict,
Slice,
BaseException,
BaseExceptionGroup,
// Types
GenericAlias,
ModuleType,
Expand Down Expand Up @@ -1887,6 +1891,8 @@ impl<'db> KnownClass {
Self::List => "list",
Self::Type => "type",
Self::Slice => "slice",
Self::BaseException => "BaseException",
Self::BaseExceptionGroup => "BaseExceptionGroup",
Self::GenericAlias => "GenericAlias",
Self::ModuleType => "ModuleType",
Self::FunctionType => "FunctionType",
Expand Down Expand Up @@ -1914,6 +1920,12 @@ impl<'db> KnownClass {
.unwrap_or(Type::Unknown)
}

pub fn to_subclass_of(self, db: &'db dyn Db) -> Option<Type<'db>> {
self.to_class_literal(db)
.into_class_literal()
.map(|ClassLiteralType { class }| Type::subclass_of(class))
}

/// Return the module in which we should look up the definition for this class
pub(crate) fn canonical_module(self, db: &'db dyn Db) -> CoreStdlibModule {
match self {
Expand All @@ -1928,6 +1940,8 @@ impl<'db> KnownClass {
| Self::Tuple
| Self::Set
| Self::Dict
| Self::BaseException
| Self::BaseExceptionGroup
| Self::Slice => CoreStdlibModule::Builtins,
Self::VersionInfo => CoreStdlibModule::Sys,
Self::GenericAlias | Self::ModuleType | Self::FunctionType => CoreStdlibModule::Types,
Expand Down Expand Up @@ -1971,6 +1985,8 @@ impl<'db> KnownClass {
| Self::ModuleType
| Self::FunctionType
| Self::SpecialForm
| Self::BaseException
| Self::BaseExceptionGroup
| Self::TypeVar => false,
}
}
Expand All @@ -1992,6 +2008,8 @@ impl<'db> KnownClass {
"dict" => Self::Dict,
"list" => Self::List,
"slice" => Self::Slice,
"BaseException" => Self::BaseException,
"BaseExceptionGroup" => Self::BaseExceptionGroup,
"GenericAlias" => Self::GenericAlias,
"NoneType" => Self::NoneType,
"ModuleType" => Self::ModuleType,
Expand Down Expand Up @@ -2028,6 +2046,8 @@ impl<'db> KnownClass {
| Self::GenericAlias
| Self::ModuleType
| Self::VersionInfo
| Self::BaseException
| Self::BaseExceptionGroup
| Self::FunctionType => module.name() == self.canonical_module(db).as_str(),
Self::NoneType => matches!(module.name().as_str(), "_typeshed" | "types"),
Self::SpecialForm | Self::TypeVar | Self::TypeAliasType | Self::NoDefaultType => {
Expand Down
12 changes: 12 additions & 0 deletions crates/red_knot_python_semantic/src/types/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,18 @@ impl<'db> TypeCheckDiagnosticsBuilder<'db> {
);
}

pub(super) fn add_invalid_exception(&mut self, db: &dyn Db, node: &ast::Expr, ty: Type) {
self.add(
node.into(),
"invalid-exception",
format_args!(
"Cannot catch object of type `{}` in an exception handler \
(must be a `BaseException` subclass or a tuple of `BaseException` subclasses)",
ty.display(db)
),
);
}

/// Adds a new diagnostic.
///
/// The diagnostic does not get added if the rule isn't enabled for this file.
Expand Down
76 changes: 46 additions & 30 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1512,40 +1512,56 @@ impl<'db> TypeInferenceBuilder<'db> {
except_handler_definition: &ExceptHandlerDefinitionKind,
definition: Definition<'db>,
) {
let node_ty = except_handler_definition
.handled_exceptions()
.map(|ty| self.infer_expression(ty))
// If there is no handled exception, it's invalid syntax;
// a diagnostic will have already been emitted
.unwrap_or(Type::Unknown);
let node = except_handler_definition.handled_exceptions();

// If there is no handled exception, it's invalid syntax;
// a diagnostic will have already been emitted
let node_ty = node.map_or(Type::Unknown, |ty| self.infer_expression(ty));

// If it's an `except*` handler, this won't actually be the type of the bound symbol;
// it will actually be the type of the generic parameters to `BaseExceptionGroup` or `ExceptionGroup`.
let symbol_ty = if let Type::Tuple(tuple) = node_ty {
let type_base_exception = KnownClass::BaseException
.to_subclass_of(self.db)
.unwrap_or(Type::Unknown);
let mut builder = UnionBuilder::new(self.db);
for element in tuple.elements(self.db).iter().copied() {
builder = builder.add(if element.is_assignable_to(self.db, type_base_exception) {
element.to_instance(self.db)
} else {
if let Some(node) = node {
self.diagnostics
.add_invalid_exception(self.db, node, element);
}
Type::Unknown
});
}
builder.build()
} else if node_ty.is_subtype_of(self.db, KnownClass::Tuple.to_instance(self.db)) {
todo_type!("Homogeneous tuple in exception handler")
} else {
let type_base_exception = KnownClass::BaseException
.to_subclass_of(self.db)
.unwrap_or(Type::Unknown);
if node_ty.is_assignable_to(self.db, type_base_exception) {
node_ty.to_instance(self.db)
} else {
if let Some(node) = node {
self.diagnostics
.add_invalid_exception(self.db, node, node_ty);
}
Type::Unknown
}
};

let symbol_ty = if except_handler_definition.is_star() {
// TODO should be generic --Alex
// TODO: we should infer `ExceptionGroup` if `node_ty` is a subtype of `tuple[type[Exception], ...]`
// (needs support for homogeneous tuples).
//
// TODO should infer `ExceptionGroup` if all caught exceptions
// are subclasses of `Exception` --Alex
builtins_symbol(self.db, "BaseExceptionGroup")
.ignore_possibly_unbound()
.unwrap_or(Type::Unknown)
.to_instance(self.db)
// TODO: should be generic with `symbol_ty` as the generic parameter
KnownClass::BaseExceptionGroup.to_instance(self.db)
} else {
// TODO: anything that's a consistent subtype of
// `type[BaseException] | tuple[type[BaseException], ...]` should be valid;
// anything else is invalid and should lead to a diagnostic being reported --Alex
match node_ty {
Type::Any | Type::Unknown => node_ty,
Type::ClassLiteral(ClassLiteralType { class }) => Type::instance(class),
Type::Tuple(tuple) => UnionType::from_elements(
self.db,
tuple.elements(self.db).iter().map(|ty| {
ty.into_class_literal().map_or(
todo_type!("exception type"),
|ClassLiteralType { class }| Type::instance(class),
)
}),
),
_ => todo_type!("exception type"),
}
symbol_ty
};

self.add_binding(
Expand Down
Loading