Skip to content

Commit

Permalink
[red-knot] support for typing.reveal_type (#13384)
Browse files Browse the repository at this point in the history
Add support for the `typing.reveal_type` function, emitting a diagnostic
revealing the type of its single argument. This is a necessary piece for
the planned testing framework.

This puts the cart slightly in front of the horse, in that we don't yet
have proper support for validating call signatures / argument types. But
it's easy to do just enough to make `reveal_type` work.

This PR includes support for calling union types (this is necessary
because we don't yet support `sys.version_info` checks, so
`typing.reveal_type` itself is a union type), plus some nice
consolidated error messages for calls to unions where some elements are
not callable. This is mostly to demonstrate the flexibility in
diagnostics that we get from the `CallOutcome` enum.
  • Loading branch information
carljm authored Sep 18, 2024
1 parent 44d916f commit c173ec5
Show file tree
Hide file tree
Showing 3 changed files with 384 additions and 51 deletions.
231 changes: 209 additions & 22 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ pub enum Type<'db> {
None,
/// a specific function object
Function(FunctionType<'db>),
/// The `typing.reveal_type` function, which has special `__call__` behavior.
RevealTypeFunction(FunctionType<'db>),
/// a specific module object
Module(File),
/// a specific class object
Expand Down Expand Up @@ -324,14 +326,16 @@ impl<'db> Type<'db> {

pub const fn into_function_type(self) -> Option<FunctionType<'db>> {
match self {
Type::Function(function_type) => Some(function_type),
Type::Function(function_type) | Type::RevealTypeFunction(function_type) => {
Some(function_type)
}
_ => None,
}
}

pub fn expect_function(self) -> FunctionType<'db> {
self.into_function_type()
.expect("Expected a Type::Function variant")
.expect("Expected a variant wrapping a FunctionType")
}

pub const fn into_int_literal_type(self) -> Option<i64> {
Expand Down Expand Up @@ -367,6 +371,16 @@ impl<'db> Type<'db> {
}
}

pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
match self {
Type::Class(class) => class.is_stdlib_symbol(db, module_name, name),
Type::Function(function) | Type::RevealTypeFunction(function) => {
function.is_stdlib_symbol(db, module_name, name)
}
_ => false,
}
}

/// Return true if this type is [assignable to] type `target`.
///
/// [assignable to]: https://typing.readthedocs.io/en/latest/spec/concepts.html#the-assignable-to-or-consistent-subtyping-relation
Expand Down Expand Up @@ -436,7 +450,7 @@ impl<'db> Type<'db> {
// TODO: attribute lookup on None type
Type::Unknown
}
Type::Function(_) => {
Type::Function(_) | Type::RevealTypeFunction(_) => {
// TODO: attribute lookup on function type
Type::Unknown
}
Expand Down Expand Up @@ -482,26 +496,39 @@ impl<'db> Type<'db> {
///
/// Returns `None` if `self` is not a callable type.
#[must_use]
pub fn call(&self, db: &'db dyn Db) -> Option<Type<'db>> {
fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> {
match self {
Type::Function(function_type) => Some(function_type.return_type(db)),
// TODO validate typed call arguments vs callable signature
Type::Function(function_type) => CallOutcome::callable(function_type.return_type(db)),
Type::RevealTypeFunction(function_type) => CallOutcome::revealed(
function_type.return_type(db),
*arg_types.first().unwrap_or(&Type::Unknown),
),

// TODO annotated return type on `__new__` or metaclass `__call__`
Type::Class(class) => Some(Type::Instance(*class)),
Type::Class(class) => CallOutcome::callable(Type::Instance(class)),

// TODO: handle classes which implement the Callable protocol
Type::Instance(_instance_ty) => Some(Type::Unknown),
// TODO: handle classes which implement the `__call__` protocol
Type::Instance(_instance_ty) => CallOutcome::callable(Type::Unknown),

// `Any` is callable, and its return type is also `Any`.
Type::Any => Some(Type::Any),
Type::Any => CallOutcome::callable(Type::Any),

Type::Unknown => Some(Type::Unknown),
Type::Unknown => CallOutcome::callable(Type::Unknown),

// TODO: union and intersection types, if they reduce to `Callable`
Type::Union(_) => Some(Type::Unknown),
Type::Intersection(_) => Some(Type::Unknown),
Type::Union(union) => CallOutcome::union(
self,
union
.elements(db)
.iter()
.map(|elem| elem.call(db, arg_types))
.collect::<Box<[CallOutcome<'db>]>>(),
),

_ => None,
// TODO: intersection types
Type::Intersection(_) => CallOutcome::callable(Type::Unknown),

_ => CallOutcome::not_callable(self),
}
}

Expand All @@ -513,7 +540,7 @@ impl<'db> Type<'db> {
/// for y in x:
/// pass
/// ```
fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
fn iterate(self, db: &'db dyn Db) -> IterationOutcome<'db> {
if let Type::Tuple(tuple_type) = self {
return IterationOutcome::Iterable {
element_ty: UnionType::from_elements(db, &**tuple_type.elements(db)),
Expand All @@ -526,18 +553,22 @@ impl<'db> Type<'db> {

let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let Some(iterator_ty) = dunder_iter_method.call(db) else {
let CallOutcome::Callable {
return_ty: iterator_ty,
} = dunder_iter_method.call(db, &[])
else {
return IterationOutcome::NotIterable {
not_iterable_ty: *self,
not_iterable_ty: self,
};
};

let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method
.call(db)
.call(db, &[])
.return_ty(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
not_iterable_ty: self,
});
}

Expand All @@ -550,10 +581,11 @@ impl<'db> Type<'db> {
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");

dunder_get_item_method
.call(db)
.call(db, &[])
.return_ty(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
not_iterable_ty: self,
})
}

Expand All @@ -573,6 +605,7 @@ impl<'db> Type<'db> {
Type::BooleanLiteral(_)
| Type::BytesLiteral(_)
| Type::Function(_)
| Type::RevealTypeFunction(_)
| Type::Instance(_)
| Type::Module(_)
| Type::IntLiteral(_)
Expand All @@ -595,7 +628,7 @@ impl<'db> Type<'db> {
Type::BooleanLiteral(_) => builtins_symbol_ty(db, "bool"),
Type::BytesLiteral(_) => builtins_symbol_ty(db, "bytes"),
Type::IntLiteral(_) => builtins_symbol_ty(db, "int"),
Type::Function(_) => types_symbol_ty(db, "FunctionType"),
Type::Function(_) | Type::RevealTypeFunction(_) => types_symbol_ty(db, "FunctionType"),
Type::Module(_) => types_symbol_ty(db, "ModuleType"),
Type::None => typeshed_symbol_ty(db, "NoneType"),
// TODO not accurate if there's a custom metaclass...
Expand All @@ -619,6 +652,152 @@ impl<'db> From<&Type<'db>> for Type<'db> {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
enum CallOutcome<'db> {
Callable {
return_ty: Type<'db>,
},
RevealType {
return_ty: Type<'db>,
revealed_ty: Type<'db>,
},
NotCallable {
not_callable_ty: Type<'db>,
},
Union {
called_ty: Type<'db>,
outcomes: Box<[CallOutcome<'db>]>,
},
}

impl<'db> CallOutcome<'db> {
/// Create a new `CallOutcome::Callable` with given return type.
fn callable(return_ty: Type<'db>) -> CallOutcome {
CallOutcome::Callable { return_ty }
}

/// Create a new `CallOutcome::NotCallable` with given not-callable type.
fn not_callable(not_callable_ty: Type<'db>) -> CallOutcome {
CallOutcome::NotCallable { not_callable_ty }
}

/// Create a new `CallOutcome::RevealType` with given revealed and return types.
fn revealed(return_ty: Type<'db>, revealed_ty: Type<'db>) -> CallOutcome<'db> {
CallOutcome::RevealType {
return_ty,
revealed_ty,
}
}

/// Create a new `CallOutcome::Union` with given wrapped outcomes.
fn union(called_ty: Type<'db>, outcomes: impl Into<Box<[CallOutcome<'db>]>>) -> CallOutcome {
CallOutcome::Union {
called_ty,
outcomes: outcomes.into(),
}
}

/// Get the return type of the call, or `None` if not callable.
fn return_ty(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Self::Callable { return_ty } => Some(*return_ty),
Self::RevealType {
return_ty,
revealed_ty: _,
} => Some(*return_ty),
Self::NotCallable { not_callable_ty: _ } => None,
Self::Union {
outcomes,
called_ty: _,
} => outcomes
.iter()
// If all outcomes are NotCallable, we return None; if some outcomes are callable
// and some are not, we return a union including Unknown.
.fold(None, |acc, outcome| {
let ty = outcome.return_ty(db);
match (acc, ty) {
(None, None) => None,
(None, Some(ty)) => Some(UnionBuilder::new(db).add(ty)),
(Some(builder), ty) => Some(builder.add(ty.unwrap_or(Type::Unknown))),
}
})
.map(UnionBuilder::build),
}
}

/// Get the return type of the call, emitting diagnostics if needed.
fn unwrap_with_diagnostic<'a>(
&self,
db: &'db dyn Db,
node: ast::AnyNodeRef,
builder: &'a mut TypeInferenceBuilder<'db>,
) -> Type<'db> {
match self {
Self::Callable { return_ty } => *return_ty,
Self::RevealType {
return_ty,
revealed_ty,
} => {
builder.add_diagnostic(
node,
"revealed-type",
format_args!("Revealed type is '{}'.", revealed_ty.display(db)),
);
*return_ty
}
Self::NotCallable { not_callable_ty } => {
builder.add_diagnostic(
node,
"call-non-callable",
format_args!(
"Object of type '{}' is not callable.",
not_callable_ty.display(db)
),
);
Type::Unknown
}
Self::Union {
outcomes,
called_ty,
} => {
let mut not_callable = vec![];
let mut union_builder = UnionBuilder::new(db);
for outcome in &**outcomes {
let return_ty = if let Self::NotCallable { not_callable_ty } = outcome {
not_callable.push(*not_callable_ty);
Type::Unknown
} else {
outcome.unwrap_with_diagnostic(db, node, builder)
};
union_builder = union_builder.add(return_ty);
}
match not_callable[..] {
[] => {}
[elem] => builder.add_diagnostic(
node,
"call-non-callable",
format_args!(
"Union element '{}' of type '{}' is not callable.",
elem.display(db),
called_ty.display(db)
),
),
_ => builder.add_diagnostic(
node,
"call-non-callable",
format_args!(
"Union elements {} of type '{}' are not callable.",
not_callable.display(db),
called_ty.display(db)
),
),
}
union_builder.build()
}
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IterationOutcome<'db> {
Iterable { element_ty: Type<'db> },
Expand Down Expand Up @@ -654,6 +833,14 @@ pub struct FunctionType<'db> {
}

impl<'db> FunctionType<'db> {
/// Return true if this is a standard library function with given module name and name.
pub(crate) fn is_stdlib_symbol(self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
name == self.name(db)
&& file_to_module(db, self.definition(db).file(db)).is_some_and(|module| {
module.search_path().is_standard_library() && module.name() == module_name
})
}

pub fn has_decorator(self, db: &dyn Db, decorator: Type<'_>) -> bool {
self.decorators(db).contains(&decorator)
}
Expand Down
7 changes: 5 additions & 2 deletions crates/red_knot_python_semantic/src/types/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ impl Display for DisplayType<'_> {
| Type::BytesLiteral(_)
| Type::Class(_)
| Type::Function(_)
| Type::RevealTypeFunction(_)
) {
write!(f, "Literal[{representation}]",)
} else {
Expand Down Expand Up @@ -72,7 +73,9 @@ impl Display for DisplayRepresentation<'_> {
// TODO functions and classes should display using a fully qualified name
Type::Class(class) => f.write_str(class.name(self.db)),
Type::Instance(class) => f.write_str(class.name(self.db)),
Type::Function(function) => f.write_str(function.name(self.db)),
Type::Function(function) | Type::RevealTypeFunction(function) => {
f.write_str(function.name(self.db))
}
Type::Union(union) => union.display(self.db).fmt(f),
Type::Intersection(intersection) => intersection.display(self.db).fmt(f),
Type::IntLiteral(n) => n.fmt(f),
Expand Down Expand Up @@ -191,7 +194,7 @@ impl TryFrom<Type<'_>> for LiteralTypeKind {
fn try_from(value: Type<'_>) -> Result<Self, Self::Error> {
match value {
Type::Class(_) => Ok(Self::Class),
Type::Function(_) => Ok(Self::Function),
Type::Function(_) | Type::RevealTypeFunction(_) => Ok(Self::Function),
Type::IntLiteral(_) => Ok(Self::IntLiteral),
Type::StringLiteral(_) => Ok(Self::StringLiteral),
Type::BytesLiteral(_) => Ok(Self::BytesLiteral),
Expand Down
Loading

0 comments on commit c173ec5

Please sign in to comment.