From 16a926d138108b22214a3c010c4d739baa56fe4a Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Fri, 31 May 2024 13:52:29 -0600 Subject: [PATCH] [red-knot] infer int literal types (#11623) ## Summary Give red-knot the ability to infer int literal types. This is quick and easy, mostly because these types are a convenient way to observe control-flow handling with simple assignments. ## Test Plan Added test. --- crates/red_knot/src/types.rs | 6 +++++ crates/red_knot/src/types/infer.rs | 40 ++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index 478a35f1c1658..f8b0201435555 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -36,6 +36,7 @@ pub enum Type { Instance(ClassTypeId), Union(UnionTypeId), Intersection(IntersectionTypeId), + IntLiteral(i64), // TODO protocols, callable types, overloads, generics, type vars } @@ -78,6 +79,10 @@ impl Type { // TODO return the intersection of those results todo!("attribute lookup on Intersection type") } + Type::IntLiteral(_) => { + // TODO raise error + Ok(Some(Type::Unknown)) + } } } } @@ -616,6 +621,7 @@ impl std::fmt::Display for DisplayType<'_> { .get_module(int_id.file_id) .get_intersection(int_id.intersection_id) .display(f, self.store), + Type::IntLiteral(n) => write!(f, "Literal[{n}]"), } } } diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index 0d6d23b8ce779..ff27f25c2dd09 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -145,6 +145,16 @@ fn infer_expr_type(db: &dyn SemanticDb, file_id: FileId, expr: &ast::Expr) -> Qu // TODO cache the resolution of the type on the node let symbols = symbol_table(db, file_id)?; match expr { + ast::Expr::NumberLiteral(ast::ExprNumberLiteral { value, .. }) => { + match value { + ast::Number::Int(n) => { + // TODO support big int literals + Ok(n.as_i64().map(Type::IntLiteral).unwrap_or(Type::Unknown)) + } + // TODO builtins.float or builtins.complex + _ => Ok(Type::Unknown), + } + } ast::Expr::Name(name) => { // TODO look up in the correct scope, don't assume global if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) { @@ -348,4 +358,34 @@ mod tests { assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]"); Ok(()) } + + #[test] + fn resolve_literal() -> anyhow::Result<()> { + let case = create_test()?; + let db = &case.db; + + let path = case.src.path().join("a.py"); + std::fs::write(path, "x = 1")?; + let file = resolve_module(db, ModuleName::new("a"))? + .expect("module should be found") + .path(db)? + .file(); + let syms = symbol_table(db, file)?; + let x_sym = syms + .root_symbol_id_by_name("x") + .expect("x symbol should be found"); + + let ty = infer_symbol_type( + db, + GlobalSymbolId { + file_id: file, + symbol_id: x_sym, + }, + )?; + + let jar = HasJar::::jar(db)?; + assert!(matches!(ty, Type::IntLiteral(_))); + assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[1]"); + Ok(()) + } }