Skip to content

Commit

Permalink
reveal_type works
Browse files Browse the repository at this point in the history
  • Loading branch information
carljm committed Sep 17, 2024
1 parent 23c2cc7 commit d327c60
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 38 deletions.
148 changes: 126 additions & 22 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,30 +496,37 @@ impl<'db> Type<'db> {
///
/// Returns `None` if `self` is not a callable type.
#[must_use]
pub fn call(&self, db: &'db dyn Db) -> Option<Type<'db>> {
fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> {
match self {
Type::Function(function_type) => Some(function_type.return_type(db)),
Type::RevealTypeFunction(function_type) => {
// TODO emit the diagnostic
Some(function_type.return_type(db))
}
Type::Function(function_type) => CallOutcome::callable(function_type.return_type(db)),
Type::RevealTypeFunction(function_type) => CallOutcome::RevealType {
return_ty: function_type.return_type(db),
revealed_ty: arg_types[0],
},

// TODO annotated return type on `__new__` or metaclass `__call__`
Type::Class(class) => Some(Type::Instance(*class)),
Type::Class(class) => CallOutcome::callable(Type::Instance(class)),

// TODO: handle classes which implement the Callable protocol
Type::Instance(_instance_ty) => Some(Type::Unknown),
// TODO: handle classes which implement the `__call__` protocol
Type::Instance(_instance_ty) => CallOutcome::callable(Type::Unknown),

// `Any` is callable, and its return type is also `Any`.
Type::Any => Some(Type::Any),
Type::Any => CallOutcome::callable(Type::Any),

Type::Unknown => Some(Type::Unknown),
Type::Unknown => CallOutcome::callable(Type::Unknown),

// TODO: union and intersection types, if they reduce to `Callable`
Type::Union(_) => Some(Type::Unknown),
Type::Intersection(_) => Some(Type::Unknown),
Type::Union(union) => CallOutcome::Union {
outcomes: union
.elements(db)
.iter()
.map(|elem| elem.call(db, arg_types))
.collect(),
},

_ => None,
// TODO: intersection types
Type::Intersection(_) => CallOutcome::callable(Type::Unknown),

_ => CallOutcome::not_callable(self),
}
}

Expand All @@ -531,7 +538,7 @@ impl<'db> Type<'db> {
/// for y in x:
/// pass
/// ```
fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
fn iterate(self, db: &'db dyn Db) -> IterationOutcome<'db> {
if let Type::Tuple(tuple_type) = self {
return IterationOutcome::Iterable {
element_ty: UnionType::from_elements(db, &**tuple_type.elements(db)),
Expand All @@ -544,18 +551,22 @@ impl<'db> Type<'db> {

let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let Some(iterator_ty) = dunder_iter_method.call(db) else {
let CallOutcome::Callable {
return_ty: iterator_ty,
} = dunder_iter_method.call(db, &[])
else {
return IterationOutcome::NotIterable {
not_iterable_ty: *self,
not_iterable_ty: self,
};
};

let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method
.call(db)
.call(db, &[])
.return_ty(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
not_iterable_ty: self,
});
}

Expand All @@ -568,10 +579,11 @@ impl<'db> Type<'db> {
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");

dunder_get_item_method
.call(db)
.call(db, &[])
.return_ty(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
not_iterable_ty: self,
})
}

Expand Down Expand Up @@ -638,6 +650,98 @@ impl<'db> From<&Type<'db>> for Type<'db> {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
enum CallOutcome<'db> {
Callable {
return_ty: Type<'db>,
},
RevealType {
return_ty: Type<'db>,
revealed_ty: Type<'db>,
},
NotCallable {
not_callable_ty: Type<'db>,
},
Union {
outcomes: Box<[CallOutcome<'db>]>,
},
}

impl<'db> CallOutcome<'db> {
/// Create a new `CallOutcome::Callable` with given return type.
fn callable(return_ty: Type<'db>) -> CallOutcome {
CallOutcome::Callable { return_ty }
}

/// Create a new `CallOutcome::NotCallable` with given not-callable type.
fn not_callable(not_callable_ty: Type<'db>) -> CallOutcome {
CallOutcome::NotCallable { not_callable_ty }
}

/// Get the return type of the call, or `None` if not callable.
fn return_ty(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Self::Callable { return_ty } => Some(*return_ty),
Self::RevealType {
return_ty,
revealed_ty: _,
} => Some(*return_ty),
Self::NotCallable { not_callable_ty: _ } => None,
Self::Union { outcomes } => outcomes
.iter()
.fold(None, |acc, outcome| {
let ty = outcome.return_ty(db);
match (acc, ty) {
(None, None) => None,
(None, Some(ty)) => Some(UnionBuilder::new(db).add(ty)),
(Some(builder), ty) => Some(builder.add(ty.unwrap_or(Type::Unknown))),
}
})
.map(UnionBuilder::build),
}
}

/// Get the return type of the call, emitting diagnostics if needed.
fn unwrap_with_diagnostic<'a>(
&self,
db: &'db dyn Db,
node: ast::AnyNodeRef,
builder: &'a mut TypeInferenceBuilder<'db>,
) -> Type<'db> {
match self {
Self::Callable { return_ty } => *return_ty,
Self::RevealType {
return_ty,
revealed_ty,
} => {
builder.add_diagnostic(
node,
"revealed-type",
format_args!("Revealed type is '{}'.", revealed_ty.display(db)),
);
*return_ty
}
Self::NotCallable { not_callable_ty } => {
builder.add_diagnostic(
node,
"call-non-callable",
format_args!(
"Object of type '{}' is not callable.",
not_callable_ty.display(db)
),
);
Type::Unknown
}
Self::Union { outcomes } => UnionType::from_elements(
db,
outcomes
.iter()
.map(|outcome| outcome.unwrap_with_diagnostic(db, node, builder)),
),
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IterationOutcome<'db> {
Iterable { element_ty: Type<'db> },
Expand Down
36 changes: 20 additions & 16 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2023,19 +2023,11 @@ impl<'db> TypeInferenceBuilder<'db> {
arguments,
} = call_expression;

self.infer_arguments(arguments);
let arg_types = self.infer_arguments(arguments);
let function_type = self.infer_expression(func);
function_type.call(self.db).unwrap_or_else(|| {
self.add_diagnostic(
func.as_ref().into(),
"call-non-callable",
format_args!(
"Object of type '{}' is not callable",
function_type.display(self.db)
),
);
Type::Unknown
})
function_type
.call(self.db, arg_types.as_slice())
.unwrap_with_diagnostic(self.db, func.as_ref().into(), self)
}

fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> {
Expand Down Expand Up @@ -2410,7 +2402,19 @@ impl<'db> TypeInferenceBuilder<'db> {
/// Adds a new diagnostic.
///
/// The diagnostic does not get added if the rule isn't enabled for this file.
fn add_diagnostic(&mut self, node: AnyNodeRef, rule: &str, message: std::fmt::Arguments) {
pub(super) fn add_diagnostic(
&mut self,
node: AnyNodeRef,
rule: &str,
message: std::fmt::Arguments,
) {
self.add_diagnostic_string(node, rule, message.to_string());
}

/// Adds a new diagnostic with a string message.
///
/// The diagnostic does not get added if the rule isn't enabled for this file.
pub(super) fn add_diagnostic_string(&mut self, node: AnyNodeRef, rule: &str, message: String) {
if !self.db.is_file_open(self.file) {
return;
}
Expand All @@ -2424,7 +2428,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.types.diagnostics.push(TypeCheckDiagnostic {
file: self.file,
rule: rule.to_string(),
message: message.to_string(),
message,
range: node.range(),
});
}
Expand Down Expand Up @@ -2760,7 +2764,7 @@ mod tests {
",
)?;

assert_file_diagnostics(&db, "/src/a.py", &["Revealed type of 'x' is 'Literal[1]'."]);
assert_file_diagnostics(&db, "/src/a.py", &["Revealed type is 'Literal[1]'."]);

Ok(())
}
Expand Down Expand Up @@ -3368,7 +3372,7 @@ mod tests {
assert_file_diagnostics(
&db,
"/src/a.py",
&["Object of type 'Literal[123]' is not callable"],
&["Object of type 'Literal[123]' is not callable."],
);
}

Expand Down

0 comments on commit d327c60

Please sign in to comment.