Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checker: fix lambda generic param and return validation #22387

Merged
merged 4 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vlib/v/ast/table.v
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ pub mut:
panic_handler FnPanicHandler = default_table_panic_handler
panic_userdata voidptr = unsafe { nil } // can be used to pass arbitrary data to panic_handler;
panic_npanics int
cur_fn &FnDecl = unsafe { nil } // previously stored in Checker.cur_fn and Gen.cur_fn
cur_fn &FnDecl = unsafe { nil } // previously stored in Checker.cur_fn and Gen.cur_fn
cur_lambda &LambdaExpr = unsafe { nil } // current lambda node
cur_concrete_types []Type // current concrete types, e.g. <int, string>
gostmts int // how many `go` statements there were in the parsed files.
// When table.gostmts > 0, __VTHREADS__ is defined, which can be checked with `$if threads {`
Expand Down
9 changes: 9 additions & 0 deletions vlib/v/checker/checker.v
Original file line number Diff line number Diff line change
Expand Up @@ -2735,6 +2735,13 @@ fn (mut c Checker) unwrap_generic(typ ast.Type) ast.Type {
{
return t_typ
}
if c.inside_lambda && c.table.cur_lambda.call_ctx != unsafe { nil } {
if t_typ := c.table.resolve_generic_to_concrete(typ, c.table.cur_lambda.func.decl.generic_names,
c.table.cur_lambda.call_ctx.concrete_types)
{
return t_typ
}
}
}
}
return typ
Expand Down Expand Up @@ -2971,8 +2978,10 @@ pub fn (mut c Checker) expr(mut node ast.Expr) ast.Type {
}
ast.LambdaExpr {
c.inside_lambda = true
c.table.cur_lambda = unsafe { &node }
defer {
c.inside_lambda = false
c.table.cur_lambda = unsafe { nil }
}
return c.lambda_expr(mut node, c.expected_type)
}
Expand Down
6 changes: 5 additions & 1 deletion vlib/v/checker/fn.v
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,10 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
for i, mut call_arg in node.args {
if call_arg.expr is ast.CallExpr {
node.args[i].typ = c.expr(mut call_arg.expr)
} else if mut call_arg.expr is ast.LambdaExpr {
if node.concrete_types.len > 0 {
call_arg.expr.call_ctx = unsafe { node }
}
}
}
c.check_expected_arg_count(mut node, func) or { return func.return_type }
Expand Down Expand Up @@ -1608,7 +1612,7 @@ fn (mut c Checker) fn_call(mut node ast.CallExpr, mut continue_check &bool) ast.
}
if mut call_arg.expr is ast.LambdaExpr {
// Calling fn is generic and lambda arg also is generic
if node.concrete_types.len > 0
if node.concrete_types.len > 0 && call_arg.expr.func != unsafe { nil }
&& call_arg.expr.func.decl.generic_names.len > 0 {
call_arg.expr.call_ctx = unsafe { node }
if c.table.register_fn_concrete_types(call_arg.expr.func.decl.fkey(),
Expand Down
5 changes: 2 additions & 3 deletions vlib/v/checker/return.v
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,8 @@ fn (mut c Checker) return_stmt(mut node ast.Return) {
} else {
got_type_sym.name
}
// ignore generic casting expr on lambda in this phase
if c.inside_lambda && exp_type.has_flag(.generic)
&& node.exprs[expr_idxs[i]] is ast.CastExpr {
// ignore generic lambda return in this phase
if c.inside_lambda && exp_type.has_flag(.generic) {
continue
}
c.error('cannot use `${got_type_name}` as ${c.error_type_name(exp_type)} in return argument',
Expand Down
7 changes: 7 additions & 0 deletions vlib/v/tests/generics/lamda_param_and_ret_test.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import arrays

fn test_main() {
items := ['item1', 'item2', 'item3']
list := arrays.map_indexed[string, string](items, |i, item| '${i}. ${item}')
assert list == ['0. item1', '1. item2', '2. item3']
}
Loading